├── .env-sample ├── README.md ├── app.py └── requirements.txt /.env-sample: -------------------------------------------------------------------------------- 1 | OPENAI_API_KEY="" 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Golden-Retriever: A Framework for High-Fidelity Agentic Retrieval Augmented Generation 2 | 3 | Golden-Retriever is a framework for high-fidelity retrieval augmented generation in industrial knowledge bases. It integrates jargon identification, context recognition, and question augmentation to overcome challenges in specialized domains. 4 | 5 | ## Features 6 | 7 | - Jargon identification and definition retrieval 8 | - Context recognition for domain-specific questions 9 | - Dynamic question augmentation 10 | - Retrieval-augmented generation using DSPy 11 | - Adaptive answer generation with reasoning 12 | - Extensible and customizable architecture 13 | 14 | ## Installation 15 | 16 | ```bash 17 | git clone https://github.com/yourusername/golden-retriever.git 18 | cd golden-retriever 19 | pip install -r requirements.txt 20 | ``` 21 | 22 | ## Configuration 23 | 24 | Set your OpenAI API key in a `.env` file: 25 | 26 | ``` 27 | OPENAI_API_KEY=your_api_key_here 28 | ``` 29 | 30 | ## Usage 31 | 32 | ```python 33 | from golden_retriever import GoldenRetrieverRAG 34 | 35 | # Initialize the framework 36 | rag = GoldenRetrieverRAG() 37 | 38 | # Set up the necessary modules 39 | rag.identify_jargon = dspy.Predict("question -> jargon_terms") 40 | rag.identify_context = dspy.Predict("question -> context") 41 | rag.augment_question = dspy.ChainOfThought("question, jargon_definitions, context -> augmented_question") 42 | rag.generate_answer = ImprovedAnswerGenerator() 43 | 44 | # Compile the RAG instance (optional) 45 | compiled_rag = teleprompter.compile(rag, trainset=trainset, valset=devset) 46 | 47 | # Ask a question 48 | question = "What is the role of wear leveling in SSDs?" 49 | result = compiled_rag(question) 50 | 51 | print(result.answer) 52 | ``` 53 | 54 | ## Training and Evaluation 55 | 56 | The framework includes functionality for generating training data, compiling the RAG instance using teleprompter, and evaluating the model's performance. 57 | 58 | ## Interactive Mode 59 | 60 | Run the script to enter an interactive mode where you can ask questions and receive detailed responses, including jargon definitions, context, reasoning, and retrieved passages. 61 | 62 | ## Contributing 63 | 64 | Contributions are welcome! Please feel free to submit a Pull Request. 65 | 66 | ## License 67 | 68 | This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details. 69 | 70 | ## Acknowledgements 71 | 72 | This implementation is based on the DSPy library and the concepts from the paper "Golden-Retriever: High-Fidelity Agentic Retrieval Augmented Generation for Industrial Knowledge Base". 73 | -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | import dspy 2 | import os 3 | from dotenv import load_dotenv 4 | import asyncio 5 | import aiohttp 6 | from cachetools import TTLCache 7 | import logging 8 | import json 9 | import wikipedia 10 | import time 11 | import random 12 | import requests 13 | from dspy.teleprompt import BootstrapFewShotWithRandomSearch 14 | from dspy.evaluate.evaluate import Evaluate 15 | from dspy import ColBERTv2 16 | import backoff 17 | import nest_asyncio 18 | import copy 19 | from rouge import Rouge 20 | from sentence_transformers import SentenceTransformer, util 21 | 22 | # Apply nest_asyncio to allow nested event loops 23 | nest_asyncio.apply() 24 | 25 | # Load environment variables and setup logging 26 | load_dotenv() 27 | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') 28 | 29 | # Configure DSPy 30 | llm = dspy.OpenAI( 31 | model='gpt-3.5-turbo', 32 | api_key=os.environ['OPENAI_API_KEY'], 33 | max_tokens=2000 34 | ) 35 | dspy.settings.configure(lm=llm) 36 | 37 | # Initialize ColBERTv2 retriever 38 | retriever = ColBERTv2(url='http://20.102.90.50:2017/wiki17_abstracts') 39 | dspy.settings.configure(rm=retriever) 40 | 41 | class QueryJargonDictionary(dspy.Module): 42 | def __init__(self): 43 | super().__init__() 44 | self.cache = TTLCache(maxsize=1000, ttl=3600) 45 | self.rate_limit = 1.0 46 | self.local_dictionary = { 47 | # ... [previous dictionary entries remain unchanged] ... 48 | "Wear leveling": "A technique used in SSDs to distribute write operations evenly across all the flash memory blocks, extending the lifespan of the drive by preventing premature wear-out of specific areas.", 49 | "SSDs": "Solid State Drives, storage devices that use integrated circuit assemblies to store data persistently, offering faster access times and improved reliability compared to traditional hard disk drives.", 50 | "Traditional storage interfaces": "Conventional methods of connecting storage devices to computers, such as SATA (Serial ATA) or SAS (Serial Attached SCSI), which are generally slower and less efficient than newer interfaces like NVMe.", 51 | } 52 | 53 | async def forward(self, jargon_terms): 54 | jargon_definitions = {} 55 | 56 | async with aiohttp.ClientSession() as session: 57 | tasks = [self.get_jargon_definition(term, session) for term in jargon_terms] 58 | results = await asyncio.gather(*tasks) 59 | 60 | for term, definitions in results: 61 | jargon_definitions[term] = definitions 62 | 63 | return jargon_definitions 64 | 65 | @backoff.on_exception(backoff.expo, Exception, max_tries=3) 66 | async def get_jargon_definition(self, term, session): 67 | if term in self.cache: 68 | return term, self.cache[term] 69 | 70 | logging.info(f"Querying for term: {term}") 71 | 72 | # Check local dictionary first 73 | if term.lower() in self.local_dictionary: 74 | self.cache[term] = {"local": self.local_dictionary[term.lower()]} 75 | return term, self.cache[term] 76 | 77 | definitions = { 78 | "wikipedia": await self.query_wikipedia(term, session), 79 | } 80 | 81 | # Remove None values 82 | definitions = {k: v for k, v in definitions.items() if v is not None} 83 | 84 | if not definitions: 85 | # Use GPT-3 as a fallback for definition 86 | definitions["gpt"] = await self.query_gpt(term) 87 | 88 | self.cache[term] = definitions 89 | return term, definitions 90 | 91 | @backoff.on_exception(backoff.expo, Exception, max_tries=3) 92 | async def query_wikipedia(self, term, session): 93 | try: 94 | await asyncio.sleep(self.rate_limit) # Rate limiting 95 | url = f"https://en.wikipedia.org/api/rest_v1/page/summary/{term}" 96 | async with session.get(url, headers={"User-Agent": "GoldenRetrieverBot/1.0"}) as response: 97 | if response.status == 200: 98 | data = await response.json() 99 | return data.get('extract') 100 | else: 101 | logging.warning(f"Wikipedia returned status {response.status} for term {term}") 102 | except Exception as e: 103 | logging.error(f"Error querying Wikipedia for {term}: {e}") 104 | return None 105 | 106 | async def query_gpt(self, term): 107 | max_retries = 3 108 | for attempt in range(max_retries): 109 | try: 110 | prompt = f"Provide a brief definition for the term '{term}' in the context of computer storage technology:" 111 | response = dspy.Predict("term -> definition")(term=prompt).definition 112 | return response.strip() 113 | except Exception as e: 114 | logging.warning(f"Error querying GPT for {term} (attempt {attempt + 1}/{max_retries}): {e}") 115 | if attempt == max_retries - 1: 116 | logging.error(f"Failed to query GPT for {term} after {max_retries} attempts") 117 | return None 118 | await asyncio.sleep(2 ** attempt) # Exponential backoff 119 | 120 | class ImprovedAnswerGenerator(dspy.Module): 121 | def __init__(self): 122 | super().__init__() 123 | self.generate_answer = dspy.ChainOfThought("original_question, augmented_question, jargon_definitions, context, retrieved_passages -> reasoning, comprehensive_answer") 124 | 125 | def forward(self, original_question, augmented_question, jargon_definitions, context, retrieved_passages): 126 | result = self.generate_answer( 127 | original_question=original_question, 128 | augmented_question=augmented_question, 129 | jargon_definitions=jargon_definitions, 130 | context=context, 131 | retrieved_passages=retrieved_passages 132 | ) 133 | return result.reasoning, result.comprehensive_answer 134 | 135 | class GoldenRetrieverRAG(dspy.Module): 136 | def __init__(self, num_passages=5): 137 | super().__init__() 138 | self.query_jargon_dictionary = QueryJargonDictionary() 139 | self.retrieve = dspy.Retrieve(k=num_passages) 140 | 141 | # Initialize these as None, they will be set later 142 | self.identify_jargon = None 143 | self.identify_context = None 144 | self.augment_question = None 145 | self.generate_answer = None 146 | 147 | async def forward(self, question): 148 | if not all([self.identify_jargon, self.identify_context, self.augment_question, self.generate_answer]): 149 | raise ValueError("Not all required modules have been set.") 150 | 151 | jargon_terms = self.identify_jargon(question=question).jargon_terms.strip().split(',') 152 | jargon_terms = [term.strip() for term in jargon_terms if len(term.strip().split()) <= 3] # Limit to terms with 3 words or less 153 | jargon_definitions = await self.query_jargon_dictionary(jargon_terms) 154 | context = self.identify_context(question=question).context.strip() 155 | 156 | augmented_question = self.augment_question( 157 | question=question, 158 | jargon_definitions=json.dumps(jargon_definitions), 159 | context=context 160 | ).augmented_question.strip() 161 | 162 | retrieved_passages = self.retrieve(augmented_question).passages 163 | 164 | reasoning, answer = self.generate_answer( 165 | original_question=question, 166 | augmented_question=augmented_question, 167 | jargon_definitions=json.dumps(jargon_definitions), 168 | context=context, 169 | retrieved_passages=json.dumps(retrieved_passages) 170 | ) 171 | 172 | return dspy.Prediction( 173 | original_question=question, 174 | augmented_question=augmented_question, 175 | jargon_definitions=jargon_definitions, 176 | context=context, 177 | reasoning=reasoning, 178 | answer=answer, 179 | retrieved_passages=retrieved_passages 180 | ) 181 | 182 | def __call__(self, question): 183 | return asyncio.run(self.forward(question)) 184 | 185 | def generate_and_load_trainset(num_examples=20): 186 | questions = [ 187 | "What is Flash Translation Layer (FTL) in computer storage technology?", 188 | "How does Error Correction Code (ECC) work in data storage?", 189 | "What are the advantages of NVMe over traditional storage interfaces?", 190 | "Explain the concept of wear leveling in SSDs.", 191 | "What is the difference between NOR and NAND flash memory?", 192 | "How does TRIM command improve SSD performance?", 193 | "What is the role of a controller in an SSD?", 194 | "Explain the concept of garbage collection in SSDs.", 195 | "What is over-provisioning in SSDs and why is it important?", 196 | "How does QLC NAND differ from TLC NAND?", 197 | ] 198 | 199 | answers = [ 200 | "FTL is a layer that translates logical block addresses to physical addresses in flash memory, managing wear leveling and garbage collection.", 201 | "ECC detects and corrects errors in data storage by adding redundant bits, improving data reliability.", 202 | "NVMe offers lower latency, higher throughput, and more efficient queuing than traditional interfaces like SATA.", 203 | "Wear leveling distributes write operations evenly across all blocks of an SSD, preventing premature wear-out of specific areas.", 204 | "NOR flash allows random access to any memory location, while NAND flash reads and writes data in blocks, offering higher density.", 205 | "TRIM informs the SSD which blocks of data are no longer in use, improving garbage collection and write performance.", 206 | "An SSD controller manages data transfer between the computer and flash memory chips, handling tasks like wear leveling and error correction.", 207 | "Garbage collection in SSDs consolidates valid data and erases invalid data blocks, freeing up space for new writes.", 208 | "Over-provisioning reserves extra space in an SSD, improving performance, endurance, and allowing for more efficient garbage collection.", 209 | "QLC NAND stores 4 bits per cell, offering higher capacity but lower endurance compared to TLC NAND, which stores 3 bits per cell.", 210 | ] 211 | 212 | trainset = [] 213 | for _ in range(num_examples): 214 | idx = random.randint(0, len(questions) - 1) 215 | example = dspy.Example(question=questions[idx], answer=answers[idx]) 216 | trainset.append(example.with_inputs('question')) # Specify 'question' as input 217 | 218 | return trainset 219 | 220 | def improved_answer_evaluation(example, pred, trace=None, frac=0.5): 221 | rouge = Rouge() 222 | model = SentenceTransformer('all-MiniLM-L6-v2') 223 | 224 | def normalize_text(text): 225 | return ' '.join(text.lower().split()) 226 | 227 | def calculate_rouge(prediction, ground_truth): 228 | scores = rouge.get_scores(prediction, ground_truth) 229 | return scores[0]['rouge-l']['f'] 230 | 231 | def calculate_semantic_similarity(prediction, ground_truth): 232 | embeddings1 = model.encode([prediction], convert_to_tensor=True) 233 | embeddings2 = model.encode([ground_truth], convert_to_tensor=True) 234 | return util.pytorch_cos_sim(embeddings1, embeddings2).item() 235 | 236 | prediction = normalize_text(pred.answer) 237 | ground_truth = normalize_text(example.answer) 238 | 239 | rouge_score = calculate_rouge(prediction, ground_truth) 240 | semantic_similarity = calculate_semantic_similarity(prediction, ground_truth) 241 | 242 | combined_score = (rouge_score + semantic_similarity) / 2 243 | 244 | return combined_score >= frac 245 | 246 | async def async_evaluate(compiled_rag, devset): 247 | results = [] 248 | for example in devset: 249 | pred = await compiled_rag.forward(example.question) 250 | score = improved_answer_evaluation(example, pred) 251 | results.append(score) 252 | return sum(results) / len(results) 253 | 254 | def evaluate(compiled_rag, devset): 255 | return asyncio.run(async_evaluate(compiled_rag, devset)) 256 | 257 | # Run the main event loop 258 | if __name__ == "__main__": 259 | # Setup and compilation 260 | dataset = generate_and_load_trainset() 261 | trainset = dataset[:-5] # Use all but last 5 examples as train set 262 | devset = dataset[-5:] # Use last 5 examples as dev set 263 | 264 | # Define the modules 265 | modules = [ 266 | ("identify_jargon", dspy.Predict("question -> jargon_terms")), 267 | ("identify_context", dspy.Predict("question -> context")), 268 | ("augment_question", dspy.ChainOfThought("question, jargon_definitions, context -> augmented_question")), 269 | ("generate_answer", ImprovedAnswerGenerator()) 270 | ] 271 | 272 | # Create a new GoldenRetrieverRAG instance 273 | rag_instance = GoldenRetrieverRAG() 274 | 275 | # Set the modules 276 | for name, module in modules: 277 | setattr(rag_instance, name, module) 278 | 279 | # Set instructions separately 280 | rag_instance.identify_jargon.instructions = "Identify technical jargon or abbreviations in the following question. Output only individual terms or short phrases, separated by commas." 281 | rag_instance.identify_context.instructions = "Identify the relevant context or domain for the given question." 282 | rag_instance.augment_question.instructions = "Given the original question, jargon definitions, and context, create an augmented version of the question that incorporates this additional information." 283 | rag_instance.generate_answer.generate_answer.instructions = """ 284 | Given the original question, augmented question, jargon definitions, context, and retrieved passages: 285 | 1. Analyze the question and identify the key concepts and requirements. 286 | 2. Review the jargon definitions and context to understand the specific domain knowledge needed. 287 | 3. Examine the retrieved passages and extract relevant information. 288 | 4. Reason step-by-step about how to construct a comprehensive answer. 289 | 5. Synthesize the information into a clear, concise, and accurate answer. 290 | 6. Ensure the answer directly addresses the original question and incorporates relevant jargon and context. 291 | 7. Provide your step-by-step reasoning in the 'reasoning' output. 292 | 8. Provide your final comprehensive answer in the 'comprehensive_answer' output. 293 | """ 294 | 295 | teleprompter = BootstrapFewShotWithRandomSearch( 296 | metric=improved_answer_evaluation, 297 | num_candidate_programs=10, 298 | max_bootstrapped_demos=4, 299 | max_labeled_demos=16, 300 | max_rounds=2, 301 | num_threads=1, # Set this to 1 to avoid multi-threading issues 302 | max_errors=10 303 | ) 304 | 305 | try: 306 | compiled_rag = teleprompter.compile(rag_instance, trainset=trainset, valset=devset) 307 | except Exception as e: 308 | logging.error(f"Error during compilation: {e}") 309 | compiled_rag = rag_instance 310 | 311 | # Save the compiled program 312 | compiled_program_json = compiled_rag.save("compiled_goldenretriever_rag.json") 313 | print("Program saved to compiled_goldenretriever_rag.json") 314 | 315 | # Evaluate the compiled program 316 | try: 317 | results = evaluate(compiled_rag, devset) 318 | print("Evaluation Results:") 319 | print(results) 320 | except Exception as e: 321 | logging.error(f"Error during evaluation: {e}") 322 | print("An error occurred during evaluation. Please check the logs for details.") 323 | 324 | # Interactive loop 325 | while True: 326 | question = input("Enter a question (or 'quit' to exit): ") 327 | if question.lower() == 'quit': 328 | break 329 | try: 330 | prediction = asyncio.run(compiled_rag.forward(question)) 331 | print(f"Original Question: {prediction.original_question}") 332 | print(f"Augmented Question: {prediction.augmented_question}") 333 | print(f"Identified Jargon Terms:") 334 | for term, definitions in prediction.jargon_definitions.items(): 335 | print(f" - {term}:") 336 | for source, definition in definitions.items(): 337 | print(f" {source}: {definition}") 338 | print(f"Identified Context: {prediction.context}") 339 | print(f"Reasoning:") 340 | print(prediction.reasoning) 341 | print(f"Answer: {prediction.answer}") 342 | print("Retrieved Passages:") 343 | for i, passage in enumerate(prediction.retrieved_passages, 1): 344 | print(f"Passage {i}: {passage[:200]}...") # Print first 200 characters of each passage 345 | except Exception as e: 346 | logging.error(f"Error during prediction: {e}") 347 | print("An error occurred while processing the question. Please try again.") 348 | 349 | print("Thank you for using GoldenRetrieverRAG. Goodbye!") 350 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | dspy 2 | python-dotenv 3 | aiohttp 4 | cachetools 5 | wikipedia 6 | backoff 7 | nest_asyncio 8 | rouge 9 | sentence-transformers 10 | torch 11 | transformers 12 | openai 13 | --------------------------------------------------------------------------------