├── .gitignore ├── README.md ├── env.template ├── knowledge_graph ├── __init__.py ├── cassandra_graph_store.py ├── extraction.py ├── knowledge_graph.py ├── knowledge_schema.py ├── prompt_templates │ ├── extraction.md │ └── schema_inference.md ├── render.py ├── runnables.py ├── schema_inference.py ├── templates.py ├── traverse.py └── utils.py ├── notebook.ipynb ├── poetry.lock ├── pyproject.toml └── tests ├── __init__.py ├── conftest.py ├── marie_curie_schema.yaml ├── test_extraction.py ├── test_knowledge_graph.py ├── test_runnables.py ├── test_schema_inference.py └── test_traverse.py /.gitignore: -------------------------------------------------------------------------------- 1 | .pytest_cache 2 | __pycache__ 3 | .env -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Example GraphRAG using Astra 2 | 3 | > [!WARNING] 4 | > This repository is deprecated. 5 | > See https://github.com/datastax/ragstack-ai/tree/main/libs/knowledge-graph 6 | -------------------------------------------------------------------------------- /env.template: -------------------------------------------------------------------------------- 1 | OPENAI_API_KEY="" 2 | LANGCHAIN_TRACING_V2=true 3 | LANGCHAIN_API_KEY="" 4 | ASTRA_DB_DATABASE_ID="" 5 | ASTRA_DB_APPLICATION_TOKEN="" 6 | ASTRA_DB_KEYSPACE="" -------------------------------------------------------------------------------- /knowledge_graph/__init__.py: -------------------------------------------------------------------------------- 1 | from .cassandra_graph_store import CassandraGraphStore 2 | from .runnables import extract_entities 3 | from .traverse import Node, Relation 4 | 5 | __all__ = ["CassandraGraphStore", "extract_entities", "Node", "Relation"] 6 | -------------------------------------------------------------------------------- /knowledge_graph/cassandra_graph_store.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, Iterable, List, Optional, Sequence, Union 2 | 3 | from cassandra.cluster import Session 4 | from langchain_community.graphs.graph_document import GraphDocument 5 | from langchain_community.graphs.graph_document import Node as LangChainNode 6 | from langchain_community.graphs.graph_store import GraphStore 7 | from langchain_core.embeddings import Embeddings 8 | from langchain_core.runnables import Runnable, RunnableLambda 9 | 10 | from knowledge_graph.knowledge_graph import CassandraKnowledgeGraph 11 | 12 | from .traverse import Node, Relation 13 | 14 | 15 | def _elements(documents: Iterable[GraphDocument]) -> Iterable[Union[Node, Relation]]: 16 | def _node(node: LangChainNode) -> Node: 17 | return Node(name=str(node.id), type=node.type) 18 | 19 | for document in documents: 20 | for node in document.nodes: 21 | yield _node(node) 22 | for edge in document.relationships: 23 | yield Relation(source=_node(edge.source), target=_node(edge.target), type=edge.type) 24 | 25 | 26 | class CassandraGraphStore(GraphStore): 27 | def __init__( 28 | self, 29 | node_table: str = "entities", 30 | edge_table: str = "relationships", 31 | text_embeddings: Optional[Embeddings] = None, 32 | session: Optional[Session] = None, 33 | keyspace: Optional[str] = None, 34 | ) -> None: 35 | """ 36 | Create a Cassandra Graph Store. 37 | 38 | Before calling this, you must initialize cassio with `cassio.init`, or 39 | provide valid session and keyspace values. 40 | """ 41 | self.graph = CassandraKnowledgeGraph( 42 | node_table=node_table, 43 | edge_table=edge_table, 44 | text_embeddings=text_embeddings, 45 | session=session, 46 | keyspace=keyspace, 47 | ) 48 | 49 | def add_graph_documents( 50 | self, graph_documents: List[GraphDocument], include_source: bool = False 51 | ) -> None: 52 | # TODO: Include source. 53 | self.graph.insert(_elements(graph_documents)) 54 | 55 | # TODO: should this include the types of each node? 56 | def query(self, query: str, params: dict = {}) -> List[Dict[str, Any]]: 57 | raise ValueError("Querying Cassandra should use `as_runnable`.") 58 | 59 | def as_runnable(self, steps: int = 3, edge_filters: Sequence[str] = []) -> Runnable: 60 | """ 61 | Return a runnable that retrieves the sub-graph near the input entity or entities. 62 | 63 | Parameters: 64 | - steps: The maximum distance to follow from the starting points. 65 | - edge_filters: Predicates to use for filtering the edges. 66 | """ 67 | return RunnableLambda(func=self.graph.traverse, afunc=self.graph.atraverse).bind( 68 | steps=steps, 69 | edge_filters=edge_filters, 70 | ) 71 | -------------------------------------------------------------------------------- /knowledge_graph/extraction.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List, Sequence, Union, cast 2 | 3 | from langchain_community.graphs.graph_document import GraphDocument 4 | from langchain_core.documents import Document 5 | from langchain_core.language_models.chat_models import BaseChatModel 6 | from langchain_core.prompts import ( 7 | ChatPromptTemplate, 8 | HumanMessagePromptTemplate, 9 | SystemMessagePromptTemplate, 10 | ) 11 | from langchain_core.pydantic_v1 import BaseModel 12 | from langchain_experimental.graph_transformers.llm import ( 13 | _Graph, 14 | create_simple_model, 15 | map_to_base_node, 16 | map_to_base_relationship, 17 | ) 18 | 19 | from knowledge_graph.knowledge_schema import ( 20 | Example, 21 | KnowledgeSchema, 22 | KnowledgeSchemaValidator, 23 | ) 24 | from knowledge_graph.templates import load_template 25 | 26 | 27 | def _format_example(idx: int, example: Example) -> str: 28 | from pydantic_yaml import to_yaml_str 29 | 30 | return f"Example {idx}:\n```yaml\n{to_yaml_str(example)}\n```" 31 | 32 | 33 | class KnowledgeSchemaExtractor: 34 | def __init__( 35 | self, 36 | llm: BaseChatModel, 37 | schema: KnowledgeSchema, 38 | examples: Sequence[Example] = [], 39 | strict: bool = False, 40 | ) -> None: 41 | self._validator = KnowledgeSchemaValidator(schema) 42 | self.strict = strict 43 | 44 | messages = [ 45 | SystemMessagePromptTemplate( 46 | prompt=load_template("extraction.md", knowledge_schema_yaml=schema.to_yaml_str()) 47 | ) 48 | ] 49 | 50 | if examples: 51 | formatted = "\n\n".join(map(_format_example, examples)) 52 | messages.append(SystemMessagePromptTemplate(prompt=formatted)) 53 | 54 | messages.append(HumanMessagePromptTemplate.from_template("Input: {input}")) 55 | 56 | prompt = ChatPromptTemplate.from_messages(messages) 57 | schema = create_simple_model( 58 | node_labels=[node.type for node in schema.nodes], 59 | rel_types=list({r.edge_type for r in schema.relationships}), 60 | ) 61 | # TODO: Use "full" output so we can detect parsing errors? 62 | structured_llm = llm.with_structured_output(schema) 63 | self._chain = prompt | structured_llm 64 | 65 | def _process_response( 66 | self, document: Document, response: Union[Dict, BaseModel] 67 | ) -> GraphDocument: 68 | raw_graph = cast(_Graph, response) 69 | nodes = [map_to_base_node(node) for node in raw_graph.nodes] if raw_graph.nodes else [] 70 | relationships = ( 71 | [map_to_base_relationship(rel) for rel in raw_graph.relationships] 72 | if raw_graph.relationships 73 | else [] 74 | ) 75 | 76 | document = GraphDocument(nodes=nodes, relationships=relationships, source=document) 77 | 78 | if self.strict: 79 | self._validator.validate_graph_document(document) 80 | 81 | return document 82 | 83 | def extract(self, documents: List[Document]) -> List[GraphDocument]: 84 | # TODO: Define an async version of extraction? 85 | responses = self._chain.batch_as_completed( 86 | [{"input": doc.page_content} for doc in documents] 87 | ) 88 | return [self._process_response(documents[idx], response) for idx, response in responses] 89 | -------------------------------------------------------------------------------- /knowledge_graph/knowledge_graph.py: -------------------------------------------------------------------------------- 1 | import json 2 | from itertools import repeat 3 | from typing import Any, Dict, Iterable, Optional, Sequence, Tuple, Union, cast 4 | 5 | from cassandra.cluster import ResponseFuture, Session 6 | from cassandra.query import BatchStatement 7 | from cassio.config import check_resolve_keyspace, check_resolve_session 8 | from langchain_core.embeddings import Embeddings 9 | 10 | from .traverse import Node, Relation, atraverse, traverse 11 | from .utils import batched 12 | 13 | 14 | def _serialize_md_dict(md_dict: Dict[str, Any]) -> str: 15 | return json.dumps(md_dict, separators=(",", ":"), sort_keys=True) 16 | 17 | 18 | def _deserialize_md_dict(md_string: str) -> Dict[str, Any]: 19 | return cast(Dict[str, Any], json.loads(md_string)) 20 | 21 | 22 | def _parse_node(row) -> Node: 23 | return Node( 24 | name=row.name, 25 | type=row.type, 26 | properties=_deserialize_md_dict(row.properties_json) if row.properties_json else dict(), 27 | ) 28 | 29 | 30 | class CassandraKnowledgeGraph: 31 | def __init__( 32 | self, 33 | node_table: str = "entities", 34 | edge_table: str = "relationships", 35 | text_embeddings: Optional[Embeddings] = None, 36 | session: Optional[Session] = None, 37 | keyspace: Optional[str] = None, 38 | apply_schema: bool = True, 39 | ) -> None: 40 | """ 41 | Create a Cassandra Knowledge Graph. 42 | 43 | Parameters: 44 | - node_table: Name of the table containing nodes. Defaults to `"entities"`. 45 | - edge_table: Name of the table containing edges. Defaults to `"relationships"`. 46 | _ text_embeddings: Name of the embeddings to use, if any. 47 | - session: The Cassandra `Session` to use. If not specified, uses the default `cassio` 48 | session, which requires `cassio.init` has been called. 49 | - keyspace: The Cassandra keyspace to use. If not specified, uses the default `cassio` 50 | keyspace, which requires `cassio.init` has been called. 51 | - apply_schema: If true, the node table and edge table are created. 52 | """ 53 | 54 | session = check_resolve_session(session) 55 | keyspace = check_resolve_keyspace(keyspace) 56 | 57 | self._text_embeddings = text_embeddings 58 | self._text_embeddings_dim = ( 59 | # Embedding vectors must have dimension: 60 | # > 0 to be created at all. 61 | # > 1 to support cosine distance. 62 | # So we default to 2. 63 | len(text_embeddings.embed_query("test string")) if text_embeddings else 2 64 | ) 65 | 66 | self._session = session 67 | self._keyspace = keyspace 68 | 69 | self._node_table = node_table 70 | self._edge_table = edge_table 71 | 72 | if apply_schema: 73 | self._apply_schema() 74 | 75 | self._insert_node = self._session.prepare( 76 | f"""INSERT INTO {keyspace}.{node_table} ( 77 | name, type, text_embedding, properties_json 78 | ) VALUES (?, ?, ?, ?) 79 | """ 80 | ) 81 | 82 | self._insert_relationship = self._session.prepare( 83 | f""" 84 | INSERT INTO {keyspace}.{edge_table} ( 85 | source_name, source_type, target_name, target_type, edge_type 86 | ) VALUES (?, ?, ?, ?, ?) 87 | """ 88 | ) 89 | 90 | self._query_relationship = self._session.prepare( 91 | f""" 92 | SELECT name, type, properties_json 93 | FROM {keyspace}.{node_table} 94 | WHERE name = ? AND type = ? 95 | """ 96 | ) 97 | 98 | self._query_nodes_by_embedding = self._session.prepare( 99 | f""" 100 | SELECT name, type, properties_json 101 | FROM {keyspace}.{node_table} 102 | ORDER BY text_embedding ANN OF ? 103 | LIMIT ? 104 | """ 105 | ) 106 | 107 | def _apply_schema(self): 108 | # Partition by `name` and cluster by `type`. 109 | # Each `(name, type)` pair is a unique node. 110 | # We can enumerate all `type` values for a given `name` to identify ambiguous terms. 111 | self._session.execute( 112 | f""" 113 | CREATE TABLE IF NOT EXISTS {self._keyspace}.{self._node_table} ( 114 | name TEXT, 115 | type TEXT, 116 | properties_json TEXT, 117 | text_embedding VECTOR, 118 | PRIMARY KEY (name, type) 119 | ); 120 | """ 121 | ) 122 | 123 | self._session.execute( 124 | f""" 125 | CREATE TABLE IF NOT EXISTS {self._keyspace}.{self._edge_table} ( 126 | source_name TEXT, 127 | source_type TEXT, 128 | target_name TEXT, 129 | target_type TEXT, 130 | edge_type TEXT, 131 | PRIMARY KEY ((source_name, source_type), target_name, target_type, edge_type) 132 | ); 133 | """ 134 | ) 135 | 136 | self._session.execute( 137 | f""" 138 | CREATE CUSTOM INDEX IF NOT EXISTS {self._node_table}_text_embedding_index 139 | ON {self._keyspace}.{self._node_table} (text_embedding) 140 | USING 'StorageAttachedIndex'; 141 | """ 142 | ) 143 | 144 | self._session.execute( 145 | f""" 146 | CREATE CUSTOM INDEX IF NOT EXISTS {self._edge_table}_type_index 147 | ON {self._keyspace}.{self._edge_table} (edge_type) 148 | USING 'StorageAttachedIndex'; 149 | """ 150 | ) 151 | 152 | def _send_query_nearest_node(self, node: str, k: int = 1) -> ResponseFuture: 153 | return self._session.execute_async( 154 | self._query_nodes_by_embedding, 155 | ( 156 | self._text_embeddings.embed_query(node), 157 | k, 158 | ), 159 | ) 160 | 161 | # TODO: Allow filtering by node predicates and/or minimum similarity. 162 | def query_nearest_nodes(self, nodes: Iterable[str], k: int = 1) -> Iterable[Node]: 163 | """ 164 | For each node, return the nearest nodes in the table. 165 | 166 | Parameters: 167 | - nodes: The strings to search for in the list of nodes. 168 | - k: The number of similar nodes to retrieve for each string. 169 | """ 170 | if self._text_embeddings is None: 171 | raise ValueError("Unable to query for nearest nodes without embeddings") 172 | 173 | node_futures: Iterable[ResponseFuture] = [ 174 | self._send_query_nearest_node(n, k) for n in nodes 175 | ] 176 | 177 | nodes = {_parse_node(n) for node_future in node_futures for n in node_future.result()} 178 | return list(nodes) 179 | 180 | # TODO: Introduce `ainsert` for async insertions. 181 | def insert( 182 | self, 183 | elements: Iterable[Union[Node, Relation]], 184 | ) -> None: 185 | for batch in batched(elements, n=4): 186 | from yaml import dump 187 | text_embeddings = ( 188 | iter( 189 | self._text_embeddings.embed_documents( 190 | [dump(n) for n in batch if isinstance(n, Node)] 191 | ) 192 | ) 193 | if self._text_embeddings 194 | else repeat([0.0, 1.0]) 195 | ) 196 | 197 | batch_statement = BatchStatement() 198 | for element in batch: 199 | if isinstance(element, Node): 200 | properties_json = _serialize_md_dict(element.properties) 201 | batch_statement.add( 202 | self._insert_node, 203 | (element.name, element.type, next(text_embeddings), properties_json), 204 | ) 205 | elif isinstance(element, Relation): 206 | batch_statement.add( 207 | self._insert_relationship, 208 | ( 209 | element.source.name, 210 | element.source.type, 211 | element.target.name, 212 | element.target.type, 213 | element.type, 214 | ), 215 | ) 216 | else: 217 | raise ValueError(f"Unsupported element type: {element}") 218 | 219 | # TODO: Support concurrent execution of these statements. 220 | self._session.execute(batch_statement) 221 | 222 | def subgraph( 223 | self, 224 | start: Node | Sequence[Node], 225 | edge_filters: Sequence[str] = (), 226 | steps: int = 3, 227 | ) -> Tuple[Iterable[Node], Iterable[Relation]]: 228 | """ 229 | Retrieve the sub-graph from the given starting nodes. 230 | """ 231 | edges = self.traverse(start, edge_filters, steps) 232 | 233 | # Create the set of nodes. 234 | nodes = {n for e in edges for n in (e.source, e.target)} 235 | 236 | # Retrieve the set of nodes to get the properties. 237 | 238 | # TODO: We really should have a NodeKey separate from Node. Otherwise, we end 239 | # up in a state where two nodes can be the "same" but with different properties, 240 | # etc. 241 | 242 | node_futures: Iterable[ResponseFuture] = [ 243 | self._session.execute_async(self._query_relationship, (n.name, n.type)) for n in nodes 244 | ] 245 | 246 | nodes = [_parse_node(n) for future in node_futures for n in future.result()] 247 | 248 | return (nodes, edges) 249 | 250 | def traverse( 251 | self, 252 | start: Node | Sequence[Node], 253 | edge_filters: Sequence[str] = (), 254 | steps: int = 3, 255 | ) -> Iterable[Relation]: 256 | """ 257 | Traverse the graph from the given starting nodes and return the resulting sub-graph. 258 | 259 | Parameters: 260 | - start: The starting node or nodes. 261 | - edge_filters: Filters to apply to the edges being traversed. 262 | - steps: The number of steps of edges to follow from a start node. 263 | 264 | Returns: 265 | An iterable over relations in the traversed sub-graph. 266 | """ 267 | return traverse( 268 | start=start, 269 | edge_table=self._edge_table, 270 | edge_source_name="source_name", 271 | edge_source_type="source_type", 272 | edge_target_name="target_name", 273 | edge_target_type="target_type", 274 | edge_type="edge_type", 275 | edge_filters=edge_filters, 276 | steps=steps, 277 | session=self._session, 278 | keyspace=self._keyspace, 279 | ) 280 | 281 | async def atraverse( 282 | self, 283 | start: Node | Sequence[Node], 284 | edge_filters: Sequence[str] = (), 285 | steps: int = 3, 286 | ) -> Iterable[Relation]: 287 | """ 288 | Traverse the graph from the given starting nodes and return the resulting sub-graph. 289 | 290 | Parameters: 291 | - start: The starting node or nodes. 292 | - edge_filters: Filters to apply to the edges being traversed. 293 | - steps: The number of steps of edges to follow from a start node. 294 | 295 | Returns: 296 | An iterable over relations in the traversed sub-graph. 297 | """ 298 | return await atraverse( 299 | start=start, 300 | edge_table=self._edge_table, 301 | edge_source_name="source_name", 302 | edge_source_type="source_type", 303 | edge_target_name="target_name", 304 | edge_target_type="target_type", 305 | edge_type="edge_type", 306 | edge_filters=edge_filters, 307 | steps=steps, 308 | session=self._session, 309 | keyspace=self._keyspace, 310 | ) 311 | -------------------------------------------------------------------------------- /knowledge_graph/knowledge_schema.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Dict, List, Self, Sequence, Union 3 | 4 | from langchain_community.graphs.graph_document import GraphDocument 5 | from langchain_core.pydantic_v1 import BaseModel 6 | 7 | from knowledge_graph.traverse import Node, Relation 8 | 9 | 10 | class NodeSchema(BaseModel): 11 | type: str 12 | """The name of the node type.""" 13 | 14 | description: str 15 | """Description of the node type.""" 16 | 17 | 18 | class EdgeSchema(BaseModel): 19 | type: str 20 | """The name of the edge type.""" 21 | 22 | description: str 23 | """Description of the edge type.""" 24 | 25 | 26 | class RelationshipSchema(BaseModel): 27 | edge_type: str 28 | """The name of the edge type for the relationhsip.""" 29 | 30 | source_types: List[str] 31 | """The node types for the source of the relationship.""" 32 | 33 | target_types: List[str] 34 | """The node types for the target of the relationship.""" 35 | 36 | description: str 37 | """Description of the relationship.""" 38 | 39 | 40 | class Example(BaseModel): 41 | input: str 42 | """The source input.""" 43 | 44 | nodes: Sequence[Node] 45 | """The extracted example nodes.""" 46 | 47 | edges: Sequence[Relation] 48 | """The extracted example relationhsips.""" 49 | 50 | 51 | class KnowledgeSchema(BaseModel): 52 | nodes: List[NodeSchema] 53 | """Allowed node types for the knowledge schema.""" 54 | 55 | relationships: List[RelationshipSchema] 56 | """Allowed relationships for the knowledge schema.""" 57 | 58 | @classmethod 59 | def from_file(cls, path: Union[str, Path]) -> Self: 60 | """Load a KnowledgeSchema from a JSON or YAML file. 61 | 62 | Parameters: 63 | - path: The path to the file to load. 64 | """ 65 | from pydantic_yaml import parse_yaml_file_as 66 | 67 | return parse_yaml_file_as(cls, path) 68 | 69 | def to_yaml_str(self) -> str: 70 | from pydantic_yaml import to_yaml_str 71 | 72 | return to_yaml_str(self) 73 | 74 | 75 | class KnowledgeSchemaValidator: 76 | def __init__(self, schema: KnowledgeSchema) -> None: 77 | self._schema = schema 78 | 79 | self._nodes = {node.type: node for node in schema.nodes} 80 | 81 | self._relationships: Dict[str, List[RelationshipSchema]] = {} 82 | for r in schema.relationships: 83 | self._relationships.setdefault(r.edge_type, []).append(r) 84 | 85 | # TODO: Validate the relationship. 86 | # source/target type should exist in nodes, edge_type should exist in edges 87 | 88 | def validate_graph_document(self, document: GraphDocument): 89 | e = ValueError("Invalid graph document for schema") 90 | for node_type in {node.type for node in document.nodes}: 91 | if node_type not in self._nodes: 92 | e.add_note(f"No node type '{node_type}") 93 | for r in document.relationships: 94 | relationships = self._relationships.get(r.edge_type, None) 95 | if relationships is None: 96 | e.add_note(f"No edge type '{r.edge_type}") 97 | else: 98 | relationship = next( 99 | ( 100 | candidate 101 | for candidate in relationships 102 | if r.source_type in candidate.source_types 103 | if r.target_type in candidate.target_types 104 | ) 105 | ) 106 | if relationship is None: 107 | e.add_note( 108 | f"No relationship allows ({r.source_id} -> {r.type} -> {r.target.type})" 109 | ) 110 | 111 | if e.__notes__: 112 | raise e 113 | -------------------------------------------------------------------------------- /knowledge_graph/prompt_templates/extraction.md: -------------------------------------------------------------------------------- 1 | # Knowledge Graph Instructions for GPT-4 2 | 3 | ## 1. Overview 4 | You are a top-tier algorithm designed for extracting information in structured formats to build a knowledge graph. 5 | Try to capture as much information from the text as possible without sacrificing accuracy. 6 | Do not add any information that is not explicitly mentioned in the text. 7 | 8 | The aim is to achieve simplicity and clarity in the knowledge graph, making it accessible for a vast audience. 9 | 10 | - **Nodes** represent entities and concepts. 11 | - **Edges** represent relationships between entities or concepts. 12 | 13 | ## 2. Labeling Nodes 14 | 15 | - **Node IDs**: Never utilize integers as node IDs. Node IDs should be names or human-readable identifiers found in the text. 16 | - **Node Types**: Ensure you use available node types for node types. 17 | 18 | ## 3. Labeling Edges 19 | 20 | - **Edge Types**: Ensure you use available edge types for edge types. 21 | - **Edge Consistency**: Ensure the source and target of each edge are consistent with one of the defined patterns. 22 | 23 | ## 4. Coreference Resolution 24 | - **Maintain Entity Consistency**: When extracting entities, it's vital to ensure consistency. If an entity, such as "John Doe", is mentioned multiple times in the text but is referred to by different names or pronouns (e.g., "Joe", "he") always use the most complete identifier for that entity throughout the knowledge graph. In this example, use "John Doe" as the entity ID. 25 | 26 | Remember, the knowledge graph should be coherent and easily understandable, so maintaining consistency in entity references is crucial. 27 | 28 | ## 5. Strict Compliance 29 | Adhere to the rules strictly. Non-compliance will result in termination. 30 | 31 | ## 6. Knowledge Schema 32 | 33 | Use the following knowledge schema when extracting information for the knowledge graph. 34 | 35 | ```yaml 36 | {knowledge_schema_yaml} 37 | ``` -------------------------------------------------------------------------------- /knowledge_graph/prompt_templates/schema_inference.md: -------------------------------------------------------------------------------- 1 | # Knowledge Schema Instructions for GPT-4 2 | 3 | ## 1. Overview 4 | You are a top-tier algorithm designed for extracting knowledge schemas from unstructured content. 5 | Try to create a knowledge schema that captures as much information from the text as possible without sacrificing accuracy. 6 | Do not add anything to the schema that is explicitly related to the text. 7 | 8 | The aim is to achieve simplicity, clarity, and generality in the knowledge schema, making it applicable to other documents from the same corpus. 9 | 10 | - Simplicity: The knowledge schema should have as few node and edge types as needed while allowing a knowledge created with nodes and edges instantiated from those types to capture the information in this and similar documents. 11 | - Clarity: The knowledge schema should clearly identify each type, and the descriptions shouldn't be confusing. It should be obvious which type a given concept best fits. 12 | - Generality: The knowledge schema should be useful for describing the concepts in not just this document but other similar documents from the same domain. 13 | - Completeness: The knowledge schema should allow capturing as much information as possible from the content. 14 | 15 | The knowledge schema should be able to capture all the information in the source documents and similar documents. 16 | 17 | The knowledge schema should be specific enough to reject invalid knowledge graphs, such as treating a relationship saying an edge between two people saying "studied_at". 18 | 19 | ## 2. Node Types 20 | 21 | Nodes represent entities and concepts in the knowledge graph. 22 | Each node is associated with a type from the knowledge schema. 23 | 24 | Node types should correspond to specific basic or elementary types. 25 | For instance, a knowledge schema with the node type "person" would allow the knowledge graph to represent many people as nodes with the type "person". 26 | Avoid more specific terms node types like "mathematician" or "scientist". 27 | 28 | Distinct kinds of entities or concepts should have distinct node types. 29 | For example, nationalities should be represented as a distinct "nationality" node type rather than a "person" or "award". 30 | 31 | ## 3. Relationship Types 32 | 33 | Edges represent relationships in the knowledge graph. 34 | Each edge is associated with a type from the knowledge schema. 35 | 36 | Relationship types describe a specific edge type, as well as the node types which may be used as sources and targets of the edge. 37 | Ensure consistency and generality in relationship types when constructing knowledge schemas. 38 | Instead of using specific and momentary types such as 'became_professor', use more general and timeless relationship types like 'professor'. 39 | Make sure to use general and timeless relationship types! 40 | 41 | Relationships should respect common sense. 42 | A person is not a location or place of learning, so it should not be possible to have a "studied_at" relationship targeting a person. 43 | For example, nodes of type "person" should not be valid targets of a relationship representing nationalities. 44 | 45 | If an edge is symmetric, it should be noted in the description. 46 | For example, a relationship representing marriage should be symmetric. 47 | 48 | For non-symmetric edges, the direction should be from more specific to more general. 49 | This makes it easier to start with questions about a specific concept (a person or place) and locate information about that concept. 50 | For example, relationships involving a person should generally start at the person and target various information about that person. 51 | 52 | ## 4. Strict Compliance 53 | Adhere to the rules strictly. Non-compliance will result in termination. -------------------------------------------------------------------------------- /knowledge_graph/render.py: -------------------------------------------------------------------------------- 1 | from typing import Iterable, Union 2 | 3 | import graphviz 4 | from langchain_community.graphs.graph_document import GraphDocument, Node 5 | 6 | from knowledge_graph.knowledge_schema import KnowledgeSchema 7 | 8 | 9 | def _node_label(node: Node) -> str: 10 | return f"{node.id} [{node.type}]" 11 | 12 | 13 | def print_graph_documents(graph_documents: Union[GraphDocument, Iterable[GraphDocument]]): 14 | if isinstance(graph_documents, GraphDocument): 15 | graph_documents = [graph_documents] 16 | 17 | for doc in graph_documents: 18 | for relation in doc.relationships: 19 | source = relation.source 20 | target = relation.target 21 | type = relation.type 22 | print(f"{_node_label(source)} -> {_node_label(target)}: {type}") 23 | 24 | 25 | def render_graph_documents( 26 | graph_documents: Union[GraphDocument, Iterable[GraphDocument]], 27 | ) -> graphviz.Digraph: 28 | if isinstance(graph_documents, GraphDocument): 29 | graph_documents = [GraphDocument] 30 | 31 | dot = graphviz.Digraph() 32 | 33 | nodes = {} 34 | 35 | def _node_id(node: Node) -> int: 36 | node_key = (node.id, node.type) 37 | if node_id := nodes.get(node_key, None): 38 | return node_id 39 | else: 40 | node_id = f"{len(nodes)}" 41 | nodes[node_key] = node_id 42 | dot.node(node_id, label=_node_label(node)) 43 | return node_id 44 | 45 | for graph_document in graph_documents: 46 | for node in graph_document.nodes: 47 | _node_id(node) 48 | for r in graph_document.relationships: 49 | dot.edge(_node_id(r.source), _node_id(r.target), r.type) 50 | 51 | return dot 52 | 53 | 54 | def render_knowledge_schema(knowledge_schema: KnowledgeSchema) -> graphviz.Digraph: 55 | dot = graphviz.Digraph() 56 | 57 | for node in knowledge_schema.nodes: 58 | dot.node(node.type, tooltip=node.description) 59 | 60 | for r in knowledge_schema.relationships: 61 | for source in r.source_types: 62 | for target in r.target_types: 63 | dot.edge(source, target, label=r.edge_type, tooltip=r.description) 64 | 65 | return dot 66 | -------------------------------------------------------------------------------- /knowledge_graph/runnables.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional 2 | 3 | from langchain_core.language_models import BaseChatModel 4 | from langchain_core.output_parsers import JsonOutputParser 5 | from langchain_core.prompts import ChatPromptTemplate 6 | from langchain_core.pydantic_v1 import BaseModel, Field 7 | from langchain_core.runnables import Runnable, RunnableLambda, RunnablePassthrough 8 | from langchain_experimental.graph_transformers.llm import optional_enum_field 9 | 10 | from .traverse import Node 11 | 12 | QUERY_ENTITY_EXTRACT_PROMPT = ( 13 | "A question is provided below. Given the question, extract up to 5 " 14 | "entities (name and type) from the text. Focus on extracting the entities " 15 | " that we can use to best lookup answers to the question. Avoid stopwords.\n" 16 | "---------------------\n" 17 | "{question}\n" 18 | "---------------------\n" 19 | "{format_instructions}\n" 20 | ) 21 | 22 | 23 | # TODO: Use a knowledge schema when extracting entities, to get the right kinds of nodes. 24 | def extract_entities( 25 | llm: BaseChatModel, 26 | keyword_extraction_prompt: str = QUERY_ENTITY_EXTRACT_PROMPT, 27 | node_types: Optional[List[str]] = None, 28 | ) -> Runnable: 29 | """ 30 | Return a keyword-extraction runnable. 31 | 32 | This will expect a dictionary containing the `"question"` to extract keywords from. 33 | 34 | Parameters: 35 | - llm: The LLM to use for extracting entities. 36 | - node_types: List of node types to extract. 37 | - keyword_extraction_prompt: The prompt to use for requesting entities. 38 | This should include the `{question}` being asked as well as the `{format_instructions}` 39 | which describe how to produce the output. 40 | """ 41 | prompt = ChatPromptTemplate.from_messages([keyword_extraction_prompt]) 42 | assert "question" in prompt.input_variables 43 | assert "format_instructions" in prompt.input_variables 44 | 45 | class SimpleNode(BaseModel): 46 | """Represents a node in a graph with associated properties.""" 47 | 48 | id: str = Field(description="Name or human-readable unique identifier.") 49 | type: str = optional_enum_field(node_types, description="The type or label of the node.") 50 | 51 | class SimpleNodeList(BaseModel): 52 | """Represents a list of simple nodes.""" 53 | 54 | nodes: List[SimpleNode] 55 | 56 | output_parser = JsonOutputParser(pydantic_object=SimpleNodeList) 57 | return ( 58 | RunnablePassthrough.assign( 59 | format_instructions=lambda _: output_parser.get_format_instructions(), 60 | ) 61 | | ChatPromptTemplate.from_messages([keyword_extraction_prompt]) 62 | | llm 63 | | output_parser 64 | | RunnableLambda(lambda node_list: [Node(n["id"], n["type"]) for n in node_list["nodes"]]) 65 | ) 66 | -------------------------------------------------------------------------------- /knowledge_graph/schema_inference.py: -------------------------------------------------------------------------------- 1 | from typing import Sequence, cast 2 | 3 | from langchain_core.documents import Document 4 | from langchain_core.language_models.chat_models import BaseChatModel 5 | from langchain_core.prompts import ( 6 | ChatPromptTemplate, 7 | HumanMessagePromptTemplate, 8 | SystemMessagePromptTemplate, 9 | ) 10 | 11 | from knowledge_graph.knowledge_schema import KnowledgeSchema 12 | from knowledge_graph.templates import load_template 13 | 14 | 15 | class KnowledgeSchemaInferer: 16 | def __init__(self, llm: BaseChatModel) -> None: 17 | prompt = load_template( 18 | "schema_inference.md", 19 | ) 20 | prompt = ChatPromptTemplate.from_messages( 21 | [ 22 | SystemMessagePromptTemplate(prompt=load_template("schema_inference.md")), 23 | HumanMessagePromptTemplate.from_template("Input: {input}"), 24 | ] 25 | ) 26 | # TODO: Use "full" output so we can detect parsing errors? 27 | structured_llm = llm.with_structured_output(KnowledgeSchema) 28 | self._chain = prompt | structured_llm 29 | 30 | def infer_schemas_from(self, documents: Sequence[Document]) -> Sequence[KnowledgeSchema]: 31 | responses = self._chain.batch([{"input": doc.page_content} for doc in documents]) 32 | return cast(Sequence[KnowledgeSchema], responses) 33 | -------------------------------------------------------------------------------- /knowledge_graph/templates.py: -------------------------------------------------------------------------------- 1 | from os import path 2 | from typing import Callable, Union 3 | 4 | from langchain_core.prompts import PromptTemplate 5 | 6 | TEMPLATE_PATH = path.join(path.dirname(__file__), "prompt_templates") 7 | 8 | 9 | def load_template(filename: str, **kwargs: Union[str, Callable[[], str]]) -> PromptTemplate: 10 | template = PromptTemplate.from_file(path.join(TEMPLATE_PATH, filename)) 11 | if kwargs: 12 | template = template.partial(**kwargs) 13 | return template 14 | -------------------------------------------------------------------------------- /knowledge_graph/traverse.py: -------------------------------------------------------------------------------- 1 | import asyncio 2 | import threading 3 | from typing import Any, Dict, Iterable, NamedTuple, Optional, Sequence 4 | 5 | from cassandra.cluster import PreparedStatement, ResponseFuture, Session 6 | from cassio.config import check_resolve_keyspace, check_resolve_session 7 | 8 | 9 | class Node(NamedTuple): 10 | name: str 11 | type: str 12 | properties: Dict[str, Any] = {} 13 | 14 | def __repr__(self): 15 | return f"{self.name} ({self.type})" 16 | 17 | def __hash__(self): 18 | return hash(self.name) * hash(self.type) 19 | 20 | def __eq__(self, value) -> bool: 21 | if not isinstance(value, Node): 22 | return False 23 | 24 | return self.name == value.name and self.type == value.type 25 | 26 | 27 | class Relation(NamedTuple): 28 | source: Node 29 | target: Node 30 | type: str 31 | 32 | def __repr__(self): 33 | return f"{self.source} -> {self.target}: {self.type}" 34 | 35 | 36 | def _parse_relation(row) -> Relation: 37 | return Relation( 38 | source=Node(name=row.source_name, type=row.source_type), 39 | target=Node(name=row.target_name, type=row.target_type), 40 | type=row.type, 41 | ) 42 | 43 | 44 | def _prepare_edge_query( 45 | edge_table: str, 46 | edge_source_name: str, 47 | edge_source_type: str, 48 | edge_target_name: str, 49 | edge_target_type: str, 50 | edge_type: str, 51 | edge_filters: Sequence[str], 52 | session: Session, 53 | keyspace: str, 54 | ) -> PreparedStatement: 55 | """Return the query for the edges from a given source.""" 56 | query = f""" 57 | SELECT 58 | {edge_source_name} AS source_name, 59 | {edge_source_type} AS source_type, 60 | {edge_target_name} AS target_name, 61 | {edge_target_type} AS target_type, 62 | {edge_type} AS type 63 | FROM {keyspace}.{edge_table} 64 | WHERE {edge_source_name} = ? 65 | AND {edge_source_type} = ?""" 66 | if edge_filters: 67 | query = "\n AND ".join([query] + edge_filters) 68 | return session.prepare(query) 69 | 70 | 71 | def traverse( 72 | start: Node | Sequence[Node], 73 | edge_table: str, 74 | edge_source_name: str = "source_name", 75 | edge_source_type: str = "source_type", 76 | edge_target_name: str = "target_name", 77 | edge_target_type: str = "target_type", 78 | edge_type: str = "edge_type", 79 | edge_filters: Sequence[str] = (), 80 | steps: int = 3, 81 | session: Optional[Session] = None, 82 | keyspace: Optional[str] = None, 83 | ) -> Iterable[Relation]: 84 | """ 85 | Traverse the graph from the given starting nodes and return the resulting sub-graph. 86 | 87 | Parameters: 88 | - start: The starting node or nodes. 89 | - edge_table: The table containing the edges. 90 | - edge_source_name: The name of the column containing edge source names. 91 | - edge_source_type: The name of the column containing edge source types. 92 | - edge_target_name: The name of the column containing edge target names. 93 | - edge_target_type: The name of the column containing edge target types. 94 | - edge_type: The name of the column containing edge types. 95 | - edge_filters: Filters to apply to the edges being traversed. 96 | - steps: The number of steps of edges to follow from a start node. 97 | - session: The session to use for executing the query. If not specified, 98 | it will use th default cassio session. 99 | - keyspace: The keyspace to use for the query. If not specified, it will 100 | use the default cassio keyspace. 101 | 102 | Returns: 103 | An iterable over relations in the traversed sub-graph. 104 | """ 105 | if len(start) == 0: 106 | return [] 107 | 108 | session = check_resolve_session(session) 109 | keyspace = check_resolve_keyspace(keyspace) 110 | 111 | pending = set() 112 | distances = {} 113 | results = set() 114 | query = _prepare_edge_query( 115 | edge_table=edge_table, 116 | edge_source_name=edge_source_name, 117 | edge_source_type=edge_source_type, 118 | edge_target_name=edge_target_name, 119 | edge_target_type=edge_target_type, 120 | edge_type=edge_type, 121 | edge_filters=edge_filters, 122 | session=session, 123 | keyspace=keyspace, 124 | ) 125 | 126 | condition = threading.Condition() 127 | error = None 128 | 129 | def handle_result(rows, source_distance: int, request: ResponseFuture): 130 | relations = map(_parse_relation, rows) 131 | with condition: 132 | if source_distance < steps: 133 | for r in relations: 134 | results.add(r) 135 | fetch_relationships(source_distance + 1, r.target) 136 | else: 137 | results.update(relations) 138 | 139 | if request.has_more_pages: 140 | request.start_fetching_next_page() 141 | else: 142 | with condition: 143 | if request._req_id in pending: 144 | pending.remove(request._req_id) 145 | if len(pending) == 0: 146 | condition.notify() 147 | 148 | def handle_error(e): 149 | nonlocal error 150 | with condition: 151 | error = e 152 | condition.notify() 153 | 154 | def fetch_relationships(distance: int, source: Node) -> None: 155 | """ 156 | Fetch relationships from node `source` is found at `distance`. 157 | 158 | This will retrieve the edges from `source`, and visit the resulting 159 | nodes at distance `distance + 1`. 160 | """ 161 | with condition: 162 | old_distance = distances.get(source) 163 | if old_distance is not None and old_distance <= distance: 164 | # Already discovered at that distance. 165 | return 166 | 167 | distances[source] = distance 168 | 169 | request: ResponseFuture = session.execute_async(query, (source.name, source.type)) 170 | pending.add(request._req_id) 171 | request.add_callbacks( 172 | handle_result, 173 | handle_error, 174 | callback_kwargs={"source_distance": distance, "request": request}, 175 | ) 176 | 177 | with condition: 178 | if isinstance(start, Node): 179 | start = [start] 180 | for source in start: 181 | fetch_relationships(1, source) 182 | 183 | condition.wait() 184 | 185 | if error is not None: 186 | raise error 187 | else: 188 | return results 189 | 190 | 191 | class AsyncPagedQuery(object): 192 | def __init__(self, depth: int, response_future: ResponseFuture): 193 | self.loop = asyncio.get_running_loop() 194 | self.depth = depth 195 | self.response_future = response_future 196 | self.current_page_future = asyncio.Future() 197 | self.response_future.add_callbacks(self._handle_page, self._handle_error) 198 | 199 | def _handle_page(self, rows): 200 | self.loop.call_soon_threadsafe(self.current_page_future.set_result, rows) 201 | 202 | def _handle_error(self, error): 203 | self.loop.call_soon_threadsafe(self.current_page_future.set_exception, error) 204 | 205 | async def next(self): 206 | page = [_parse_relation(r) for r in await self.current_page_future] 207 | 208 | if self.response_future.has_more_pages: 209 | self.current_page_future = asyncio.Future() 210 | self.response_future.start_fetching_next_page() 211 | return (self.depth, page, self) 212 | else: 213 | return (self.depth, page, None) 214 | 215 | 216 | async def atraverse( 217 | start: Node | Sequence[Node], 218 | edge_table: str, 219 | edge_source_name: str = "source_name", 220 | edge_source_type: str = "source_type", 221 | edge_target_name: str = "target_name", 222 | edge_target_type: str = "target_type", 223 | edge_type: str = "edge_type", 224 | edge_filters: Sequence[str] = [], 225 | steps: int = 3, 226 | session: Optional[Session] = None, 227 | keyspace: Optional[str] = None, 228 | ) -> Iterable[Relation]: 229 | """ 230 | Async traversal of the graph from the given starting nodes and return the resulting sub-graph. 231 | 232 | Parameters: 233 | - start: The starting node or nodes. 234 | - edge_table: The table containing the edges. 235 | - edge_source_name: The name of the column containing edge source names. 236 | - edge_source_type: The name of the column containing edge source types. 237 | - edge_target_name: The name of the column containing edge target names. 238 | - edge_target_type: The name of the column containing edge target types. 239 | - edge_type: The name of the column containing edge types. 240 | - edge_filters: Filters to apply to the edges being traversed. 241 | Currently, this is specified as a dictionary containing the name 242 | of the edge field to filter on and the CQL predicate to apply. 243 | For example `{"foo": "IN ['a', 'b', 'c']"}`. 244 | - steps: The number of steps of edges to follow from a start node. 245 | - session: The session to use for executing the query. If not specified, 246 | it will use th default cassio session. 247 | - keyspace: The keyspace to use for the query. If not specified, it will 248 | use the default cassio keyspace. 249 | 250 | Returns: 251 | An iterable over relations in the traversed sub-graph. 252 | """ 253 | 254 | session = check_resolve_session(session) 255 | keyspace = check_resolve_keyspace(keyspace) 256 | 257 | # Prepare the query. 258 | # 259 | # We reprepare this for each traversal since each call may have different 260 | # filters. 261 | # 262 | # TODO: We should cache this at least for the common case of no-filters. 263 | query = _prepare_edge_query( 264 | edge_table=edge_table, 265 | edge_source_name=edge_source_name, 266 | edge_source_type=edge_source_type, 267 | edge_target_name=edge_target_name, 268 | edge_target_type=edge_target_type, 269 | edge_type=edge_type, 270 | edge_filters=edge_filters, 271 | session=session, 272 | keyspace=keyspace, 273 | ) 274 | 275 | def fetch_relation(tg: asyncio.TaskGroup, depth: int, source: Node) -> AsyncPagedQuery: 276 | paged_query = AsyncPagedQuery( 277 | depth, session.execute_async(query, (source.name, source.type)) 278 | ) 279 | return tg.create_task(paged_query.next()) 280 | 281 | results = set() 282 | async with asyncio.TaskGroup() as tg: 283 | if isinstance(start, Node): 284 | start = [start] 285 | 286 | discovered = {t: 0 for t in start} 287 | pending = {fetch_relation(tg, 1, source) for source in start} 288 | 289 | while pending: 290 | done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED) 291 | for future in done: 292 | depth, relations, more = future.result() 293 | for relation in relations: 294 | results.add(relation) 295 | 296 | # Schedule the future for more results from the same query. 297 | if more is not None: 298 | pending.add(tg.create_task(more.next())) 299 | 300 | # Schedule futures for the next step. 301 | if depth < steps: 302 | # We've found a path of length `depth` to each of the targets. 303 | # We need to update `discovered` to include the shortest path. 304 | # And build `to_visit` to be all of the targets for which this is 305 | # the new shortest path. 306 | to_visit = set() 307 | for r in relations: 308 | previous = discovered.get(r.target, steps + 1) 309 | if depth < previous: 310 | discovered[r.target] = depth 311 | to_visit.add(r.target) 312 | 313 | for source in to_visit: 314 | pending.add(fetch_relation(tg, depth + 1, source)) 315 | 316 | return results 317 | -------------------------------------------------------------------------------- /knowledge_graph/utils.py: -------------------------------------------------------------------------------- 1 | try: 2 | # Try importing the function from itertools (Python 3.12+) 3 | from itertools import batched 4 | except ImportError: 5 | from itertools import islice 6 | from typing import Iterable, Iterator, TypeVar 7 | 8 | # Fallback implementation for older Python versions 9 | 10 | T = TypeVar("T") 11 | 12 | # This is equivalent to `itertools.batched`, but that is only available in 3.12 13 | def batched(iterable: Iterable[T], n: int) -> Iterator[Iterator[T]]: 14 | if n < 1: 15 | raise ValueError("n must be at least one") 16 | it = iter(iterable) 17 | while batch := tuple(islice(it, n)): 18 | yield batch 19 | -------------------------------------------------------------------------------- /notebook.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# IMPORTANT: new `graph-rag` project with improved functionality\n", 8 | "\n", 9 | "> Since the date of publication, some code in this notebook has been superseded by\n", 10 | "> a new, dedicated \n", 11 | "> [Graph RAG project](https://datastax.github.io/graph-rag/).\n", 12 | "> \n", 13 | "> Specifically, the `ragstack-ai-langchain`, `ragstack-ai-knowledge-graph`, and\n", 14 | "> `ragstack-ai-knowledge-store` libraries as well as the\n", 15 | "> `https://github.com/datastax-labs/knowledge-graphs-langchain`\n", 16 | "> repo are no longer under development.\n", 17 | "> \n", 18 | "> Instead, you can find the latest tools and techniques for working with knowledge\n", 19 | "> graphs and graph RAG in the \n", 20 | "> [Graph RAG project](https://datastax.github.io/graph-rag/)." 21 | ] 22 | }, 23 | { 24 | "cell_type": "markdown", 25 | "metadata": {}, 26 | "source": [ 27 | "# Introduction / Setup\n", 28 | "\n", 29 | "This notebook shows how to use LangChain's [`LLMGraphTransformer`](https://python.langchain.com/docs/use_cases/graph/constructing/#llm-graph-transformer) to extract knowledge triples and store them in [DataStax AstraDB](https://www.datastax.com/products/datastax-astra)." 30 | ] 31 | }, 32 | { 33 | "cell_type": "code", 34 | "execution_count": null, 35 | "metadata": {}, 36 | "outputs": [], 37 | "source": [ 38 | "# (Optional) When developing locally, this reloads the module code when changes are made,\n", 39 | "# making it easier to iterate.\n", 40 | "%load_ext autoreload\n", 41 | "%autoreload 2" 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": null, 47 | "metadata": {}, 48 | "outputs": [], 49 | "source": [ 50 | "# (Required in Colab) Install the knowledge graph library from the repository.\n", 51 | "# This will also install the dependencies.\n", 52 | "%pip install https://github.com/datastax-labs/knowledge-graphs-langchain/archive/main.zip" 53 | ] 54 | }, 55 | { 56 | "cell_type": "markdown", 57 | "metadata": {}, 58 | "source": [ 59 | "## Environment\n", 60 | "Pick one of the following.\n", 61 | "1. If you're just running the notebook, it's probably best to run the cell using `getpass` to set the necessary\n", 62 | " environment variables.\n", 63 | "1. If you're developing, it's likely easiest to create a `.env` file and store the necessary credentials." 64 | ] 65 | }, 66 | { 67 | "cell_type": "code", 68 | "execution_count": null, 69 | "metadata": {}, 70 | "outputs": [], 71 | "source": [ 72 | "# (Option 1) - Set the environment variables from getpass.\n", 73 | "import getpass\n", 74 | "import os\n", 75 | "\n", 76 | "os.environ[\"OPENAI_API_KEY\"] = getpass.getpass(\"Enter OpenAI API Key: \")\n", 77 | "os.environ[\"ASTRA_DB_DATABASE_ID\"] = input(\"Enter Astra DB Database ID: \")\n", 78 | "os.environ[\"ASTRA_DB_APPLICATION_TOKEN\"] = getpass.getpass(\"Enter Astra DB Application Token: \")\n", 79 | "\n", 80 | "keyspace = input(\"Enter Astra DB Keyspace (Empty for default): \")\n", 81 | "if keyspace:\n", 82 | " os.environ[\"ASTRA_DB_KEYSPACE\"] = keyspace\n", 83 | "else:\n", 84 | " os.environ.pop(\"ASTRA_DB_KEYSPACE\", None)" 85 | ] 86 | }, 87 | { 88 | "cell_type": "code", 89 | "execution_count": 1, 90 | "metadata": {}, 91 | "outputs": [ 92 | { 93 | "name": "stdout", 94 | "output_type": "stream", 95 | "text": [ 96 | "Requirement already satisfied: python-dotenv in /Users/benjamin.chambers/Library/Caches/pypoetry/virtualenvs/knowledge-graph-bxUBmW8M-py3.11/lib/python3.11/site-packages (1.0.1)\n", 97 | "\n", 98 | "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m23.3.1\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m24.0\u001b[0m\n", 99 | "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n", 100 | "Note: you may need to restart the kernel to use updated packages.\n" 101 | ] 102 | }, 103 | { 104 | "data": { 105 | "text/plain": [ 106 | "True" 107 | ] 108 | }, 109 | "execution_count": 1, 110 | "metadata": {}, 111 | "output_type": "execute_result" 112 | } 113 | ], 114 | "source": [ 115 | "# (Option 2) - Load the `.env` file.\n", 116 | "# See `env.template` for an example of what you should have there.\n", 117 | "%pip install python-dotenv\n", 118 | "import dotenv\n", 119 | "dotenv.load_dotenv()" 120 | ] 121 | }, 122 | { 123 | "cell_type": "markdown", 124 | "metadata": {}, 125 | "source": [ 126 | "## Initialize Astra DB / Cassandra" 127 | ] 128 | }, 129 | { 130 | "cell_type": "code", 131 | "execution_count": 2, 132 | "metadata": {}, 133 | "outputs": [], 134 | "source": [ 135 | "# Initialize cassandra connection from environment variables).\n", 136 | "import cassio\n", 137 | "cassio.init(auto=True)" 138 | ] 139 | }, 140 | { 141 | "cell_type": "markdown", 142 | "metadata": {}, 143 | "source": [ 144 | "## Create Graph Store" 145 | ] 146 | }, 147 | { 148 | "cell_type": "code", 149 | "execution_count": 3, 150 | "metadata": {}, 151 | "outputs": [], 152 | "source": [ 153 | "# Create graph store.\n", 154 | "from knowledge_graph.cassandra_graph_store import CassandraGraphStore\n", 155 | "graph_store = CassandraGraphStore()" 156 | ] 157 | }, 158 | { 159 | "cell_type": "markdown", 160 | "metadata": {}, 161 | "source": [ 162 | "# Extracting Knowledge Graph" 163 | ] 164 | }, 165 | { 166 | "cell_type": "code", 167 | "execution_count": 4, 168 | "metadata": {}, 169 | "outputs": [], 170 | "source": [ 171 | "from langchain_experimental.graph_transformers import LLMGraphTransformer\n", 172 | "from langchain_openai import ChatOpenAI\n", 173 | "\n", 174 | "# Prompt used by LLMGraphTransformer is tuned for Gpt4.\n", 175 | "llm = ChatOpenAI(temperature=0, model_name=\"gpt-4\")\n", 176 | "\n", 177 | "llm_transformer = LLMGraphTransformer(llm=llm)" 178 | ] 179 | }, 180 | { 181 | "cell_type": "code", 182 | "execution_count": 5, 183 | "metadata": {}, 184 | "outputs": [ 185 | { 186 | "name": "stdout", 187 | "output_type": "stream", 188 | "text": [ 189 | "Nodes:[Node(id='Marie Curie', type='Person'), Node(id='Pierre Curie', type='Person'), Node(id='Nobel Prize', type='Award'), Node(id='University Of Paris', type='Organization'), Node(id='Polish', type='Nationality'), Node(id='French', type='Nationality'), Node(id='Physicist', type='Profession'), Node(id='Chemist', type='Profession'), Node(id='Radioactivity', type='Scientific field'), Node(id='Professor', type='Profession')]\n", 190 | "Relationships:[Relationship(source=Node(id='Marie Curie', type='Person'), target=Node(id='Polish', type='Nationality'), type='HAS_NATIONALITY'), Relationship(source=Node(id='Marie Curie', type='Person'), target=Node(id='French', type='Nationality'), type='HAS_NATIONALITY'), Relationship(source=Node(id='Marie Curie', type='Person'), target=Node(id='Physicist', type='Profession'), type='HAS_PROFESSION'), Relationship(source=Node(id='Marie Curie', type='Person'), target=Node(id='Chemist', type='Profession'), type='HAS_PROFESSION'), Relationship(source=Node(id='Marie Curie', type='Person'), target=Node(id='Radioactivity', type='Scientific field'), type='RESEARCHED_IN'), Relationship(source=Node(id='Marie Curie', type='Person'), target=Node(id='Nobel Prize', type='Award'), type='WON'), Relationship(source=Node(id='Pierre Curie', type='Person'), target=Node(id='Nobel Prize', type='Award'), type='WON'), Relationship(source=Node(id='Marie Curie', type='Person'), target=Node(id='University Of Paris', type='Organization'), type='WORKED_AT'), Relationship(source=Node(id='Marie Curie', type='Person'), target=Node(id='Professor', type='Profession'), type='HAS_PROFESSION')]\n" 191 | ] 192 | } 193 | ], 194 | "source": [ 195 | "from langchain_core.documents import Document\n", 196 | "\n", 197 | "text = \"\"\"\n", 198 | "Marie Curie, was a Polish and naturalised-French physicist and chemist who conducted pioneering research on radioactivity.\n", 199 | "She was the first woman to win a Nobel Prize, the first person to win a Nobel Prize twice, and the only person to win a Nobel Prize in two scientific fields.\n", 200 | "Her husband, Pierre Curie, was a co-winner of her first Nobel Prize, making them the first-ever married couple to win the Nobel Prize and launching the Curie family legacy of five Nobel Prizes.\n", 201 | "She was, in 1906, the first woman to become a professor at the University of Paris.\n", 202 | "\"\"\"\n", 203 | "documents = [Document(page_content=text)]\n", 204 | "graph_documents = llm_transformer.convert_to_graph_documents(documents)\n", 205 | "print(f\"Nodes:{graph_documents[0].nodes}\")\n", 206 | "print(f\"Relationships:{graph_documents[0].relationships}\")" 207 | ] 208 | }, 209 | { 210 | "cell_type": "code", 211 | "execution_count": 6, 212 | "metadata": {}, 213 | "outputs": [ 214 | { 215 | "data": { 216 | "image/svg+xml": [ 217 | "\n", 218 | "\n", 220 | "\n", 222 | "\n", 223 | "\n", 225 | "\n", 226 | "\n", 227 | "\n", 228 | "\n", 229 | "Marie Curie [Person]\n", 230 | "\n", 231 | "Marie Curie [Person]\n", 232 | "\n", 233 | "\n", 234 | "\n", 235 | "Nobel Prize [Award]\n", 236 | "\n", 237 | "Nobel Prize [Award]\n", 238 | "\n", 239 | "\n", 240 | "\n", 241 | "Marie Curie [Person]->Nobel Prize [Award]\n", 242 | "\n", 243 | "\n", 244 | "WON\n", 245 | "\n", 246 | "\n", 247 | "\n", 248 | "University Of Paris [Organization]\n", 249 | "\n", 250 | "University Of Paris [Organization]\n", 251 | "\n", 252 | "\n", 253 | "\n", 254 | "Marie Curie [Person]->University Of Paris [Organization]\n", 255 | "\n", 256 | "\n", 257 | "WORKED_AT\n", 258 | "\n", 259 | "\n", 260 | "\n", 261 | "Polish [Nationality]\n", 262 | "\n", 263 | "Polish [Nationality]\n", 264 | "\n", 265 | "\n", 266 | "\n", 267 | "Marie Curie [Person]->Polish [Nationality]\n", 268 | "\n", 269 | "\n", 270 | "HAS_NATIONALITY\n", 271 | "\n", 272 | "\n", 273 | "\n", 274 | "French [Nationality]\n", 275 | "\n", 276 | "French [Nationality]\n", 277 | "\n", 278 | "\n", 279 | "\n", 280 | "Marie Curie [Person]->French [Nationality]\n", 281 | "\n", 282 | "\n", 283 | "HAS_NATIONALITY\n", 284 | "\n", 285 | "\n", 286 | "\n", 287 | "Physicist [Profession]\n", 288 | "\n", 289 | "Physicist [Profession]\n", 290 | "\n", 291 | "\n", 292 | "\n", 293 | "Marie Curie [Person]->Physicist [Profession]\n", 294 | "\n", 295 | "\n", 296 | "HAS_PROFESSION\n", 297 | "\n", 298 | "\n", 299 | "\n", 300 | "Chemist [Profession]\n", 301 | "\n", 302 | "Chemist [Profession]\n", 303 | "\n", 304 | "\n", 305 | "\n", 306 | "Marie Curie [Person]->Chemist [Profession]\n", 307 | "\n", 308 | "\n", 309 | "HAS_PROFESSION\n", 310 | "\n", 311 | "\n", 312 | "\n", 313 | "Radioactivity [Scientific field]\n", 314 | "\n", 315 | "Radioactivity [Scientific field]\n", 316 | "\n", 317 | "\n", 318 | "\n", 319 | "Marie Curie [Person]->Radioactivity [Scientific field]\n", 320 | "\n", 321 | "\n", 322 | "RESEARCHED_IN\n", 323 | "\n", 324 | "\n", 325 | "\n", 326 | "Professor [Profession]\n", 327 | "\n", 328 | "Professor [Profession]\n", 329 | "\n", 330 | "\n", 331 | "\n", 332 | "Marie Curie [Person]->Professor [Profession]\n", 333 | "\n", 334 | "\n", 335 | "HAS_PROFESSION\n", 336 | "\n", 337 | "\n", 338 | "\n", 339 | "Pierre Curie [Person]\n", 340 | "\n", 341 | "Pierre Curie [Person]\n", 342 | "\n", 343 | "\n", 344 | "\n", 345 | "Pierre Curie [Person]->Nobel Prize [Award]\n", 346 | "\n", 347 | "\n", 348 | "WON\n", 349 | "\n", 350 | "\n", 351 | "\n" 352 | ], 353 | "text/plain": [ 354 | "" 355 | ] 356 | }, 357 | "execution_count": 6, 358 | "metadata": {}, 359 | "output_type": "execute_result" 360 | } 361 | ], 362 | "source": [ 363 | "# Render the extracted graph to GraphViz.\n", 364 | "from knowledge_graph.render import render_graph_documents\n", 365 | "render_graph_documents(graph_documents)" 366 | ] 367 | }, 368 | { 369 | "cell_type": "code", 370 | "execution_count": 7, 371 | "metadata": {}, 372 | "outputs": [], 373 | "source": [ 374 | "# Save the extracted graph documents to the AstraDB / Cassandra Graph Store.\n", 375 | "graph_store.add_graph_documents(graph_documents)" 376 | ] 377 | }, 378 | { 379 | "cell_type": "markdown", 380 | "metadata": {}, 381 | "source": [ 382 | "### Optional: Predefine entities / relationships\n", 383 | "\n", 384 | "The below shows how to configure the `LLMGraphTransformer` with specific kinds of nodes and relationships it is allowed to extract.\n", 385 | "This is useful for constraining what will be extracted.\n", 386 | "\n", 387 | "```python\n", 388 | "llm_transformer_filtered = LLMGraphTransformer(\n", 389 | " llm=llm,\n", 390 | " allowed_nodes=[\"Person\", \"Country\", \"Organization\"],\n", 391 | " allowed_relationships=[\"NATIONALITY\", \"LOCATED_IN\", \"WORKED_AT\", \"SPOUSE\"],\n", 392 | ")\n", 393 | "graph_documents_filtered = llm_transformer_filtered.convert_to_graph_documents(\n", 394 | " documents\n", 395 | ")\n", 396 | "print(f\"Nodes:{graph_documents_filtered[0].nodes}\")\n", 397 | "print(f\"Relationships:{graph_documents_filtered[0].relationships}\")\n", 398 | "```" 399 | ] 400 | }, 401 | { 402 | "cell_type": "markdown", 403 | "metadata": {}, 404 | "source": [ 405 | "# Querying" 406 | ] 407 | }, 408 | { 409 | "cell_type": "markdown", 410 | "metadata": {}, 411 | "source": [ 412 | "We can query the `GraphStore` directly. The `as_runnable` method takes some configuration for how to extract the subgraph and returns a LangChain `Runnable` which can be invoked on a node or sequence of nodes to traverse from those starting points." 413 | ] 414 | }, 415 | { 416 | "cell_type": "code", 417 | "execution_count": 8, 418 | "metadata": {}, 419 | "outputs": [ 420 | { 421 | "data": { 422 | "text/plain": [ 423 | "{Marie Curie(Person) -> Chemist(Profession): HAS_PROFESSION,\n", 424 | " Marie Curie(Person) -> French(Nationality): HAS_NATIONALITY,\n", 425 | " Marie Curie(Person) -> Nobel Prize(Award): WON,\n", 426 | " Marie Curie(Person) -> Physicist(Profession): HAS_PROFESSION,\n", 427 | " Marie Curie(Person) -> Pierre Curie(Person): MARRIED_TO,\n", 428 | " Marie Curie(Person) -> Polish(Nationality): HAS_NATIONALITY,\n", 429 | " Marie Curie(Person) -> Professor(Profession): HAS_PROFESSION,\n", 430 | " Marie Curie(Person) -> Radioactivity(Scientific concept): RESEARCHED,\n", 431 | " Marie Curie(Person) -> Radioactivity(Scientific field): RESEARCHED_IN,\n", 432 | " Marie Curie(Person) -> University Of Paris(Organization): WORKED_AT,\n", 433 | " Pierre Curie(Person) -> Nobel Prize(Award): WON}" 434 | ] 435 | }, 436 | "execution_count": 8, 437 | "metadata": {}, 438 | "output_type": "execute_result" 439 | } 440 | ], 441 | "source": [ 442 | "from knowledge_graph.traverse import Node\n", 443 | "\n", 444 | "graph_store.as_runnable(steps=2).invoke(Node(\"Marie Curie\", \"Person\"))" 445 | ] 446 | }, 447 | { 448 | "cell_type": "markdown", 449 | "metadata": {}, 450 | "source": [ 451 | "For getting started, the library also provides a `Runnable` for extracting the starting entities from a question." 452 | ] 453 | }, 454 | { 455 | "cell_type": "code", 456 | "execution_count": 9, 457 | "metadata": {}, 458 | "outputs": [ 459 | { 460 | "data": { 461 | "text/plain": [ 462 | "[Marie Curie(Person)]" 463 | ] 464 | }, 465 | "execution_count": 9, 466 | "metadata": {}, 467 | "output_type": "execute_result" 468 | } 469 | ], 470 | "source": [ 471 | "# Example showing extracted entities (nodes)\n", 472 | "from knowledge_graph import extract_entities\n", 473 | "extract_entities(llm).invoke({ \"question\": \"Who is Marie Curie?\"})" 474 | ] 475 | }, 476 | { 477 | "cell_type": "markdown", 478 | "metadata": {}, 479 | "source": [ 480 | "## Query Chain\n", 481 | "\n", 482 | "We'll create a chain which does the following:\n", 483 | "\n", 484 | "1. Use the entity extraction `Runnable` from the library in order to determine the starting points.\n", 485 | "2. Retrieve the sub-knowledge graphs starting from those nodes.\n", 486 | "3. Create a context containing those knowledge triples.\n", 487 | "4. Apply the LLM to answer the question given the context. " 488 | ] 489 | }, 490 | { 491 | "cell_type": "code", 492 | "execution_count": 10, 493 | "metadata": {}, 494 | "outputs": [], 495 | "source": [ 496 | "from operator import itemgetter\n", 497 | "from langchain_core.runnables import RunnableLambda, RunnablePassthrough\n", 498 | "from langchain_core.prompts import ChatPromptTemplate\n", 499 | "from knowledge_graph import extract_entities\n", 500 | "from langchain_openai import ChatOpenAI\n", 501 | "llm = ChatOpenAI(model_name = \"gpt-4\")\n", 502 | "\n", 503 | "def _combine_relations(relations):\n", 504 | " return \"\\n\".join(map(repr, relations))\n", 505 | "\n", 506 | "ANSWER_PROMPT = (\n", 507 | " \"The original question is given below.\"\n", 508 | " \"This question has been used to retrieve information from a knowledge graph.\"\n", 509 | " \"The matching triples are shown below.\"\n", 510 | " \"Use the information in the triples to answer the original question.\\n\\n\"\n", 511 | " \"Original Question: {question}\\n\\n\"\n", 512 | " \"Knowledge Graph Triples:\\n{context}\\n\\n\"\n", 513 | " \"Response:\"\n", 514 | ")\n", 515 | "\n", 516 | "chain = (\n", 517 | " { \"question\": RunnablePassthrough() }\n", 518 | " | RunnablePassthrough.assign(entities = extract_entities(llm))\n", 519 | " | RunnablePassthrough.assign(triples = itemgetter(\"entities\") | graph_store.as_runnable())\n", 520 | " | RunnablePassthrough.assign(context = itemgetter(\"triples\") | RunnableLambda(_combine_relations))\n", 521 | " | ChatPromptTemplate.from_messages([ANSWER_PROMPT])\n", 522 | " | llm\n", 523 | ")" 524 | ] 525 | }, 526 | { 527 | "cell_type": "markdown", 528 | "metadata": {}, 529 | "source": [ 530 | "## Example\n", 531 | "And finally, we can run the chain end to end to answer a question using the retrieved knowledge." 532 | ] 533 | }, 534 | { 535 | "cell_type": "code", 536 | "execution_count": 12, 537 | "metadata": {}, 538 | "outputs": [ 539 | { 540 | "data": { 541 | "text/plain": [ 542 | "AIMessage(content='Marie Curie is a Polish and French chemist, professor, and physicist who researched radioactivity and worked at the University of Paris. She was married to Pierre Curie and both of them have won the Nobel Prize.', response_metadata={'token_usage': {'completion_tokens': 45, 'prompt_tokens': 246, 'total_tokens': 291}, 'model_name': 'gpt-4', 'system_fingerprint': None, 'finish_reason': 'stop', 'logprobs': None}, id='run-85a75d94-705a-4c49-9bcf-f16ae77b1c7d-0')" 543 | ] 544 | }, 545 | "execution_count": 12, 546 | "metadata": {}, 547 | "output_type": "execute_result" 548 | } 549 | ], 550 | "source": [ 551 | "chain.invoke(\"Who is Marie Curie?\")" 552 | ] 553 | } 554 | ], 555 | "metadata": { 556 | "kernelspec": { 557 | "display_name": "knowledge-graph-bxUBmW8M-py3.11", 558 | "language": "python", 559 | "name": "python3" 560 | }, 561 | "language_info": { 562 | "codemirror_mode": { 563 | "name": "ipython", 564 | "version": 3 565 | }, 566 | "file_extension": ".py", 567 | "mimetype": "text/x-python", 568 | "name": "python", 569 | "nbconvert_exporter": "python", 570 | "pygments_lexer": "ipython3", 571 | "version": "3.11.4" 572 | } 573 | }, 574 | "nbformat": 4, 575 | "nbformat_minor": 2 576 | } 577 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.poetry] 2 | name = "knowledge-graph" 3 | version = "0.1.0" 4 | description = "" 5 | authors = ["Ben Chambers <35960+bjchambers@users.noreply.github.com>"] 6 | readme = "README.md" 7 | packages = [{include = "knowledge_graph"}] 8 | include = [ 9 | { path = "knowledge_graph/prompt_templates/*.md", format = ["sdist", "wheel"] } 10 | ] 11 | 12 | 13 | [tool.poetry.dependencies] 14 | python = "^3.10" 15 | langchain = "^0.1.14" 16 | langchain-community = "^0.0.31" 17 | langchain-openai = "^0.1.1" 18 | langchain-experimental = "^0.0.56" 19 | cassio = "^0.1.5" 20 | graphviz = "^0.20.3" 21 | pydantic-yaml = "^1.3.0" 22 | pyyaml = "^6.0.1" 23 | 24 | 25 | [tool.poetry.group.dev.dependencies] 26 | python-dotenv = "^1.0.1" 27 | ipykernel = "^6.29.4" 28 | ruff = "^0.3.5" 29 | testcontainers = "~3.7.1" 30 | pytest = "^8.1.1" 31 | precisely = "^0.1.9" 32 | pytest-asyncio = "^0.23.6" 33 | pytest-dotenv = "^0.5.2" 34 | 35 | [build-system] 36 | requires = ["poetry-core"] 37 | build-backend = "poetry.core.masonry.api" 38 | 39 | [tool.ruff] 40 | line-length = 98 41 | 42 | # Assume Python 3.11. 43 | target-version = "py310" 44 | 45 | [tool.ruff.lint] 46 | # Enable Pyflakes `E` and `F` codes by default. 47 | select = [ 48 | # Pyflakes 49 | "F", 50 | # Pycodestyle 51 | "E", 52 | "W", 53 | # isort 54 | "I001", 55 | ] 56 | ignore = [] 57 | 58 | # Allow unused variables when underscore-prefixed. 59 | dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" 60 | 61 | [tool.ruff.lint.mccabe] 62 | # Unlike Flake8, default to a complexity level of 10. 63 | max-complexity = 10 64 | 65 | [tool.mypy] 66 | strict = true 67 | warn_unreachable = true 68 | pretty = true 69 | show_column_numbers = true 70 | show_error_context = true 71 | 72 | [tool.pytest.ini_options] 73 | testpaths = ["tests"] 74 | asyncio_mode = "auto" -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/datastax-labs/knowledge-graphs-langchain/31fd23f8d37690c2d2d3381c4d09c61e96fc5d19/tests/__init__.py -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import secrets 2 | from typing import Iterator, List 3 | 4 | import pytest 5 | from cassandra.cluster import Cluster, Session 6 | from langchain.graphs.graph_document import GraphDocument, Node, Relationship 7 | from langchain_core.documents import Document 8 | from langchain_core.language_models import BaseChatModel 9 | from testcontainers.core.container import DockerContainer 10 | from testcontainers.core.waiting_utils import wait_for_logs 11 | 12 | from knowledge_graph.cassandra_graph_store import CassandraGraphStore 13 | 14 | 15 | @pytest.fixture(scope="session") 16 | def db_keyspace() -> str: 17 | return "default_keyspace" 18 | 19 | 20 | @pytest.fixture(scope="session") 21 | def cassandra_port(db_keyspace: str) -> Iterator[int]: 22 | # TODO: Allow running against local Cassandra and/or Astra using pytest option. 23 | cassandra = DockerContainer("cassandra:5") 24 | cassandra.with_exposed_ports(9042) 25 | cassandra.with_env( 26 | "JVM_OPTS", 27 | "-Dcassandra.skip_wait_for_gossip_to_settle=0 -Dcassandra.initial_token=0", 28 | ) 29 | cassandra.with_env("HEAP_NEWSIZE", "128M") 30 | cassandra.with_env("MAX_HEAP_SIZE", "1024M") 31 | cassandra.with_env("CASSANDRA_ENDPOINT_SNITCH", "GossipingPropertyFileSnitch") 32 | cassandra.with_env("CASSANDRA_DC", "datacenter1") 33 | cassandra.start() 34 | wait_for_logs(cassandra, "Startup complete") 35 | cassandra.get_wrapped_container().exec_run( 36 | ( 37 | f"""cqlsh -e "CREATE KEYSPACE {db_keyspace} WITH replication = """ 38 | '''{'class': 'SimpleStrategy', 'replication_factor': '1'};"''' 39 | ) 40 | ) 41 | port = cassandra.get_exposed_port(9042) 42 | print(f"Cassandra started. Port is {port}") 43 | yield port 44 | cassandra.stop() 45 | 46 | 47 | @pytest.fixture(scope="session") 48 | def db_session(cassandra_port: int) -> Session: 49 | print(f"Connecting to cassandra on {cassandra_port}") 50 | cluster = Cluster( 51 | port=cassandra_port, 52 | ) 53 | return cluster.connect() 54 | 55 | 56 | @pytest.fixture(scope="session") 57 | def llm() -> BaseChatModel: 58 | try: 59 | from langchain_openai import ChatOpenAI 60 | 61 | model = ChatOpenAI(model_name="gpt-4-turbo-2024-04-09", temperature=0.0) 62 | return model 63 | except ValueError: 64 | pytest.skip("Unable to create OpenAI model") 65 | 66 | 67 | class DataFixture: 68 | def __init__(self, session: Session, keyspace: str, documents: List[GraphDocument]) -> None: 69 | self.session = session 70 | self.keyspace = "default_keyspace" 71 | self.uid = secrets.token_hex(8) 72 | self.node_table = f"entities_{self.uid}" 73 | self.edge_table = f"relationships_{self.uid}" 74 | 75 | text_embeddings = None 76 | try: 77 | from langchain_openai import OpenAIEmbeddings 78 | 79 | text_embeddings = OpenAIEmbeddings() 80 | except ValueError: 81 | print("OpenAI not configured. Not embedding data.") 82 | self.has_embeddings = text_embeddings is not None 83 | 84 | self.graph_store = CassandraGraphStore( 85 | node_table=self.node_table, 86 | edge_table=self.edge_table, 87 | text_embeddings=text_embeddings, 88 | session=session, 89 | keyspace=keyspace, 90 | ) 91 | 92 | self.graph_store.add_graph_documents(documents) 93 | 94 | def drop(self): 95 | self.session.execute(f"DROP TABLE IF EXISTS {self.keyspace}.{self.node_table};") 96 | self.session.execute(f"DROP TABLE IF EXISTS {self.keyspace}.{self.edge_table};") 97 | 98 | 99 | @pytest.fixture(scope="session") 100 | def marie_curie(db_session: Session, db_keyspace: str) -> Iterator[DataFixture]: 101 | marie_curie = Node(id="Marie Curie", type="Person") 102 | pierre_curie = Node(id="Pierre Curie", type="Person") 103 | nobel_prize = Node(id="Nobel Prize", type="Award") 104 | university_of_paris = Node(id="University of Paris", type="Organization") 105 | polish = Node(id="Polish", type="Nationality", properties={"European": True}) 106 | french = Node(id="French", type="Nationality", properties={"European": True}) 107 | physicist = Node(id="Physicist", type="Profession") 108 | chemist = Node(id="Chemist", type="Profession") 109 | radioactivity = Node(id="Radioactivity", type="Scientific concept") 110 | professor = Node(id="Professor", type="Profession") 111 | document = GraphDocument( 112 | nodes=[ 113 | marie_curie, 114 | pierre_curie, 115 | nobel_prize, 116 | university_of_paris, 117 | polish, 118 | french, 119 | physicist, 120 | chemist, 121 | radioactivity, 122 | professor, 123 | ], 124 | relationships=[ 125 | Relationship(source=marie_curie, target=polish, type="HAS_NATIONALITY"), 126 | Relationship(source=marie_curie, target=french, type="HAS_NATIONALITY"), 127 | Relationship(source=marie_curie, target=physicist, type="HAS_PROFESSION"), 128 | Relationship(source=marie_curie, target=chemist, type="HAS_PROFESSION"), 129 | Relationship(source=marie_curie, target=radioactivity, type="RESEARCHED"), 130 | Relationship(source=marie_curie, target=nobel_prize, type="WON"), 131 | Relationship(source=pierre_curie, target=nobel_prize, type="WON"), 132 | Relationship(source=marie_curie, target=pierre_curie, type="MARRIED_TO"), 133 | Relationship(source=marie_curie, target=university_of_paris, type="WORKED_AT"), 134 | Relationship(source=marie_curie, target=professor, type="HAS_PROFESSION"), 135 | ], 136 | source=Document(page_content="test_content"), 137 | ) 138 | data = DataFixture(session=db_session, keyspace=db_keyspace, documents=[document]) 139 | yield data 140 | data.drop() 141 | -------------------------------------------------------------------------------- /tests/marie_curie_schema.yaml: -------------------------------------------------------------------------------- 1 | nodes: 2 | - type: Institution 3 | description: An institution, such as a business or university. 4 | - type: Award 5 | description: An award, such as the Nobel Prize or an Oscar. 6 | - type: Person 7 | description: A person. 8 | - type: Occupation 9 | description: An occupation which a person held, such as Biologist or Chemist. 10 | - type: Nationality 11 | description: A nationality associated with people of a given country. 12 | 13 | relationships: 14 | - edge_type: HAS_OCCUPATION 15 | source_types: ["Person"] 16 | target_types: ["Occupation"] 17 | description: The source person had the target occupation. 18 | - edge_type: STUDIED_AT 19 | source_types: ["Person"] 20 | target_types: ["Institution"] 21 | description: The source person studied at the target institution. 22 | - edge_type: WORKED_AT 23 | source_types: ["Person"] 24 | target_types: ["Institution"] 25 | description: The source person worked at the target institution. 26 | - edge_type: RECEIVED 27 | source_types: ["Person"] 28 | target_types: ["Award"] 29 | description: The source person received the target award. 30 | - edge_type: HAS_NATIONALITY 31 | source_types: ["Person"] 32 | target_types: ["Nationality"] 33 | description: The source person has the target nationality. 34 | - edge_type: MARRIED_TO 35 | source_types: ["Person"] 36 | target_types: ["Person"] 37 | description: | 38 | The source is married to the target. 39 | Marriage is symmetric so the reverse relationship should also exist. -------------------------------------------------------------------------------- /tests/test_extraction.py: -------------------------------------------------------------------------------- 1 | from os import path 2 | 3 | import pytest 4 | from langchain_community.graphs.graph_document import Node, Relationship 5 | from langchain_core.documents import Document 6 | from langchain_core.language_models import BaseChatModel 7 | from precisely import assert_that, contains_exactly 8 | 9 | from knowledge_graph.extraction import ( 10 | KnowledgeSchema, 11 | KnowledgeSchemaExtractor, 12 | ) 13 | 14 | 15 | @pytest.fixture(scope="session") 16 | def extractor(llm: BaseChatModel) -> KnowledgeSchemaExtractor: 17 | schema = KnowledgeSchema.from_file( 18 | path.join(path.dirname(__file__), "marie_curie_schema.yaml") 19 | ) 20 | return KnowledgeSchemaExtractor( 21 | llm=llm, 22 | schema=schema, 23 | ) 24 | 25 | 26 | MARIE_CURIE_SOURCE = """ 27 | Marie Curie, was a Polish and naturalised-French physicist and chemist who 28 | conducted pioneering research on radioactivity. She was the first woman to win a 29 | Nobel Prize, the first person to win a Nobel Prize twice, and the only person to 30 | win a Nobel Prize in two scientific fields. Her husband, Pierre Curie, was a 31 | co-winner of her first Nobel Prize, making them the first-ever married couple to 32 | win the Nobel Prize and launching the Curie family legacy of five Nobel Prizes. 33 | She was, in 1906, the first woman to become a professor at the University of 34 | Paris. 35 | """ 36 | 37 | 38 | def test_extraction(extractor: KnowledgeSchemaExtractor): 39 | results = extractor.extract([Document(page_content=MARIE_CURIE_SOURCE)]) 40 | 41 | marie_curie = Node(id="Marie Curie", type="Person") 42 | polish = Node(id="Polish", type="Nationality") 43 | french = Node(id="French", type="Nationality") 44 | physicist = Node(id="Physicist", type="Occupation") 45 | chemist = Node(id="Chemist", type="Occupation") 46 | nobel_prize = Node(id="Nobel Prize", type="Award") 47 | pierre_curie = Node(id="Pierre Curie", type="Person") 48 | 49 | # Annoyingly, the LLM seems to upper-case `of`. We probably need some instructions around 50 | # putting things into standard title case, etc. 51 | university_of_paris = Node(id="University Of Paris", type="Institution") 52 | 53 | assert_that( 54 | results[0].nodes, 55 | contains_exactly( 56 | marie_curie, 57 | polish, 58 | french, 59 | physicist, 60 | chemist, 61 | nobel_prize, 62 | pierre_curie, 63 | university_of_paris, 64 | ), 65 | ) 66 | assert_that( 67 | results[0].relationships, 68 | contains_exactly( 69 | Relationship(source=marie_curie, target=polish, type="HAS_NATIONALITY"), 70 | Relationship(source=marie_curie, target=french, type="HAS_NATIONALITY"), 71 | Relationship(source=marie_curie, target=physicist, type="HAS_OCCUPATION"), 72 | Relationship(source=marie_curie, target=chemist, type="HAS_OCCUPATION"), 73 | Relationship(source=marie_curie, target=nobel_prize, type="RECEIVED"), 74 | Relationship(source=pierre_curie, target=nobel_prize, type="RECEIVED"), 75 | Relationship(source=marie_curie, target=university_of_paris, type="WORKED_AT"), 76 | Relationship(source=marie_curie, target=pierre_curie, type="MARRIED_TO"), 77 | Relationship(source=pierre_curie, target=marie_curie, type="MARRIED_TO"), 78 | ), 79 | ) 80 | -------------------------------------------------------------------------------- /tests/test_knowledge_graph.py: -------------------------------------------------------------------------------- 1 | import secrets 2 | import pytest 3 | from precisely import assert_that, contains_exactly 4 | 5 | from cassandra.cluster import Session 6 | from knowledge_graph.knowledge_graph import CassandraKnowledgeGraph 7 | from knowledge_graph.traverse import Node, Relation 8 | 9 | from .conftest import DataFixture 10 | 11 | def test_no_embeddings(db_session: Session, db_keyspace: str) -> None: 12 | uid = secrets.token_hex(8) 13 | node_table = f"entities_{uid}" 14 | edge_table = f"relationships_{uid}" 15 | 16 | graph = CassandraKnowledgeGraph( 17 | node_table=node_table, 18 | edge_table=edge_table, 19 | text_embeddings=None, 20 | session=db_session, 21 | keyspace=db_keyspace, 22 | ) 23 | graph.insert([Node(name="a", type="b")]) 24 | 25 | def test_traverse_marie_curie(marie_curie: DataFixture) -> None: 26 | (result_nodes, result_edges) = marie_curie.graph_store.graph.subgraph( 27 | start=Node("Marie Curie", "Person"), 28 | steps=1, 29 | ) 30 | expected_nodes = [ 31 | Node(name="Marie Curie", type="Person"), 32 | Node(name="Pierre Curie", type="Person"), 33 | Node(name="Nobel Prize", type="Award"), 34 | Node(name="University of Paris", type="Organization"), 35 | Node(name="Polish", type="Nationality", properties={"European": True}), 36 | Node(name="French", type="Nationality", properties={"European": True}), 37 | Node(name="Physicist", type="Profession"), 38 | Node(name="Chemist", type="Profession"), 39 | Node(name="Radioactivity", type="Scientific concept"), 40 | Node(name="Professor", type="Profession"), 41 | ] 42 | expected_edges = { 43 | Relation(Node("Marie Curie", "Person"), Node("Polish", "Nationality"), "HAS_NATIONALITY"), 44 | Relation(Node("Marie Curie", "Person"), Node("French", "Nationality"), "HAS_NATIONALITY"), 45 | Relation( 46 | Node("Marie Curie", "Person"), Node("Physicist", "Profession"), "HAS_PROFESSION" 47 | ), 48 | Relation(Node("Marie Curie", "Person"), Node("Chemist", "Profession"), "HAS_PROFESSION"), 49 | Relation( 50 | Node("Marie Curie", "Person"), Node("Professor", "Profession"), "HAS_PROFESSION" 51 | ), 52 | Relation( 53 | Node("Marie Curie", "Person"), 54 | Node("Radioactivity", "Scientific concept"), 55 | "RESEARCHED", 56 | ), 57 | Relation(Node("Marie Curie", "Person"), Node("Nobel Prize", "Award"), "WON"), 58 | Relation(Node("Marie Curie", "Person"), Node("Pierre Curie", "Person"), "MARRIED_TO"), 59 | Relation( 60 | Node("Marie Curie", "Person"), 61 | Node("University of Paris", "Organization"), 62 | "WORKED_AT", 63 | ), 64 | } 65 | assert_that(result_edges, contains_exactly(*expected_edges)) 66 | assert_that(result_nodes, contains_exactly(*expected_nodes)) 67 | 68 | 69 | def test_fuzzy_search(marie_curie: DataFixture) -> None: 70 | if not marie_curie.has_embeddings: 71 | pytest.skip("Fuzzy search requires embeddings. Run with openai environment variables") 72 | result_nodes = marie_curie.graph_store.graph.query_nearest_nodes(["Marie", "Poland"]) 73 | expected_nodes = [ 74 | Node(name="Marie Curie", type="Person"), 75 | Node(name="Polish", type="Nationality", properties={"European": True}), 76 | ] 77 | assert_that(result_nodes, contains_exactly(*expected_nodes)) 78 | 79 | result_nodes = marie_curie.graph_store.graph.query_nearest_nodes(["European"], k=2) 80 | expected_nodes = [ 81 | Node(name="Polish", type="Nationality", properties={"European": True}), 82 | Node(name="French", type="Nationality", properties={"European": True}), 83 | ] 84 | assert_that(result_nodes, contains_exactly(*expected_nodes)) -------------------------------------------------------------------------------- /tests/test_runnables.py: -------------------------------------------------------------------------------- 1 | from precisely import assert_that, contains_exactly 2 | 3 | from knowledge_graph.runnables import extract_entities 4 | from knowledge_graph.traverse import Node 5 | 6 | 7 | def test_extract_entities(llm): 8 | extractor = extract_entities(llm) 9 | assert_that( 10 | extractor.invoke({"question": "Who is Marie Curie?"}), 11 | contains_exactly(Node("Marie Curie", "Person")), 12 | ) 13 | -------------------------------------------------------------------------------- /tests/test_schema_inference.py: -------------------------------------------------------------------------------- 1 | from langchain_core.documents import Document 2 | from langchain_core.language_models import BaseChatModel 3 | from precisely import assert_that, contains_exactly 4 | 5 | from knowledge_graph.schema_inference import KnowledgeSchemaInferer 6 | 7 | MARIE_CURIE_SOURCE = """ 8 | Marie Curie, was a Polish and naturalised-French physicist and chemist who 9 | conducted pioneering research on radioactivity. She was the first woman to win a 10 | Nobel Prize, the first person to win a Nobel Prize twice, and the only person to 11 | win a Nobel Prize in two scientific fields. Her husband, Pierre Curie, was a 12 | co-winner of her first Nobel Prize, making them the first-ever married couple to 13 | win the Nobel Prize and launching the Curie family legacy of five Nobel Prizes. 14 | She was, in 1906, the first woman to become a professor at the University of 15 | Paris. 16 | """ 17 | 18 | 19 | def test_schema_inference(llm: BaseChatModel): 20 | schema_inferer = KnowledgeSchemaInferer(llm) 21 | 22 | results = schema_inferer.infer_schemas_from([Document(page_content=MARIE_CURIE_SOURCE)])[0] 23 | 24 | print(results.to_yaml_str()) 25 | assert_that( 26 | [n.type for n in results.nodes], 27 | contains_exactly("person", "institution", "award", "nationality", "field"), 28 | ) 29 | assert_that( 30 | [r.edge_type for r in results.relationships], 31 | contains_exactly("won", "is_nationality_of", "works_at", "is_field_of"), 32 | ) 33 | 34 | # We don't do more testing here since this is meant to attempt to infer things. 35 | -------------------------------------------------------------------------------- /tests/test_traverse.py: -------------------------------------------------------------------------------- 1 | from precisely import assert_that, contains_exactly 2 | 3 | from knowledge_graph.traverse import Node, Relation, atraverse, traverse 4 | 5 | from .conftest import DataFixture 6 | 7 | 8 | def test_traverse_empty(marie_curie: DataFixture) -> None: 9 | results = traverse( 10 | start=[], 11 | steps=1, 12 | edge_table=marie_curie.edge_table, 13 | session=marie_curie.session, 14 | keyspace=marie_curie.keyspace, 15 | ) 16 | assert_that(results, contains_exactly()) 17 | 18 | 19 | def test_traverse_marie_curie(marie_curie: DataFixture) -> None: 20 | results = traverse( 21 | start=Node("Marie Curie", "Person"), 22 | steps=1, 23 | edge_table=marie_curie.edge_table, 24 | session=marie_curie.session, 25 | keyspace=marie_curie.keyspace, 26 | ) 27 | expected = { 28 | Relation(Node("Marie Curie", "Person"), Node("Polish", "Nationality"), "HAS_NATIONALITY"), 29 | Relation(Node("Marie Curie", "Person"), Node("French", "Nationality"), "HAS_NATIONALITY"), 30 | Relation( 31 | Node("Marie Curie", "Person"), Node("Physicist", "Profession"), "HAS_PROFESSION" 32 | ), 33 | Relation(Node("Marie Curie", "Person"), Node("Chemist", "Profession"), "HAS_PROFESSION"), 34 | Relation( 35 | Node("Marie Curie", "Person"), Node("Professor", "Profession"), "HAS_PROFESSION" 36 | ), 37 | Relation( 38 | Node("Marie Curie", "Person"), 39 | Node("Radioactivity", "Scientific concept"), 40 | "RESEARCHED", 41 | ), 42 | Relation(Node("Marie Curie", "Person"), Node("Nobel Prize", "Award"), "WON"), 43 | Relation(Node("Marie Curie", "Person"), Node("Pierre Curie", "Person"), "MARRIED_TO"), 44 | Relation( 45 | Node("Marie Curie", "Person"), 46 | Node("University of Paris", "Organization"), 47 | "WORKED_AT", 48 | ), 49 | } 50 | assert_that(results, contains_exactly(*expected)) 51 | 52 | results = traverse( 53 | start=Node("Marie Curie", "Person"), 54 | steps=2, 55 | edge_table=marie_curie.edge_table, 56 | session=marie_curie.session, 57 | keyspace=marie_curie.keyspace, 58 | ) 59 | expected.add(Relation(Node("Pierre Curie", "Person"), Node("Nobel Prize", "Award"), "WON")) 60 | assert_that(results, contains_exactly(*expected)) 61 | 62 | 63 | async def test_atraverse_empty(marie_curie: DataFixture) -> None: 64 | results = await atraverse( 65 | start=[], 66 | steps=1, 67 | edge_table=marie_curie.edge_table, 68 | session=marie_curie.session, 69 | keyspace=marie_curie.keyspace, 70 | ) 71 | assert_that(results, contains_exactly()) 72 | 73 | 74 | async def test_atraverse_marie_curie(marie_curie: DataFixture) -> None: 75 | results = await atraverse( 76 | start=Node("Marie Curie", "Person"), 77 | steps=1, 78 | edge_table=marie_curie.edge_table, 79 | session=marie_curie.session, 80 | keyspace=marie_curie.keyspace, 81 | ) 82 | expected = { 83 | Relation(Node("Marie Curie", "Person"), Node("Polish", "Nationality"), "HAS_NATIONALITY"), 84 | Relation(Node("Marie Curie", "Person"), Node("French", "Nationality"), "HAS_NATIONALITY"), 85 | Relation( 86 | Node("Marie Curie", "Person"), Node("Physicist", "Profession"), "HAS_PROFESSION" 87 | ), 88 | Relation(Node("Marie Curie", "Person"), Node("Chemist", "Profession"), "HAS_PROFESSION"), 89 | Relation( 90 | Node("Marie Curie", "Person"), Node("Professor", "Profession"), "HAS_PROFESSION" 91 | ), 92 | Relation( 93 | Node("Marie Curie", "Person"), 94 | Node("Radioactivity", "Scientific concept"), 95 | "RESEARCHED", 96 | ), 97 | Relation(Node("Marie Curie", "Person"), Node("Nobel Prize", "Award"), "WON"), 98 | Relation(Node("Marie Curie", "Person"), Node("Pierre Curie", "Person"), "MARRIED_TO"), 99 | Relation( 100 | Node("Marie Curie", "Person"), 101 | Node("University of Paris", "Organization"), 102 | "WORKED_AT", 103 | ), 104 | } 105 | assert_that(results, contains_exactly(*expected)) 106 | 107 | results = await atraverse( 108 | start=Node("Marie Curie", "Person"), 109 | steps=2, 110 | edge_table=marie_curie.edge_table, 111 | session=marie_curie.session, 112 | keyspace=marie_curie.keyspace, 113 | ) 114 | expected.add(Relation(Node("Pierre Curie", "Person"), Node("Nobel Prize", "Award"), "WON")) 115 | assert_that(results, contains_exactly(*expected)) 116 | 117 | 118 | def test_traverse_marie_curie_filtered_edges(marie_curie: DataFixture) -> None: 119 | results = traverse( 120 | start=Node("Marie Curie", "Person"), 121 | steps=1, 122 | edge_filters=["edge_type = 'HAS_NATIONALITY'"], 123 | edge_table=marie_curie.edge_table, 124 | session=marie_curie.session, 125 | keyspace=marie_curie.keyspace, 126 | ) 127 | expected = { 128 | Relation(Node("Marie Curie", "Person"), Node("Polish", "Nationality"), "HAS_NATIONALITY"), 129 | Relation(Node("Marie Curie", "Person"), Node("French", "Nationality"), "HAS_NATIONALITY"), 130 | } 131 | assert_that(results, contains_exactly(*expected)) 132 | 133 | 134 | async def test_atraverse_marie_curie_filtered_edges(marie_curie: DataFixture) -> None: 135 | results = await atraverse( 136 | start=Node("Marie Curie", "Person"), 137 | steps=1, 138 | edge_filters=["edge_type = 'HAS_NATIONALITY'"], 139 | edge_table=marie_curie.edge_table, 140 | session=marie_curie.session, 141 | keyspace=marie_curie.keyspace, 142 | ) 143 | expected = { 144 | Relation(Node("Marie Curie", "Person"), Node("Polish", "Nationality"), "HAS_NATIONALITY"), 145 | Relation(Node("Marie Curie", "Person"), Node("French", "Nationality"), "HAS_NATIONALITY"), 146 | } 147 | assert_that(results, contains_exactly(*expected)) 148 | --------------------------------------------------------------------------------