├── .env ├── .gitignore ├── README.md ├── app.py ├── data ├── alpha_society.pdf ├── beta_society.pdf └── gamma_society.pdf ├── embeddings ├── embeddings.py └── openai_embeddings.py ├── llm ├── llm.py └── llm_factory.py ├── populate_database.py ├── requirements.txt ├── retrieval └── rag_retriever.py ├── static ├── admin_settings.js ├── demo_img │ ├── rag_demo.mp4 │ ├── screenshot_1.jpg │ ├── screenshot_2.jpg │ ├── screenshot_3.jpg │ └── screenshot_4.jpg └── styles.css ├── templates ├── admin.html └── index.html └── test_rag.py /.env: -------------------------------------------------------------------------------- 1 | VECTOR_DB_OPENAI_PATH='chroma-openai' 2 | VECTOR_DB_OLLAMA_PATH='chroma-ollama' 3 | DATA_PATH='data' 4 | EMBEDDING_MODEL_NAME='openai' 5 | LLM_MODEL_TYPE='gpt' 6 | LLM_MODEL_NAME='gpt-3.5-turbo' 7 | NUM_RELEVANT_DOCS='3' 8 | OPENAI_API_KEY='YOUR_OPENAI_KEY_HERE' 9 | CLAUDE_API_KEY='YOUR_CLAUDE_KEY_HERE' 10 | 11 | # UNCOMMENT FOR LOCAL SETUP: 12 | 13 | #EMBEDDING_MODEL_NAME=ollama 14 | #LLM_MODEL_TYPE=ollama 15 | #LLM_MODEL_NAME=llama3:8b -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | .DS_Store 3 | backup 4 | chroma-ollama 5 | chroma-openai -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | A simple local Retrieval-Augmented Generation (RAG) chatbot that can answer to questions by acquiring information from personal pdf documents. 2 | 3 | (please, if you find this content useful please consider leaving a star ⭐) 4 | 5 | ## What is Retrieval-Augmented Generation (RAG)? 6 |
7 | rag_pipeline 8 |
9 | Retrieval-Augmented Generation (RAG) is a technique that combines the strengths of information retrieval and natural language generation. In a RAG system, a retriever fetches relevant documents or text chunks from a database, and then a generator produces a response based on the retrieved context. 10 | 11 | 1. **Data Indexing** 12 | - Documents: This is the starting point where multiple documents are stored. 13 | - Vector DB: The documents are processed and indexed into a Vector Database. 14 | 15 | 2. **User Query** 16 | - A user query is input into the system, which interacts with the Vector Database. 17 | 18 | 3. **Data Retrieval & Generation** 19 | - Top-K Chunks: The Vector Database retrieves the top-K relevant chunks based on the user query. 20 | - LLM (Large Language Model): These chunks are then fed into a Large Language Model. 21 | - Response: The LLM generates a response based on the relevant chunks. 22 | 23 | ## 🏗️ Implementation Components 24 | For this project, i exploited the following components to build the RAG architecture: 25 | 1. **Chroma**: A vector database used to store and retrieve document embeddings efficiently. 26 | 2. **Flask**: Framework for rendering web page and handling user interactions. 27 | 3. **Ollama**: Manages the local language model for generating responses. 28 | 4. **LangChain**: A framework for integrating language models and retrieval systems. 29 | 30 | ## 🛠️ Setup and Local Deployment 31 | 32 | 1. **Choose Your Setup**: 33 | - You have three different options for setting up the LLMs: 34 | 1. Local setup using Ollama. 35 | 2. Using the OpenAI API for GPT models. 36 | 3. Using the Anthropic API for Claude models. 37 | 38 | ### Option 1: Local Setup with Ollama 39 | 40 | - **Download and install Ollama on your PC**: 41 | - Visit [Ollama's official website](https://ollama.com/download) to download and install Ollama. Ensure you have sufficient hardware resources to run the local language model. 42 | - Pull a LMM of your choice: 43 | ```sh 44 | ollama pull # e.g. ollama pull llama3:8b 45 | 46 | ### Option 2: Use OpenAI API for GPT Models 47 | - **Set up OpenAI API**: you can sign up and get your API key from [OpenAI's website](https://openai.com/api/). 48 | 49 | ### Option 3: Use Anthropic API for Claude Models 50 | - **Set up Anthropic API**: you can sign up and get your API key from [Anthropic's website](https://www.anthropic.com/api). 51 | 52 | ## Common Steps 53 | 54 | 2. **Clone the repository and navigate to the project directory**: 55 | ```sh 56 | git clone https://github.com/enricollen/rag-conversational-agent.git 57 | cd rag-conversational-agent 58 | ``` 59 | 60 | 3. **Create a virtual environment**: 61 | ```sh 62 | python -m venv venv 63 | source venv/bin/activate # On Windows, use `venv\Scripts\activate` 64 | ``` 65 | 66 | 4. **Install the required libraries**: 67 | ```sh 68 | pip install -r requirements.txt 69 | ``` 70 | 71 | 5. **Insert you own PDFs in /data folder** 72 | 73 | 6. **Run once the populate_database script to index the pdf files into the vector db:** 74 | ```sh 75 | python populate_database.py 76 | ``` 77 | 78 | 7. **Run the application:** 79 | ```sh 80 | python app.py 81 | ``` 82 | 83 | 8. **Navigate to `http://localhost:5000/`** 84 | 85 | 9. **If needed, click on ⚙️ icon to access the admin panel and adjust app parameters** 86 | 87 | 10. **Perform a query** 88 | 89 | ## 🚀 Future Improvements 90 | Here are some ideas for future improvements: 91 | - [x] Add OpenAI LLM GPT models compatibility (3.5 turbo, 4, 4-o) 92 | - [x] Add Anthropic Claude LLM models compatibility (Claude 3.5 Sonnet, Claude 3 Sonnet, Claude 3 Opus, Claude 3 Haiku) 93 | - [x] Add unit testing to validate the responses given by the LLM 94 | - [x] Add an admin user interface in web UI to choose interactively the parameters like LLMs, embedding models etc. 95 | - [ ] Add Langchain Tools compatibility, allowing users to define custom Python functions that can be utilized by the LLMs. 96 | - [ ] Add web scraping in case none of the personal documents contain relevant info w.r.t. the query 97 | 98 | ## 📹 Demo Video 99 | Watch the demo video below to see the RAG Chatbot in action: 100 | 101 | [![YT Video](https://img.youtube.com/vi/_JVt5gwwZq0/0.jpg)](https://www.youtube.com/watch?v=_JVt5gwwZq0) 102 | 103 | The demo was run on my PC with the following specifications: 104 | - **Processor**: Intel(R) Core(TM) i7-14700K 3.40 GHz 105 | - **RAM**: 32.0 GB 106 | - **GPU**: NVIDIA GeForce RTX 3090 FE 24 GB 107 | -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | from flask import Flask, request, render_template, jsonify, redirect, url_for 2 | from llm.llm_factory import LLMFactory 3 | from retrieval.rag_retriever import RAGRetriever 4 | from dotenv import load_dotenv, set_key 5 | import os 6 | 7 | load_dotenv() 8 | 9 | VECTOR_DB_OPENAI_PATH = os.getenv('VECTOR_DB_OPENAI_PATH') 10 | VECTOR_DB_OLLAMA_PATH = os.getenv('VECTOR_DB_OLLAMA_PATH') 11 | LLM_MODEL_NAME = os.getenv('LLM_MODEL_NAME') # 'gpt-3.5-turbo', 'GPT-4o' or local LLM like 'llama3:8b', 'gemma2', 'mistral:7b' etc. 12 | LLM_MODEL_TYPE = os.getenv('LLM_MODEL_TYPE') # 'ollama', 'gpt' or 'claude' 13 | EMBEDDING_MODEL_NAME = os.getenv('EMBEDDING_MODEL_NAME') # 'ollama' or 'openai' 14 | NUM_RELEVANT_DOCS = int(os.getenv('NUM_RELEVANT_DOCS')) 15 | OPENAI_API_KEY = os.getenv('OPENAI_API_KEY') 16 | CLAUDE_API_KEY = os.getenv('CLAUDE_API_KEY') 17 | ENV_PATH = '.env' 18 | 19 | app = Flask(__name__) 20 | 21 | # Initialize the retriever and LLM 22 | retriever = None 23 | llm_model = None 24 | 25 | def get_vector_db_path(embedding_model_name): 26 | if embedding_model_name == "openai": 27 | return VECTOR_DB_OPENAI_PATH 28 | elif embedding_model_name == "ollama": 29 | return VECTOR_DB_OLLAMA_PATH 30 | else: 31 | raise ValueError(f"Unsupported embedding model: {embedding_model_name}") 32 | 33 | def initialize_components(): 34 | """ Initialize the retriever and LLM components based on the current settings. """ 35 | global retriever, llm_model 36 | vector_db_path = get_vector_db_path(EMBEDDING_MODEL_NAME) 37 | 38 | # Select the appropriate API key based on the embedding model 39 | if EMBEDDING_MODEL_NAME == "openai": 40 | api_key = OPENAI_API_KEY 41 | else: 42 | api_key = CLAUDE_API_KEY 43 | 44 | retriever = RAGRetriever(vector_db_path=vector_db_path, embedding_model_name=EMBEDDING_MODEL_NAME, api_key=api_key) 45 | llm_model = LLMFactory.create_llm(model_type=LLM_MODEL_TYPE, model_name=LLM_MODEL_NAME, api_key=api_key) 46 | print(f"Instantiating model type: {LLM_MODEL_TYPE} | model name: {LLM_MODEL_NAME} | embedding model: {EMBEDDING_MODEL_NAME}") 47 | 48 | initialize_components() 49 | 50 | @app.route('/') 51 | def index(): 52 | return render_template('index.html') 53 | 54 | @app.route('/admin') 55 | def admin(): 56 | return render_template('admin.html', 57 | llm_model_name=LLM_MODEL_NAME, 58 | llm_model_type=LLM_MODEL_TYPE, 59 | embedding_model_name=EMBEDDING_MODEL_NAME, 60 | num_relevant_docs=NUM_RELEVANT_DOCS, 61 | openai_api_key=OPENAI_API_KEY) 62 | 63 | @app.route('/update_settings', methods=['POST']) 64 | def update_settings(): 65 | global LLM_MODEL_NAME, LLM_MODEL_TYPE, EMBEDDING_MODEL_NAME, NUM_RELEVANT_DOCS, OPENAI_API_KEY 66 | LLM_MODEL_NAME = request.form['llm_model_name'] 67 | LLM_MODEL_TYPE = request.form['llm_model_type'] 68 | EMBEDDING_MODEL_NAME = request.form['embedding_model_name'] 69 | NUM_RELEVANT_DOCS = int(request.form['num_relevant_docs']) 70 | OPENAI_API_KEY = request.form['openai_api_key'] 71 | 72 | # Update the .env file 73 | set_key(ENV_PATH, 'LLM_MODEL_NAME', LLM_MODEL_NAME) 74 | set_key(ENV_PATH, 'LLM_MODEL_TYPE', LLM_MODEL_TYPE) 75 | set_key(ENV_PATH, 'EMBEDDING_MODEL_NAME', EMBEDDING_MODEL_NAME) 76 | set_key(ENV_PATH, 'NUM_RELEVANT_DOCS', str(NUM_RELEVANT_DOCS)) 77 | set_key(ENV_PATH, 'OPENAI_API_KEY', OPENAI_API_KEY) 78 | 79 | # Reinitialize the components (llm and retriever objects) 80 | initialize_components() 81 | print(f"Updating model type: {LLM_MODEL_TYPE} | model name: {LLM_MODEL_NAME} | embedding model: {EMBEDDING_MODEL_NAME}") 82 | return redirect(url_for('admin')) 83 | 84 | @app.route('/query', methods=['POST']) 85 | def query(): 86 | query_text = request.json['query_text'] 87 | # Retrieve and format results 88 | results = retriever.query(query_text, k=NUM_RELEVANT_DOCS) 89 | enhanced_context_text, sources = retriever.format_results(results) 90 | # Generate response from LLM 91 | llm_response = llm_model.generate_response(context=enhanced_context_text, question=query_text) 92 | sources_html = "
".join(sources) 93 | response_text = f"{llm_response}

Sources:
{sources_html}

Response given by: {LLM_MODEL_NAME}" 94 | return jsonify(response=response_text) 95 | 96 | if __name__ == "__main__": 97 | app.run(debug=True) -------------------------------------------------------------------------------- /data/alpha_society.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/enricollen/rag-conversational-agent/66e4bce9b1d5b94538f6507afefea5d60c05b5c4/data/alpha_society.pdf -------------------------------------------------------------------------------- /data/beta_society.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/enricollen/rag-conversational-agent/66e4bce9b1d5b94538f6507afefea5d60c05b5c4/data/beta_society.pdf -------------------------------------------------------------------------------- /data/gamma_society.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/enricollen/rag-conversational-agent/66e4bce9b1d5b94538f6507afefea5d60c05b5c4/data/gamma_society.pdf -------------------------------------------------------------------------------- /embeddings/embeddings.py: -------------------------------------------------------------------------------- 1 | from embeddings.openai_embeddings import OpenAIEmbeddings 2 | from langchain_community.embeddings.ollama import OllamaEmbeddings 3 | from langchain_community.embeddings.bedrock import BedrockEmbeddings 4 | 5 | class Embeddings: 6 | def __init__(self, model_name: str, api_key: str = None): 7 | self.model_name = model_name 8 | self.api_key = api_key 9 | 10 | def get_embedding_function(self): 11 | if self.model_name == "ollama": 12 | return OllamaEmbeddings(model="mxbai-embed-large") 13 | elif self.model_name == "openai": 14 | if not self.api_key: 15 | raise ValueError("OpenAI API key must be provided for OpenAI embeddings") 16 | return OpenAIEmbeddings(api_key=self.api_key) 17 | elif self.model_name == "bedrock": 18 | return BedrockEmbeddings(credentials_profile_name="default", region_name="us-east-1") 19 | else: 20 | raise ValueError(f"Unsupported embedding model: {self.model_name}") -------------------------------------------------------------------------------- /embeddings/openai_embeddings.py: -------------------------------------------------------------------------------- 1 | from openai import OpenAI 2 | 3 | class OpenAIEmbeddings: 4 | """ 5 | class that implements two methods to be called from Chroma 6 | """ 7 | def __init__(self, api_key: str): 8 | self.client = OpenAI(api_key=api_key) 9 | 10 | def embed_documents(self, texts: list[str]): 11 | embeddings = [] 12 | for text in texts: 13 | response = self.client.embeddings.create(input=text, model="text-embedding-3-small") 14 | embeddings.append(response.data[0].embedding) 15 | return embeddings 16 | 17 | def embed_query(self, text: str): 18 | response = self.client.embeddings.create(input=text, model="text-embedding-3-small") 19 | return response.data[0].embedding -------------------------------------------------------------------------------- /llm/llm.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | from langchain_community.llms.ollama import Ollama 3 | from openai import OpenAI 4 | from langchain.prompts import ChatPromptTemplate 5 | import anthropic 6 | 7 | PROMPT_TEMPLATE = """ 8 | Basing only on the following context: 9 | 10 | {context} 11 | 12 | --- 13 | 14 | Answer the following question: {question} 15 | Avoid to start the answer saying that you are basing on the provided context and go straight with the response. 16 | """ 17 | 18 | class LLM(ABC): 19 | def __init__(self, model_name: str): 20 | self.model_name = model_name 21 | 22 | @abstractmethod 23 | def invoke(self, prompt: str) -> str: 24 | pass 25 | 26 | def generate_response(self, context: str, question: str) -> str: 27 | prompt_template = ChatPromptTemplate.from_template(PROMPT_TEMPLATE) 28 | prompt = prompt_template.format(context=context, question=question) 29 | response_text = self.invoke(prompt) 30 | return response_text 31 | 32 | class OllamaModel(LLM): 33 | def __init__(self, model_name: str): 34 | super().__init__(model_name) 35 | self.model = Ollama(model=model_name) 36 | 37 | def invoke(self, prompt: str) -> str: 38 | return self.model.invoke(prompt) 39 | 40 | class GPTModel(LLM): 41 | def __init__(self, model_name: str, api_key: str): 42 | super().__init__(model_name) 43 | self.client = OpenAI(api_key=api_key) 44 | 45 | def invoke(self, prompt: str) -> str: 46 | messages = [ 47 | #{"role": "system", "content": "You are a helpful assistant."}, 48 | {"role": "user", "content": prompt} 49 | ] 50 | response = self.client.chat.completions.create( 51 | model=self.model_name, 52 | messages=messages, 53 | max_tokens=150, 54 | n=1, 55 | stop=None, 56 | temperature=0.7, 57 | ) 58 | return response.choices[0].message.content.strip() 59 | 60 | class AnthropicModel(LLM): 61 | def __init__(self, model_name: str, api_key: str): 62 | super().__init__(model_name) 63 | self.client = anthropic.Anthropic(api_key=api_key) 64 | 65 | def invoke(self, prompt: str) -> str: 66 | messages = [ 67 | { 68 | "role": "user", 69 | "content": [ 70 | { 71 | "type": "text", 72 | "text": prompt 73 | } 74 | ] 75 | } 76 | ] 77 | response = self.client.messages.create( 78 | model=self.model_name, 79 | max_tokens=1000, 80 | temperature=0.7, 81 | messages=messages 82 | ) 83 | # Extract the plain text from the response content 84 | text_blocks = response.content 85 | plain_text = "\n".join(block.text for block in text_blocks if block.type == 'text') 86 | return plain_text -------------------------------------------------------------------------------- /llm/llm_factory.py: -------------------------------------------------------------------------------- 1 | from llm.llm import LLM, GPTModel, OllamaModel, AnthropicModel 2 | 3 | class LLMFactory: 4 | @staticmethod 5 | def create_llm(model_type: str, model_name: str, api_key: str = None) -> LLM: 6 | if model_type == 'ollama': 7 | return OllamaModel(model_name) 8 | elif model_type == 'gpt': 9 | return GPTModel(model_name, api_key) 10 | elif model_type == 'claude': 11 | return AnthropicModel(model_name, api_key) 12 | else: 13 | raise ValueError(f"Unsupported model type: {model_type}") -------------------------------------------------------------------------------- /populate_database.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import shutil 4 | from embeddings.embeddings import Embeddings 5 | from langchain_community.document_loaders import PyPDFDirectoryLoader 6 | from langchain_text_splitters import RecursiveCharacterTextSplitter 7 | from langchain.schema import Document 8 | from langchain_chroma import Chroma 9 | from dotenv import load_dotenv 10 | 11 | load_dotenv() 12 | 13 | OPENAI_API_KEY = os.getenv('OPENAI_API_KEY') 14 | DATA_PATH = os.getenv('DATA_PATH') 15 | VECTOR_DB_OPENAI_PATH = os.getenv('VECTOR_DB_OPENAI_PATH') 16 | VECTOR_DB_OLLAMA_PATH = os.getenv('VECTOR_DB_OLLAMA_PATH') 17 | 18 | def main(): 19 | # check whether the database should be cleared or not (using the --clear flag) 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument("--reset", nargs="?", const="both", choices=["ollama", "openai", "both"], help="Reset the database.") 22 | parser.add_argument("--embedding-model", type=str, default="openai", help="The embedding model to use (ollama or openai).") 23 | args = parser.parse_args() 24 | 25 | if args.reset: 26 | reset_databases(args.reset) 27 | return 28 | 29 | # choose the embedding model 30 | embeddings = Embeddings(model_name=args.embedding_model, api_key=OPENAI_API_KEY) 31 | embedding_function = embeddings.get_embedding_function() 32 | 33 | # determine the correct path for the database based on the embedding model 34 | if args.embedding_model == "openai": 35 | db_path = VECTOR_DB_OPENAI_PATH 36 | elif args.embedding_model == "ollama": 37 | db_path = VECTOR_DB_OLLAMA_PATH 38 | else: 39 | raise ValueError("Unsupported embedding model specified.") 40 | 41 | # load the existing database 42 | db = Chroma( 43 | persist_directory=db_path, embedding_function=embedding_function 44 | ) 45 | 46 | # create (or update) the data store 47 | documents = load_documents() 48 | chunks = split_documents(documents) 49 | add_to_chroma(chunks, db) 50 | 51 | def reset_databases(reset_choice): 52 | if reset_choice in ["openai", "both"]: 53 | if ask_to_clear_database("openai"): 54 | print("✨ Rebuilding OpenAI Database") 55 | clear_database("openai") 56 | rebuild_database("openai") 57 | 58 | if reset_choice in ["ollama", "both"]: 59 | if ask_to_clear_database("ollama"): 60 | print("✨ Rebuilding Ollama Database") 61 | clear_database("ollama") 62 | rebuild_database("ollama") 63 | 64 | def ask_to_clear_database(embedding_model): 65 | response = input(f"Do you want to override the existing {embedding_model} database? (yes/no): ").strip().lower() 66 | return response == 'yes' 67 | 68 | def load_documents(): 69 | document_loader = PyPDFDirectoryLoader(DATA_PATH) 70 | return document_loader.load() 71 | 72 | def split_documents(documents: list[Document]): 73 | text_splitter = RecursiveCharacterTextSplitter( 74 | chunk_size=800, 75 | chunk_overlap=80, 76 | length_function=len, 77 | is_separator_regex=False, 78 | ) 79 | return text_splitter.split_documents(documents) 80 | 81 | def add_to_chroma(chunks: list[Document], db): 82 | # calculate Page IDs 83 | chunks_with_ids = calculate_chunk_ids(chunks) 84 | 85 | # Add or Update the documents 86 | existing_items = db.get(include=[]) # IDs are always included by default 87 | existing_ids = set(existing_items["ids"]) 88 | print(f"Number of existing documents in DB: {len(existing_ids)}") 89 | 90 | # only add documents that don't exist in the DB 91 | new_chunks = [] 92 | for chunk in chunks_with_ids: 93 | if chunk.metadata["id"] not in existing_ids: 94 | new_chunks.append(chunk) 95 | 96 | if len(new_chunks): 97 | print(f"➕ Adding new documents: {len(new_chunks)}") 98 | new_chunk_ids = [chunk.metadata["id"] for chunk in new_chunks] 99 | db.add_documents(new_chunks, ids=new_chunk_ids) 100 | else: 101 | print("✅ No new documents to add") 102 | 103 | def calculate_chunk_ids(chunks): 104 | # create IDs like "data/alpha_society.pdf:6:2" 105 | # Page Source : Page Number : Chunk Index 106 | last_page_id = None 107 | current_chunk_index = 0 108 | 109 | for chunk in chunks: 110 | source = chunk.metadata.get("source") 111 | page = chunk.metadata.get("page") 112 | current_page_id = f"{source}:{page}" 113 | 114 | # if the page ID is the same as the last one, increment the index 115 | if current_page_id == last_page_id: 116 | current_chunk_index += 1 117 | else: 118 | current_chunk_index = 0 119 | 120 | # calculate the unique chunk ID 121 | chunk_id = f"{current_page_id}:{current_chunk_index}" 122 | last_page_id = current_page_id 123 | 124 | # add it to the page meta-data 125 | chunk.metadata["id"] = chunk_id 126 | 127 | return chunks 128 | 129 | def clear_database(embedding_model): 130 | if embedding_model == "openai": 131 | db_path = VECTOR_DB_OPENAI_PATH 132 | elif embedding_model == "ollama": 133 | db_path = VECTOR_DB_OLLAMA_PATH 134 | else: 135 | raise ValueError("Unsupported embedding model specified.") 136 | 137 | if os.path.exists(db_path): 138 | shutil.rmtree(db_path) 139 | 140 | def rebuild_database(embedding_model): 141 | if embedding_model == "openai": 142 | embeddings = Embeddings(model_name="openai", api_key=OPENAI_API_KEY) 143 | db_path = VECTOR_DB_OPENAI_PATH 144 | elif embedding_model == "ollama": 145 | embeddings = Embeddings(model_name="ollama", api_key=OPENAI_API_KEY) 146 | db_path = VECTOR_DB_OLLAMA_PATH 147 | else: 148 | raise ValueError("Unsupported embedding model specified.") 149 | 150 | embedding_function = embeddings.get_embedding_function() 151 | 152 | # load the existing database 153 | db = Chroma( 154 | persist_directory=db_path, embedding_function=embedding_function 155 | ) 156 | 157 | # create (or update) the data store 158 | documents = load_documents() 159 | chunks = split_documents(documents) 160 | add_to_chroma(chunks, db) 161 | 162 | if __name__ == "__main__": 163 | main() -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pypdf 2 | langchain 3 | chromadb # vector db 4 | pytest 5 | boto3 6 | langchain_community 7 | anthropic 8 | langchain-chroma -------------------------------------------------------------------------------- /retrieval/rag_retriever.py: -------------------------------------------------------------------------------- 1 | from langchain_chroma import Chroma 2 | from langchain.schema import Document 3 | from embeddings.embeddings import Embeddings 4 | 5 | class RAGRetriever: 6 | def __init__(self, vector_db_path: str, embedding_model_name: str, api_key: str): 7 | self.vector_db_path = vector_db_path 8 | embeddings = Embeddings(model_name=embedding_model_name, api_key=api_key) 9 | self.embedding_function = embeddings.get_embedding_function() 10 | self.db = Chroma(persist_directory=self.vector_db_path, embedding_function=self.embedding_function) 11 | 12 | def query(self, query_text: str, k: int = 4): 13 | # compute similarity between embeddings of query and of pdf text chunks 14 | results = self.db.similarity_search_with_score(query_text, k=k) 15 | return results 16 | 17 | def format_results(self, results: list[tuple[Document, float]]): 18 | enhanced_context_text = "\n\n---\n\n".join([doc.page_content for doc, _score in results]) 19 | sources = set(self.format_source(doc.metadata) for doc, _score in results) # set to ensure uniqueness 20 | return enhanced_context_text, list(sources) 21 | 22 | def format_source(self, metadata: dict): 23 | source = metadata.get("source", "unknown") 24 | page = metadata.get("page", "unknown") 25 | filename = source.split("\\")[-1] # extract filename 26 | return f"{filename} page {page}" -------------------------------------------------------------------------------- /static/admin_settings.js: -------------------------------------------------------------------------------- 1 | // admin_settings.js 2 | 3 | function updateApiKeyField() { 4 | const llmModelType = document.getElementById('llm_model_type').value; 5 | const openaiApiKeyField = document.getElementById('openai_api_key_field'); 6 | const claudeApiKeyField = document.getElementById('claude_api_key_field'); 7 | const embeddingModelName = document.getElementById('embedding_model_name'); 8 | 9 | if (llmModelType === 'gpt') { 10 | openaiApiKeyField.style.display = 'block'; 11 | claudeApiKeyField.style.display = 'none'; 12 | updateEmbeddingModelNames('openai'); 13 | } else if (llmModelType === 'claude') { 14 | openaiApiKeyField.style.display = 'none'; 15 | claudeApiKeyField.style.display = 'block'; 16 | updateEmbeddingModelNames('ollama'); 17 | } else if (llmModelType === 'ollama') { 18 | openaiApiKeyField.style.display = 'none'; 19 | claudeApiKeyField.style.display = 'none'; 20 | updateEmbeddingModelNames('ollama'); 21 | } 22 | 23 | updateLlmModelNames(); 24 | } 25 | 26 | function updateLlmModelNames() { 27 | const llmModelType = document.getElementById('llm_model_type').value; 28 | const llmModelName = document.getElementById('llm_model_name'); 29 | const otherLlmModelName = document.getElementById('llm_model_name_other'); 30 | 31 | // Clear current options 32 | llmModelName.innerHTML = ''; 33 | 34 | let options = []; 35 | 36 | if (llmModelType === 'gpt') { 37 | options = [ 38 | { text: 'GPT 3.5', value: 'gpt-3.5-turbo' }, 39 | { text: 'GPT-4o', value: 'gpt-4o' }, 40 | { text: 'GPT-4', value: 'gpt-4' } 41 | ]; 42 | } else if (llmModelType === 'ollama') { 43 | options = [ 44 | { text: 'Llama3', value: 'llama3:8b' }, 45 | { text: 'Gemma 2', value: 'gemma2' }, 46 | { text: 'Mistral', value: 'mistral:7b' }, 47 | { text: 'Other', value: 'other' } 48 | ]; 49 | } else if (llmModelType === 'claude') { 50 | options = [ 51 | { text: 'Claude 3.5 Sonnet', value: 'claude-3-5-sonnet-20240620' }, 52 | { text: 'Claude 3 Opus', value: 'claude-3-opus-20240229' }, 53 | { text: 'Claude 3 Sonnet', value: 'claude-3-sonnet-20240229' }, 54 | { text: 'Claude 3 Haiku', value: 'claude-3-haiku-20240307' } 55 | ]; 56 | } 57 | 58 | options.forEach(option => { 59 | const opt = document.createElement('option'); 60 | opt.value = option.value; 61 | opt.textContent = option.text; 62 | llmModelName.appendChild(opt); 63 | }); 64 | 65 | // Show or hide the "Other" input field based on the initial value 66 | if (llmModelName.value === 'other') { 67 | otherLlmModelName.style.display = 'block'; 68 | } else { 69 | otherLlmModelName.style.display = 'none'; 70 | } 71 | 72 | // Add event listener to show/hide "Other" input field 73 | llmModelName.addEventListener('change', function() { 74 | if (llmModelName.value === 'other') { 75 | otherLlmModelName.style.display = 'block'; 76 | } else { 77 | otherLlmModelName.style.display = 'none'; 78 | } 79 | }); 80 | } 81 | 82 | function updateEmbeddingModelNames(selectedValue) { 83 | const embeddingModelName = document.getElementById('embedding_model_name'); 84 | 85 | // Clear current options 86 | embeddingModelName.innerHTML = ''; 87 | 88 | const option = document.createElement('option'); 89 | option.value = selectedValue; 90 | option.textContent = selectedValue.charAt(0).toUpperCase() + selectedValue.slice(1); 91 | embeddingModelName.appendChild(option); 92 | 93 | // Automatically select the option 94 | embeddingModelName.value = selectedValue; 95 | } 96 | 97 | function handleFormSubmission(event) { 98 | const llmModelName = document.getElementById('llm_model_name'); 99 | const llmModelNameOther = document.getElementById('llm_model_name_other'); 100 | const embeddingModelName = document.getElementById('embedding_model_name'); 101 | 102 | // If "Other" is selected, set the value to the text input value 103 | if (llmModelName.value === 'other') { 104 | llmModelName.value = llmModelNameOther.value; 105 | } 106 | 107 | // Enable embedding_model_name so its value can be submitted 108 | embeddingModelName.disabled = false; 109 | } 110 | 111 | window.onload = function() { 112 | updateApiKeyField(); 113 | updateLlmModelNames(); 114 | updateEmbeddingModelNames('openai'); // Default to OpenAI on load 115 | }; 116 | -------------------------------------------------------------------------------- /static/demo_img/rag_demo.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/enricollen/rag-conversational-agent/66e4bce9b1d5b94538f6507afefea5d60c05b5c4/static/demo_img/rag_demo.mp4 -------------------------------------------------------------------------------- /static/demo_img/screenshot_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/enricollen/rag-conversational-agent/66e4bce9b1d5b94538f6507afefea5d60c05b5c4/static/demo_img/screenshot_1.jpg -------------------------------------------------------------------------------- /static/demo_img/screenshot_2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/enricollen/rag-conversational-agent/66e4bce9b1d5b94538f6507afefea5d60c05b5c4/static/demo_img/screenshot_2.jpg -------------------------------------------------------------------------------- /static/demo_img/screenshot_3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/enricollen/rag-conversational-agent/66e4bce9b1d5b94538f6507afefea5d60c05b5c4/static/demo_img/screenshot_3.jpg -------------------------------------------------------------------------------- /static/demo_img/screenshot_4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/enricollen/rag-conversational-agent/66e4bce9b1d5b94538f6507afefea5d60c05b5c4/static/demo_img/screenshot_4.jpg -------------------------------------------------------------------------------- /static/styles.css: -------------------------------------------------------------------------------- 1 | body { 2 | font-family: Arial, sans-serif; 3 | background-color: #f4f4f9; 4 | color: #333; 5 | margin: 0; 6 | padding: 20px; 7 | display: flex; 8 | flex-direction: column; 9 | align-items: center; 10 | justify-content: center; 11 | height: 100vh; 12 | } 13 | 14 | .header { 15 | display: flex; 16 | justify-content: space-between; 17 | width: 100%; 18 | max-width: 600px; 19 | align-items: center; 20 | margin-bottom: 20px; 21 | } 22 | 23 | .settings-link { 24 | font-size: 24px; 25 | text-decoration: none; 26 | color: #333; 27 | margin-left: auto; 28 | } 29 | 30 | .settings-link:hover { 31 | color: #007BFF; 32 | } 33 | 34 | #chat-container { 35 | border: 1px solid #ddd; 36 | border-radius: 8px; 37 | background: #fff; 38 | max-width: 600px; 39 | width: 100%; 40 | height: 500px; /* Fixed height */ 41 | box-shadow: 0 0 10px rgba(0, 0, 0, 0.1); 42 | display: flex; 43 | flex-direction: column; 44 | overflow: hidden; 45 | } 46 | 47 | #chat { 48 | padding: 20px; 49 | flex-grow: 1; 50 | overflow-y: auto; 51 | } 52 | 53 | .message { 54 | margin: 10px 0; 55 | } 56 | 57 | .message.user { 58 | text-align: right; 59 | } 60 | 61 | .message.bot { 62 | text-align: left; 63 | color: #007BFF; 64 | } 65 | 66 | .input-container { 67 | display: flex; 68 | border-top: 1px solid #ddd; 69 | padding: 10px; 70 | background: #f9f9f9; 71 | } 72 | 73 | #query { 74 | flex-grow: 1; 75 | padding: 10px; 76 | border: 1px solid #ddd; 77 | border-radius: 4px; 78 | font-size: 16px; 79 | } 80 | 81 | button { 82 | padding: 10px 20px; 83 | margin-left: 10px; 84 | border: none; 85 | border-radius: 4px; 86 | background-color: #007BFF; 87 | color: #fff; 88 | font-size: 16px; 89 | cursor: pointer; 90 | } 91 | 92 | button:hover { 93 | background-color: #0056b3; 94 | } 95 | 96 | h1 { 97 | color: #333; 98 | } 99 | 100 | label { 101 | font-weight: bold; 102 | } 103 | 104 | #admin-container { 105 | margin-bottom: 20px; 106 | width: 100%; 107 | box-sizing: border-box; 108 | } 109 | 110 | #admin-container label, 111 | #admin-container input { 112 | display: block; 113 | margin-bottom: 10px; 114 | width: 100%; 115 | max-width: 600px; 116 | } 117 | 118 | button:hover { 119 | background-color: #45a049; 120 | } 121 | 122 | a { 123 | display: block; 124 | margin-top: 10px; 125 | color: #4CAF50; 126 | text-decoration: none; 127 | } 128 | 129 | a:hover { 130 | text-decoration: underline; 131 | } 132 | 133 | .admin-link { 134 | display: block; 135 | margin-top: 20px; 136 | color: #4CAF50; 137 | text-decoration: none; 138 | font-weight: bold; 139 | text-align: center; 140 | } 141 | 142 | .admin-link:hover { 143 | text-decoration: underline; 144 | } -------------------------------------------------------------------------------- /templates/admin.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | ⚙️ Admin Settings 7 | 8 | 9 | 10 | 11 |

⚙️ Admin Settings

12 |
13 |
14 | 15 |
20 | 21 | 22 | 25 |
26 | 27 | 28 |
31 | 32 | 33 |
34 | 35 | 39 | 40 | 44 | 45 | 46 |
47 |
48 | 🔙 Go to Chat 49 | 50 | -------------------------------------------------------------------------------- /templates/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 🤖 Chat with Documents 7 | 8 | 9 | 10 | 11 |
12 |

🤖 Chat with Documents

13 | ⚙️ 14 |
15 |
16 |
17 |
18 | 19 | 20 |
21 |
22 | 23 | 41 | 42 | -------------------------------------------------------------------------------- /test_rag.py: -------------------------------------------------------------------------------- 1 | import os 2 | from dotenv import load_dotenv 3 | from llm.llm import GPTModel, OllamaModel 4 | from llm.llm_factory import LLMFactory 5 | from retrieval.rag_retriever import RAGRetriever 6 | 7 | load_dotenv() 8 | 9 | VECTOR_DB_OPENAI_PATH = os.getenv('VECTOR_DB_OPENAI_PATH') 10 | VECTOR_DB_OLLAMA_PATH = os.getenv('VECTOR_DB_OLLAMA_PATH') 11 | LLM_MODEL_NAME = os.getenv('LLM_MODEL_NAME') # 'gpt-3.5-turbo', 'GPT-4o' or local LLM like 'llama3:8b' 12 | LLM_MODEL_TYPE = os.getenv('LLM_MODEL_TYPE') # 'ollama', 'gpt', 'claude' 13 | EMBEDDING_MODEL_NAME = os.getenv('EMBEDDING_MODEL_NAME') # 'ollama' or 'openai' 14 | NUM_RELEVANT_DOCS = int(os.getenv('NUM_RELEVANT_DOCS')) 15 | OPENAI_API_KEY = os.getenv('OPENAI_API_KEY') 16 | CLAUDE_API_KEY = os.getenv('CLAUDE_API_KEY') 17 | 18 | EVAL_PROMPT = """ 19 | Expected Response: {expected_response} 20 | Actual Response: {actual_response} 21 | --- 22 | (Answer with 'true' or 'false') Does the actual response match the expected response? 23 | """ 24 | 25 | def get_vector_db_path(embedding_model_name): 26 | if embedding_model_name == "openai": 27 | return VECTOR_DB_OPENAI_PATH 28 | elif embedding_model_name == "ollama": 29 | return VECTOR_DB_OLLAMA_PATH 30 | else: 31 | raise ValueError(f"Unsupported embedding model: {embedding_model_name}") 32 | 33 | def get_api_key(embedding_model_name): 34 | if embedding_model_name == "openai": 35 | return OPENAI_API_KEY 36 | else: 37 | return CLAUDE_API_KEY 38 | 39 | # Initialize the retriever and the LLM once 40 | vector_db_path = get_vector_db_path(EMBEDDING_MODEL_NAME) 41 | api_key = get_api_key(EMBEDDING_MODEL_NAME) 42 | 43 | retriever = RAGRetriever(vector_db_path=vector_db_path, embedding_model_name=EMBEDDING_MODEL_NAME, api_key=api_key) 44 | llm_model = LLMFactory.create_llm(model_type=LLM_MODEL_TYPE, model_name=LLM_MODEL_NAME, api_key=api_key) 45 | 46 | print(LLM_MODEL_TYPE) 47 | print(LLM_MODEL_NAME) 48 | print(EMBEDDING_MODEL_NAME) 49 | 50 | def test_num_employees_alpha(): 51 | assert query_and_validate( 52 | question="How many people are in the head staff inside the alpha corporation? (Answer with the number only)", 53 | expected_response="4", 54 | retriever=retriever, 55 | llm_model=llm_model 56 | ) 57 | 58 | def test_company_field_beta(): 59 | assert query_and_validate( 60 | question="What is the field in which the beta enterprises operate? (Answer with few words)", 61 | expected_response="biotechnology and pharmaceuticals", 62 | retriever=retriever, 63 | llm_model=llm_model 64 | ) 65 | 66 | def test_foundation_year_gamma(): 67 | assert query_and_validate( 68 | question="When was the gamma innovation society founded? (Answer with the number only)", 69 | expected_response="2015", 70 | retriever=retriever, 71 | llm_model=llm_model 72 | ) 73 | 74 | def query_and_validate(question: str, expected_response: str, retriever, llm_model): 75 | """ 76 | Queries the language model (LLM) to get a response for the given question, and then validates this response 77 | against the expected response using the LLM itself. 78 | 79 | Parameters: 80 | question (str): The question to be asked to the LLM. 81 | expected_response (str): The expected response to validate against. 82 | retriever: An instance of the RAGRetriever used to retrieve relevant documents. 83 | llm_model: An instance of the LLM to generate responses. 84 | 85 | Returns: 86 | bool: True if the LLM validates that the actual response matches the expected response, False otherwise. 87 | """ 88 | results = retriever.query(question, k=NUM_RELEVANT_DOCS) 89 | enhanced_context_text, sources = retriever.format_results(results) 90 | 91 | # Generate response from LLM 92 | response_text = llm_model.generate_response(context=enhanced_context_text, question=question) 93 | 94 | # Use the same LLM also for response validation 95 | prompt = EVAL_PROMPT.format( 96 | expected_response=expected_response, actual_response=response_text 97 | ) 98 | 99 | evaluation_results_str = llm_model.invoke(prompt) 100 | evaluation_results_str_cleaned = evaluation_results_str.strip().lower() 101 | 102 | print(question) 103 | print(prompt) 104 | 105 | if "true" in evaluation_results_str_cleaned: 106 | # Print response in Green if it is correct. 107 | print("\033[92m" + f"Response: {evaluation_results_str_cleaned}" + "\033[0m") 108 | return True 109 | elif "false" in evaluation_results_str_cleaned: 110 | # Print response in Red if it is incorrect. 111 | print("\033[91m" + f"Response: {evaluation_results_str_cleaned}" + "\033[0m") 112 | return False 113 | else: 114 | raise ValueError( 115 | f"Invalid evaluation result. Cannot determine if 'true' or 'false'." 116 | ) 117 | 118 | if __name__ == "__main__": 119 | """ 120 | to run tests type: 'pytest test/test_rag.py -s' in cmd 121 | """ 122 | test_num_employees_alpha() 123 | test_company_field_beta() 124 | test_foundation_year_gamma() --------------------------------------------------------------------------------