├── .gitignore ├── README.md ├── data ├── monopoly.pdf └── ticket_to_ride.pdf ├── get_embedding_function.py ├── populate_database.py ├── query_data.py ├── requirements.txt └── test_rag.py /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | .DS_Store 3 | backup 4 | chroma -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # rag-tutorial-v2 2 | -------------------------------------------------------------------------------- /data/monopoly.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pixegami/rag-tutorial-v2/5e71164a3ab0f78e734e0d79b8c82d4954197fc7/data/monopoly.pdf -------------------------------------------------------------------------------- /data/ticket_to_ride.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pixegami/rag-tutorial-v2/5e71164a3ab0f78e734e0d79b8c82d4954197fc7/data/ticket_to_ride.pdf -------------------------------------------------------------------------------- /get_embedding_function.py: -------------------------------------------------------------------------------- 1 | from langchain_community.embeddings.ollama import OllamaEmbeddings 2 | from langchain_community.embeddings.bedrock import BedrockEmbeddings 3 | 4 | 5 | def get_embedding_function(): 6 | embeddings = BedrockEmbeddings( 7 | credentials_profile_name="default", region_name="us-east-1" 8 | ) 9 | # embeddings = OllamaEmbeddings(model="nomic-embed-text") 10 | return embeddings 11 | -------------------------------------------------------------------------------- /populate_database.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import shutil 4 | from langchain.document_loaders.pdf import PyPDFDirectoryLoader 5 | from langchain_text_splitters import RecursiveCharacterTextSplitter 6 | from langchain.schema.document import Document 7 | from get_embedding_function import get_embedding_function 8 | from langchain.vectorstores.chroma import Chroma 9 | 10 | 11 | CHROMA_PATH = "chroma" 12 | DATA_PATH = "data" 13 | 14 | 15 | def main(): 16 | 17 | # Check if the database should be cleared (using the --clear flag). 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument("--reset", action="store_true", help="Reset the database.") 20 | args = parser.parse_args() 21 | if args.reset: 22 | print("✨ Clearing Database") 23 | clear_database() 24 | 25 | # Create (or update) the data store. 26 | documents = load_documents() 27 | chunks = split_documents(documents) 28 | add_to_chroma(chunks) 29 | 30 | 31 | def load_documents(): 32 | document_loader = PyPDFDirectoryLoader(DATA_PATH) 33 | return document_loader.load() 34 | 35 | 36 | def split_documents(documents: list[Document]): 37 | text_splitter = RecursiveCharacterTextSplitter( 38 | chunk_size=800, 39 | chunk_overlap=80, 40 | length_function=len, 41 | is_separator_regex=False, 42 | ) 43 | return text_splitter.split_documents(documents) 44 | 45 | 46 | def add_to_chroma(chunks: list[Document]): 47 | # Load the existing database. 48 | db = Chroma( 49 | persist_directory=CHROMA_PATH, embedding_function=get_embedding_function() 50 | ) 51 | 52 | # Calculate Page IDs. 53 | chunks_with_ids = calculate_chunk_ids(chunks) 54 | 55 | # Add or Update the documents. 56 | existing_items = db.get(include=[]) # IDs are always included by default 57 | existing_ids = set(existing_items["ids"]) 58 | print(f"Number of existing documents in DB: {len(existing_ids)}") 59 | 60 | # Only add documents that don't exist in the DB. 61 | new_chunks = [] 62 | for chunk in chunks_with_ids: 63 | if chunk.metadata["id"] not in existing_ids: 64 | new_chunks.append(chunk) 65 | 66 | if len(new_chunks): 67 | print(f"👉 Adding new documents: {len(new_chunks)}") 68 | new_chunk_ids = [chunk.metadata["id"] for chunk in new_chunks] 69 | db.add_documents(new_chunks, ids=new_chunk_ids) 70 | db.persist() 71 | else: 72 | print("✅ No new documents to add") 73 | 74 | 75 | def calculate_chunk_ids(chunks): 76 | 77 | # This will create IDs like "data/monopoly.pdf:6:2" 78 | # Page Source : Page Number : Chunk Index 79 | 80 | last_page_id = None 81 | current_chunk_index = 0 82 | 83 | for chunk in chunks: 84 | source = chunk.metadata.get("source") 85 | page = chunk.metadata.get("page") 86 | current_page_id = f"{source}:{page}" 87 | 88 | # If the page ID is the same as the last one, increment the index. 89 | if current_page_id == last_page_id: 90 | current_chunk_index += 1 91 | else: 92 | current_chunk_index = 0 93 | 94 | # Calculate the chunk ID. 95 | chunk_id = f"{current_page_id}:{current_chunk_index}" 96 | last_page_id = current_page_id 97 | 98 | # Add it to the page meta-data. 99 | chunk.metadata["id"] = chunk_id 100 | 101 | return chunks 102 | 103 | 104 | def clear_database(): 105 | if os.path.exists(CHROMA_PATH): 106 | shutil.rmtree(CHROMA_PATH) 107 | 108 | 109 | if __name__ == "__main__": 110 | main() 111 | -------------------------------------------------------------------------------- /query_data.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from langchain.vectorstores.chroma import Chroma 3 | from langchain.prompts import ChatPromptTemplate 4 | from langchain_community.llms.ollama import Ollama 5 | 6 | from get_embedding_function import get_embedding_function 7 | 8 | CHROMA_PATH = "chroma" 9 | 10 | PROMPT_TEMPLATE = """ 11 | Answer the question based only on the following context: 12 | 13 | {context} 14 | 15 | --- 16 | 17 | Answer the question based on the above context: {question} 18 | """ 19 | 20 | 21 | def main(): 22 | # Create CLI. 23 | parser = argparse.ArgumentParser() 24 | parser.add_argument("query_text", type=str, help="The query text.") 25 | args = parser.parse_args() 26 | query_text = args.query_text 27 | query_rag(query_text) 28 | 29 | 30 | def query_rag(query_text: str): 31 | # Prepare the DB. 32 | embedding_function = get_embedding_function() 33 | db = Chroma(persist_directory=CHROMA_PATH, embedding_function=embedding_function) 34 | 35 | # Search the DB. 36 | results = db.similarity_search_with_score(query_text, k=5) 37 | 38 | context_text = "\n\n---\n\n".join([doc.page_content for doc, _score in results]) 39 | prompt_template = ChatPromptTemplate.from_template(PROMPT_TEMPLATE) 40 | prompt = prompt_template.format(context=context_text, question=query_text) 41 | # print(prompt) 42 | 43 | model = Ollama(model="mistral") 44 | response_text = model.invoke(prompt) 45 | 46 | sources = [doc.metadata.get("id", None) for doc, _score in results] 47 | formatted_response = f"Response: {response_text}\nSources: {sources}" 48 | print(formatted_response) 49 | return response_text 50 | 51 | 52 | if __name__ == "__main__": 53 | main() 54 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pypdf 2 | langchain 3 | chromadb # Vector storage 4 | pytest 5 | boto3 6 | -------------------------------------------------------------------------------- /test_rag.py: -------------------------------------------------------------------------------- 1 | from query_data import query_rag 2 | from langchain_community.llms.ollama import Ollama 3 | 4 | EVAL_PROMPT = """ 5 | Expected Response: {expected_response} 6 | Actual Response: {actual_response} 7 | --- 8 | (Answer with 'true' or 'false') Does the actual response match the expected response? 9 | """ 10 | 11 | 12 | def test_monopoly_rules(): 13 | assert query_and_validate( 14 | question="How much total money does a player start with in Monopoly? (Answer with the number only)", 15 | expected_response="$1500", 16 | ) 17 | 18 | 19 | def test_ticket_to_ride_rules(): 20 | assert query_and_validate( 21 | question="How many points does the longest continuous train get in Ticket to Ride? (Answer with the number only)", 22 | expected_response="10 points", 23 | ) 24 | 25 | 26 | def query_and_validate(question: str, expected_response: str): 27 | response_text = query_rag(question) 28 | prompt = EVAL_PROMPT.format( 29 | expected_response=expected_response, actual_response=response_text 30 | ) 31 | 32 | model = Ollama(model="mistral") 33 | evaluation_results_str = model.invoke(prompt) 34 | evaluation_results_str_cleaned = evaluation_results_str.strip().lower() 35 | 36 | print(prompt) 37 | 38 | if "true" in evaluation_results_str_cleaned: 39 | # Print response in Green if it is correct. 40 | print("\033[92m" + f"Response: {evaluation_results_str_cleaned}" + "\033[0m") 41 | return True 42 | elif "false" in evaluation_results_str_cleaned: 43 | # Print response in Red if it is incorrect. 44 | print("\033[91m" + f"Response: {evaluation_results_str_cleaned}" + "\033[0m") 45 | return False 46 | else: 47 | raise ValueError( 48 | f"Invalid evaluation result. Cannot determine if 'true' or 'false'." 49 | ) 50 | --------------------------------------------------------------------------------