├── .python-version
├── knowledge_graph_rag
├── utils
│ ├── __init__.py
│ ├── llm.py
│ ├── prompts.py
│ └── text_preprocessing.py
├── documents_cluster.py
├── vectordb.py
├── __init__.py
├── document.py
├── knowledge_graph.py
└── documents_graph.py
├── assets
├── documents_graph.png
└── knowledge_graph.png
├── .gitignore
├── setup.py
├── LICENSE
├── README.MD
└── examples
└── documents_graph_usage.ipynb
/.python-version:
--------------------------------------------------------------------------------
1 | 3.12.2
2 |
--------------------------------------------------------------------------------
/knowledge_graph_rag/utils/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/assets/documents_graph.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sarthakrastogi/graph-rag/HEAD/assets/documents_graph.png
--------------------------------------------------------------------------------
/assets/knowledge_graph.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/sarthakrastogi/graph-rag/HEAD/assets/knowledge_graph.png
--------------------------------------------------------------------------------
/knowledge_graph_rag/documents_cluster.py:
--------------------------------------------------------------------------------
1 | class DocumentsCluster:
2 | def __init__(self) -> None:
3 | pass
4 |
--------------------------------------------------------------------------------
/knowledge_graph_rag/vectordb.py:
--------------------------------------------------------------------------------
1 | class VectorDBCollection:
2 | def __init__(self, vendor="chromadb") -> None:
3 | self.vendor = vendor
4 |
--------------------------------------------------------------------------------
/knowledge_graph_rag/__init__.py:
--------------------------------------------------------------------------------
1 | from .vectordb import VectorDBCollection
2 | from .documents_cluster import DocumentsCluster
3 | from .documents_graph import DocumentsGraph
4 |
--------------------------------------------------------------------------------
/knowledge_graph_rag/utils/llm.py:
--------------------------------------------------------------------------------
1 | from litellm import completion
2 |
3 | def llm_call(messages, model="gpt-3.5-turbo"):
4 | response = completion(model="gpt-3.5-turbo", messages=messages)
5 | return response.choices[0].message.content
--------------------------------------------------------------------------------
/knowledge_graph_rag/document.py:
--------------------------------------------------------------------------------
1 | class Document:
2 | def __init__(self, content, embedding=[], title="", source="") -> None:
3 | self.content = content
4 | self.embedding = embedding
5 | self.title = title
6 | self.source = source
7 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | .DS_Store
2 | .env
3 | utils/__pycache__
4 | graph-rag/__pycache__
5 | graph-rag/utils/__pycache__
6 | examples/__pycache__
7 | resources/__pycache__
8 | __pycache__
9 | .gitmodules
10 | usage.ipynb
11 | build/
12 | dist/
13 | knowledge_graph_rag.egg-info/
14 | cvd_vectors/
15 | med_graph.pickle
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | with open("README.MD", "r") as f:
4 | readme_content = f.read()
5 |
6 | setup(
7 | name="knowledge_graph_rag",
8 | version="0.1.0",
9 | packages=find_packages(),
10 | long_description=readme_content,
11 | long_description_content_type="text/markdown",
12 | install_requires=[
13 | "numpy==1.24.0",
14 | "networkx==3.2.1",
15 | "nltk==3.8.1",
16 | "litellm==1.34.0",
17 | ],
18 | )
19 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright (c) 2024 Sarthak Rastogi
2 |
3 | Permission is hereby granted, free of charge, to any person obtaining a copy
4 | of this software and associated documentation files (the "Software"), to deal
5 | in the Software without restriction, including without limitation the rights
6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7 | copies of the Software, and to permit persons to whom the Software is
8 | furnished to do so, subject to the following conditions:
9 |
10 | The above copyright notice and this permission notice shall be included in all
11 | copies or substantial portions of the Software.
12 |
13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
19 | SOFTWARE.
--------------------------------------------------------------------------------
/knowledge_graph_rag/utils/prompts.py:
--------------------------------------------------------------------------------
1 | knowledge_graph_creation_system_prompt = """
2 | You are given a document and your task is to create a knowledge graph from it.
3 | In the knowledge graph, entities such as people, places, objects, institutions, topics, ideas, etc. are represented as nodes.
4 | Whereas the relationships and actions between them are represented as edges.
5 |
6 | You will respond with a knowledge graph in the given JSON format:
7 |
8 | [
9 | {"entity" : "Entity_name", "connections" : [
10 | {"entity" : "Connected_entity_1", "relationship" : "Relationship_with_connected_entity_1},
11 | {"entity" : "Connected_entity_2", "relationship" : "Relationship_with_connected_entity_2},
12 | ]
13 | },
14 | {"entity" : "Entity_name", "connections" : [
15 | {"entity" : "Connected_entity_1", "relationship" : "Relationship_with_connected_entity_1},
16 | {"entity" : "Connected_entity_2", "relationship" : "Relationship_with_connected_entity_2},
17 | ]
18 | },
19 | ]
20 |
21 | You must strictly respond in the given JSON format or your response will not be parsed correctly!
22 | """
23 |
--------------------------------------------------------------------------------
/knowledge_graph_rag/utils/text_preprocessing.py:
--------------------------------------------------------------------------------
1 | import nltk
2 | from nltk.corpus import stopwords
3 | from nltk.tokenize import word_tokenize
4 | from nltk.stem import WordNetLemmatizer
5 |
6 | # Download necessary NLTK data files
7 | nltk.download("punkt")
8 | nltk.download("stopwords")
9 | nltk.download("wordnet")
10 | nltk.download("omw-1.4")
11 |
12 |
13 | def remove_stop_words_from_and_lemmatise_documents(documents):
14 | # Initialize stop words and lemmatizer
15 | stop_words = set(stopwords.words("english"))
16 | lemmatizer = WordNetLemmatizer()
17 |
18 | # Function to preprocess text
19 | def preprocess_text(sentences):
20 | preprocessed_sentences = []
21 |
22 | for sentence in sentences:
23 | # Tokenize the sentence
24 | words = word_tokenize(sentence)
25 |
26 | # Remove stop words and lemmatize each word
27 | filtered_words = [
28 | lemmatizer.lemmatize(word.lower())
29 | for word in words
30 | if word.lower() not in stop_words and word.isalpha()
31 | ]
32 |
33 | # Join words back to form the sentence
34 | preprocessed_sentence = " ".join(filtered_words)
35 | preprocessed_sentences.append(preprocessed_sentence)
36 |
37 | return preprocessed_sentences
38 |
39 | # Preprocess the list of sentences
40 | preprocessed_documents = preprocess_text(documents)
41 | return preprocessed_documents
42 |
--------------------------------------------------------------------------------
/README.MD:
--------------------------------------------------------------------------------
1 | # Knowledge Graph RAG
2 | ## Automatically create knowledge graphs + document networks to boost performance on RAG
3 |
4 | ### 1. Install Knowledge Graph RAG:
5 |
6 | `pip install knowledge_graph_rag`
7 |
8 | ### 2. Create a Knowledge Graph or a Document Graph:
9 |
10 | ```
11 | # Creating KG on medical documents
12 | documents = ["Cardiovascular disease ...",
13 | "Emerging therapeutic interventions ...",
14 | "The epidemiological burden ...
15 | "Cardiovascular disease also ...",
16 | "Advanced imaging techniques, ...",
17 | "Role of novel biomarkers ..."
18 | ]
19 | knowledge_graph = KnowledgeGraph(documents)
20 | knowledge_graph.create()
21 | knowledge_graph.plot()
22 | ```
23 | 
24 |
25 | ```
26 | documents_graph = DocumentsGraph(documents=documents)
27 | documents_graph.plot()
28 | ```
29 | 
30 |
31 | ### 3. Search knowledge graph entities or find interconnected documents, to augment your LLM context:
32 |
33 | ```
34 | knowledge_graph.search_document(user_query)
35 | ```
36 |
37 | ```
38 | >> Entity: cardiovascular disease
39 | -> antihypertensive agents (Relationship: involves treatment with)
40 | -> statins (Relationship: used to modulate dyslipidemia)
41 | -> antiplatelet therapy (Relationship: utilized to mitigate thrombosis risk)
42 | -> biomarkers (Relationship: detection and prognostication of acute coronary syndromes and heart failure)
43 | -> high-sensitivity troponins (Relationship: detection of acute coronary syndromes and heart failure)
44 | -> natriuretic peptides (Relationship: prognostication of acute coronary syndromes and heart failure)
45 | ```
46 |
47 |
48 |
49 | ```
50 | documents_containing_connected_terminology = documents_graph.find_connected_documents(vectordb_search_result)
51 | documents_containing_connected_terminology
52 | ```
53 |
54 | ```
55 | >> [{'document': 'emerging therapeutic intervention ...'},
56 | {'document': 'management cardiovascular ...'},
57 | {'document': 'role novel biomarkers ...'}]
58 | ```
59 |
60 |
61 | ## Star History
62 |
63 | [](https://star-history.com/#sarthakrastogi/graph-rag&Date)
--------------------------------------------------------------------------------
/knowledge_graph_rag/knowledge_graph.py:
--------------------------------------------------------------------------------
1 | from knowledge_graph_rag.utils.llm import llm_call
2 | import json
3 | import re
4 | from tqdm.notebook import tqdm
5 | import networkx as nx
6 | from collections import defaultdict, deque
7 | import matplotlib.pyplot as plt
8 |
9 | from .utils.prompts import knowledge_graph_creation_system_prompt
10 |
11 |
12 | class KnowledgeGraph:
13 | def __init__(self, documents) -> None:
14 | self.documents = documents
15 |
16 | def remove_trailing_commas(self, json_string):
17 | # Remove trailing commas from JSON arrays and objects
18 | json_string = re.sub(r",\s*([\]}])", r"\1", json_string)
19 | return json_string
20 |
21 | def create_knowledge_representations(self, documents):
22 | knowledge_representations_of_individual_documents = []
23 | for document in tqdm(documents):
24 | messages = [
25 | {"role": "system", "content": knowledge_graph_creation_system_prompt},
26 | {"role": "user", "content": document},
27 | ]
28 |
29 | response = llm_call(messages=messages)
30 | response = response.lower()
31 | response = self.remove_trailing_commas(response)
32 | knowledge_representations_of_individual_documents.append(
33 | json.loads(response)
34 | )
35 |
36 | return knowledge_representations_of_individual_documents
37 |
38 | def create_knowledge_graph_from_representations(self, representations):
39 | G = nx.DiGraph()
40 |
41 | def add_edge(source, target, relationship):
42 | if G.has_edge(source, target):
43 | G[source][target]["relationship"] += f", {relationship}"
44 | G[source][target]["weight"] = G[source][target].get("weight", 1) + 1
45 | else:
46 | G.add_edge(source, target, relationship=relationship, weight=1)
47 |
48 | for rep in representations:
49 | for item in rep:
50 | source = item["entity"]
51 | if "connections" in item:
52 | for conn in item["connections"]:
53 | target = conn["entity"]
54 | relationship = conn["relationship"]
55 | add_edge(source, target, relationship)
56 |
57 | return G
58 |
59 | def create(self):
60 | self.knowledge_representations = self.create_knowledge_representations(
61 | self.documents
62 | )
63 | self.G = self.create_knowledge_graph_from_representations(
64 | self.knowledge_representations
65 | )
66 |
67 | def plot(self):
68 | pos = nx.spring_layout(self.G)
69 | plt.figure(figsize=(12, 8))
70 |
71 | # Draw nodes with labels
72 | nx.draw_networkx_nodes(
73 | self.G, pos, node_size=5000, node_color="skyblue", alpha=0.7
74 | )
75 | node_labels = {
76 | node: node[:20] + "..." if len(node) > 20 else node
77 | for node in self.G.nodes()
78 | }
79 | nx.draw_networkx_labels(
80 | self.G, pos, labels=node_labels, font_size=10, font_family="sans-serif"
81 | )
82 |
83 | # Draw edges with weights
84 | edges = self.G.edges(data=True)
85 | for u, v, d in edges:
86 | weight = d.get("weight", 1) # Default to 1 if weight is not present
87 | nx.draw_networkx_edges(
88 | self.G, pos, edgelist=[(u, v)], width=weight, alpha=0.5
89 | )
90 |
91 | # Add edge labels (relationship and weight)
92 | edge_label = (
93 | f"{d['relationship'][:20]}...\n(w:{weight:.2f})"
94 | if len(d["relationship"]) > 20
95 | else f"{d['relationship']}\n(w:{weight:.2f})"
96 | )
97 | x = (pos[u][0] + pos[v][0]) / 2
98 | y = (pos[u][1] + pos[v][1]) / 2
99 | plt.text(
100 | x,
101 | y,
102 | edge_label,
103 | fontsize=8,
104 | ha="center",
105 | va="center",
106 | bbox=dict(facecolor="white", edgecolor="none", alpha=0.7),
107 | )
108 |
109 | plt.axis("off")
110 | plt.tight_layout()
111 | plt.show()
112 |
113 | def search_document(self, input_document, max_depth=3):
114 | knowledge_representations_of_input_document = (
115 | self.create_knowledge_representations(documents=[input_document])
116 | )
117 | result = []
118 | for rep in knowledge_representations_of_input_document:
119 | for item in rep:
120 | source_entity = item["entity"]
121 | if source_entity in self.G:
122 | result.append(f"\nEntity: {source_entity}")
123 | result.extend(self.bfs_traversal(source_entity, max_depth))
124 | return "\n".join(result)
125 |
126 | def bfs_traversal(self, start_node, max_depth):
127 | visited = set()
128 | queue = deque([(start_node, 0)])
129 | result = []
130 | while queue:
131 | node, depth = queue.popleft()
132 | if depth > max_depth:
133 | break
134 | if node not in visited:
135 | visited.add(node)
136 | for neighbor in self.G.neighbors(node):
137 | if neighbor not in visited:
138 | relationship = self.G[node][neighbor]["relationship"]
139 | result.append(f" -> {neighbor} (Relationship: {relationship})")
140 | queue.append((neighbor, depth + 1))
141 | return result
142 |
--------------------------------------------------------------------------------
/knowledge_graph_rag/documents_graph.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import networkx as nx
3 | import matplotlib.pyplot as plt
4 | from sklearn.feature_extraction.text import TfidfVectorizer
5 | from sklearn.metrics.pairwise import cosine_similarity
6 | import pickle
7 |
8 | from knowledge_graph_rag.utils.text_preprocessing import (
9 | remove_stop_words_from_and_lemmatise_documents,
10 | )
11 |
12 |
13 | class DocumentsGraph:
14 | def __init__(self, documents) -> None:
15 | self.documents = documents
16 | self.preprocessed_documents = remove_stop_words_from_and_lemmatise_documents(
17 | documents=documents
18 | )
19 | self.G = self.create_graph_from_documents()
20 |
21 | def create_graph_from_documents(self):
22 | # Compute TF-IDF
23 | vectorizer = TfidfVectorizer()
24 | tfidf_matrix = vectorizer.fit_transform(self.preprocessed_documents)
25 |
26 | # Compute cosine similarity matrix
27 | cosine_sim = cosine_similarity(tfidf_matrix)
28 |
29 | # Create the graph
30 | G = nx.Graph()
31 |
32 | # Add nodes
33 | for i, doc in enumerate(self.preprocessed_documents):
34 | G.add_node(i, label=doc)
35 |
36 | # Add edges with weights (cosine similarity)
37 | for i in range(len(self.preprocessed_documents)):
38 | for j in range(i + 1, len(self.preprocessed_documents)):
39 | weight = cosine_sim[i, j]
40 | if weight > 0: # Add edge only if there's a similarity
41 | G.add_edge(i, j, weight=weight)
42 |
43 | return G
44 |
45 | def plot(self):
46 | # Draw the graph with labels and edge weights
47 | pos = nx.spring_layout(self.G)
48 |
49 | plt.figure(figsize=(12, 8))
50 |
51 | # Draw nodes with labels
52 | node_labels = nx.get_node_attributes(self.G, "label")
53 | node_labels = {
54 | node_number: node_label[:20] + "..."
55 | for node_number, node_label in node_labels.items()
56 | }
57 | nx.draw_networkx_nodes(
58 | self.G, pos, node_size=5000, node_color="skyblue", alpha=0.7
59 | )
60 | nx.draw_networkx_labels(
61 | self.G, pos, labels=node_labels, font_size=10, font_family="sans-serif"
62 | )
63 |
64 | # Draw edges with weights
65 | edges = self.G.edges(data=True)
66 | for u, v, d in edges:
67 | weight = d["weight"]
68 | nx.draw_networkx_edges(
69 | self.G, pos, edgelist=[(u, v)], width=weight * 10, alpha=0.5
70 | )
71 | edge_label = f"{weight:.4f}"
72 | mid_edge = (pos[u] + pos[v]) / 2
73 | plt.text(
74 | mid_edge[0],
75 | mid_edge[1],
76 | edge_label,
77 | fontsize=9,
78 | ha="center",
79 | va="center",
80 | )
81 |
82 | plt.axis("off")
83 | plt.show()
84 |
85 | def find_connected_documents(self, input_sentence, N=3):
86 | # Find the node corresponding to the given sentence
87 | input_sentence = remove_stop_words_from_and_lemmatise_documents(
88 | documents=[input_sentence]
89 | )[0]
90 | node_index = None
91 | for node, data in self.G.nodes(data=True):
92 | if data["label"] == input_sentence:
93 | node_index = node
94 | break
95 |
96 | if node_index is None:
97 | raise ValueError("The provided sentence is not in the graph.")
98 |
99 | # Get the neighbors and their edge weights
100 | neighbors = [
101 | (neighbor, self.G[node_index][neighbor]["weight"])
102 | for neighbor in self.G.neighbors(node_index)
103 | ]
104 |
105 | # Sort neighbors by edge weight in descending order
106 | neighbors = sorted(neighbors, key=lambda x: x[1], reverse=True)
107 |
108 | # Return the top N neighbors with their full text and weights
109 | top_neighbors = [
110 | {"document": self.G.nodes[neighbor]["label"]} # , "similarity": weight}
111 | for neighbor, weight in neighbors[:N]
112 | ]
113 | return top_neighbors
114 |
115 | def find_k_closest_sentences(self, input_sentence, N=5):
116 | input_sentence = remove_stop_words_from_and_lemmatise_documents(
117 | documents=[input_sentence]
118 | )[0]
119 |
120 | # Append the input_sentence to the list of documents
121 | all_docs = self.preprocessed_documents + [input_sentence]
122 |
123 | # Compute TF-IDF for all documents including the input sentence
124 | vectorizer = TfidfVectorizer()
125 | tfidf_matrix = vectorizer.fit_transform(all_docs)
126 |
127 | # Compute cosine similarity between all pairs of documents
128 | cosine_sim = cosine_similarity(tfidf_matrix)
129 | similarity_scores = cosine_sim[-1, :-1] # Exclude the similarity with itself
130 |
131 | # Get the indices of the top N similar documents
132 | closest_indices = np.argsort(similarity_scores)[-N:][::-1]
133 |
134 | # Return the closest N sentences and their similarity scores
135 | closest_sentences = [
136 | (self.preprocessed_documents[idx], similarity_scores[idx])
137 | for idx in closest_indices
138 | if similarity_scores[idx] > 0
139 | ]
140 | return closest_sentences
141 |
142 | def save(self, graph_name):
143 | # save graph object to file
144 | pickle.dump(self.G, open(f"{graph_name}.pickle", "wb"))
145 |
146 | def load_from_file(self, graph_name):
147 | # load graph object from file
148 | self.G = pickle.load(open(f"{graph_name}.pickle", "rb"))
149 |
--------------------------------------------------------------------------------
/examples/documents_graph_usage.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "\n",
8 | "
\n",
9 | ""
10 | ]
11 | },
12 | {
13 | "cell_type": "code",
14 | "execution_count": 1,
15 | "metadata": {},
16 | "outputs": [
17 | {
18 | "data": {
19 | "text/plain": [
20 | "'\\n%%capture\\n!pip install knowledge_graph_rag\\n!pip install numpy==1.24.0\\n!pip install chromadb\\n'"
21 | ]
22 | },
23 | "execution_count": 1,
24 | "metadata": {},
25 | "output_type": "execute_result"
26 | }
27 | ],
28 | "source": [
29 | "%%capture\n",
30 | "!pip install knowledge_graph_rag\n",
31 | "!pip install numpy==1.24.0\n",
32 | "!pip install chromadb"
33 | ]
34 | },
35 | {
36 | "cell_type": "code",
37 | "execution_count": 2,
38 | "metadata": {},
39 | "outputs": [],
40 | "source": [
41 | "%%capture\n",
42 | "from knowledge_graph_rag.document import Document\n",
43 | "from knowledge_graph_rag.documents_graph import DocumentsGraph"
44 | ]
45 | },
46 | {
47 | "cell_type": "code",
48 | "execution_count": 3,
49 | "metadata": {},
50 | "outputs": [],
51 | "source": [
52 | "\n",
53 | "documents = [\"Cardiovascular disease (CVD) encompasses a spectrum of disorders involving the heart and vasculature, prominently including atherosclerosis, characterized by endothelial dysfunction and the accumulation of lipid-laden plaques. These pathophysiological processes often precipitate myocardial infarction and cerebrovascular accidents, arising from the rupture of vulnerable plaques and subsequent thrombogenesis.\",\n",
54 | " \"Management of cardiovascular disease necessitates a multifaceted approach involving antihypertensive agents, statins to modulate dyslipidemia, and antiplatelet therapy to mitigate thrombosis risk.\",\n",
55 | " \"Emerging therapeutic interventions targeting molecular pathways, including PCSK9 inhibitors and SGLT2 inhibitors, show promise in reducing cardiovascular morbidity and mortality.\",\n",
56 | " \"The epidemiological burden of cardiovascular disease underscores the imperative for ongoing research into genetic predispositions and the optimization of primary and secondary prevention strategies.\"\n",
57 | " \"Cardiovascular disease also significantly intersects with metabolic syndrome, wherein insulin resistance and visceral adiposity contribute to endothelial dysfunction and systemic inflammation, further accelerating atherogenic processes.\",\n",
58 | " \"Advanced imaging techniques, such as coronary artery calcium scoring and carotid intima-media thickness measurement, enhance the stratification of cardiovascular risk, enabling more tailored therapeutic interventions.\",\n",
59 | " \"Role of novel biomarkers, including high-sensitivity troponins and natriuretic peptides, is pivotal in the early detection and prognostication of acute coronary syndromes and heart failure within the broader spectrum of cardiovascular disease.\"\n",
60 | "]"
61 | ]
62 | },
63 | {
64 | "cell_type": "markdown",
65 | "metadata": {},
66 | "source": [
67 | "Graph RAG can perform much better than std RAG. Here’s when and how:\n",
68 | "\n",
69 | "When you want your LLM to understand the interconnection between your documents before arriving to its answer, Graph RAG is necessary.\n",
70 | "\n",
71 | "RAG returns search results based on semantic similarity. It fails to consider that, if doc A is selected as highly relevant, the docs containing data closely linked to doc A must be included in the context to give a full picture.\n",
72 | "\n",
73 | "This is where we need Graph RAG.\n",
74 | "\n",
75 | "Search results from a graph are more likely to give you a comprehensive view of the entity being searched and the info connected to it.\n",
76 | "\n",
77 | "Information on entities like people, institutions, etc. is often highly interconnected, and this might be the case for your data too.\n"
78 | ]
79 | },
80 | {
81 | "cell_type": "markdown",
82 | "metadata": {},
83 | "source": [
84 | "### 1. Create a VectorDB"
85 | ]
86 | },
87 | {
88 | "cell_type": "code",
89 | "execution_count": 4,
90 | "metadata": {},
91 | "outputs": [],
92 | "source": [
93 | "def get_embedding_batch(input_array):\n",
94 | " from openai import OpenAI\n",
95 | " client = OpenAI(api_key=\"YOUR_OPENAI_API_KEY\")\n",
96 | " response = client.embeddings.create(\n",
97 | " input=input_array,\n",
98 | " model=\"text-embedding-3-small\"\n",
99 | " )\n",
100 | " return [data.embedding for data in response.data]"
101 | ]
102 | },
103 | {
104 | "cell_type": "code",
105 | "execution_count": 5,
106 | "metadata": {},
107 | "outputs": [],
108 | "source": [
109 | "embeddings = get_embedding_batch(documents)\n",
110 | "vectors_collection = [{document : embedding} for document, embedding in zip(documents, embeddings)]"
111 | ]
112 | },
113 | {
114 | "cell_type": "code",
115 | "execution_count": 6,
116 | "metadata": {},
117 | "outputs": [],
118 | "source": [
119 | "import chromadb\n",
120 | "vectordb_name = \"cvd_vectors\"\n",
121 | "client = chromadb.PersistentClient(path=vectordb_name)\n",
122 | "collection = client.create_collection(vectordb_name)\n",
123 | "\n",
124 | "collection.add(\n",
125 | " embeddings=embeddings,\n",
126 | " documents=documents,\n",
127 | " metadatas=[{\"source\" : \"\"} for i in range(len(documents))],\n",
128 | " ids=list(map(str, range(len(documents))))\n",
129 | ")"
130 | ]
131 | },
132 | {
133 | "cell_type": "markdown",
134 | "metadata": {},
135 | "source": [
136 | "### 2. Create a Documents Graph"
137 | ]
138 | },
139 | {
140 | "cell_type": "code",
141 | "execution_count": 7,
142 | "metadata": {},
143 | "outputs": [
144 | {
145 | "data": {
146 | "image/png": "",
147 | "text/plain": [
148 | ""
149 | ]
150 | },
151 | "metadata": {},
152 | "output_type": "display_data"
153 | }
154 | ],
155 | "source": [
156 | "documents_graph = DocumentsGraph(documents=documents)\n",
157 | "documents_graph.plot()"
158 | ]
159 | },
160 | {
161 | "cell_type": "code",
162 | "execution_count": 8,
163 | "metadata": {},
164 | "outputs": [],
165 | "source": [
166 | "documents_graph.save(\"med_graph\")"
167 | ]
168 | },
169 | {
170 | "cell_type": "code",
171 | "execution_count": 9,
172 | "metadata": {},
173 | "outputs": [],
174 | "source": [
175 | "user_query = \"How do advanced imaging techniques enhance cardiovascular risk stratification?\"\n",
176 | "query_embeddings = get_embedding_batch(user_query)"
177 | ]
178 | },
179 | {
180 | "cell_type": "markdown",
181 | "metadata": {},
182 | "source": [
183 | "### 3. Search vectorDB"
184 | ]
185 | },
186 | {
187 | "cell_type": "code",
188 | "execution_count": 10,
189 | "metadata": {},
190 | "outputs": [
191 | {
192 | "name": "stdout",
193 | "output_type": "stream",
194 | "text": [
195 | "Advanced imaging techniques, such as coronary artery calcium scoring and carotid intima-media thickness measurement, enhance the stratification of cardiovascular risk, enabling more tailored therapeutic interventions.\n"
196 | ]
197 | }
198 | ],
199 | "source": [
200 | "vectordb_search_result = collection.query(query_embeddings=query_embeddings, n_results=1)['documents'][0][0]\n",
201 | "print(vectordb_search_result)"
202 | ]
203 | },
204 | {
205 | "cell_type": "markdown",
206 | "metadata": {},
207 | "source": [
208 | "### 4. Search Documents Graph\n",
209 | "\n",
210 | "To find interconnected documents containing terminology / n-grams used in search result.\n",
211 | "\n",
212 | "Search results from a graph can give you a comprehensive view of the entity being searched and the info connected to it.\n"
213 | ]
214 | },
215 | {
216 | "cell_type": "code",
217 | "execution_count": 11,
218 | "metadata": {},
219 | "outputs": [
220 | {
221 | "data": {
222 | "text/plain": [
223 | "[{'document': 'emerging therapeutic intervention targeting molecular pathway including inhibitor inhibitor show promise reducing cardiovascular morbidity mortality'},\n",
224 | " {'document': 'management cardiovascular disease necessitates multifaceted approach involving antihypertensive agent statin modulate dyslipidemia antiplatelet therapy mitigate thrombosis risk'},\n",
225 | " {'document': 'role novel biomarkers including troponins natriuretic peptide pivotal early detection prognostication acute coronary syndrome heart failure within broader spectrum cardiovascular disease'}]"
226 | ]
227 | },
228 | "execution_count": 11,
229 | "metadata": {},
230 | "output_type": "execute_result"
231 | }
232 | ],
233 | "source": [
234 | "documents_containing_connected_terminology = documents_graph.find_connected_documents(vectordb_search_result)\n",
235 | "documents_containing_connected_terminology"
236 | ]
237 | },
238 | {
239 | "cell_type": "markdown",
240 | "metadata": {},
241 | "source": [
242 | "### 5. Augment interconnected documents into context"
243 | ]
244 | }
245 | ],
246 | "metadata": {
247 | "kernelspec": {
248 | "display_name": "base",
249 | "language": "python",
250 | "name": "python3"
251 | },
252 | "language_info": {
253 | "codemirror_mode": {
254 | "name": "ipython",
255 | "version": 3
256 | },
257 | "file_extension": ".py",
258 | "mimetype": "text/x-python",
259 | "name": "python",
260 | "nbconvert_exporter": "python",
261 | "pygments_lexer": "ipython3",
262 | "version": "3.12.2"
263 | }
264 | },
265 | "nbformat": 4,
266 | "nbformat_minor": 2
267 | }
268 |
--------------------------------------------------------------------------------