├── sample.env ├── src ├── __init__.py ├── pipeline.py ├── imageprocessing.py ├── docparser.py ├── doc_qa.py └── chunkers.py ├── requirements.txt ├── main.py └── README.md /sample.env: -------------------------------------------------------------------------------- 1 | LLAMA_CLOUD_API_KEY=... 2 | GOOGLE_API_KEY=... 3 | TAVILY_API_KEY=... -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- 1 | from .docparser import DocParser 2 | from .chunkers import Chunker, SemanticChunker, AgenticChunker 3 | from .imageprocessing import ImageProcessor 4 | from .doc_qa import QA, AgenticQA, indexing -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | chromadb==0.5.23 2 | img2table[easyocr]==1.4.0 3 | langchain==0.3.13 4 | langchain-chroma==0.1.4 5 | langchain_core==0.3.28 6 | langchain_experimental==0.3.4 7 | langchain_google_genai==2.0.7 8 | llama_parse==0.5.18 9 | opencv_contrib_python_headless==4.10.0.84 10 | opencv_python_headless==4.10.0.84 11 | pymupdf4llm==0.0.17 12 | PyMuPDF==1.24.14 13 | python-dotenv==1.0.1 14 | uuid6==2024.7.10 -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from src import pipeline 3 | 4 | parser = argparse.ArgumentParser() 5 | parser.add_argument( 6 | '--InputPath', 7 | help= 'Directory path containing files to be processed, or a single file path') 8 | 9 | parser.add_argument( 10 | '--parser_name', 11 | help='Specify the name of the parser to use for document processing. Possible values: ["LlamaParse", "pymupdf4llm"]' 12 | ) 13 | 14 | parser.add_argument( 15 | '--chunking_strategy', 16 | help='Define the chunking strategy to apply when processing documents. Possible values: ["semantic", "agentic"]' 17 | ) 18 | 19 | parser.add_argument( 20 | '--retrieval_strategy', 21 | help='Specify the retrieval strategy for querying indexed documents. Possible values:["semantic", "agentic"]' 22 | ) 23 | 24 | def main(): 25 | args = parser.parse_args() 26 | 27 | pipeline.pipeline(args.InputPath, 28 | parser_name=args.parser_name, 29 | chunking_strategy=args.chunking_strategy, 30 | retrieval_strategy=args.retrieval_strategy) 31 | 32 | 33 | if __name__ == '__main__': 34 | main() -------------------------------------------------------------------------------- /src/pipeline.py: -------------------------------------------------------------------------------- 1 | from .chunkers import Chunker 2 | from .docparser import DocParser 3 | from .imageprocessing import ImageProcessor 4 | from glob import glob 5 | from pathlib import Path 6 | from .doc_qa import QA, AgenticQA, indexing 7 | 8 | 9 | def list_supported_files(inputPath, supported_extensions= [".pdf"]): 10 | """ 11 | Lists all supported files in the given input path. 12 | 13 | Args: 14 | inputPath (str): The path where files are located. 15 | 16 | Returns: 17 | List[str]: A list of file paths with supported extensions. 18 | """ 19 | # Retrieve all files matching the input path and filter by supported extensions 20 | file_list = glob(f"{inputPath}/**/*", recursive=True) 21 | return [f for f in file_list if Path(f).suffix in supported_extensions] 22 | 23 | 24 | def pipeline(inputPath, 25 | parser_name, 26 | chunking_strategy, 27 | retrieval_strategy): 28 | 29 | parser= DocParser(parser_name= parser_name) 30 | chunker= Chunker(chunking_strategy) 31 | image_processor= ImageProcessor() 32 | 33 | files_list= list_supported_files(inputPath) 34 | chunks, image_documents= [], [] 35 | 36 | for file_path in files_list: 37 | print("processing started ...") 38 | 39 | text_docs= parser.parsing_function(file_path) 40 | parser.extract_tables(file_path) 41 | 42 | chunks.extend(chunker.build_chunks(text_docs, source= file_path)) 43 | image_documents.extend(image_processor.get_image_documents()) 44 | 45 | doc_indexing= indexing() 46 | retriever= doc_indexing.index_documents(chunks + image_documents) 47 | 48 | if retrieval_strategy == "agentic": 49 | agentic_qa= AgenticQA() 50 | agentic_qa.run(retriever) 51 | agentic_qa.query() 52 | else: 53 | qa = QA(retriever) 54 | qa.query() 55 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Multimodal Agentic RAG 2 | 3 | ### Table of contents 4 | * [Overview](###Overview) 5 | * [Features](###Features) 6 | * [Installation](###Installation) 7 | * [Usage](###Usage) 8 | * [Perspectives](###Perspectives) 9 | 10 | ### Overview 11 | Retrieval-Augmented Generation (RAG) is an advanced AI framework that combines the retrieval capabilities of information retrieval systems (e.g., encoders and vector databases) with the generative power of large language models (LLMs). By leveraging external knowledge sources, such as user-provided documents, RAG delivers accurate, context-aware answers and generates factual, coherent responses. 12 | 13 | Traditional RAG systems primarily rely on static chunking and retrieval mechanisms, which may struggle to adapt dynamically to complex, multimodal data. To address this limitation, this project introduces an agentic approach to chunking and retrieval, adding significant value to the RAG process. 14 | 15 | ### Features 16 | **- chunking (Semantic or Agentic):** \ 17 | Semantic chunker split documents into semantically coherent, meaningful chunks. 18 | Agentic chunker goes further and simulates human judgment of text segmentation: start at the beginning of a document, group sentences based on context and topic, and continue this process iteratively until the entire document is segmented. \ 19 | (For more info: [Agentic Chunking: Enhancing RAG Answers for Completeness and Accuracy](https://gleen.ai/blog/agentic-chunking-enhancing-rag-answers-for-completeness-and-accuracy/)).\ 20 | **- Image and table detection:** \ 21 | Detecting images and tables using PyMuPDF and img2table respectively.\ 22 | **- Summarizing images and tables:** \ 23 | Using a multimodal LLM (eg. gemini-1.5-flash), create a text description of each image and table.\ 24 | **- Embedding:** \ 25 | Embed chunks, images and tables summaries using "text-embedding-004" model.\ 26 | **- Retrieval (Semantic or Agentic):** \ 27 | For a given query: semantic retrieval focuses on embedding-based similarity searches to retrieve information. Agentic retrieval includes 4 steps, following ReAct process: \ 28 | (1). Query rephrasing, with regards to chat history \ 29 | (2). semantic retrieval \ 30 | (3). Assess whether the retrieved documents are relevant and sufficient to answer the query \ 31 | (4). Accordingly, either use the retrieved documents or web search engine to generate a relevant, sufficient and factual answer. 32 | 33 | ### Installation 34 | To run the app locally, the following steps are necessary: 35 | - Clone The repo: 36 | ```bash 37 | git clone https://github.com/AhmedAl93/multimodal-agentic-RAG.git 38 | cd multimodal-agentic-RAG/ 39 | ``` 40 | - Install the required python packages: 41 | ```bash 42 | pip install -r requirements.txt 43 | ``` 44 | - Set up the environment variables in the .env file: 45 | ```bash 46 | LLAMA_CLOUD_API_KEY= 47 | GOOGLE_API_KEY= 48 | TAVILY_API_KEY= 49 | ``` 50 | 51 | ### Usage 52 | 1. Process input document(s): 53 | 54 | Run the following command: 55 | ```bash 56 | python main.py --InputPath --parser_name --chunking_strategy --retrieval_strategy 57 | ``` 58 | Here are more details about the inputs: 59 | ```bash 60 | --InputPath: 'Directory path containing files to be processed, or a single file path' 61 | --parser_name: 'Specify the name of the parser to use for document processing. Possible values: ["LlamaParse", "pymupdf4llm"]' 62 | --chunking_strategy: 'Define the chunking strategy to apply when processing documents. Possible values: ["semantic", "agentic"]' 63 | --retrieval_strategy: 'Specify the retrieval strategy for querying indexed documents. Possible values:["semantic", "agentic"]' 64 | ``` 65 | Currently, only PDF files are supported. So if the input directory contains x PDFs, x files will be processed. 66 | 67 | 2. Provide queries: 68 | In the terminal, you can provide multiple queries and get relevant answers. 69 | 70 | ### Perspectives 71 | In the near future, I plan to work on the following features: 72 | - Support other file types than PDF 73 | - Performance Evaluation for different chunking and retrieval strategies 74 | - Support open-source LLMs 75 | - Support other Vector DBs providers 76 | - Assess and test new concepts: GraphRAG, ... 77 | - Cloud deployment -------------------------------------------------------------------------------- /src/imageprocessing.py: -------------------------------------------------------------------------------- 1 | from langchain_core.prompts import ChatPromptTemplate 2 | from langchain_core.messages import HumanMessage 3 | from typing import List 4 | import glob, time, base64, logging, uuid6 5 | from pathlib import Path 6 | from langchain_google_genai import ChatGoogleGenerativeAI 7 | from langchain_core.documents import Document 8 | from dotenv import find_dotenv, load_dotenv 9 | 10 | load_dotenv(find_dotenv()) 11 | logging.basicConfig(level=logging.INFO) 12 | logger = logging.getLogger(__name__) 13 | 14 | class ImageProcessor: 15 | def __init__(self): 16 | self.image_dir= "./parsed_assets/" 17 | self.llm = ChatGoogleGenerativeAI( 18 | model="gemini-1.5-flash", 19 | temperature=0 20 | ) 21 | 22 | @staticmethod 23 | def retry_with_delay(func, *args, delay=2, retries=30, **kwargs): 24 | """ 25 | Helper method to retry a function call with a delay. 26 | """ 27 | for attempt in range(retries): 28 | try: 29 | return func(*args, **kwargs) 30 | except Exception as e: 31 | logger.warning(f"Attempt {attempt + 1} failed: {e}. Retrying...") 32 | time.sleep(delay) 33 | raise RuntimeError("Exceeded maximum retries.") 34 | 35 | def encode_image(self, image_path): 36 | """Getting the base64 string""" 37 | with open(image_path, "rb") as image_file: 38 | return base64.b64encode(image_file.read()).decode("utf-8") 39 | 40 | def image_summarize(self, img_base64): 41 | """Make image summary""" 42 | prompt = """You are an assistant tasked with summarizing images for retrieval. \ 43 | These summaries will be embedded and used to retrieve the raw image. \ 44 | Give a concise summary of the image that is well optimized for retrieval.""" 45 | # chat = ChatGoogleGenerativeAI(model="gpt-4-vision-preview", max_tokens=1024) 46 | 47 | msg = self.llm.invoke( 48 | [HumanMessage( 49 | content=[ 50 | {"type": "text", "text": prompt}, 51 | {"type": "image_url", 52 | "image_url": {"url": f"data:image/jpeg;base64,{img_base64}"}, 53 | },] 54 | )]) 55 | return msg.content 56 | 57 | def get_image_summaries(self): 58 | # image_paths: List[str])->List[str]: 59 | """ 60 | Generates summaries for a list of images using a generative AI model. 61 | 62 | Args: 63 | image_paths (List[str]): A list of file paths to images. 64 | 65 | Returns: 66 | List[str]: A list of textual summaries for each image. 67 | """ 68 | image_summaries = [] 69 | # for i, img_path in enumerate(image_paths): 70 | for img_path in sorted(glob.glob(f"{self.image_dir}*.png")): 71 | base64_image = self.encode_image(img_path) 72 | 73 | # img_base64_list.append(base64_image) 74 | # Append the AI-generated summary to the list 75 | image_summaries.append( 76 | self.retry_with_delay(self.image_summarize, base64_image) 77 | ) 78 | return image_summaries 79 | 80 | def get_image_documents(self)->List[Document]: 81 | """ 82 | Extracts images from files and generates corresponding text nodes with metadata. 83 | 84 | Args: 85 | files_to_process (List[str]): A list of file paths to extract images from. 86 | 87 | Returns: 88 | List[TextNode]: A list of nodes containing image summaries and metadata. 89 | """ 90 | image_documents = [] 91 | # Generate summaries for the extracted images 92 | image_summaries = self.get_image_summaries() 93 | image_paths= sorted(glob.glob(f"{self.image_dir}*.png")) 94 | 95 | for summary, image_path in zip(image_summaries, image_paths): 96 | # Append the created node to the list 97 | image_documents.append( 98 | Document( 99 | page_content=summary, 100 | metadata={"source": Path(image_path).name}, 101 | id= str(uuid6.uuid6()), 102 | ) 103 | ) 104 | 105 | return image_documents 106 | -------------------------------------------------------------------------------- /src/docparser.py: -------------------------------------------------------------------------------- 1 | import pymupdf4llm, cv2, fitz, os, io 2 | from pathlib import Path 3 | from typing import List 4 | from llama_parse import LlamaParse 5 | from img2table.ocr import EasyOCR 6 | from img2table.document import PDF, Image 7 | 8 | from dotenv import find_dotenv, load_dotenv 9 | load_dotenv(find_dotenv()) 10 | 11 | class DocParser: 12 | def __init__(self, parser_name): 13 | self.parser_name= parser_name 14 | self.assets_dir= "./parsed_assets/" 15 | self.parser_function_map= { 16 | "LlamaParse": self.with_LlamaParse, 17 | "pymupdf4llm": self.with_pymupdf4llm 18 | } 19 | self.parsing_function= self.parser_function_map[parser_name] 20 | 21 | # Instantiation of OCR 22 | self.ocr = EasyOCR(lang=["en"]) 23 | # Ensure the save directory exists 24 | os.makedirs(self.assets_dir, exist_ok=True) 25 | 26 | def parse(self, file_path): 27 | text_docs= self.parsing_function(file_path) 28 | if self.parser_name=="LlamaParse": 29 | self.extract_images(file_path) 30 | 31 | self.extract_tables(file_path) 32 | return text_docs 33 | 34 | def with_LlamaParse(self, file_path): 35 | print("LLamaParse is being used ...") 36 | parser = LlamaParse(result_type="markdown", verbose=False) 37 | data= parser.load_data(file_path=file_path) 38 | text_docs= [x.text for x in data] 39 | return text_docs 40 | 41 | def with_pymupdf4llm(self, file_path): 42 | #No need for standalone image extraction step, already done here 43 | output = pymupdf4llm.to_markdown( 44 | file_path, 45 | write_images=True, 46 | image_path=self.assets_dir, 47 | extract_words= True, 48 | show_progress= False) 49 | 50 | text_docs= [x["text"].replace("-----", "") 51 | for x in output] 52 | return text_docs 53 | 54 | def extract_tables(self, file_path): 55 | # Instantiation of document, either an image or a PDF 56 | if Path(file_path).suffix==".pdf": 57 | doc = PDF(file_path) 58 | else: 59 | doc = Image(file_path) 60 | # Table extraction 61 | extracted_tables = doc.extract_tables(ocr=self.ocr, 62 | implicit_rows=True, 63 | implicit_columns=True, 64 | borderless_tables=True) 65 | 66 | margin= 20 67 | save_dir= Path(self.assets_dir) 68 | file_stem= Path(file_path).stem 69 | 70 | for p, (image, tables) in enumerate( 71 | zip(doc._images, 72 | extracted_tables.values())): 73 | for i, t in enumerate(tables): 74 | table_image= image[t.bbox.y1-margin:t.bbox.y2+margin, 75 | t.bbox.x1-margin:t.bbox.x2+margin] 76 | cv2.imwrite(save_dir.joinpath(f"{file_stem}_{p}_table{i}.png"), table_image) 77 | 78 | 79 | def extract_images(self, filepath): 80 | """ 81 | Extracts images from the provided files and saves them to the specified directory. 82 | 83 | Args: 84 | files_to_process (List[str]): List of file paths to extract images from. 85 | save_dir (str): Directory to save the extracted images. 86 | 87 | Returns: 88 | None 89 | """ 90 | # for filepath in files_to_process: 91 | # Open the document using PyMuPDF 92 | doc = fitz.open(filepath) 93 | save_dir= Path(self.assets_dir) 94 | 95 | for p in range(len(doc)): 96 | page = doc[p] 97 | 98 | # Iterate through images on the page 99 | for i, img in enumerate(page.get_images(), start=1): 100 | xref = img[0] # Image reference ID 101 | 102 | # Extract image bytes 103 | base_image = doc.extract_image(xref) 104 | image_bytes = base_image["image"] 105 | 106 | # Create a PIL Image object from the bytes 107 | pil_image = Image.open(io.BytesIO(image_bytes)) 108 | 109 | # Save the image with a structured name 110 | image_name = f"{save_dir.joinpath(Path(filepath).stem)}_{p}_image{i}.png" 111 | pil_image.save(image_name) -------------------------------------------------------------------------------- /src/doc_qa.py: -------------------------------------------------------------------------------- 1 | from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder 2 | from langchain import hub 3 | from langchain_google_genai import ChatGoogleGenerativeAI 4 | from langchain.chains.combine_documents import create_stuff_documents_chain 5 | from langchain.chains import create_history_aware_retriever, create_retrieval_chain 6 | from langchain_core.tools import Tool 7 | from langchain_community.tools.tavily_search import TavilySearchResults 8 | from langchain.agents import AgentExecutor, create_react_agent 9 | from langchain_core.messages import AIMessage, HumanMessage 10 | 11 | from langchain_chroma import Chroma 12 | from langchain_google_genai import GoogleGenerativeAIEmbeddings 13 | import uuid6 14 | 15 | class indexing: 16 | def __init__(self): 17 | self.embedding_function= GoogleGenerativeAIEmbeddings(model="models/text-embedding-004") 18 | pass 19 | 20 | def index_documents(self, documents, 21 | collection_name="Agentic_retrieval", 22 | top_k= 3): 23 | vector_store = Chroma( 24 | collection_name= collection_name, 25 | embedding_function=self.embedding_function) 26 | 27 | vector_store.add_documents( 28 | documents=documents, 29 | ids=[str(uuid6.uuid6()) for _ in documents]) 30 | 31 | retriever = vector_store.as_retriever( 32 | search_type="similarity", 33 | search_kwargs={"k": top_k},) 34 | 35 | return retriever 36 | 37 | class QA: 38 | def __init__(self, retriever) -> None: 39 | self.system_template = """ 40 | Answer the user's questions based on the below context. 41 | If the context doesn't contain any relevant information to the question, don't make something up and just say "I don't know": 42 | 43 | 44 | {context} 45 | 46 | """ 47 | 48 | self.question_answering_prompt = ChatPromptTemplate.from_messages( 49 | [("system", self.system_template), 50 | MessagesPlaceholder(variable_name="messages"),] 51 | ) 52 | self.retriever= retriever 53 | self.llm = ChatGoogleGenerativeAI(model="gemini-1.5-flash") 54 | 55 | self.qa_chain = create_stuff_documents_chain(self.llm, 56 | self.question_answering_prompt 57 | ) 58 | 59 | def query(self): 60 | while True: 61 | query = input("You: ") 62 | if query.lower() == "exit": 63 | break 64 | docs = self.retriever.invoke(query) 65 | 66 | response = self.qa_chain.invoke( 67 | {"context": docs, 68 | "messages": [HumanMessage(content=query)] 69 | } 70 | ) 71 | print(f"AI: {response}") 72 | 73 | 74 | class AgenticQA: 75 | def __init__(self) -> None: 76 | self.contextualize_q_system_prompt = ( 77 | "Given a chat history and the latest user question " 78 | "which might reference context in the chat history, " 79 | "formulate a standalone question which can be understood " 80 | "without the chat history. Do NOT answer the question, just " 81 | "reformulate it if needed and otherwise return it as is." 82 | ) 83 | self.chat_history = [] 84 | 85 | self.contextualize_q_prompt = ChatPromptTemplate.from_messages( 86 | [ 87 | ("system", self.contextualize_q_system_prompt), 88 | MessagesPlaceholder("chat_history"), 89 | ("human", "{input}"), 90 | ] 91 | ) 92 | 93 | self.qa_system_prompt = ( 94 | "You are an assistant for question-answering tasks. Use " 95 | "the following pieces of retrieved context to answer the " 96 | "question." 97 | "\n\n" 98 | "{context}" 99 | ) 100 | self.qa_prompt = ChatPromptTemplate.from_messages( 101 | [ 102 | ("system", self.qa_system_prompt), 103 | MessagesPlaceholder("chat_history"), 104 | ("human", "{input}"), 105 | ] 106 | ) 107 | self.react_docstore_prompt = hub.pull("aallali/react_tool_priority") 108 | self.llm = ChatGoogleGenerativeAI(model="gemini-1.5-flash") 109 | 110 | def create_rag_chain(self, retriever): 111 | history_aware_retriever = create_history_aware_retriever( 112 | self.llm, retriever, self.contextualize_q_prompt 113 | ) 114 | question_answer_chain = create_stuff_documents_chain(self.llm, 115 | self.qa_prompt) 116 | 117 | self.rag_chain = create_retrieval_chain( 118 | history_aware_retriever, question_answer_chain) 119 | 120 | def create_rag_agent(self): 121 | self.agent = create_react_agent( 122 | llm=self.llm, 123 | tools=self.tools, 124 | prompt=self.react_docstore_prompt) 125 | 126 | def execute_rag_agent(self): 127 | self.agent_executor = AgentExecutor.from_agent_and_tools( 128 | agent=self.agent, 129 | tools=self.tools, 130 | handle_parsing_errors=True, 131 | verbose=True,) 132 | 133 | def run(self, retriever): 134 | self.create_rag_chain(retriever) 135 | 136 | self.tools = [ 137 | Tool( 138 | name="Answer Question", 139 | func=lambda query, **kwargs: self.rag_chain.invoke({ 140 | "input": query, 141 | "chat_history": kwargs.get("chat_history", []) 142 | }), 143 | description=( 144 | "A chat assistant tool designed to provide answers based on document knowledge. " 145 | "Maintains the context of previous questions and answers for continuity."), 146 | ), 147 | TavilySearchResults(max_results=2)] 148 | 149 | self.create_rag_agent() 150 | self.execute_rag_agent() 151 | 152 | def query(self): 153 | while True: 154 | query = input("You: ") 155 | if query.lower() == "exit": 156 | break 157 | response = self.agent_executor.invoke( 158 | {"input": query, 159 | "chat_history": self.chat_history}) 160 | print(f"AI: {response['output']}") 161 | 162 | # Update history 163 | self.chat_history.append(HumanMessage(content=query)) 164 | self.chat_history.append(AIMessage(content=response["output"])) -------------------------------------------------------------------------------- /src/chunkers.py: -------------------------------------------------------------------------------- 1 | 2 | from langchain_core.prompts import ChatPromptTemplate 3 | from typing import List 4 | from langchain_core.pydantic_v1 import BaseModel, Field 5 | from langchain import hub 6 | import time, logging, uuid6 7 | from langchain_core.documents import Document 8 | from dotenv import find_dotenv, load_dotenv 9 | from langchain_google_genai import ( 10 | GoogleGenerativeAIEmbeddings, 11 | ChatGoogleGenerativeAI) 12 | from langchain_experimental.text_splitter import SemanticChunker 13 | 14 | 15 | load_dotenv(find_dotenv()) 16 | 17 | logging.basicConfig(level=logging.INFO) 18 | logger = logging.getLogger(__name__) 19 | 20 | class Chunker: 21 | def __init__(self, strategy): 22 | self.semantic_chunker= SemanticChunker_langchain() 23 | self.agentic_chunker= AgenticChunker() 24 | self.strategy_chunker_map= { 25 | "semantic": self.semantic_chunker, 26 | "agentic": self.agentic_chunker 27 | } 28 | self.chunker= self.strategy_chunker_map[strategy] 29 | 30 | def build_chunks(self, texts, source): 31 | return self.chunker.build_chunks(texts, source) 32 | 33 | class SemanticChunker_langchain: 34 | #https://python.langchain.com/v0.2/docs/how_to/semantic-chunker/ 35 | def __init__(self): 36 | self.embed_model_name= "models/text-embedding-004" 37 | 38 | def build_chunks(self, texts, source): 39 | text_splitter = SemanticChunker( 40 | GoogleGenerativeAIEmbeddings( 41 | model=self.embed_model_name)) 42 | 43 | chunks= text_splitter.create_documents( 44 | texts=texts, 45 | metadatas= [{"source": source}]*len(texts) 46 | ) 47 | return chunks 48 | 49 | class ChunkMeta(BaseModel): 50 | title: str = Field(description="The title of the chunk.") 51 | summary: str = Field(description="The summary of the chunk.") 52 | 53 | class ChunkID(BaseModel): 54 | chunk_id: int = Field(description="The chunk id.") 55 | 56 | class Sentences(BaseModel): 57 | sentences: List[str] 58 | 59 | class AgenticChunker: 60 | def __init__(self): 61 | """ 62 | Initializes the AgenticChunker with: 63 | - An empty dictionary for storing chunks. 64 | - A large language model (LLM) for processing and summarizing text. 65 | - A placeholder for raw text input. 66 | """ 67 | self.chunks = {} 68 | self.llm = ChatGoogleGenerativeAI( 69 | model="gemini-1.5-flash", 70 | temperature=0 71 | ) 72 | # self.raw_text = "" 73 | 74 | @staticmethod 75 | def retry_with_delay(func, *args, delay=2, retries=30, **kwargs): 76 | """ 77 | Helper method to retry a function call with a delay. 78 | """ 79 | for attempt in range(retries): 80 | try: 81 | return func(*args, **kwargs) 82 | except Exception as e: 83 | logger.warning(f"Attempt {attempt + 1} failed: {e}. Retrying...") 84 | time.sleep(delay) 85 | raise RuntimeError("Exceeded maximum retries.") 86 | 87 | def extract_propositions_list(self, raw_text): 88 | """ 89 | Extracts a list of propositions from the raw text using an LLM. 90 | """ 91 | logger.info("Extracting propositions from raw text.") 92 | extraction_llm = self.llm.with_structured_output(Sentences) 93 | obj = hub.pull("wfh/proposal-indexing") 94 | extraction_chain = obj | extraction_llm 95 | self.propositions_list = self.retry_with_delay(extraction_chain.invoke, raw_text).sentences 96 | 97 | def build_chunks(self, raw_text, source=""): 98 | """ 99 | Processes the list of propositions and organizes them into chunks. 100 | """ 101 | chunks_as_documents=[] 102 | logger.info("Building chunks from propositions.") 103 | self.extract_propositions_list(raw_text) 104 | for proposition in self.propositions_list: 105 | self.find_chunk_and_push_proposition(proposition) 106 | 107 | for chunk_id in self.chunks: 108 | chunk_content= " ".join(self.chunks[chunk_id]["propositions"]) 109 | chunks_as_documents.append(Document( 110 | page_content=chunk_content, 111 | metadata={"source": f"{source}_{chunk_id}"}, 112 | id= str(uuid6.uuid6()), 113 | )) 114 | 115 | return chunks_as_documents 116 | 117 | def create_prompt_template(self, messages): 118 | """ 119 | Helper method to create prompt templates. 120 | """ 121 | return ChatPromptTemplate.from_messages(messages) 122 | 123 | def upsert_chunk(self, chunk_id, propositions): 124 | """ 125 | Creates or updates a chunk with the given propositions. 126 | """ 127 | summary_llm = self.llm.with_structured_output(ChunkMeta) 128 | prompt = self.create_prompt_template([ 129 | ("system", "Generate a new or updated summary and title based on the propositions."), 130 | ("user", "propositions:{propositions}") 131 | ]) 132 | summary_chain = prompt | summary_llm 133 | 134 | chunk_meta = self.retry_with_delay(summary_chain.invoke, {"propositions": propositions}) 135 | self.chunks[chunk_id] = { 136 | "summary": chunk_meta.summary, 137 | "title": chunk_meta.title, 138 | "propositions": propositions 139 | } 140 | 141 | def find_chunk_and_push_proposition(self, proposition): 142 | """ 143 | Finds the most relevant chunk for a proposition or creates a new one if none match. 144 | """ 145 | logger.info(f"Finding chunk for proposition: {proposition}") 146 | allocation_llm = self.llm.with_structured_output(ChunkID) 147 | allocation_prompt = self.create_prompt_template([ 148 | ("system", "Using the chunk IDs and summaries, determine the best chunk for the proposition. " 149 | "If no chunk matches, generate a new chunk ID. Return only the chunk ID."), 150 | ("user", "proposition:{proposition}\nchunks_summaries:{chunks_summaries}") 151 | ]) 152 | allocation_chain = allocation_prompt | allocation_llm 153 | 154 | chunks_summaries = { 155 | chunk_id: chunk["summary"] for chunk_id, chunk in self.chunks.items() 156 | } 157 | 158 | best_chunk_id = self.retry_with_delay( 159 | allocation_chain.invoke, { 160 | "proposition": proposition, 161 | "chunks_summaries": chunks_summaries 162 | } 163 | ).chunk_id 164 | 165 | if best_chunk_id not in self.chunks: 166 | logger.info(f"Creating new chunk for proposition: {proposition}") 167 | self.upsert_chunk(best_chunk_id, [proposition]) 168 | else: 169 | logger.info(f"Adding proposition to existing chunk ID: {best_chunk_id}") 170 | current_propositions = self.chunks[best_chunk_id]["propositions"] 171 | self.upsert_chunk(best_chunk_id, current_propositions + [proposition]) 172 | --------------------------------------------------------------------------------