├── PathRAG
├── PathRAG.py
├── __init__.py
├── __pycache__
│ ├── PathRAG.cpython-39.pyc
│ ├── __init__.cpython-39.pyc
│ ├── base.cpython-39.pyc
│ ├── llm.cpython-39.pyc
│ ├── operate.cpython-39.pyc
│ ├── prompt.cpython-39.pyc
│ ├── storage.cpython-39.pyc
│ └── utils.cpython-39.pyc
├── base.py
├── llm.py
├── operate.py
├── prompt.py
├── storage.py
└── utils.py
├── README.md
├── requirements.txt
└── v1_test.py
/PathRAG/PathRAG.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import os
3 | from tqdm.asyncio import tqdm as tqdm_async
4 | from dataclasses import asdict, dataclass, field
5 | from datetime import datetime
6 | from functools import partial
7 | from typing import Type, cast
8 |
9 |
10 | from .llm import (
11 | gpt_4o_mini_complete,
12 | openai_embedding,
13 | )
14 | from .operate import (
15 | chunking_by_token_size,
16 | extract_entities,
17 | kg_query,
18 | )
19 |
20 | from .utils import (
21 | EmbeddingFunc,
22 | compute_mdhash_id,
23 | limit_async_func_call,
24 | convert_response_to_json,
25 | logger,
26 | set_logger,
27 | )
28 | from .base import (
29 | BaseGraphStorage,
30 | BaseKVStorage,
31 | BaseVectorStorage,
32 | StorageNameSpace,
33 | QueryParam,
34 | )
35 |
36 | from .storage import (
37 | JsonKVStorage,
38 | NanoVectorDBStorage,
39 | NetworkXStorage,
40 | )
41 |
42 |
43 |
44 |
45 | def lazy_external_import(module_name: str, class_name: str):
46 | """Lazily import a class from an external module based on the package of the caller."""
47 |
48 |
49 | import inspect
50 |
51 | caller_frame = inspect.currentframe().f_back
52 | module = inspect.getmodule(caller_frame)
53 | package = module.__package__ if module else None
54 |
55 | def import_class(*args, **kwargs):
56 | import importlib
57 |
58 |
59 | module = importlib.import_module(module_name, package=package)
60 |
61 |
62 | cls = getattr(module, class_name)
63 | return cls(*args, **kwargs)
64 |
65 | return import_class
66 |
67 |
68 | Neo4JStorage = lazy_external_import(".kg.neo4j_impl", "Neo4JStorage")
69 | OracleKVStorage = lazy_external_import(".kg.oracle_impl", "OracleKVStorage")
70 | OracleGraphStorage = lazy_external_import(".kg.oracle_impl", "OracleGraphStorage")
71 | OracleVectorDBStorage = lazy_external_import(".kg.oracle_impl", "OracleVectorDBStorage")
72 | MilvusVectorDBStorge = lazy_external_import(".kg.milvus_impl", "MilvusVectorDBStorge")
73 | MongoKVStorage = lazy_external_import(".kg.mongo_impl", "MongoKVStorage")
74 | ChromaVectorDBStorage = lazy_external_import(".kg.chroma_impl", "ChromaVectorDBStorage")
75 | TiDBKVStorage = lazy_external_import(".kg.tidb_impl", "TiDBKVStorage")
76 | TiDBVectorDBStorage = lazy_external_import(".kg.tidb_impl", "TiDBVectorDBStorage")
77 | AGEStorage = lazy_external_import(".kg.age_impl", "AGEStorage")
78 |
79 |
80 | def always_get_an_event_loop() -> asyncio.AbstractEventLoop:
81 | """
82 | Ensure that there is always an event loop available.
83 |
84 | This function tries to get the current event loop. If the current event loop is closed or does not exist,
85 | it creates a new event loop and sets it as the current event loop.
86 |
87 | Returns:
88 | asyncio.AbstractEventLoop: The current or newly created event loop.
89 | """
90 | try:
91 |
92 | current_loop = asyncio.get_event_loop()
93 | if current_loop.is_closed():
94 | raise RuntimeError("Event loop is closed.")
95 | return current_loop
96 |
97 | except RuntimeError:
98 |
99 | logger.info("Creating a new event loop in main thread.")
100 | new_loop = asyncio.new_event_loop()
101 | asyncio.set_event_loop(new_loop)
102 | return new_loop
103 |
104 |
105 | @dataclass
106 | class PathRAG:
107 | working_dir: str = field(
108 | default_factory=lambda: f"./PathRAG_cache_{datetime.now().strftime('%Y-%m-%d-%H:%M:%S')}"
109 | )
110 |
111 | embedding_cache_config: dict = field(
112 | default_factory=lambda: {
113 | "enabled": False,
114 | "similarity_threshold": 0.95,
115 | "use_llm_check": False,
116 | }
117 | )
118 | kv_storage: str = field(default="JsonKVStorage")
119 | vector_storage: str = field(default="NanoVectorDBStorage")
120 | graph_storage: str = field(default="NetworkXStorage")
121 |
122 | current_log_level = logger.level
123 | log_level: str = field(default=current_log_level)
124 |
125 |
126 | chunk_token_size: int = 1200
127 | chunk_overlap_token_size: int = 100
128 | tiktoken_model_name: str = "gpt-4o-mini"
129 |
130 |
131 | entity_extract_max_gleaning: int = 1
132 | entity_summary_to_max_tokens: int = 500
133 |
134 |
135 | node_embedding_algorithm: str = "node2vec"
136 | node2vec_params: dict = field(
137 | default_factory=lambda: {
138 | "dimensions": 1536,
139 | "num_walks": 10,
140 | "walk_length": 40,
141 | "window_size": 2,
142 | "iterations": 3,
143 | "random_seed": 3,
144 | }
145 | )
146 |
147 |
148 | embedding_func: EmbeddingFunc = field(default_factory=lambda: openai_embedding)
149 | embedding_batch_num: int = 32
150 | embedding_func_max_async: int = 16
151 |
152 |
153 | llm_model_func: callable = gpt_4o_mini_complete
154 | llm_model_name: str = "meta-llama/Llama-3.2-1B-Instruct"
155 | llm_model_max_token_size: int = 32768
156 | llm_model_max_async: int = 16
157 | llm_model_kwargs: dict = field(default_factory=dict)
158 |
159 |
160 | vector_db_storage_cls_kwargs: dict = field(default_factory=dict)
161 |
162 | enable_llm_cache: bool = True
163 |
164 |
165 | addon_params: dict = field(default_factory=dict)
166 | convert_response_to_json_func: callable = convert_response_to_json
167 |
168 | def __post_init__(self):
169 | log_file = os.path.join("PathRAG.log")
170 | set_logger(log_file)
171 | logger.setLevel(self.log_level)
172 |
173 | logger.info(f"Logger initialized for working directory: {self.working_dir}")
174 |
175 |
176 | self.key_string_value_json_storage_cls: Type[BaseKVStorage] = (
177 | self._get_storage_class()[self.kv_storage]
178 | )
179 | self.vector_db_storage_cls: Type[BaseVectorStorage] = self._get_storage_class()[
180 | self.vector_storage
181 | ]
182 | self.graph_storage_cls: Type[BaseGraphStorage] = self._get_storage_class()[
183 | self.graph_storage
184 | ]
185 |
186 | if not os.path.exists(self.working_dir):
187 | logger.info(f"Creating working directory {self.working_dir}")
188 | os.makedirs(self.working_dir)
189 |
190 | self.llm_response_cache = (
191 | self.key_string_value_json_storage_cls(
192 | namespace="llm_response_cache",
193 | global_config=asdict(self),
194 | embedding_func=None,
195 | )
196 | if self.enable_llm_cache
197 | else None
198 | )
199 | self.embedding_func = limit_async_func_call(self.embedding_func_max_async)(
200 | self.embedding_func
201 | )
202 |
203 |
204 | self.full_docs = self.key_string_value_json_storage_cls(
205 | namespace="full_docs",
206 | global_config=asdict(self),
207 | embedding_func=self.embedding_func,
208 | )
209 | self.text_chunks = self.key_string_value_json_storage_cls(
210 | namespace="text_chunks",
211 | global_config=asdict(self),
212 | embedding_func=self.embedding_func,
213 | )
214 | self.chunk_entity_relation_graph = self.graph_storage_cls(
215 | namespace="chunk_entity_relation",
216 | global_config=asdict(self),
217 | embedding_func=self.embedding_func,
218 | )
219 |
220 |
221 | self.entities_vdb = self.vector_db_storage_cls(
222 | namespace="entities",
223 | global_config=asdict(self),
224 | embedding_func=self.embedding_func,
225 | meta_fields={"entity_name"},
226 | )
227 | self.relationships_vdb = self.vector_db_storage_cls(
228 | namespace="relationships",
229 | global_config=asdict(self),
230 | embedding_func=self.embedding_func,
231 | meta_fields={"src_id", "tgt_id"},
232 | )
233 | self.chunks_vdb = self.vector_db_storage_cls(
234 | namespace="chunks",
235 | global_config=asdict(self),
236 | embedding_func=self.embedding_func,
237 | )
238 |
239 | self.llm_model_func = limit_async_func_call(self.llm_model_max_async)(
240 | partial(
241 | self.llm_model_func,
242 | hashing_kv=self.llm_response_cache
243 | if self.llm_response_cache
244 | and hasattr(self.llm_response_cache, "global_config")
245 | else self.key_string_value_json_storage_cls(
246 | global_config=asdict(self),
247 | ),
248 | **self.llm_model_kwargs,
249 | )
250 | )
251 |
252 | def _get_storage_class(self) -> Type[BaseGraphStorage]:
253 | return {
254 |
255 | "JsonKVStorage": JsonKVStorage,
256 | "OracleKVStorage": OracleKVStorage,
257 | "MongoKVStorage": MongoKVStorage,
258 | "TiDBKVStorage": TiDBKVStorage,
259 |
260 | "NanoVectorDBStorage": NanoVectorDBStorage,
261 | "OracleVectorDBStorage": OracleVectorDBStorage,
262 | "MilvusVectorDBStorge": MilvusVectorDBStorge,
263 | "ChromaVectorDBStorage": ChromaVectorDBStorage,
264 | "TiDBVectorDBStorage": TiDBVectorDBStorage,
265 |
266 | "NetworkXStorage": NetworkXStorage,
267 | "Neo4JStorage": Neo4JStorage,
268 | "OracleGraphStorage": OracleGraphStorage,
269 | "AGEStorage": AGEStorage,
270 |
271 | }
272 |
273 | def insert(self, string_or_strings):
274 |
275 | loop = always_get_an_event_loop()
276 | return loop.run_until_complete(self.ainsert(string_or_strings))
277 |
278 | async def ainsert(self, string_or_strings):
279 | update_storage = False
280 | try:
281 | if isinstance(string_or_strings, str):
282 | string_or_strings = [string_or_strings]
283 |
284 | new_docs = {
285 | compute_mdhash_id(c.strip(), prefix="doc-"): {"content": c.strip()}
286 | for c in string_or_strings
287 | }
288 | _add_doc_keys = await self.full_docs.filter_keys(list(new_docs.keys()))
289 | new_docs = {k: v for k, v in new_docs.items() if k in _add_doc_keys}
290 | if not len(new_docs):
291 | logger.warning("All docs are already in the storage")
292 | return
293 | update_storage = True
294 | logger.info(f"[New Docs] inserting {len(new_docs)} docs")
295 |
296 | inserting_chunks = {}
297 | for doc_key, doc in tqdm_async(
298 | new_docs.items(), desc="Chunking documents", unit="doc"
299 | ):
300 | chunks = {
301 | compute_mdhash_id(dp["content"], prefix="chunk-"): {
302 | **dp,
303 | "full_doc_id": doc_key,
304 | }
305 | for dp in chunking_by_token_size(
306 | doc["content"],
307 | overlap_token_size=self.chunk_overlap_token_size,
308 | max_token_size=self.chunk_token_size,
309 | tiktoken_model=self.tiktoken_model_name,
310 | )
311 | }
312 | inserting_chunks.update(chunks)
313 | _add_chunk_keys = await self.text_chunks.filter_keys(
314 | list(inserting_chunks.keys())
315 | )
316 | inserting_chunks = {
317 | k: v for k, v in inserting_chunks.items() if k in _add_chunk_keys
318 | }
319 | if not len(inserting_chunks):
320 | logger.warning("All chunks are already in the storage")
321 | return
322 | logger.info(f"[New Chunks] inserting {len(inserting_chunks)} chunks")
323 |
324 | await self.chunks_vdb.upsert(inserting_chunks)
325 |
326 | logger.info("[Entity Extraction]...")
327 | maybe_new_kg = await extract_entities(
328 | inserting_chunks,
329 | knowledge_graph_inst=self.chunk_entity_relation_graph,
330 | entity_vdb=self.entities_vdb,
331 | relationships_vdb=self.relationships_vdb,
332 | global_config=asdict(self),
333 | )
334 | if maybe_new_kg is None:
335 | logger.warning("No new entities and relationships found")
336 | return
337 | self.chunk_entity_relation_graph = maybe_new_kg
338 |
339 | await self.full_docs.upsert(new_docs)
340 | await self.text_chunks.upsert(inserting_chunks)
341 | finally:
342 | if update_storage:
343 | await self._insert_done()
344 |
345 | async def _insert_done(self):
346 | tasks = []
347 | for storage_inst in [
348 | self.full_docs,
349 | self.text_chunks,
350 | self.llm_response_cache,
351 | self.entities_vdb,
352 | self.relationships_vdb,
353 | self.chunks_vdb,
354 | self.chunk_entity_relation_graph,
355 | ]:
356 | if storage_inst is None:
357 | continue
358 | tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
359 | await asyncio.gather(*tasks)
360 |
361 | def insert_custom_kg(self, custom_kg: dict):
362 | loop = always_get_an_event_loop()
363 | return loop.run_until_complete(self.ainsert_custom_kg(custom_kg))
364 |
365 | async def ainsert_custom_kg(self, custom_kg: dict):
366 | update_storage = False
367 | try:
368 |
369 | all_chunks_data = {}
370 | chunk_to_source_map = {}
371 | for chunk_data in custom_kg.get("chunks", []):
372 | chunk_content = chunk_data["content"]
373 | source_id = chunk_data["source_id"]
374 | chunk_id = compute_mdhash_id(chunk_content.strip(), prefix="chunk-")
375 |
376 | chunk_entry = {"content": chunk_content.strip(), "source_id": source_id}
377 | all_chunks_data[chunk_id] = chunk_entry
378 | chunk_to_source_map[source_id] = chunk_id
379 | update_storage = True
380 |
381 | if self.chunks_vdb is not None and all_chunks_data:
382 | await self.chunks_vdb.upsert(all_chunks_data)
383 | if self.text_chunks is not None and all_chunks_data:
384 | await self.text_chunks.upsert(all_chunks_data)
385 |
386 |
387 | all_entities_data = []
388 | for entity_data in custom_kg.get("entities", []):
389 | entity_name = f'"{entity_data["entity_name"].upper()}"'
390 | entity_type = entity_data.get("entity_type", "UNKNOWN")
391 | description = entity_data.get("description", "No description provided")
392 |
393 | source_chunk_id = entity_data.get("source_id", "UNKNOWN")
394 | source_id = chunk_to_source_map.get(source_chunk_id, "UNKNOWN")
395 |
396 |
397 | if source_id == "UNKNOWN":
398 | logger.warning(
399 | f"Entity '{entity_name}' has an UNKNOWN source_id. Please check the source mapping."
400 | )
401 |
402 |
403 | node_data = {
404 | "entity_type": entity_type,
405 | "description": description,
406 | "source_id": source_id,
407 | }
408 |
409 | await self.chunk_entity_relation_graph.upsert_node(
410 | entity_name, node_data=node_data
411 | )
412 | node_data["entity_name"] = entity_name
413 | all_entities_data.append(node_data)
414 | update_storage = True
415 |
416 |
417 | all_relationships_data = []
418 | for relationship_data in custom_kg.get("relationships", []):
419 | src_id = f'"{relationship_data["src_id"].upper()}"'
420 | tgt_id = f'"{relationship_data["tgt_id"].upper()}"'
421 | description = relationship_data["description"]
422 | keywords = relationship_data["keywords"]
423 | weight = relationship_data.get("weight", 1.0)
424 |
425 | source_chunk_id = relationship_data.get("source_id", "UNKNOWN")
426 | source_id = chunk_to_source_map.get(source_chunk_id, "UNKNOWN")
427 |
428 |
429 | if source_id == "UNKNOWN":
430 | logger.warning(
431 | f"Relationship from '{src_id}' to '{tgt_id}' has an UNKNOWN source_id. Please check the source mapping."
432 | )
433 |
434 |
435 | for need_insert_id in [src_id, tgt_id]:
436 | if not (
437 | await self.chunk_entity_relation_graph.has_node(need_insert_id)
438 | ):
439 | await self.chunk_entity_relation_graph.upsert_node(
440 | need_insert_id,
441 | node_data={
442 | "source_id": source_id,
443 | "description": "UNKNOWN",
444 | "entity_type": "UNKNOWN",
445 | },
446 | )
447 |
448 |
449 | await self.chunk_entity_relation_graph.upsert_edge(
450 | src_id,
451 | tgt_id,
452 | edge_data={
453 | "weight": weight,
454 | "description": description,
455 | "keywords": keywords,
456 | "source_id": source_id,
457 | },
458 | )
459 | edge_data = {
460 | "src_id": src_id,
461 | "tgt_id": tgt_id,
462 | "description": description,
463 | "keywords": keywords,
464 | }
465 | all_relationships_data.append(edge_data)
466 | update_storage = True
467 |
468 |
469 | if self.entities_vdb is not None:
470 | data_for_vdb = {
471 | compute_mdhash_id(dp["entity_name"], prefix="ent-"): {
472 | "content": dp["entity_name"] + dp["description"],
473 | "entity_name": dp["entity_name"],
474 | }
475 | for dp in all_entities_data
476 | }
477 | await self.entities_vdb.upsert(data_for_vdb)
478 |
479 |
480 | if self.relationships_vdb is not None:
481 | data_for_vdb = {
482 | compute_mdhash_id(dp["src_id"] + dp["tgt_id"], prefix="rel-"): {
483 | "src_id": dp["src_id"],
484 | "tgt_id": dp["tgt_id"],
485 | "content": dp["keywords"]
486 | + dp["src_id"]
487 | + dp["tgt_id"]
488 | + dp["description"],
489 | }
490 | for dp in all_relationships_data
491 | }
492 | await self.relationships_vdb.upsert(data_for_vdb)
493 | finally:
494 | if update_storage:
495 | await self._insert_done()
496 |
497 | def query(self, query: str, param: QueryParam = QueryParam()):
498 | loop = always_get_an_event_loop()
499 | return loop.run_until_complete(self.aquery(query, param))
500 |
501 | async def aquery(self, query: str, param: QueryParam = QueryParam()):
502 | if param.mode in ["hybrid"]:
503 | response= await kg_query(
504 | query,
505 | self.chunk_entity_relation_graph,
506 | self.entities_vdb,
507 | self.relationships_vdb,
508 | self.text_chunks,
509 | param,
510 | asdict(self),
511 | hashing_kv=self.llm_response_cache
512 | if self.llm_response_cache
513 | and hasattr(self.llm_response_cache, "global_config")
514 | else self.key_string_value_json_storage_cls(
515 | global_config=asdict(self),
516 | ),
517 | )
518 | print("response all ready")
519 | else:
520 | raise ValueError(f"Unknown mode {param.mode}")
521 | await self._query_done()
522 | return response
523 |
524 |
525 | async def _query_done(self):
526 | tasks = []
527 | for storage_inst in [self.llm_response_cache]:
528 | if storage_inst is None:
529 | continue
530 | tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
531 | await asyncio.gather(*tasks)
532 |
533 | def delete_by_entity(self, entity_name: str):
534 | loop = always_get_an_event_loop()
535 | return loop.run_until_complete(self.adelete_by_entity(entity_name))
536 |
537 | async def adelete_by_entity(self, entity_name: str):
538 | entity_name = f'"{entity_name.upper()}"'
539 |
540 | try:
541 | await self.entities_vdb.delete_entity(entity_name)
542 | await self.relationships_vdb.delete_relation(entity_name)
543 | await self.chunk_entity_relation_graph.delete_node(entity_name)
544 |
545 | logger.info(
546 | f"Entity '{entity_name}' and its relationships have been deleted."
547 | )
548 | await self._delete_by_entity_done()
549 | except Exception as e:
550 | logger.error(f"Error while deleting entity '{entity_name}': {e}")
551 |
552 | async def _delete_by_entity_done(self):
553 | tasks = []
554 | for storage_inst in [
555 | self.entities_vdb,
556 | self.relationships_vdb,
557 | self.chunk_entity_relation_graph,
558 | ]:
559 | if storage_inst is None:
560 | continue
561 | tasks.append(cast(StorageNameSpace, storage_inst).index_done_callback())
562 | await asyncio.gather(*tasks)
563 |
--------------------------------------------------------------------------------
/PathRAG/__init__.py:
--------------------------------------------------------------------------------
1 | from .PathRAG import PathRAG as PathRAG, QueryParam as QueryParam
2 |
3 |
4 |
--------------------------------------------------------------------------------
/PathRAG/__pycache__/PathRAG.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BUPT-GAMMA/PathRAG/d9be0879d7c32dc35f1bd38ffa4b04b11ba9b246/PathRAG/__pycache__/PathRAG.cpython-39.pyc
--------------------------------------------------------------------------------
/PathRAG/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BUPT-GAMMA/PathRAG/d9be0879d7c32dc35f1bd38ffa4b04b11ba9b246/PathRAG/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/PathRAG/__pycache__/base.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BUPT-GAMMA/PathRAG/d9be0879d7c32dc35f1bd38ffa4b04b11ba9b246/PathRAG/__pycache__/base.cpython-39.pyc
--------------------------------------------------------------------------------
/PathRAG/__pycache__/llm.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BUPT-GAMMA/PathRAG/d9be0879d7c32dc35f1bd38ffa4b04b11ba9b246/PathRAG/__pycache__/llm.cpython-39.pyc
--------------------------------------------------------------------------------
/PathRAG/__pycache__/operate.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BUPT-GAMMA/PathRAG/d9be0879d7c32dc35f1bd38ffa4b04b11ba9b246/PathRAG/__pycache__/operate.cpython-39.pyc
--------------------------------------------------------------------------------
/PathRAG/__pycache__/prompt.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BUPT-GAMMA/PathRAG/d9be0879d7c32dc35f1bd38ffa4b04b11ba9b246/PathRAG/__pycache__/prompt.cpython-39.pyc
--------------------------------------------------------------------------------
/PathRAG/__pycache__/storage.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BUPT-GAMMA/PathRAG/d9be0879d7c32dc35f1bd38ffa4b04b11ba9b246/PathRAG/__pycache__/storage.cpython-39.pyc
--------------------------------------------------------------------------------
/PathRAG/__pycache__/utils.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/BUPT-GAMMA/PathRAG/d9be0879d7c32dc35f1bd38ffa4b04b11ba9b246/PathRAG/__pycache__/utils.cpython-39.pyc
--------------------------------------------------------------------------------
/PathRAG/base.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass, field
2 | from typing import TypedDict, Union, Literal, Generic, TypeVar
3 |
4 | import numpy as np
5 |
6 | from .utils import EmbeddingFunc
7 |
8 | TextChunkSchema = TypedDict(
9 | "TextChunkSchema",
10 | {"tokens": int, "content": str, "full_doc_id": str, "chunk_order_index": int},
11 | )
12 |
13 | T = TypeVar("T")
14 |
15 |
16 | @dataclass
17 | class QueryParam:
18 | mode: Literal["hybrid"] = "global"
19 | only_need_context: bool = False
20 | only_need_prompt: bool = False
21 | response_type: str = "Multiple Paragraphs"
22 | stream: bool = False
23 | top_k: int =40
24 | max_token_for_text_unit: int = 4000
25 | max_token_for_global_context: int = 3000
26 | max_token_for_local_context: int = 5000
27 |
28 |
29 | @dataclass
30 | class StorageNameSpace:
31 | namespace: str
32 | global_config: dict
33 |
34 | async def index_done_callback(self):
35 |
36 | pass
37 |
38 | async def query_done_callback(self):
39 |
40 | pass
41 |
42 |
43 | @dataclass
44 | class BaseVectorStorage(StorageNameSpace):
45 | embedding_func: EmbeddingFunc
46 | meta_fields: set = field(default_factory=set)
47 |
48 | async def query(self, query: str, top_k: int) -> list[dict]:
49 | raise NotImplementedError
50 |
51 | async def upsert(self, data: dict[str, dict]):
52 |
53 | raise NotImplementedError
54 |
55 |
56 | @dataclass
57 | class BaseKVStorage(Generic[T], StorageNameSpace):
58 | embedding_func: EmbeddingFunc
59 |
60 | async def all_keys(self) -> list[str]:
61 | raise NotImplementedError
62 |
63 | async def get_by_id(self, id: str) -> Union[T, None]:
64 | raise NotImplementedError
65 |
66 | async def get_by_ids(
67 | self, ids: list[str], fields: Union[set[str], None] = None
68 | ) -> list[Union[T, None]]:
69 | raise NotImplementedError
70 |
71 | async def filter_keys(self, data: list[str]) -> set[str]:
72 |
73 | raise NotImplementedError
74 |
75 | async def upsert(self, data: dict[str, T]):
76 | raise NotImplementedError
77 |
78 | async def drop(self):
79 | raise NotImplementedError
80 |
81 |
82 | @dataclass
83 | class BaseGraphStorage(StorageNameSpace):
84 | embedding_func: EmbeddingFunc = None
85 |
86 | async def has_node(self, node_id: str) -> bool:
87 | raise NotImplementedError
88 |
89 | async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
90 | raise NotImplementedError
91 |
92 | async def node_degree(self, node_id: str) -> int:
93 | raise NotImplementedError
94 |
95 | async def edge_degree(self, src_id: str, tgt_id: str) -> int:
96 | raise NotImplementedError
97 |
98 | async def get_pagerank(self,node_id:str) -> float:
99 | raise NotImplementedError
100 |
101 | async def get_node(self, node_id: str) -> Union[dict, None]:
102 | raise NotImplementedError
103 |
104 | async def get_edge(
105 | self, source_node_id: str, target_node_id: str
106 | ) -> Union[dict, None]:
107 | raise NotImplementedError
108 |
109 | async def get_node_edges(
110 | self, source_node_id: str
111 | ) -> Union[list[tuple[str, str]], None]:
112 | raise NotImplementedError
113 |
114 | async def get_node_in_edges(
115 | self,source_node_id:str
116 | ) -> Union[list[tuple[str,str]],None]:
117 | raise NotImplementedError
118 | async def get_node_out_edges(
119 | self,source_node_id:str
120 | ) -> Union[list[tuple[str,str]],None]:
121 | raise NotImplementedError
122 |
123 | async def upsert_node(self, node_id: str, node_data: dict[str, str]):
124 | raise NotImplementedError
125 |
126 | async def upsert_edge(
127 | self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
128 | ):
129 | raise NotImplementedError
130 |
131 | async def delete_node(self, node_id: str):
132 | raise NotImplementedError
133 |
134 | async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]:
135 | raise NotImplementedError("Node embedding is not used in PathRag.")
136 |
--------------------------------------------------------------------------------
/PathRAG/llm.py:
--------------------------------------------------------------------------------
1 | import base64
2 | import copy
3 | import json
4 | import os
5 | import re
6 | import struct
7 | from functools import lru_cache
8 | from typing import List, Dict, Callable, Any, Union, Optional
9 | import aioboto3
10 | import aiohttp
11 | import numpy as np
12 | import ollama
13 | import torch
14 | import time
15 | from openai import (
16 | AsyncOpenAI,
17 | APIConnectionError,
18 | RateLimitError,
19 | Timeout,
20 | AsyncAzureOpenAI,
21 | )
22 | from pydantic import BaseModel, Field
23 | from tenacity import (
24 | retry,
25 | stop_after_attempt,
26 | wait_exponential,
27 | retry_if_exception_type,
28 | )
29 | from transformers import AutoTokenizer, AutoModelForCausalLM
30 |
31 | from .utils import (
32 | wrap_embedding_func_with_attrs,
33 | locate_json_string_body_from_string,
34 | safe_unicode_decode,
35 | logger,
36 | )
37 |
38 | import sys
39 |
40 | if sys.version_info < (3, 9):
41 | from typing import AsyncIterator
42 | else:
43 | from collections.abc import AsyncIterator
44 |
45 | os.environ["TOKENIZERS_PARALLELISM"] = "false"
46 |
47 |
48 | @retry(
49 | stop=stop_after_attempt(3),
50 | wait=wait_exponential(multiplier=1, min=4, max=10),
51 | retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
52 | )
53 | async def openai_complete_if_cache(
54 | model,
55 | prompt,
56 | system_prompt=None,
57 | history_messages=[],
58 | base_url="https://api.openai.com/v1",
59 | api_key="",
60 | **kwargs,
61 | ) -> str:
62 | if api_key:
63 | os.environ["OPENAI_API_KEY"] = api_key
64 | time.sleep(2)
65 | openai_async_client = (
66 | AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url)
67 | )
68 | kwargs.pop("hashing_kv", None)
69 | kwargs.pop("keyword_extraction", None)
70 | messages = []
71 | if system_prompt:
72 | messages.append({"role": "system", "content": system_prompt})
73 | messages.extend(history_messages)
74 | messages.append({"role": "user", "content": prompt})
75 |
76 |
77 | logger.debug("===== Query Input to LLM =====")
78 | logger.debug(f"Query: {prompt}")
79 | logger.debug(f"System prompt: {system_prompt}")
80 | logger.debug("Full context:")
81 | if "response_format" in kwargs:
82 | response = await openai_async_client.beta.chat.completions.parse(
83 | model=model, messages=messages, **kwargs
84 | )
85 | else:
86 | response = await openai_async_client.chat.completions.create(
87 | model=model, messages=messages, **kwargs
88 | )
89 |
90 | if hasattr(response, "__aiter__"):
91 |
92 | async def inner():
93 | async for chunk in response:
94 | content = chunk.choices[0].delta.content
95 | if content is None:
96 | continue
97 | if r"\u" in content:
98 | content = safe_unicode_decode(content.encode("utf-8"))
99 | yield content
100 |
101 | return inner()
102 | else:
103 | content = response.choices[0].message.content
104 | if r"\u" in content:
105 | content = safe_unicode_decode(content.encode("utf-8"))
106 | return content
107 |
108 |
109 | @retry(
110 | stop=stop_after_attempt(3),
111 | wait=wait_exponential(multiplier=1, min=4, max=10),
112 | retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
113 | )
114 | async def azure_openai_complete_if_cache(
115 | model,
116 | prompt,
117 | system_prompt=None,
118 | history_messages=[],
119 | base_url=None,
120 | api_key=None,
121 | api_version=None,
122 | **kwargs,
123 | ):
124 | if api_key:
125 | os.environ["AZURE_OPENAI_API_KEY"] = api_key
126 | if base_url:
127 | os.environ["AZURE_OPENAI_ENDPOINT"] = base_url
128 | if api_version:
129 | os.environ["AZURE_OPENAI_API_VERSION"] = api_version
130 |
131 | openai_async_client = AsyncAzureOpenAI(
132 | azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
133 | api_key=os.getenv("AZURE_OPENAI_API_KEY"),
134 | api_version=os.getenv("AZURE_OPENAI_API_VERSION"),
135 | )
136 | kwargs.pop("hashing_kv", None)
137 | messages = []
138 | if system_prompt:
139 | messages.append({"role": "system", "content": system_prompt})
140 | messages.extend(history_messages)
141 | if prompt is not None:
142 | messages.append({"role": "user", "content": prompt})
143 |
144 | response = await openai_async_client.chat.completions.create(
145 | model=model, messages=messages, **kwargs
146 | )
147 | content = response.choices[0].message.content
148 |
149 | return content
150 |
151 |
152 | class BedrockError(Exception):
153 | """Generic error for issues related to Amazon Bedrock"""
154 |
155 |
156 | @retry(
157 | stop=stop_after_attempt(5),
158 | wait=wait_exponential(multiplier=1, max=60),
159 | retry=retry_if_exception_type((BedrockError)),
160 | )
161 | async def bedrock_complete_if_cache(
162 | model,
163 | prompt,
164 | system_prompt=None,
165 | history_messages=[],
166 | aws_access_key_id=None,
167 | aws_secret_access_key=None,
168 | aws_session_token=None,
169 | **kwargs,
170 | ) -> str:
171 | os.environ["AWS_ACCESS_KEY_ID"] = os.environ.get(
172 | "AWS_ACCESS_KEY_ID", aws_access_key_id
173 | )
174 | os.environ["AWS_SECRET_ACCESS_KEY"] = os.environ.get(
175 | "AWS_SECRET_ACCESS_KEY", aws_secret_access_key
176 | )
177 | os.environ["AWS_SESSION_TOKEN"] = os.environ.get(
178 | "AWS_SESSION_TOKEN", aws_session_token
179 | )
180 | kwargs.pop("hashing_kv", None)
181 |
182 | messages = []
183 | for history_message in history_messages:
184 | message = copy.copy(history_message)
185 | message["content"] = [{"text": message["content"]}]
186 | messages.append(message)
187 |
188 |
189 | messages.append({"role": "user", "content": [{"text": prompt}]})
190 |
191 |
192 | args = {"modelId": model, "messages": messages}
193 |
194 |
195 | if system_prompt:
196 | args["system"] = [{"text": system_prompt}]
197 |
198 |
199 | inference_params_map = {
200 | "max_tokens": "maxTokens",
201 | "top_p": "topP",
202 | "stop_sequences": "stopSequences",
203 | }
204 | if inference_params := list(
205 | set(kwargs) & set(["max_tokens", "temperature", "top_p", "stop_sequences"])
206 | ):
207 | args["inferenceConfig"] = {}
208 | for param in inference_params:
209 | args["inferenceConfig"][inference_params_map.get(param, param)] = (
210 | kwargs.pop(param)
211 | )
212 |
213 |
214 | session = aioboto3.Session()
215 | async with session.client("bedrock-runtime") as bedrock_async_client:
216 | try:
217 | response = await bedrock_async_client.converse(**args, **kwargs)
218 | except Exception as e:
219 | raise BedrockError(e)
220 |
221 | return response["output"]["message"]["content"][0]["text"]
222 |
223 |
224 | @lru_cache(maxsize=1)
225 | def initialize_hf_model(model_name):
226 | hf_tokenizer = AutoTokenizer.from_pretrained(
227 | model_name, device_map="auto", trust_remote_code=True
228 | )
229 | hf_model = AutoModelForCausalLM.from_pretrained(
230 | model_name, device_map="auto", trust_remote_code=True
231 | )
232 | if hf_tokenizer.pad_token is None:
233 | hf_tokenizer.pad_token = hf_tokenizer.eos_token
234 |
235 | return hf_model, hf_tokenizer
236 |
237 |
238 | @retry(
239 | stop=stop_after_attempt(3),
240 | wait=wait_exponential(multiplier=1, min=4, max=10),
241 | retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
242 | )
243 | async def hf_model_if_cache(
244 | model,
245 | prompt,
246 | system_prompt=None,
247 | history_messages=[],
248 | **kwargs,
249 | ) -> str:
250 | model_name = model
251 | hf_model, hf_tokenizer = initialize_hf_model(model_name)
252 | messages = []
253 | if system_prompt:
254 | messages.append({"role": "system", "content": system_prompt})
255 | messages.extend(history_messages)
256 | messages.append({"role": "user", "content": prompt})
257 | kwargs.pop("hashing_kv", None)
258 | input_prompt = ""
259 | try:
260 | input_prompt = hf_tokenizer.apply_chat_template(
261 | messages, tokenize=False, add_generation_prompt=True
262 | )
263 | except Exception:
264 | try:
265 | ori_message = copy.deepcopy(messages)
266 | if messages[0]["role"] == "system":
267 | messages[1]["content"] = (
268 | ""
269 | + messages[0]["content"]
270 | + "\n"
271 | + messages[1]["content"]
272 | )
273 | messages = messages[1:]
274 | input_prompt = hf_tokenizer.apply_chat_template(
275 | messages, tokenize=False, add_generation_prompt=True
276 | )
277 | except Exception:
278 | len_message = len(ori_message)
279 | for msgid in range(len_message):
280 | input_prompt = (
281 | input_prompt
282 | + "<"
283 | + ori_message[msgid]["role"]
284 | + ">"
285 | + ori_message[msgid]["content"]
286 | + ""
287 | + ori_message[msgid]["role"]
288 | + ">\n"
289 | )
290 |
291 | input_ids = hf_tokenizer(
292 | input_prompt, return_tensors="pt", padding=True, truncation=True
293 | ).to("cuda")
294 | inputs = {k: v.to(hf_model.device) for k, v in input_ids.items()}
295 | output = hf_model.generate(
296 | **input_ids, max_new_tokens=512, num_return_sequences=1, early_stopping=True
297 | )
298 | response_text = hf_tokenizer.decode(
299 | output[0][len(inputs["input_ids"][0]) :], skip_special_tokens=True
300 | )
301 |
302 | return response_text
303 |
304 |
305 | @retry(
306 | stop=stop_after_attempt(3),
307 | wait=wait_exponential(multiplier=1, min=4, max=10),
308 | retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
309 | )
310 | async def ollama_model_if_cache(
311 | model,
312 | prompt,
313 | system_prompt=None,
314 | history_messages=[],
315 | **kwargs,
316 | ) -> Union[str, AsyncIterator[str]]:
317 | stream = True if kwargs.get("stream") else False
318 | kwargs.pop("max_tokens", None)
319 | host = kwargs.pop("host", None)
320 | timeout = kwargs.pop("timeout", None)
321 | kwargs.pop("hashing_kv", None)
322 | ollama_client = ollama.AsyncClient(host=host, timeout=timeout)
323 | messages = []
324 | if system_prompt:
325 | messages.append({"role": "system", "content": system_prompt})
326 | messages.extend(history_messages)
327 | messages.append({"role": "user", "content": prompt})
328 |
329 | response = await ollama_client.chat(model=model, messages=messages, **kwargs)
330 | if stream:
331 | """cannot cache stream response"""
332 |
333 | async def inner():
334 | async for chunk in response:
335 | yield chunk["message"]["content"]
336 |
337 | return inner()
338 | else:
339 | return response["message"]["content"]
340 |
341 |
342 | @lru_cache(maxsize=1)
343 | def initialize_lmdeploy_pipeline(
344 | model,
345 | tp=1,
346 | chat_template=None,
347 | log_level="WARNING",
348 | model_format="hf",
349 | quant_policy=0,
350 | ):
351 | from lmdeploy import pipeline, ChatTemplateConfig, TurbomindEngineConfig
352 |
353 | lmdeploy_pipe = pipeline(
354 | model_path=model,
355 | backend_config=TurbomindEngineConfig(
356 | tp=tp, model_format=model_format, quant_policy=quant_policy
357 | ),
358 | chat_template_config=(
359 | ChatTemplateConfig(model_name=chat_template) if chat_template else None
360 | ),
361 | log_level="WARNING",
362 | )
363 | return lmdeploy_pipe
364 |
365 |
366 | @retry(
367 | stop=stop_after_attempt(3),
368 | wait=wait_exponential(multiplier=1, min=4, max=10),
369 | retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
370 | )
371 | async def lmdeploy_model_if_cache(
372 | model,
373 | prompt,
374 | system_prompt=None,
375 | history_messages=[],
376 | chat_template=None,
377 | model_format="hf",
378 | quant_policy=0,
379 | **kwargs,
380 | ) -> str:
381 | """
382 | Args:
383 | model (str): The path to the model.
384 | It could be one of the following options:
385 | - i) A local directory path of a turbomind model which is
386 | converted by `lmdeploy convert` command or download
387 | from ii) and iii).
388 | - ii) The model_id of a lmdeploy-quantized model hosted
389 | inside a model repo on huggingface.co, such as
390 | "InternLM/internlm-chat-20b-4bit",
391 | "lmdeploy/llama2-chat-70b-4bit", etc.
392 | - iii) The model_id of a model hosted inside a model repo
393 | on huggingface.co, such as "internlm/internlm-chat-7b",
394 | "Qwen/Qwen-7B-Chat ", "baichuan-inc/Baichuan2-7B-Chat"
395 | and so on.
396 | chat_template (str): needed when model is a pytorch model on
397 | huggingface.co, such as "internlm-chat-7b",
398 | "Qwen-7B-Chat ", "Baichuan2-7B-Chat" and so on,
399 | and when the model name of local path did not match the original model name in HF.
400 | tp (int): tensor parallel
401 | prompt (Union[str, List[str]]): input texts to be completed.
402 | do_preprocess (bool): whether pre-process the messages. Default to
403 | True, which means chat_template will be applied.
404 | skip_special_tokens (bool): Whether or not to remove special tokens
405 | in the decoding. Default to be True.
406 | do_sample (bool): Whether or not to use sampling, use greedy decoding otherwise.
407 | Default to be False, which means greedy decoding will be applied.
408 | """
409 | try:
410 | import lmdeploy
411 | from lmdeploy import version_info, GenerationConfig
412 | except Exception:
413 | raise ImportError("Please install lmdeploy before initialize lmdeploy backend.")
414 | kwargs.pop("hashing_kv", None)
415 | kwargs.pop("response_format", None)
416 | max_new_tokens = kwargs.pop("max_tokens", 512)
417 | tp = kwargs.pop("tp", 1)
418 | skip_special_tokens = kwargs.pop("skip_special_tokens", True)
419 | do_preprocess = kwargs.pop("do_preprocess", True)
420 | do_sample = kwargs.pop("do_sample", False)
421 | gen_params = kwargs
422 |
423 | version = version_info
424 | if do_sample is not None and version < (0, 6, 0):
425 | raise RuntimeError(
426 | "`do_sample` parameter is not supported by lmdeploy until "
427 | f"v0.6.0, but currently using lmdeloy {lmdeploy.__version__}"
428 | )
429 | else:
430 | do_sample = True
431 | gen_params.update(do_sample=do_sample)
432 |
433 | lmdeploy_pipe = initialize_lmdeploy_pipeline(
434 | model=model,
435 | tp=tp,
436 | chat_template=chat_template,
437 | model_format=model_format,
438 | quant_policy=quant_policy,
439 | log_level="WARNING",
440 | )
441 |
442 | messages = []
443 | if system_prompt:
444 | messages.append({"role": "system", "content": system_prompt})
445 |
446 | messages.extend(history_messages)
447 | messages.append({"role": "user", "content": prompt})
448 |
449 | gen_config = GenerationConfig(
450 | skip_special_tokens=skip_special_tokens,
451 | max_new_tokens=max_new_tokens,
452 | **gen_params,
453 | )
454 |
455 | response = ""
456 | async for res in lmdeploy_pipe.generate(
457 | messages,
458 | gen_config=gen_config,
459 | do_preprocess=do_preprocess,
460 | stream_response=False,
461 | session_id=1,
462 | ):
463 | response += res.response
464 | return response
465 |
466 |
467 | class GPTKeywordExtractionFormat(BaseModel):
468 | high_level_keywords: List[str]
469 | low_level_keywords: List[str]
470 |
471 |
472 | async def openai_complete(
473 | prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
474 | ) -> Union[str, AsyncIterator[str]]:
475 | keyword_extraction = kwargs.pop("keyword_extraction", None)
476 | if keyword_extraction:
477 | kwargs["response_format"] = "json"
478 | model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
479 | return await openai_complete_if_cache(
480 | model_name,
481 | prompt,
482 | system_prompt=system_prompt,
483 | history_messages=history_messages,
484 | **kwargs,
485 | )
486 |
487 |
488 | async def gpt_4o_complete(
489 | prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
490 | ) -> str:
491 | keyword_extraction = kwargs.pop("keyword_extraction", None)
492 | if keyword_extraction:
493 | kwargs["response_format"] = GPTKeywordExtractionFormat
494 | return await openai_complete_if_cache(
495 | "gpt-4o",
496 | prompt,
497 | system_prompt=system_prompt,
498 | history_messages=history_messages,
499 | **kwargs,
500 | )
501 |
502 |
503 | async def gpt_4o_mini_complete(
504 | prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
505 | ) -> str:
506 | keyword_extraction = kwargs.pop("keyword_extraction", None)
507 | if keyword_extraction:
508 | kwargs["response_format"] = GPTKeywordExtractionFormat
509 | return await openai_complete_if_cache(
510 | "gpt-4o-mini",
511 | prompt,
512 | system_prompt=system_prompt,
513 | history_messages=history_messages,
514 | **kwargs,
515 | )
516 |
517 |
518 | async def nvidia_openai_complete(
519 | prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
520 | ) -> str:
521 | keyword_extraction = kwargs.pop("keyword_extraction", None)
522 | result = await openai_complete_if_cache(
523 | "nvidia/llama-3.1-nemotron-70b-instruct",
524 | prompt,
525 | system_prompt=system_prompt,
526 | history_messages=history_messages,
527 | base_url="https://integrate.api.nvidia.com/v1",
528 | **kwargs,
529 | )
530 | if keyword_extraction: # TODO: use JSON API
531 | return locate_json_string_body_from_string(result)
532 | return result
533 |
534 |
535 | async def azure_openai_complete(
536 | prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
537 | ) -> str:
538 | keyword_extraction = kwargs.pop("keyword_extraction", None)
539 | result = await azure_openai_complete_if_cache(
540 | "conversation-4o-mini",
541 | prompt,
542 | system_prompt=system_prompt,
543 | history_messages=history_messages,
544 | **kwargs,
545 | )
546 | if keyword_extraction: # TODO: use JSON API
547 | return locate_json_string_body_from_string(result)
548 | return result
549 |
550 |
551 | async def bedrock_complete(
552 | prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
553 | ) -> str:
554 | keyword_extraction = kwargs.pop("keyword_extraction", None)
555 | result = await bedrock_complete_if_cache(
556 | "anthropic.claude-3-haiku-20240307-v1:0",
557 | prompt,
558 | system_prompt=system_prompt,
559 | history_messages=history_messages,
560 | **kwargs,
561 | )
562 | if keyword_extraction: # TODO: use JSON API
563 | return locate_json_string_body_from_string(result)
564 | return result
565 |
566 |
567 | async def hf_model_complete(
568 | prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
569 | ) -> str:
570 | keyword_extraction = kwargs.pop("keyword_extraction", None)
571 | model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
572 | result = await hf_model_if_cache(
573 | model_name,
574 | prompt,
575 | system_prompt=system_prompt,
576 | history_messages=history_messages,
577 | **kwargs,
578 | )
579 | if keyword_extraction: # TODO: use JSON API
580 | return locate_json_string_body_from_string(result)
581 | return result
582 |
583 |
584 | async def ollama_model_complete(
585 | prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
586 | ) -> Union[str, AsyncIterator[str]]:
587 | keyword_extraction = kwargs.pop("keyword_extraction", None)
588 | if keyword_extraction:
589 | kwargs["format"] = "json"
590 | model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
591 | return await ollama_model_if_cache(
592 | model_name,
593 | prompt,
594 | system_prompt=system_prompt,
595 | history_messages=history_messages,
596 | **kwargs,
597 | )
598 |
599 |
600 | @retry(
601 | stop=stop_after_attempt(3),
602 | wait=wait_exponential(multiplier=1, min=4, max=10),
603 | retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
604 | )
605 | async def zhipu_complete_if_cache(
606 | prompt: Union[str, List[Dict[str, str]]],
607 | model: str = "glm-4-flashx",
608 | api_key: Optional[str] = None,
609 | system_prompt: Optional[str] = None,
610 | history_messages: List[Dict[str, str]] = [],
611 | **kwargs,
612 | ) -> str:
613 |
614 | try:
615 | from zhipuai import ZhipuAI
616 | except ImportError:
617 | raise ImportError("Please install zhipuai before initialize zhipuai backend.")
618 |
619 | if api_key:
620 | client = ZhipuAI(api_key=api_key)
621 | else:
622 | client = ZhipuAI()
623 |
624 | messages = []
625 |
626 | if not system_prompt:
627 | system_prompt = "You are a helpful assistant. 注意内容里的敏感词用***替换。"
628 |
629 |
630 | if system_prompt:
631 | messages.append({"role": "system", "content": system_prompt})
632 | messages.extend(history_messages)
633 | messages.append({"role": "user", "content": prompt})
634 |
635 |
636 | logger.debug("===== Query Input to LLM =====")
637 | logger.debug(f"Query: {prompt}")
638 | logger.debug(f"System prompt: {system_prompt}")
639 |
640 |
641 | kwargs = {
642 | k: v for k, v in kwargs.items() if k not in ["hashing_kv", "keyword_extraction"]
643 | }
644 |
645 | response = client.chat.completions.create(model=model, messages=messages, **kwargs)
646 |
647 | return response.choices[0].message.content
648 |
649 |
650 | async def zhipu_complete(
651 | prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
652 | ):
653 |
654 | keyword_extraction = kwargs.pop("keyword_extraction", None)
655 |
656 | if keyword_extraction:
657 | extraction_prompt = """You are a helpful assistant that extracts keywords from text.
658 | Please analyze the content and extract two types of keywords:
659 | 1. High-level keywords: Important concepts and main themes
660 | 2. Low-level keywords: Specific details and supporting elements
661 |
662 | Return your response in this exact JSON format:
663 | {
664 | "high_level_keywords": ["keyword1", "keyword2"],
665 | "low_level_keywords": ["keyword1", "keyword2", "keyword3"]
666 | }
667 |
668 | Only return the JSON, no other text."""
669 |
670 |
671 | if system_prompt:
672 | system_prompt = f"{system_prompt}\n\n{extraction_prompt}"
673 | else:
674 | system_prompt = extraction_prompt
675 |
676 | try:
677 | response = await zhipu_complete_if_cache(
678 | prompt=prompt,
679 | system_prompt=system_prompt,
680 | history_messages=history_messages,
681 | **kwargs,
682 | )
683 |
684 |
685 | try:
686 | data = json.loads(response)
687 | return GPTKeywordExtractionFormat(
688 | high_level_keywords=data.get("high_level_keywords", []),
689 | low_level_keywords=data.get("low_level_keywords", []),
690 | )
691 | except json.JSONDecodeError:
692 |
693 | match = re.search(r"\{[\s\S]*\}", response)
694 | if match:
695 | try:
696 | data = json.loads(match.group())
697 | return GPTKeywordExtractionFormat(
698 | high_level_keywords=data.get("high_level_keywords", []),
699 | low_level_keywords=data.get("low_level_keywords", []),
700 | )
701 | except json.JSONDecodeError:
702 | pass
703 |
704 |
705 | logger.warning(
706 | f"Failed to parse keyword extraction response: {response}"
707 | )
708 | return GPTKeywordExtractionFormat(
709 | high_level_keywords=[], low_level_keywords=[]
710 | )
711 | except Exception as e:
712 | logger.error(f"Error during keyword extraction: {str(e)}")
713 | return GPTKeywordExtractionFormat(
714 | high_level_keywords=[], low_level_keywords=[]
715 | )
716 | else:
717 | return await zhipu_complete_if_cache(
718 | prompt=prompt,
719 | system_prompt=system_prompt,
720 | history_messages=history_messages,
721 | **kwargs,
722 | )
723 |
724 |
725 | @wrap_embedding_func_with_attrs(embedding_dim=1024, max_token_size=8192)
726 | @retry(
727 | stop=stop_after_attempt(3),
728 | wait=wait_exponential(multiplier=1, min=4, max=60),
729 | retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
730 | )
731 | async def zhipu_embedding(
732 | texts: list[str], model: str = "embedding-3", api_key: str = None, **kwargs
733 | ) -> np.ndarray:
734 |
735 | try:
736 | from zhipuai import ZhipuAI
737 | except ImportError:
738 | raise ImportError("Please install zhipuai before initialize zhipuai backend.")
739 | if api_key:
740 | client = ZhipuAI(api_key=api_key)
741 | else:
742 | client = ZhipuAI()
743 |
744 | if isinstance(texts, str):
745 | texts = [texts]
746 |
747 | embeddings = []
748 | for text in texts:
749 | try:
750 | response = client.embeddings.create(model=model, input=[text], **kwargs)
751 | embeddings.append(response.data[0].embedding)
752 | except Exception as e:
753 | raise Exception(f"Error calling ChatGLM Embedding API: {str(e)}")
754 |
755 | return np.array(embeddings)
756 |
757 |
758 | @wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8192)
759 | @retry(
760 | stop=stop_after_attempt(3),
761 | wait=wait_exponential(multiplier=1, min=4, max=60),
762 | retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
763 | )
764 | async def openai_embedding(
765 | texts: list[str],
766 | model: str = "text-embedding-3-small",
767 | base_url="https://api.openai.com/v1",
768 | api_key="",
769 | ) -> np.ndarray:
770 | if api_key:
771 | os.environ["OPENAI_API_KEY"] = api_key
772 |
773 | openai_async_client = (
774 | AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url)
775 | )
776 | response = await openai_async_client.embeddings.create(
777 | model=model, input=texts, encoding_format="float"
778 | )
779 | return np.array([dp.embedding for dp in response.data])
780 |
781 |
782 | async def fetch_data(url, headers, data):
783 | async with aiohttp.ClientSession() as session:
784 | async with session.post(url, headers=headers, json=data) as response:
785 | response_json = await response.json()
786 | data_list = response_json.get("data", [])
787 | return data_list
788 |
789 |
790 | async def jina_embedding(
791 | texts: list[str],
792 | dimensions: int = 1024,
793 | late_chunking: bool = False,
794 | base_url: str = None,
795 | api_key: str = None,
796 | ) -> np.ndarray:
797 | if api_key:
798 | os.environ["JINA_API_KEY"] = api_key
799 | url = "https://api.jina.ai/v1/embeddings" if not base_url else base_url
800 | headers = {
801 | "Content-Type": "application/json",
802 | "Authorization": f"Bearer {os.environ['JINA_API_KEY']}",
803 | }
804 | data = {
805 | "model": "jina-embeddings-v3",
806 | "normalized": True,
807 | "embedding_type": "float",
808 | "dimensions": f"{dimensions}",
809 | "late_chunking": late_chunking,
810 | "input": texts,
811 | }
812 | data_list = await fetch_data(url, headers, data)
813 | return np.array([dp["embedding"] for dp in data_list])
814 |
815 |
816 | @wrap_embedding_func_with_attrs(embedding_dim=2048, max_token_size=512)
817 | @retry(
818 | stop=stop_after_attempt(3),
819 | wait=wait_exponential(multiplier=1, min=4, max=60),
820 | retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
821 | )
822 | async def nvidia_openai_embedding(
823 | texts: list[str],
824 | model: str = "nvidia/llama-3.2-nv-embedqa-1b-v1",
825 | base_url: str = "https://integrate.api.nvidia.com/v1",
826 | api_key: str = None,
827 | input_type: str = "passage",
828 | trunc: str = "NONE",
829 | encode: str = "float",
830 | ) -> np.ndarray:
831 | if api_key:
832 | os.environ["OPENAI_API_KEY"] = api_key
833 |
834 | openai_async_client = (
835 | AsyncOpenAI() if base_url is None else AsyncOpenAI(base_url=base_url)
836 | )
837 | response = await openai_async_client.embeddings.create(
838 | model=model,
839 | input=texts,
840 | encoding_format=encode,
841 | extra_body={"input_type": input_type, "truncate": trunc},
842 | )
843 | return np.array([dp.embedding for dp in response.data])
844 |
845 |
846 | @wrap_embedding_func_with_attrs(embedding_dim=1536, max_token_size=8191)
847 | @retry(
848 | stop=stop_after_attempt(3),
849 | wait=wait_exponential(multiplier=1, min=4, max=10),
850 | retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
851 | )
852 | async def azure_openai_embedding(
853 | texts: list[str],
854 | model: str = "text-embedding-3-small",
855 | base_url: str = None,
856 | api_key: str = None,
857 | api_version: str = None,
858 | ) -> np.ndarray:
859 | if api_key:
860 | os.environ["AZURE_OPENAI_API_KEY"] = api_key
861 | if base_url:
862 | os.environ["AZURE_OPENAI_ENDPOINT"] = base_url
863 | if api_version:
864 | os.environ["AZURE_OPENAI_API_VERSION"] = api_version
865 |
866 | openai_async_client = AsyncAzureOpenAI(
867 | azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
868 | api_key=os.getenv("AZURE_OPENAI_API_KEY"),
869 | api_version=os.getenv("AZURE_OPENAI_API_VERSION"),
870 | )
871 |
872 | response = await openai_async_client.embeddings.create(
873 | model=model, input=texts, encoding_format="float"
874 | )
875 | return np.array([dp.embedding for dp in response.data])
876 |
877 |
878 | @retry(
879 | stop=stop_after_attempt(3),
880 | wait=wait_exponential(multiplier=1, min=4, max=60),
881 | retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
882 | )
883 | async def siliconcloud_embedding(
884 | texts: list[str],
885 | model: str = "netease-youdao/bce-embedding-base_v1",
886 | base_url: str = "https://api.siliconflow.cn/v1/embeddings",
887 | max_token_size: int = 512,
888 | api_key: str = None,
889 | ) -> np.ndarray:
890 | if api_key and not api_key.startswith("Bearer "):
891 | api_key = "Bearer " + api_key
892 |
893 | headers = {"Authorization": api_key, "Content-Type": "application/json"}
894 |
895 | truncate_texts = [text[0:max_token_size] for text in texts]
896 |
897 | payload = {"model": model, "input": truncate_texts, "encoding_format": "base64"}
898 |
899 | base64_strings = []
900 | async with aiohttp.ClientSession() as session:
901 | async with session.post(base_url, headers=headers, json=payload) as response:
902 | content = await response.json()
903 | if "code" in content:
904 | raise ValueError(content)
905 | base64_strings = [item["embedding"] for item in content["data"]]
906 |
907 | embeddings = []
908 | for string in base64_strings:
909 | decode_bytes = base64.b64decode(string)
910 | n = len(decode_bytes) // 4
911 | float_array = struct.unpack("<" + "f" * n, decode_bytes)
912 | embeddings.append(float_array)
913 | return np.array(embeddings)
914 |
915 |
916 |
917 | async def bedrock_embedding(
918 | texts: list[str],
919 | model: str = "amazon.titan-embed-text-v2:0",
920 | aws_access_key_id=None,
921 | aws_secret_access_key=None,
922 | aws_session_token=None,
923 | ) -> np.ndarray:
924 | os.environ["AWS_ACCESS_KEY_ID"] = os.environ.get(
925 | "AWS_ACCESS_KEY_ID", aws_access_key_id
926 | )
927 | os.environ["AWS_SECRET_ACCESS_KEY"] = os.environ.get(
928 | "AWS_SECRET_ACCESS_KEY", aws_secret_access_key
929 | )
930 | os.environ["AWS_SESSION_TOKEN"] = os.environ.get(
931 | "AWS_SESSION_TOKEN", aws_session_token
932 | )
933 |
934 | session = aioboto3.Session()
935 | async with session.client("bedrock-runtime") as bedrock_async_client:
936 | if (model_provider := model.split(".")[0]) == "amazon":
937 | embed_texts = []
938 | for text in texts:
939 | if "v2" in model:
940 | body = json.dumps(
941 | {
942 | "inputText": text,
943 |
944 | "embeddingTypes": ["float"],
945 | }
946 | )
947 | elif "v1" in model:
948 | body = json.dumps({"inputText": text})
949 | else:
950 | raise ValueError(f"Model {model} is not supported!")
951 |
952 | response = await bedrock_async_client.invoke_model(
953 | modelId=model,
954 | body=body,
955 | accept="application/json",
956 | contentType="application/json",
957 | )
958 |
959 | response_body = await response.get("body").json()
960 |
961 | embed_texts.append(response_body["embedding"])
962 | elif model_provider == "cohere":
963 | body = json.dumps(
964 | {"texts": texts, "input_type": "search_document", "truncate": "NONE"}
965 | )
966 |
967 | response = await bedrock_async_client.invoke_model(
968 | model=model,
969 | body=body,
970 | accept="application/json",
971 | contentType="application/json",
972 | )
973 |
974 | response_body = json.loads(response.get("body").read())
975 |
976 | embed_texts = response_body["embeddings"]
977 | else:
978 | raise ValueError(f"Model provider '{model_provider}' is not supported!")
979 |
980 | return np.array(embed_texts)
981 |
982 |
983 | async def hf_embedding(texts: list[str], tokenizer, embed_model) -> np.ndarray:
984 | device = next(embed_model.parameters()).device
985 | input_ids = tokenizer(
986 | texts, return_tensors="pt", padding=True, truncation=True
987 | ).input_ids.to(device)
988 | with torch.no_grad():
989 | outputs = embed_model(input_ids)
990 | embeddings = outputs.last_hidden_state.mean(dim=1)
991 | if embeddings.dtype == torch.bfloat16:
992 | return embeddings.detach().to(torch.float32).cpu().numpy()
993 | else:
994 | return embeddings.detach().cpu().numpy()
995 |
996 |
997 | async def ollama_embedding(texts: list[str], embed_model, **kwargs) -> np.ndarray:
998 | """
999 | Deprecated in favor of `embed`.
1000 | """
1001 | embed_text = []
1002 | ollama_client = ollama.Client(**kwargs)
1003 | for text in texts:
1004 | data = ollama_client.embeddings(model=embed_model, prompt=text)
1005 | embed_text.append(data["embedding"])
1006 |
1007 | return embed_text
1008 |
1009 |
1010 | async def ollama_embed(texts: list[str], embed_model, **kwargs) -> np.ndarray:
1011 | ollama_client = ollama.Client(**kwargs)
1012 | data = ollama_client.embed(model=embed_model, input=texts)
1013 | return data["embeddings"]
1014 |
1015 |
1016 | class Model(BaseModel):
1017 | """
1018 | This is a Pydantic model class named 'Model' that is used to define a custom language model.
1019 |
1020 | Attributes:
1021 | gen_func (Callable[[Any], str]): A callable function that generates the response from the language model.
1022 | The function should take any argument and return a string.
1023 | kwargs (Dict[str, Any]): A dictionary that contains the arguments to pass to the callable function.
1024 | This could include parameters such as the model name, API key, etc.
1025 |
1026 | Example usage:
1027 | Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_1"]})
1028 |
1029 | In this example, 'openai_complete_if_cache' is the callable function that generates the response from the OpenAI model.
1030 | The 'kwargs' dictionary contains the model name and API key to be passed to the function.
1031 | """
1032 |
1033 | gen_func: Callable[[Any], str] = Field(
1034 | ...,
1035 | description="A function that generates the response from the llm. The response must be a string",
1036 | )
1037 | kwargs: Dict[str, Any] = Field(
1038 | ...,
1039 | description="The arguments to pass to the callable function. Eg. the api key, model name, etc",
1040 | )
1041 |
1042 | class Config:
1043 | arbitrary_types_allowed = True
1044 |
1045 |
1046 | class MultiModel:
1047 | """
1048 | Distributes the load across multiple language models. Useful for circumventing low rate limits with certain api providers especially if you are on the free tier.
1049 | Could also be used for spliting across diffrent models or providers.
1050 |
1051 | Attributes:
1052 | models (List[Model]): A list of language models to be used.
1053 |
1054 | Usage example:
1055 | ```python
1056 | models = [
1057 | Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_1"]}),
1058 | Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_2"]}),
1059 | Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_3"]}),
1060 | Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_4"]}),
1061 | Model(gen_func=openai_complete_if_cache, kwargs={"model": "gpt-4", "api_key": os.environ["OPENAI_API_KEY_5"]}),
1062 | ]
1063 | multi_model = MultiModel(models)
1064 | rag = LightRAG(
1065 | llm_model_func=multi_model.llm_model_func
1066 | / ..other args
1067 | )
1068 | ```
1069 | """
1070 |
1071 | def __init__(self, models: List[Model]):
1072 | self._models = models
1073 | self._current_model = 0
1074 |
1075 | def _next_model(self):
1076 | self._current_model = (self._current_model + 1) % len(self._models)
1077 | return self._models[self._current_model]
1078 |
1079 | async def llm_model_func(
1080 | self, prompt, system_prompt=None, history_messages=[], **kwargs
1081 | ) -> str:
1082 | kwargs.pop("model", None)
1083 | kwargs.pop("keyword_extraction", None)
1084 | kwargs.pop("mode", None)
1085 | next_model = self._next_model()
1086 | args = dict(
1087 | prompt=prompt,
1088 | system_prompt=system_prompt,
1089 | history_messages=history_messages,
1090 | **kwargs,
1091 | **next_model.kwargs,
1092 | )
1093 |
1094 | return await next_model.gen_func(**args)
1095 |
1096 |
1097 | if __name__ == "__main__":
1098 | import asyncio
1099 |
1100 | async def main():
1101 | result = await gpt_4o_mini_complete("How are you?")
1102 | print(result)
1103 |
1104 | asyncio.run(main())
1105 |
--------------------------------------------------------------------------------
/PathRAG/operate.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import json
3 | import re
4 | from tqdm.asyncio import tqdm as tqdm_async
5 | from typing import Union
6 | from collections import Counter, defaultdict
7 | import warnings
8 | import tiktoken
9 | import time
10 | import csv
11 | from .utils import (
12 | logger,
13 | clean_str,
14 | compute_mdhash_id,
15 | decode_tokens_by_tiktoken,
16 | encode_string_by_tiktoken,
17 | is_float_regex,
18 | list_of_list_to_csv,
19 | pack_user_ass_to_openai_messages,
20 | split_string_by_multi_markers,
21 | truncate_list_by_token_size,
22 | process_combine_contexts,
23 | compute_args_hash,
24 | handle_cache,
25 | save_to_cache,
26 | CacheData,
27 | )
28 | from .base import (
29 | BaseGraphStorage,
30 | BaseKVStorage,
31 | BaseVectorStorage,
32 | TextChunkSchema,
33 | QueryParam,
34 | )
35 | from .prompt import GRAPH_FIELD_SEP, PROMPTS
36 |
37 |
38 | def chunking_by_token_size(
39 | content: str, overlap_token_size=128, max_token_size=1024, tiktoken_model="gpt-4o"
40 | ):
41 | tokens = encode_string_by_tiktoken(content, model_name=tiktoken_model)
42 | results = []
43 | for index, start in enumerate(
44 | range(0, len(tokens), max_token_size - overlap_token_size)
45 | ):
46 | chunk_content = decode_tokens_by_tiktoken(
47 | tokens[start : start + max_token_size], model_name=tiktoken_model
48 | )
49 | results.append(
50 | {
51 | "tokens": min(max_token_size, len(tokens) - start),
52 | "content": chunk_content.strip(),
53 | "chunk_order_index": index,
54 | }
55 | )
56 | return results
57 |
58 |
59 | async def _handle_entity_relation_summary(
60 | entity_or_relation_name: str,
61 | description: str,
62 | global_config: dict,
63 | ) -> str:
64 | use_llm_func: callable = global_config["llm_model_func"]
65 | llm_max_tokens = global_config["llm_model_max_token_size"]
66 | tiktoken_model_name = global_config["tiktoken_model_name"]
67 | summary_max_tokens = global_config["entity_summary_to_max_tokens"]
68 | language = global_config["addon_params"].get(
69 | "language", PROMPTS["DEFAULT_LANGUAGE"]
70 | )
71 |
72 | tokens = encode_string_by_tiktoken(description, model_name=tiktoken_model_name)
73 | if len(tokens) < summary_max_tokens:
74 | return description
75 | prompt_template = PROMPTS["summarize_entity_descriptions"]
76 | use_description = decode_tokens_by_tiktoken(
77 | tokens[:llm_max_tokens], model_name=tiktoken_model_name
78 | )
79 | context_base = dict(
80 | entity_name=entity_or_relation_name,
81 | description_list=use_description.split(GRAPH_FIELD_SEP),
82 | language=language,
83 | )
84 | use_prompt = prompt_template.format(**context_base)
85 | logger.debug(f"Trigger summary: {entity_or_relation_name}")
86 | summary = await use_llm_func(use_prompt, max_tokens=summary_max_tokens)
87 | return summary
88 |
89 |
90 | async def _handle_single_entity_extraction(
91 | record_attributes: list[str],
92 | chunk_key: str,
93 | ):
94 | if len(record_attributes) < 4 or record_attributes[0] != '"entity"':
95 | return None
96 |
97 | entity_name = clean_str(record_attributes[1].upper())
98 | if not entity_name.strip():
99 | return None
100 | entity_type = clean_str(record_attributes[2].upper())
101 | entity_description = clean_str(record_attributes[3])
102 | entity_source_id = chunk_key
103 | return dict(
104 | entity_name=entity_name,
105 | entity_type=entity_type,
106 | description=entity_description,
107 | source_id=entity_source_id,
108 | )
109 |
110 |
111 | async def _handle_single_relationship_extraction(
112 | record_attributes: list[str],
113 | chunk_key: str,
114 | ):
115 | if len(record_attributes) < 5 or record_attributes[0] != '"relationship"':
116 | return None
117 |
118 | source = clean_str(record_attributes[1].upper())
119 | target = clean_str(record_attributes[2].upper())
120 | edge_description = clean_str(record_attributes[3])
121 |
122 | edge_keywords = clean_str(record_attributes[4])
123 | edge_source_id = chunk_key
124 | weight = (
125 | float(record_attributes[-1]) if is_float_regex(record_attributes[-1]) else 1.0
126 | )
127 | return dict(
128 | src_id=source,
129 | tgt_id=target,
130 | weight=weight,
131 | description=edge_description,
132 | keywords=edge_keywords,
133 | source_id=edge_source_id,
134 | )
135 |
136 |
137 | async def _merge_nodes_then_upsert(
138 | entity_name: str,
139 | nodes_data: list[dict],
140 | knowledge_graph_inst: BaseGraphStorage,
141 | global_config: dict,
142 | ):
143 | already_entity_types = []
144 | already_source_ids = []
145 | already_description = []
146 |
147 | already_node = await knowledge_graph_inst.get_node(entity_name)
148 | if already_node is not None:
149 | already_entity_types.append(already_node["entity_type"])
150 | already_source_ids.extend(
151 | split_string_by_multi_markers(already_node["source_id"], [GRAPH_FIELD_SEP])
152 | )
153 | already_description.append(already_node["description"])
154 |
155 | entity_type = sorted(
156 | Counter(
157 | [dp["entity_type"] for dp in nodes_data] + already_entity_types
158 | ).items(),
159 | key=lambda x: x[1],
160 | reverse=True,
161 | )[0][0]
162 | description = GRAPH_FIELD_SEP.join(
163 | sorted(set([dp["description"] for dp in nodes_data] + already_description))
164 | )
165 | source_id = GRAPH_FIELD_SEP.join(
166 | set([dp["source_id"] for dp in nodes_data] + already_source_ids)
167 | )
168 | description = await _handle_entity_relation_summary(
169 | entity_name, description, global_config
170 | )
171 | node_data = dict(
172 | entity_type=entity_type,
173 | description=description,
174 | source_id=source_id,
175 | )
176 | await knowledge_graph_inst.upsert_node(
177 | entity_name,
178 | node_data=node_data,
179 | )
180 | node_data["entity_name"] = entity_name
181 | return node_data
182 |
183 |
184 | async def _merge_edges_then_upsert(
185 | src_id: str,
186 | tgt_id: str,
187 | edges_data: list[dict],
188 | knowledge_graph_inst: BaseGraphStorage,
189 | global_config: dict,
190 | ):
191 | already_weights = []
192 | already_source_ids = []
193 | already_description = []
194 | already_keywords = []
195 |
196 | if await knowledge_graph_inst.has_edge(src_id, tgt_id):
197 | already_edge = await knowledge_graph_inst.get_edge(src_id, tgt_id)
198 | already_weights.append(already_edge["weight"])
199 | already_source_ids.extend(
200 | split_string_by_multi_markers(already_edge["source_id"], [GRAPH_FIELD_SEP])
201 | )
202 | already_description.append(already_edge["description"])
203 | already_keywords.extend(
204 | split_string_by_multi_markers(already_edge["keywords"], [GRAPH_FIELD_SEP])
205 | )
206 |
207 | weight = sum([dp["weight"] for dp in edges_data] + already_weights)
208 | description = GRAPH_FIELD_SEP.join(
209 | sorted(set([dp["description"] for dp in edges_data] + already_description))
210 | )
211 | keywords = GRAPH_FIELD_SEP.join(
212 | sorted(set([dp["keywords"] for dp in edges_data] + already_keywords))
213 | )
214 | source_id = GRAPH_FIELD_SEP.join(
215 | set([dp["source_id"] for dp in edges_data] + already_source_ids)
216 | )
217 | for need_insert_id in [src_id, tgt_id]:
218 | if not (await knowledge_graph_inst.has_node(need_insert_id)):
219 | await knowledge_graph_inst.upsert_node(
220 | need_insert_id,
221 | node_data={
222 | "source_id": source_id,
223 | "description": description,
224 | "entity_type": '"UNKNOWN"',
225 | },
226 | )
227 | description = await _handle_entity_relation_summary(
228 | f"({src_id}, {tgt_id})", description, global_config
229 | )
230 | await knowledge_graph_inst.upsert_edge(
231 | src_id,
232 | tgt_id,
233 | edge_data=dict(
234 | weight=weight,
235 | description=description,
236 | keywords=keywords,
237 | source_id=source_id,
238 | ),
239 | )
240 |
241 | edge_data = dict(
242 | src_id=src_id,
243 | tgt_id=tgt_id,
244 | description=description,
245 | keywords=keywords,
246 | )
247 |
248 | return edge_data
249 |
250 |
251 | async def extract_entities(
252 | chunks: dict[str, TextChunkSchema],
253 | knowledge_graph_inst: BaseGraphStorage,
254 | entity_vdb: BaseVectorStorage,
255 | relationships_vdb: BaseVectorStorage,
256 | global_config: dict,
257 | ) -> Union[BaseGraphStorage, None]:
258 | time.sleep(20)
259 | use_llm_func: callable = global_config["llm_model_func"]
260 | entity_extract_max_gleaning = global_config["entity_extract_max_gleaning"]
261 |
262 | ordered_chunks = list(chunks.items())
263 |
264 | language = global_config["addon_params"].get(
265 | "language", PROMPTS["DEFAULT_LANGUAGE"]
266 | )
267 | entity_types = global_config["addon_params"].get(
268 | "entity_types", PROMPTS["DEFAULT_ENTITY_TYPES"]
269 | )
270 | example_number = global_config["addon_params"].get("example_number", None)
271 | if example_number and example_number < len(PROMPTS["entity_extraction_examples"]):
272 | examples = "\n".join(
273 | PROMPTS["entity_extraction_examples"][: int(example_number)]
274 | )
275 | else:
276 | examples = "\n".join(PROMPTS["entity_extraction_examples"])
277 |
278 | example_context_base = dict(
279 | tuple_delimiter=PROMPTS["DEFAULT_TUPLE_DELIMITER"],
280 | record_delimiter=PROMPTS["DEFAULT_RECORD_DELIMITER"],
281 | completion_delimiter=PROMPTS["DEFAULT_COMPLETION_DELIMITER"],
282 | entity_types=",".join(entity_types),
283 | language=language,
284 | )
285 |
286 | examples = examples.format(**example_context_base)
287 |
288 | entity_extract_prompt = PROMPTS["entity_extraction"]
289 | context_base = dict(
290 | tuple_delimiter=PROMPTS["DEFAULT_TUPLE_DELIMITER"],
291 | record_delimiter=PROMPTS["DEFAULT_RECORD_DELIMITER"],
292 | completion_delimiter=PROMPTS["DEFAULT_COMPLETION_DELIMITER"],
293 | entity_types=",".join(entity_types),
294 | examples=examples,
295 | language=language,
296 | )
297 |
298 | continue_prompt = PROMPTS["entiti_continue_extraction"]
299 | if_loop_prompt = PROMPTS["entiti_if_loop_extraction"]
300 |
301 | already_processed = 0
302 | already_entities = 0
303 | already_relations = 0
304 |
305 | async def _process_single_content(chunk_key_dp: tuple[str, TextChunkSchema]):
306 | nonlocal already_processed, already_entities, already_relations
307 | chunk_key = chunk_key_dp[0]
308 | chunk_dp = chunk_key_dp[1]
309 | content = chunk_dp["content"]
310 | hint_prompt = entity_extract_prompt.format(
311 | **context_base, input_text="{input_text}"
312 | ).format(**context_base, input_text=content)
313 |
314 | final_result = await use_llm_func(hint_prompt)
315 | history = pack_user_ass_to_openai_messages(hint_prompt, final_result)
316 | for now_glean_index in range(entity_extract_max_gleaning):
317 | glean_result = await use_llm_func(continue_prompt, history_messages=history)
318 |
319 | history += pack_user_ass_to_openai_messages(continue_prompt, glean_result)
320 | final_result += glean_result
321 | if now_glean_index == entity_extract_max_gleaning - 1:
322 | break
323 |
324 | if_loop_result: str = await use_llm_func(
325 | if_loop_prompt, history_messages=history
326 | )
327 | if_loop_result = if_loop_result.strip().strip('"').strip("'").lower()
328 | if if_loop_result != "yes":
329 | break
330 |
331 | records = split_string_by_multi_markers(
332 | final_result,
333 | [context_base["record_delimiter"], context_base["completion_delimiter"]],
334 | )
335 |
336 | maybe_nodes = defaultdict(list)
337 | maybe_edges = defaultdict(list)
338 | for record in records:
339 | record = re.search(r"\((.*)\)", record)
340 | if record is None:
341 | continue
342 | record = record.group(1)
343 | record_attributes = split_string_by_multi_markers(
344 | record, [context_base["tuple_delimiter"]]
345 | )
346 | if_entities = await _handle_single_entity_extraction(
347 | record_attributes, chunk_key
348 | )
349 | if if_entities is not None:
350 | maybe_nodes[if_entities["entity_name"]].append(if_entities)
351 | continue
352 |
353 | if_relation = await _handle_single_relationship_extraction(
354 | record_attributes, chunk_key
355 | )
356 | if if_relation is not None:
357 | maybe_edges[(if_relation["src_id"], if_relation["tgt_id"])].append(
358 | if_relation
359 | )
360 | already_processed += 1
361 | already_entities += len(maybe_nodes)
362 | already_relations += len(maybe_edges)
363 | now_ticks = PROMPTS["process_tickers"][
364 | already_processed % len(PROMPTS["process_tickers"])
365 | ]
366 | print(
367 | f"{now_ticks} Processed {already_processed} chunks, {already_entities} entities(duplicated), {already_relations} relations(duplicated)\r",
368 | end="",
369 | flush=True,
370 | )
371 | return dict(maybe_nodes), dict(maybe_edges)
372 |
373 | results = []
374 | for result in tqdm_async(
375 | asyncio.as_completed([_process_single_content(c) for c in ordered_chunks]),
376 | total=len(ordered_chunks),
377 | desc="Extracting entities from chunks",
378 | unit="chunk",
379 | ):
380 | results.append(await result)
381 |
382 | maybe_nodes = defaultdict(list)
383 | maybe_edges = defaultdict(list)
384 | for m_nodes, m_edges in results:
385 | for k, v in m_nodes.items():
386 | maybe_nodes[k].extend(v)
387 | for k, v in m_edges.items():
388 | maybe_edges[k].extend(v)
389 | logger.info("Inserting entities into storage...")
390 | all_entities_data = []
391 | for result in tqdm_async(
392 | asyncio.as_completed(
393 | [
394 | _merge_nodes_then_upsert(k, v, knowledge_graph_inst, global_config)
395 | for k, v in maybe_nodes.items()
396 | ]
397 | ),
398 | total=len(maybe_nodes),
399 | desc="Inserting entities",
400 | unit="entity",
401 | ):
402 | all_entities_data.append(await result)
403 |
404 | logger.info("Inserting relationships into storage...")
405 | all_relationships_data = []
406 | for result in tqdm_async(
407 | asyncio.as_completed(
408 | [
409 | _merge_edges_then_upsert(
410 | k[0], k[1], v, knowledge_graph_inst, global_config
411 | )
412 | for k, v in maybe_edges.items()
413 | ]
414 | ),
415 | total=len(maybe_edges),
416 | desc="Inserting relationships",
417 | unit="relationship",
418 | ):
419 | all_relationships_data.append(await result)
420 |
421 | if not len(all_entities_data) and not len(all_relationships_data):
422 | logger.warning(
423 | "Didn't extract any entities and relationships, maybe your LLM is not working"
424 | )
425 | return None
426 |
427 | if not len(all_entities_data):
428 | logger.warning("Didn't extract any entities")
429 | if not len(all_relationships_data):
430 | logger.warning("Didn't extract any relationships")
431 |
432 | if entity_vdb is not None:
433 | data_for_vdb = {
434 | compute_mdhash_id(dp["entity_name"], prefix="ent-"): {
435 | "content": dp["entity_name"] + dp["description"],
436 | "entity_name": dp["entity_name"],
437 | }
438 | for dp in all_entities_data
439 | }
440 | await entity_vdb.upsert(data_for_vdb)
441 |
442 | if relationships_vdb is not None:
443 | data_for_vdb = {
444 | compute_mdhash_id(dp["src_id"] + dp["tgt_id"], prefix="rel-"): {
445 | "src_id": dp["src_id"],
446 | "tgt_id": dp["tgt_id"],
447 | "content": dp["keywords"]
448 | + dp["src_id"]
449 | + dp["tgt_id"]
450 | + dp["description"],
451 | }
452 | for dp in all_relationships_data
453 | }
454 | await relationships_vdb.upsert(data_for_vdb)
455 |
456 | return knowledge_graph_inst
457 |
458 |
459 |
460 | async def kg_query(
461 | query,
462 | knowledge_graph_inst: BaseGraphStorage,
463 | entities_vdb: BaseVectorStorage,
464 | relationships_vdb: BaseVectorStorage,
465 | text_chunks_db: BaseKVStorage[TextChunkSchema],
466 | query_param: QueryParam,
467 | global_config: dict,
468 | hashing_kv: BaseKVStorage = None,
469 | ) -> str:
470 |
471 | use_model_func = global_config["llm_model_func"]
472 | args_hash = compute_args_hash(query_param.mode, query)
473 | cached_response, quantized, min_val, max_val = await handle_cache(
474 | hashing_kv, args_hash, query, query_param.mode
475 | )
476 | if cached_response is not None:
477 | return cached_response
478 |
479 | example_number = global_config["addon_params"].get("example_number", None)
480 | if example_number and example_number < len(PROMPTS["keywords_extraction_examples"]):
481 | examples = "\n".join(
482 | PROMPTS["keywords_extraction_examples"][: int(example_number)]
483 | )
484 | else:
485 | examples = "\n".join(PROMPTS["keywords_extraction_examples"])
486 | language = global_config["addon_params"].get(
487 | "language", PROMPTS["DEFAULT_LANGUAGE"]
488 | )
489 |
490 | if query_param.mode not in ["hybrid"]:
491 | logger.error(f"Unknown mode {query_param.mode} in kg_query")
492 | return PROMPTS["fail_response"]
493 |
494 |
495 | kw_prompt_temp = PROMPTS["keywords_extraction"]
496 | kw_prompt = kw_prompt_temp.format(query=query, examples=examples, language=language)
497 | result = await use_model_func(kw_prompt, keyword_extraction=True)
498 | logger.info("kw_prompt result:")
499 | print(result)
500 | try:
501 |
502 | match = re.search(r"\{.*\}", result, re.DOTALL)
503 | if match:
504 | result = match.group(0)
505 | keywords_data = json.loads(result)
506 |
507 | hl_keywords = keywords_data.get("high_level_keywords", [])
508 | ll_keywords = keywords_data.get("low_level_keywords", [])
509 | else:
510 | logger.error("No JSON-like structure found in the result.")
511 | return PROMPTS["fail_response"]
512 |
513 |
514 | except json.JSONDecodeError as e:
515 | print(f"JSON parsing error: {e} {result}")
516 | return PROMPTS["fail_response"]
517 |
518 |
519 | if hl_keywords == [] and ll_keywords == []:
520 | logger.warning("low_level_keywords and high_level_keywords is empty")
521 | return PROMPTS["fail_response"]
522 | if ll_keywords == [] and query_param.mode in ["hybrid"]:
523 | logger.warning("low_level_keywords is empty")
524 | return PROMPTS["fail_response"]
525 | else:
526 | ll_keywords = ", ".join(ll_keywords)
527 | if hl_keywords == [] and query_param.mode in ["hybrid"]:
528 | logger.warning("high_level_keywords is empty")
529 | return PROMPTS["fail_response"]
530 | else:
531 | hl_keywords = ", ".join(hl_keywords)
532 |
533 |
534 | keywords = [ll_keywords, hl_keywords]
535 | context= await _build_query_context(
536 | keywords,
537 | knowledge_graph_inst,
538 | entities_vdb,
539 | relationships_vdb,
540 | text_chunks_db,
541 | query_param,
542 | )
543 |
544 |
545 |
546 | if query_param.only_need_context:
547 | return context
548 | if context is None:
549 | return PROMPTS["fail_response"]
550 | sys_prompt_temp = PROMPTS["rag_response"]
551 | sys_prompt = sys_prompt_temp.format(
552 | context_data=context, response_type=query_param.response_type
553 | )
554 | if query_param.only_need_prompt:
555 | return sys_prompt
556 | response = await use_model_func(
557 | query,
558 | system_prompt=sys_prompt,
559 | stream=query_param.stream,
560 | )
561 | if isinstance(response, str) and len(response) > len(sys_prompt):
562 | response = (
563 | response.replace(sys_prompt, "")
564 | .replace("user", "")
565 | .replace("model", "")
566 | .replace(query, "")
567 | .replace("", "")
568 | .replace("", "")
569 | .strip()
570 | )
571 |
572 |
573 | await save_to_cache(
574 | hashing_kv,
575 | CacheData(
576 | args_hash=args_hash,
577 | content=response,
578 | prompt=query,
579 | quantized=quantized,
580 | min_val=min_val,
581 | max_val=max_val,
582 | mode=query_param.mode,
583 | ),
584 | )
585 | return response
586 |
587 |
588 | async def _build_query_context(
589 | query: list,
590 | knowledge_graph_inst: BaseGraphStorage,
591 | entities_vdb: BaseVectorStorage,
592 | relationships_vdb: BaseVectorStorage,
593 | text_chunks_db: BaseKVStorage[TextChunkSchema],
594 | query_param: QueryParam,
595 | ):
596 | ll_entities_context, ll_relations_context, ll_text_units_context = "", "", ""
597 | hl_entities_context, hl_relations_context, hl_text_units_context = "", "", ""
598 |
599 | ll_kewwords, hl_keywrds = query[0], query[1]
600 | if query_param.mode in ["local", "hybrid"]:
601 | if ll_kewwords == "":
602 | ll_entities_context, ll_relations_context, ll_text_units_context = (
603 | "",
604 | "",
605 | "",
606 | )
607 | warnings.warn(
608 | "Low Level context is None. Return empty Low entity/relationship/source"
609 | )
610 | query_param.mode = "global"
611 | else:
612 | (
613 | ll_entities_context,
614 | ll_relations_context,
615 | ll_text_units_context,
616 | ) = await _get_node_data(
617 | ll_kewwords,
618 | knowledge_graph_inst,
619 | entities_vdb,
620 | text_chunks_db,
621 | query_param,
622 | )
623 | if query_param.mode in ["hybrid"]:
624 | if hl_keywrds == "":
625 | hl_entities_context, hl_relations_context, hl_text_units_context = (
626 | "",
627 | "",
628 | "",
629 | )
630 | warnings.warn(
631 | "High Level context is None. Return empty High entity/relationship/source"
632 | )
633 | query_param.mode = "local"
634 | else:
635 | (
636 | hl_entities_context,
637 | hl_relations_context,
638 | hl_text_units_context,
639 | ) = await _get_edge_data(
640 | hl_keywrds,
641 | knowledge_graph_inst,
642 | relationships_vdb,
643 | text_chunks_db,
644 | query_param,
645 | )
646 | if (
647 | hl_entities_context == ""
648 | and hl_relations_context == ""
649 | and hl_text_units_context == ""
650 | ):
651 | logger.warn("No high level context found. Switching to local mode.")
652 | query_param.mode = "local"
653 | if query_param.mode == "hybrid":
654 | entities_context, relations_context, text_units_context = combine_contexts(
655 | [hl_entities_context, hl_relations_context],
656 | [ll_entities_context, ll_relations_context],
657 | [hl_text_units_context, ll_text_units_context],
658 | )
659 |
660 |
661 | return f"""
662 | -----global-information-----
663 | -----high-level entity information-----
664 | ```csv
665 | {hl_entities_context}
666 | ```
667 | -----high-level relationship information-----
668 | ```csv
669 | {hl_relations_context}
670 | ```
671 | -----Sources-----
672 | ```csv
673 | {text_units_context}
674 | ```
675 | -----local-information-----
676 | -----low-level entity information-----
677 | ```csv
678 | {ll_entities_context}
679 | ```
680 | -----low-level relationship information-----
681 | ```csv
682 | {ll_relations_context}
683 | ```
684 | """
685 |
686 | async def _get_node_data(
687 | query,
688 | knowledge_graph_inst: BaseGraphStorage,
689 | entities_vdb: BaseVectorStorage,
690 | text_chunks_db: BaseKVStorage[TextChunkSchema],
691 | query_param: QueryParam,
692 | ):
693 |
694 | results = await entities_vdb.query(query, top_k=query_param.top_k)
695 | if not len(results):
696 | return "", "", ""
697 |
698 | node_datas = await asyncio.gather(
699 | *[knowledge_graph_inst.get_node(r["entity_name"]) for r in results]
700 | )
701 | if not all([n is not None for n in node_datas]):
702 | logger.warning("Some nodes are missing, maybe the storage is damaged")
703 |
704 |
705 | node_degrees = await asyncio.gather(
706 | *[knowledge_graph_inst.node_degree(r["entity_name"]) for r in results]
707 | )
708 | node_datas = [
709 | {**n, "entity_name": k["entity_name"], "rank": d}
710 | for k, n, d in zip(results, node_datas, node_degrees)
711 | if n is not None
712 | ]
713 | use_text_units = await _find_most_related_text_unit_from_entities(
714 | node_datas, query_param, text_chunks_db, knowledge_graph_inst
715 | )
716 |
717 |
718 | use_relations= await _find_most_related_edges_from_entities3(
719 | node_datas, query_param, knowledge_graph_inst
720 | )
721 |
722 | logger.info(
723 | f"Local query uses {len(node_datas)} entites, {len(use_relations)} relations, {len(use_text_units)} text units"
724 | )
725 |
726 |
727 | entites_section_list = [["id", "entity", "type", "description", "rank"]]
728 | for i, n in enumerate(node_datas):
729 | entites_section_list.append(
730 | [
731 | i,
732 | n["entity_name"],
733 | n.get("entity_type", "UNKNOWN"),
734 | n.get("description", "UNKNOWN"),
735 | n["rank"],
736 | ]
737 | )
738 | entities_context = list_of_list_to_csv(entites_section_list)
739 |
740 | relations_section_list=[["id","context"]]
741 | for i,e in enumerate(use_relations):
742 | relations_section_list.append([i,e])
743 | relations_context=list_of_list_to_csv(relations_section_list)
744 |
745 | text_units_section_list = [["id", "content"]]
746 | for i, t in enumerate(use_text_units):
747 | text_units_section_list.append([i, t["content"]])
748 | text_units_context = list_of_list_to_csv(text_units_section_list)
749 |
750 | return entities_context,relations_context,text_units_context
751 |
752 |
753 | async def _find_most_related_text_unit_from_entities(
754 | node_datas: list[dict],
755 | query_param: QueryParam,
756 | text_chunks_db: BaseKVStorage[TextChunkSchema],
757 | knowledge_graph_inst: BaseGraphStorage,
758 | ):
759 | text_units = [
760 | split_string_by_multi_markers(dp["source_id"], [GRAPH_FIELD_SEP])
761 | for dp in node_datas
762 | ]
763 | edges = await asyncio.gather(
764 | *[knowledge_graph_inst.get_node_edges(dp["entity_name"]) for dp in node_datas]
765 | )
766 | all_one_hop_nodes = set()
767 | for this_edges in edges:
768 | if not this_edges:
769 | continue
770 | all_one_hop_nodes.update([e[1] for e in this_edges])
771 |
772 | all_one_hop_nodes = list(all_one_hop_nodes)
773 | all_one_hop_nodes_data = await asyncio.gather(
774 | *[knowledge_graph_inst.get_node(e) for e in all_one_hop_nodes]
775 | )
776 |
777 |
778 | all_one_hop_text_units_lookup = {
779 | k: set(split_string_by_multi_markers(v["source_id"], [GRAPH_FIELD_SEP]))
780 | for k, v in zip(all_one_hop_nodes, all_one_hop_nodes_data)
781 | if v is not None and "source_id" in v
782 | }
783 |
784 | all_text_units_lookup = {}
785 | for index, (this_text_units, this_edges) in enumerate(zip(text_units, edges)):
786 | for c_id in this_text_units:
787 | if c_id not in all_text_units_lookup:
788 | all_text_units_lookup[c_id] = {
789 | "data": await text_chunks_db.get_by_id(c_id),
790 | "order": index,
791 | "relation_counts": 0,
792 | }
793 |
794 | if this_edges:
795 | for e in this_edges:
796 | if (
797 | e[1] in all_one_hop_text_units_lookup
798 | and c_id in all_one_hop_text_units_lookup[e[1]]
799 | ):
800 | all_text_units_lookup[c_id]["relation_counts"] += 1
801 |
802 |
803 | all_text_units = [
804 | {"id": k, **v}
805 | for k, v in all_text_units_lookup.items()
806 | if v is not None and v.get("data") is not None and "content" in v["data"]
807 | ]
808 |
809 | if not all_text_units:
810 | logger.warning("No valid text units found")
811 | return []
812 |
813 | all_text_units = sorted(
814 | all_text_units, key=lambda x: (x["order"], -x["relation_counts"])
815 | )
816 |
817 | all_text_units = truncate_list_by_token_size(
818 | all_text_units,
819 | key=lambda x: x["data"]["content"],
820 | max_token_size=query_param.max_token_for_text_unit,
821 | )
822 |
823 | all_text_units = [t["data"] for t in all_text_units]
824 | return all_text_units
825 |
826 | async def _get_edge_data(
827 | keywords,
828 | knowledge_graph_inst: BaseGraphStorage,
829 | relationships_vdb: BaseVectorStorage,
830 | text_chunks_db: BaseKVStorage[TextChunkSchema],
831 | query_param: QueryParam,
832 | ):
833 | results = await relationships_vdb.query(keywords, top_k=query_param.top_k)
834 |
835 | if not len(results):
836 | return "", "", ""
837 |
838 | edge_datas = await asyncio.gather(
839 | *[knowledge_graph_inst.get_edge(r["src_id"], r["tgt_id"]) for r in results]
840 | )
841 |
842 | if not all([n is not None for n in edge_datas]):
843 | logger.warning("Some edges are missing, maybe the storage is damaged")
844 | edge_degree = await asyncio.gather(
845 | *[knowledge_graph_inst.edge_degree(r["src_id"], r["tgt_id"]) for r in results]
846 | )
847 | edge_datas = [
848 | {"src_id": k["src_id"], "tgt_id": k["tgt_id"], "rank": d, **v}
849 | for k, v, d in zip(results, edge_datas, edge_degree)
850 | if v is not None
851 | ]
852 | edge_datas = sorted(
853 | edge_datas, key=lambda x: (x["rank"], x["weight"]), reverse=True
854 | )
855 | edge_datas = truncate_list_by_token_size(
856 | edge_datas,
857 | key=lambda x: x["description"],
858 | max_token_size=query_param.max_token_for_global_context,
859 | )
860 |
861 | use_entities = await _find_most_related_entities_from_relationships(
862 | edge_datas, query_param, knowledge_graph_inst
863 | )
864 | use_text_units = await _find_related_text_unit_from_relationships(
865 | edge_datas, query_param, text_chunks_db, knowledge_graph_inst
866 | )
867 | logger.info(
868 | f"Global query uses {len(use_entities)} entites, {len(edge_datas)} relations, {len(use_text_units)} text units"
869 | )
870 |
871 | relations_section_list = [
872 | ["id", "source", "target", "description", "keywords", "weight", "rank"]
873 | ]
874 | for i, e in enumerate(edge_datas):
875 | relations_section_list.append(
876 | [
877 | i,
878 | e["src_id"],
879 | e["tgt_id"],
880 | e["description"],
881 | e["keywords"],
882 | e["weight"],
883 | e["rank"],
884 | ]
885 | )
886 | relations_context = list_of_list_to_csv(relations_section_list)
887 |
888 | entites_section_list = [["id", "entity", "type", "description", "rank"]]
889 | for i, n in enumerate(use_entities):
890 | entites_section_list.append(
891 | [
892 | i,
893 | n["entity_name"],
894 | n.get("entity_type", "UNKNOWN"),
895 | n.get("description", "UNKNOWN"),
896 | n["rank"],
897 | ]
898 | )
899 | entities_context = list_of_list_to_csv(entites_section_list)
900 |
901 | text_units_section_list = [["id", "content"]]
902 | for i, t in enumerate(use_text_units):
903 | text_units_section_list.append([i, t["content"]])
904 | text_units_context = list_of_list_to_csv(text_units_section_list)
905 | return entities_context, relations_context, text_units_context
906 |
907 |
908 | async def _find_most_related_entities_from_relationships(
909 | edge_datas: list[dict],
910 | query_param: QueryParam,
911 | knowledge_graph_inst: BaseGraphStorage,
912 | ):
913 | entity_names = []
914 | seen = set()
915 |
916 | for e in edge_datas:
917 | if e["src_id"] not in seen:
918 | entity_names.append(e["src_id"])
919 | seen.add(e["src_id"])
920 | if e["tgt_id"] not in seen:
921 | entity_names.append(e["tgt_id"])
922 | seen.add(e["tgt_id"])
923 |
924 | node_datas = await asyncio.gather(
925 | *[knowledge_graph_inst.get_node(entity_name) for entity_name in entity_names]
926 | )
927 |
928 | node_degrees = await asyncio.gather(
929 | *[knowledge_graph_inst.node_degree(entity_name) for entity_name in entity_names]
930 | )
931 | node_datas = [
932 | {**n, "entity_name": k, "rank": d}
933 | for k, n, d in zip(entity_names, node_datas, node_degrees)
934 | ]
935 |
936 | node_datas = truncate_list_by_token_size(
937 | node_datas,
938 | key=lambda x: x["description"],
939 | max_token_size=query_param.max_token_for_local_context,
940 | )
941 |
942 | return node_datas
943 |
944 |
945 | async def _find_related_text_unit_from_relationships(
946 | edge_datas: list[dict],
947 | query_param: QueryParam,
948 | text_chunks_db: BaseKVStorage[TextChunkSchema],
949 | knowledge_graph_inst: BaseGraphStorage,
950 | ):
951 | text_units = [
952 | split_string_by_multi_markers(dp["source_id"], [GRAPH_FIELD_SEP])
953 | for dp in edge_datas
954 | ]
955 | all_text_units_lookup = {}
956 |
957 | for index, unit_list in enumerate(text_units):
958 | for c_id in unit_list:
959 | if c_id not in all_text_units_lookup:
960 | chunk_data = await text_chunks_db.get_by_id(c_id)
961 |
962 | if chunk_data is not None and "content" in chunk_data:
963 | all_text_units_lookup[c_id] = {
964 | "data": chunk_data,
965 | "order": index,
966 | }
967 |
968 | if not all_text_units_lookup:
969 | logger.warning("No valid text chunks found")
970 | return []
971 |
972 | all_text_units = [{"id": k, **v} for k, v in all_text_units_lookup.items()]
973 | all_text_units = sorted(all_text_units, key=lambda x: x["order"])
974 |
975 |
976 | valid_text_units = [
977 | t for t in all_text_units if t["data"] is not None and "content" in t["data"]
978 | ]
979 |
980 | if not valid_text_units:
981 | logger.warning("No valid text chunks after filtering")
982 | return []
983 |
984 | truncated_text_units = truncate_list_by_token_size(
985 | valid_text_units,
986 | key=lambda x: x["data"]["content"],
987 | max_token_size=query_param.max_token_for_text_unit,
988 | )
989 |
990 | all_text_units: list[TextChunkSchema] = [t["data"] for t in truncated_text_units]
991 |
992 | return all_text_units
993 |
994 |
995 | def combine_contexts(entities, relationships, sources):
996 |
997 | hl_entities, ll_entities = entities[0], entities[1]
998 | hl_relationships, ll_relationships = relationships[0], relationships[1]
999 | hl_sources, ll_sources = sources[0], sources[1]
1000 |
1001 | combined_entities = process_combine_contexts(hl_entities, ll_entities)
1002 |
1003 | combined_relationships = process_combine_contexts(
1004 | hl_relationships, ll_relationships
1005 | )
1006 |
1007 | combined_sources = process_combine_contexts(hl_sources, ll_sources)
1008 |
1009 | return combined_entities, combined_relationships, combined_sources
1010 |
1011 |
1012 | import networkx as nx
1013 | from collections import defaultdict
1014 | async def find_paths_and_edges_with_stats(graph, target_nodes):
1015 |
1016 | result = defaultdict(lambda: {"paths": [], "edges": set()})
1017 | path_stats = {"1-hop": 0, "2-hop": 0, "3-hop": 0}
1018 | one_hop_paths = []
1019 | two_hop_paths = []
1020 | three_hop_paths = []
1021 |
1022 | async def dfs(current, target, path, depth):
1023 |
1024 | if depth > 3:
1025 | return
1026 | if current == target:
1027 | result[(path[0], target)]["paths"].append(list(path))
1028 | for u, v in zip(path[:-1], path[1:]):
1029 | result[(path[0], target)]["edges"].add(tuple(sorted((u, v))))
1030 | if depth == 1:
1031 | path_stats["1-hop"] += 1
1032 | one_hop_paths.append(list(path))
1033 | elif depth == 2:
1034 | path_stats["2-hop"] += 1
1035 | two_hop_paths.append(list(path))
1036 | elif depth == 3:
1037 | path_stats["3-hop"] += 1
1038 | three_hop_paths.append(list(path))
1039 | return
1040 | neighbors = graph.neighbors(current)
1041 | for neighbor in neighbors:
1042 | if neighbor not in path:
1043 | await dfs(neighbor, target, path + [neighbor], depth + 1)
1044 |
1045 | for node1 in target_nodes:
1046 | for node2 in target_nodes:
1047 | if node1 != node2:
1048 | await dfs(node1, node2, [node1], 0)
1049 |
1050 | for key in result:
1051 | result[key]["edges"] = list(result[key]["edges"])
1052 |
1053 | return dict(result), path_stats , one_hop_paths, two_hop_paths, three_hop_paths
1054 | def bfs_weighted_paths(G, path, source, target, threshold, alpha):
1055 | results = []
1056 | edge_weights = defaultdict(float)
1057 | node = source
1058 | follow_dict = {}
1059 |
1060 | for p in path:
1061 | for i in range(len(p) - 1):
1062 | current = p[i]
1063 | next_num = p[i + 1]
1064 |
1065 | if current in follow_dict:
1066 | follow_dict[current].add(next_num)
1067 | else:
1068 | follow_dict[current] = {next_num}
1069 |
1070 | for neighbor in follow_dict[node]:
1071 | edge_weights[(node, neighbor)] += 1/len(follow_dict[node])
1072 |
1073 | if neighbor == target:
1074 | results.append(([node, neighbor]))
1075 | continue
1076 |
1077 | if edge_weights[(node, neighbor)] > threshold:
1078 |
1079 | for second_neighbor in follow_dict[neighbor]:
1080 | weight = edge_weights[(node, neighbor)] * alpha / len(follow_dict[neighbor])
1081 | edge_weights[(neighbor, second_neighbor)] += weight
1082 |
1083 | if second_neighbor == target:
1084 | results.append(([node, neighbor, second_neighbor]))
1085 | continue
1086 |
1087 | if edge_weights[(neighbor, second_neighbor)] > threshold:
1088 |
1089 | for third_neighbor in follow_dict[second_neighbor]:
1090 | weight = edge_weights[(neighbor, second_neighbor)] * alpha / len(follow_dict[second_neighbor])
1091 | edge_weights[(second_neighbor, third_neighbor)] += weight
1092 |
1093 | if third_neighbor == target :
1094 | results.append(([node, neighbor, second_neighbor, third_neighbor]))
1095 | continue
1096 | path_weights = []
1097 | for p in path:
1098 | path_weight = 0
1099 | for i in range(len(p) - 1):
1100 | edge = (p[i], p[i + 1])
1101 | path_weight += edge_weights.get(edge, 0)
1102 | path_weights.append(path_weight/(len(p)-1))
1103 |
1104 | combined = [(p, w) for p, w in zip(path, path_weights)]
1105 |
1106 | return combined
1107 | async def _find_most_related_edges_from_entities3(
1108 | node_datas: list[dict],
1109 | query_param: QueryParam,
1110 | knowledge_graph_inst: BaseGraphStorage,
1111 | ):
1112 |
1113 | G = nx.Graph()
1114 | edges = await knowledge_graph_inst.edges()
1115 | nodes = await knowledge_graph_inst.nodes()
1116 |
1117 | for u, v in edges:
1118 | G.add_edge(u, v)
1119 | G.add_nodes_from(nodes)
1120 | source_nodes = [dp["entity_name"] for dp in node_datas]
1121 | result, path_stats, one_hop_paths, two_hop_paths, three_hop_paths = await find_paths_and_edges_with_stats(G, source_nodes)
1122 |
1123 |
1124 | threshold = 0.3
1125 | alpha = 0.8
1126 | all_results = []
1127 |
1128 | for node1 in source_nodes:
1129 | for node2 in source_nodes:
1130 | if node1 != node2:
1131 | if (node1, node2) in result:
1132 | sub_G = nx.Graph()
1133 | paths = result[(node1,node2)]['paths']
1134 | edges = result[(node1,node2)]['edges']
1135 | sub_G.add_edges_from(edges)
1136 | results = bfs_weighted_paths(G, paths, node1, node2, threshold, alpha)
1137 | all_results+= results
1138 | all_results = sorted(all_results, key=lambda x: x[1], reverse=True)
1139 | seen = set()
1140 | result_edge = []
1141 | for edge, weight in all_results:
1142 | sorted_edge = tuple(sorted(edge))
1143 | if sorted_edge not in seen:
1144 | seen.add(sorted_edge)
1145 | result_edge.append((edge, weight))
1146 |
1147 |
1148 | length_1 = int(len(one_hop_paths)/2)
1149 | length_2 = int(len(two_hop_paths)/2)
1150 | length_3 = int(len(three_hop_paths)/2)
1151 | results = []
1152 | if one_hop_paths!=[]:
1153 | results = one_hop_paths[0:length_1]
1154 | if two_hop_paths!=[]:
1155 | results = results + two_hop_paths[0:length_2]
1156 | if three_hop_paths!=[]:
1157 | results =results + three_hop_paths[0:length_3]
1158 |
1159 | length = len(results)
1160 | total_edges = 15
1161 | if length < total_edges:
1162 | total_edges = length
1163 | sort_result = []
1164 | if result_edge:
1165 | if len(result_edge)>total_edges:
1166 | sort_result = result_edge[0:total_edges]
1167 | else :
1168 | sort_result = result_edge
1169 | final_result = []
1170 | for edge, weight in sort_result:
1171 | final_result.append(edge)
1172 |
1173 | relationship = []
1174 |
1175 | for path in final_result:
1176 | if len(path) == 4:
1177 | s_name,b1_name,b2_name,t_name = path[0],path[1],path[2],path[3]
1178 | edge0 = await knowledge_graph_inst.get_edge(path[0], path[1]) or await knowledge_graph_inst.get_edge(path[1], path[0])
1179 | edge1 = await knowledge_graph_inst.get_edge(path[1],path[2]) or await knowledge_graph_inst.get_edge(path[2], path[1])
1180 | edge2 = await knowledge_graph_inst.get_edge(path[2],path[3]) or await knowledge_graph_inst.get_edge(path[3], path[2])
1181 | if edge0==None or edge1==None or edge2==None:
1182 | print(path,"边丢失")
1183 | if edge0==None:
1184 | print("edge0丢失")
1185 | if edge1==None:
1186 | print("edge1丢失")
1187 | if edge2==None:
1188 | print("edge2丢失")
1189 | continue
1190 | e1 = "through edge ("+edge0["keywords"]+") to connect to "+s_name+" and "+b1_name+"."
1191 | e2 = "through edge ("+edge1["keywords"]+") to connect to "+b1_name+" and "+b2_name+"."
1192 | e3 = "through edge ("+edge2["keywords"]+") to connect to "+b2_name+" and "+t_name+"."
1193 | s = await knowledge_graph_inst.get_node(s_name)
1194 | s = "The entity "+s_name+" is a "+s["entity_type"]+" with the description("+s["description"]+")"
1195 | b1 = await knowledge_graph_inst.get_node(b1_name)
1196 | b1 = "The entity "+b1_name+" is a "+b1["entity_type"]+" with the description("+b1["description"]+")"
1197 | b2 = await knowledge_graph_inst.get_node(b2_name)
1198 | b2 = "The entity "+b2_name+" is a "+b2["entity_type"]+" with the description("+b2["description"]+")"
1199 | t = await knowledge_graph_inst.get_node(t_name)
1200 | t = "The entity "+t_name+" is a "+t["entity_type"]+" with the description("+t["description"]+")"
1201 | relationship.append([s+e1+b1+"and"+b1+e2+b2+"and"+b2+e3+t])
1202 | elif len(path) == 3:
1203 | s_name,b_name,t_name = path[0],path[1],path[2]
1204 | edge0 = await knowledge_graph_inst.get_edge(path[0], path[1]) or await knowledge_graph_inst.get_edge(path[1], path[0])
1205 | edge1 = await knowledge_graph_inst.get_edge(path[1],path[2]) or await knowledge_graph_inst.get_edge(path[2], path[1])
1206 | if edge0==None or edge1==None:
1207 | print(path,"边丢失")
1208 | continue
1209 | e1 = "through edge("+edge0["keywords"]+") to connect to "+s_name+" and "+b_name+"."
1210 | e2 = "through edge("+edge1["keywords"]+") to connect to "+b_name+" and "+t_name+"."
1211 | s = await knowledge_graph_inst.get_node(s_name)
1212 | s = "The entity "+s_name+" is a "+s["entity_type"]+" with the description("+s["description"]+")"
1213 | b = await knowledge_graph_inst.get_node(b_name)
1214 | b = "The entity "+b_name+" is a "+b["entity_type"]+" with the description("+b["description"]+")"
1215 | t = await knowledge_graph_inst.get_node(t_name)
1216 | t = "The entity "+t_name+" is a "+t["entity_type"]+" with the description("+t["description"]+")"
1217 | relationship.append([s+e1+b+"and"+b+e2+t])
1218 | elif len(path) == 2:
1219 | s_name,t_name = path[0],path[1]
1220 | edge0 = await knowledge_graph_inst.get_edge(path[0], path[1]) or await knowledge_graph_inst.get_edge(path[1], path[0])
1221 | if edge0==None:
1222 | print(path,"边丢失")
1223 | continue
1224 | e = "through edge("+edge0["keywords"]+") to connect to "+s_name+" and "+t_name+"."
1225 | s = await knowledge_graph_inst.get_node(s_name)
1226 | s = "The entity "+s_name+" is a "+s["entity_type"]+" with the description("+s["description"]+")"
1227 | t = await knowledge_graph_inst.get_node(t_name)
1228 | t = "The entity "+t_name+" is a "+t["entity_type"]+" with the description("+t["description"]+")"
1229 | relationship.append([s+e+t])
1230 |
1231 |
1232 | relationship = truncate_list_by_token_size(
1233 | relationship,
1234 | key=lambda x: x[0],
1235 | max_token_size=query_param.max_token_for_local_context,
1236 | )
1237 |
1238 | reversed_relationship = relationship[::-1]
1239 | return reversed_relationship
--------------------------------------------------------------------------------
/PathRAG/prompt.py:
--------------------------------------------------------------------------------
1 | GRAPH_FIELD_SEP = ""
2 |
3 | PROMPTS = {}
4 |
5 | PROMPTS["DEFAULT_LANGUAGE"] = "English"
6 | PROMPTS["DEFAULT_TUPLE_DELIMITER"] = "<|>"
7 | PROMPTS["DEFAULT_RECORD_DELIMITER"] = "##"
8 | PROMPTS["DEFAULT_COMPLETION_DELIMITER"] = "<|COMPLETE|>"
9 | PROMPTS["process_tickers"] = ["⠋", "⠙", "⠹", "⠸", "⠼", "⠴", "⠦", "⠧", "⠇", "⠏"]
10 |
11 | PROMPTS["DEFAULT_ENTITY_TYPES"] = ["organization", "person", "geo", "event", "category"]
12 |
13 | PROMPTS["entity_extraction"] = """-Goal-
14 | Given a text document that is potentially relevant to this activity and a list of entity types, identify all entities of those types from the text and all relationships among the identified entities.
15 | Use {language} as output language.
16 |
17 | -Steps-
18 | 1. Identify all entities. For each identified entity, extract the following information:
19 | - entity_name: Name of the entity, use same language as input text. If English, capitalized the name.
20 | - entity_type: One of the following types: [{entity_types}]
21 | - entity_description: Comprehensive description of the entity's attributes and activities
22 | Format each entity as ("entity"{tuple_delimiter}{tuple_delimiter}{tuple_delimiter})
23 |
24 | 2. From the entities identified in step 1, identify all pairs of (source_entity, target_entity) that are *clearly related* to each other.
25 | For each pair of related entities, extract the following information:
26 | - source_entity: name of the source entity, as identified in step 1
27 | - target_entity: name of the target entity, as identified in step 1
28 | - relationship_description: explanation as to why you think the source entity and the target entity are related to each other
29 | - relationship_strength: a numeric score indicating strength of the relationship between the source entity and target entity
30 | - relationship_keywords: one or more high-level key words that summarize the overarching nature of the relationship, focusing on concepts or themes rather than specific details
31 | Format each relationship as ("relationship"{tuple_delimiter}{tuple_delimiter}{tuple_delimiter}{tuple_delimiter}{tuple_delimiter})
32 |
33 | 3. Identify high-level key words that summarize the main concepts, themes, or topics of the entire text. These should capture the overarching ideas present in the document.
34 | Format the content-level key words as ("content_keywords"{tuple_delimiter})
35 |
36 | 4. Return output in {language} as a single list of all the entities and relationships identified in steps 1 and 2. Use **{record_delimiter}** as the list delimiter.
37 |
38 | 5. When finished, output {completion_delimiter}
39 |
40 | ######################
41 | -Examples-
42 | ######################
43 | {examples}
44 |
45 | #############################
46 | -Real Data-
47 | ######################
48 | Entity_types: {entity_types}
49 | Text: {input_text}
50 | ######################
51 | Output:
52 | """
53 |
54 | PROMPTS["entity_extraction_examples"] = [
55 | """Example 1:
56 |
57 | Entity_types: [person, technology, mission, organization, location]
58 | Text:
59 | while Alex clenched his jaw, the buzz of frustration dull against the backdrop of Taylor's authoritarian certainty. It was this competitive undercurrent that kept him alert, the sense that his and Jordan's shared commitment to discovery was an unspoken rebellion against Cruz's narrowing vision of control and order.
60 |
61 | Then Taylor did something unexpected. They paused beside Jordan and, for a moment, observed the device with something akin to reverence. “If this tech can be understood..." Taylor said, their voice quieter, "It could change the game for us. For all of us.”
62 |
63 | The underlying dismissal earlier seemed to falter, replaced by a glimpse of reluctant respect for the gravity of what lay in their hands. Jordan looked up, and for a fleeting heartbeat, their eyes locked with Taylor's, a wordless clash of wills softening into an uneasy truce.
64 |
65 | It was a small transformation, barely perceptible, but one that Alex noted with an inward nod. They had all been brought here by different paths
66 | ################
67 | Output:
68 | ("entity"{tuple_delimiter}"Alex"{tuple_delimiter}"person"{tuple_delimiter}"Alex is a character who experiences frustration and is observant of the dynamics among other characters."){record_delimiter}
69 | ("entity"{tuple_delimiter}"Taylor"{tuple_delimiter}"person"{tuple_delimiter}"Taylor is portrayed with authoritarian certainty and shows a moment of reverence towards a device, indicating a change in perspective."){record_delimiter}
70 | ("entity"{tuple_delimiter}"Jordan"{tuple_delimiter}"person"{tuple_delimiter}"Jordan shares a commitment to discovery and has a significant interaction with Taylor regarding a device."){record_delimiter}
71 | ("entity"{tuple_delimiter}"Cruz"{tuple_delimiter}"person"{tuple_delimiter}"Cruz is associated with a vision of control and order, influencing the dynamics among other characters."){record_delimiter}
72 | ("entity"{tuple_delimiter}"The Device"{tuple_delimiter}"technology"{tuple_delimiter}"The Device is central to the story, with potential game-changing implications, and is revered by Taylor."){record_delimiter}
73 | ("relationship"{tuple_delimiter}"Alex"{tuple_delimiter}"Taylor"{tuple_delimiter}"Alex is affected by Taylor's authoritarian certainty and observes changes in Taylor's attitude towards the device."{tuple_delimiter}"power dynamics, perspective shift"{tuple_delimiter}7){record_delimiter}
74 | ("relationship"{tuple_delimiter}"Alex"{tuple_delimiter}"Jordan"{tuple_delimiter}"Alex and Jordan share a commitment to discovery, which contrasts with Cruz's vision."{tuple_delimiter}"shared goals, rebellion"{tuple_delimiter}6){record_delimiter}
75 | ("relationship"{tuple_delimiter}"Taylor"{tuple_delimiter}"Jordan"{tuple_delimiter}"Taylor and Jordan interact directly regarding the device, leading to a moment of mutual respect and an uneasy truce."{tuple_delimiter}"conflict resolution, mutual respect"{tuple_delimiter}8){record_delimiter}
76 | ("relationship"{tuple_delimiter}"Jordan"{tuple_delimiter}"Cruz"{tuple_delimiter}"Jordan's commitment to discovery is in rebellion against Cruz's vision of control and order."{tuple_delimiter}"ideological conflict, rebellion"{tuple_delimiter}5){record_delimiter}
77 | ("relationship"{tuple_delimiter}"Taylor"{tuple_delimiter}"The Device"{tuple_delimiter}"Taylor shows reverence towards the device, indicating its importance and potential impact."{tuple_delimiter}"reverence, technological significance"{tuple_delimiter}9){record_delimiter}
78 | ("content_keywords"{tuple_delimiter}"power dynamics, ideological conflict, discovery, rebellion"){completion_delimiter}
79 | #############################""",
80 | """Example 2:
81 |
82 | Entity_types: [person, technology, mission, organization, location]
83 | Text:
84 | They were no longer mere operatives; they had become guardians of a threshold, keepers of a message from a realm beyond stars and stripes. This elevation in their mission could not be shackled by regulations and established protocols—it demanded a new perspective, a new resolve.
85 |
86 | Tension threaded through the dialogue of beeps and static as communications with Washington buzzed in the background. The team stood, a portentous air enveloping them. It was clear that the decisions they made in the ensuing hours could redefine humanity's place in the cosmos or condemn them to ignorance and potential peril.
87 |
88 | Their connection to the stars solidified, the group moved to address the crystallizing warning, shifting from passive recipients to active participants. Mercer's latter instincts gained precedence— the team's mandate had evolved, no longer solely to observe and report but to interact and prepare. A metamorphosis had begun, and Operation: Dulce hummed with the newfound frequency of their daring, a tone set not by the earthly
89 | #############
90 | Output:
91 | ("entity"{tuple_delimiter}"Washington"{tuple_delimiter}"location"{tuple_delimiter}"Washington is a location where communications are being received, indicating its importance in the decision-making process."){record_delimiter}
92 | ("entity"{tuple_delimiter}"Operation: Dulce"{tuple_delimiter}"mission"{tuple_delimiter}"Operation: Dulce is described as a mission that has evolved to interact and prepare, indicating a significant shift in objectives and activities."){record_delimiter}
93 | ("entity"{tuple_delimiter}"The team"{tuple_delimiter}"organization"{tuple_delimiter}"The team is portrayed as a group of individuals who have transitioned from passive observers to active participants in a mission, showing a dynamic change in their role."){record_delimiter}
94 | ("relationship"{tuple_delimiter}"The team"{tuple_delimiter}"Washington"{tuple_delimiter}"The team receives communications from Washington, which influences their decision-making process."{tuple_delimiter}"decision-making, external influence"{tuple_delimiter}7){record_delimiter}
95 | ("relationship"{tuple_delimiter}"The team"{tuple_delimiter}"Operation: Dulce"{tuple_delimiter}"The team is directly involved in Operation: Dulce, executing its evolved objectives and activities."{tuple_delimiter}"mission evolution, active participation"{tuple_delimiter}9){completion_delimiter}
96 | ("content_keywords"{tuple_delimiter}"mission evolution, decision-making, active participation, cosmic significance"){completion_delimiter}
97 | #############################""",
98 | """Example 3:
99 |
100 | Entity_types: [person, role, technology, organization, event, location, concept]
101 | Text:
102 | their voice slicing through the buzz of activity. "Control may be an illusion when facing an intelligence that literally writes its own rules," they stated stoically, casting a watchful eye over the flurry of data.
103 |
104 | "It's like it's learning to communicate," offered Sam Rivera from a nearby interface, their youthful energy boding a mix of awe and anxiety. "This gives talking to strangers' a whole new meaning."
105 |
106 | Alex surveyed his team—each face a study in concentration, determination, and not a small measure of trepidation. "This might well be our first contact," he acknowledged, "And we need to be ready for whatever answers back."
107 |
108 | Together, they stood on the edge of the unknown, forging humanity's response to a message from the heavens. The ensuing silence was palpable—a collective introspection about their role in this grand cosmic play, one that could rewrite human history.
109 |
110 | The encrypted dialogue continued to unfold, its intricate patterns showing an almost uncanny anticipation
111 | #############
112 | Output:
113 | ("entity"{tuple_delimiter}"Sam Rivera"{tuple_delimiter}"person"{tuple_delimiter}"Sam Rivera is a member of a team working on communicating with an unknown intelligence, showing a mix of awe and anxiety."){record_delimiter}
114 | ("entity"{tuple_delimiter}"Alex"{tuple_delimiter}"person"{tuple_delimiter}"Alex is the leader of a team attempting first contact with an unknown intelligence, acknowledging the significance of their task."){record_delimiter}
115 | ("entity"{tuple_delimiter}"Control"{tuple_delimiter}"concept"{tuple_delimiter}"Control refers to the ability to manage or govern, which is challenged by an intelligence that writes its own rules."){record_delimiter}
116 | ("entity"{tuple_delimiter}"Intelligence"{tuple_delimiter}"concept"{tuple_delimiter}"Intelligence here refers to an unknown entity capable of writing its own rules and learning to communicate."){record_delimiter}
117 | ("entity"{tuple_delimiter}"First Contact"{tuple_delimiter}"event"{tuple_delimiter}"First Contact is the potential initial communication between humanity and an unknown intelligence."){record_delimiter}
118 | ("entity"{tuple_delimiter}"Humanity's Response"{tuple_delimiter}"event"{tuple_delimiter}"Humanity's Response is the collective action taken by Alex's team in response to a message from an unknown intelligence."){record_delimiter}
119 | ("relationship"{tuple_delimiter}"Sam Rivera"{tuple_delimiter}"Intelligence"{tuple_delimiter}"Sam Rivera is directly involved in the process of learning to communicate with the unknown intelligence."{tuple_delimiter}"communication, learning process"{tuple_delimiter}9){record_delimiter}
120 | ("relationship"{tuple_delimiter}"Alex"{tuple_delimiter}"First Contact"{tuple_delimiter}"Alex leads the team that might be making the First Contact with the unknown intelligence."{tuple_delimiter}"leadership, exploration"{tuple_delimiter}10){record_delimiter}
121 | ("relationship"{tuple_delimiter}"Alex"{tuple_delimiter}"Humanity's Response"{tuple_delimiter}"Alex and his team are the key figures in Humanity's Response to the unknown intelligence."{tuple_delimiter}"collective action, cosmic significance"{tuple_delimiter}8){record_delimiter}
122 | ("relationship"{tuple_delimiter}"Control"{tuple_delimiter}"Intelligence"{tuple_delimiter}"The concept of Control is challenged by the Intelligence that writes its own rules."{tuple_delimiter}"power dynamics, autonomy"{tuple_delimiter}7){record_delimiter}
123 | ("content_keywords"{tuple_delimiter}"first contact, control, communication, cosmic significance"){completion_delimiter}
124 | #############################""",
125 | ]
126 |
127 | PROMPTS[
128 | "summarize_entity_descriptions"
129 | ] = """You are a helpful assistant responsible for generating a comprehensive summary of the data provided below.
130 | Given one or two entities, and a list of descriptions, all related to the same entity or group of entities.
131 | Please concatenate all of these into a single, comprehensive description. Make sure to include information collected from all the descriptions.
132 | If the provided descriptions are contradictory, please resolve the contradictions and provide a single, coherent summary.
133 | Make sure it is written in third person, and include the entity names so we the have full context.
134 | Use {language} as output language.
135 |
136 | #######
137 | -Data-
138 | Entities: {entity_name}
139 | Description List: {description_list}
140 | #######
141 | Output:
142 | """
143 |
144 | PROMPTS[
145 | "entiti_continue_extraction"
146 | ] = """MANY entities were missed in the last extraction. Add them below using the same format:
147 | """
148 |
149 | PROMPTS[
150 | "entiti_if_loop_extraction"
151 | ] = """It appears some entities may have still been missed. Answer YES | NO if there are still entities that need to be added.
152 | """
153 |
154 | PROMPTS["fail_response"] = "Sorry, I'm not able to provide an answer to that question."
155 |
156 | PROMPTS["rag_response"] = """---Role---
157 |
158 | You are a helpful assistant responding to questions about data in the tables provided.
159 |
160 |
161 | ---Goal---
162 |
163 | Generate a response of the target length and format that responds to the user's question, summarizing all information in the input data tables appropriate for the response length and format, and incorporating any relevant general knowledge.
164 | If you don't know the answer, just say so. Do not make anything up.
165 | Do not include information where the supporting evidence for it is not provided.
166 |
167 | ---Target response length and format---
168 |
169 | {response_type}
170 |
171 | ---Data tables---
172 |
173 | {context_data}
174 |
175 | Add sections and commentary to the response as appropriate for the length and format. Style the response in markdown.
176 | """
177 |
178 | PROMPTS["keywords_extraction"] = """---Role---
179 |
180 | You are a helpful assistant tasked with identifying both high-level and low-level keywords in the user's query.
181 |
182 | ---Goal---
183 |
184 | Given the query, list both high-level and low-level keywords. High-level keywords focus on overarching concepts or themes, while low-level keywords focus on specific entities, details, or concrete terms.
185 |
186 | ---Instructions---
187 |
188 | - Output the keywords in JSON format.
189 | - The JSON should have two keys:
190 | - "high_level_keywords" for overarching concepts or themes.
191 | - "low_level_keywords" for specific entities or details.
192 |
193 | ######################
194 | -Examples-
195 | ######################
196 | {examples}
197 |
198 | #############################
199 | -Real Data-
200 | ######################
201 | Query: {query}
202 | ######################
203 | The `Output` should be human text, not unicode characters. Keep the same language as `Query`.
204 | Output:
205 |
206 | """
207 |
208 | PROMPTS["keywords_extraction_examples"] = [
209 | """Example 1:
210 |
211 | Query: "How does international trade influence global economic stability?"
212 | ################
213 | Output:
214 | {{
215 | "high_level_keywords": ["International trade", "Global economic stability", "Economic impact"],
216 | "low_level_keywords": ["Trade agreements", "Tariffs", "Currency exchange", "Imports", "Exports"]
217 | }}
218 | #############################""",
219 | """Example 2:
220 |
221 | Query: "What are the environmental consequences of deforestation on biodiversity?"
222 | ################
223 | Output:
224 | {{
225 | "high_level_keywords": ["Environmental consequences", "Deforestation", "Biodiversity loss"],
226 | "low_level_keywords": ["Species extinction", "Habitat destruction", "Carbon emissions", "Rainforest", "Ecosystem"]
227 | }}
228 | #############################""",
229 | """Example 3:
230 |
231 | Query: "What is the role of education in reducing poverty?"
232 | ################
233 | Output:
234 | {{
235 | "high_level_keywords": ["Education", "Poverty reduction", "Socioeconomic development"],
236 | "low_level_keywords": ["School access", "Literacy rates", "Job training", "Income inequality"]
237 | }}
238 | #############################""",
239 | ]
240 |
241 |
242 | PROMPTS["naive_rag_response"] = """---Role---
243 |
244 | You are a helpful assistant responding to questions about documents provided.
245 |
246 |
247 | ---Goal---
248 |
249 | Generate a response of the target length and format that responds to the user's question, summarizing all information in the input data tables appropriate for the response length and format, and incorporating any relevant general knowledge.
250 | If you don't know the answer, just say so. Do not make anything up.
251 | Do not include information where the supporting evidence for it is not provided.
252 |
253 | ---Target response length and format---
254 |
255 | {response_type}
256 |
257 | ---Documents---
258 |
259 | {content_data}
260 |
261 | Add sections and commentary to the response as appropriate for the length and format. Style the response in markdown.
262 | """
263 |
264 | PROMPTS[
265 | "similarity_check"
266 | ] = """Please analyze the similarity between these two questions:
267 |
268 | Question 1: {original_prompt}
269 | Question 2: {cached_prompt}
270 |
271 | Please evaluate the following two points and provide a similarity score between 0 and 1 directly:
272 | 1. Whether these two questions are semantically similar
273 | 2. Whether the answer to Question 2 can be used to answer Question 1
274 | Similarity score criteria:
275 | 0: Completely unrelated or answer cannot be reused, including but not limited to:
276 | - The questions have different topics
277 | - The locations mentioned in the questions are different
278 | - The times mentioned in the questions are different
279 | - The specific individuals mentioned in the questions are different
280 | - The specific events mentioned in the questions are different
281 | - The background information in the questions is different
282 | - The key conditions in the questions are different
283 | 1: Identical and answer can be directly reused
284 | 0.5: Partially related and answer needs modification to be used
285 | Return only a number between 0-1, without any additional content.
286 | """
287 |
--------------------------------------------------------------------------------
/PathRAG/storage.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import html
3 | import os
4 | from tqdm.asyncio import tqdm as tqdm_async
5 | from dataclasses import dataclass
6 | from typing import Any, Union, cast
7 | import networkx as nx
8 | import numpy as np
9 | from nano_vectordb import NanoVectorDB
10 |
11 | from .utils import (
12 | logger,
13 | load_json,
14 | write_json,
15 | compute_mdhash_id,
16 | )
17 |
18 | from .base import (
19 | BaseGraphStorage,
20 | BaseKVStorage,
21 | BaseVectorStorage,
22 | )
23 |
24 |
25 | @dataclass
26 | class JsonKVStorage(BaseKVStorage):
27 | def __post_init__(self):
28 | working_dir = self.global_config["working_dir"]
29 | self._file_name = os.path.join(working_dir, f"kv_store_{self.namespace}.json")
30 | self._data = load_json(self._file_name) or {}
31 | logger.info(f"Load KV {self.namespace} with {len(self._data)} data")
32 |
33 | async def all_keys(self) -> list[str]:
34 | return list(self._data.keys())
35 |
36 | async def index_done_callback(self):
37 | write_json(self._data, self._file_name)
38 |
39 | async def get_by_id(self, id):
40 | return self._data.get(id, None)
41 |
42 | async def get_by_ids(self, ids, fields=None):
43 | if fields is None:
44 | return [self._data.get(id, None) for id in ids]
45 | return [
46 | (
47 | {k: v for k, v in self._data[id].items() if k in fields}
48 | if self._data.get(id, None)
49 | else None
50 | )
51 | for id in ids
52 | ]
53 |
54 | async def filter_keys(self, data: list[str]) -> set[str]:
55 | return set([s for s in data if s not in self._data])
56 |
57 | async def upsert(self, data: dict[str, dict]):
58 | left_data = {k: v for k, v in data.items() if k not in self._data}
59 | self._data.update(left_data)
60 | return left_data
61 |
62 | async def drop(self):
63 | self._data = {}
64 |
65 |
66 | @dataclass
67 | class NanoVectorDBStorage(BaseVectorStorage):
68 | cosine_better_than_threshold: float = 0.2
69 |
70 | def __post_init__(self):
71 | self._client_file_name = os.path.join(
72 | self.global_config["working_dir"], f"vdb_{self.namespace}.json"
73 | )
74 | self._max_batch_size = self.global_config["embedding_batch_num"]
75 | self._client = NanoVectorDB(
76 | self.embedding_func.embedding_dim, storage_file=self._client_file_name
77 | )
78 | self.cosine_better_than_threshold = self.global_config.get(
79 | "cosine_better_than_threshold", self.cosine_better_than_threshold
80 | )
81 |
82 | async def upsert(self, data: dict[str, dict]):
83 | logger.info(f"Inserting {len(data)} vectors to {self.namespace}")
84 | if not len(data):
85 | logger.warning("You insert an empty data to vector DB")
86 | return []
87 | list_data = [
88 | {
89 | "__id__": k,
90 | **{k1: v1 for k1, v1 in v.items() if k1 in self.meta_fields},
91 | }
92 | for k, v in data.items()
93 | ]
94 | contents = [v["content"] for v in data.values()]
95 | batches = [
96 | contents[i : i + self._max_batch_size]
97 | for i in range(0, len(contents), self._max_batch_size)
98 | ]
99 |
100 | async def wrapped_task(batch):
101 | result = await self.embedding_func(batch)
102 | pbar.update(1)
103 | return result
104 |
105 | embedding_tasks = [wrapped_task(batch) for batch in batches]
106 | pbar = tqdm_async(
107 | total=len(embedding_tasks), desc="Generating embeddings", unit="batch"
108 | )
109 | embeddings_list = await asyncio.gather(*embedding_tasks)
110 |
111 | embeddings = np.concatenate(embeddings_list)
112 | if len(embeddings) == len(list_data):
113 | for i, d in enumerate(list_data):
114 | d["__vector__"] = embeddings[i]
115 | results = self._client.upsert(datas=list_data)
116 | return results
117 | else:
118 | # sometimes the embedding is not returned correctly. just log it.
119 | logger.error(
120 | f"embedding is not 1-1 with data, {len(embeddings)} != {len(list_data)}"
121 | )
122 |
123 | async def query(self, query: str, top_k=5):
124 | embedding = await self.embedding_func([query])
125 | embedding = embedding[0]
126 | results = self._client.query(
127 | query=embedding,
128 | top_k=top_k,
129 | better_than_threshold=self.cosine_better_than_threshold,
130 | )
131 | results = [
132 | {**dp, "id": dp["__id__"], "distance": dp["__metrics__"]} for dp in results
133 | ]
134 | return results
135 |
136 | @property
137 | def client_storage(self):
138 | return getattr(self._client, "_NanoVectorDB__storage")
139 |
140 | async def delete_entity(self, entity_name: str):
141 | try:
142 | entity_id = [compute_mdhash_id(entity_name, prefix="ent-")]
143 |
144 | if self._client.get(entity_id):
145 | self._client.delete(entity_id)
146 | logger.info(f"Entity {entity_name} have been deleted.")
147 | else:
148 | logger.info(f"No entity found with name {entity_name}.")
149 | except Exception as e:
150 | logger.error(f"Error while deleting entity {entity_name}: {e}")
151 |
152 | async def delete_relation(self, entity_name: str):
153 | try:
154 | relations = [
155 | dp
156 | for dp in self.client_storage["data"]
157 | if dp["src_id"] == entity_name or dp["tgt_id"] == entity_name
158 | ]
159 | ids_to_delete = [relation["__id__"] for relation in relations]
160 |
161 | if ids_to_delete:
162 | self._client.delete(ids_to_delete)
163 | logger.info(
164 | f"All relations related to entity {entity_name} have been deleted."
165 | )
166 | else:
167 | logger.info(f"No relations found for entity {entity_name}.")
168 | except Exception as e:
169 | logger.error(
170 | f"Error while deleting relations for entity {entity_name}: {e}"
171 | )
172 |
173 | async def index_done_callback(self):
174 | self._client.save()
175 |
176 |
177 | @dataclass
178 | class NetworkXStorage(BaseGraphStorage):
179 | @staticmethod
180 | def load_nx_graph(file_name) -> nx.DiGraph:
181 | if os.path.exists(file_name):
182 | return nx.read_graphml(file_name)
183 | return None
184 | # def load_nx_graph(file_name) -> nx.Graph:
185 | # if os.path.exists(file_name):
186 | # return nx.read_graphml(file_name)
187 | # return None
188 |
189 | @staticmethod
190 | def write_nx_graph(graph: nx.DiGraph, file_name):
191 | logger.info(
192 | f"Writing graph with {graph.number_of_nodes()} nodes, {graph.number_of_edges()} edges"
193 | )
194 | nx.write_graphml(graph, file_name)
195 |
196 | @staticmethod
197 | def stable_largest_connected_component(graph: nx.Graph) -> nx.Graph:
198 | """Refer to https://github.com/microsoft/graphrag/index/graph/utils/stable_lcc.py
199 | Return the largest connected component of the graph, with nodes and edges sorted in a stable way.
200 | """
201 | from graspologic.utils import largest_connected_component
202 |
203 | graph = graph.copy()
204 | graph = cast(nx.Graph, largest_connected_component(graph))
205 | node_mapping = {
206 | node: html.unescape(node.upper().strip()) for node in graph.nodes()
207 | } # type: ignore
208 | graph = nx.relabel_nodes(graph, node_mapping)
209 | return NetworkXStorage._stabilize_graph(graph)
210 |
211 | @staticmethod
212 | def _stabilize_graph(graph: nx.Graph) -> nx.Graph:
213 | """Refer to https://github.com/microsoft/graphrag/index/graph/utils/stable_lcc.py
214 | Ensure an undirected graph with the same relationships will always be read the same way.
215 | """
216 | fixed_graph = nx.DiGraph() if graph.is_directed() else nx.Graph()
217 |
218 | sorted_nodes = graph.nodes(data=True)
219 | sorted_nodes = sorted(sorted_nodes, key=lambda x: x[0])
220 |
221 | fixed_graph.add_nodes_from(sorted_nodes)
222 | edges = list(graph.edges(data=True))
223 |
224 | if not graph.is_directed():
225 |
226 | def _sort_source_target(edge):
227 | source, target, edge_data = edge
228 | if source > target:
229 | temp = source
230 | source = target
231 | target = temp
232 | return source, target, edge_data
233 |
234 | edges = [_sort_source_target(edge) for edge in edges]
235 |
236 | def _get_edge_key(source: Any, target: Any) -> str:
237 | return f"{source} -> {target}"
238 |
239 | edges = sorted(edges, key=lambda x: _get_edge_key(x[0], x[1]))
240 |
241 | fixed_graph.add_edges_from(edges)
242 | return fixed_graph
243 |
244 | def __post_init__(self):
245 | self._graphml_xml_file = os.path.join(
246 | self.global_config["working_dir"], f"graph_{self.namespace}.graphml"
247 | )
248 | preloaded_graph = NetworkXStorage.load_nx_graph(self._graphml_xml_file)
249 | if preloaded_graph is not None:
250 | logger.info(
251 | f"Loaded graph from {self._graphml_xml_file} with {preloaded_graph.number_of_nodes()} nodes, {preloaded_graph.number_of_edges()} edges"
252 | )
253 | self._graph = preloaded_graph or nx.DiGraph()
254 | self._node_embed_algorithms = {
255 | "node2vec": self._node2vec_embed,
256 | }
257 |
258 | async def index_done_callback(self):
259 | NetworkXStorage.write_nx_graph(self._graph, self._graphml_xml_file)
260 |
261 | async def has_node(self, node_id: str) -> bool:
262 | return self._graph.has_node(node_id)
263 |
264 | async def has_edge(self, source_node_id: str, target_node_id: str) -> bool:
265 | return self._graph.has_edge(source_node_id, target_node_id)
266 |
267 | async def get_node(self, node_id: str) -> Union[dict, None]:
268 | return self._graph.nodes.get(node_id)
269 |
270 | async def node_degree(self, node_id: str) -> int:
271 | return self._graph.degree(node_id)
272 |
273 | async def edge_degree(self, src_id: str, tgt_id: str) -> int:
274 | return self._graph.degree(src_id) + self._graph.degree(tgt_id)
275 |
276 | async def get_edge(
277 | self, source_node_id: str, target_node_id: str
278 | ) -> Union[dict, None]:
279 | return self._graph.edges.get((source_node_id, target_node_id))
280 |
281 | async def get_node_edges(self, source_node_id: str):
282 | if self._graph.has_node(source_node_id):
283 | return list(self._graph.edges(source_node_id))
284 | return None
285 | async def get_node_in_edges(self, source_node_id: str):
286 | if self._graph.has_node(source_node_id):
287 | return list(self._graph.in_edges(source_node_id))
288 | return None
289 | async def get_node_out_edges(self, source_node_id: str):
290 | if self._graph.has_node(source_node_id):
291 | return list(self._graph.out_edges(source_node_id))
292 | return None
293 |
294 | async def get_pagerank(self,source_node_id:str):
295 | pagerank_list=nx.pagerank(self._graph)
296 | if source_node_id in pagerank_list:
297 | return pagerank_list[source_node_id]
298 | else:
299 | print("pagerank failed")
300 |
301 | async def upsert_node(self, node_id: str, node_data: dict[str, str]):
302 | self._graph.add_node(node_id, **node_data)
303 |
304 | async def upsert_edge(
305 | self, source_node_id: str, target_node_id: str, edge_data: dict[str, str]
306 | ):
307 | self._graph.add_edge(source_node_id, target_node_id, **edge_data)
308 |
309 | async def delete_node(self, node_id: str):
310 | """
311 | Delete a node from the graph based on the specified node_id.
312 |
313 | :param node_id: The node_id to delete
314 | """
315 | if self._graph.has_node(node_id):
316 | self._graph.remove_node(node_id)
317 | logger.info(f"Node {node_id} deleted from the graph.")
318 | else:
319 | logger.warning(f"Node {node_id} not found in the graph for deletion.")
320 |
321 | async def embed_nodes(self, algorithm: str) -> tuple[np.ndarray, list[str]]:
322 | if algorithm not in self._node_embed_algorithms:
323 | raise ValueError(f"Node embedding algorithm {algorithm} not supported")
324 | return await self._node_embed_algorithms[algorithm]()
325 |
326 | # @TODO: NOT USED
327 | async def _node2vec_embed(self):
328 | from graspologic import embed
329 |
330 | embeddings, nodes = embed.node2vec_embed(
331 | self._graph,
332 | **self.global_config["node2vec_params"],
333 | )
334 |
335 | nodes_ids = [self._graph.nodes[node_id]["id"] for node_id in nodes]
336 | return embeddings, nodes_ids
337 |
338 | async def edges(self):
339 | return self._graph.edges()
340 | async def nodes(self):
341 | return self._graph.nodes()
342 |
--------------------------------------------------------------------------------
/PathRAG/utils.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import html
3 | import io
4 | import csv
5 | import json
6 | import logging
7 | import os
8 | import re
9 | from dataclasses import dataclass
10 | from functools import wraps
11 | from hashlib import md5
12 | from typing import Any, Union, List, Optional
13 | import xml.etree.ElementTree as ET
14 |
15 | import numpy as np
16 | import tiktoken
17 |
18 | from PathRAG.prompt import PROMPTS
19 |
20 |
21 | class UnlimitedSemaphore:
22 |
23 |
24 | async def __aenter__(self):
25 | pass
26 |
27 | async def __aexit__(self, exc_type, exc, tb):
28 | pass
29 |
30 |
31 | ENCODER = None
32 |
33 | logger = logging.getLogger("PathRAG")
34 |
35 |
36 | def set_logger(log_file: str):
37 | logger.setLevel(logging.DEBUG)
38 |
39 | file_handler = logging.FileHandler(log_file)
40 | file_handler.setLevel(logging.DEBUG)
41 |
42 | formatter = logging.Formatter(
43 | "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
44 | )
45 | file_handler.setFormatter(formatter)
46 |
47 | if not logger.handlers:
48 | logger.addHandler(file_handler)
49 |
50 |
51 | @dataclass
52 | class EmbeddingFunc:
53 | embedding_dim: int
54 | max_token_size: int
55 | func: callable
56 | concurrent_limit: int = 16
57 |
58 | def __post_init__(self):
59 | if self.concurrent_limit != 0:
60 | self._semaphore = asyncio.Semaphore(self.concurrent_limit)
61 | else:
62 | self._semaphore = UnlimitedSemaphore()
63 |
64 | async def __call__(self, *args, **kwargs) -> np.ndarray:
65 | async with self._semaphore:
66 | return await self.func(*args, **kwargs)
67 |
68 |
69 | def locate_json_string_body_from_string(content: str) -> Union[str, None]:
70 |
71 | try:
72 | maybe_json_str = re.search(r"{.*}", content, re.DOTALL)
73 | if maybe_json_str is not None:
74 | maybe_json_str = maybe_json_str.group(0)
75 | maybe_json_str = maybe_json_str.replace("\\n", "")
76 | maybe_json_str = maybe_json_str.replace("\n", "")
77 | maybe_json_str = maybe_json_str.replace("'", '"')
78 |
79 | return maybe_json_str
80 | except Exception:
81 | pass
82 |
83 |
84 | return None
85 |
86 |
87 | def convert_response_to_json(response: str) -> dict:
88 | json_str = locate_json_string_body_from_string(response)
89 | assert json_str is not None, f"Unable to parse JSON from response: {response}"
90 | try:
91 | data = json.loads(json_str)
92 | return data
93 | except json.JSONDecodeError as e:
94 | logger.error(f"Failed to parse JSON: {json_str}")
95 | raise e from None
96 |
97 |
98 | def compute_args_hash(*args):
99 | return md5(str(args).encode()).hexdigest()
100 |
101 |
102 | def compute_mdhash_id(content, prefix: str = ""):
103 | return prefix + md5(content.encode()).hexdigest()
104 |
105 |
106 | def limit_async_func_call(max_size: int, waitting_time: float = 0.0001):
107 |
108 |
109 | def final_decro(func):
110 |
111 | __current_size = 0
112 |
113 | @wraps(func)
114 | async def wait_func(*args, **kwargs):
115 | nonlocal __current_size
116 | while __current_size >= max_size:
117 | await asyncio.sleep(waitting_time)
118 | __current_size += 1
119 | result = await func(*args, **kwargs)
120 | __current_size -= 1
121 | return result
122 |
123 | return wait_func
124 |
125 | return final_decro
126 |
127 |
128 | def wrap_embedding_func_with_attrs(**kwargs):
129 |
130 |
131 | def final_decro(func) -> EmbeddingFunc:
132 | new_func = EmbeddingFunc(**kwargs, func=func)
133 | return new_func
134 |
135 | return final_decro
136 |
137 |
138 | def load_json(file_name):
139 | if not os.path.exists(file_name):
140 | return None
141 | with open(file_name, encoding="utf-8") as f:
142 | return json.load(f)
143 |
144 |
145 | def write_json(json_obj, file_name):
146 | with open(file_name, "w", encoding="utf-8") as f:
147 | json.dump(json_obj, f, indent=2, ensure_ascii=False)
148 |
149 |
150 | def encode_string_by_tiktoken(content: str, model_name: str = "gpt-4o-mini"):
151 | global ENCODER
152 | if ENCODER is None:
153 | ENCODER = tiktoken.encoding_for_model(model_name)
154 | tokens = ENCODER.encode(content)
155 | return tokens
156 |
157 |
158 | def decode_tokens_by_tiktoken(tokens: list[int], model_name: str = "gpt-4o-mini"):
159 | global ENCODER
160 | if ENCODER is None:
161 | ENCODER = tiktoken.encoding_for_model(model_name)
162 | content = ENCODER.decode(tokens)
163 | return content
164 |
165 |
166 | def pack_user_ass_to_openai_messages(*args: str):
167 | roles = ["user", "assistant"]
168 | return [
169 | {"role": roles[i % 2], "content": content} for i, content in enumerate(args)
170 | ]
171 |
172 |
173 | def split_string_by_multi_markers(content: str, markers: list[str]) -> list[str]:
174 |
175 | if not markers:
176 | return [content]
177 | results = re.split("|".join(re.escape(marker) for marker in markers), content)
178 | return [r.strip() for r in results if r.strip()]
179 |
180 |
181 |
182 | def clean_str(input: Any) -> str:
183 |
184 |
185 | if not isinstance(input, str):
186 | return input
187 |
188 | result = html.unescape(input.strip())
189 |
190 | return re.sub(r"[\x00-\x1f\x7f-\x9f]", "", result)
191 |
192 |
193 | def is_float_regex(value):
194 | return bool(re.match(r"^[-+]?[0-9]*\.?[0-9]+$", value))
195 |
196 |
197 | def truncate_list_by_token_size(list_data: list, key: callable, max_token_size: int):
198 |
199 | if max_token_size <= 0:
200 | return []
201 | tokens = 0
202 | for i, data in enumerate(list_data):
203 | tokens += len(encode_string_by_tiktoken(key(data)))
204 | if tokens > max_token_size:
205 | return list_data[:i]
206 | return list_data
207 |
208 |
209 | def list_of_list_to_csv(data: List[List[str]]) -> str:
210 | output = io.StringIO()
211 | writer = csv.writer(output)
212 | writer.writerows(data)
213 | return output.getvalue()
214 |
215 |
216 | def csv_string_to_list(csv_string: str) -> List[List[str]]:
217 | output = io.StringIO(csv_string)
218 | reader = csv.reader(output)
219 | return [row for row in reader]
220 |
221 |
222 | def save_data_to_file(data, file_name):
223 | with open(file_name, "w", encoding="utf-8") as f:
224 | json.dump(data, f, ensure_ascii=False, indent=4)
225 |
226 |
227 | def xml_to_json(xml_file):
228 | try:
229 | tree = ET.parse(xml_file)
230 | root = tree.getroot()
231 |
232 | print(f"Root element: {root.tag}")
233 | print(f"Root attributes: {root.attrib}")
234 |
235 | data = {"nodes": [], "edges": []}
236 | namespace = {"": "http://graphml.graphdrawing.org/xmlns"}
237 |
238 | for node in root.findall(".//node", namespace):
239 | node_data = {
240 | "id": node.get("id").strip('"'),
241 | "entity_type": node.find("./data[@key='d0']", namespace).text.strip('"')
242 | if node.find("./data[@key='d0']", namespace) is not None
243 | else "",
244 | "description": node.find("./data[@key='d1']", namespace).text
245 | if node.find("./data[@key='d1']", namespace) is not None
246 | else "",
247 | "source_id": node.find("./data[@key='d2']", namespace).text
248 | if node.find("./data[@key='d2']", namespace) is not None
249 | else "",
250 | }
251 | data["nodes"].append(node_data)
252 |
253 | for edge in root.findall(".//edge", namespace):
254 | edge_data = {
255 | "source": edge.get("source").strip('"'),
256 | "target": edge.get("target").strip('"'),
257 | "weight": float(edge.find("./data[@key='d3']", namespace).text)
258 | if edge.find("./data[@key='d3']", namespace) is not None
259 | else 0.0,
260 | "description": edge.find("./data[@key='d4']", namespace).text
261 | if edge.find("./data[@key='d4']", namespace) is not None
262 | else "",
263 | "keywords": edge.find("./data[@key='d5']", namespace).text
264 | if edge.find("./data[@key='d5']", namespace) is not None
265 | else "",
266 | "source_id": edge.find("./data[@key='d6']", namespace).text
267 | if edge.find("./data[@key='d6']", namespace) is not None
268 | else "",
269 | }
270 | data["edges"].append(edge_data)
271 |
272 | print(f"Found {len(data['nodes'])} nodes and {len(data['edges'])} edges")
273 |
274 | return data
275 | except ET.ParseError as e:
276 | print(f"Error parsing XML file: {e}")
277 | return None
278 | except Exception as e:
279 | print(f"An error occurred: {e}")
280 | return None
281 |
282 |
283 | def process_combine_contexts(hl, ll):
284 | header = None
285 | list_hl = csv_string_to_list(hl.strip())
286 | list_ll = csv_string_to_list(ll.strip())
287 |
288 | if list_hl:
289 | header = list_hl[0]
290 | list_hl = list_hl[1:]
291 | if list_ll:
292 | header = list_ll[0]
293 | list_ll = list_ll[1:]
294 | if header is None:
295 | return ""
296 |
297 | if list_hl:
298 | list_hl = [",".join(item[1:]) for item in list_hl if item]
299 | if list_ll:
300 | list_ll = [",".join(item[1:]) for item in list_ll if item]
301 |
302 | combined_sources = []
303 | seen = set()
304 |
305 | for item in list_hl + list_ll:
306 | if item and item not in seen:
307 | combined_sources.append(item)
308 | seen.add(item)
309 |
310 | combined_sources_result = [",\t".join(header)]
311 |
312 | for i, item in enumerate(combined_sources, start=1):
313 | combined_sources_result.append(f"{i},\t{item}")
314 |
315 | combined_sources_result = "\n".join(combined_sources_result)
316 |
317 | return combined_sources_result
318 |
319 |
320 | async def get_best_cached_response(
321 | hashing_kv,
322 | current_embedding,
323 | similarity_threshold=0.95,
324 | mode="default",
325 | use_llm_check=False,
326 | llm_func=None,
327 | original_prompt=None,
328 | ) -> Union[str, None]:
329 |
330 | mode_cache = await hashing_kv.get_by_id(mode)
331 | if not mode_cache:
332 | return None
333 |
334 | best_similarity = -1
335 | best_response = None
336 | best_prompt = None
337 | best_cache_id = None
338 |
339 |
340 | for cache_id, cache_data in mode_cache.items():
341 | if cache_data["embedding"] is None:
342 | continue
343 |
344 |
345 | cached_quantized = np.frombuffer(
346 | bytes.fromhex(cache_data["embedding"]), dtype=np.uint8
347 | ).reshape(cache_data["embedding_shape"])
348 | cached_embedding = dequantize_embedding(
349 | cached_quantized,
350 | cache_data["embedding_min"],
351 | cache_data["embedding_max"],
352 | )
353 |
354 | similarity = cosine_similarity(current_embedding, cached_embedding)
355 | if similarity > best_similarity:
356 | best_similarity = similarity
357 | best_response = cache_data["return"]
358 | best_prompt = cache_data["original_prompt"]
359 | best_cache_id = cache_id
360 |
361 | if best_similarity > similarity_threshold:
362 |
363 | if use_llm_check and llm_func and original_prompt and best_prompt:
364 | compare_prompt = PROMPTS["similarity_check"].format(
365 | original_prompt=original_prompt, cached_prompt=best_prompt
366 | )
367 |
368 | try:
369 | llm_result = await llm_func(compare_prompt)
370 | llm_result = llm_result.strip()
371 | llm_similarity = float(llm_result)
372 |
373 |
374 | best_similarity = llm_similarity
375 | if best_similarity < similarity_threshold:
376 | log_data = {
377 | "event": "llm_check_cache_rejected",
378 | "original_question": original_prompt[:100] + "..."
379 | if len(original_prompt) > 100
380 | else original_prompt,
381 | "cached_question": best_prompt[:100] + "..."
382 | if len(best_prompt) > 100
383 | else best_prompt,
384 | "similarity_score": round(best_similarity, 4),
385 | "threshold": similarity_threshold,
386 | }
387 | logger.info(json.dumps(log_data, ensure_ascii=False))
388 | return None
389 | except Exception as e:
390 | logger.warning(f"LLM similarity check failed: {e}")
391 | return None
392 |
393 | prompt_display = (
394 | best_prompt[:50] + "..." if len(best_prompt) > 50 else best_prompt
395 | )
396 | log_data = {
397 | "event": "cache_hit",
398 | "mode": mode,
399 | "similarity": round(best_similarity, 4),
400 | "cache_id": best_cache_id,
401 | "original_prompt": prompt_display,
402 | }
403 | logger.info(json.dumps(log_data, ensure_ascii=False))
404 | return best_response
405 | return None
406 |
407 |
408 | def cosine_similarity(v1, v2):
409 |
410 | dot_product = np.dot(v1, v2)
411 | norm1 = np.linalg.norm(v1)
412 | norm2 = np.linalg.norm(v2)
413 | return dot_product / (norm1 * norm2)
414 |
415 |
416 | def quantize_embedding(embedding: np.ndarray, bits=8) -> tuple:
417 |
418 |
419 | min_val = embedding.min()
420 | max_val = embedding.max()
421 |
422 |
423 | scale = (2**bits - 1) / (max_val - min_val)
424 | quantized = np.round((embedding - min_val) * scale).astype(np.uint8)
425 |
426 | return quantized, min_val, max_val
427 |
428 |
429 | def dequantize_embedding(
430 | quantized: np.ndarray, min_val: float, max_val: float, bits=8
431 | ) -> np.ndarray:
432 |
433 | scale = (max_val - min_val) / (2**bits - 1)
434 | return (quantized * scale + min_val).astype(np.float32)
435 |
436 |
437 | async def handle_cache(hashing_kv, args_hash, prompt, mode="default"):
438 |
439 | if hashing_kv is None:
440 | return None, None, None, None
441 |
442 |
443 | if mode == "naive":
444 | mode_cache = await hashing_kv.get_by_id(mode) or {}
445 | if args_hash in mode_cache:
446 | return mode_cache[args_hash]["return"], None, None, None
447 | return None, None, None, None
448 |
449 |
450 | embedding_cache_config = hashing_kv.global_config.get(
451 | "embedding_cache_config",
452 | {"enabled": False, "similarity_threshold": 0.95, "use_llm_check": False},
453 | )
454 | is_embedding_cache_enabled = embedding_cache_config["enabled"]
455 | use_llm_check = embedding_cache_config.get("use_llm_check", False)
456 |
457 | quantized = min_val = max_val = None
458 | if is_embedding_cache_enabled:
459 |
460 | embedding_model_func = hashing_kv.global_config["embedding_func"]["func"]
461 | llm_model_func = hashing_kv.global_config.get("llm_model_func")
462 |
463 | current_embedding = await embedding_model_func([prompt])
464 | quantized, min_val, max_val = quantize_embedding(current_embedding[0])
465 | best_cached_response = await get_best_cached_response(
466 | hashing_kv,
467 | current_embedding[0],
468 | similarity_threshold=embedding_cache_config["similarity_threshold"],
469 | mode=mode,
470 | use_llm_check=use_llm_check,
471 | llm_func=llm_model_func if use_llm_check else None,
472 | original_prompt=prompt if use_llm_check else None,
473 | )
474 | if best_cached_response is not None:
475 | return best_cached_response, None, None, None
476 | else:
477 |
478 | mode_cache = await hashing_kv.get_by_id(mode) or {}
479 | if args_hash in mode_cache:
480 | return mode_cache[args_hash]["return"], None, None, None
481 |
482 | return None, quantized, min_val, max_val
483 |
484 |
485 | @dataclass
486 | class CacheData:
487 | args_hash: str
488 | content: str
489 | prompt: str
490 | quantized: Optional[np.ndarray] = None
491 | min_val: Optional[float] = None
492 | max_val: Optional[float] = None
493 | mode: str = "default"
494 |
495 |
496 | async def save_to_cache(hashing_kv, cache_data: CacheData):
497 | if hashing_kv is None or hasattr(cache_data.content, "__aiter__"):
498 | return
499 |
500 | mode_cache = await hashing_kv.get_by_id(cache_data.mode) or {}
501 |
502 | mode_cache[cache_data.args_hash] = {
503 | "return": cache_data.content,
504 | "embedding": cache_data.quantized.tobytes().hex()
505 | if cache_data.quantized is not None
506 | else None,
507 | "embedding_shape": cache_data.quantized.shape
508 | if cache_data.quantized is not None
509 | else None,
510 | "embedding_min": cache_data.min_val,
511 | "embedding_max": cache_data.max_val,
512 | "original_prompt": cache_data.prompt,
513 | }
514 |
515 | await hashing_kv.upsert({cache_data.mode: mode_cache})
516 |
517 |
518 | def safe_unicode_decode(content):
519 | unicode_escape_pattern = re.compile(r"\\u([0-9a-fA-F]{4})")
520 | def replace_unicode_escape(match):
521 | return chr(int(match.group(1), 16))
522 |
523 | decoded_content = unicode_escape_pattern.sub(
524 | replace_unicode_escape, content.decode("utf-8")
525 | )
526 |
527 | return decoded_content
528 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | The code for the paper **"PathRAG: Pruning Graph-based Retrieval Augmented Generation with Relational Paths"**.
2 | ## Install
3 | ```bash
4 | cd PathRAG
5 | pip install -e .
6 | ```
7 | ## Quick Start
8 | * You can quickly experience this project in the `v1_test.py` file.
9 | * Set OpenAI API key in environment if using OpenAI models: `api_key="sk-...".` in the `v1_test.py` and `llm.py` file
10 | * Prepare your retrieval document "text.txt".
11 | * Use the following Python snippet in the "v1_text.py" file to initialize PathRAG and perform queries.
12 |
13 | ```python
14 | import os
15 | from PathRAG import PathRAG, QueryParam
16 | from PathRAG.llm import gpt_4o_mini_complete
17 |
18 | WORKING_DIR = "./your_working_dir"
19 | api_key="your_api_key"
20 | os.environ["OPENAI_API_KEY"] = api_key
21 | base_url="https://api.openai.com/v1"
22 | os.environ["OPENAI_API_BASE"]=base_url
23 |
24 |
25 | if not os.path.exists(WORKING_DIR):
26 | os.mkdir(WORKING_DIR)
27 |
28 | rag = PathRAG(
29 | working_dir=WORKING_DIR,
30 | llm_model_func=gpt_4o_mini_complete,
31 | )
32 |
33 | data_file="./text.txt"
34 | question="your_question"
35 | with open(data_file) as f:
36 | rag.insert(f.read())
37 |
38 | print(rag.query(question, param=QueryParam(mode="hybrid")))
39 | ```
40 | ## Parameter modification
41 | You can adjust the relevant parameters in the `base.py` and `operate.py` files.
42 |
43 | ## Batch Insert
44 | ```python
45 | import os
46 | folder_path = "your_folder_path"
47 |
48 | txt_files = [f for f in os.listdir(folder_path) if f.endswith(".txt")]
49 | for file_name in txt_files:
50 | file_path = os.path.join(folder_path, file_name)
51 | with open(file_path, "r", encoding="utf-8") as file:
52 | rag.insert(file.read())
53 | ```
54 |
55 | ## Cite
56 | Please cite our paper if you use this code in your own work:
57 | ```python
58 | @article{chen2025pathrag,
59 | title={PathRAG: Pruning Graph-based Retrieval Augmented Generation with Relational Paths},
60 | author={Chen, Boyu and Guo, Zirui and Yang, Zidan and Chen, Yuluo and Chen, Junze and Liu, Zhenghao and Shi, Chuan and Yang, Cheng},
61 | journal={arXiv preprint arXiv:2502.14902},
62 | year={2025}
63 | }
64 | ```
65 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | accelerate
2 | aioboto3
3 | aiohttp
4 |
5 | # database packages
6 | graspologic
7 | hnswlib
8 | nano-vectordb
9 | neo4j
10 | networkx
11 | ollama
12 | openai
13 | oracledb
14 | psycopg[binary,pool]
15 | pymilvus
16 | pymongo
17 | pymysql
18 | pyvis
19 | # lmdeploy[all]
20 | sqlalchemy
21 | tenacity
22 |
23 |
24 | # LLM packages
25 | tiktoken
26 | torch
27 | transformers
28 | xxhash
29 |
--------------------------------------------------------------------------------
/v1_test.py:
--------------------------------------------------------------------------------
1 | import os
2 | from PathRAG import PathRAG, QueryParam
3 | from PathRAG.llm import gpt_4o_mini_complete
4 |
5 | WORKING_DIR = ""
6 |
7 | api_key=""
8 | os.environ["OPENAI_API_KEY"] = api_key
9 | base_url="https://api.openai.com/v1"
10 | os.environ["OPENAI_API_BASE"]=base_url
11 |
12 |
13 | if not os.path.exists(WORKING_DIR):
14 | os.mkdir(WORKING_DIR)
15 |
16 | rag = PathRAG(
17 | working_dir=WORKING_DIR,
18 | llm_model_func=gpt_4o_mini_complete,
19 | )
20 |
21 | data_file=""
22 | question=""
23 | with open(data_file) as f:
24 | rag.insert(f.read())
25 |
26 | print(rag.query(question, param=QueryParam(mode="hybrid")))
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
--------------------------------------------------------------------------------