├── .gitignore ├── README.md ├── cache ├── embeddings_11656.pt ├── embeddings_6119.pt └── embeddings_9811.pt ├── config └── config.py ├── dataset ├── 2wikimultihopqa.json ├── 2wikimultihopqa_corpus.json ├── hotpotqa.json ├── hotpotqa_corpus.json ├── musique.json └── musique_corpus.json ├── requirements.txt ├── run.py ├── scripts ├── README.md └── run.sh ├── setup.py └── src ├── __init__.py ├── evaluation ├── __init__.py └── evaluation.py ├── main.py ├── models ├── __init__.py ├── agentic_rag.py ├── base_rag.py ├── light_agentic_rag.py └── vanilla_rag.py └── utils ├── __init__.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Python缓存文件 2 | __pycache__/ 3 | */__pycache__/ 4 | */*/__pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | *.so 8 | .Python 9 | 10 | # 虚拟环境 11 | env/ 12 | venv/ 13 | ENV/ 14 | env.bak/ 15 | venv.bak/ 16 | 17 | # 项目特定目录 18 | result/ 19 | cache/ 20 | 21 | # 环境变量 22 | .env 23 | .env.* 24 | 25 | # 敏感配置备份 26 | *config_backup* 27 | 28 | # 编辑器和IDE文件 29 | .vscode/ 30 | .idea/ 31 | *.swp 32 | *.swo 33 | .DS_Store 34 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Agentic RAG 2 | 3 |
4 | 5 | 6 | 7 | 8 |
9 | 10 | 11 | A modular and extensible system of Retrieval-Augmented Generation (RAG) - with vanilla RAG and some prototype agentic RAG algorithms. Our key features include: 12 | 13 | - **Research Prototype** - A minimal viable implementation for experimentation and academic exploration 14 | - **Modular Architecture** - Clean, decoupled codebase designed for easy customization and extension 15 | - **Extensible** - Easily add new RAG algorithms to the framework 16 | 17 | 18 | ## Installation and Configuration 19 | 20 | - Install dependencies: 21 | ```bash 22 | pip install -r requirements.txt 23 | ``` 24 | - Set your OpenAI API key: 25 | ```bash 26 | # Create a .env file in the root directory with: 27 | OPENAI_API_KEY=your_api_key_here 28 | ``` 29 | 30 | - Other configuration options can be modified in `config/config.py` 31 | 32 | 33 | ## Quick start 34 | 35 | ### Running Evaluation on a Dataset 36 | 37 | 38 | ```bash 39 | python run.py --model MODEL_NAME --dataset path/to/dataset.json --corpus path/to/corpus.json 40 | ``` 41 | 42 | Options: 43 | - `--max-rounds`: Maximum number of agent retrieval rounds (default: 3) 44 | - `--top-k`: Number of top contexts to retrieve (default: 5) 45 | - `--limit`: Number of questions to evaluate (default: 20) 46 | 47 | ### Running a Single Question 48 | 49 | ```bash 50 | python run.py --model MODEL_NAME --question "Your question here" --corpus path/to/corpus.json 51 | ``` 52 | 53 | ### Using the Script 54 | 55 | For convenience, you can use the provided script to run evaluations with specific RAG models on all datasets: 56 | 57 | ```bash 58 | # Run evaluations with vanilla RAG on all datasets 59 | ./scripts/run.sh 60 | 61 | # Run evaluations with LightAgenticRAG on all datasets 62 | ./scripts/run.sh --model light 63 | ``` 64 | 65 | 66 | 67 | 68 | ## Components 69 | 70 | | Component | Features/Description | 71 | |-----------|---------------------| 72 | | **BaseRAG** | • Loading and processing document corpus
• Computing and caching document embeddings
• Basic retrieval functionality | 73 | | **VanillaRAG** | • Single retrieval step for relevant contexts
• Direct answer generation from retrieved contexts | 74 | | **AgenticRAG** | • Multiple retrieval rounds with iterative refinement
• Reflection on retrieved information to identify missing details
• Generation of focused sub-queries for additional retrieval | 75 | | **LightAgenticRAG** | • Memory-efficient implementation of AgenticRAG
• Optimized for running on systems with limited resources | 76 | | **Evaluation** | • Answer accuracy (LLM evaluated)
• Retrieval metrics
• Performance efficiency
• String-based evaluation metrics | 77 | 78 | ## Adding a New RAG Model 79 | 80 | To add a new RAG algorithm: 81 | 82 | 1. Create a new class that extends `BaseRAG` in the `src/models` directory 83 | 2. Implement the required methods (`answer_question` at minimum) 84 | 3. Add your model to the `RAG_MODELS` dictionary in `src/main.py` 85 | 86 | ```python 87 | # Example of a new RAG model 88 | from src.models.base_rag import BaseRAG 89 | 90 | class MyNewRAG(BaseRAG): 91 | def answer_question(self, question: str): 92 | # Implementation here 93 | return answer, contexts 94 | ``` 95 | 96 | ## Example Usage 97 | 98 | ```python 99 | from src.models.agentic_rag import AgenticRAG 100 | 101 | # Initialize RAG system 102 | rag = AgenticRAG('path/to/corpus.json') 103 | rag.set_max_rounds(3) 104 | rag.set_top_k(5) 105 | 106 | # Ask a question 107 | answer, contexts, rounds = rag.answer_question("What is the capital of France?") 108 | print(f"Answer: {answer}") 109 | print(f"Retrieved in {rounds} rounds") 110 | ``` 111 | 112 | ## Citation 113 | If you find this work helpful, please cite our recent paper: 114 | 115 | ``` 116 | @article{zhang2025survey, 117 | title={A Survey of Graph Retrieval-Augmented Generation for Customized Large Language Models}, 118 | author={Zhang, Qinggang and Chen, Shengyuan and Bei, Yuanchen and Yuan, Zheng and Zhou, Huachi and Hong, Zijin and Dong, Junnan and Chen, Hao and Chang, Yi and Huang, Xiao}, 119 | journal={arXiv preprint arXiv:2501.13958}, 120 | year={2025} 121 | } 122 | ``` -------------------------------------------------------------------------------- /cache/embeddings_11656.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chensyCN/Agentic-RAG/a17bdc8c8eaf231db5983a54703e4454f882ed9a/cache/embeddings_11656.pt -------------------------------------------------------------------------------- /cache/embeddings_6119.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chensyCN/Agentic-RAG/a17bdc8c8eaf231db5983a54703e4454f882ed9a/cache/embeddings_6119.pt -------------------------------------------------------------------------------- /cache/embeddings_9811.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chensyCN/Agentic-RAG/a17bdc8c8eaf231db5983a54703e4454f882ed9a/cache/embeddings_9811.pt -------------------------------------------------------------------------------- /config/config.py: -------------------------------------------------------------------------------- 1 | """ 2 | Configuration file for API keys and other settings. 3 | """ 4 | import os 5 | from dotenv import load_dotenv 6 | 7 | # Load environment variables from .env file 8 | load_dotenv() 9 | 10 | # OpenAI API Configuration 11 | OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY") 12 | 13 | # API Rate Limiting Configuration 14 | CALLS_PER_MINUTE = 20 15 | PERIOD = 60 16 | MAX_RETRIES = 3 17 | RETRY_DELAY = 120 18 | 19 | # Model Configuration 20 | DEFAULT_MODEL = "gpt-4o-mini" # please specify your preferred LLM model 21 | DEFAULT_MAX_TOKENS = 150 22 | 23 | # Embedding Configuration 24 | EMBEDDING_MODEL = "sentence-transformers/all-MiniLM-L6-v2" # please specify your preferred embedding model 25 | EMBEDDING_BATCH_SIZE = 32 26 | 27 | # Cache Configuration 28 | CACHE_DIR = "cache" 29 | RESULT_DIR = "result" -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.9.0 2 | sentence-transformers>=2.2.0 3 | openai>=1.0.0 4 | tqdm>=4.62.0 5 | numpy>=1.21.0 6 | backoff>=1.11.0 7 | ratelimit>=2.2.1 8 | python-dotenv>=0.19.0 -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """ 3 | Entry point for Agentic-RAG. 4 | Run this script to start the program. 5 | """ 6 | 7 | from src.main import main 8 | 9 | if __name__ == "__main__": 10 | main() -------------------------------------------------------------------------------- /scripts/README.md: -------------------------------------------------------------------------------- 1 | # Agentic-RAG Scripts 2 | 3 | This directory contains scripts for running and evaluating the Agentic-RAG system on various datasets. 4 | 5 | ## Available Scripts 6 | 7 | ### run.sh 8 | 9 | Runs evaluations on all available datasets in the `dataset` directory. 10 | 11 | ```bash 12 | # Run with default parameters (standard AgenticRAG) 13 | ./scripts/run.sh 14 | 15 | # Run with LightAgenticRAG 16 | ./scripts/run.sh --rag-type light 17 | 18 | # Run both algorithms for comparison 19 | ./scripts/run.sh --rag-type both 20 | 21 | # You can also modify other parameters: 22 | # - MAX_ROUNDS: Number of agent rounds (default: 3) 23 | # - TOP_K: Number of top contexts to retrieve (default: 5) 24 | # - EVAL_TOP_KS: List of k values for evaluation (default: "5 10 20") 25 | # - LIMIT: Number of questions to evaluate per dataset (default: 20) 26 | ``` 27 | 28 | #### What it does: 29 | 30 | 1. Runs evaluations on HotpotQA, MuSiQue, and 2WikiMultihopQA datasets 31 | 2. Saves results to separate files in the `result` directory 32 | 3. Shows progress and completion messages 33 | 34 | #### Output: 35 | 36 | The script will generate output files like: 37 | - `result/hotpotqa_evaluation_standard.json` (for standard AgenticRAG) 38 | - `result/hotpotqa_evaluation_light.json` (for LightAgenticRAG) 39 | 40 | ### run_light_agentic.sh 41 | 42 | Runs evaluations using only the LightAgenticRAG algorithm on all available datasets. 43 | 44 | ```bash 45 | # Run with default parameters 46 | ./scripts/run_light_agentic.sh 47 | ``` 48 | 49 | #### What it does: 50 | 51 | 1. Runs LightAgenticRAG evaluations on all datasets 52 | 2. Saves results to the `result/light_agentic` directory 53 | 3. Shows progress and completion messages 54 | 55 | #### Output: 56 | 57 | The script will generate the following output files: 58 | - `result/light_agentic/hotpotqa_light_evaluation.json` 59 | - `result/light_agentic/musique_light_evaluation.json` 60 | - `result/light_agentic/2wikimultihopqa_light_evaluation.json` 61 | 62 | ### test_light_agentic_rag.py 63 | 64 | Tests and compares LightAgenticRAG with standard AgenticRAG using a single question. 65 | 66 | ```bash 67 | # Test only LightAgenticRAG 68 | python scripts/test_light_agentic_rag.py --question "Your question here" 69 | 70 | # Compare LightAgenticRAG with standard AgenticRAG 71 | python scripts/test_light_agentic_rag.py --question "Your question here" --compare 72 | ``` 73 | 74 | Additional options: 75 | - `--corpus`: Path to corpus file (default: dataset/hotpotqa_corpus.json) 76 | - `--max-rounds`: Maximum number of rounds (default: 3) 77 | - `--top-k`: Number of contexts to retrieve (default: 5) 78 | 79 | #### Output: 80 | 81 | Results are saved to `result/light_agentic_test_results.json` 82 | 83 | ## Adding New Scripts 84 | 85 | To add a new script: 86 | 1. Create your script in this directory 87 | 2. Make it executable with `chmod +x scripts/your_script.sh` 88 | 3. Add documentation in this README.md file -------------------------------------------------------------------------------- /scripts/run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Agentic-RAG evaluation script 4 | # This script runs evaluations on all available datasets with configurable RAG model 5 | 6 | # Terminal colors 7 | GREEN='\033[0;32m' 8 | YELLOW='\033[1;33m' 9 | BLUE='\033[0;34m' 10 | NC='\033[0m' # No Color 11 | 12 | # Configuration parameters 13 | MAX_ROUNDS=5 14 | TOP_K=3 15 | EVAL_TOP_KS="5 10" 16 | LIMIT=50 # Number of questions to evaluate per dataset 17 | MODEL="light" # Default model: vanilla, agentic, light 18 | 19 | # Output directory for results 20 | RESULTS_DIR="result" 21 | mkdir -p $RESULTS_DIR 22 | 23 | # Show usage information 24 | usage() { 25 | echo "Usage: $0 [options]" 26 | echo "" 27 | echo "Options:" 28 | echo " -h, --help Show this help message" 29 | echo " -m, --max-rounds NUM Set maximum rounds (default: $MAX_ROUNDS)" 30 | echo " -k, --top-k NUM Set top-k contexts (default: $TOP_K)" 31 | echo " -l, --limit NUM Set number of questions per dataset (default: $LIMIT)" 32 | echo " -r, --model MODEL Set RAG model: vanilla, agentic, light (default: $MODEL)" 33 | echo "" 34 | exit 1 35 | } 36 | 37 | # Parse command line arguments 38 | while [[ $# -gt 0 ]]; do 39 | case $1 in 40 | -h|--help) 41 | usage 42 | ;; 43 | -m|--max-rounds) 44 | MAX_ROUNDS="$2" 45 | shift 2 46 | ;; 47 | -k|--top-k) 48 | TOP_K="$2" 49 | shift 2 50 | ;; 51 | -l|--limit) 52 | LIMIT="$2" 53 | shift 2 54 | ;; 55 | -r|--model) 56 | MODEL="$2" 57 | shift 2 58 | ;; 59 | *) 60 | echo "Unknown option: $1" 61 | usage 62 | ;; 63 | esac 64 | done 65 | 66 | # Validate RAG model 67 | if [[ "$MODEL" != "vanilla" && "$MODEL" != "agentic" && "$MODEL" != "light" ]]; then 68 | echo "Invalid RAG model: $MODEL. Must be 'vanilla', 'agentic', or 'light'." 69 | usage 70 | fi 71 | 72 | echo -e "${GREEN}Starting Agentic-RAG evaluations...${NC}" 73 | echo -e "Model: $MODEL, Max rounds: $MAX_ROUNDS, Top-K: $TOP_K, Eval Top-Ks: $EVAL_TOP_KS, Questions per dataset: $LIMIT" 74 | echo "" 75 | 76 | # Function to run evaluation on a dataset 77 | run_evaluation() { 78 | local dataset=$1 79 | local corpus=$2 80 | local output=$3 81 | local dataset_name=$(basename "$dataset" .json) 82 | 83 | echo -e "${BLUE}Evaluating ${dataset_name} with ${MODEL} model${NC}" 84 | echo "Dataset: $dataset" 85 | echo "Corpus: $corpus" 86 | echo "Output: $output" 87 | 88 | python run.py \ 89 | --dataset "$dataset" \ 90 | --corpus "$corpus" \ 91 | --model "$MODEL" \ 92 | --max-rounds $MAX_ROUNDS \ 93 | --top-k $TOP_K \ 94 | --eval-top-ks $EVAL_TOP_KS \ 95 | --limit $LIMIT \ 96 | --output "$output" 97 | 98 | echo -e "${GREEN}Evaluation complete for $dataset_name${NC}" 99 | echo "" 100 | } 101 | 102 | # HotpotQA dataset 103 | echo -e "${YELLOW}=== HotpotQA Dataset ===${NC}" 104 | run_evaluation \ 105 | "dataset/hotpotqa.json" \ 106 | "dataset/hotpotqa_corpus.json" \ 107 | "${MODEL}_hotpotqa_evaluation.json" 108 | 109 | # MuSiQue dataset 110 | echo -e "${YELLOW}=== MuSiQue Dataset ===${NC}" 111 | run_evaluation \ 112 | "dataset/musique.json" \ 113 | "dataset/musique_corpus.json" \ 114 | "${MODEL}_musique_evaluation.json" 115 | 116 | # 2WikiMultihopQA dataset 117 | echo -e "${YELLOW}=== 2WikiMultihopQA Dataset ===${NC}" 118 | run_evaluation \ 119 | "dataset/2wikimultihopqa.json" \ 120 | "dataset/2wikimultihopqa_corpus.json" \ 121 | "${MODEL}_2wikimultihopqa_evaluation.json" 122 | 123 | echo -e "${GREEN}All evaluations complete!${NC}" 124 | echo -e "Results saved with prefix '${MODEL}_' in the $RESULTS_DIR directory." -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name="agentic-rag", 5 | version="0.1.0", 6 | description="Agentic Retrieval-Augmented Generation with iterative retrieval", 7 | author="Your Name", 8 | packages=find_packages(), 9 | install_requires=[ 10 | "torch", 11 | "sentence-transformers", 12 | "openai", 13 | "tqdm", 14 | "numpy", 15 | "backoff", 16 | "ratelimit", 17 | ], 18 | python_requires=">=3.7", 19 | ) -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chensyCN/Agentic-RAG/a17bdc8c8eaf231db5983a54703e4454f882ed9a/src/__init__.py -------------------------------------------------------------------------------- /src/evaluation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chensyCN/Agentic-RAG/a17bdc8c8eaf231db5983a54703e4454f882ed9a/src/evaluation/__init__.py -------------------------------------------------------------------------------- /src/evaluation/evaluation.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import time 3 | from typing import Dict, List, Tuple, Any 4 | import json 5 | import os 6 | from tqdm import tqdm 7 | 8 | from src.utils.utils import ( 9 | normalize_answer, 10 | evaluate_with_llm, 11 | string_based_evaluation, 12 | save_results 13 | ) 14 | from src.models.vanilla_rag import VanillaRAG 15 | from src.models.agentic_rag import AgenticRAG 16 | from src.models.light_agentic_rag import LightAgenticRAG 17 | from config.config import RESULT_DIR 18 | 19 | # Configure logging 20 | logging.basicConfig(level=logging.INFO) 21 | logger = logging.getLogger(__name__) 22 | 23 | # Dictionary of available RAG models 24 | RAG_MODELS = { 25 | "vanilla": VanillaRAG, 26 | "agentic": AgenticRAG, 27 | "light": LightAgenticRAG 28 | } 29 | 30 | class RAGEvaluator: 31 | """Evaluator for RAG models.""" 32 | 33 | def __init__(self, model_name: str, corpus_path: str, max_rounds: int = 3, top_k: int = 5, 34 | eval_top_ks: List[int] = [5, 10]): 35 | """Initialize the evaluator with corpus path and parameters. 36 | 37 | Args: 38 | model_name: Name of the RAG model to evaluate 39 | corpus_path: Path to the corpus file 40 | max_rounds: Maximum number of rounds for agentic RAG 41 | top_k: Number of contexts to retrieve 42 | eval_top_ks: List of k values for top-k accuracy evaluation 43 | """ 44 | self.model_name = model_name 45 | self.corpus_path = corpus_path 46 | self.max_rounds = max_rounds 47 | self.top_k = top_k 48 | self.eval_top_ks = sorted(eval_top_ks) # Sort to ensure consistent processing 49 | 50 | # Create result directory if it doesn't exist 51 | os.makedirs(RESULT_DIR, exist_ok=True) 52 | 53 | # Initialize the RAG model 54 | self._initialize_model() 55 | 56 | def _initialize_model(self): 57 | """Initialize the specified RAG model.""" 58 | if self.model_name not in RAG_MODELS: 59 | raise ValueError(f"Unknown RAG model: {self.model_name}") 60 | 61 | # Create model instance 62 | model_class = RAG_MODELS[self.model_name] 63 | self.model = model_class(self.corpus_path) 64 | 65 | # Configure model 66 | self.model.set_top_k(self.top_k) 67 | 68 | # Set max rounds for agentic models 69 | if hasattr(self.model, 'set_max_rounds'): 70 | self.model.set_max_rounds(self.max_rounds) 71 | 72 | logger.info(f"Initialized {self.model_name} model") 73 | 74 | def evaluate_question(self, question: str, gold_answer: str) -> Dict: 75 | """Evaluate the model on a single question.""" 76 | # Run the model on the question 77 | start_time = time.time() 78 | 79 | # Handle different return signatures 80 | if self.model_name in ["agentic", "light"]: 81 | answer, contexts, rounds = self.model.answer_question(question) 82 | elapsed_time = time.time() - start_time 83 | 84 | # Evaluate answer with LLM 85 | is_correct = evaluate_with_llm(answer, gold_answer) 86 | 87 | return { 88 | "question": question, 89 | "gold_answer": gold_answer, 90 | "answer": answer, 91 | "contexts": contexts, 92 | "time": elapsed_time, 93 | "rounds": rounds, 94 | "is_correct": is_correct 95 | } 96 | else: 97 | answer, contexts = self.model.answer_question(question) 98 | elapsed_time = time.time() - start_time 99 | 100 | # Evaluate answer with LLM 101 | is_correct = evaluate_with_llm(answer, gold_answer) 102 | 103 | return { 104 | "question": question, 105 | "gold_answer": gold_answer, 106 | "answer": answer, 107 | "contexts": contexts, 108 | "time": elapsed_time, 109 | "is_correct": is_correct 110 | } 111 | 112 | def calculate_retrieval_metrics(self, retrieved_contexts: List[List[str]], answers: List[str]) -> Dict[str, float]: 113 | """Calculate retrieval-based metrics.""" 114 | total = len(answers) 115 | found_in_context = 0 116 | 117 | # Initialize answer_in_top_k counters for each k in eval_top_ks 118 | answer_in_top_k = {k: 0 for k in self.eval_top_ks} 119 | 120 | for contexts, answer in zip(retrieved_contexts, answers): 121 | normalized_answer = normalize_answer(answer) 122 | 123 | # Check if answer is in any context 124 | for i, context in enumerate(contexts): 125 | if normalized_answer in normalize_answer(context): 126 | found_in_context += 1 127 | # Update counters for each k value 128 | for k in self.eval_top_ks: 129 | if i < k: 130 | answer_in_top_k[k] += 1 131 | break 132 | 133 | # Prepare result dictionary 134 | result = { 135 | "answer_found_in_context": found_in_context / total, 136 | "total_questions": total 137 | } 138 | 139 | # Add top-k metrics to result 140 | for k in self.eval_top_ks: 141 | result[f"answer_in_top{k}"] = answer_in_top_k[k] / total 142 | 143 | return result 144 | 145 | def run_single_model_evaluation(self, eval_data: List[Dict], output_file: str = "evaluation_results.json"): 146 | """Run evaluation of a single model on the given evaluation data.""" 147 | results = [] 148 | 149 | # Evaluation metrics 150 | total_questions = len(eval_data) 151 | 152 | # Initialize metrics dictionary with dynamic top-k keys 153 | metrics = { 154 | "total_time": 0, 155 | "answer_coverage": 0, 156 | "answer_accuracy": 0, 157 | "string_accuracy": 0, 158 | "string_precision": 0, 159 | "string_recall": 0 160 | } 161 | 162 | # Add top-k hits for each k in eval_top_ks 163 | for k in self.eval_top_ks: 164 | metrics[f"top{k}_hits"] = 0 165 | 166 | # Add rounds tracking for agentic models 167 | if self.model_name in ["agentic", "light"]: 168 | metrics["total_rounds"] = 0 169 | 170 | for item in tqdm(eval_data, desc=f"Evaluating {self.model_name}"): 171 | question = item['question'] 172 | gold_answer = item['answer'] 173 | 174 | # Evaluate the model on this question 175 | result = self.evaluate_question( 176 | question=question, 177 | gold_answer=gold_answer 178 | ) 179 | results.append(result) 180 | 181 | # Update metrics 182 | metrics["total_time"] += result["time"] 183 | normalized_gold = normalize_answer(gold_answer) 184 | 185 | # String-based evaluation 186 | string_metrics = string_based_evaluation( 187 | result["answer"], 188 | gold_answer 189 | ) 190 | metrics["string_accuracy"] += string_metrics["accuracy"] 191 | metrics["string_precision"] += string_metrics["precision"] 192 | metrics["string_recall"] += string_metrics["recall"] 193 | 194 | # Check retrieval coverage 195 | for i, ctx in enumerate(result["contexts"]): 196 | if normalized_gold in normalize_answer(ctx): 197 | metrics["answer_coverage"] += 1 198 | # Update counters for each k value 199 | for k in self.eval_top_ks: 200 | if i < k: 201 | metrics[f"top{k}_hits"] += 1 202 | break 203 | 204 | # Update rounds for agentic models 205 | if self.model_name in ["agentic", "light"] and "rounds" in result: 206 | metrics["total_rounds"] += result["rounds"] 207 | 208 | # Evaluate answer using LLM 209 | if result["is_correct"]: 210 | metrics["answer_accuracy"] += 1 211 | 212 | # Calculate average metrics 213 | avg_metrics = { 214 | "avg_time": metrics["total_time"] / total_questions, 215 | "answer_coverage": metrics["answer_coverage"] / total_questions * 100, 216 | "answer_accuracy": metrics["answer_accuracy"] / total_questions * 100, 217 | "string_accuracy": metrics["string_accuracy"] / total_questions * 100, 218 | "string_precision": metrics["string_precision"] / total_questions * 100, 219 | "string_recall": metrics["string_recall"] / total_questions * 100 220 | } 221 | 222 | # Add top-k coverage (renamed from accuracy) for each k in eval_top_ks 223 | for k in self.eval_top_ks: 224 | avg_metrics[f"top{k}_coverage"] = metrics[f"top{k}_hits"] / total_questions * 100 225 | 226 | # Add average rounds for agentic models 227 | if self.model_name in ["agentic", "light"]: 228 | avg_metrics["avg_rounds"] = metrics["total_rounds"] / total_questions 229 | 230 | # Organize metrics by category 231 | organized_metrics = { 232 | "performance": { 233 | "avg_time": avg_metrics["avg_time"] 234 | }, 235 | "string_based": { 236 | "accuracy": avg_metrics["string_accuracy"], 237 | "precision": avg_metrics["string_precision"], 238 | "recall": avg_metrics["string_recall"] 239 | }, 240 | "llm_evaluated": { 241 | "answer_accuracy": avg_metrics["answer_accuracy"] 242 | }, 243 | "retrieval": { 244 | "answer_coverage": avg_metrics["answer_coverage"] 245 | } 246 | } 247 | 248 | # Add rounds for agentic models 249 | if self.model_name in ["agentic", "light"]: 250 | organized_metrics["performance"]["avg_rounds"] = avg_metrics["avg_rounds"] 251 | 252 | # Add top-k coverage metrics 253 | for k in self.eval_top_ks: 254 | organized_metrics["retrieval"][f"top{k}_coverage"] = avg_metrics[f"top{k}_coverage"] 255 | 256 | # Add raw metrics for backwards compatibility 257 | organized_metrics["raw"] = metrics 258 | 259 | # Prepare final evaluation summary 260 | evaluation_summary = { 261 | "model": self.model_name, 262 | "metrics": organized_metrics, 263 | "results": results 264 | } 265 | 266 | # Save results 267 | save_results( 268 | results=evaluation_summary, 269 | output_file=output_file, 270 | results_dir=RESULT_DIR 271 | ) 272 | 273 | # Log results in three sections 274 | logger.info(f"\nEvaluation Summary for {self.model_name}:") 275 | 276 | # Performance metrics 277 | if self.model_name in ["agentic", "light"]: 278 | logger.info(f"Average time per question: {avg_metrics['avg_time']:.2f} seconds") 279 | logger.info(f"Average rounds per question: {avg_metrics['avg_rounds']:.2f}") 280 | else: 281 | logger.info(f"Average time per question: {avg_metrics['avg_time']:.2f} seconds") 282 | 283 | # 1. String-based metrics 284 | logger.info("\n1. String-based Metrics:") 285 | logger.info(f" • Accuracy: {avg_metrics['string_accuracy']:.2f}%") 286 | logger.info(f" • Precision: {avg_metrics['string_precision']:.2f}%") 287 | logger.info(f" • Recall: {avg_metrics['string_recall']:.2f}%") 288 | 289 | # 2. LLM evaluated metrics 290 | logger.info("\n2. LLM Evaluated Metrics:") 291 | logger.info(f" • Answer Accuracy: {avg_metrics['answer_accuracy']:.2f}%") 292 | 293 | # 3. Retrieval performance 294 | logger.info("\n3. Retrieval Performance:") 295 | logger.info(f" • Answer Coverage: {avg_metrics['answer_coverage']:.2f}%") 296 | 297 | # Log top-k coverage metrics 298 | for k in self.eval_top_ks: 299 | logger.info(f" • Top-{k} Coverage: {avg_metrics[f'top{k}_coverage']:.2f}%") 300 | 301 | return evaluation_summary -------------------------------------------------------------------------------- /src/main.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import argparse 3 | import json 4 | import logging 5 | import importlib 6 | from typing import Dict, List, Type 7 | 8 | from src.evaluation.evaluation import RAGEvaluator 9 | from src.models.base_rag import BaseRAG 10 | from src.models.vanilla_rag import VanillaRAG 11 | from src.models.agentic_rag import AgenticRAG 12 | from src.models.light_agentic_rag import LightAgenticRAG 13 | 14 | # Configure logging 15 | logging.basicConfig(level=logging.INFO, 16 | format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') 17 | logger = logging.getLogger(__name__) 18 | 19 | # Dictionary of available RAG models 20 | RAG_MODELS = { 21 | "vanilla": VanillaRAG, 22 | "agentic": AgenticRAG, 23 | "light": LightAgenticRAG 24 | } 25 | 26 | def parse_arguments(): 27 | """Parse command line arguments.""" 28 | parser = argparse.ArgumentParser(description='Run RAG models') 29 | 30 | # Dataset arguments 31 | parser.add_argument('--dataset', type=str, default='dataset/hotpotqa.json', 32 | help='Path to the dataset file') 33 | parser.add_argument('--corpus', type=str, default='dataset/hotpotqa_corpus.json', 34 | help='Path to the corpus file') 35 | parser.add_argument('--limit', type=int, default=20, 36 | help='Number of questions to evaluate (default: 20)') 37 | 38 | # RAG configuration 39 | parser.add_argument('--max-rounds', type=int, default=3, 40 | help='Maximum number of agent rounds') 41 | parser.add_argument('--top-k', type=int, default=5, 42 | help='Number of top contexts to retrieve') 43 | parser.add_argument('--eval-top-ks', type=int, nargs='+', default=[5, 10], 44 | help='List of k values for top-k accuracy evaluation (default: [5, 10])') 45 | 46 | # Single question (optional) 47 | parser.add_argument('--question', type=str, 48 | help='Optional: Single question to answer') 49 | 50 | # RAG model selection 51 | parser.add_argument('--model', type=str, choices=list(RAG_MODELS.keys()), 52 | default='vanilla', 53 | help='Which RAG model to use') 54 | 55 | # Evaluation options 56 | parser.add_argument('--output', type=str, default='evaluation_results.json', 57 | help='Output file name') 58 | 59 | return parser.parse_args() 60 | 61 | def load_evaluation_data(dataset_path: str, limit: int) -> List[Dict]: 62 | """Load and limit the evaluation dataset.""" 63 | try: 64 | with open(dataset_path, 'r') as f: 65 | eval_data = json.load(f) 66 | 67 | # Limit the number of questions if needed 68 | if limit and limit > 0: 69 | eval_data = eval_data[:limit] 70 | 71 | return eval_data 72 | except Exception as e: 73 | logger.error(f"Error loading dataset: {e}") 74 | return [] 75 | 76 | def create_rag_model(model_name: str, corpus_path: str, max_rounds: int = 3, top_k: int = 5) -> BaseRAG: 77 | """Create and configure a RAG model instance. 78 | 79 | Args: 80 | model_name: Name of the RAG model to create 81 | corpus_path: Path to the corpus file 82 | max_rounds: Maximum number of rounds for agentic models 83 | top_k: Number of contexts to retrieve 84 | 85 | Returns: 86 | A configured RAG model instance 87 | """ 88 | # Check if model exists 89 | if model_name not in RAG_MODELS: 90 | raise ValueError(f"Unknown RAG model: {model_name}") 91 | 92 | # Create model instance 93 | model_class = RAG_MODELS[model_name] 94 | model = model_class(corpus_path) 95 | 96 | # Configure model 97 | model.set_top_k(top_k) 98 | 99 | # Set max rounds for agentic models 100 | if hasattr(model, 'set_max_rounds'): 101 | model.set_max_rounds(max_rounds) 102 | 103 | return model 104 | 105 | def run_single_question(model_name: str, question: str, corpus_path: str, max_rounds: int, top_k: int): 106 | """Run a single question through the specified RAG model.""" 107 | # Create and configure the model 108 | model = create_rag_model(model_name, corpus_path, max_rounds, top_k) 109 | 110 | # Get answer 111 | logger.info(f"\nQuestion: {question}") 112 | logger.info(f"Using {model_name} RAG model") 113 | 114 | # Handle different return signatures 115 | if model_name in ["agentic", "light"]: 116 | answer, contexts, rounds = model.answer_question(question) 117 | logger.info(f"\nAnswer: {answer}") 118 | logger.info(f"Retrieved in {rounds} rounds") 119 | else: 120 | answer, contexts = model.answer_question(question) 121 | logger.info(f"\nAnswer: {answer}") 122 | 123 | # Log contexts 124 | logger.info("\nContexts used:") 125 | for i, ctx in enumerate(contexts): 126 | logger.info(f"{i+1}. {ctx[:100]}...") 127 | 128 | return answer, contexts 129 | 130 | def main(): 131 | """Main function to run the RAG model.""" 132 | args = parse_arguments() 133 | 134 | # If a question is provided, run in single question mode 135 | if args.question: 136 | run_single_question( 137 | model_name=args.model, 138 | question=args.question, 139 | corpus_path=args.corpus, 140 | max_rounds=args.max_rounds, 141 | top_k=args.top_k 142 | ) 143 | return 144 | 145 | # Otherwise run in evaluation mode 146 | logger.info(f"Starting evaluation of {args.model} RAG model") 147 | logger.info(f"Max rounds: {args.max_rounds}, Top-k: {args.top_k}") 148 | logger.info(f"Evaluating top-k accuracy for k values: {args.eval_top_ks}") 149 | 150 | # Load evaluation data 151 | eval_data = load_evaluation_data(args.dataset, args.limit) 152 | if not eval_data: 153 | logger.error("No evaluation data available. Exiting.") 154 | return 155 | 156 | logger.info(f"Loaded {len(eval_data)} questions for evaluation") 157 | 158 | # Initialize evaluator for single model 159 | evaluator = RAGEvaluator( 160 | model_name=args.model, 161 | corpus_path=args.corpus, 162 | max_rounds=args.max_rounds, 163 | top_k=args.top_k, 164 | eval_top_ks=args.eval_top_ks 165 | ) 166 | 167 | # Run evaluation 168 | evaluation_summary = evaluator.run_single_model_evaluation( 169 | eval_data=eval_data, 170 | output_file=args.output 171 | ) 172 | 173 | logger.info("Evaluation complete!") 174 | 175 | if __name__ == "__main__": 176 | main() -------------------------------------------------------------------------------- /src/models/__init__.py: -------------------------------------------------------------------------------- 1 | from src.models.base_rag import BaseRAG 2 | from src.models.vanilla_rag import VanillaRAG 3 | from src.models.agentic_rag import AgenticRAG 4 | from src.models.light_agentic_rag import LightAgenticRAG 5 | 6 | __all__ = ['BaseRAG', 'VanillaRAG', 'AgenticRAG', 'LightAgenticRAG'] 7 | -------------------------------------------------------------------------------- /src/models/agentic_rag.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import time 4 | from typing import List, Dict, Tuple, Any 5 | from src.models.base_rag import BaseRAG 6 | from src.utils.utils import get_response_with_retry, REFLECTION_PROMPT 7 | 8 | # Configure logging 9 | logging.basicConfig(level=logging.INFO) 10 | logger = logging.getLogger(__name__) 11 | 12 | class AgenticRAG(BaseRAG): 13 | """ 14 | AgenticRAG implements an agentic approach to retrieval-augmented generation. 15 | It uses iterative reflection and retrieval to improve answer quality. 16 | """ 17 | 18 | def __init__(self, corpus_path: str = None, cache_dir: str = "./cache"): 19 | """Initialize the AgenticRAG system.""" 20 | super().__init__(corpus_path, cache_dir) 21 | self.max_rounds = 3 # Default max rounds for iterative retrieval 22 | 23 | def set_max_rounds(self, max_rounds: int): 24 | """Set the maximum number of retrieval rounds.""" 25 | self.max_rounds = max_rounds 26 | 27 | def analyze_completeness(self, question: str, context: List[str]) -> Dict: 28 | """Analyze if the retrieved context is sufficient to answer the question.""" 29 | try: 30 | context_text = "\n".join(context) 31 | prompt = f"""Question: {question} 32 | 33 | Retrieved Context: 34 | {context_text} 35 | 36 | {REFLECTION_PROMPT}""" 37 | 38 | try: 39 | response = get_response_with_retry(prompt) 40 | 41 | # Clean up response to ensure it's valid JSON 42 | response = response.strip() 43 | 44 | # Remove any markdown code block markers 45 | response = response.replace('```json', '').replace('```', '') 46 | 47 | # Try to find JSON-like content within the response 48 | import re 49 | json_match = re.search(r'\{.*\}', response, re.DOTALL) 50 | if json_match: 51 | response = json_match.group() 52 | 53 | # Parse the cleaned response 54 | result = json.loads(response) 55 | 56 | # Validate required fields 57 | required_fields = ["can_answer", "missing_info", "subquery", "current_understanding"] 58 | if not all(field in result for field in required_fields): 59 | raise ValueError("Missing required fields") 60 | 61 | # Ensure boolean type for can_answer 62 | result["can_answer"] = bool(result["can_answer"]) 63 | 64 | # Ensure non-empty subquery 65 | if not result["subquery"]: 66 | result["subquery"] = question 67 | 68 | return result 69 | 70 | except json.JSONDecodeError as e: 71 | logger.error(f"JSON parsing error: {e}") 72 | logger.error(f"Raw response: {response}") 73 | return { 74 | "can_answer": True, 75 | "missing_info": "", 76 | "subquery": question, 77 | "current_understanding": "Failed to parse reflection response." 78 | } 79 | 80 | except Exception as e: 81 | logger.error(f"Error in analyze_completeness: {e}") 82 | return { 83 | "can_answer": True, 84 | "missing_info": "", 85 | "subquery": question, 86 | "current_understanding": f"Error during analysis: {str(e)}" 87 | } 88 | 89 | def generate_answer(self, question: str, context: List[str], 90 | current_understanding: str = "") -> str: 91 | """Generate final answer based on all retrieved context.""" 92 | try: 93 | context_text = "\n".join(context) 94 | current_understanding_text = f"\nCurrent Understanding: {current_understanding}" if current_understanding else "" 95 | 96 | prompt = f"""You must give ONLY the direct answer in the most concise way possible. DO NOT explain or provide any additional context. 97 | If the answer is a simple yes/no, just say "Yes." or "No." 98 | If the answer is a name, just give the name. 99 | If the answer is a date, just give the date. 100 | If the answer is a number, just give the number. 101 | If the answer requires a brief phrase, make it as concise as possible. 102 | 103 | Question: {question}{current_understanding_text} 104 | 105 | Context: 106 | {context_text} 107 | 108 | Remember: Be as concise as vanilla RAG - give ONLY the essential answer, nothing more. 109 | Ans: """ 110 | 111 | return get_response_with_retry(prompt) 112 | except Exception as e: 113 | logger.error(f"Error generating answer: {e}") 114 | return "" 115 | 116 | def answer_question(self, question: str) -> Tuple[str, List[str], int]: 117 | """Answer question with iterative retrieval and reflection.""" 118 | all_contexts = [] 119 | round_count = 0 120 | current_query = question 121 | retrieval_history = [] 122 | 123 | while round_count < self.max_rounds: 124 | round_count += 1 125 | logger.info(f"Retrieval round {round_count}") 126 | 127 | # Retrieve relevant contexts 128 | new_contexts = self.retrieve(current_query) 129 | all_contexts.extend(new_contexts) 130 | 131 | # Remove duplicates while preserving order 132 | seen = set() 133 | all_contexts = [x for x in all_contexts if not (x in seen or seen.add(x))] 134 | 135 | # Record retrieval history 136 | retrieval_history.append({ 137 | "round": round_count, 138 | "query": current_query, 139 | "contexts": new_contexts 140 | }) 141 | 142 | # Analyze completeness 143 | analysis = self.analyze_completeness(question, all_contexts) 144 | 145 | if analysis["can_answer"]: 146 | # Generate and return final answer 147 | answer = self.generate_answer( 148 | question, 149 | all_contexts, 150 | analysis["current_understanding"] 151 | ) 152 | return answer, all_contexts, round_count 153 | 154 | # Update query for next round 155 | current_query = analysis["subquery"] 156 | logger.info(f"Generated subquery: {current_query}") 157 | 158 | # If max rounds reached, generate best possible answer 159 | answer = self.generate_answer( 160 | question, 161 | all_contexts, 162 | "Note: Maximum retrieval rounds reached. Providing best possible answer." 163 | ) 164 | return answer, all_contexts, round_count -------------------------------------------------------------------------------- /src/models/base_rag.py: -------------------------------------------------------------------------------- 1 | import os 2 | import gc 3 | import json 4 | import pickle 5 | import torch 6 | import logging 7 | from typing import List, Dict, Any 8 | from tqdm import tqdm 9 | from sentence_transformers import SentenceTransformer 10 | import numpy as np 11 | from config.config import ( 12 | CACHE_DIR, 13 | RESULT_DIR, 14 | EMBEDDING_MODEL, 15 | EMBEDDING_BATCH_SIZE 16 | ) 17 | 18 | # Configure logging 19 | logging.basicConfig(level=logging.INFO) 20 | logger = logging.getLogger(__name__) 21 | 22 | class BaseRAG: 23 | def __init__(self, corpus_path: str = None, cache_dir: str = CACHE_DIR): 24 | """Initialize the BaseRAG system.""" 25 | self.cache_dir = cache_dir 26 | os.makedirs(cache_dir, exist_ok=True) 27 | os.makedirs(RESULT_DIR, exist_ok=True) 28 | 29 | self.model = SentenceTransformer(EMBEDDING_MODEL) 30 | self.corpus = {} 31 | self.corpus_embeddings = None 32 | self.embeddings = None # For compatibility with vanilla_retrieve 33 | self.sentences = None # For compatibility with vanilla_retrieve 34 | self.retrieval_cache = {} 35 | self.top_k = 5 # Default retrieval count 36 | 37 | if corpus_path: 38 | self.load_corpus(corpus_path) 39 | 40 | def load_corpus(self, corpus_path: str): 41 | """Load and process the document corpus.""" 42 | logger.info("Loading corpus...") 43 | with open(corpus_path, 'r') as f: 44 | documents = json.load(f) 45 | 46 | # Process documents into chunks 47 | self.corpus = { 48 | i: f"title: {doc['title']} content: {doc['text']}" 49 | for i, doc in enumerate(documents) 50 | } 51 | 52 | # Store sentences for vanilla retrieval 53 | self.sentences = list(self.corpus.values()) 54 | 55 | # Try to load cached embeddings 56 | cache_file = os.path.join(self.cache_dir, f'embeddings_{len(self.corpus)}.pt') 57 | 58 | if os.path.exists(cache_file): 59 | logger.info("Loading cached embeddings...") 60 | self.corpus_embeddings = torch.load(cache_file) 61 | self.embeddings = self.corpus_embeddings # For compatibility with vanilla_retrieve 62 | else: 63 | logger.info("Computing embeddings...") 64 | texts = list(self.corpus.values()) 65 | self.corpus_embeddings = self.encode_sentences_batch(texts) 66 | self.embeddings = self.corpus_embeddings # For compatibility with vanilla_retrieve 67 | torch.save(self.corpus_embeddings, cache_file) 68 | 69 | def encode_batch(self, texts: List[str], batch_size: int = EMBEDDING_BATCH_SIZE) -> np.ndarray: 70 | """Encode texts in batches.""" 71 | all_embeddings = [] 72 | for i in range(0, len(texts), batch_size): 73 | batch = texts[i:i + batch_size] 74 | embeddings = self.model.encode(batch, convert_to_tensor=True) 75 | all_embeddings.append(embeddings) 76 | return torch.cat(all_embeddings) 77 | 78 | def encode_sentences_batch(self, sentences: List[str], batch_size: int = 32) -> torch.Tensor: 79 | """Encode sentences in batches with memory management.""" 80 | all_embeddings = [] 81 | 82 | for i in tqdm(range(0, len(sentences), batch_size), desc="Encoding sentences"): 83 | batch = sentences[i:i + batch_size] 84 | 85 | gc.collect() 86 | if torch.cuda.is_available(): 87 | torch.cuda.empty_cache() 88 | 89 | with torch.no_grad(): 90 | embeddings = self.model.encode( 91 | batch, 92 | convert_to_tensor=True, 93 | show_progress_bar=False 94 | ) 95 | embeddings = embeddings.cpu() 96 | all_embeddings.append(embeddings) 97 | 98 | final_embeddings = torch.cat(all_embeddings, dim=0) 99 | del all_embeddings 100 | gc.collect() 101 | 102 | return final_embeddings 103 | 104 | def build_index(self, sentences: List[str], batch_size: int = 32): 105 | """Build the embedding index for the sentences.""" 106 | self.sentences = sentences 107 | 108 | # Try to load existing embeddings 109 | embedding_file = f'cache/embeddings_{len(sentences)}.pkl' 110 | if os.path.exists(embedding_file): 111 | try: 112 | with open(embedding_file, 'rb') as f: 113 | self.embeddings = pickle.load(f) 114 | logger.info(f"Embeddings loaded from {embedding_file}") 115 | return 116 | except Exception as e: 117 | logger.error(f"Error loading embeddings: {e}") 118 | 119 | # Build new embeddings 120 | self.embeddings = self.encode_sentences_batch(sentences, batch_size) 121 | 122 | # Save embeddings 123 | try: 124 | os.makedirs('cache', exist_ok=True) 125 | with open(embedding_file, 'wb') as f: 126 | pickle.dump(self.embeddings, f) 127 | except Exception as e: 128 | logger.error(f"Error saving embeddings: {e}") 129 | 130 | def retrieve(self, query: str) -> List[str]: 131 | """Retrieve similar sentences using query embedding.""" 132 | # Check cache first 133 | if query in self.retrieval_cache: 134 | return self.retrieval_cache[query] 135 | 136 | if self.corpus_embeddings is None or not self.corpus: 137 | return [] 138 | 139 | try: 140 | # Encode query 141 | with torch.no_grad(): 142 | query_embedding = self.model.encode([query], convert_to_tensor=True)[0] 143 | query_embedding = query_embedding.cpu() 144 | 145 | # Calculate similarities 146 | similarities = torch.nn.functional.cosine_similarity( 147 | query_embedding.unsqueeze(0), 148 | self.corpus_embeddings 149 | ) 150 | 151 | # Convert indices to list before using them 152 | top_k_scores, top_k_indices = similarities.topk(self.top_k) 153 | indices = top_k_indices.tolist() 154 | 155 | # Get results using integer indices 156 | results = [self.corpus[idx] for idx in indices] 157 | 158 | # Cache results 159 | self.retrieval_cache[query] = results 160 | return results 161 | 162 | except Exception as e: 163 | logger.error(f"Error in retrieve: {e}") 164 | return [] 165 | 166 | def set_top_k(self, top_k: int): 167 | """Set the number of top contexts to retrieve.""" 168 | self.top_k = top_k -------------------------------------------------------------------------------- /src/models/light_agentic_rag.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import time 4 | from typing import List, Dict, Tuple, Any 5 | from src.models.base_rag import BaseRAG 6 | from src.utils.utils import get_response_with_retry 7 | 8 | # Configure logging 9 | logging.basicConfig(level=logging.INFO) 10 | logger = logging.getLogger(__name__) 11 | 12 | class LightAgenticRAG(BaseRAG): 13 | """ 14 | LightAgenticRAG implements a memory-efficient agentic approach to retrieval-augmented generation. 15 | Instead of accumulating all retrieved contexts across iterations, it maintains a concise 16 | information summary that is continually refined with new retrieved information. 17 | """ 18 | 19 | def __init__(self, corpus_path: str = None, cache_dir: str = "./cache"): 20 | """Initialize the LightAgenticRAG system.""" 21 | super().__init__(corpus_path, cache_dir) 22 | self.max_rounds = 3 # Default max rounds for iterative retrieval 23 | 24 | def set_max_rounds(self, max_rounds: int): 25 | """Set the maximum number of retrieval rounds.""" 26 | self.max_rounds = max_rounds 27 | 28 | def generate_or_refine_summary(self, question: str, new_contexts: List[str], 29 | current_summary: str = "") -> str: 30 | """ 31 | Generate a new summary or refine an existing one based on newly retrieved contexts. 32 | 33 | Args: 34 | question: The original question 35 | new_contexts: Newly retrieved context chunks 36 | current_summary: Current information summary (if any) 37 | 38 | Returns: 39 | A concise summary of all relevant information so far 40 | """ 41 | try: 42 | context_text = "\n".join(new_contexts) 43 | 44 | if not current_summary: 45 | # Generate initial summary 46 | prompt = f"""Please create a concise summary of the following information as it relates to answering this question: 47 | 48 | Question: {question} 49 | 50 | Information: 51 | {context_text} 52 | 53 | Your summary should: 54 | 1. Include all relevant facts that might help answer the question 55 | 2. Exclude irrelevant information 56 | 3. Be clear and concise 57 | 4. Preserve specific details, dates, numbers, and names that may be relevant 58 | 59 | Summary:""" 60 | else: 61 | # Refine existing summary with new information 62 | prompt = f"""Please refine the following information summary using newly retrieved information. 63 | 64 | Question: {question} 65 | 66 | Current summary: 67 | {current_summary} 68 | 69 | New information: 70 | {context_text} 71 | 72 | Your refined summary should: 73 | 1. Integrate new relevant facts with the existing summary 74 | 2. Remove redundancies 75 | 3. Remain concise while preserving all important information 76 | 4. Prioritize information that helps answer the question 77 | 5. Maintain specific details, dates, numbers, and names that may be relevant 78 | 79 | Refined summary:""" 80 | 81 | summary = get_response_with_retry(prompt) 82 | return summary 83 | 84 | except Exception as e: 85 | logger.error(f"Error generating/refining summary: {e}") 86 | # If error occurs, concatenate current summary with new contexts as fallback 87 | if current_summary: 88 | return f"{current_summary}\n\nNew information:\n{context_text}" 89 | return context_text 90 | 91 | def analyze_completeness(self, question: str, info_summary: str) -> Dict: 92 | """ 93 | Analyze if the current information summary is sufficient to answer the question. 94 | 95 | Args: 96 | question: The original question 97 | info_summary: Current information summary 98 | 99 | Returns: 100 | Dictionary with analysis results 101 | """ 102 | try: 103 | prompt = f"""Question: {question} 104 | 105 | Available Information: 106 | {info_summary} 107 | 108 | Based on the information provided, please analyze: 109 | 1. Can the question be answered completely with this information? (Yes/No) 110 | 2. What specific information is missing, if any? 111 | 3. What specific question should we ask to find the missing information? 112 | 4. Summarize our current understanding based on available information. 113 | 114 | Please format your response as a JSON object with these keys: 115 | - "can_answer": boolean 116 | - "missing_info": string 117 | - "subquery": string 118 | - "current_understanding": string""" 119 | 120 | try: 121 | response = get_response_with_retry(prompt) 122 | 123 | # Clean up response to ensure it's valid JSON 124 | response = response.strip() 125 | 126 | # Remove any markdown code block markers 127 | response = response.replace('```json', '').replace('```', '') 128 | 129 | # Try to find JSON-like content within the response 130 | import re 131 | json_match = re.search(r'\{.*\}', response, re.DOTALL) 132 | if json_match: 133 | response = json_match.group() 134 | 135 | # Parse the cleaned response 136 | result = json.loads(response) 137 | 138 | # Validate required fields 139 | required_fields = ["can_answer", "missing_info", "subquery", "current_understanding"] 140 | if not all(field in result for field in required_fields): 141 | raise ValueError("Missing required fields") 142 | 143 | # Ensure boolean type for can_answer 144 | result["can_answer"] = bool(result["can_answer"]) 145 | 146 | # Ensure non-empty subquery 147 | if not result["subquery"]: 148 | result["subquery"] = question 149 | 150 | return result 151 | 152 | except json.JSONDecodeError as e: 153 | logger.error(f"JSON parsing error: {e}") 154 | logger.error(f"Raw response: {response}") 155 | return { 156 | "can_answer": True, 157 | "missing_info": "", 158 | "subquery": question, 159 | "current_understanding": "Failed to parse reflection response." 160 | } 161 | 162 | except Exception as e: 163 | logger.error(f"Error in analyze_completeness: {e}") 164 | return { 165 | "can_answer": True, 166 | "missing_info": "", 167 | "subquery": question, 168 | "current_understanding": f"Error during analysis: {str(e)}" 169 | } 170 | 171 | def generate_answer(self, question: str, info_summary: str) -> str: 172 | """Generate final answer based on the information summary.""" 173 | try: 174 | prompt = f"""You must give ONLY the direct answer in the most concise way possible. DO NOT explain or provide any additional context. 175 | If the answer is a simple yes/no, just say "Yes." or "No." 176 | If the answer is a name, just give the name. 177 | If the answer is a date, just give the date. 178 | If the answer is a number, just give the number. 179 | If the answer requires a brief phrase, make it as concise as possible. 180 | 181 | Question: {question} 182 | 183 | Information Summary: 184 | {info_summary} 185 | 186 | Remember: Be concise - give ONLY the essential answer, nothing more. 187 | Ans: """ 188 | 189 | return get_response_with_retry(prompt) 190 | except Exception as e: 191 | logger.error(f"Error generating answer: {e}") 192 | return "" 193 | 194 | def answer_question(self, question: str) -> Tuple[str, List[str], int]: 195 | """ 196 | Answer question with iterative retrieval and information summary refinement. 197 | 198 | Returns: 199 | Tuple of (answer, last_retrieved_contexts, round_count) 200 | """ 201 | info_summary = "" # Start with empty summary 202 | round_count = 0 203 | current_query = question 204 | retrieval_history = [] 205 | last_contexts = [] # Store only the last retrieved contexts 206 | 207 | logger.info(f"LightAgenticRAG answering: {question}") 208 | 209 | while round_count < self.max_rounds: 210 | round_count += 1 211 | logger.info(f"Retrieval round {round_count}") 212 | 213 | # Retrieve relevant contexts for the current query 214 | new_contexts = self.retrieve(current_query) 215 | last_contexts = new_contexts # Save current contexts 216 | 217 | # Record retrieval history 218 | retrieval_history.append({ 219 | "round": round_count, 220 | "query": current_query, 221 | "contexts": new_contexts 222 | }) 223 | 224 | # Generate or refine information summary with new contexts 225 | info_summary = self.generate_or_refine_summary( 226 | question, 227 | new_contexts, 228 | info_summary 229 | ) 230 | 231 | logger.info(f"Information summary after round {round_count} (length: {len(info_summary)})") 232 | 233 | # Analyze if we can answer the question with current summary 234 | analysis = self.analyze_completeness(question, info_summary) 235 | 236 | if analysis["can_answer"]: 237 | # Generate and return final answer 238 | answer = self.generate_answer(question, info_summary) 239 | # We return the last retrieved contexts for evaluation purposes 240 | return answer, last_contexts, round_count 241 | 242 | # Update query for next round 243 | current_query = analysis["subquery"] 244 | logger.info(f"Generated subquery: {current_query}") 245 | 246 | # If max rounds reached, generate best possible answer 247 | logger.info(f"Reached maximum rounds ({self.max_rounds}). Generating final answer...") 248 | answer = self.generate_answer(question, info_summary) 249 | return answer, last_contexts, round_count -------------------------------------------------------------------------------- /src/models/vanilla_rag.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import List, Dict, Tuple 3 | from src.models.base_rag import BaseRAG 4 | from src.utils.utils import get_response_with_retry 5 | 6 | # Configure logging 7 | logging.basicConfig(level=logging.INFO) 8 | logger = logging.getLogger(__name__) 9 | 10 | class VanillaRAG(BaseRAG): 11 | """ 12 | VanillaRAG performs basic retrieval-augmented generation without iterative refinement. 13 | It inherits basic retrieval and embedding functionality from BaseRAG. 14 | """ 15 | 16 | def __init__(self, corpus_path: str = None, cache_dir: str = "./cache"): 17 | """Initialize the VanillaRAG system.""" 18 | super().__init__(corpus_path, cache_dir) 19 | 20 | def retrieve(self, query: str) -> List[str]: 21 | """Retrieve documents for the given query using vector similarity.""" 22 | return super().retrieve(query) 23 | 24 | def answer_question(self, question: str) -> Tuple[str, List[str]]: 25 | """ 26 | Answer a question using vanilla RAG approach: 27 | 1. Retrieve relevant contexts 28 | 2. Pass the question and contexts to the LLM to generate an answer 29 | """ 30 | # Retrieve relevant contexts 31 | contexts = self.retrieve(question) 32 | 33 | # Generate answer using retrieved contexts 34 | answer = self.generate_answer(question, contexts) 35 | 36 | return answer, contexts 37 | 38 | def generate_answer(self, question: str, contexts: List[str]) -> str: 39 | """Generate an answer based on the retrieved contexts.""" 40 | try: 41 | context_text = "\n".join(contexts) 42 | 43 | prompt = f"""You must give ONLY the direct answer in the most concise way possible. DO NOT explain or provide any additional context. 44 | If the answer is a simple yes/no, just say "Yes." or "No." 45 | If the answer is a name, just give the name. 46 | If the answer is a date, just give the date. 47 | If the answer is a number, just give the number. 48 | If the answer requires a brief phrase, make it as concise as possible. 49 | 50 | Question: {question} 51 | 52 | Context: 53 | {context_text} 54 | 55 | Remember: Be concise - give ONLY the essential answer, nothing more. 56 | Ans: """ 57 | 58 | return get_response_with_retry(prompt) 59 | except Exception as e: 60 | logger.error(f"Error generating answer: {e}") 61 | return "" -------------------------------------------------------------------------------- /src/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chensyCN/Agentic-RAG/a17bdc8c8eaf231db5983a54703e4454f882ed9a/src/utils/__init__.py -------------------------------------------------------------------------------- /src/utils/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import re 4 | import json 5 | import time 6 | import backoff 7 | from openai import OpenAI 8 | from ratelimit import limits, sleep_and_retry 9 | from collections import Counter 10 | from typing import List, Dict, Any 11 | from config.config import ( 12 | OPENAI_API_KEY, 13 | DEFAULT_MODEL, 14 | DEFAULT_MAX_TOKENS, 15 | CALLS_PER_MINUTE, 16 | PERIOD, 17 | MAX_RETRIES, 18 | RETRY_DELAY 19 | ) 20 | 21 | # Configure logging 22 | logging.basicConfig(level=logging.INFO) 23 | logger = logging.getLogger(__name__) 24 | 25 | # Configure OpenAI 26 | client = OpenAI(api_key=OPENAI_API_KEY) 27 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 28 | 29 | REFLECTION_PROMPT = """Based on the question and the retrieved context, analyze: 30 | 1. Can you confidently answer the question with the given context and your knowledge? 31 | 2. If not, what specific information is missing? 32 | 3. Generate a focused search query to find the missing information. 33 | 34 | Format your response as: 35 | { 36 | "can_answer": true/false, 37 | "missing_info": "description of what information is missing", 38 | "subquery": "specific search query for missing information", 39 | "current_understanding": "brief summary of current understanding" 40 | } 41 | """ 42 | 43 | @sleep_and_retry 44 | @limits(calls=CALLS_PER_MINUTE, period=PERIOD) 45 | @backoff.on_exception( 46 | backoff.expo, 47 | (Exception), 48 | max_tries=MAX_RETRIES, 49 | max_time=300 50 | ) 51 | def get_response_with_retry(prompt: str, temperature: float = 0.0) -> str: 52 | """Get response from OpenAI API with retry logic.""" 53 | try: 54 | messages = [ 55 | {"role": "system", "content": "You are a helpful assistant."}, 56 | {"role": "user", "content": prompt} 57 | ] 58 | response = client.chat.completions.create( 59 | model=DEFAULT_MODEL, 60 | messages=messages, 61 | temperature=temperature, 62 | max_tokens=DEFAULT_MAX_TOKENS 63 | ) 64 | return response.choices[0].message.content.strip() 65 | except Exception as e: 66 | logger.error(f"Error in get_response_with_retry: {str(e)}") 67 | return "" 68 | 69 | def normalize_answer(text: str) -> str: 70 | """Normalize answer text for comparison.""" 71 | if not isinstance(text, str): 72 | return "" 73 | 74 | # Convert to lowercase 75 | text = text.lower() 76 | 77 | # Replace hyphen with space 78 | text = text.replace('-', ' ') 79 | 80 | # Remove punctuation 81 | text = re.sub(r'[^\w\s]', '', text) 82 | 83 | # Remove extra whitespace 84 | text = ' '.join(text.split()) 85 | 86 | return text 87 | 88 | def save_results(results: Dict, output_file: str, results_dir: str = 'result'): 89 | """Save evaluation results to file. 90 | 91 | Args: 92 | results: Dictionary containing results to save 93 | output_file: Filename for the results 94 | results_dir: Directory to save results in 95 | """ 96 | os.makedirs(results_dir, exist_ok=True) 97 | output_path = os.path.join(results_dir, output_file) 98 | with open(output_path, 'w', encoding='utf-8') as f: 99 | json.dump(results, f, ensure_ascii=False, indent=2) 100 | logger.info(f"Results saved to {output_path}") 101 | 102 | def evaluate_with_llm(generated: str, gold: str) -> bool: 103 | """Use LLM to evaluate if the generated answer correctly answers the question.""" 104 | if not isinstance(generated, str) or not isinstance(gold, str): 105 | return False 106 | 107 | prompt = f"""You are an expert evaluator. Please evaluate if the generated answer is correct by comparing it with the gold answer. 108 | 109 | Generated answer: {generated} 110 | Gold answer: {gold} 111 | 112 | The generated answer should be considered correct if it: 113 | 1. Contains the key information from the gold answer 114 | 2. Is factually accurate and consistent with the gold answer 115 | 3. Does not contain any contradicting information 116 | 117 | Respond with ONLY 'correct' or 'incorrect'. 118 | Response:""" 119 | 120 | try: 121 | response = get_response_with_retry(prompt, temperature=0.0) 122 | return response.strip().lower() == "correct" 123 | except Exception as e: 124 | logger.error(f"Error in LLM evaluation: {e}") 125 | return False 126 | 127 | def string_based_evaluation(generated: str, gold: str) -> dict: 128 | """Evaluate string similarity between generated and gold answers. 129 | 130 | Args: 131 | generated: Generated answer string 132 | gold: Gold/ground truth answer string 133 | 134 | Returns: 135 | Dictionary containing accuracy, precision, recall metrics 136 | """ 137 | # Normalize answers 138 | normalized_prediction = normalize_answer(generated) 139 | normalized_ground_truth = normalize_answer(gold) 140 | 141 | # Calculate accuracy 142 | accuracy = 1 if normalized_ground_truth in normalized_prediction else 0 143 | 144 | # Calculate precision and recall 145 | prediction_tokens = normalized_prediction.split() 146 | ground_truth_tokens = normalized_ground_truth.split() 147 | 148 | # Handle yes/no/noanswer cases 149 | if (normalized_prediction in ["yes", "no", "noanswer"] and 150 | normalized_prediction != normalized_ground_truth) or \ 151 | (normalized_ground_truth in ["yes", "no", "noanswer"] and 152 | normalized_prediction != normalized_ground_truth): 153 | return { 154 | "accuracy": accuracy, 155 | "precision": 0, 156 | "recall": 0 157 | } 158 | 159 | # Calculate token overlap 160 | common = Counter(prediction_tokens) & Counter(ground_truth_tokens) 161 | num_same = sum(common.values()) 162 | 163 | # Calculate precision and recall 164 | precision = 1.0 * num_same / len(prediction_tokens) if prediction_tokens else 0 165 | recall = 1.0 * num_same / len(ground_truth_tokens) if ground_truth_tokens else 0 166 | 167 | return { 168 | "accuracy": accuracy, 169 | "precision": precision, 170 | "recall": recall 171 | } --------------------------------------------------------------------------------