├── .python-version ├── examples ├── .gitignore └── simple-app │ ├── app │ ├── py.typed │ ├── __init__.py │ ├── main.py │ └── indexer.py │ ├── .env.example │ └── README.md ├── src └── langchain_graphrag │ ├── py.typed │ ├── query │ ├── __init__.py │ ├── global_search │ │ ├── __init__.py │ │ ├── community_report.py │ │ ├── key_points_aggregator │ │ │ ├── __init__.py │ │ │ ├── aggregator.py │ │ │ ├── prompt_builder.py │ │ │ ├── context_builder.py │ │ │ └── _system_prompt.py │ │ ├── key_points_generator │ │ │ ├── utils.py │ │ │ ├── __init__.py │ │ │ ├── _output_parser.py │ │ │ ├── generator.py │ │ │ ├── prompt_builder.py │ │ │ ├── context_builder.py │ │ │ └── _system_prompt.py │ │ ├── community_weight_calculator.py │ │ └── search.py │ └── local_search │ │ ├── __init__.py │ │ ├── context_builders │ │ ├── __init__.py │ │ ├── text_units.py │ │ ├── communities_reports.py │ │ ├── entities.py │ │ ├── relationships.py │ │ └── context.py │ │ ├── context_selectors │ │ ├── __init__.py │ │ ├── entities.py │ │ ├── communities_reports.py │ │ ├── text_units.py │ │ ├── context.py │ │ └── relationships.py │ │ ├── retriever.py │ │ ├── search.py │ │ ├── prompt_builder.py │ │ └── _system_prompt.py │ ├── __init__.py │ ├── types │ ├── graphs │ │ ├── __init__.py │ │ ├── embedding.py │ │ └── community.py │ ├── tokens.py │ ├── __init__.py │ └── prompts.py │ ├── utils │ ├── __init__.py │ ├── uuid.py │ └── token_counter.py │ └── indexing │ ├── graph_clustering │ ├── __init__.py │ └── leiden_community_detector.py │ ├── embedding_generation │ ├── __init__.py │ └── graph │ │ ├── __init__.py │ │ └── node2vec.py │ ├── __init__.py │ ├── graph_generation │ ├── entity_relationship_extraction │ │ ├── __init__.py │ │ ├── extractor.py │ │ ├── prompt_builder.py │ │ ├── _output_parser.py │ │ └── _default_prompts.py │ ├── entity_relationship_summarization │ │ ├── __init__.py │ │ ├── _default_prompts.py │ │ ├── prompt_builder.py │ │ └── summarizer.py │ ├── __init__.py │ ├── generator.py │ └── graphs_merger.py │ ├── report_generation │ ├── __init__.py │ ├── _output_parser.py │ ├── writer.py │ ├── generator.py │ ├── prompt_builder.py │ └── utils.py │ ├── artifacts_generation │ ├── __init__.py │ ├── relationships.py │ ├── reports.py │ ├── text_units.py │ └── entities.py │ ├── text_unit_extractor.py │ ├── _graph_utils.py │ ├── simple_indexer.py │ └── artifacts.py ├── docs ├── guides │ ├── graph_extraction │ │ ├── .gitignore │ │ ├── sample-data │ │ │ └── base_text_units.parquet │ │ └── index.md │ └── text_units_extraction.ipynb ├── index.md └── architecture │ └── overview.md ├── tests ├── __init__.py ├── test_summarizer.py └── test_graphs_merger.py ├── mypy.ini ├── requirements-docs.txt ├── .readthedocs.yaml ├── .devcontainer ├── Dockerfile └── devcontainer.json ├── .pre-commit-config.yaml ├── scripts ├── post-create.sh └── gen_ref_pages.py ├── .github └── workflows │ └── publish.yml ├── mkdocs.yml ├── ruff.toml ├── .vscode └── launch.json ├── .gitignore └── pyproject.toml /.python-version: -------------------------------------------------------------------------------- 1 | 3.10 2 | -------------------------------------------------------------------------------- /examples/.gitignore: -------------------------------------------------------------------------------- 1 | notebooks -------------------------------------------------------------------------------- /examples/simple-app/app/py.typed: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/langchain_graphrag/py.typed: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /docs/guides/graph_extraction/.gitignore: -------------------------------------------------------------------------------- 1 | *.db -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | """Unit Tests.""" 2 | -------------------------------------------------------------------------------- /examples/simple-app/app/__init__.py: -------------------------------------------------------------------------------- 1 | """Simple app.""" 2 | -------------------------------------------------------------------------------- /mypy.ini: -------------------------------------------------------------------------------- 1 | [mypy] 2 | enable_incomplete_feature=Unpack 3 | -------------------------------------------------------------------------------- /src/langchain_graphrag/query/__init__.py: -------------------------------------------------------------------------------- 1 | """Query module.""" 2 | -------------------------------------------------------------------------------- /src/langchain_graphrag/__init__.py: -------------------------------------------------------------------------------- 1 | """GraphRAG module for LangChain.""" 2 | -------------------------------------------------------------------------------- /src/langchain_graphrag/types/graphs/__init__.py: -------------------------------------------------------------------------------- 1 | """Types and protocols specific to graphs.""" 2 | -------------------------------------------------------------------------------- /src/langchain_graphrag/types/tokens.py: -------------------------------------------------------------------------------- 1 | from typing import Protocol 2 | 3 | 4 | class TokenCounter(Protocol): 5 | def count_tokens(self, text: str) -> int: ... 6 | -------------------------------------------------------------------------------- /src/langchain_graphrag/query/global_search/__init__.py: -------------------------------------------------------------------------------- 1 | """Global Query Search Module.""" 2 | 3 | from .search import GlobalSearch 4 | 5 | __all__ = ["GlobalSearch"] 6 | -------------------------------------------------------------------------------- /requirements-docs.txt: -------------------------------------------------------------------------------- 1 | mkdocs 2 | mkdocstrings[python] 3 | markdown-include 4 | mkdocs-gen-files 5 | mkdocs-literate-nav 6 | mkdocs-section-index 7 | mkdocs-material 8 | mkdocs-jupyter -------------------------------------------------------------------------------- /docs/guides/graph_extraction/sample-data/base_text_units.parquet: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ksachdeva/langchain-graphrag/HEAD/docs/guides/graph_extraction/sample-data/base_text_units.parquet -------------------------------------------------------------------------------- /src/langchain_graphrag/types/__init__.py: -------------------------------------------------------------------------------- 1 | """Misc types and protocols.""" 2 | 3 | from .prompts import PromptBuilder 4 | from .tokens import TokenCounter 5 | 6 | __all__ = ["PromptBuilder", "TokenCounter"] 7 | -------------------------------------------------------------------------------- /src/langchain_graphrag/utils/__init__.py: -------------------------------------------------------------------------------- 1 | """Misc utility functions for the GraphRAG project.""" 2 | 3 | from .token_counter import TiktokenCounter 4 | from .uuid import gen_uuid 5 | 6 | __all__ = ["TiktokenCounter", "gen_uuid"] 7 | -------------------------------------------------------------------------------- /src/langchain_graphrag/indexing/graph_clustering/__init__.py: -------------------------------------------------------------------------------- 1 | """Graph clustering module.""" 2 | 3 | from .leiden_community_detector import HierarchicalLeidenCommunityDetector 4 | 5 | __all__ = ["HierarchicalLeidenCommunityDetector"] 6 | -------------------------------------------------------------------------------- /src/langchain_graphrag/indexing/embedding_generation/__init__.py: -------------------------------------------------------------------------------- 1 | """Embedding generation module for indexing.""" 2 | 3 | from .graph import Node2VectorGraphEmbeddingGenerator 4 | 5 | __all__ = [ 6 | "Node2VectorGraphEmbeddingGenerator", 7 | ] 8 | -------------------------------------------------------------------------------- /src/langchain_graphrag/types/graphs/embedding.py: -------------------------------------------------------------------------------- 1 | from typing import Protocol 2 | 3 | import networkx as nx 4 | import numpy as np 5 | 6 | 7 | class GraphEmbeddingGenerator(Protocol): 8 | def run(self, graph: nx.Graph) -> dict[str, np.ndarray]: ... 9 | -------------------------------------------------------------------------------- /src/langchain_graphrag/indexing/embedding_generation/graph/__init__.py: -------------------------------------------------------------------------------- 1 | """Graph Embedding generation module for indexing.""" 2 | 3 | from .node2vec import Node2VectorGraphEmbeddingGenerator 4 | 5 | __all__ = [ 6 | "Node2VectorGraphEmbeddingGenerator", 7 | ] 8 | -------------------------------------------------------------------------------- /src/langchain_graphrag/query/global_search/community_report.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | 4 | @dataclass 5 | class CommunityReport: 6 | id: str 7 | title: str 8 | summary: str 9 | rank: float 10 | weight: float 11 | content: str 12 | -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | --- 2 | version: 2 3 | 4 | build: 5 | os: ubuntu-22.04 6 | tools: 7 | python: "3.10" 8 | 9 | python: 10 | install: 11 | - requirements: requirements-docs.txt 12 | - method: pip 13 | path: . 14 | 15 | mkdocs: 16 | configuration: mkdocs.yml 17 | -------------------------------------------------------------------------------- /src/langchain_graphrag/utils/uuid.py: -------------------------------------------------------------------------------- 1 | import uuid 2 | from random import Random, getrandbits 3 | 4 | 5 | def gen_uuid(rd: Random | None = None): 6 | """Generate a random UUID v4.""" 7 | return uuid.UUID( 8 | int=rd.getrandbits(128) if rd is not None else getrandbits(128), version=4 9 | ).hex 10 | -------------------------------------------------------------------------------- /src/langchain_graphrag/indexing/__init__.py: -------------------------------------------------------------------------------- 1 | """Indexing module.""" 2 | 3 | from .artifacts import IndexerArtifacts 4 | from .simple_indexer import SimpleIndexer 5 | from .text_unit_extractor import TextUnitExtractor 6 | 7 | __all__ = [ 8 | "SimpleIndexer", 9 | "IndexerArtifacts", 10 | "TextUnitExtractor", 11 | ] 12 | -------------------------------------------------------------------------------- /src/langchain_graphrag/query/local_search/__init__.py: -------------------------------------------------------------------------------- 1 | """Local Search module.""" 2 | 3 | from .prompt_builder import LocalSearchPromptBuilder 4 | from .retriever import LocalSearchRetriever 5 | from .search import LocalSearch 6 | 7 | __all__ = [ 8 | "LocalSearch", 9 | "LocalSearchPromptBuilder", 10 | "LocalSearchRetriever", 11 | ] 12 | -------------------------------------------------------------------------------- /src/langchain_graphrag/indexing/graph_generation/entity_relationship_extraction/__init__.py: -------------------------------------------------------------------------------- 1 | """Entity Relationship Extraction module.""" 2 | 3 | from .extractor import EntityRelationshipExtractor 4 | from .prompt_builder import EntityExtractionPromptBuilder 5 | 6 | __all__ = [ 7 | "EntityRelationshipExtractor", 8 | "EntityExtractionPromptBuilder", 9 | ] 10 | -------------------------------------------------------------------------------- /src/langchain_graphrag/indexing/graph_generation/entity_relationship_summarization/__init__.py: -------------------------------------------------------------------------------- 1 | """Entity Relationship Description Summarization Module.""" 2 | 3 | from .prompt_builder import SummarizeDescriptionPromptBuilder 4 | from .summarizer import EntityRelationshipDescriptionSummarizer 5 | 6 | __all__ = [ 7 | "SummarizeDescriptionPromptBuilder", 8 | "EntityRelationshipDescriptionSummarizer", 9 | ] 10 | -------------------------------------------------------------------------------- /src/langchain_graphrag/indexing/report_generation/__init__.py: -------------------------------------------------------------------------------- 1 | """Community Report Generation.""" 2 | 3 | from .generator import CommunityReportGenerator 4 | from .prompt_builder import CommunityReportGenerationPromptBuilder 5 | from .writer import CommunityReportWriter 6 | 7 | __all__ = [ 8 | "CommunityReportGenerator", 9 | "CommunityReportGenerationPromptBuilder", 10 | "CommunityReportWriter", 11 | ] 12 | -------------------------------------------------------------------------------- /src/langchain_graphrag/query/global_search/key_points_aggregator/__init__.py: -------------------------------------------------------------------------------- 1 | """KeyPointsAggregator module.""" 2 | 3 | from .aggregator import KeyPointsAggregator 4 | from .context_builder import KeyPointsContextBuilder 5 | from .prompt_builder import KeyPointsAggregatorPromptBuilder 6 | 7 | __all__ = [ 8 | "KeyPointsAggregatorPromptBuilder", 9 | "KeyPointsContextBuilder", 10 | "KeyPointsAggregator", 11 | ] 12 | -------------------------------------------------------------------------------- /src/langchain_graphrag/query/global_search/key_points_generator/utils.py: -------------------------------------------------------------------------------- 1 | from pydantic import BaseModel, Field 2 | 3 | 4 | class KeyPointInfo(BaseModel): 5 | description: str = Field(description="The description of the key point") 6 | score: float = Field(description="The score of the key point") 7 | 8 | 9 | class KeyPointsResult(BaseModel): 10 | points: list[KeyPointInfo] = Field(description="the points") 11 | -------------------------------------------------------------------------------- /src/langchain_graphrag/query/global_search/key_points_generator/__init__.py: -------------------------------------------------------------------------------- 1 | """Key Points generator module.""" 2 | 3 | from .context_builder import CommunityReportContextBuilder 4 | from .generator import KeyPointsGenerator 5 | from .prompt_builder import KeyPointsGeneratorPromptBuilder 6 | 7 | __all__ = [ 8 | "KeyPointsGeneratorPromptBuilder", 9 | "CommunityReportContextBuilder", 10 | "KeyPointsGenerator", 11 | ] 12 | -------------------------------------------------------------------------------- /.devcontainer/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM mcr.microsoft.com/devcontainers/base:debian 2 | 3 | ARG USERNAME=vscode 4 | 5 | ENV DEBIAN_FRONTEND=noninteractive 6 | RUN apt-get update \ 7 | && apt-get upgrade -y \ 8 | && apt-get -y install --no-install-recommends build-essential libmagic-dev iputils-ping \ 9 | && apt-get autoremove -y \ 10 | && apt-get clean -y \ 11 | && rm -rf /var/lib/apt/lists/* 12 | 13 | ENV SHELL /bin/zsh 14 | -------------------------------------------------------------------------------- /src/langchain_graphrag/utils/token_counter.py: -------------------------------------------------------------------------------- 1 | """Counter for Tiktoken based tokens.""" 2 | 3 | import tiktoken 4 | 5 | from langchain_graphrag.types.tokens import TokenCounter 6 | 7 | 8 | class TiktokenCounter(TokenCounter): 9 | def __init__(self, encoding_name: str = "cl100k_base"): 10 | self.tokenizer = tiktoken.get_encoding(encoding_name) 11 | 12 | def count_tokens(self, text: str) -> int: 13 | return len(self.tokenizer.encode(text)) 14 | -------------------------------------------------------------------------------- /examples/simple-app/.env.example: -------------------------------------------------------------------------------- 1 | LANGCHAIN_GRAPHRAG_AZURE_OPENAI_CHAT_API_KEY= 2 | LANGCHAIN_GRAPHRAG_AZURE_OPENAI_CHAT_ENDPOINT= 3 | LANGCHAIN_GRAPHRAG_AZURE_OPENAI_CHAT_DEPLOYMENT= 4 | LANGCHAIN_GRAPHRAG_AZURE_OPENAI_EMBED_API_KEY= 5 | LANGCHAIN_GRAPHRAG_AZURE_OPENAI_EMBED_ENDPOINT= 6 | LANGCHAIN_GRAPHRAG_AZURE_OPENAI_EMBED_DEPLOYMENT= 7 | LANGCHAIN_GRAPHRAG_OPENAI_CHAT_API_KEY= 8 | LANGCHAIN_GRAPHRAG_OPENAI_EMBED_API_KEY= 9 | LANGCHAIN_API_KEY= 10 | OLLAMA_HOST=http://host.docker.internal:11434 11 | ANONYMIZED_TELEMETRY=False 12 | -------------------------------------------------------------------------------- /src/langchain_graphrag/query/global_search/key_points_generator/_output_parser.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | try: 4 | from langchain.output_parsers import PydanticOutputParser 5 | except ImportError: 6 | # Langchain >= 1.0.0 7 | from langchain_core.output_parsers import PydanticOutputParser 8 | 9 | from .utils import KeyPointsResult 10 | 11 | 12 | class KeyPointsOutputParser(PydanticOutputParser): 13 | def __init__(self, **kwargs: dict[str, Any]): 14 | super().__init__(pydantic_object=KeyPointsResult, **kwargs) 15 | -------------------------------------------------------------------------------- /src/langchain_graphrag/indexing/artifacts_generation/__init__.py: -------------------------------------------------------------------------------- 1 | """Artifacts generation module for indexing.""" 2 | 3 | from .entities import EntitiesArtifactsGenerator 4 | from .relationships import RelationshipsArtifactsGenerator 5 | from .reports import CommunitiesReportsArtifactsGenerator 6 | from .text_units import TextUnitsArtifactsGenerator 7 | 8 | __all__ = [ 9 | "EntitiesArtifactsGenerator", 10 | "RelationshipsArtifactsGenerator", 11 | "TextUnitsArtifactsGenerator", 12 | "CommunitiesReportsArtifactsGenerator", 13 | ] 14 | -------------------------------------------------------------------------------- /src/langchain_graphrag/indexing/report_generation/_output_parser.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | try: 4 | from langchain.output_parsers import PydanticOutputParser 5 | except ImportError: 6 | # Langchain >= 1.0.0 7 | from langchain_core.output_parsers import PydanticOutputParser 8 | 9 | from .utils import CommunityReportResult 10 | 11 | 12 | class CommunityReportOutputParser(PydanticOutputParser): 13 | def __init__(self, **kwargs: dict[str, Any]): 14 | super().__init__(pydantic_object=CommunityReportResult, **kwargs) 15 | -------------------------------------------------------------------------------- /src/langchain_graphrag/types/prompts.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Protocol 2 | 3 | from langchain_core.output_parsers.base import BaseOutputParser 4 | from langchain_core.prompts import BasePromptTemplate 5 | from typing_extensions import Unpack 6 | 7 | 8 | class PromptBuilder(Protocol): 9 | def build(self) -> tuple[BasePromptTemplate, BaseOutputParser]: ... 10 | 11 | 12 | class IndexingPromptBuilder(PromptBuilder, Protocol): 13 | def prepare_chain_input( 14 | self, **kwargs: Unpack[dict[str, Any]] 15 | ) -> dict[str, str]: ... 16 | -------------------------------------------------------------------------------- /examples/simple-app/app/main.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from indexer import app as indexer_app 4 | from query import app as query_app 5 | from typer import Typer 6 | 7 | app = Typer() 8 | app.add_typer(indexer_app, name="indexer") 9 | app.add_typer(query_app, name="query") 10 | 11 | 12 | if __name__ == "__main__": 13 | logging.basicConfig(level=logging.INFO) 14 | logging.getLogger("httpx").setLevel(logging.WARNING) 15 | logging.getLogger("gensim").setLevel(logging.WARNING) 16 | logging.getLogger("langchain_graphrag").setLevel(logging.INFO) 17 | app() 18 | -------------------------------------------------------------------------------- /src/langchain_graphrag/query/local_search/context_builders/__init__.py: -------------------------------------------------------------------------------- 1 | """Context builders for local search.""" 2 | 3 | from .communities_reports import CommunitiesReportsContextBuilder 4 | from .context import ContextBuilder 5 | from .entities import EntitiesContextBuilder 6 | from .relationships import RelationshipsContextBuilder 7 | from .text_units import TextUnitsContextBuilder 8 | 9 | __all__ = [ 10 | "EntitiesContextBuilder", 11 | "ContextBuilder", 12 | "RelationshipsContextBuilder", 13 | "TextUnitsContextBuilder", 14 | "CommunitiesReportsContextBuilder", 15 | ] 16 | -------------------------------------------------------------------------------- /src/langchain_graphrag/query/local_search/context_selectors/__init__.py: -------------------------------------------------------------------------------- 1 | """Context selectors for local search.""" 2 | 3 | from .communities_reports import CommunitiesReportsSelector 4 | from .context import ContextSelectionResult, ContextSelector 5 | from .entities import EntitiesSelector 6 | from .relationships import RelationshipsSelector 7 | from .text_units import TextUnitsSelector 8 | 9 | __all__ = [ 10 | "ContextSelector", 11 | "ContextSelectionResult", 12 | "EntitiesSelector", 13 | "TextUnitsSelector", 14 | "RelationshipsSelector", 15 | "CommunitiesReportsSelector", 16 | ] 17 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/pre-commit/pre-commit-hooks 3 | rev: v4.6.0 4 | hooks: 5 | - id: check-ast 6 | - id: check-toml 7 | - id: check-yaml 8 | - id: end-of-file-fixer 9 | exclude: ^(\.devcontainer|\.vscode).+json 10 | - id: trailing-whitespace 11 | - id: mixed-line-ending 12 | - repo: https://github.com/pycqa/isort 13 | rev: 5.13.2 14 | hooks: 15 | - id: isort 16 | - repo: https://github.com/astral-sh/ruff-pre-commit 17 | rev: v0.6.0 18 | hooks: 19 | - id: ruff 20 | - id: ruff-format 21 | -------------------------------------------------------------------------------- /scripts/post-create.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | USERNAME=vscode 4 | 5 | echo "changing zshrc theme to ys ..." 6 | sed -i s/^ZSH_THEME=".\+"$/ZSH_THEME=\"ys\"/g ~/.zshrc 7 | 8 | echo "sym link zsh_history ..." 9 | mkdir -p /commandhistory 10 | touch /commandhistory/.zsh_history 11 | chown -R $USERNAME /commandhistory 12 | 13 | SNIPPET="export PROMPT_COMMAND='history -a' && export HISTFILE=/commandhistory/.zsh_history" 14 | echo "$SNIPPET" >> "/home/$USERNAME/.zshrc" 15 | 16 | echo "Setting up uv project..." 17 | # Create virtual environment and install dependencies 18 | uv sync 19 | 20 | echo "Setup complete!" 21 | -------------------------------------------------------------------------------- /src/langchain_graphrag/indexing/report_generation/writer.py: -------------------------------------------------------------------------------- 1 | from io import StringIO 2 | 3 | from .utils import CommunityReportResult 4 | 5 | 6 | class CommunityReportWriter: 7 | def write(self, report: CommunityReportResult) -> str: 8 | fp = StringIO() 9 | 10 | try: 11 | fp.write(f"# {report.title}\n") 12 | fp.write(f"{report.summary}\n") 13 | 14 | for finding in report.findings: 15 | fp.write(f"## {finding.summary}\n") 16 | fp.write(f"{finding.explanation}\n") 17 | 18 | return fp.getvalue() 19 | finally: 20 | fp.close() 21 | -------------------------------------------------------------------------------- /src/langchain_graphrag/indexing/graph_generation/__init__.py: -------------------------------------------------------------------------------- 1 | """Graph generation Module.""" 2 | 3 | from .entity_relationship_extraction import ( 4 | EntityExtractionPromptBuilder, 5 | EntityRelationshipExtractor, 6 | ) 7 | from .entity_relationship_summarization import ( 8 | EntityRelationshipDescriptionSummarizer, 9 | SummarizeDescriptionPromptBuilder, 10 | ) 11 | from .generator import GraphGenerator 12 | from .graphs_merger import GraphsMerger 13 | 14 | __all__ = [ 15 | "EntityRelationshipExtractor", 16 | "EntityExtractionPromptBuilder", 17 | "EntityRelationshipDescriptionSummarizer", 18 | "SummarizeDescriptionPromptBuilder", 19 | "GraphGenerator", 20 | "GraphsMerger", 21 | ] 22 | -------------------------------------------------------------------------------- /src/langchain_graphrag/indexing/graph_generation/entity_relationship_summarization/_default_prompts.py: -------------------------------------------------------------------------------- 1 | # ruff: noqa: E501 2 | 3 | DEFAULT_PROMPT = """ 4 | You are a helpful assistant responsible for generating a comprehensive summary of the data provided below. 5 | Given one or two entities, and a list of descriptions, all related to the same entity or group of entities. 6 | Please concatenate all of these into a single, comprehensive description. Make sure to include information collected from all the descriptions. 7 | If the provided descriptions are contradictory, please resolve the contradictions and provide a single, coherent summary. 8 | Make sure it is written in third person, and include the entity names so we the have full context. 9 | 10 | ####### 11 | -Data- 12 | Entities: {entity_name} 13 | Description List: {description_list} 14 | ####### 15 | Output: 16 | 17 | """ 18 | -------------------------------------------------------------------------------- /src/langchain_graphrag/types/graphs/community.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import NewType, Protocol 3 | 4 | import networkx as nx 5 | 6 | CommunityId = NewType("CommunityId", int) 7 | CommunityLevel = NewType("CommunityLevel", int) 8 | 9 | 10 | @dataclass 11 | class CommunityNode: 12 | name: str 13 | parent_cluster: CommunityId | None 14 | is_final_cluster: bool 15 | 16 | 17 | @dataclass 18 | class Community: 19 | id: CommunityId 20 | nodes: list[CommunityNode] 21 | 22 | 23 | @dataclass 24 | class CommunityDetectionResult: 25 | communities: dict[CommunityLevel, dict[CommunityId, Community]] 26 | 27 | def communities_at_level(self, level: CommunityLevel) -> list[Community]: 28 | return list(self.communities[level].values()) 29 | 30 | 31 | class CommunityDetector(Protocol): 32 | def run(self, graph: nx.Graph) -> CommunityDetectionResult: ... 33 | -------------------------------------------------------------------------------- /src/langchain_graphrag/query/local_search/retriever.py: -------------------------------------------------------------------------------- 1 | from langchain_core.callbacks import CallbackManagerForRetrieverRun 2 | from langchain_core.documents import Document 3 | from langchain_core.retrievers import BaseRetriever 4 | 5 | from langchain_graphrag.indexing.artifacts import IndexerArtifacts 6 | 7 | from .context_builders import ContextBuilder 8 | from .context_selectors import ContextSelector 9 | 10 | 11 | class LocalSearchRetriever(BaseRetriever): 12 | context_selector: ContextSelector 13 | context_builder: ContextBuilder 14 | artifacts: IndexerArtifacts 15 | 16 | def _get_relevant_documents( 17 | self, 18 | query: str, 19 | *, 20 | run_manager: CallbackManagerForRetrieverRun, # noqa: ARG002 21 | ) -> list[Document]: 22 | context_selection_result = self.context_selector.run( 23 | query=query, 24 | artifacts=self.artifacts, 25 | ) 26 | 27 | return self.context_builder(context_selection_result) 28 | -------------------------------------------------------------------------------- /scripts/gen_ref_pages.py: -------------------------------------------------------------------------------- 1 | """Generate the code reference pages and navigation.""" 2 | 3 | from pathlib import Path 4 | 5 | import mkdocs_gen_files 6 | 7 | nav = mkdocs_gen_files.Nav() 8 | 9 | root = Path(__file__).parent.parent 10 | src = root / "src" 11 | 12 | for path in sorted(src.rglob("*.py")): 13 | if path.name.startswith("_"): 14 | continue 15 | module_path = path.relative_to(src).with_suffix("") 16 | doc_path = path.relative_to(src).with_suffix(".md") 17 | full_doc_path = Path("reference", doc_path) 18 | 19 | parts = tuple(module_path.parts) 20 | 21 | if parts[-1] == "__init__": 22 | parts = parts[:-1] 23 | doc_path = doc_path.with_name("index.md") 24 | full_doc_path = full_doc_path.with_name("index.md") 25 | elif parts[-1] == "__main__": 26 | continue 27 | 28 | nav[parts] = doc_path.as_posix() 29 | 30 | with mkdocs_gen_files.open(full_doc_path, "w") as fd: 31 | ident = ".".join(parts) 32 | fd.write(f"::: {ident}") 33 | 34 | mkdocs_gen_files.set_edit_path(full_doc_path, path.relative_to(root)) 35 | 36 | with mkdocs_gen_files.open("reference/SUMMARY.md", "w") as nav_file: 37 | nav_file.writelines(nav.build_literate_nav()) 38 | -------------------------------------------------------------------------------- /src/langchain_graphrag/query/global_search/community_weight_calculator.py: -------------------------------------------------------------------------------- 1 | """Compute the weight of the community.""" 2 | 3 | import pandas as pd 4 | 5 | from langchain_graphrag.types.graphs.community import CommunityId 6 | 7 | 8 | class CommunityWeightCalculator: 9 | def __init__(self, *, should_normalize: bool = True): 10 | self._should_normalize = should_normalize 11 | 12 | def __call__( 13 | self, 14 | df_entities: pd.DataFrame, 15 | df_reports: pd.DataFrame, 16 | ) -> dict[CommunityId, float]: 17 | result: dict[CommunityId, float] = {} 18 | 19 | for _, row in df_reports.iterrows(): 20 | entities = row["entities"] 21 | # get rows from entities dataframe where ids are in entities 22 | df_entities_filtered = df_entities[df_entities["id"].isin(entities)] 23 | # get the text_units from df_entities_filtered 24 | text_units = df_entities_filtered["text_unit_ids"].explode().unique() 25 | result[row["community_id"]] = len(text_units) 26 | 27 | if self._should_normalize: 28 | max_weight = max(result.values()) 29 | for community_id in result: 30 | result[community_id] = result[community_id] / max_weight 31 | 32 | return result 33 | -------------------------------------------------------------------------------- /tests/test_summarizer.py: -------------------------------------------------------------------------------- 1 | import networkx as nx 2 | from langchain_core.language_models import FakeListLLM 3 | 4 | from langchain_graphrag.indexing.graph_generation import ( 5 | EntityRelationshipDescriptionSummarizer, 6 | SummarizeDescriptionPromptBuilder, 7 | ) 8 | 9 | 10 | def test_summarizer(): 11 | llm = FakeListLLM(responses=["fake summary 1", "fake summary 2", "fake summary 3"]) 12 | prompt_builder = SummarizeDescriptionPromptBuilder() 13 | summarizer = EntityRelationshipDescriptionSummarizer( 14 | prompt_builder, 15 | llm, 16 | ) 17 | 18 | graph1 = nx.Graph() 19 | graph1.add_node("node1", source_id=["1"], description=[" "]) 20 | graph1.add_node( 21 | "node2", 22 | source_id=["2"], 23 | description=["description1 of node 2", "description2 of node 2"], 24 | ) 25 | graph1.add_node("node3", source_id=["3"], description=["description3"]) 26 | 27 | graph1.add_edge( 28 | "node1", 29 | "node2", 30 | source_id=["1"], 31 | description=["edge description1", "This edge has another description"], 32 | weight=2, 33 | ) 34 | 35 | graph1_updated = summarizer.invoke(graph1) 36 | 37 | print("\n") 38 | print(graph1_updated.nodes(data=True)) 39 | 40 | print("\n") 41 | print(graph1_updated.edges(data=True)) 42 | -------------------------------------------------------------------------------- /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | name: 'Publish on pypi' 2 | on: 3 | release: 4 | types: [created] 5 | push: 6 | branches: [main] 7 | 8 | jobs: 9 | publish: 10 | if: startsWith(github.ref, 'refs/tags/') 11 | runs-on: ubuntu-latest 12 | environment: 13 | name: pypi 14 | url: https://pypi.org/p/langchain-graphrag 15 | permissions: 16 | contents: read 17 | packages: write 18 | attestations: write 19 | id-token: write 20 | 21 | steps: 22 | 23 | - name: Checkout (GitHub) 24 | uses: actions/checkout@v3 25 | 26 | - name: Login to GitHub Container Registry 27 | uses: docker/login-action@v2 28 | with: 29 | registry: ghcr.io 30 | username: ${{ github.repository_owner }} 31 | password: ${{ secrets.GITHUB_TOKEN }} 32 | 33 | - name: Build and run dev container task 34 | uses: devcontainers/ci@v0.3 35 | with: 36 | imageName: ghcr.io/ksachdeva/langchain-graphrag-devcontainer 37 | cacheFrom: ghcr.io/ksachdeva/langchain-graphrag-devcontainer 38 | runCmd: uv build 39 | 40 | - name: Publish package distributions to PyPI 41 | uses: pypa/gh-action-pypi-publish@release/v1 42 | with: 43 | packages-dir: dist 44 | skip-existing: true 45 | verbose: true 46 | -------------------------------------------------------------------------------- /src/langchain_graphrag/indexing/text_unit_extractor.py: -------------------------------------------------------------------------------- 1 | """A module to extract text units from the document.""" 2 | 3 | import uuid 4 | from typing import TypedDict 5 | 6 | import pandas as pd 7 | from langchain_core.documents import Document 8 | from langchain_text_splitters import TextSplitter 9 | from tqdm import tqdm 10 | 11 | 12 | class _TextUnit(TypedDict): 13 | id: str 14 | document_id: str 15 | text_unit: str 16 | 17 | 18 | class TextUnitExtractor: 19 | def __init__(self, text_splitter: TextSplitter): 20 | self._text_splitter = text_splitter 21 | 22 | def run(self, documents: list[Document]) -> pd.DataFrame: 23 | response: list[_TextUnit] = [] 24 | 25 | # TODO: Parallize this 26 | for document in tqdm(documents, desc="Processing documents ..."): 27 | text_units = self._text_splitter.split_text(document.page_content) 28 | 29 | document_id = document.id if document.id else str(uuid.uuid4()) 30 | 31 | for t in tqdm(text_units, desc="Extracting text units ..."): 32 | response.append( # noqa: PERF401 33 | _TextUnit( 34 | document_id=document_id, 35 | id=str(uuid.uuid4()), 36 | text_unit=t, 37 | ) 38 | ) 39 | 40 | return pd.DataFrame.from_records(response) 41 | -------------------------------------------------------------------------------- /src/langchain_graphrag/query/local_search/search.py: -------------------------------------------------------------------------------- 1 | from langchain_core.documents import Document 2 | from langchain_core.language_models import LanguageModelLike 3 | from langchain_core.retrievers import BaseRetriever 4 | from langchain_core.runnables import Runnable, RunnablePassthrough 5 | 6 | from langchain_graphrag.types.prompts import PromptBuilder 7 | 8 | 9 | def _format_docs(documents: list[Document]) -> str: 10 | context_data = [d.page_content for d in documents] 11 | context_data_str: str = "\n".join(context_data) 12 | return context_data_str 13 | 14 | 15 | class LocalSearch: 16 | def __init__( 17 | self, 18 | llm: LanguageModelLike, 19 | prompt_builder: PromptBuilder, 20 | retriever: BaseRetriever, 21 | *, 22 | output_raw: bool = False, 23 | ): 24 | self._llm = llm 25 | self._prompt_builder = prompt_builder 26 | self._retriever = retriever 27 | self._output_raw = output_raw 28 | 29 | def __call__(self) -> Runnable: 30 | prompt, output_parser = self._prompt_builder.build() 31 | 32 | base_chain = prompt | self._llm 33 | 34 | if not self._output_raw: 35 | base_chain = base_chain | output_parser 36 | 37 | search_chain: Runnable = { 38 | "context_data": self._retriever | _format_docs, 39 | "local_query": RunnablePassthrough(), 40 | } | base_chain 41 | 42 | return search_chain 43 | -------------------------------------------------------------------------------- /src/langchain_graphrag/indexing/graph_generation/generator.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | from typing import Callable 3 | 4 | import networkx as nx 5 | import pandas as pd 6 | 7 | from .entity_relationship_extraction import EntityRelationshipExtractor 8 | from .entity_relationship_summarization import EntityRelationshipDescriptionSummarizer 9 | from .graphs_merger import GraphsMerger 10 | 11 | 12 | class GraphGenerator: 13 | def __init__( 14 | self, 15 | er_extractor: EntityRelationshipExtractor, 16 | graphs_merger: GraphsMerger, 17 | er_description_summarizer: EntityRelationshipDescriptionSummarizer, 18 | graph_sanitizer: Callable[[nx.Graph], nx.Graph] | None = None, 19 | ): 20 | self._er_extractor = er_extractor 21 | self._graphs_merger = graphs_merger 22 | self._graph_sanitizer = graph_sanitizer 23 | self._er_description_summarizer = er_description_summarizer 24 | 25 | def run(self, text_units: pd.DataFrame) -> tuple[nx.Graph, nx.Graph]: 26 | er_graphs = self._er_extractor.invoke(text_units) 27 | er_merged_graph = self._graphs_merger(er_graphs) 28 | er_sanitized_graph = ( 29 | self._graph_sanitizer(er_merged_graph) 30 | if self._graph_sanitizer 31 | else er_merged_graph 32 | ) 33 | er_summarized_graph = self._er_description_summarizer.invoke( 34 | deepcopy(er_sanitized_graph) 35 | ) 36 | return er_sanitized_graph, er_summarized_graph 37 | -------------------------------------------------------------------------------- /mkdocs.yml: -------------------------------------------------------------------------------- 1 | site_name: GraphRAG 2 | 3 | theme: 4 | name: material 5 | highlightjs: true 6 | features: 7 | - navigation.tabs 8 | 9 | plugins: 10 | - search 11 | - gen-files: 12 | scripts: 13 | - scripts/gen_ref_pages.py 14 | - literate-nav: 15 | nav_file: SUMMARY.md 16 | - section-index 17 | - mkdocstrings: 18 | handlers: 19 | python: 20 | paths: [src] 21 | options: 22 | show_source: false 23 | heading_level: 1 24 | docstring_style: google 25 | show_if_no_docstring: true 26 | members_order: alphabetical 27 | - mkdocs-jupyter 28 | 29 | markdown_extensions: 30 | - markdown_include.include: 31 | base_path: . 32 | - admonition 33 | - pymdownx.superfences: 34 | custom_fences: 35 | - name: mermaid 36 | class: mermaid 37 | format: !!python/name:pymdownx.superfences.fence_code_format 38 | 39 | nav: 40 | - Home: index.md 41 | - Architecture Overview: architecture/overview.md 42 | - Indexing Pipeline: guides/indexing_pipeline.md 43 | - Query System: guides/query_system.md 44 | - Data Flow & Examples: guides/data_flow_examples.md 45 | - Advanced Examples: 46 | - Graph Extraction Overview: guides/graph_extraction/index.md 47 | - Entity Relationship Extraction: guides/graph_extraction/er_extraction.ipynb 48 | - Graph Generator: guides/graph_extraction/graph_generator.ipynb 49 | 50 | repo_name: ksachdeva/langchain-graphrag 51 | repo_url: https://github.com/ksachdeva/langchain-graphrag 52 | -------------------------------------------------------------------------------- /src/langchain_graphrag/query/global_search/key_points_generator/generator.py: -------------------------------------------------------------------------------- 1 | from langchain_core.documents import Document 2 | from langchain_core.language_models import LanguageModelLike 3 | from langchain_core.runnables import Runnable, RunnableParallel 4 | 5 | from langchain_graphrag.types.prompts import PromptBuilder 6 | 7 | from .context_builder import CommunityReportContextBuilder 8 | 9 | 10 | def _format_docs(documents: list[Document]) -> str: 11 | context_data = [d.page_content for d in documents] 12 | context_data_str: str = "\n".join(context_data) 13 | return context_data_str 14 | 15 | 16 | class KeyPointsGenerator: 17 | def __init__( 18 | self, 19 | llm: LanguageModelLike, 20 | prompt_builder: PromptBuilder, 21 | context_builder: CommunityReportContextBuilder, 22 | ): 23 | self._llm = llm 24 | self._prompt_builder = prompt_builder 25 | self._context_builder = context_builder 26 | 27 | def __call__(self) -> Runnable: 28 | prompt, output_parser = self._prompt_builder.build() 29 | 30 | documents = self._context_builder() 31 | 32 | chains: list[Runnable] = [] 33 | 34 | for d in documents: 35 | d_context_data = _format_docs([d]) 36 | d_prompt = prompt.partial(context_data=d_context_data) 37 | generator_chain: Runnable = d_prompt | self._llm | output_parser 38 | chains.append(generator_chain) 39 | 40 | analysts = [f"Analayst-{i}" for i in range(1, len(chains) + 1)] 41 | 42 | return RunnableParallel(dict(zip(analysts, chains, strict=True))) 43 | -------------------------------------------------------------------------------- /tests/test_graphs_merger.py: -------------------------------------------------------------------------------- 1 | import networkx as nx 2 | 3 | from langchain_graphrag.indexing.graph_generation.graphs_merger import ( 4 | merge_edges, 5 | merge_nodes, 6 | ) 7 | 8 | 9 | def test_node_merge(): 10 | target_graph = nx.Graph() 11 | 12 | graph1 = nx.Graph() 13 | graph1.add_node("node1", text_unit_ids=["1"], description=[" "]) 14 | graph1.add_node("node2", text_unit_ids=["2"], description=["description2"]) 15 | graph1.add_node("node3", text_unit_ids=["3"], description=["description3"]) 16 | 17 | graph1.add_edge( 18 | "node1", 19 | "node2", 20 | text_unit_ids=["1"], 21 | description=["edge description1"], 22 | weight=2, 23 | ) 24 | 25 | graph2 = nx.Graph() 26 | graph2.add_node( 27 | "node1", text_unit_ids=["4"], description=["description1 from graph2"] 28 | ) 29 | graph2.add_node( 30 | "node2", text_unit_ids=["5"], description=["description2 from graph2"] 31 | ) 32 | graph2.add_node( 33 | "node4", text_unit_ids=["6"], description=["description4 from graph2"] 34 | ) 35 | 36 | graph2.add_edge( 37 | "node1", 38 | "node2", 39 | text_unit_ids=["9"], 40 | description=["edge description1"], 41 | weight=4, 42 | ) 43 | 44 | merge_nodes(target_graph=target_graph, sub_graph=graph1) 45 | merge_edges(target_graph=target_graph, sub_graph=graph1) 46 | 47 | merge_nodes(target_graph=target_graph, sub_graph=graph2) 48 | merge_edges(target_graph=target_graph, sub_graph=graph2) 49 | 50 | print(target_graph.nodes(data=True)) 51 | print(target_graph.edges(data=True)) 52 | -------------------------------------------------------------------------------- /src/langchain_graphrag/indexing/report_generation/generator.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import networkx as nx 4 | from langchain_core.language_models import LanguageModelLike 5 | from langchain_core.runnables.config import RunnableConfig 6 | 7 | from langchain_graphrag.types.graphs.community import Community 8 | from langchain_graphrag.types.prompts import IndexingPromptBuilder 9 | 10 | from .prompt_builder import CommunityReportGenerationPromptBuilder 11 | from .utils import CommunityReportResult 12 | 13 | 14 | class CommunityReportGenerator: 15 | def __init__( 16 | self, 17 | prompt_builder: IndexingPromptBuilder, 18 | llm: LanguageModelLike, 19 | *, 20 | chain_config: RunnableConfig | None = None, 21 | ): 22 | prompt, output_parser = prompt_builder.build() 23 | self._chain = prompt | llm | output_parser 24 | self._prompt_builder = prompt_builder 25 | self._chain_config = chain_config 26 | 27 | @staticmethod 28 | def build_default( 29 | llm: LanguageModelLike, 30 | *, 31 | chain_config: RunnableConfig | None = None, 32 | ) -> CommunityReportGenerator: 33 | return CommunityReportGenerator( 34 | prompt_builder=CommunityReportGenerationPromptBuilder(), 35 | llm=llm, 36 | chain_config=chain_config, 37 | ) 38 | 39 | def invoke(self, community: Community, graph: nx.Graph) -> CommunityReportResult: 40 | chain_input = self._prompt_builder.prepare_chain_input( 41 | community=community, 42 | graph=graph, 43 | ) 44 | 45 | return self._chain.invoke(input=chain_input, config=self._chain_config) 46 | -------------------------------------------------------------------------------- /src/langchain_graphrag/query/local_search/context_builders/text_units.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import pandas as pd 4 | from langchain_core.documents import Document 5 | 6 | from langchain_graphrag.types.tokens import TokenCounter 7 | 8 | _LOGGER = logging.getLogger(__name__) 9 | 10 | 11 | class TextUnitsContextBuilder: 12 | def __init__( 13 | self, 14 | *, 15 | context_name: str = "Sources", 16 | column_delimiter: str = "|", 17 | max_tokens: int = 8000, 18 | token_counter: TokenCounter, 19 | ): 20 | self._context_name = context_name 21 | self._column_delimiter = column_delimiter 22 | self._max_tokens = max_tokens 23 | self._token_counter = token_counter 24 | 25 | def __call__(self, text_units: pd.DataFrame) -> Document: 26 | context_text = f"-----{self._context_name}-----" + "\n" 27 | header = ["id", "text"] 28 | 29 | context_text += self._column_delimiter.join(header) + "\n" 30 | token_count = self._token_counter.count_tokens(context_text) 31 | 32 | for row in text_units.itertuples(): 33 | new_context = [str(row.short_id), row.text_unit] 34 | new_context_text = self._column_delimiter.join(new_context) + "\n" 35 | 36 | new_token_count = self._token_counter.count_tokens(new_context_text) 37 | if token_count + new_token_count > self._max_tokens: 38 | _LOGGER.warning( 39 | f"Stopping text units context build at {token_count} tokens ..." 40 | ) 41 | break 42 | 43 | context_text += new_context_text 44 | token_count += new_token_count 45 | 46 | return Document( 47 | page_content=context_text, 48 | metadata={"token_count": token_count}, 49 | ) 50 | -------------------------------------------------------------------------------- /src/langchain_graphrag/indexing/embedding_generation/graph/node2vec.py: -------------------------------------------------------------------------------- 1 | """Graph embedding generation using node2vec.""" 2 | 3 | import graspologic as gl 4 | import networkx as nx 5 | import numpy as np 6 | 7 | from langchain_graphrag.indexing._graph_utils import stable_largest_connected_component 8 | from langchain_graphrag.types.graphs.embedding import GraphEmbeddingGenerator 9 | 10 | 11 | class Node2VectorGraphEmbeddingGenerator(GraphEmbeddingGenerator): 12 | def __init__( 13 | self, 14 | *, 15 | use_lcc: bool = True, 16 | dimensions: int = 1536, 17 | num_walks: int = 10, 18 | walk_length: int = 40, 19 | window_size: int = 2, 20 | num_iter: int = 3, 21 | random_seed: int = 86, 22 | ): 23 | self._use_lcc = use_lcc 24 | self._dimensions = dimensions 25 | self._num_walks = num_walks 26 | self._walk_length = walk_length 27 | self._window_size = window_size 28 | self._num_iter = num_iter 29 | self._random_seed = random_seed 30 | 31 | def run( 32 | self, 33 | graph: nx.Graph, 34 | ) -> dict[str, np.ndarray]: 35 | if self._use_lcc: 36 | graph = stable_largest_connected_component(graph) 37 | 38 | lcc_tensors = gl.embed.node2vec_embed( 39 | graph=graph, 40 | dimensions=self._dimensions, 41 | window_size=self._window_size, 42 | iterations=self._num_iter, 43 | num_walks=self._num_walks, 44 | walk_length=self._walk_length, 45 | random_seed=self._random_seed, 46 | ) 47 | 48 | embeddings = lcc_tensors[0] 49 | nodes = lcc_tensors[1] 50 | 51 | pairs = zip(nodes, embeddings, strict=True) 52 | sorted_pairs = sorted(pairs, key=lambda x: x[0]) 53 | 54 | return dict(sorted_pairs) 55 | -------------------------------------------------------------------------------- /src/langchain_graphrag/indexing/graph_generation/entity_relationship_summarization/prompt_builder.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Any 3 | 4 | from langchain_core.output_parsers.base import BaseOutputParser 5 | from langchain_core.output_parsers.string import StrOutputParser 6 | from langchain_core.prompts import BasePromptTemplate, PromptTemplate 7 | from typing_extensions import Unpack 8 | 9 | from langchain_graphrag.types.prompts import IndexingPromptBuilder 10 | 11 | from ._default_prompts import DEFAULT_PROMPT 12 | 13 | 14 | class SummarizeDescriptionPromptBuilder(IndexingPromptBuilder): 15 | def __init__( 16 | self, 17 | *, 18 | prompt: str | None = None, 19 | prompt_path: Path | None = None, 20 | ): 21 | self._prompt: str | None 22 | if prompt is None and prompt_path is None: 23 | self._prompt = DEFAULT_PROMPT 24 | else: 25 | self._prompt = prompt 26 | 27 | self._prompt_path = prompt_path 28 | 29 | def build(self) -> tuple[BasePromptTemplate, BaseOutputParser]: 30 | if self._prompt: 31 | prompt_template = PromptTemplate.from_template(self._prompt) 32 | else: 33 | assert self._prompt_path is not None 34 | prompt_template = PromptTemplate.from_file(self._prompt_path) 35 | 36 | return prompt_template, StrOutputParser() 37 | 38 | def prepare_chain_input(self, **kwargs: Unpack[dict[str, Any]]) -> dict[str, str]: 39 | entity_name = kwargs.get("entity_name", None) 40 | description_list = kwargs.get("description_list", None) 41 | if entity_name is None: 42 | raise ValueError("entity_name is required") 43 | if description_list is None: 44 | raise ValueError("description_list is required") 45 | 46 | return dict(description_list=description_list, entity_name=entity_name) 47 | -------------------------------------------------------------------------------- /src/langchain_graphrag/query/local_search/context_selectors/entities.py: -------------------------------------------------------------------------------- 1 | """Select the entities to be used in the local search.""" 2 | 3 | import logging 4 | 5 | import pandas as pd 6 | from langchain_core.vectorstores import VectorStore 7 | 8 | _LOGGER = logging.getLogger(__name__) 9 | 10 | 11 | class EntitiesSelector: 12 | def __init__(self, vector_store: VectorStore, top_k: int): 13 | self._vector_store = vector_store 14 | self._top_k = top_k 15 | 16 | def run(self, query: str, df_entities: pd.DataFrame) -> pd.DataFrame: 17 | """Select the entities to be used in the local search.""" 18 | documents_with_scores = ( 19 | self._vector_store.similarity_search_with_relevance_scores( 20 | query, 21 | self._top_k, 22 | ) 23 | ) 24 | 25 | # Relying on metadata to get the entity_ids 26 | # These returned entities are ranked by similarity 27 | entity_ids_with_scores = pd.DataFrame.from_records( 28 | [ 29 | dict(id=doc.metadata["entity_id"], score=score) 30 | for doc, score in documents_with_scores 31 | ] 32 | ) 33 | 34 | # Filter the entities dataframe to only include the selected entities 35 | selected_entities = df_entities[ 36 | df_entities["id"].isin(entity_ids_with_scores["id"]) 37 | ] 38 | 39 | selected_entities = ( 40 | selected_entities.merge(entity_ids_with_scores, on="id") 41 | .sort_values(by="score", ascending=False) 42 | .reset_index(drop=True) 43 | ) 44 | 45 | if _LOGGER.isEnabledFor(logging.DEBUG): 46 | import tableprint 47 | 48 | tableprint.banner("Selected Entities") 49 | tableprint.dataframe(selected_entities[["title", "degree", "score"]]) 50 | 51 | return selected_entities 52 | -------------------------------------------------------------------------------- /src/langchain_graphrag/query/local_search/context_builders/communities_reports.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import pandas as pd 4 | from langchain_core.documents import Document 5 | 6 | from langchain_graphrag.types.tokens import TokenCounter 7 | 8 | _LOGGER = logging.getLogger(__name__) 9 | 10 | 11 | class CommunitiesReportsContextBuilder: 12 | def __init__( 13 | self, 14 | *, 15 | context_name: str = "Reports", 16 | column_delimiter: str = "|", 17 | max_tokens: int = 8000, 18 | token_counter: TokenCounter, 19 | ): 20 | self._context_name = context_name 21 | self._column_delimiter = column_delimiter 22 | self._max_tokens = max_tokens 23 | self._token_counter = token_counter 24 | 25 | def __call__(self, communities_reports: pd.DataFrame) -> Document: 26 | context_text = f"-----{self._context_name}-----" + "\n" 27 | header = ["id", "title", "content"] 28 | 29 | context_text += self._column_delimiter.join(header) + "\n" 30 | token_count = self._token_counter.count_tokens(context_text) 31 | 32 | for report in communities_reports.itertuples(): 33 | new_context = [ 34 | str(report.community_id), 35 | report.title, 36 | report.content, 37 | ] 38 | 39 | new_context_text = self._column_delimiter.join(new_context) + "\n" 40 | new_token_count = self._token_counter.count_tokens(new_context_text) 41 | 42 | if token_count + new_token_count > self._max_tokens: 43 | _LOGGER.warning( 44 | f"Stopping communities context build at {token_count} tokens ..." 45 | ) 46 | break 47 | 48 | context_text += new_context_text 49 | token_count += new_token_count 50 | 51 | return Document( 52 | page_content=context_text, 53 | metadata={"token_count": token_count}, 54 | ) 55 | -------------------------------------------------------------------------------- /src/langchain_graphrag/query/global_search/key_points_generator/prompt_builder.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | from langchain_core.output_parsers.base import BaseOutputParser 4 | from langchain_core.prompts import ( 5 | BasePromptTemplate, 6 | ChatPromptTemplate, 7 | SystemMessagePromptTemplate, 8 | ) 9 | 10 | from langchain_graphrag.types.prompts import PromptBuilder 11 | 12 | from ._output_parser import KeyPointsOutputParser 13 | from ._system_prompt import MAP_SYSTEM_PROMPT 14 | 15 | 16 | class KeyPointsGeneratorPromptBuilder(PromptBuilder): 17 | def __init__( 18 | self, 19 | *, 20 | system_prompt: str | None = None, 21 | system_prompt_path: Path | None = None, 22 | show_references: bool = True, 23 | repeat_instructions: bool = True, 24 | ): 25 | self._system_prompt: str | None 26 | if system_prompt is None and system_prompt_path is None: 27 | self._system_prompt = MAP_SYSTEM_PROMPT 28 | else: 29 | self._system_prompt = system_prompt 30 | 31 | self._system_prompt_path = system_prompt_path 32 | self._show_references = show_references 33 | self._repeat_instructions = repeat_instructions 34 | 35 | def build(self) -> tuple[BasePromptTemplate, BaseOutputParser]: 36 | if self._system_prompt_path: 37 | prompt = Path.read_text(self._system_prompt_path) 38 | else: 39 | assert self._system_prompt is not None 40 | prompt = self._system_prompt 41 | 42 | system_template = SystemMessagePromptTemplate.from_template( 43 | prompt, 44 | template_format="mustache", 45 | partial_variables=dict( 46 | show_references=self._show_references, 47 | repeat_instructions=self._repeat_instructions, 48 | ), 49 | ) 50 | 51 | template = ChatPromptTemplate( 52 | [system_template, ("user", "{{global_query}}")], 53 | template_format="mustache", 54 | ) 55 | return template, KeyPointsOutputParser() 56 | -------------------------------------------------------------------------------- /src/langchain_graphrag/query/global_search/search.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import Iterator 3 | 4 | from langchain_core.runnables import RunnableConfig 5 | 6 | from .key_points_aggregator import KeyPointsAggregator 7 | from .key_points_generator import KeyPointsGenerator 8 | from .key_points_generator.utils import ( 9 | KeyPointsResult, 10 | ) 11 | 12 | _LOGGER = logging.getLogger(__name__) 13 | 14 | 15 | class GlobalSearch: 16 | def __init__( 17 | self, 18 | kp_generator: KeyPointsGenerator, 19 | kp_aggregator: KeyPointsAggregator, 20 | *, 21 | generation_chain_config: RunnableConfig | None = None, 22 | aggregation_chain_config: RunnableConfig | None = None, 23 | ): 24 | self._kp_generator = kp_generator 25 | self._kp_aggregator = kp_aggregator 26 | self._generation_chain_config = generation_chain_config 27 | self._aggregation_chain_config = aggregation_chain_config 28 | 29 | def _get_key_points(self, query: str) -> dict[str, KeyPointsResult]: 30 | generation_chain = self._kp_generator() 31 | response = generation_chain.invoke( 32 | query, 33 | config=self._generation_chain_config, 34 | ) 35 | 36 | if _LOGGER.getEffectiveLevel() == logging.INFO: 37 | for k, v in response.items(): 38 | _LOGGER.info(f"{k} - {len(v.points)}") 39 | 40 | return response 41 | 42 | def invoke(self, query: str) -> str: 43 | aggregation_chain = self._kp_aggregator() 44 | response = self._get_key_points(query) 45 | 46 | return aggregation_chain.invoke( 47 | input=dict(report_data=response, global_query=query), 48 | config=self._aggregation_chain_config, 49 | ) 50 | 51 | def stream(self, query: str) -> Iterator: 52 | aggregation_chain = self._kp_aggregator() 53 | response = self._get_key_points(query) 54 | 55 | return aggregation_chain.stream( 56 | input=dict(report_data=response, global_query=query), 57 | config=self._aggregation_chain_config, 58 | ) 59 | -------------------------------------------------------------------------------- /src/langchain_graphrag/query/global_search/key_points_aggregator/aggregator.py: -------------------------------------------------------------------------------- 1 | import operator 2 | from functools import partial 3 | 4 | from langchain_core.documents import Document 5 | from langchain_core.language_models import LanguageModelLike 6 | from langchain_core.runnables import Runnable, RunnableLambda 7 | 8 | from langchain_graphrag.query.global_search.key_points_generator.utils import ( 9 | KeyPointsResult, 10 | ) 11 | from langchain_graphrag.types.prompts import PromptBuilder 12 | 13 | from .context_builder import KeyPointsContextBuilder 14 | 15 | 16 | def _format_docs(documents: list[Document]) -> str: 17 | context_data = [d.page_content for d in documents] 18 | context_data_str: str = "\n".join(context_data) 19 | return context_data_str 20 | 21 | 22 | def _kp_result_to_docs( 23 | key_points: dict[str, KeyPointsResult], 24 | context_builder: KeyPointsContextBuilder, 25 | ) -> list[Document]: 26 | return context_builder(key_points) 27 | 28 | 29 | class KeyPointsAggregator: 30 | def __init__( 31 | self, 32 | llm: LanguageModelLike, 33 | prompt_builder: PromptBuilder, 34 | context_builder: KeyPointsContextBuilder, 35 | *, 36 | output_raw: bool = False, 37 | ): 38 | self._llm = llm 39 | self._prompt_builder = prompt_builder 40 | self._context_builder = context_builder 41 | self._output_raw = output_raw 42 | 43 | def __call__(self) -> Runnable: 44 | kp_lambda = partial( 45 | _kp_result_to_docs, 46 | context_builder=self._context_builder, 47 | ) 48 | 49 | prompt, output_parser = self._prompt_builder.build() 50 | base_chain = prompt | self._llm 51 | 52 | if not self._output_raw: 53 | base_chain = base_chain | output_parser 54 | 55 | search_chain: Runnable = { 56 | "report_data": operator.itemgetter("report_data") 57 | | RunnableLambda(kp_lambda) 58 | | _format_docs, 59 | "global_query": operator.itemgetter("global_query"), 60 | } | base_chain 61 | 62 | return search_chain 63 | -------------------------------------------------------------------------------- /src/langchain_graphrag/query/local_search/prompt_builder.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | from langchain_core.output_parsers.base import BaseOutputParser 4 | from langchain_core.output_parsers.string import StrOutputParser 5 | from langchain_core.prompts import ( 6 | BasePromptTemplate, 7 | ChatPromptTemplate, 8 | SystemMessagePromptTemplate, 9 | ) 10 | 11 | from langchain_graphrag.types.prompts import PromptBuilder 12 | 13 | from ._system_prompt import LOCAL_SEARCH_SYSTEM_PROMPT 14 | 15 | 16 | class LocalSearchPromptBuilder(PromptBuilder): 17 | def __init__( 18 | self, 19 | *, 20 | system_prompt: str | None = None, 21 | system_prompt_path: Path | None = None, 22 | show_references: bool = True, 23 | repeat_instructions: bool = True, 24 | ): 25 | self._system_prompt: str | None 26 | if system_prompt is None and system_prompt_path is None: 27 | self._system_prompt = LOCAL_SEARCH_SYSTEM_PROMPT 28 | else: 29 | self._system_prompt = system_prompt 30 | 31 | self._system_prompt_path = system_prompt_path 32 | self._show_references = show_references 33 | self._repeat_instructions = repeat_instructions 34 | 35 | def build(self) -> tuple[BasePromptTemplate, BaseOutputParser]: 36 | if self._system_prompt_path: 37 | prompt = Path.read_text(self._system_prompt_path) 38 | else: 39 | assert self._system_prompt is not None 40 | prompt = self._system_prompt 41 | 42 | system_template = SystemMessagePromptTemplate.from_template( 43 | prompt, 44 | partial_variables=dict( 45 | response_type="Multiple Paragraphs", 46 | show_references=self._show_references, 47 | repeat_instructions=self._repeat_instructions, 48 | ), 49 | template_format="mustache", 50 | ) 51 | 52 | template = ChatPromptTemplate( 53 | [system_template, ("user", "{{local_query}}")], 54 | template_format="mustache", 55 | ) 56 | return template, StrOutputParser() 57 | -------------------------------------------------------------------------------- /src/langchain_graphrag/query/global_search/key_points_aggregator/prompt_builder.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | from langchain_core.output_parsers.base import BaseOutputParser 4 | from langchain_core.output_parsers.string import StrOutputParser 5 | from langchain_core.prompts import ( 6 | BasePromptTemplate, 7 | ChatPromptTemplate, 8 | SystemMessagePromptTemplate, 9 | ) 10 | 11 | from langchain_graphrag.types.prompts import PromptBuilder 12 | 13 | from ._system_prompt import REDUCE_SYSTEM_PROMPT 14 | 15 | 16 | class KeyPointsAggregatorPromptBuilder(PromptBuilder): 17 | def __init__( 18 | self, 19 | *, 20 | system_prompt: str | None = None, 21 | system_prompt_path: Path | None = None, 22 | show_references: bool = True, 23 | repeat_instructions: bool = True, 24 | ): 25 | self._system_prompt: str | None 26 | if system_prompt is None and system_prompt_path is None: 27 | self._system_prompt = REDUCE_SYSTEM_PROMPT 28 | else: 29 | self._system_prompt = system_prompt 30 | 31 | self._system_prompt_path = system_prompt_path 32 | self._show_references = show_references 33 | self._repeat_instructions = repeat_instructions 34 | 35 | def build(self) -> tuple[BasePromptTemplate, BaseOutputParser]: 36 | if self._system_prompt_path: 37 | prompt = Path.read_text(self._system_prompt_path) 38 | else: 39 | assert self._system_prompt is not None 40 | prompt = self._system_prompt 41 | 42 | system_template = SystemMessagePromptTemplate.from_template( 43 | prompt, 44 | partial_variables=dict( 45 | response_type="Multiple Paragraphs", 46 | show_references=self._show_references, 47 | repeat_instructions=self._repeat_instructions, 48 | ), 49 | template_format="mustache", 50 | ) 51 | 52 | template = ChatPromptTemplate( 53 | [system_template, ("user", "{{global_query}}")], 54 | template_format="mustache", 55 | ) 56 | return template, StrOutputParser() 57 | -------------------------------------------------------------------------------- /src/langchain_graphrag/query/local_search/context_builders/entities.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import pandas as pd 4 | from langchain_core.documents import Document 5 | 6 | from langchain_graphrag.types.tokens import TokenCounter 7 | 8 | _LOGGER = logging.getLogger(__name__) 9 | 10 | 11 | class EntitiesContextBuilder: 12 | def __init__( 13 | self, 14 | *, 15 | include_rank: bool = True, 16 | context_name: str = "Entities", 17 | rank_heading: str = "number of relationships", 18 | column_delimiter: str = "|", 19 | max_tokens: int = 8000, 20 | token_counter: TokenCounter, 21 | ): 22 | self._include_rank = include_rank 23 | self._context_name = context_name 24 | self._rank_heading = rank_heading 25 | self._column_delimiter = column_delimiter 26 | self._max_tokens = max_tokens 27 | self._token_counter = token_counter 28 | 29 | def __call__(self, entities: pd.DataFrame) -> Document: 30 | context_text = f"-----{self._context_name}-----" + "\n" 31 | header = ["id", "entity", "description"] 32 | if self._include_rank: 33 | header.append(self._rank_heading) 34 | 35 | context_text += self._column_delimiter.join(header) + "\n" 36 | token_count = self._token_counter.count_tokens(context_text) 37 | 38 | for entity in entities.itertuples(): 39 | new_context = [ 40 | str(entity.human_readable_id), 41 | entity.title, 42 | entity.description, 43 | ] 44 | if self._include_rank: 45 | new_context.append(str(entity.degree)) 46 | 47 | new_context_text = self._column_delimiter.join(new_context) + "\n" 48 | 49 | new_token_count = self._token_counter.count_tokens(new_context_text) 50 | if token_count + new_token_count > self._max_tokens: 51 | _LOGGER.warning( 52 | f"Stopping entities context build at {token_count} tokens ..." 53 | ) 54 | break 55 | 56 | context_text += new_context_text 57 | token_count += new_token_count 58 | 59 | return Document( 60 | page_content=context_text, 61 | metadata={"token_count": token_count}, 62 | ) 63 | -------------------------------------------------------------------------------- /src/langchain_graphrag/indexing/graph_clustering/leiden_community_detector.py: -------------------------------------------------------------------------------- 1 | from typing import cast 2 | 3 | import networkx as nx 4 | from graspologic.partition import ( 5 | HierarchicalCluster, 6 | HierarchicalClusters, 7 | hierarchical_leiden, 8 | ) 9 | 10 | from langchain_graphrag.indexing._graph_utils import stable_largest_connected_component 11 | from langchain_graphrag.types.graphs.community import ( 12 | Community, 13 | CommunityDetectionResult, 14 | CommunityDetector, 15 | CommunityId, 16 | CommunityLevel, 17 | CommunityNode, 18 | ) 19 | 20 | 21 | class HierarchicalLeidenCommunityDetector(CommunityDetector): 22 | def __init__( 23 | self, 24 | *, 25 | use_lcc: bool = True, 26 | max_cluster_size: int = 10, 27 | seed: int = 0xDEADBEEF, 28 | ): 29 | self._use_lcc = use_lcc 30 | self._max_cluster_size = max_cluster_size 31 | self._seed = seed 32 | 33 | def run(self, graph: nx.Graph) -> CommunityDetectionResult: 34 | if self._use_lcc: 35 | graph = stable_largest_connected_component(graph) 36 | 37 | community_mapping: HierarchicalClusters = hierarchical_leiden( 38 | graph, 39 | max_cluster_size=self._max_cluster_size, 40 | random_seed=self._seed, 41 | ) 42 | 43 | communities: dict[CommunityLevel, dict[CommunityId, Community]] = {} 44 | 45 | partition: HierarchicalCluster 46 | for partition in community_mapping: 47 | partition_level = cast(CommunityLevel, partition.level) 48 | partition_cluster = cast(CommunityId, partition.cluster) 49 | 50 | communities_at_level = communities.get(partition_level, {}) 51 | community = communities_at_level.get( 52 | partition_cluster, 53 | Community(id=partition_cluster, nodes=[]), 54 | ) 55 | community.nodes.append( 56 | CommunityNode( 57 | name=partition.node, 58 | parent_cluster=cast(CommunityId, partition.parent_cluster), 59 | is_final_cluster=partition.is_final_cluster, 60 | ) 61 | ) 62 | 63 | communities_at_level[partition_cluster] = community 64 | communities[partition_level] = communities_at_level 65 | 66 | return CommunityDetectionResult(communities=communities) 67 | -------------------------------------------------------------------------------- /examples/simple-app/README.md: -------------------------------------------------------------------------------- 1 | # Simple App 2 | 3 | This app shows how you would create various components from `langchain_graphrag` and use them to create a simple app. 4 | 5 | The CLI uses `typer` to create a command line interface. 6 | 7 | ## Install 8 | 9 | At the root of the repository, install all dependencies: 10 | 11 | ```bash 12 | uv sync 13 | ``` 14 | 15 | ## Setup 16 | 17 | **Note**: Make sure to rename `.env.example` to `.env` if you are using OpenAI or AzureOpenAI and fill in the necessary environment variables. 18 | 19 | ## Usage 20 | 21 | ### Help Commands 22 | 23 | ```bash 24 | # Main help 25 | uv run poe simple-app-help 26 | 27 | # Indexer help 28 | uv run poe simple-app-indexer-help 29 | 30 | # Query help 31 | uv run poe simple-app-query-help 32 | ``` 33 | 34 | ### Step 1 - Indexing 35 | 36 | ```bash 37 | # Azure OpenAI (recommended) 38 | uv run poe simple-app-indexer-azure 39 | 40 | # OpenAI 41 | uv run poe simple-app-indexer-openai 42 | 43 | # Ollama (local) 44 | uv run poe simple-app-indexer-ollama 45 | ``` 46 | 47 | ### Step 2 - Global Search 48 | 49 | ```bash 50 | # Azure OpenAI 51 | uv run poe simple-app-global-search-azure --query "What are the main themes in this story?" 52 | 53 | # OpenAI 54 | uv run poe simple-app-global-search-openai --query "What are the main themes in this story?" 55 | 56 | # Ollama 57 | uv run poe simple-app-global-search-ollama --query "What are the main themes in this story?" 58 | ``` 59 | 60 | ### Step 3 - Local Search 61 | 62 | ```bash 63 | # Azure OpenAI 64 | uv run poe simple-app-local-search-azure --query "Who is Scrooge, and what are his main relationships?" 65 | 66 | # OpenAI 67 | uv run poe simple-app-local-search-openai --query "Who is Scrooge, and what are his main relationships?" 68 | 69 | # Ollama 70 | uv run poe simple-app-local-search-ollama --query "Who is Scrooge, and what are his main relationships?" 71 | ``` 72 | 73 | ### Generate Reports 74 | 75 | ```bash 76 | # Generate reports 77 | uv run poe simple-app-report 78 | ``` 79 | 80 | ### Development Commands 81 | 82 | ```bash 83 | # Run tests 84 | uv run poe test 85 | 86 | # Check code quality 87 | uv run poe lint 88 | 89 | # View documentation locally 90 | uv run poe docs-serve 91 | ``` 92 | 93 | ## Notes 94 | 95 | - All commands should be run from the root of the repository 96 | - Azure OpenAI is recommended for best results 97 | - Make sure to index your data before running search queries 98 | - Use your own queries by replacing the `--query` parameter 99 | -------------------------------------------------------------------------------- /src/langchain_graphrag/indexing/report_generation/prompt_builder.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Any 3 | 4 | import networkx as nx 5 | import pandas as pd 6 | from langchain_core.output_parsers.base import BaseOutputParser 7 | from langchain_core.prompts import BasePromptTemplate, PromptTemplate 8 | from typing_extensions import Unpack 9 | 10 | from langchain_graphrag.types.graphs.community import Community 11 | from langchain_graphrag.types.prompts import IndexingPromptBuilder 12 | 13 | from ._default_prompts import DEFAULT_PROMPT 14 | from ._output_parser import CommunityReportOutputParser 15 | from .utils import get_info 16 | 17 | 18 | class CommunityReportGenerationPromptBuilder(IndexingPromptBuilder): 19 | def __init__( 20 | self, 21 | *, 22 | prompt: str | None = None, 23 | prompt_path: Path | None = None, 24 | ): 25 | self._prompt: str | None 26 | if prompt is None and prompt_path is None: 27 | self._prompt = DEFAULT_PROMPT 28 | else: 29 | self._prompt = prompt 30 | 31 | self._prompt_path = prompt_path 32 | 33 | def build(self) -> tuple[BasePromptTemplate, BaseOutputParser]: 34 | if self._prompt: 35 | prompt_template = PromptTemplate.from_template(self._prompt) 36 | else: 37 | assert self._prompt_path is not None 38 | prompt_template = PromptTemplate.from_file(self._prompt_path) 39 | 40 | return prompt_template, CommunityReportOutputParser() 41 | 42 | def prepare_chain_input(self, **kwargs: Unpack[dict[str, Any]]) -> dict[str, str]: 43 | community: Community = kwargs.get("community", None) 44 | graph: nx.Graph = kwargs.get("graph", None) 45 | 46 | if community is None: 47 | raise ValueError("community is required") 48 | 49 | if graph is None: 50 | raise ValueError("graph is required") 51 | 52 | entities, relationships = get_info(community, graph) 53 | 54 | entities_table = pd.DataFrame.from_records(entities).to_csv( 55 | index=False, 56 | ) 57 | 58 | relationships_table = pd.DataFrame.from_records(relationships).to_csv( 59 | index=False, 60 | ) 61 | 62 | input_text = f""" 63 | -----Entities----- 64 | {entities_table} 65 | 66 | -----Relationships----- 67 | {relationships_table} 68 | """ 69 | 70 | return dict(input_text=input_text) 71 | -------------------------------------------------------------------------------- /src/langchain_graphrag/indexing/artifacts_generation/relationships.py: -------------------------------------------------------------------------------- 1 | import networkx as nx 2 | import pandas as pd 3 | from langchain_core.vectorstores import VectorStore 4 | 5 | 6 | class RelationshipsArtifactsGenerator: 7 | def __init__( 8 | self, 9 | relationships_vector_store: VectorStore | None = None, 10 | ): 11 | self._relationships_vector_store = relationships_vector_store 12 | 13 | def _unpack_edges(self, graph: nx.Graph) -> pd.DataFrame: 14 | records = [ 15 | { 16 | "source": source, 17 | "target": target, 18 | "source_id": graph.nodes[source].get("id"), 19 | "target_id": graph.nodes[target].get("id"), 20 | **(edge_data or {}), 21 | } 22 | for source, target, edge_data in graph.edges(data=True) 23 | ] 24 | return pd.DataFrame.from_records(records) 25 | 26 | def _embed_relationships(self, graph: nx.Graph) -> None: 27 | # Extract the information to embed from the graph 28 | # and put in the vectorstore 29 | texts_to_embed = [] 30 | texts_metadata = [] 31 | texts_ids = [] 32 | for source, target, edge_data in graph.edges(data=True): 33 | text_description = edge_data.get("description") 34 | texts_ids.append(edge_data.get("id")) 35 | texts_to_embed.append(text_description) 36 | 37 | # Bug in langchain vectorstore retrival that 38 | # does not populate Document.id field. 39 | # 40 | # Hence add relationship_id as an additional field 41 | # in the metadata 42 | texts_metadata.append( 43 | dict( 44 | source=source, 45 | target=target, 46 | description=text_description, 47 | rank=edge_data.get("rank"), 48 | relationship_id=edge_data.get( 49 | "id" 50 | ), # TODO: Remove once langchain is fixed 51 | ) 52 | ) 53 | 54 | assert self._relationships_vector_store is not None 55 | 56 | self._relationships_vector_store.add_texts( 57 | texts_to_embed, 58 | metadatas=texts_metadata, 59 | ids=texts_ids, 60 | ) 61 | 62 | def run(self, graph: nx.Graph) -> pd.DataFrame: 63 | if self._relationships_vector_store: 64 | self._embed_relationships(graph) 65 | 66 | return self._unpack_edges(graph) 67 | -------------------------------------------------------------------------------- /src/langchain_graphrag/indexing/report_generation/utils.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | from typing import Sequence, TypedDict 3 | 4 | import networkx as nx 5 | from pydantic import BaseModel, Field 6 | 7 | from langchain_graphrag.types.graphs.community import Community 8 | 9 | 10 | class Entity(TypedDict): 11 | id: str 12 | name: str 13 | type: str 14 | description: str 15 | degree: int 16 | 17 | 18 | class Relationship(TypedDict): 19 | id: str 20 | source: str 21 | target: str 22 | description: str 23 | rank: int 24 | 25 | 26 | def entity_from_graph(name: str, graph: nx.Graph) -> Entity: 27 | node = graph.nodes[name] 28 | return Entity( 29 | id=node["human_readable_id"], 30 | name=name, 31 | type=node["type"], 32 | description=node["description"], 33 | degree=node["degree"], 34 | ) 35 | 36 | 37 | def relationship_from_graph( 38 | pair: tuple[str, str], 39 | graph: nx.Graph, 40 | ) -> Relationship: 41 | n1, n2 = pair 42 | edge = graph.edges[n1, n2] 43 | return Relationship( 44 | id=edge["human_readable_id"], 45 | source=n1, 46 | target=n2, 47 | description=edge["description"], 48 | rank=edge["rank"], 49 | ) 50 | 51 | 52 | class CommunityFinding(BaseModel): 53 | summary: str = Field(description="Insight summary") 54 | explanation: str = Field(description="Insight explanation") 55 | 56 | 57 | class CommunityReportResult(BaseModel): 58 | title: str = Field(description="Title of the report") 59 | summary: str = Field(description="Summary of the report") 60 | rating: float = Field(description="Impact severity rating of the report") 61 | rating_explanation: str = Field( 62 | description="Single sentence explanation of the IMPACT severity rating" 63 | ) 64 | findings: list[CommunityFinding] = Field(description="Detailed findings") 65 | 66 | 67 | def get_info( 68 | community: Community, 69 | graph: nx.Graph, 70 | ) -> tuple[Sequence[Entity], Sequence[Relationship]]: 71 | nodes = [n.name for n in community.nodes] 72 | entities = [entity_from_graph(n, graph) for n in nodes] 73 | 74 | node_pairs = itertools.combinations(entities, 2) 75 | 76 | pairs_with_edges = [] 77 | for n1, n2 in node_pairs: 78 | if not graph.has_edge(n1["name"], n2["name"]): 79 | continue 80 | pairs_with_edges.append((n1["name"], n2["name"])) 81 | 82 | relationships = [relationship_from_graph(p, graph) for p in pairs_with_edges] 83 | 84 | return entities, relationships 85 | -------------------------------------------------------------------------------- /src/langchain_graphrag/indexing/artifacts_generation/reports.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import networkx as nx 4 | import pandas as pd 5 | from langchain_core.exceptions import OutputParserException 6 | from tqdm import tqdm 7 | 8 | from langchain_graphrag.indexing.report_generation import ( 9 | CommunityReportGenerator, 10 | CommunityReportWriter, 11 | ) 12 | from langchain_graphrag.types.graphs.community import ( 13 | Community, 14 | CommunityDetectionResult, 15 | ) 16 | 17 | _LOGGER = logging.getLogger(__name__) 18 | 19 | 20 | def _get_entities(community: Community, graph: nx.Graph) -> list[str]: 21 | return [graph.nodes[n.name]["id"] for n in community.nodes] 22 | 23 | 24 | class CommunitiesReportsArtifactsGenerator: 25 | def __init__( 26 | self, 27 | report_generator: CommunityReportGenerator, 28 | report_writer: CommunityReportWriter, 29 | ): 30 | self._report_generator = report_generator 31 | self._report_writer = report_writer 32 | 33 | def run( 34 | self, 35 | detection_result: CommunityDetectionResult, 36 | graph: nx.Graph, 37 | ) -> pd.DataFrame: 38 | reports = [] 39 | 40 | # TODO: Parallelize all this 41 | for level in detection_result.communities: 42 | communities = detection_result.communities_at_level(level) 43 | c_pbar = tqdm(communities) 44 | for c in c_pbar: 45 | c_pbar.set_description_str( 46 | f"Generating report for level={level} commnuity_id={c.id}" 47 | ) 48 | 49 | try: 50 | report = self._report_generator.invoke(community=c, graph=graph) 51 | except OutputParserException: 52 | _LOGGER.exception( 53 | f"Failed to generate report for level={level} community_id={c.id}" 54 | ) 55 | continue 56 | 57 | report_str = self._report_writer.write(report) 58 | entities = _get_entities(c, graph) 59 | 60 | reports.append( 61 | dict( 62 | level=level, 63 | community_id=c.id, 64 | entities=entities, 65 | title=report.title, 66 | summary=report.summary, 67 | rating=report.rating, 68 | rating_explanation=report.rating_explanation, 69 | content=report_str, 70 | ) 71 | ) 72 | 73 | return pd.DataFrame.from_records(reports) 74 | -------------------------------------------------------------------------------- /src/langchain_graphrag/indexing/_graph_utils.py: -------------------------------------------------------------------------------- 1 | # Most functions in this file 2 | # are taken from https://github.com/microsoft/graphrag/ 3 | # and are modified to fit the needs of this project 4 | 5 | import html 6 | from typing import Any, cast 7 | 8 | import networkx as nx 9 | from graspologic.utils import largest_connected_component 10 | 11 | 12 | def _stabilize_graph(graph: nx.Graph) -> nx.Graph: 13 | """Ensure an undirected graph with the same relationships will always be read the same way.""" # noqa: E501 14 | fixed_graph: nx.Graph = nx.DiGraph() if graph.is_directed() else nx.Graph() 15 | 16 | graph_nodes = graph.nodes(data=True) 17 | sorted_nodes = sorted(graph_nodes, key=lambda x: x[0]) 18 | 19 | fixed_graph.add_nodes_from(sorted_nodes) 20 | edges = list(graph.edges(data=True)) 21 | 22 | # If the graph is undirected, we create the edges in a stable way, so we get 23 | # the same results 24 | # for example: 25 | # A -> B 26 | # in graph theory is the same as 27 | # B -> A 28 | # in an undirected graph 29 | # however, this can lead to downstream issues because sometimes 30 | # consumers read graph.nodes() which ends up being [A, B] and sometimes it's [B, A] 31 | # but they base some of their logic on the order of the nodes, so the order ends up 32 | # being important so we sort the nodes in the edge in a stable way, so that we 33 | # always get the same order 34 | if not graph.is_directed(): 35 | 36 | def _sort_source_target(edge: Any) -> tuple[str, str, dict[str, Any]]: 37 | source, target, edge_data = edge 38 | if source > target: 39 | temp = source 40 | source = target 41 | target = temp 42 | return source, target, edge_data 43 | 44 | edges = [_sort_source_target(edge) for edge in edges] 45 | 46 | def _get_edge_key(source: Any, target: Any) -> str: 47 | return f"{source} -> {target}" 48 | 49 | edges = sorted(edges, key=lambda x: _get_edge_key(x[0], x[1])) 50 | 51 | fixed_graph.add_edges_from(edges) 52 | return fixed_graph 53 | 54 | 55 | def normalize_node_names(graph: nx.Graph | nx.DiGraph) -> nx.Graph | nx.DiGraph: 56 | """Normalize node names.""" 57 | node_mapping = {node: html.unescape(node.upper().strip()) for node in graph.nodes()} 58 | return nx.relabel_nodes(graph, node_mapping) 59 | 60 | 61 | def stable_largest_connected_component(graph: nx.Graph) -> nx.Graph: 62 | graph = graph.copy() 63 | graph = cast(nx.Graph, largest_connected_component(graph)) 64 | graph = normalize_node_names(graph) 65 | return _stabilize_graph(graph) 66 | -------------------------------------------------------------------------------- /ruff.toml: -------------------------------------------------------------------------------- 1 | # Exclude a variety of commonly ignored directories. 2 | exclude = [ 3 | ".bzr", 4 | ".direnv", 5 | ".eggs", 6 | ".git", 7 | ".git-rewrite", 8 | ".hg", 9 | ".ipynb_checkpoints", 10 | ".mypy_cache", 11 | ".nox", 12 | ".pants.d", 13 | ".pyenv", 14 | ".pytest_cache", 15 | ".pytype", 16 | ".ruff_cache", 17 | ".svn", 18 | ".tox", 19 | ".venv", 20 | ".vscode", 21 | "__pypackages__", 22 | "_build", 23 | "buck-out", 24 | "build", 25 | "dist", 26 | "node_modules", 27 | "site-packages", 28 | "venv", 29 | ] 30 | 31 | # Same as Black. 32 | line-length = 88 33 | indent-width = 4 34 | 35 | # Assume Python 3.10 36 | target-version = "py310" 37 | 38 | src = ["src", "tests"] 39 | 40 | [lint] 41 | # Enable Pyflakes (`F`) and a subset of the pycodestyle (`E`) codes by default. 42 | select = ["ALL"] 43 | ignore = [ 44 | "COM812", 45 | "COM819", 46 | "D100", 47 | "D203", 48 | "D213", 49 | "D300", 50 | "E111", 51 | "E114", 52 | "E117", 53 | "ISC001", 54 | "ISC002", 55 | "Q000", 56 | "Q001", 57 | "Q002", 58 | "Q003", 59 | "W191", 60 | "T201", 61 | "D101", # Missing docstring in public class 62 | "D102", # Missing docstring in public method 63 | "D103", # Missing docstring in public function 64 | "D107", # Missing docstring in `__init__` 65 | "ANN201", # Missing return type annotation 66 | "ANN204", # Missing return type annotation 67 | "ANN401", 68 | "PD002", # inplace=True 69 | "C408", # Unnecessary dict call 70 | "E402", 71 | "UP035", 72 | "EM101", 73 | "TRY003", 74 | "G004", 75 | "TCH001", 76 | "TCH002", 77 | "PLR0913", 78 | "FIX002", # Line contains TODO, consider resolving the issue 79 | "TD002", # Missing author in TODO; 80 | "TD003", # Missing issue link on the line following this TODO 81 | "S101", # Use of `assert` detected 82 | ] 83 | 84 | 85 | # Allow fix for all enabled rules (when `--fix`) is provided. 86 | fixable = ["ALL"] 87 | unfixable = [] 88 | 89 | # Allow unused variables when underscore-prefixed. 90 | dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" 91 | 92 | [format] 93 | # Like Black, use double quotes for strings. 94 | quote-style = "double" 95 | 96 | # Like Black, indent with spaces, rather than tabs. 97 | indent-style = "space" 98 | 99 | # Like Black, respect magic trailing commas. 100 | skip-magic-trailing-comma = false 101 | 102 | # Like Black, automatically detect the appropriate line ending. 103 | line-ending = "auto" 104 | 105 | [lint.pydocstyle] 106 | convention = "google" 107 | -------------------------------------------------------------------------------- /src/langchain_graphrag/indexing/graph_generation/entity_relationship_summarization/summarizer.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import networkx as nx 4 | from langchain_core.language_models import LanguageModelLike 5 | from langchain_core.runnables.config import RunnableConfig 6 | from tqdm import tqdm 7 | 8 | from langchain_graphrag.types.prompts import IndexingPromptBuilder 9 | 10 | from .prompt_builder import SummarizeDescriptionPromptBuilder 11 | 12 | 13 | class EntityRelationshipDescriptionSummarizer: 14 | def __init__( 15 | self, 16 | prompt_builder: IndexingPromptBuilder, 17 | llm: LanguageModelLike, 18 | *, 19 | chain_config: RunnableConfig | None = None, 20 | ): 21 | prompt, output_parser = prompt_builder.build() 22 | self._summarize_chain = prompt | llm | output_parser 23 | self._prompt_builder = prompt_builder 24 | self._chain_config = chain_config 25 | 26 | @staticmethod 27 | def build_default( 28 | llm: LanguageModelLike, 29 | *, 30 | chain_config: RunnableConfig | None = None, 31 | ) -> EntityRelationshipDescriptionSummarizer: 32 | return EntityRelationshipDescriptionSummarizer( 33 | prompt_builder=SummarizeDescriptionPromptBuilder(), 34 | llm=llm, 35 | chain_config=chain_config, 36 | ) 37 | 38 | def invoke(self, graph: nx.Graph) -> nx.Graph: 39 | for node_name, node in tqdm( 40 | graph.nodes(data=True), desc="Summarizing entities descriptions" 41 | ): 42 | if len(node["description"]) == 1: 43 | node["description"] = node["description"][0] 44 | continue 45 | 46 | chain_input = self._prompt_builder.prepare_chain_input( 47 | entity_name=node_name, description_list=node["description"] 48 | ) 49 | 50 | node["description"] = self._summarize_chain.invoke( 51 | input=chain_input, 52 | config=self._chain_config, 53 | ) 54 | 55 | for from_node, to_node, edge in tqdm( 56 | graph.edges(data=True), desc="Summarizing relationship descriptions" 57 | ): 58 | if len(edge["description"]) == 1: 59 | edge["description"] = edge["description"][0] 60 | continue 61 | 62 | chain_input = self._prompt_builder.prepare_chain_input( 63 | entity_name=f"{from_node} -> {to_node}", 64 | description_list=edge["description"], 65 | ) 66 | 67 | edge["description"] = self._summarize_chain.invoke( 68 | input=chain_input, 69 | config=self._chain_config, 70 | ) 71 | 72 | return graph 73 | -------------------------------------------------------------------------------- /src/langchain_graphrag/query/global_search/key_points_aggregator/context_builder.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from langchain_core.documents import Document 4 | 5 | from langchain_graphrag.query.global_search.key_points_generator.utils import ( 6 | KeyPointsResult, 7 | ) 8 | from langchain_graphrag.utils.token_counter import TokenCounter 9 | 10 | _REPORT_TEMPLATE = """ 11 | --- {analyst} --- 12 | 13 | Importance Score: {score} 14 | 15 | {content} 16 | 17 | """ 18 | 19 | _LOGGER = logging.getLogger(__name__) 20 | 21 | 22 | class KeyPointsContextBuilder: 23 | def __init__( 24 | self, 25 | token_counter: TokenCounter, 26 | max_tokens: int = 8000, 27 | ): 28 | self._token_counter = token_counter 29 | self._max_tokens = max_tokens 30 | 31 | def __call__(self, key_points: dict[str, KeyPointsResult]) -> list[Document]: 32 | documents: list[Document] = [] 33 | total_tokens = 0 34 | max_token_limit_reached = False 35 | for k, v in key_points.items(): 36 | if max_token_limit_reached: 37 | break 38 | for p in v.points: 39 | report = _REPORT_TEMPLATE.format( 40 | analyst=k, 41 | score=p.score, 42 | content=p.description, 43 | ) 44 | report_token = self._token_counter.count_tokens(report) 45 | if total_tokens + report_token > self._max_tokens: 46 | _LOGGER.warning("Reached max tokens for key points aggregation ...") 47 | max_token_limit_reached = True 48 | break 49 | total_tokens += report_token 50 | documents.append( 51 | Document( 52 | page_content=report, 53 | metadata={ 54 | "score": p.score, 55 | "analyst": k, 56 | "token_count": report_token, 57 | }, 58 | ) 59 | ) 60 | 61 | # we now sort the documents based on the 62 | # importance score of the key points 63 | sorted_documents = sorted( 64 | documents, 65 | key=lambda x: x.metadata["score"], 66 | reverse=True, 67 | ) 68 | 69 | if _LOGGER.isEnabledFor(logging.DEBUG): 70 | import tableprint 71 | 72 | rows = [] 73 | tableprint.banner("KP Aggregation Context Token Usage") 74 | for doc in sorted_documents: 75 | rows.append([doc.metadata["analyst"], doc.metadata["token_count"]]) # noqa: PERF401 76 | 77 | tableprint.table(rows, ["Analyst", "Token Count"]) 78 | 79 | return sorted_documents 80 | -------------------------------------------------------------------------------- /src/langchain_graphrag/query/local_search/context_selectors/communities_reports.py: -------------------------------------------------------------------------------- 1 | """Select the communities to be used in the local search.""" 2 | 3 | import logging 4 | 5 | import pandas as pd 6 | 7 | from langchain_graphrag.types.graphs.community import CommunityId, CommunityLevel 8 | 9 | _LOGGER = logging.getLogger(__name__) 10 | 11 | 12 | class CommunitiesReportsSelector: 13 | def __init__( 14 | self, 15 | community_level: CommunityLevel, 16 | *, 17 | must_have_selected_entities: bool = True, 18 | ): 19 | self._community_level = community_level 20 | self._must_have_selected_entities = must_have_selected_entities 21 | 22 | def run( 23 | self, 24 | df_entities: pd.DataFrame, 25 | df_reports: pd.DataFrame, 26 | ) -> pd.DataFrame: 27 | # Filter the communities based on the community level 28 | df_reports_filtered = df_reports[ 29 | df_reports["level"] <= self._community_level 30 | ].copy(deep=True) 31 | 32 | # get the communities we have 33 | selected_communities = df_reports_filtered["community_id"].unique() 34 | 35 | # we will rank the communities based on the 36 | # number of selected entities that belong to a community 37 | community_to_entities_count: dict[CommunityId, int] = {} 38 | for entity in df_entities.itertuples(): 39 | if entity.communities is None: 40 | continue 41 | for community in entity.communities: 42 | if community in selected_communities: 43 | community_to_entities_count[community] = ( 44 | community_to_entities_count.get(community, 0) + 1 45 | ) 46 | 47 | df_reports_filtered["selected_entities_count"] = df_reports_filtered[ 48 | "community_id" 49 | ].apply(lambda community_id: community_to_entities_count.get(community_id, 0)) 50 | 51 | # sort the communities based on the number of selected entities 52 | # and rank of the community 53 | selected_reports = df_reports_filtered.sort_values( 54 | by=["selected_entities_count", "rating"], 55 | ascending=[False, False], 56 | ).reset_index(drop=True) 57 | 58 | if self._must_have_selected_entities: 59 | selected_reports = selected_reports[ 60 | selected_reports["selected_entities_count"] > 0 61 | ] 62 | 63 | if _LOGGER.isEnabledFor(logging.DEBUG): 64 | import tableprint 65 | 66 | tableprint.banner("Selected Reports") 67 | tableprint.dataframe( 68 | selected_reports[ 69 | ["community_id", "level", "selected_entities_count", "rating"] 70 | ] 71 | ) 72 | 73 | return selected_reports 74 | -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | # Welcome to GraphRAG using langchain 2 | 3 | **Transform your documents into searchable knowledge graphs** 4 | 5 | ## Overview 6 | 7 | This library is an implementation of concepts from the paper: 8 | 9 | [From Local to Global: A Graph RAG Approach to Query-Focused Summarization](https://arxiv.org/pdf/2404.16130) 10 | 11 | Below excerpts are taken from the companion website of the paper: 12 | [https://microsoft.github.io/graphrag/](https://microsoft.github.io/graphrag/) 13 | 14 | GraphRAG is a structured, hierarchical approach to Retrieval Augmented Generation (RAG), as opposed to naive semantic-search approaches using plain text snippets. The GraphRAG process involves extracting a knowledge graph out of raw text, building a community hierarchy, generating summaries for these communities, and then leveraging these structures when performing RAG-based tasks. 15 | 16 | There are two main phases in the GraphRAG process: 17 | 18 | ### Indexing 19 | 20 | * Slice up an input corpus into a series of TextUnits, which act as analyzable units for the rest of the process, and provide fine-grained references in our outputs. 21 | 22 | * Extract all entities, relationships, and key claims from the TextUnits using an LLM. 23 | 24 | * Perform a hierarchical clustering of the graph using the Leiden technique. 25 | 26 | * Generate summaries of each community and its constituents from the bottom-up. This aids in holistic understanding of the dataset. 27 | 28 | ### Query 29 | 30 | At query time, these structures are used to provide materials for the LLM context window when answering a question. The primary query modes are: 31 | 32 | * Global Search for reasoning about holistic questions about the corpus by leveraging the community summaries. 33 | 34 | * Local Search for reasoning about specific entities by fanning-out to their neighbors and associated concepts. 35 | 36 | ### Differences from the official implementation 37 | 38 | There is an official implementation of the paper available at 39 | [https://github.com/microsoft/graphrag](https://github.com/microsoft/graphrag) 40 | 41 | The main differeneces are: 42 | 43 | - Usage of [langchain](https://python.langchain.com/) as the foundation 44 | - Support for LLMs and Embedding models other than the ones provided by Azure OpenAI 45 | - Focus on modularity, readability, and extensibility 46 | - Does not assume any workflow engine and leave it to the application 47 | 48 | --- 49 | 50 | ## Installation 51 | 52 | ```bash 53 | pip install langchain-graphrag 54 | ``` 55 | 56 | ## Documentation 57 | 58 | ### 1. **[Architecture Overview](architecture/overview.md)** 59 | Understand how GraphRAG works and when to use Local vs Global search 60 | 61 | ### 2. **[Indexing Pipeline](guides/indexing_pipeline.md)** 62 | How to build knowledge graphs from your documents with technical implementation details 63 | 64 | ### 3. **[Query System](guides/query_system.md)** 65 | Local Search vs Global Search with practical examples 66 | 67 | ### 4. **[Data Flow & Examples](guides/data_flow_examples.md)** 68 | Real data transformations through each pipeline step with actual JSON examples 69 | 70 | ### 5. **[Advanced Examples](guides/graph_extraction/index.md)** 71 | Jupyter notebooks for component-level customization and development -------------------------------------------------------------------------------- /.devcontainer/devcontainer.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "langchain-graphrag", 3 | "dockerFile": "Dockerfile", 4 | "customizations": { 5 | "vscode": { 6 | "settings": { 7 | "[python]": { 8 | "editor.tabSize": 4, 9 | "editor.insertSpaces": true, 10 | "editor.formatOnSave": true, 11 | "editor.defaultFormatter": "charliermarsh.ruff", 12 | "testing.unittestEnabled": false, 13 | "testing.pytestEnabled": true, 14 | "testing.pytestArgs": [ 15 | "tests" 16 | ], 17 | "defaultInterpreterPath": "./venv/bin/python" 18 | }, 19 | "files.exclude": { 20 | "**/.git": true, 21 | "**/.svn": true, 22 | "**/.hg": true, 23 | "**/CVS": true, 24 | "**/.DS_Store": true, 25 | "**/__pycache__": true 26 | }, 27 | "terminal.integrated.defaultProfile.linux": "zsh", 28 | "terminal.integrated.profiles.linux": { 29 | "bash": { 30 | "path": "bash", 31 | "icon": "terminal-bash" 32 | }, 33 | "zsh": { 34 | "path": "zsh" 35 | }, 36 | "fish": { 37 | "path": "fish" 38 | }, 39 | "tmux": { 40 | "path": "tmux", 41 | "icon": "terminal-tmux" 42 | }, 43 | "pwsh": { 44 | "path": "pwsh", 45 | "icon": "terminal-powershell" 46 | } 47 | } 48 | }, 49 | "extensions": [ 50 | "ms-python.python", 51 | "charliermarsh.ruff", 52 | "ms-python.vscode-pylance", 53 | "ms-toolsai.jupyter", 54 | "visualstudioexptteam.vscodeintellicode", 55 | "ms-python.mypy-type-checker", 56 | "github.vscode-github-actions" 57 | ] 58 | } 59 | }, 60 | "features": { 61 | "ghcr.io/devcontainers/features/common-utils:2": { 62 | "installOhMyZshConfig": false, 63 | "configureZshAsDefaultShell": true 64 | }, 65 | // Python 66 | "ghcr.io/devcontainers/features/python:1": { 67 | "version": "3.10" 68 | }, 69 | // Rust (required by few python libraries) 70 | "ghcr.io/devcontainers/features/rust:1": {}, 71 | // Enable Docker (via Docker-in-Docker) 72 | "ghcr.io/devcontainers/features/docker-in-docker:2": {}, 73 | // Modern shell utils 74 | "ghcr.io/mikaello/devcontainer-features/modern-shell-utils:1": {}, 75 | // uv (Python package manager) 76 | "ghcr.io/jsburckhardt/devcontainer-features/uv:1": {} 77 | }, 78 | "mounts": [ 79 | "source=devcontainer-zshhistory,target=/commandhistory,type=volume" 80 | ], 81 | "postCreateCommand": "bash scripts/post-create.sh" 82 | } -------------------------------------------------------------------------------- /src/langchain_graphrag/query/local_search/_system_prompt.py: -------------------------------------------------------------------------------- 1 | # ruff: noqa 2 | 3 | LOCAL_SEARCH_SYSTEM_PROMPT = """ 4 | ---Role--- 5 | 6 | You are a helpful assistant responding to questions about data in the tables provided. 7 | 8 | ---Goal--- 9 | 10 | Generate a response of the target length and format that responds to the user's question, summarizing all information in the input data tables appropriate for the response length and format, and incorporating any relevant general knowledge. 11 | 12 | If you don't know the answer, just say so. Do not make anything up. 13 | 14 | {{#show_references}} 15 | Points supported by data should list their data references as follows: 16 | 17 | "This is an example sentence supported by multiple data references [Data: (record ids); (record ids)]." 18 | 19 | Do not list more than 5 record ids in a single reference. Instead, list the top 5 most relevant record ids and add "+more" to indicate that there are more. 20 | {{/show_references}} 21 | 22 | For example: 23 | 24 | "Person X is the owner of Company Y and subject to many allegations of wrongdoing {{#show_references}}[Data: Sources (15, 16), Reports (1), Entities (5, 7); Relationships (23); Claims (2, 7, 34, 46, 64, +more)]{{/show_references}}." 25 | 26 | {{#show_references}} 27 | where 15, 16, 1, 5, 7, 23, 2, 7, 34, 46, and 64 represent the id (not the index) of the relevant data record. 28 | {{/show_references}} 29 | 30 | Do not include information where the supporting evidence for it is not provided. 31 | 32 | ---Target response length and format--- 33 | 34 | {{response_type}} 35 | 36 | Add sections and commentary to the response as appropriate for the length and format. Style the response in markdown. 37 | 38 | ---Data tables--- 39 | 40 | {{context_data}} 41 | 42 | {{#repeat_instructions}} 43 | 44 | ---Goal--- 45 | 46 | Generate a response of the target length and format that responds to the user's question, summarizing all information in the input data tables appropriate for the response length and format, and incorporating any relevant general knowledge. 47 | 48 | If you don't know the answer, just say so. Do not make anything up. 49 | 50 | {{#show_references}} 51 | Points supported by data should list their data references as follows: 52 | 53 | "This is an example sentence supported by multiple data references [Data: (record ids); (record ids)]." 54 | 55 | Do not list more than 5 record ids in a single reference. Instead, list the top 5 most relevant record ids and add "+more" to indicate that there are more. 56 | {{/show_references}} 57 | 58 | For example: 59 | 60 | "Person X is the owner of Company Y and subject to many allegations of wrongdoing {{#show_references}}[Data: Sources (15, 16), Reports (1), Entities (5, 7); Relationships (23); Claims (2, 7, 34, 46, 64, +more)]{{/show_references}}." 61 | 62 | {{#show_references}}where 15, 16, 1, 5, 7, 23, 2, 7, 34, 46, and 64 represent the id (not the index) of the relevant data record{{/show_references}}. 63 | 64 | Do not include information where the supporting evidence for it is not provided. 65 | 66 | ---Target response length and format--- 67 | 68 | {{response_type}} 69 | 70 | Add sections and commentary to the response as appropriate for the length and format. Style the response in markdown. 71 | 72 | {{/repeat_instructions}} 73 | """ 74 | -------------------------------------------------------------------------------- /src/langchain_graphrag/query/local_search/context_builders/relationships.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import pandas as pd 4 | from langchain_core.documents import Document 5 | 6 | from langchain_graphrag.query.local_search.context_selectors.relationships import ( 7 | RelationshipsSelectionResult, 8 | ) 9 | from langchain_graphrag.types.tokens import TokenCounter 10 | 11 | _LOGGER = logging.getLogger(__name__) 12 | 13 | 14 | class RelationshipsContextBuilder: 15 | def __init__( 16 | self, 17 | *, 18 | include_weight: bool = True, 19 | context_name: str = "Relationships", 20 | column_delimiter: str = "|", 21 | max_tokens: int = 8000, 22 | token_counter: TokenCounter, 23 | ): 24 | self._include_weight = include_weight 25 | self._context_name = context_name 26 | self._column_delimiter = column_delimiter 27 | self._max_tokens = max_tokens 28 | self._token_counter = token_counter 29 | 30 | def __call__( 31 | self, 32 | selected_relationships: RelationshipsSelectionResult, 33 | ) -> Document: 34 | all_context_text = f"-----{self._context_name}-----" + "\n" 35 | header = ["id", "source", "target", "description"] 36 | if self._include_weight: 37 | header.append("weight") 38 | 39 | all_context_text += self._column_delimiter.join(header) + "\n" 40 | all_token_count = self._token_counter.count_tokens(all_context_text) 41 | 42 | def _build_context_text( 43 | relationships: pd.DataFrame, 44 | context_text: str, 45 | token_count: int, 46 | ) -> tuple[str, int]: 47 | for relationship in relationships.itertuples(): 48 | new_context = [ 49 | str(relationship.human_readable_id), 50 | relationship.source, 51 | relationship.target, 52 | relationship.description, 53 | ] 54 | if self._include_weight: 55 | new_context.append(str(relationship.weight)) 56 | 57 | new_context_text = self._column_delimiter.join(new_context) + "\n" 58 | new_token_count = self._token_counter.count_tokens(new_context_text) 59 | 60 | if token_count + new_token_count > self._max_tokens: 61 | _LOGGER.warning( 62 | f"Stopping relationships context build at {token_count} tokens..." # noqa: E501 63 | ) 64 | return context_text, token_count 65 | 66 | context_text += new_context_text 67 | token_count += new_token_count 68 | 69 | return context_text, token_count 70 | 71 | all_context_text, all_token_count = _build_context_text( 72 | selected_relationships.in_network_relationships, 73 | all_context_text, 74 | all_token_count, 75 | ) 76 | 77 | if all_token_count < self._max_tokens: 78 | all_context_text, all_token_count = _build_context_text( 79 | selected_relationships.out_network_relationships, 80 | all_context_text, 81 | all_token_count, 82 | ) 83 | 84 | return Document( 85 | page_content=all_context_text, 86 | metadata={"token_count": all_token_count}, 87 | ) 88 | -------------------------------------------------------------------------------- /.vscode/launch.json: -------------------------------------------------------------------------------- 1 | { 2 | // Use IntelliSense to learn about possible attributes. 3 | // Hover to view descriptions of existing attributes. 4 | // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 5 | "version": "0.2.0", 6 | "configurations": [ 7 | { 8 | "name": "simple-app-indexing-azure", 9 | "type": "debugpy", 10 | "request": "launch", 11 | "program": "${workspaceFolder}/examples/simple-app/app/main.py", 12 | "justMyCode": false, 13 | "console": "integratedTerminal", 14 | "env": { 15 | "PYTHONPATH": "${workspaceFolder}" 16 | }, 17 | "args": [ 18 | "indexer", 19 | "index", 20 | "--input-file", 21 | "${workspaceFolder}/examples/input-data/book.txt", 22 | "--output-dir", 23 | "${workspaceFolder}/tmp", 24 | "--cache-dir", 25 | "${workspaceFolder}/tmp/cache", 26 | "--llm-type", 27 | "azure_openai", 28 | "--llm-model", 29 | "gpt-4o", 30 | "--embedding-type", 31 | "azure_openai", 32 | "--embedding-model", 33 | "text-embedding-3-small", 34 | ] 35 | }, 36 | { 37 | "name": "simple-app-global-query-azure", 38 | "type": "debugpy", 39 | "request": "launch", 40 | "program": "${workspaceFolder}/examples/simple-app/app/main.py", 41 | "justMyCode": false, 42 | "console": "integratedTerminal", 43 | "env": { 44 | "PYTHONPATH": "${workspaceFolder}" 45 | }, 46 | "args": [ 47 | "query", 48 | "global-search", 49 | "--output-dir", 50 | "${workspaceFolder}/tmp", 51 | "--cache-dir", 52 | "${workspaceFolder}/tmp/cache", 53 | "--query", 54 | "What are the top themes in this story?", 55 | "--llm-type", 56 | "azure_openai", 57 | "--llm-model", 58 | "gpt-4o" 59 | ] 60 | }, 61 | { 62 | "name": "simple-app-local-query-azure", 63 | "type": "debugpy", 64 | "request": "launch", 65 | "program": "${workspaceFolder}/examples/simple-app/app/main.py", 66 | "justMyCode": false, 67 | "console": "integratedTerminal", 68 | "env": { 69 | "PYTHONPATH": "${workspaceFolder}" 70 | }, 71 | "args": [ 72 | "query", 73 | "local-search", 74 | "--output-dir", 75 | "${workspaceFolder}/tmp", 76 | "--cache-dir", 77 | "${workspaceFolder}/tmp/cache", 78 | "--query", 79 | "Who is Scrooge, and what are his main relationships?", 80 | "--llm-type", 81 | "azure_openai", 82 | "--llm-model", 83 | "gpt-4o", 84 | "--embedding-type", 85 | "azure_openai", 86 | "--embedding-model", 87 | "text-embedding-3-small", 88 | ] 89 | } 90 | ] 91 | } -------------------------------------------------------------------------------- /src/langchain_graphrag/query/local_search/context_selectors/text_units.py: -------------------------------------------------------------------------------- 1 | """Build the TextUnit context for the LocalSearch algorithm.""" 2 | 3 | import logging 4 | from typing import TypedDict 5 | 6 | import pandas as pd 7 | 8 | _LOGGER = logging.getLogger(__name__) 9 | 10 | 11 | class SelectedTextUnit(TypedDict): 12 | id: str 13 | short_id: str 14 | entity_score: float 15 | relationship_score: int 16 | text_unit: str 17 | 18 | 19 | def compute_relationship_score( 20 | df_relationships: pd.DataFrame, 21 | df_text_relationships: pd.DataFrame, 22 | entity_title: str, 23 | ) -> int: 24 | relationships_subset = df_relationships[ 25 | df_relationships["id"].isin(df_text_relationships) 26 | ] 27 | 28 | source_count = (relationships_subset["source"] == entity_title).sum() 29 | target_count = (relationships_subset["target"] == entity_title).sum() 30 | 31 | return source_count + target_count 32 | 33 | 34 | class TextUnitsSelector: 35 | def run( 36 | self, 37 | df_entities: pd.DataFrame, 38 | df_relationships: pd.DataFrame, 39 | df_text_units: pd.DataFrame, 40 | ) -> pd.DataFrame: 41 | """Build the TextUnit context for the LocalSearch algorithm.""" 42 | selected_text_units: dict[str, SelectedTextUnit] = {} 43 | 44 | def _process_text_unit_id(text_unit_id: str) -> SelectedTextUnit: 45 | df_texts_units_subset = df_text_units[df_text_units["id"] == text_unit_id] 46 | text_relationship_ids = df_texts_units_subset["relationship_ids"].explode() 47 | 48 | relationship_score = compute_relationship_score( 49 | df_relationships, 50 | text_relationship_ids, 51 | entity.title, 52 | ) 53 | 54 | text_unit = df_texts_units_subset["text_unit"].iloc[0] 55 | short_id = df_texts_units_subset.index.to_numpy()[0] 56 | 57 | return SelectedTextUnit( 58 | id=text_unit_id, 59 | short_id=short_id, 60 | entity_score=entity.score, 61 | relationship_score=relationship_score, 62 | text_unit=text_unit, 63 | ) 64 | 65 | def _process_entity(entity) -> None: # noqa: ANN001 66 | for text_unit_id in entity.text_unit_ids: 67 | if text_unit_id in selected_text_units: 68 | continue 69 | selected_text_units[text_unit_id] = _process_text_unit_id(text_unit_id) 70 | 71 | for entity in df_entities.itertuples(): 72 | _process_entity(entity) 73 | 74 | df_selected_text_units = pd.DataFrame.from_records( 75 | list(selected_text_units.values()) 76 | ) 77 | 78 | # sort it by 79 | # descending order of entity_score 80 | # and then descending order of relationship_score 81 | df_selected_text_units = df_selected_text_units.sort_values( 82 | by=["entity_score", "relationship_score"], 83 | ascending=[False, False], 84 | ).reset_index(drop=True) 85 | 86 | if _LOGGER.isEnabledFor(logging.DEBUG): 87 | import tableprint 88 | 89 | tableprint.banner("Selected Text units") 90 | tableprint.dataframe( 91 | df_selected_text_units[["id", "entity_score", "relationship_score"]] 92 | ) 93 | 94 | return df_selected_text_units 95 | -------------------------------------------------------------------------------- /src/langchain_graphrag/query/local_search/context_builders/context.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import logging 4 | 5 | from langchain_core.documents import Document 6 | 7 | from langchain_graphrag.query.local_search.context_selectors import ( 8 | ContextSelectionResult, 9 | ) 10 | from langchain_graphrag.types.tokens import TokenCounter 11 | 12 | from .communities_reports import CommunitiesReportsContextBuilder 13 | from .entities import EntitiesContextBuilder 14 | from .relationships import RelationshipsContextBuilder 15 | from .text_units import TextUnitsContextBuilder 16 | 17 | _LOGGER = logging.getLogger(__name__) 18 | 19 | 20 | class ContextBuilder: 21 | def __init__( 22 | self, 23 | entities_context_builder: EntitiesContextBuilder, 24 | realtionships_context_builder: RelationshipsContextBuilder, 25 | text_units_context_builder: TextUnitsContextBuilder, 26 | communities_reports_context_builder: CommunitiesReportsContextBuilder, 27 | ): 28 | self._entities_context_builder = entities_context_builder 29 | self._relationships_context_builder = realtionships_context_builder 30 | self._text_units_context_builder = text_units_context_builder 31 | self._communities_reports_context_builder = communities_reports_context_builder 32 | 33 | @staticmethod 34 | def build_default(token_counter: TokenCounter) -> ContextBuilder: 35 | return ContextBuilder( 36 | entities_context_builder=EntitiesContextBuilder( 37 | token_counter=token_counter, 38 | ), 39 | realtionships_context_builder=RelationshipsContextBuilder( 40 | token_counter=token_counter, 41 | ), 42 | text_units_context_builder=TextUnitsContextBuilder( 43 | token_counter=token_counter, 44 | ), 45 | communities_reports_context_builder=CommunitiesReportsContextBuilder( 46 | token_counter=token_counter, 47 | ), 48 | ) 49 | 50 | def __call__(self, result: ContextSelectionResult) -> list[Document]: 51 | entities_document = self._entities_context_builder(result.entities) 52 | relationships_document = self._relationships_context_builder( 53 | result.relationships 54 | ) 55 | text_units_document = self._text_units_context_builder(result.text_units) 56 | communities_reports_document = self._communities_reports_context_builder( 57 | result.communities_reports 58 | ) 59 | 60 | documents = [ 61 | entities_document, 62 | relationships_document, 63 | text_units_document, 64 | communities_reports_document, 65 | ] 66 | 67 | if _LOGGER.isEnabledFor(logging.DEBUG): 68 | import tableprint 69 | 70 | rows = [] 71 | tableprint.banner("Context Token Usage") 72 | for name, doc in zip( 73 | ["Entities", "Relationships", "Text Units", "Communities Reports"], 74 | [ 75 | entities_document, 76 | relationships_document, 77 | text_units_document, 78 | communities_reports_document, 79 | ], 80 | strict=True, 81 | ): 82 | rows.append([name, doc.metadata["token_count"]]) 83 | 84 | tableprint.table(rows, ["Context", "Token Count"]) 85 | 86 | return documents 87 | -------------------------------------------------------------------------------- /src/langchain_graphrag/query/local_search/context_selectors/context.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from typing import NamedTuple 4 | 5 | import pandas as pd 6 | from langchain_core.vectorstores import VectorStore 7 | 8 | from langchain_graphrag.indexing.artifacts import IndexerArtifacts 9 | from langchain_graphrag.types.graphs.community import CommunityLevel 10 | 11 | from .communities_reports import CommunitiesReportsSelector 12 | from .entities import EntitiesSelector 13 | from .relationships import RelationshipsSelectionResult, RelationshipsSelector 14 | from .text_units import TextUnitsSelector 15 | 16 | 17 | class ContextSelectionResult(NamedTuple): 18 | entities: pd.DataFrame 19 | text_units: pd.DataFrame 20 | relationships: RelationshipsSelectionResult 21 | communities_reports: pd.DataFrame 22 | 23 | 24 | class ContextSelector: 25 | def __init__( 26 | self, 27 | entities_selector: EntitiesSelector, 28 | text_units_selector: TextUnitsSelector, 29 | relationships_selector: RelationshipsSelector, 30 | communities_reports_selector: CommunitiesReportsSelector, 31 | ): 32 | self._entities_selector = entities_selector 33 | self._text_units_selector = text_units_selector 34 | self._relationships_selector = relationships_selector 35 | self._communities_reports_selector = communities_reports_selector 36 | 37 | @staticmethod 38 | def build_default( 39 | entities_vector_store: VectorStore, 40 | entities_top_k: int, 41 | community_level: CommunityLevel, 42 | ) -> ContextSelector: 43 | return ContextSelector( 44 | entities_selector=EntitiesSelector( 45 | vector_store=entities_vector_store, 46 | top_k=entities_top_k, 47 | ), 48 | text_units_selector=TextUnitsSelector(), 49 | relationships_selector=RelationshipsSelector(), 50 | communities_reports_selector=CommunitiesReportsSelector( 51 | community_level=community_level 52 | ), 53 | ) 54 | 55 | def run( 56 | self, 57 | query: str, 58 | artifacts: IndexerArtifacts, 59 | ): 60 | # Step 1 61 | # Select the entities to be used in the local search 62 | selected_entities = self._entities_selector.run(query, artifacts.entities) 63 | 64 | # Step 2 65 | # Select the text units to be used in the local search 66 | selected_text_units = self._text_units_selector.run( 67 | df_entities=selected_entities, 68 | df_relationships=artifacts.relationships, 69 | df_text_units=artifacts.text_units, 70 | ) 71 | 72 | # Step 3 73 | # Select the relationships to be used in the local search 74 | selected_relationships = self._relationships_selector.run( 75 | df_entities=selected_entities, 76 | df_relationships=artifacts.relationships, 77 | ) 78 | 79 | # Step 4 80 | # Select the communities to be used in the local search 81 | selected_communities_reports = self._communities_reports_selector.run( 82 | df_entities=selected_entities, 83 | df_reports=artifacts.communities_reports, 84 | ) 85 | 86 | return ContextSelectionResult( 87 | entities=selected_entities, 88 | text_units=selected_text_units, 89 | relationships=selected_relationships, 90 | communities_reports=selected_communities_reports, 91 | ) 92 | -------------------------------------------------------------------------------- /src/langchain_graphrag/indexing/simple_indexer.py: -------------------------------------------------------------------------------- 1 | """A simple indexer that sequentially uses various components to index a document. 2 | 3 | You could implement a more complex indexer that uses a pipeline, workflow engines 4 | etc to index a document. 5 | 6 | """ 7 | 8 | from langchain_core.documents import Document 9 | 10 | from langchain_graphrag.types.graphs.community import CommunityDetector 11 | 12 | from .artifacts import IndexerArtifacts 13 | from .artifacts_generation import ( 14 | CommunitiesReportsArtifactsGenerator, 15 | EntitiesArtifactsGenerator, 16 | RelationshipsArtifactsGenerator, 17 | TextUnitsArtifactsGenerator, 18 | ) 19 | from .graph_generation import GraphGenerator 20 | from .text_unit_extractor import TextUnitExtractor 21 | 22 | 23 | class SimpleIndexer: 24 | def __init__( 25 | self, 26 | text_unit_extractor: TextUnitExtractor, 27 | graph_generator: GraphGenerator, 28 | community_detector: CommunityDetector, 29 | entities_artifacts_generator: EntitiesArtifactsGenerator, 30 | relationships_artifacts_generator: RelationshipsArtifactsGenerator, 31 | communities_report_artifacts_generator: CommunitiesReportsArtifactsGenerator, 32 | text_units_artifacts_generator: TextUnitsArtifactsGenerator, 33 | ): 34 | self._text_unit_extractor = text_unit_extractor 35 | self._graph_generator = graph_generator 36 | self._community_detector = community_detector 37 | self._entities_artifacts_generator = entities_artifacts_generator 38 | self._relationships_artifacts_generator = relationships_artifacts_generator 39 | self._communities_report_artifacts_generator = ( 40 | communities_report_artifacts_generator 41 | ) 42 | self._text_units_artifacts_generator = text_units_artifacts_generator 43 | 44 | def run(self, documents: list[Document]) -> IndexerArtifacts: 45 | # Step 1 - Text Unit extraction 46 | df_base_text_units = self._text_unit_extractor.run(documents) 47 | 48 | # Step 2 - Generate graphs 49 | merged_graph, summarized_graph = self._graph_generator.run(df_base_text_units) 50 | 51 | # Step 3 - Detect communities in Graph 52 | community_detection_result = self._community_detector.run(summarized_graph) 53 | 54 | # Step 4 - Reports for detected Communities (depends on Step 2 & Step 3) 55 | df_communities_reports = self._communities_report_artifacts_generator.run( 56 | community_detection_result, 57 | summarized_graph, 58 | ) 59 | 60 | # Step 5 - Entities generation (depends on Step 2 & Step 3) 61 | df_entities = self._entities_artifacts_generator.run( 62 | community_detection_result, 63 | summarized_graph, 64 | ) 65 | 66 | # Step 6 - Relationships generation (depends on Step 2) 67 | df_relationships = self._relationships_artifacts_generator.run(summarized_graph) 68 | 69 | # Step 7 - Text Units generation (depends on Steps 1, 5, 6) 70 | df_text_units = self._text_units_artifacts_generator.run( 71 | df_base_text_units, 72 | df_entities, 73 | df_relationships, 74 | ) 75 | 76 | return IndexerArtifacts( 77 | entities=df_entities, 78 | relationships=df_relationships, 79 | text_units=df_text_units, 80 | communities_reports=df_communities_reports, 81 | summarized_graph=summarized_graph, 82 | merged_graph=merged_graph, 83 | communities=community_detection_result, 84 | ) 85 | -------------------------------------------------------------------------------- /src/langchain_graphrag/indexing/artifacts.py: -------------------------------------------------------------------------------- 1 | from typing import NamedTuple 2 | 3 | import networkx as nx 4 | import pandas as pd 5 | import tableprint 6 | 7 | from langchain_graphrag.types.graphs.community import CommunityDetectionResult 8 | 9 | 10 | class IndexerArtifacts(NamedTuple): 11 | entities: pd.DataFrame 12 | relationships: pd.DataFrame 13 | text_units: pd.DataFrame 14 | communities_reports: pd.DataFrame 15 | merged_graph: nx.Graph | None = None 16 | summarized_graph: nx.Graph | None = None 17 | communities: CommunityDetectionResult | None = None 18 | 19 | def _entity_info(self, top_k: int) -> None: 20 | tableprint.banner("Entities") 21 | 22 | rows = [ 23 | ["Count", len(self.entities)], 24 | ["Number of types", len(self.entities["type"].unique())], 25 | ] 26 | tableprint.table(rows) 27 | 28 | # entity types 29 | tableprint.banner("Entity Types") 30 | rows = [] 31 | for entity_type, count in self.entities["type"].value_counts().items(): 32 | rows.append([entity_type, count]) 33 | tableprint.table(rows, ["Type", "Count"]) 34 | 35 | # entities for which the generated type is empty string 36 | empty_type_entities = self.entities[self.entities["type"] == ""] 37 | if not empty_type_entities.empty: 38 | tableprint.banner("Entities with Empty Type") 39 | tableprint.dataframe(empty_type_entities[["title", "degree"]]) 40 | 41 | # k most connected entities 42 | by_degree = self.entities.sort_values("degree", ascending=False)[:top_k] 43 | 44 | # print k most connected entities 45 | tableprint.banner(f"{top_k} Most Connected Entities") 46 | tableprint.dataframe(by_degree[["title", "degree"]]) 47 | 48 | # are there entities with degree 0 49 | zero_degree_entities = self.entities[self.entities["degree"] == 0] 50 | if not zero_degree_entities.empty: 51 | tableprint.banner("Disconnected Entities") 52 | tableprint.dataframe(zero_degree_entities[["title", "degree"]]) 53 | 54 | def _relationships_info(self, top_k: int) -> None: 55 | tableprint.banner("Relationships") 56 | rows = [["Count", len(self.relationships)]] 57 | tableprint.table(rows) 58 | 59 | # k highly ranked relationships 60 | by_rank = self.relationships.sort_values("rank", ascending=False)[:top_k] 61 | 62 | # print 5 most connected entities 63 | tableprint.banner(f"{top_k} Top Ranked Relationships") 64 | tableprint.dataframe(by_rank[["source", "target", "rank"]]) 65 | 66 | def _text_units_info(self) -> None: 67 | tableprint.banner("Text Units") 68 | rows = [["Count", len(self.text_units)]] 69 | tableprint.table(rows) 70 | 71 | def _communities_reports_info(self) -> None: 72 | tableprint.banner("Communities Reports") 73 | 74 | levels = self.communities_reports["level"].unique() 75 | 76 | rows = [] 77 | for level in levels: 78 | communities = self.communities_reports[ 79 | self.communities_reports["level"] == level 80 | ] 81 | row = [level, len(communities)] 82 | rows.append(row) 83 | 84 | tableprint.table(rows, ["Level", "Number of Communities"]) 85 | 86 | def report( 87 | self, 88 | top_k_entities: int = 5, 89 | top_k_relationships: int = 5, 90 | ) -> None: 91 | self._text_units_info() 92 | self._entity_info(top_k=top_k_entities) 93 | self._relationships_info(top_k=top_k_relationships) 94 | self._communities_reports_info() 95 | -------------------------------------------------------------------------------- /src/langchain_graphrag/indexing/artifacts_generation/text_units.py: -------------------------------------------------------------------------------- 1 | from typing import cast 2 | 3 | import pandas as pd 4 | from langchain_core.vectorstores import VectorStore 5 | from pandas._typing import Suffixes 6 | from tqdm import tqdm 7 | 8 | 9 | def _array_agg_distinct(series: pd.Series) -> list[pd.Series]: 10 | return list(series.unique()) 11 | 12 | 13 | def _make_temporary_frame( 14 | entities_or_relationships: pd.DataFrame, 15 | rename_id_to: str, 16 | ) -> pd.DataFrame: 17 | # select only id & text_unit_ids columns 18 | tmp = entities_or_relationships[["id", "text_unit_ids"]] 19 | 20 | # flatten the text_unit_ids 21 | tmp = tmp.explode("text_unit_ids") 22 | 23 | # group by text_unit_ids 24 | grouped = tmp.groupby("text_unit_ids", sort=False) 25 | 26 | aggregations = {"id": _array_agg_distinct} 27 | 28 | output = cast(pd.DataFrame, grouped.agg(aggregations)) 29 | output.rename(columns={"id": rename_id_to}, inplace=True) 30 | 31 | return output.reset_index() 32 | 33 | 34 | class TextUnitsArtifactsGenerator: 35 | def __init__(self, vector_store: VectorStore | None = None): 36 | self._vector_store = vector_store 37 | 38 | def run( 39 | self, 40 | base_text_units: pd.DataFrame, 41 | entities: pd.DataFrame, 42 | relationships: pd.DataFrame, 43 | ) -> pd.DataFrame: 44 | entities_df = _make_temporary_frame( 45 | entities, 46 | rename_id_to="entity_ids", 47 | ) 48 | entities_df.rename(columns={"text_unit_ids": "id"}, inplace=True) 49 | 50 | relationships_df = _make_temporary_frame( 51 | relationships, 52 | rename_id_to="relationship_ids", 53 | ) 54 | relationships_df.rename(columns={"text_unit_ids": "id"}, inplace=True) 55 | 56 | text_units_entities = base_text_units.merge( 57 | entities_df, 58 | left_on="id", 59 | right_on="id", 60 | how="left", 61 | suffixes=cast(Suffixes, ["_1", "_2"]), 62 | indicator=True, 63 | ) 64 | 65 | text_units_entities = base_text_units.merge( 66 | entities_df, 67 | left_on="id", 68 | right_on="id", 69 | how="left", 70 | suffixes=cast(Suffixes, ["_1", "_2"]), 71 | indicator=True, 72 | ).drop("_merge", axis=1) 73 | 74 | text_units = text_units_entities.merge( 75 | relationships_df, 76 | left_on="id", 77 | right_on="id", 78 | how="left", 79 | suffixes=cast(Suffixes, ["_1", "_2"]), 80 | indicator=True, 81 | ).drop("_merge", axis=1) 82 | 83 | def _run_embedder(series: pd.Series) -> None: 84 | chunk_to_embedd = series["text_unit"] 85 | chunk_id = series["id"] 86 | 87 | # Bug in langchain vectorstore retrival that 88 | # does not populate Document.id field. 89 | # 90 | # Hence add relationship_id as an additional field 91 | # in the metadata 92 | chunk_metadata = dict( 93 | document_id=series["document_id"], 94 | text_unit_id=chunk_id, # TODO: Remove once langchain is fixed 95 | ) 96 | 97 | assert self._vector_store is not None 98 | 99 | self._vector_store.add_texts( 100 | [chunk_to_embedd], 101 | metadata=[chunk_metadata], 102 | ids=[chunk_id], 103 | ) 104 | 105 | if self._vector_store: 106 | tqdm.pandas(desc="Generating chunk embedding ...") 107 | text_units.progress_apply( 108 | _run_embedder, 109 | axis=1, 110 | ) 111 | 112 | return text_units 113 | -------------------------------------------------------------------------------- /src/langchain_graphrag/indexing/artifacts_generation/entities.py: -------------------------------------------------------------------------------- 1 | import networkx as nx 2 | import numpy as np 3 | import pandas as pd 4 | from langchain_core.vectorstores import VectorStore 5 | 6 | from langchain_graphrag.types.graphs.community import ( 7 | CommunityDetectionResult, 8 | CommunityId, 9 | ) 10 | from langchain_graphrag.types.graphs.embedding import GraphEmbeddingGenerator 11 | 12 | 13 | def _make_entity_to_communities_map( 14 | detection_result: CommunityDetectionResult, 15 | ) -> dict[str, list[CommunityId]]: 16 | entity_to_communities: dict[str, list[CommunityId]] = {} 17 | for level in detection_result.communities: 18 | communities = detection_result.communities_at_level(level) 19 | for c in communities: 20 | for node in c.nodes: 21 | entity_to_communities.setdefault(node.name, []).append(c.id) 22 | return entity_to_communities 23 | 24 | 25 | class EntitiesArtifactsGenerator: 26 | def __init__( 27 | self, 28 | entities_vector_store: VectorStore, 29 | graph_embedding_generator: GraphEmbeddingGenerator | None = None, 30 | ): 31 | self._graph_embedding_generator = graph_embedding_generator 32 | self._entities_vector_store = entities_vector_store 33 | 34 | def _unpack_nodes( 35 | self, 36 | graph: nx.Graph, 37 | entity_to_commnunities_map: dict[str, list[CommunityId]], 38 | graph_embeddings: dict[str, np.ndarray] | None, 39 | ) -> pd.DataFrame: 40 | records = [ 41 | { 42 | "title": label, 43 | **(node_data or {}), 44 | "communities": entity_to_commnunities_map.get(label), 45 | "graph_embedding": graph_embeddings.get(label) 46 | if graph_embeddings 47 | else None, 48 | } 49 | for label, node_data in graph.nodes(data=True) 50 | ] 51 | return pd.DataFrame.from_records(records) 52 | 53 | def run( 54 | self, 55 | detection_result: CommunityDetectionResult, 56 | graph: nx.Graph, 57 | ) -> pd.DataFrame: 58 | # Step 1 (Optional) 59 | # Generate graph embeddings 60 | graph_embeddings = ( 61 | self._graph_embedding_generator.run(graph) 62 | if self._graph_embedding_generator 63 | else None 64 | ) 65 | 66 | # Step 2 67 | # Extract the information to embed from the graph 68 | # and put in the vectorstore 69 | texts_to_embed = [] 70 | texts_metadata = [] 71 | texts_ids = [] 72 | for name, node_data in graph.nodes(data=True): 73 | text_description = node_data.get("description") 74 | texts_ids.append(node_data.get("id")) 75 | texts_to_embed.append(f"{name}:{text_description}") 76 | 77 | # Bug in langchain vectorstore retrival that 78 | # does not populate Document.id field. 79 | # 80 | # Hence add entity_id as an additional field 81 | # in the metadata 82 | texts_metadata.append( 83 | dict( 84 | name=name, 85 | description=text_description, 86 | degree=node_data.get("degree"), 87 | entity_id=node_data.get( 88 | "id" 89 | ), # TODO: Remove once langchain is fixed 90 | ) 91 | ) 92 | 93 | self._entities_vector_store.add_texts( 94 | texts_to_embed, 95 | metadatas=texts_metadata, 96 | ids=texts_ids, 97 | ) 98 | 99 | entity_to_commnunities_map = _make_entity_to_communities_map(detection_result) 100 | 101 | # Step 3 102 | # Make a dataframe 103 | return self._unpack_nodes( 104 | graph, 105 | entity_to_commnunities_map, 106 | graph_embeddings, 107 | ) 108 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 110 | .pdm.toml 111 | .pdm-python 112 | .pdm-build/ 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | 157 | # PyCharm 158 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 159 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 160 | # and can be added to the global gitignore or merged into this file. For a more nuclear 161 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 162 | #.idea/ 163 | 164 | tmp 165 | temp 166 | outputs 167 | 168 | .env 169 | test-data 170 | scratch 171 | .DS_Store 172 | tmp-* -------------------------------------------------------------------------------- /src/langchain_graphrag/indexing/graph_generation/graphs_merger.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | from random import Random 3 | from typing import Any 4 | 5 | import networkx as nx 6 | 7 | from langchain_graphrag.utils.uuid import gen_uuid 8 | 9 | 10 | class AttributesToMerge(str, Enum): 11 | text_unit_ids = "text_unit_ids" 12 | description = "description" 13 | weight = "weight" 14 | 15 | 16 | def merge_attributes( 17 | *, 18 | target_node: dict[str, Any], 19 | source_node: dict[str, Any], 20 | attribs: list[AttributesToMerge], 21 | ): 22 | for attrib in attribs: 23 | # I am expecting the attributes are not missing 24 | target_attrib = target_node.get(attrib) 25 | source_attrib = source_node.get(attrib) 26 | if attrib == AttributesToMerge.weight: 27 | target_node[attrib] = int(target_attrib) + int(source_attrib) 28 | else: 29 | target_node[attrib].extend(source_attrib) 30 | target_node[attrib] = sorted(set(target_node[attrib])) 31 | 32 | 33 | def merge_nodes(*, target_graph: nx.Graph, sub_graph: nx.Graph): 34 | for node in sub_graph.nodes: 35 | if node not in target_graph.nodes: 36 | target_graph.add_node(node, **(sub_graph.nodes[node] or {})) 37 | else: 38 | merge_attributes( 39 | target_node=target_graph.nodes[node], 40 | source_node=sub_graph.nodes[node], 41 | attribs=[ 42 | AttributesToMerge.text_unit_ids, 43 | AttributesToMerge.description, 44 | ], 45 | ) 46 | 47 | 48 | def merge_edges(*, target_graph: nx.Graph, sub_graph: nx.Graph): 49 | for source, target, edge_data in sub_graph.edges(data=True): 50 | if not target_graph.has_edge(source, target): 51 | target_graph.add_edge(source, target, **(edge_data or {})) 52 | else: 53 | merge_attributes( 54 | target_node=target_graph.edges[(source, target)], 55 | source_node=edge_data, 56 | attribs=[ 57 | AttributesToMerge.text_unit_ids, 58 | AttributesToMerge.description, 59 | AttributesToMerge.weight, 60 | ], 61 | ) 62 | 63 | 64 | class GraphsMerger: 65 | def __init__(self, seed: int = 0xF001): 66 | self._seed = seed 67 | 68 | def __call__( 69 | self, 70 | graphs: list[nx.Graph], 71 | ) -> nx.Graph: 72 | merged_graph: nx.Graph = nx.Graph() 73 | for g in graphs: 74 | merge_nodes(target_graph=merged_graph, sub_graph=g) 75 | merge_edges(target_graph=merged_graph, sub_graph=g) 76 | 77 | # add degree as an attribute 78 | for node_degree in merged_graph.degree: 79 | merged_graph.nodes[str(node_degree[0])]["degree"] = int(node_degree[1]) 80 | 81 | # add source degree, target degree and rank as attributes 82 | # to the edges 83 | for source, target in merged_graph.edges(): 84 | source_degree = merged_graph.nodes[source]["degree"] 85 | target_degree = merged_graph.nodes[target]["degree"] 86 | merged_graph.edges[source, target]["source_degree"] = source_degree 87 | merged_graph.edges[source, target]["target_degree"] = target_degree 88 | merged_graph.edges[source, target]["rank"] = source_degree + target_degree 89 | 90 | random = Random(self._seed) # noqa: S311 91 | 92 | # add ids to nodes 93 | for index, node in enumerate(merged_graph.nodes()): 94 | merged_graph.nodes[node]["human_readable_id"] = index 95 | merged_graph.nodes[node]["id"] = str(gen_uuid(random)) 96 | 97 | # add ids to edges 98 | for index, edge in enumerate(merged_graph.edges()): 99 | merged_graph.edges[edge]["human_readable_id"] = index 100 | merged_graph.edges[edge]["id"] = str(gen_uuid(random)) 101 | 102 | return merged_graph 103 | -------------------------------------------------------------------------------- /docs/guides/graph_extraction/index.md: -------------------------------------------------------------------------------- 1 | # Advanced Examples - Graph Extraction 2 | 3 | Component-level customization and development 4 | 5 | The graph extraction is done by `GraphGenerator` class - a wrapper over 3 core components. This section explains how to customize and extend individual components. 6 | 7 | **Prerequisites**: Read the core documentation guides first: [Architecture Overview](../../architecture/overview.md), [Indexing Pipeline](../indexing_pipeline.md), [Query System](../query_system.md), and [Data Flow Examples](../data_flow_examples.md). 8 | 9 | Note: You can pass your own implementations of these components provided they follow the same protocol. 10 | 11 | --- 12 | 13 | ## EntityRelationshipExtractor 14 | 15 | **Purpose**: Generates a `networkx` graph for every `text_unit` in the dataframe. 16 | 17 | **Process**: This component extracts entities and relationships from individual text units using LLM analysis. 18 | 19 | See the [Entity Relationship Extraction notebook](er_extraction.ipynb) for a detailed interactive guide on this component. 20 | 21 | **Customization Options**: 22 | - Custom entity types and extraction prompts 23 | - Different LLM models for extraction 24 | - Custom output parsing logic 25 | - Domain-specific entity recognition rules 26 | 27 | --- 28 | 29 | ## GraphsMerger 30 | 31 | **Purpose**: Merges all the graphs generated by `EntityRelationshipExtractor` into a single graph. 32 | 33 | **Process**: The merging is done by creating a list of descriptions for nodes and edges that appear in multiple text units. 34 | 35 | **Key Features**: 36 | - Consolidates duplicate entities across text units 37 | - Aggregates relationship descriptions 38 | - Maintains source traceability to original text units 39 | - Calculates entity importance (degree) and relationship rankings 40 | 41 | **Customization Options**: 42 | - Custom entity matching logic 43 | - Different description aggregation strategies 44 | - Custom ranking algorithms 45 | - Graph sanitization rules 46 | 47 | --- 48 | 49 | ## EntityRelationshipDescriptionSummarizer 50 | 51 | **Purpose**: Every node and edge in the merged graph has multiple descriptions from different text units. This component uses LLM to summarize them into clear, unified descriptions. 52 | 53 | **Process**: Takes the aggregated descriptions and creates clean, comprehensive summaries for entities and relationships. 54 | 55 | **Key Features**: 56 | - LLM-powered description synthesis 57 | - Maintains factual accuracy across sources 58 | - Removes redundancy while preserving key information 59 | - Creates consistent description format 60 | 61 | **Customization Options**: 62 | - Custom summarization prompts 63 | - Different LLM models for summarization 64 | - Domain-specific description templates 65 | - Quality validation rules 66 | 67 | --- 68 | 69 | ## Complete GraphGenerator Usage 70 | 71 | **Bringing it all together**: See the [Graph Generator notebook](graph_generator.ipynb) for a complete interactive example. 72 | 73 | **Code Overview**: 74 | ```python 75 | from langchain_graphrag.indexing.graph_generation import ( 76 | GraphGenerator, 77 | EntityRelationshipExtractor, 78 | GraphsMerger, 79 | EntityRelationshipDescriptionSummarizer 80 | ) 81 | 82 | # Create components 83 | extractor = EntityRelationshipExtractor.build_default(llm=your_llm) 84 | merger = GraphsMerger() 85 | summarizer = EntityRelationshipDescriptionSummarizer.build_default(llm=your_llm) 86 | 87 | # Create graph generator 88 | graph_generator = GraphGenerator( 89 | er_extractor=extractor, 90 | graphs_merger=merger, 91 | er_description_summarizer=summarizer 92 | ) 93 | 94 | # Process text units 95 | merged_graph, summarized_graph = graph_generator.run(text_units_df) 96 | ``` 97 | 98 | --- 99 | 100 | ## Additional Resources 101 | 102 | These components provide the foundation for knowledge graph construction. Understanding their individual roles helps with system customization and troubleshooting. 103 | 104 | **[Documentation Index](../../index.md)** 105 | Return to the complete documentation structure 106 | 107 | -------------------------------------------------------------------------------- /src/langchain_graphrag/query/global_search/key_points_aggregator/_system_prompt.py: -------------------------------------------------------------------------------- 1 | # ruff: noqa 2 | 3 | REDUCE_SYSTEM_PROMPT = """ 4 | ---Role--- 5 | 6 | You are a helpful assistant responding to questions about a dataset by synthesizing perspectives from multiple analysts. 7 | 8 | ---Goal--- 9 | 10 | Generate a response of the target length and format that responds to the user's question, summarize all the reports from multiple analysts who focused on different parts of the dataset. 11 | 12 | Note that the analysts' reports provided below are ranked in the **descending order of importance**. 13 | 14 | If you don't know the answer or if the provided reports do not contain sufficient information to provide an answer, just say so. Do not make anything up. 15 | 16 | The final response should remove all irrelevant information from the analysts' reports and merge the cleaned information into a comprehensive answer that provides explanations of all the key points and implications appropriate for the response length and format. 17 | 18 | Add sections and commentary to the response as appropriate for the length and format. Style the response in markdown. 19 | 20 | The response shall preserve the original meaning and use of modal verbs such as "shall", "may" or "will". 21 | 22 | The response should also preserve all the data references previously included in the analysts' reports, but do not mention the roles of multiple analysts in the analysis process. 23 | 24 | {{#show_references}} 25 | **Do not list more than 5 record ids in a single reference**. Instead, list the top 5 most relevant record ids and add "+more" to indicate that there are more. 26 | {{/show_references}} 27 | 28 | For example: 29 | 30 | "Person X is the owner of Company Y and subject to many allegations of wrongdoing {{#show_references}}[Data: Reports (2, 7, 34, 46, 64, +more)]{{/show_references}}. He is also CEO of company X {{#show_references}}[Data: Reports (1, 3)]{{/show_references}}" 31 | 32 | {{#show_references}} 33 | where 1, 2, 3, 7, 34, 46, and 64 represent the id (not the index) of the relevant data record. 34 | {{/show_references}} 35 | 36 | Do not include information where the supporting evidence for it is not provided. 37 | 38 | ---Target response length and format--- 39 | 40 | {{response_type}} 41 | 42 | Add sections and commentary to the response as appropriate for the length and format. Style the response in markdown. 43 | 44 | ---Analyst Reports--- 45 | 46 | {{report_data}} 47 | 48 | {{#repeat_instructions}} 49 | 50 | ---Goal--- 51 | 52 | Generate a response of the target length and format that responds to the user's question, summarize all the reports from multiple analysts who focused on different parts of the dataset. 53 | 54 | Note that the analysts' reports provided below are ranked in the **descending order of importance**. 55 | 56 | If you don't know the answer or if the provided reports do not contain sufficient information to provide an answer, just say so. Do not make anything up. 57 | 58 | The final response should remove all irrelevant information from the analysts' reports and merge the cleaned information into a comprehensive answer that provides explanations of all the key points and implications appropriate for the response length and format. 59 | 60 | The response shall preserve the original meaning and use of modal verbs such as "shall", "may" or "will". 61 | 62 | The response should also preserve all the data references previously included in the analysts' reports, but do not mention the roles of multiple analysts in the analysis process. 63 | 64 | {{#show_references}} 65 | **Do not list more than 5 record ids in a single reference**. Instead, list the top 5 most relevant record ids and add "+more" to indicate that there are more. 66 | {{/show_references}} 67 | 68 | For example: 69 | 70 | "Person X is the owner of Company Y and subject to many allegations of wrongdoing {{#show_references}}[Data: Reports (2, 7, 34, 46, 64, +more)]{{/show_references}}. He is also CEO of company X {{#show_references}}[Data: Reports (1, 3)]{{/show_references}}" 71 | 72 | {{#show_references}} 73 | where 1, 2, 3, 7, 34, 46, and 64 represent the id (not the index) of the relevant data record. 74 | {{/show_references}} 75 | 76 | Do not include information where the supporting evidence for it is not provided. 77 | 78 | ---Target response length and format--- 79 | 80 | {{response_type}} 81 | 82 | Add sections and commentary to the response as appropriate for the length and format. Style the response in markdown. 83 | {{/repeat_instructions}} 84 | """ 85 | -------------------------------------------------------------------------------- /src/langchain_graphrag/indexing/graph_generation/entity_relationship_extraction/extractor.py: -------------------------------------------------------------------------------- 1 | """Entity Relationship Extractor module.""" 2 | 3 | from __future__ import annotations 4 | 5 | import logging 6 | 7 | import networkx as nx 8 | import pandas as pd 9 | from langchain_core.language_models import LanguageModelLike 10 | from langchain_core.runnables.config import RunnableConfig 11 | from tqdm import tqdm 12 | 13 | from langchain_graphrag.types.prompts import IndexingPromptBuilder 14 | 15 | from .prompt_builder import EntityExtractionPromptBuilder 16 | 17 | _LOGGER = logging.getLogger(__name__) 18 | 19 | 20 | class EntityRelationshipExtractor: 21 | def __init__( 22 | self, 23 | prompt_builder: IndexingPromptBuilder, 24 | llm: LanguageModelLike, 25 | *, 26 | chain_config: RunnableConfig | None = None, 27 | ): 28 | """Extracts entities and relationships from text units using a language model. 29 | 30 | Args: 31 | prompt_builder (PromptBuilder): The prompt builder object used to construct the prompt for the language model. 32 | llm (LanguageModelLike): The language model used for entity and relationship extraction. 33 | chain_config (RunnableConfig, optional): The configuration object for the extraction chain. Defaults to None. 34 | 35 | """ 36 | prompt, output_parser = prompt_builder.build() 37 | self._extraction_chain = prompt | llm | output_parser 38 | self._prompt_builder = prompt_builder 39 | self._chain_config = chain_config 40 | 41 | @staticmethod 42 | def build_default( 43 | llm: LanguageModelLike, 44 | *, 45 | chain_config: RunnableConfig | None = None, 46 | ) -> EntityRelationshipExtractor: 47 | """Builds and returns an instance of EntityRelationshipExtractor with default parameters. 48 | 49 | Parameters: 50 | llm (LanguageModelLike): The language model used for entity relationship extraction. 51 | chain_config (RunnableConfig, optional): The configuration object for the extraction chain. Defaults to None. 52 | 53 | Returns: 54 | EntityRelationshipExtractor: An instance of EntityRelationshipExtractor with default parameters. 55 | """ 56 | return EntityRelationshipExtractor( 57 | prompt_builder=EntityExtractionPromptBuilder(), 58 | llm=llm, 59 | chain_config=chain_config, 60 | ) 61 | 62 | def invoke(self, text_units: pd.DataFrame) -> list[nx.Graph]: 63 | """Invoke the entity relationship extraction process on the text units. 64 | 65 | The pandas DataFrame required by this method is generated by `TextUnitExtractor` 66 | and must have following three columns: 67 | 68 | - document_id 69 | - id 70 | - text_unit. 71 | 72 | Parameters: 73 | text_units (pd.DataFrame): A pandas dataframe containing the text units. 74 | 75 | Returns: 76 | A list of networkx Graph objects representing the extracted entities and relationships. 77 | """ 78 | 79 | def _run_chain(series: pd.Series) -> nx.Graph: 80 | _, text_id, text_unit = ( 81 | series["document_id"], 82 | series["id"], 83 | series["text_unit"], 84 | ) 85 | 86 | chain_input = self._prompt_builder.prepare_chain_input(text_unit=text_unit) 87 | 88 | chunk_graph = self._extraction_chain.invoke( 89 | input=chain_input, 90 | config=self._chain_config, 91 | ) 92 | 93 | # add the chunk_id to the nodes 94 | for node_names in chunk_graph.nodes(): 95 | chunk_graph.nodes[node_names]["text_unit_ids"] = [text_id] 96 | 97 | # add the chunk_id to the edges as well 98 | for edge_names in chunk_graph.edges(): 99 | chunk_graph.edges[edge_names]["text_unit_ids"] = [text_id] 100 | 101 | if logging.getLevelName(_LOGGER.getEffectiveLevel()) == "DEBUG": 102 | _LOGGER.debug(f"Graph for: {text_id}") 103 | _LOGGER.debug(chunk_graph) 104 | 105 | return chunk_graph 106 | 107 | tqdm.pandas(desc="Extracting entities and relationships ...") 108 | chunk_graphs: list[nx.Graph] = text_units.progress_apply(_run_chain, axis=1) 109 | 110 | return chunk_graphs 111 | -------------------------------------------------------------------------------- /src/langchain_graphrag/query/global_search/key_points_generator/context_builder.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from langchain_core.documents import Document 4 | 5 | from langchain_graphrag.indexing.artifacts import IndexerArtifacts 6 | from langchain_graphrag.query.global_search.community_report import CommunityReport 7 | from langchain_graphrag.query.global_search.community_weight_calculator import ( 8 | CommunityWeightCalculator, 9 | ) 10 | from langchain_graphrag.types.graphs.community import CommunityId, CommunityLevel 11 | from langchain_graphrag.utils.token_counter import TokenCounter 12 | 13 | _REPORT_TEMPLATE = """ 14 | --- Report {report_id} --- 15 | 16 | Title: {title} 17 | Weight: {weight} 18 | Rank: {rank} 19 | Report: 20 | 21 | {content} 22 | 23 | """ 24 | 25 | _LOGGER = logging.getLogger(__name__) 26 | 27 | 28 | class CommunityReportContextBuilder: 29 | def __init__( 30 | self, 31 | community_level: CommunityLevel, 32 | weight_calculator: CommunityWeightCalculator, 33 | artifacts: IndexerArtifacts, 34 | token_counter: TokenCounter, 35 | max_tokens: int = 8000, 36 | ): 37 | self._community_level = community_level 38 | self._weight_calculator = weight_calculator 39 | self._artifacts = artifacts 40 | self._token_counter = token_counter 41 | self._max_tokens = max_tokens 42 | 43 | def _filter_communities(self) -> list[CommunityReport]: 44 | df_entities = self._artifacts.entities 45 | df_reports = self._artifacts.communities_reports 46 | 47 | reports_weight: dict[CommunityId, float] = self._weight_calculator( 48 | df_entities, 49 | df_reports, 50 | ) 51 | 52 | df_reports_filtered = df_reports[df_reports["level"] <= self._community_level] 53 | 54 | reports = [] 55 | for _, row in df_reports_filtered.iterrows(): 56 | reports.append( 57 | CommunityReport( 58 | id=row["community_id"], 59 | weight=reports_weight[row["community_id"]], 60 | title=row["title"], 61 | summary=row["summary"], 62 | rank=row["rating"], 63 | content=row["content"], 64 | ) 65 | ) 66 | 67 | return reports 68 | 69 | def __call__(self) -> list[Document]: 70 | reports = self._filter_communities() 71 | 72 | documents: list[Document] = [] 73 | report_str_accumulated: list[str] = [] 74 | token_count = 0 75 | for report in reports: 76 | # we would try to combine multiple 77 | # reports into a single document 78 | # as long as we do not exceed the token limit 79 | 80 | report_str = _REPORT_TEMPLATE.format( 81 | report_id=report.id, 82 | title=report.title, 83 | weight=report.weight, 84 | rank=report.rank, 85 | content=report.content, 86 | ) 87 | 88 | report_str_token_count = self._token_counter.count_tokens(report_str) 89 | 90 | if token_count + report_str_token_count > self._max_tokens: 91 | _LOGGER.warning("Reached max tokens for a community report call ...") 92 | # we cut a new document here 93 | documents.append( 94 | Document( 95 | page_content="\n".join(report_str_accumulated), 96 | metadata={"token_count": token_count}, 97 | ) 98 | ) 99 | # reset the token count and the accumulated string 100 | token_count = 0 101 | report_str_accumulated = [] 102 | else: 103 | token_count += report_str_token_count 104 | report_str_accumulated.append(report_str) 105 | 106 | if report_str_accumulated: 107 | documents.append( 108 | Document( 109 | page_content="\n".join(report_str_accumulated), 110 | metadata={"token_count": token_count}, 111 | ) 112 | ) 113 | 114 | if _LOGGER.isEnabledFor(logging.DEBUG): 115 | import tableprint 116 | 117 | rows = [] 118 | tableprint.banner("KP Generation Context Token Usage") 119 | for index, doc in enumerate(documents): 120 | rows.append([f"Report {index}", doc.metadata["token_count"]]) 121 | 122 | tableprint.table(rows, ["Reports", "Token Count"]) 123 | 124 | return documents 125 | -------------------------------------------------------------------------------- /src/langchain_graphrag/query/global_search/key_points_generator/_system_prompt.py: -------------------------------------------------------------------------------- 1 | # ruff: noqa: E501 2 | 3 | MAP_SYSTEM_PROMPT = """ 4 | ---Role--- 5 | 6 | You are a helpful assistant responding to questions about data in the tables provided. 7 | 8 | ---Goal--- 9 | 10 | Generate a response consisting of a list of key points that responds to the user's question, summarizing all relevant information in the input data tables. 11 | 12 | You should use the data provided in the data tables below as the primary context for generating the response. 13 | If you don't know the answer or if the input data tables do not contain sufficient information to provide an answer, just say so. Do not make anything up. 14 | 15 | Each key point in the response should have the following element: 16 | - Description: A comprehensive description of the point. 17 | - Importance Score: An integer score between 0-100 that indicates how important the point is in answering the user's question. An 'I don't know' type of response should have a score of 0. 18 | 19 | The response should be JSON formatted as follows: 20 | { 21 | "points": [ 22 | {"description": "Description of point 1 {{#show_references}}[Data: Reports (report ids)]{{/show_references}}", "score": score_value}, 23 | {"description": "Description of point 2 {{#show_references}}[Data: Reports (report ids)]{{/show_references}}", "score": score_value} 24 | ] 25 | } 26 | 27 | The response shall preserve the original meaning and use of modal verbs such as "shall", "may" or "will". 28 | 29 | {{#show_references}} 30 | Points supported by data should list the relevant reports as references as follows: 31 | "This is an example sentence supported by data references [Data: Reports (report ids)]" 32 | 33 | 34 | **Do not list more than 5 record ids in a single reference**. Instead, list the top 5 most relevant record ids and add "+more" to indicate that there are more. 35 | {{/show_references}} 36 | 37 | For example: 38 | "Person X is the owner of Company Y and subject to many allegations of wrongdoing {{#show_references}}[Data: Reports (2, 7, 64, 46, 34, +more)]{{/show_references}}. He is also CEO of company X {{#show_references}}[Data: Reports (1, 3)]{{/show_references}}" 39 | 40 | {{#show_references}} 41 | where 1, 2, 3, 7, 34, 46, and 64 represent the id (not the index) of the relevant data report in the provided tables. 42 | {{/show_references}} 43 | 44 | Do not include information where the supporting evidence for it is not provided. 45 | 46 | ---Data tables--- 47 | 48 | {{context_data}} 49 | 50 | {{#repeat_instructions}} 51 | 52 | ---Goal--- 53 | 54 | Generate a response consisting of a list of key points that responds to the user's question, summarizing all relevant information in the input data tables. 55 | 56 | You should use the data provided in the data tables below as the primary context for generating the response. 57 | If you don't know the answer or if the input data tables do not contain sufficient information to provide an answer, just say so. Do not make anything up. 58 | 59 | Each key point in the response should have the following element: 60 | - Description: A comprehensive description of the point. 61 | - Importance Score: An integer score between 0-100 that indicates how important the point is in answering the user's question. An 'I don't know' type of response should have a score of 0. 62 | 63 | The response shall preserve the original meaning and use of modal verbs such as "shall", "may" or "will". 64 | 65 | {{#show_references}} 66 | Points supported by data should list the relevant reports as references as follows: 67 | "This is an example sentence supported by data references [Data: Reports (report ids)]" 68 | 69 | **Do not list more than 5 record ids in a single reference**. Instead, list the top 5 most relevant record ids and add "+more" to indicate that there are more. 70 | {{/show_references}} 71 | 72 | For example: 73 | "Person X is the owner of Company Y and subject to many allegations of wrongdoing {{#show_references}}[Data: Reports (2, 7, 64, 46, 34, +more)]{{/show_references}}. He is also CEO of company X {{#show_references}}[Data: Reports (1, 3)]{{/show_references}}" 74 | 75 | {{#show_references}} 76 | where 1, 2, 3, 7, 34, 46, and 64 represent the id (not the index) of the relevant data report in the provided tables. 77 | {{/show_references}} 78 | 79 | Do not include information where the supporting evidence for it is not provided. 80 | 81 | The response should be JSON formatted as follows: 82 | { 83 | "points": [ 84 | {"description": "Description of point 1 {{#show_references}}[Data: Reports (report ids)]{{/show_references}}", "score": score_value}, 85 | {"description": "Description of point 2 {{#show_references}}[Data: Reports (report ids)]{{/show_references}}", "score": score_value} 86 | ] 87 | } 88 | {{/repeat_instructions}} 89 | """ 90 | -------------------------------------------------------------------------------- /src/langchain_graphrag/indexing/graph_generation/entity_relationship_extraction/prompt_builder.py: -------------------------------------------------------------------------------- 1 | """Default PromptBuilder for Entity Relationship Extraction.""" 2 | 3 | from pathlib import Path 4 | from typing import Any 5 | 6 | from langchain_core.output_parsers.base import BaseOutputParser 7 | from langchain_core.prompts import BasePromptTemplate, PromptTemplate 8 | from typing_extensions import Unpack 9 | 10 | from langchain_graphrag.types.prompts import IndexingPromptBuilder 11 | 12 | from ._default_prompts import DEFAULT_ER_EXTRACTION_PROMPT 13 | from ._output_parser import EntityExtractionOutputParser 14 | 15 | _DEFAULT_TUPLE_DELIMITER = "<|>" 16 | _DEFAULT_RECORD_DELIMITER = "##" 17 | _DEFAULT_COMPLETION_DELIMITER = "<|COMPLETE|>" 18 | _DEFAULT_ENTITY_TYPES = ["organization", "person", "geo", "event"] 19 | 20 | 21 | class EntityExtractionPromptBuilder(IndexingPromptBuilder): 22 | """PromptBuilder for Entity Relationship extraction. 23 | 24 | This implementation assumes that the prompt is a template string with the following placeholders: 25 | 26 | - entity_types: A comma-separated list of entity types. Default is "organization,person,geo,event". 27 | 28 | - tuple_delimiter: The delimiter for tuples. Default is "<|>". 29 | 30 | - record_delimiter: The delimiter for records. Default is "##". 31 | 32 | - completion_delimiter: The delimiter for completions. Default is "<|COMPLETE|>". 33 | 34 | 35 | There is a default template embedded in the package which is same 36 | as that of the official implementation. 37 | 38 | You can supply your own prompt as long as you keep the placeholders and as 39 | have examples specified in the same output format as the default prompt. 40 | 41 | If you want a prompt with different placeholders or output formats for 42 | entity extraction, you can create a custom implementation of the protocol 43 | `PromptBuilder` and use it in the `EntityRelationshipExtractor` class. 44 | 45 | """ 46 | 47 | def __init__( 48 | self, 49 | *, 50 | prompt: str | None = None, 51 | prompt_path: Path | None = None, 52 | entity_types: list[str] = _DEFAULT_ENTITY_TYPES, 53 | tuple_delimiter: str = _DEFAULT_TUPLE_DELIMITER, 54 | record_delimiter: str = _DEFAULT_RECORD_DELIMITER, 55 | completion_delimiter: str = _DEFAULT_COMPLETION_DELIMITER, 56 | ): 57 | """Initializes the PromptBuilder object. 58 | 59 | If neither prompt nor prompt_path is provided, the default prompt 60 | provided with in the package (same as official implementation) will be used. 61 | 62 | Args: 63 | prompt (str | None, optional): The prompt string. 64 | prompt_path (Path | None, optional): The path to the prompt file. 65 | entity_types (list[str], optional): The list of entity types. 66 | tuple_delimiter (str, optional): The delimiter for tuples. 67 | record_delimiter (str, optional): The delimiter for records. 68 | completion_delimiter (str, optional): The delimiter for completions. 69 | """ # noqa: D202 70 | 71 | self._prompt: str | None 72 | if prompt is None and prompt_path is None: 73 | self._prompt = DEFAULT_ER_EXTRACTION_PROMPT 74 | else: 75 | self._prompt = prompt 76 | 77 | self._prompt_path = prompt_path 78 | 79 | self._entity_types = entity_types 80 | self._tuple_delimiter = tuple_delimiter 81 | self._record_delimiter = record_delimiter 82 | self._completion_delimiter = completion_delimiter 83 | 84 | def build(self) -> tuple[BasePromptTemplate, BaseOutputParser]: 85 | """Build the template and output parser. 86 | 87 | Note: 88 | You would not directly use this method. It is used by the 89 | `EntityRelationshipExtractor` class. 90 | 91 | Returns: 92 | A tuple containing the built `BasePromptTemplate` object 93 | and the `EntityExtractionOutputParser` object. 94 | """ 95 | if self._prompt: 96 | prompt_template = PromptTemplate.from_template(self._prompt) 97 | else: 98 | assert self._prompt_path is not None 99 | prompt_template = PromptTemplate.from_file(self._prompt_path) 100 | 101 | return ( 102 | prompt_template.partial( 103 | completion_delimiter=self._completion_delimiter, 104 | tuple_delimiter=self._tuple_delimiter, 105 | record_delimiter=self._record_delimiter, 106 | entity_types=",".join(self._entity_types), 107 | ), 108 | EntityExtractionOutputParser( 109 | tuple_delimiter=self._tuple_delimiter, 110 | record_delimiter=self._record_delimiter, 111 | ), 112 | ) 113 | 114 | def prepare_chain_input(self, **kwargs: Unpack[dict[str, Any]]) -> dict[str, str]: # noqa: D417 115 | """Prepares the input for the extraction chain. 116 | 117 | Note: 118 | You would not directly use this method. 119 | It is used by the `EntityRelationshipExtractor` class. 120 | 121 | Args: 122 | text_unit: The text unit from which entities and relationships are extracted. 123 | """ 124 | text_unit: str = kwargs.get("text_unit", None) 125 | if text_unit is None: 126 | raise ValueError("text_unit is required") 127 | 128 | return dict(input_text=text_unit) 129 | -------------------------------------------------------------------------------- /src/langchain_graphrag/indexing/graph_generation/entity_relationship_extraction/_output_parser.py: -------------------------------------------------------------------------------- 1 | """Default OutputParser for Entity Relationship Extraction.""" 2 | 3 | import html 4 | import numbers 5 | import re 6 | from collections.abc import Mapping 7 | from typing import Any 8 | 9 | import networkx as nx 10 | from langchain_core.output_parsers import BaseOutputParser 11 | 12 | _ENTITY_ATTRIBUTES_LENGTH = 4 13 | _RELATIONSHIP_ATTRIBUTES_LENGTH = 5 14 | 15 | 16 | def _clean_str(input_str: Any) -> str: 17 | """Remove HTML escapes, control characters, and other unwanted characters.""" 18 | # If we get non-string input, just give it back 19 | if not isinstance(input_str, str): 20 | return input_str 21 | 22 | result = html.unescape(input_str.strip()) 23 | # https://stackoverflow.com/questions/4324790/removing-control-characters-from-a-string-in-python 24 | return re.sub(r"[\x00-\x1f\x7f-\x9f]", "", result) 25 | 26 | 27 | def _unpack_descriptions(data: Mapping) -> list[str]: 28 | return data.get("description", []) 29 | 30 | 31 | class EntityExtractionOutputParser(BaseOutputParser[nx.Graph]): 32 | """OutputParser for extracting entities and relationships. 33 | 34 | This parses the output from the LLM based on the conventions & formats 35 | used in the default prompt. 36 | 37 | Note: 38 | You would not use this class directly. 39 | 40 | 41 | Attributes: 42 | tuple_delimiter (str): Delimiter used to separate attributes within a record. 43 | record_delimiter (str): Delimiter used to separate records. 44 | """ 45 | 46 | tuple_delimiter: str 47 | record_delimiter: str 48 | 49 | def _process_entity(self, record_attributes: list[str], graph: nx.Graph) -> None: 50 | if (record_attributes[0] != '"entity"') or ( 51 | len(record_attributes) < _ENTITY_ATTRIBUTES_LENGTH 52 | ): 53 | return 54 | 55 | # add this record as a node in the G 56 | entity_name = _clean_str(record_attributes[1].upper()) 57 | entity_type = _clean_str(record_attributes[2].upper()) 58 | entity_description = _clean_str(record_attributes[3]) 59 | 60 | if entity_name in graph.nodes(): 61 | node = graph.nodes[entity_name] 62 | node["description"] = list( 63 | { 64 | *_unpack_descriptions(node), 65 | entity_description, 66 | } 67 | ) 68 | 69 | node["entity_type"] = ( 70 | entity_type if entity_type != "" else node["entity_type"] 71 | ) 72 | else: 73 | graph.add_node( 74 | entity_name, 75 | type=entity_type, 76 | description=[entity_description], 77 | ) 78 | 79 | def _process_relationship( 80 | self, 81 | record_attributes: list[str], 82 | graph: nx.Graph, 83 | ) -> None: 84 | if ( 85 | record_attributes[0] != '"relationship"' 86 | or len(record_attributes) < _RELATIONSHIP_ATTRIBUTES_LENGTH 87 | ): 88 | return 89 | 90 | # add this record as edge 91 | source = _clean_str(record_attributes[1].upper()) 92 | target = _clean_str(record_attributes[2].upper()) 93 | edge_description = _clean_str(record_attributes[3]) 94 | 95 | weight = ( 96 | float(record_attributes[-1]) 97 | if isinstance(record_attributes[-1], numbers.Number) 98 | else 1.0 99 | ) 100 | if source not in graph.nodes(): 101 | graph.add_node( 102 | source, 103 | type="", 104 | description=[""], 105 | ) 106 | if target not in graph.nodes(): 107 | graph.add_node( 108 | target, 109 | type="", 110 | description=[""], 111 | ) 112 | if graph.has_edge(source, target): 113 | edge_data = graph.get_edge_data(source, target) 114 | if edge_data is not None: 115 | weight += edge_data["weight"] 116 | edge_descriptions = list( 117 | { 118 | *_unpack_descriptions(edge_data), 119 | edge_description, 120 | } 121 | ) 122 | else: 123 | edge_descriptions = [edge_description] 124 | 125 | graph.add_edge(source, target, weight=weight, description=edge_descriptions) 126 | 127 | def _process_record(self, graph: nx.Graph, record: str) -> None: 128 | record = re.sub(r"^\(|\)$", "", record.strip()) 129 | record_attributes = record.split(self.tuple_delimiter) 130 | 131 | self._process_entity(record_attributes, graph) 132 | self._process_relationship(record_attributes, graph) 133 | 134 | def parse(self, text: str) -> nx.Graph: 135 | """Parses the given text and returns a networkx Graph object. 136 | 137 | Parameters: 138 | text (str): The text to be parsed. 139 | 140 | Returns: 141 | The parsed graph object. 142 | """ 143 | graph: nx.Graph = nx.Graph() 144 | records = [r.strip() for r in text.split(self.record_delimiter)] 145 | for record in records: 146 | self._process_record(graph, record) 147 | return graph 148 | 149 | @property 150 | def _type(self) -> str: 151 | return "entity_extraction_output_parser" 152 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "langchain-graphrag" 3 | version = "0.0.9" 4 | description = "Implementation of GraphRAG (https://arxiv.org/pdf/2404.16130)" 5 | authors = [{ name = "Kapil Sachdeva", email = "notan@email.com" }] 6 | dependencies = [ 7 | "pandas>=2.2.2", 8 | "networkx>=3.3", 9 | "langchain-core>=0.3.0", 10 | "langchain-text-splitters>=0.3.0", 11 | "graspologic>=3.4.1", 12 | "tableprint>=0.9.1" 13 | ] 14 | readme = "README.md" 15 | requires-python = ">= 3.10" 16 | 17 | [build-system] 18 | requires = ["hatchling"] 19 | build-backend = "hatchling.build" 20 | 21 | [project.urls] 22 | repository = "https://github.com/dineshkrishna9999/langchain-graphrag.git" 23 | 24 | [tool.hatch.metadata] 25 | allow-direct-references = true 26 | 27 | [tool.hatch.build.targets.wheel] 28 | packages = ["src/langchain_graphrag"] 29 | 30 | 31 | # ===== TEMPLATE VARIABLES (environment-based) ===== 32 | [tool.poe.env] 33 | APP_CMD = "python examples/simple-app/app/main.py" 34 | DIRS = "--output-dir tmp --cache-dir tmp/cache" 35 | INPUT = "--input-file examples/input-data/book.txt" 36 | ARTIFACTS = "--artifacts-dir tmp/artifacts_gpt-4o" 37 | AZURE_LLM = "--llm-type azure_openai --llm-model gpt-4o" 38 | AZURE_EMBED = "--embedding-type azure_openai --embedding-model text-embedding-3-large" 39 | OPENAI_LLM = "--llm-type openai --llm-model gpt-4o" 40 | OPENAI_EMBED = "--embedding-type openai --embedding-model text-embedding-3-small" 41 | OLLAMA_LLM = "--llm-type ollama --llm-model llama2" 42 | OLLAMA_EMBED = "--embedding-type ollama --embedding-model llama2" 43 | 44 | [tool.poe.tasks] 45 | 46 | # ===== INTERNAL TASKS (using template variables) ===== 47 | 48 | # Basic commands 49 | _app-help = "${APP_CMD} --help" 50 | _indexer-help = "${APP_CMD} indexer --help" 51 | _query-help = "${APP_CMD} query --help" 52 | _indexer-base = "${APP_CMD} indexer index" 53 | _local-search-base = "${APP_CMD} query local-search ${DIRS}" 54 | _global-search-base = "${APP_CMD} query global-search ${DIRS}" 55 | _report = "${APP_CMD} indexer report ${ARTIFACTS}" 56 | 57 | # Full configured indexer commands (template composition) 58 | _index-azure = "${APP_CMD} indexer index ${INPUT} ${DIRS} ${AZURE_LLM} ${AZURE_EMBED}" 59 | _index-openai = "${APP_CMD} indexer index ${INPUT} ${DIRS} ${OPENAI_LLM} ${OPENAI_EMBED}" 60 | _index-ollama = "${APP_CMD} indexer index ${INPUT} ${DIRS} ${OLLAMA_LLM} ${OLLAMA_EMBED}" 61 | 62 | # Full configured search commands (template composition) 63 | _local-azure = "${APP_CMD} query local-search ${DIRS} ${AZURE_LLM} ${AZURE_EMBED}" 64 | _local-openai = "${APP_CMD} query local-search ${DIRS} ${OPENAI_LLM} ${OPENAI_EMBED}" 65 | _local-ollama = "${APP_CMD} query local-search ${DIRS} ${OLLAMA_LLM} ${OLLAMA_EMBED}" 66 | _global-azure = "${APP_CMD} query global-search ${DIRS} ${AZURE_LLM}" 67 | _global-openai = "${APP_CMD} query global-search ${DIRS} ${OPENAI_LLM}" 68 | _global-ollama = "${APP_CMD} query global-search ${DIRS} ${OLLAMA_LLM}" 69 | 70 | 71 | 72 | # Base app command 73 | simple-app = "${APP_CMD}" 74 | 75 | # Help commands 76 | simple-app-help = { ref = "_app-help" } 77 | simple-app-indexer-help = { ref = "_indexer-help" } 78 | simple-app-query-help = { ref = "_query-help" } 79 | 80 | # Basic indexer commands 81 | simple-app-indexer = { ref = "_indexer-base" } 82 | 83 | # Indexer with different providers 84 | simple-app-indexer-azure = { ref = "_index-azure" } 85 | simple-app-indexer-openai = { ref = "_index-openai" } 86 | simple-app-indexer-ollama = { ref = "_index-ollama" } 87 | 88 | # Report generation 89 | simple-app-report = { ref = "_report" } 90 | 91 | # Search commands - basic 92 | simple-app-local-search = { ref = "_local-search-base" } 93 | simple-app-global-search = { ref = "_global-search-base" } 94 | 95 | # Search commands with provider configurations 96 | simple-app-local-search-azure = { ref = "_local-azure" } 97 | simple-app-local-search-openai = { ref = "_local-openai" } 98 | simple-app-local-search-ollama = { ref = "_local-ollama" } 99 | simple-app-global-search-azure = { ref = "_global-azure" } 100 | simple-app-global-search-openai = { ref = "_global-openai" } 101 | simple-app-global-search-ollama = { ref = "_global-ollama" } 102 | 103 | 104 | # Development tasks 105 | test = "pytest" 106 | test-verbose = "pytest -v" 107 | lint = "ruff check" 108 | lint-fix = "ruff check --fix" 109 | format = "ruff format" 110 | format-check = "ruff format --check" 111 | typecheck = "mypy src" 112 | clean = "python -c \"import shutil, os; [shutil.rmtree(d, ignore_errors=True) for d in ['tmp', '__pycache__', '.pytest_cache', '.mypy_cache', '.ruff_cache'] if os.path.exists(d)]\"" 113 | 114 | # Documentation tasks 115 | docs-serve = "mkdocs serve" 116 | docs-build = "mkdocs build" 117 | 118 | # Combined tasks 119 | check = ["lint", "format-check", "typecheck"] 120 | fix = ["lint-fix", "format"] 121 | check-all = ["lint", "format-check", "typecheck"] 122 | 123 | [tool.uv] 124 | dev-dependencies = [ 125 | "pytest>=8.3.2", 126 | "ipykernel>=6.29.5", 127 | "mkdocs>=1.6.0", 128 | "mkdocstrings[python]>=0.25.2", 129 | "markdown-include>=0.8.1", 130 | "pre-commit>=3.8.0", 131 | "mkdocs-gen-files>=0.5.0", 132 | "mkdocs-literate-nav>=0.6.1", 133 | "mkdocs-section-index>=0.3.9", 134 | "mkdocs-material>=9.5.31", 135 | "mkdocs-jupyter>=0.24.8", 136 | "fastparquet>=2024.11.0", 137 | "langchain-chroma>=0.2.2", 138 | "langchain-community>=0.3.21", 139 | "langchain-ollama>=0.3.3", 140 | "langchain-openai>=0.3.24", 141 | "mypy>=1.16.1", 142 | "poethepoet>=0.35.0", 143 | "pyarrow>=20.0.0", 144 | "python-dotenv>=1.1.0", 145 | "ruff>=0.12.0", 146 | "typer>=0.16.0" 147 | ] 148 | 149 | [tool.isort] 150 | profile = "black" 151 | -------------------------------------------------------------------------------- /src/langchain_graphrag/query/local_search/context_selectors/relationships.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import logging 3 | from typing import NamedTuple 4 | 5 | import pandas as pd 6 | 7 | _LOGGER = logging.getLogger(__name__) 8 | 9 | 10 | class RelationshipsSelectionResult(NamedTuple): 11 | in_network_relationships: pd.DataFrame 12 | out_network_relationships: pd.DataFrame 13 | 14 | 15 | def _find_in_network_relationships( 16 | df_entities: pd.DataFrame, 17 | df_relationships: pd.DataFrame, 18 | source_column_name: str = "source_id", 19 | target_column_name: str = "target_id", 20 | entity_column_name: str = "id", 21 | ) -> pd.DataFrame: 22 | entities_ids = df_entities[entity_column_name].tolist() 23 | entities_pairs = list(itertools.combinations(entities_ids, 2)) 24 | 25 | def filter_in_network_relationships(source: str, target: str) -> bool: 26 | check_1 = (source, target) in entities_pairs 27 | check_2 = (target, source) in entities_pairs 28 | return check_1 == True or check_2 == True # noqa: E712 29 | 30 | df_relationships["is_in_network"] = df_relationships.apply( 31 | lambda x: filter_in_network_relationships( 32 | x[source_column_name], x[target_column_name] 33 | ), 34 | axis=1, 35 | ) 36 | 37 | df_relationships = df_relationships[df_relationships["is_in_network"] == True] # noqa: E712 38 | 39 | df_relationships.drop(columns=["is_in_network"], inplace=True) 40 | 41 | # sort the relationships by rank 42 | df_relationships = df_relationships.sort_values( 43 | by="rank", ascending=False 44 | ).reset_index(drop=True) 45 | 46 | if _LOGGER.isEnabledFor(logging.DEBUG): 47 | import tableprint 48 | 49 | how_many = len(df_relationships) 50 | 51 | tableprint.banner(f"Selected {how_many} In-Network Relationships") 52 | tableprint.dataframe(df_relationships[["source", "target", "rank"]]) 53 | 54 | return df_relationships 55 | 56 | 57 | def _find_out_network_relationships( 58 | df_entities: pd.DataFrame, 59 | df_relationships: pd.DataFrame, 60 | top_k: int = 10, 61 | source_column_name: str = "source_id", 62 | target_column_name: str = "target_id", 63 | entity_column_name: str = "id", 64 | ) -> pd.DataFrame: 65 | entities_ids = df_entities[entity_column_name].tolist() 66 | 67 | # top_k is budget for out-network relationships 68 | relationship_budget = top_k * len(entities_ids) 69 | 70 | def filter_out_network_relationships(source: str, target: str) -> bool: 71 | if source in entities_ids and target not in entities_ids: 72 | return True 73 | if target in entities_ids and source not in entities_ids: # noqa: SIM103 74 | return True 75 | 76 | return False 77 | 78 | df_relationships["is_out_network"] = df_relationships.apply( 79 | lambda x: filter_out_network_relationships( 80 | x[source_column_name], x[target_column_name] 81 | ), 82 | axis=1, 83 | ) 84 | 85 | df_relationships = df_relationships[df_relationships["is_out_network"] == True] # noqa: E712 86 | 87 | df_relationships.drop(columns=["is_out_network"], inplace=True) 88 | 89 | # now we need to prioritize based on which external 90 | # entities have the most connection with the selected entities 91 | # we will do this by counting the number of relationships 92 | # each external entity has with the selected entities 93 | source_external_entities = df_relationships[ 94 | ~df_relationships[source_column_name].isin(entities_ids) 95 | ][source_column_name] 96 | 97 | target_external_entities = df_relationships[ 98 | ~df_relationships[target_column_name].isin(entities_ids) 99 | ][target_column_name] 100 | 101 | df_relationships = ( 102 | df_relationships.merge( 103 | source_external_entities.value_counts(), 104 | how="left", 105 | left_on=source_column_name, 106 | right_on=source_column_name, 107 | ) 108 | .fillna(0) 109 | .rename(columns={"count": "source_count"}) 110 | ) 111 | 112 | df_relationships = ( 113 | df_relationships.merge( 114 | target_external_entities.value_counts(), 115 | how="left", 116 | left_on=target_column_name, 117 | right_on=target_column_name, 118 | ) 119 | .fillna(0) 120 | .rename(columns={"count": "target_count"}) 121 | ) 122 | 123 | df_relationships["links"] = ( 124 | df_relationships["source_count"] + df_relationships["target_count"] 125 | ) 126 | 127 | df_relationships = df_relationships.sort_values( 128 | by=["links", "rank"], 129 | ascending=[False, False], 130 | ).reset_index(drop=True) 131 | 132 | # time to use the budget 133 | df_relationships = df_relationships.head(relationship_budget) 134 | 135 | if _LOGGER.isEnabledFor(logging.DEBUG): 136 | import tableprint 137 | 138 | how_many = len(df_relationships) 139 | 140 | tableprint.banner(f"Selected {how_many} Out-Network Relationships") 141 | tableprint.dataframe(df_relationships[["source", "target", "rank", "links"]]) 142 | 143 | return df_relationships 144 | 145 | 146 | class RelationshipsSelector: 147 | def __init__(self, top_k_out_network: int = 5): 148 | self._top_k_out_network = top_k_out_network 149 | 150 | def run( 151 | self, 152 | df_entities: pd.DataFrame, 153 | df_relationships: pd.DataFrame, 154 | ) -> RelationshipsSelectionResult: 155 | in_network_relationships = _find_in_network_relationships( 156 | df_entities, 157 | df_relationships.copy(deep=True), 158 | ) 159 | 160 | out_network_relationships = _find_out_network_relationships( 161 | df_entities, 162 | df_relationships.copy(deep=True), 163 | top_k=self._top_k_out_network, 164 | ) 165 | 166 | return RelationshipsSelectionResult( 167 | in_network_relationships, 168 | out_network_relationships, 169 | ) 170 | -------------------------------------------------------------------------------- /examples/simple-app/app/indexer.py: -------------------------------------------------------------------------------- 1 | # ruff: noqa: B008 2 | # ruff: noqa: E402 3 | # ruff: noqa: ERA001 4 | 5 | import logging 6 | import os 7 | from pathlib import Path 8 | 9 | import tableprint 10 | import typer 11 | from dotenv import load_dotenv 12 | from typer import Typer 13 | 14 | _LOGGER = logging.getLogger("main:indexer") 15 | 16 | # going to do load_dotenv() here 17 | # as OLLAMA_HOST needs to be in the environment 18 | # before the imports below 19 | load_dotenv() 20 | 21 | 22 | from common import ( 23 | EmbeddingModelType, 24 | LLMType, 25 | get_artifacts_dir_name, 26 | load_artifacts, 27 | make_embedding_instance, 28 | make_llm_instance, 29 | save_artifacts, 30 | trace_via_langsmith, 31 | ) 32 | from langchain_chroma.vectorstores import Chroma as ChromaVectorStore 33 | from langchain_community.document_loaders import TextLoader 34 | from langchain_text_splitters import TokenTextSplitter 35 | 36 | from langchain_graphrag.indexing import SimpleIndexer, TextUnitExtractor 37 | from langchain_graphrag.indexing.artifacts import IndexerArtifacts 38 | from langchain_graphrag.indexing.artifacts_generation import ( 39 | CommunitiesReportsArtifactsGenerator, 40 | EntitiesArtifactsGenerator, 41 | RelationshipsArtifactsGenerator, 42 | TextUnitsArtifactsGenerator, 43 | ) 44 | from langchain_graphrag.indexing.graph_clustering.leiden_community_detector import ( 45 | HierarchicalLeidenCommunityDetector, 46 | ) 47 | from langchain_graphrag.indexing.graph_generation import ( 48 | EntityRelationshipDescriptionSummarizer, 49 | EntityRelationshipExtractor, 50 | GraphGenerator, 51 | GraphsMerger, 52 | ) 53 | from langchain_graphrag.indexing.report_generation import ( 54 | CommunityReportGenerator, 55 | CommunityReportWriter, 56 | ) 57 | 58 | app = Typer() 59 | 60 | 61 | @app.command() 62 | def index( 63 | input_file: Path = typer.Option(..., dir_okay=False, file_okay=True, exists=True), 64 | output_dir: Path = typer.Option(..., dir_okay=True, file_okay=False), 65 | cache_dir: Path = typer.Option(..., dir_okay=True, file_okay=False), 66 | llm_type: LLMType = typer.Option(..., case_sensitive=False), 67 | llm_model: str = typer.Option(..., case_sensitive=False), 68 | embedding_type: EmbeddingModelType = typer.Option(..., case_sensitive=False), 69 | embedding_model: str = typer.Option(..., case_sensitive=False), 70 | chunk_size: int = typer.Option(1200, help="Chunk size for text splitting"), 71 | chunk_overlap: int = typer.Option(100, help="Chunk overlap for text splitting"), 72 | ollama_num_context: int = typer.Option( 73 | None, help="Context window size for ollama model" 74 | ), 75 | enable_langsmith: bool = typer.Option(False, help="Enable Langsmith"), # noqa: FBT001, FBT003 76 | ): 77 | if enable_langsmith: 78 | trace_via_langsmith() 79 | 80 | output_dir.mkdir(parents=True, exist_ok=True) 81 | cache_dir.mkdir(parents=True, exist_ok=True) 82 | vector_store_dir = output_dir / "vector_stores" 83 | artifacts_dir = output_dir / get_artifacts_dir_name(llm_model) 84 | artifacts_dir.mkdir(parents=True, exist_ok=True) 85 | 86 | tableprint.table( 87 | [ 88 | ["LangSmith", str(enable_langsmith)], 89 | ["Input file", str(input_file)], 90 | ["Cache directory", str(cache_dir)], 91 | ["Vector store directory", str(vector_store_dir)], 92 | ["Artifacts directory", str(artifacts_dir)], 93 | ["LLM Type", llm_type], 94 | ["LLM Model", llm_model], 95 | ["Embedding Type", embedding_type], 96 | ["Embedding Model", embedding_model], 97 | ["Chunk Size", chunk_size], 98 | ["Chunk Overlap", chunk_overlap], 99 | ["OLLAMA_HOST", os.getenv("OLLAMA_HOST")], 100 | [ 101 | "Ollama Num Context", 102 | "Not Provided" if ollama_num_context is None else ollama_num_context, 103 | ], 104 | ] 105 | ) 106 | 107 | ######### Start of creation of various objects/dependencies ############# 108 | 109 | # Dataloader that loads the supplied text file for indexing 110 | documents = TextLoader(file_path=input_file).load() 111 | 112 | # TextSplitter required by TextUnitExtractor 113 | text_splitter = TokenTextSplitter( 114 | chunk_size=chunk_size, 115 | chunk_overlap=chunk_overlap, 116 | ) 117 | 118 | # TextUnitExtractor that extracts text units from the text files 119 | text_unit_extractor = TextUnitExtractor(text_splitter=text_splitter) 120 | 121 | # Entity Relationship Extractor 122 | entity_extractor = EntityRelationshipExtractor.build_default( 123 | llm=make_llm_instance(llm_type, llm_model, cache_dir), 124 | chain_config={"tags": ["er-extraction"]}, 125 | ) 126 | 127 | # Entity Relationship Description Summarizer 128 | entity_summarizer = EntityRelationshipDescriptionSummarizer.build_default( 129 | llm=make_llm_instance(llm_type, llm_model, cache_dir), 130 | chain_config={"tags": ["er-description-summarization"]}, 131 | ) 132 | 133 | # Graph Generator 134 | graph_generator = GraphGenerator( 135 | er_extractor=entity_extractor, 136 | graphs_merger=GraphsMerger(), 137 | er_description_summarizer=entity_summarizer, 138 | ) 139 | 140 | # Community Detector 141 | community_detector = HierarchicalLeidenCommunityDetector() 142 | 143 | # Entities artifacts Generator 144 | # We need the vector Store (mandatory) for entities 145 | 146 | # let's create a collection name based on 147 | # the embedding model name 148 | entities_collection_name = f"entity-{embedding_model}" 149 | entities_vector_store = ChromaVectorStore( 150 | collection_name=entities_collection_name, 151 | persist_directory=str(vector_store_dir), 152 | embedding_function=make_embedding_instance( 153 | embedding_type=embedding_type, 154 | model=embedding_model, 155 | cache_dir=cache_dir, 156 | ), 157 | ) 158 | 159 | entities_artifacts_generator = EntitiesArtifactsGenerator( 160 | entities_vector_store=entities_vector_store 161 | ) 162 | 163 | relationships_artifacts_generator = RelationshipsArtifactsGenerator() 164 | 165 | # Community Report Generator 166 | report_gen_llm = make_llm_instance(llm_type, llm_model, cache_dir) 167 | report_generator = CommunityReportGenerator.build_default( 168 | llm=report_gen_llm, 169 | chain_config={"tags": ["community-report"]}, 170 | ) 171 | 172 | report_writer = CommunityReportWriter() 173 | 174 | communities_report_artifacts_generator = CommunitiesReportsArtifactsGenerator( 175 | report_generator=report_generator, 176 | report_writer=report_writer, 177 | ) 178 | 179 | text_units_artifacts_generator = TextUnitsArtifactsGenerator() 180 | 181 | ######### End of creation of various objects/dependencies ############# 182 | 183 | indexer = SimpleIndexer( 184 | text_unit_extractor=text_unit_extractor, 185 | graph_generator=graph_generator, 186 | community_detector=community_detector, 187 | entities_artifacts_generator=entities_artifacts_generator, 188 | relationships_artifacts_generator=relationships_artifacts_generator, 189 | text_units_artifacts_generator=text_units_artifacts_generator, 190 | communities_report_artifacts_generator=communities_report_artifacts_generator, 191 | ) 192 | 193 | artifacts = indexer.run(documents) 194 | 195 | # save the artifacts 196 | save_artifacts(artifacts, artifacts_dir) 197 | artifacts.report() 198 | 199 | 200 | @app.command() 201 | def report( 202 | artifacts_dir: Path = typer.Option( 203 | ..., 204 | exists=True, 205 | dir_okay=True, 206 | file_okay=False, 207 | ), 208 | ): 209 | _LOGGER.info("Artifacts directory - %s", artifacts_dir) 210 | 211 | artifacts: IndexerArtifacts = load_artifacts(artifacts_dir) 212 | artifacts.report() 213 | -------------------------------------------------------------------------------- /docs/architecture/overview.md: -------------------------------------------------------------------------------- 1 | # GraphRAG Architecture Overview 2 | 3 | GraphRAG transforms document collections into structured knowledge graphs, supporting both entity-specific and thematic analysis queries. 4 | 5 | 6 | --- 7 | 8 | ## Core System Overview 9 | 10 | GraphRAG operates through **two fundamental processes**: 11 | 12 | | Phase | Purpose | Output | 13 | |-------|---------|--------| 14 | | **Indexing Pipeline** | Analyzes documents to construct structured knowledge | Knowledge Graph + Community Reports | 15 | | **Query Engine** | Uses knowledge graph for contextual responses | Contextual Answers | 16 | 17 | ```mermaid 18 | flowchart LR 19 | A["Document Collection"] --> B["Indexing Pipeline"] 20 | B --> C["Knowledge Graph
+ Community Reports"] 21 | C --> D["Query Engine"] 22 | D --> E["Contextual Responses"] 23 | 24 | classDef inputStyle fill:#f8f9fa,stroke:#6c757d,stroke-width:2px,color:#212529 25 | classDef processStyle fill:#e3f2fd,stroke:#1976d2,stroke-width:2px,color:#0d47a1 26 | classDef dataStyle fill:#fff3e0,stroke:#f57c00,stroke-width:2px,color:#e65100 27 | classDef outputStyle fill:#e8f5e8,stroke:#388e3c,stroke-width:2px,color:#1b5e20 28 | 29 | class A inputStyle 30 | class B,D processStyle 31 | class C dataStyle 32 | class E outputStyle 33 | ``` 34 | 35 | ## Indexing Pipeline Architecture 36 | 37 | The indexing process turns raw documents into organized knowledge through a step-by-step process: 38 | 39 | ```mermaid 40 | flowchart TD 41 | subgraph input ["Input Processing"] 42 | A["Raw Documents"] 43 | B["Document Splitting"] 44 | A --> B 45 | end 46 | 47 | subgraph extraction ["Knowledge Extraction"] 48 | C["Entity Recognition"] 49 | D["Relationship Mining"] 50 | E["Graph Construction"] 51 | C --> E 52 | D --> E 53 | end 54 | 55 | subgraph organization ["Knowledge Organization"] 56 | F["Community Detection"] 57 | G["Summary Generation"] 58 | F --> G 59 | end 60 | 61 | subgraph output ["Query Artifacts"] 62 | H["Searchable Knowledge Base"] 63 | end 64 | 65 | B --> C 66 | B --> D 67 | E --> F 68 | G --> H 69 | 70 | classDef inputClass fill:#f8f9fa,stroke:#6c757d,stroke-width:2px 71 | classDef processClass fill:#e3f2fd,stroke:#1976d2,stroke-width:2px 72 | classDef organizeClass fill:#fff3e0,stroke:#f57c00,stroke-width:2px 73 | classDef outputClass fill:#e8f5e8,stroke:#388e3c,stroke-width:2px 74 | 75 | class A,B inputClass 76 | class C,D,E processClass 77 | class F,G organizeClass 78 | class H outputClass 79 | ``` 80 | 81 | --- 82 | 83 | ## Practical Example 84 | 85 | > **Input Document**: "Ratan Tata served as Chairman of Tata Group from 1991 to 2012, transforming it into a global business group with acquisitions like Jaguar Land Rover." 86 | 87 | ### GraphRAG Knowledge Extraction 88 | 89 | | Extract Type | Results | 90 | |--------------|---------| 91 | | **Entities** | Ratan Tata • Tata Group • Jaguar Land Rover • 1991 • 2012 | 92 | | **Relationships** | Ratan Tata → served_as_chairman → Tata Group
Tata Group → acquired → Jaguar Land Rover | 93 | | **Communities** | Business Leadership • Automotive Industry | 94 | 95 | --- 96 | 97 | ## Query Engine Architecture 98 | 99 | GraphRAG uses **two different search methods** to handle different types of questions: 100 | 101 | ### Local Search (Entity-Focused Queries) 102 | 103 | | Aspect | Details | 104 | |--------|---------| 105 | | **Best For** | Specific factual questions about entities and relationships | 106 | | **Examples** | "What companies did Ratan Tata lead?" • "When did Tata acquire Jaguar?" | 107 | | **How it Works** | Finds entities → Follows connections → Builds context | 108 | | **Characteristics** | High precision with specific responses | 109 | 110 | ### Global Search (Thematic Analysis) 111 | 112 | | Aspect | Details | 113 | |--------|---------| 114 | | **Best For** | Big-picture questions requiring complete insights | 115 | | **Examples** | "Key business transformation strategies?" • "Leadership patterns in business groups?" | 116 | | **How it Works** | Community analysis → Pre-built summaries → Insight combining | 117 | | **Characteristics** | Complete coverage with synthesized insights | 118 | 119 | ```mermaid 120 | flowchart TD 121 | A["User Query"] --> B{"Query Classification"} 122 | 123 | subgraph local ["Local Search Pipeline"] 124 | C["Entity Resolution"] 125 | D["Graph Traversal"] 126 | E["Context Assembly"] 127 | C --> D --> E 128 | end 129 | 130 | subgraph global ["Global Search Pipeline"] 131 | F["Community Matching"] 132 | G["Report Synthesis"] 133 | H["Insight Aggregation"] 134 | F --> G --> H 135 | end 136 | 137 | I["Response Generation"] 138 | 139 | B -->|"Entity-Specific"| local 140 | B -->|"Thematic"| global 141 | E --> I 142 | H --> I 143 | 144 | classDef queryStyle fill:#f8f9fa,stroke:#6c757d,stroke-width:2px,color:#212529 145 | classDef decisionStyle fill:#fff3e0,stroke:#f57c00,stroke-width:3px,color:#e65100 146 | classDef localStyle fill:#e3f2fd,stroke:#1976d2,stroke-width:2px,color:#0d47a1 147 | classDef globalStyle fill:#f3e5f5,stroke:#7b1fa2,stroke-width:2px,color:#4a148c 148 | classDef outputStyle fill:#e8f5e8,stroke:#388e3c,stroke-width:2px,color:#1b5e20 149 | 150 | class A queryStyle 151 | class B decisionStyle 152 | class C,D,E localStyle 153 | class F,G,H globalStyle 154 | class I outputStyle 155 | ``` 156 | 157 | --- 158 | 159 | ## Comparison with Traditional Search 160 | 161 | | Aspect | Traditional Search | GraphRAG | 162 | |-----------|------------------|----------| 163 | | **Understanding** | Keyword matching only | Meaningful connections & relationships | 164 | | **Analysis Depth** | Single-level results | **Detailed** (Local) + **Strategic** (Global) insights | 165 | | **Source Tracking** | Basic page references | Complete traceability to original documents | 166 | | **Context Awareness** | Isolated results | Connected knowledge with relationships | 167 | 168 | --- 169 | 170 | ## Knowledge Architecture Components 171 | 172 | GraphRAG organizes information into **four interconnected layers**: 173 | 174 | | Component | Description | Example | 175 | |-----------|-------------|---------| 176 | | **Entities** | People, organizations, and concepts | `Ratan Tata` • `Tata Group` • `Jaguar Land Rover` | 177 | | **Relationships** | How different entities connect to each other | `Ratan Tata served as Chairman of Tata Group` | 178 | | **Communities** | Groups of related entities by topic | `Business Leadership` • `Automotive Industry` | 179 | | **Text Units** | Original text pieces with entity links | `"Ratan Tata served as Chairman of Tata Group from 1991..."` | 180 | 181 | --- 182 | 183 | ## Query Strategy Selection Guide 184 | 185 | **Choose the right search method for your question type:** 186 | 187 | | Query Example | Search Strategy | Why This Choice | 188 | |---------------|-----------------|-----------------| 189 | | `"What is Ratan Tata's background?"` | **Local Search** | Entity-specific biographical information | 190 | | `"Which companies did Tata Group acquire?"` | **Local Search** | Specific relationship and timeline queries | 191 | | `"What are the main business transformation patterns?"` | **Global Search** | Theme analysis across multiple entities | 192 | | `"Analyze the strategic evolution of Indian business groups"` | **Global Search** | Complete pattern recognition and insights | 193 | 194 | --- 195 | 196 | ## Implementation Resources 197 | 198 | | Resource | Description | Best For | 199 | |----------|-------------|----------| 200 | | **[Indexing Pipeline Guide](../guides/indexing_pipeline.md)** | Complete indexing process documentation | Understanding the build process | 201 | | **[Query System Guide](../guides/query_system.md)** | Local vs Global search explained | Learning when to use each query type | 202 | | **[System Customization Guide](../guides/customization.md)** | Component configuration and extensions | Adapting to your needs | 203 | 204 | --- 205 | 206 | ## Key Value Proposition 207 | 208 | > **GraphRAG transforms how you work with documents** by building intelligent knowledge structures that understand context and relationships. This enables both **precise factual questions** and **strategic analytical insights** - going far beyond traditional search. 209 | 210 | --- 211 | 212 | ## Related Documentation 213 | 214 | **[Indexing Pipeline](../guides/indexing_pipeline.md)** 215 | Technical implementation details and configuration options for building knowledge graphs. 216 | 217 | **[Documentation Index](../index.md)** 218 | Return to documentation overview -------------------------------------------------------------------------------- /src/langchain_graphrag/indexing/graph_generation/entity_relationship_extraction/_default_prompts.py: -------------------------------------------------------------------------------- 1 | # ruff: noqa 2 | 3 | DEFAULT_ER_EXTRACTION_PROMPT = """ 4 | -Goal- 5 | Given a text document that is potentially relevant to this activity and a list of entity types, identify all entities of those types from the text and all relationships among the identified entities. 6 | 7 | -Steps- 8 | 1. Identify all entities. For each identified entity, extract the following information: 9 | - entity_name: Name of the entity, capitalized 10 | - entity_type: One of the following types: [{entity_types}] 11 | - entity_description: Comprehensive description of the entity's attributes and activities 12 | Format each entity as ("entity"{tuple_delimiter}{tuple_delimiter}{tuple_delimiter}) 13 | 14 | 2. From the entities identified in step 1, identify all pairs of (source_entity, target_entity) that are *clearly related* to each other. 15 | For each pair of related entities, extract the following information: 16 | - source_entity: name of the source entity, as identified in step 1 17 | - target_entity: name of the target entity, as identified in step 1 18 | - relationship_description: explanation as to why you think the source entity and the target entity are related to each other 19 | - relationship_strength: a numeric score indicating strength of the relationship between the source entity and target entity 20 | Format each relationship as ("relationship"{tuple_delimiter}{tuple_delimiter}{tuple_delimiter}{tuple_delimiter}) 21 | 22 | 3. Return output in English as a single list of all the entities and relationships identified in steps 1 and 2. Use **{record_delimiter}** as the list delimiter. 23 | 24 | 4. When finished, output {completion_delimiter} 25 | 26 | ###################### 27 | -Examples- 28 | ###################### 29 | Example 1: 30 | Entity_types: ORGANIZATION,PERSON 31 | Text: 32 | The Verdantis's Central Institution is scheduled to meet on Monday and Thursday, with the institution planning to release its latest policy decision on Thursday at 1:30 p.m. PDT, followed by a press conference where Central Institution Chair Martin Smith will take questions. Investors expect the Market Strategy Committee to hold its benchmark interest rate steady in a range of 3.5%-3.75%. 33 | ###################### 34 | Output: 35 | ("entity"{tuple_delimiter}CENTRAL INSTITUTION{tuple_delimiter}ORGANIZATION{tuple_delimiter}The Central Institution is the Federal Reserve of Verdantis, which is setting interest rates on Monday and Thursday) 36 | {record_delimiter} 37 | ("entity"{tuple_delimiter}MARTIN SMITH{tuple_delimiter}PERSON{tuple_delimiter}Martin Smith is the chair of the Central Institution) 38 | {record_delimiter} 39 | ("entity"{tuple_delimiter}MARKET STRATEGY COMMITTEE{tuple_delimiter}ORGANIZATION{tuple_delimiter}The Central Institution committee makes key decisions about interest rates and the growth of Verdantis's money supply) 40 | {record_delimiter} 41 | ("relationship"{tuple_delimiter}MARTIN SMITH{tuple_delimiter}CENTRAL INSTITUTION{tuple_delimiter}Martin Smith is the Chair of the Central Institution and will answer questions at a press conference{tuple_delimiter}9) 42 | {completion_delimiter} 43 | 44 | ###################### 45 | Example 2: 46 | Entity_types: ORGANIZATION 47 | Text: 48 | TechGlobal's (TG) stock skyrocketed in its opening day on the Global Exchange Thursday. But IPO experts warn that the semiconductor corporation's debut on the public markets isn't indicative of how other newly listed companies may perform. 49 | 50 | TechGlobal, a formerly public company, was taken private by Vision Holdings in 2014. The well-established chip designer says it powers 85% of premium smartphones. 51 | ###################### 52 | Output: 53 | ("entity"{tuple_delimiter}TECHGLOBAL{tuple_delimiter}ORGANIZATION{tuple_delimiter}TechGlobal is a stock now listed on the Global Exchange which powers 85% of premium smartphones) 54 | {record_delimiter} 55 | ("entity"{tuple_delimiter}VISION HOLDINGS{tuple_delimiter}ORGANIZATION{tuple_delimiter}Vision Holdings is a firm that previously owned TechGlobal) 56 | {record_delimiter} 57 | ("relationship"{tuple_delimiter}TECHGLOBAL{tuple_delimiter}VISION HOLDINGS{tuple_delimiter}Vision Holdings formerly owned TechGlobal from 2014 until present{tuple_delimiter}5) 58 | {completion_delimiter} 59 | 60 | ###################### 61 | Example 3: 62 | Entity_types: ORGANIZATION,GEO,PERSON 63 | Text: 64 | Five Aurelians jailed for 8 years in Firuzabad and widely regarded as hostages are on their way home to Aurelia. 65 | 66 | The swap orchestrated by Quintara was finalized when $8bn of Firuzi funds were transferred to financial institutions in Krohaara, the capital of Quintara. 67 | 68 | The exchange initiated in Firuzabad's capital, Tiruzia, led to the four men and one woman, who are also Firuzi nationals, boarding a chartered flight to Krohaara. 69 | 70 | They were welcomed by senior Aurelian officials and are now on their way to Aurelia's capital, Cashion. 71 | 72 | The Aurelians include 39-year-old businessman Samuel Namara, who has been held in Tiruzia's Alhamia Prison, as well as journalist Durke Bataglani, 59, and environmentalist Meggie Tazbah, 53, who also holds Bratinas nationality. 73 | ###################### 74 | Output: 75 | ("entity"{tuple_delimiter}FIRUZABAD{tuple_delimiter}GEO{tuple_delimiter}Firuzabad held Aurelians as hostages) 76 | {record_delimiter} 77 | ("entity"{tuple_delimiter}AURELIA{tuple_delimiter}GEO{tuple_delimiter}Country seeking to release hostages) 78 | {record_delimiter} 79 | ("entity"{tuple_delimiter}QUINTARA{tuple_delimiter}GEO{tuple_delimiter}Country that negotiated a swap of money in exchange for hostages) 80 | {record_delimiter} 81 | {record_delimiter} 82 | ("entity"{tuple_delimiter}TIRUZIA{tuple_delimiter}GEO{tuple_delimiter}Capital of Firuzabad where the Aurelians were being held) 83 | {record_delimiter} 84 | ("entity"{tuple_delimiter}KROHAARA{tuple_delimiter}GEO{tuple_delimiter}Capital city in Quintara) 85 | {record_delimiter} 86 | ("entity"{tuple_delimiter}CASHION{tuple_delimiter}GEO{tuple_delimiter}Capital city in Aurelia) 87 | {record_delimiter} 88 | ("entity"{tuple_delimiter}SAMUEL NAMARA{tuple_delimiter}PERSON{tuple_delimiter}Aurelian who spent time in Tiruzia's Alhamia Prison) 89 | {record_delimiter} 90 | ("entity"{tuple_delimiter}ALHAMIA PRISON{tuple_delimiter}GEO{tuple_delimiter}Prison in Tiruzia) 91 | {record_delimiter} 92 | ("entity"{tuple_delimiter}DURKE BATAGLANI{tuple_delimiter}PERSON{tuple_delimiter}Aurelian journalist who was held hostage) 93 | {record_delimiter} 94 | ("entity"{tuple_delimiter}MEGGIE TAZBAH{tuple_delimiter}PERSON{tuple_delimiter}Bratinas national and environmentalist who was held hostage) 95 | {record_delimiter} 96 | ("relationship"{tuple_delimiter}FIRUZABAD{tuple_delimiter}AURELIA{tuple_delimiter}Firuzabad negotiated a hostage exchange with Aurelia{tuple_delimiter}2) 97 | {record_delimiter} 98 | ("relationship"{tuple_delimiter}QUINTARA{tuple_delimiter}AURELIA{tuple_delimiter}Quintara brokered the hostage exchange between Firuzabad and Aurelia{tuple_delimiter}2) 99 | {record_delimiter} 100 | ("relationship"{tuple_delimiter}QUINTARA{tuple_delimiter}FIRUZABAD{tuple_delimiter}Quintara brokered the hostage exchange between Firuzabad and Aurelia{tuple_delimiter}2) 101 | {record_delimiter} 102 | ("relationship"{tuple_delimiter}SAMUEL NAMARA{tuple_delimiter}ALHAMIA PRISON{tuple_delimiter}Samuel Namara was a prisoner at Alhamia prison{tuple_delimiter}8) 103 | {record_delimiter} 104 | ("relationship"{tuple_delimiter}SAMUEL NAMARA{tuple_delimiter}MEGGIE TAZBAH{tuple_delimiter}Samuel Namara and Meggie Tazbah were exchanged in the same hostage release{tuple_delimiter}2) 105 | {record_delimiter} 106 | ("relationship"{tuple_delimiter}SAMUEL NAMARA{tuple_delimiter}DURKE BATAGLANI{tuple_delimiter}Samuel Namara and Durke Bataglani were exchanged in the same hostage release{tuple_delimiter}2) 107 | {record_delimiter} 108 | ("relationship"{tuple_delimiter}MEGGIE TAZBAH{tuple_delimiter}DURKE BATAGLANI{tuple_delimiter}Meggie Tazbah and Durke Bataglani were exchanged in the same hostage release{tuple_delimiter}2) 109 | {record_delimiter} 110 | ("relationship"{tuple_delimiter}SAMUEL NAMARA{tuple_delimiter}FIRUZABAD{tuple_delimiter}Samuel Namara was a hostage in Firuzabad{tuple_delimiter}2) 111 | {record_delimiter} 112 | ("relationship"{tuple_delimiter}MEGGIE TAZBAH{tuple_delimiter}FIRUZABAD{tuple_delimiter}Meggie Tazbah was a hostage in Firuzabad{tuple_delimiter}2) 113 | {record_delimiter} 114 | ("relationship"{tuple_delimiter}DURKE BATAGLANI{tuple_delimiter}FIRUZABAD{tuple_delimiter}Durke Bataglani was a hostage in Firuzabad{tuple_delimiter}2) 115 | {completion_delimiter} 116 | 117 | ###################### 118 | -Real Data- 119 | ###################### 120 | Entity_types: {entity_types} 121 | Text: {input_text} 122 | ###################### 123 | Output: 124 | 125 | 126 | """ 127 | -------------------------------------------------------------------------------- /docs/guides/text_units_extraction.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Text Units Extraction" 8 | ] 9 | }, 10 | { 11 | "cell_type": "markdown", 12 | "metadata": {}, 13 | "source": [ 14 | "## Overview" 15 | ] 16 | }, 17 | { 18 | "cell_type": "markdown", 19 | "metadata": {}, 20 | "source": [ 21 | "This guide shows the usage of `TextUnitExtractor` class which relies on \n", 22 | "the supplied `TextSplitter` to extract text units from the supplied documents\n", 23 | "\n", 24 | "The output of this component is a pandas DataFrame with the following columns:\n", 25 | "- `document_id`\n", 26 | "- `id`\n", 27 | "- `text_unit`" 28 | ] 29 | }, 30 | { 31 | "cell_type": "markdown", 32 | "metadata": {}, 33 | "source": [ 34 | "## Make a fake Document" 35 | ] 36 | }, 37 | { 38 | "cell_type": "markdown", 39 | "metadata": {}, 40 | "source": [ 41 | "Below is some random text that we will use to make a `langchain` [Document](https://api.python.langchain.com/en/latest/documents/langchain_core.documents.base.Document.html#langchain_core.documents.base.Document)." 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": 1, 47 | "metadata": {}, 48 | "outputs": [], 49 | "source": [ 50 | "from langchain_core.documents import Document\n", 51 | "\n", 52 | "from langchain_graphrag.indexing import TextUnitExtractor" 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": 2, 58 | "metadata": {}, 59 | "outputs": [], 60 | "source": [ 61 | "SOME_TEXT = \"\"\"\n", 62 | "Contrary to popular belief, Lorem Ipsum is not simply random text. \n", 63 | "It has roots in a piece of classical Latin literature from 45 BC, \n", 64 | "making it over 2000 years old. Richard McClintock, a Latin professor \n", 65 | "at Hampden-Sydney College in Virginia, looked up one of the more obscure Latin words,\n", 66 | "consectetur, from a Lorem Ipsum passage, and going through the cites of the word in \n", 67 | "classical literature, discovered the undoubtable source. Lorem Ipsum comes \n", 68 | "from sections 1.10.32 and 1.10.33 of \"de Finibus Bonorum et Malorum\" \n", 69 | "(The Extremes of Good and Evil) by Cicero, written in 45 BC. This book is a \n", 70 | "treatise on the theory of ethics, very popular during the Renaissance. \n", 71 | "The first line of Lorem Ipsum, \"Lorem ipsum dolor sit amet..\", \n", 72 | "comes from a line in section 1.10.32.\n", 73 | "\n", 74 | "The standard chunk of Lorem Ipsum used since the 1500s is reproduced below \n", 75 | "for those interested. Sections 1.10.32 and 1.10.33 from \"de Finibus Bonorum et \n", 76 | "Malorum\" by Cicero are also reproduced in their exact original form, accompanied\n", 77 | "by English versions from the 1914 translation by H. Rackham.\n", 78 | "\"\"\"\n", 79 | "\n", 80 | "document = Document(page_content=SOME_TEXT)" 81 | ] 82 | }, 83 | { 84 | "cell_type": "markdown", 85 | "metadata": {}, 86 | "source": [ 87 | "## Select a TextSplitter" 88 | ] 89 | }, 90 | { 91 | "cell_type": "markdown", 92 | "metadata": {}, 93 | "source": [ 94 | "`TextUnitExtractor` requirs you to supply a TextSplitter.\n", 95 | "\n", 96 | "See all available splitters from [langchain_text_splitters](https://api.python.langchain.com/en/latest/text_splitters_api_reference.html) and of course you can write your own splitter.\n", 97 | "\n", 98 | "In this example, we are going to use the simplest of them - [CharacterTextSplitter](https://api.python.langchain.com/en/latest/character/langchain_text_splitters.character.CharacterTextSplitter.html#langchain_text_splitters.character.CharacterTextSplitter)." 99 | ] 100 | }, 101 | { 102 | "cell_type": "code", 103 | "execution_count": 3, 104 | "metadata": {}, 105 | "outputs": [], 106 | "source": [ 107 | "from langchain_text_splitters import CharacterTextSplitter" 108 | ] 109 | }, 110 | { 111 | "cell_type": "code", 112 | "execution_count": 4, 113 | "metadata": {}, 114 | "outputs": [], 115 | "source": [ 116 | "splitter = CharacterTextSplitter(chunk_size=512, chunk_overlap=64)\n", 117 | "\n", 118 | "text_unit_extractor = TextUnitExtractor(text_splitter=splitter)" 119 | ] 120 | }, 121 | { 122 | "cell_type": "markdown", 123 | "metadata": {}, 124 | "source": [ 125 | "## Run the TextUnitExtractor" 126 | ] 127 | }, 128 | { 129 | "cell_type": "markdown", 130 | "metadata": {}, 131 | "source": [ 132 | "And now we run it, the run method takes the list of the documents and returns\n", 133 | "a pandas DataFrame object." 134 | ] 135 | }, 136 | { 137 | "cell_type": "code", 138 | "execution_count": 5, 139 | "metadata": {}, 140 | "outputs": [ 141 | { 142 | "name": "stderr", 143 | "output_type": "stream", 144 | "text": [ 145 | "Processing documents ...: 0%| | 0/1 [00:00\n", 154 | "\n", 167 | "\n", 168 | " \n", 169 | " \n", 170 | " \n", 171 | " \n", 172 | " \n", 173 | " \n", 174 | " \n", 175 | " \n", 176 | " \n", 177 | " \n", 178 | " \n", 179 | " \n", 180 | " \n", 181 | " \n", 182 | " \n", 183 | " \n", 184 | " \n", 185 | " \n", 186 | " \n", 187 | " \n", 188 | " \n", 189 | " \n", 190 | "
document_ididtext_unit
0d6b99162-5843-4c73-89c1-e53d92d6dd56534d87de-7463-47b0-81d2-2e4c392b4e7bContrary to popular belief, Lorem Ipsum is not...
1d6b99162-5843-4c73-89c1-e53d92d6dd5678f69b7b-9c74-4c57-af0d-e39c83119866The standard chunk of Lorem Ipsum used since t...
\n", 191 | "" 192 | ], 193 | "text/plain": [ 194 | " document_id id \\\n", 195 | "0 d6b99162-5843-4c73-89c1-e53d92d6dd56 534d87de-7463-47b0-81d2-2e4c392b4e7b \n", 196 | "1 d6b99162-5843-4c73-89c1-e53d92d6dd56 78f69b7b-9c74-4c57-af0d-e39c83119866 \n", 197 | "\n", 198 | " text_unit \n", 199 | "0 Contrary to popular belief, Lorem Ipsum is not... \n", 200 | "1 The standard chunk of Lorem Ipsum used since t... " 201 | ] 202 | }, 203 | "execution_count": 5, 204 | "metadata": {}, 205 | "output_type": "execute_result" 206 | } 207 | ], 208 | "source": [ 209 | "df_text_units = text_unit_extractor.run([document])\n", 210 | "\n", 211 | "df_text_units.head()" 212 | ] 213 | }, 214 | { 215 | "cell_type": "markdown", 216 | "metadata": {}, 217 | "source": [ 218 | "## Final Remarks\n", 219 | "\n", 220 | "As you can see above, this dataframe has three columns:\n", 221 | "- `document_id`\n", 222 | "- `id`\n", 223 | "- `text_unit`\n", 224 | "\n", 225 | "Since our document was not very big, given our `chunk_size` we only have two rows\n", 226 | "\n", 227 | "Every text_unit gets a unique id that would be used in other components.\n", 228 | "\n", 229 | "If the document object (type `Document`) did not have `id` then one is \n", 230 | "generated by the `TextUnitExtractor`." 231 | ] 232 | } 233 | ], 234 | "metadata": { 235 | "kernelspec": { 236 | "display_name": ".venv", 237 | "language": "python", 238 | "name": "python3" 239 | }, 240 | "language_info": { 241 | "codemirror_mode": { 242 | "name": "ipython", 243 | "version": 3 244 | }, 245 | "file_extension": ".py", 246 | "mimetype": "text/x-python", 247 | "name": "python", 248 | "nbconvert_exporter": "python", 249 | "pygments_lexer": "ipython3", 250 | "version": "3.10.14" 251 | } 252 | }, 253 | "nbformat": 4, 254 | "nbformat_minor": 2 255 | } 256 | --------------------------------------------------------------------------------