├── .gitignore
├── README.md
├── cache
├── embeddings_11656.pt
├── embeddings_6119.pt
└── embeddings_9811.pt
├── config
└── config.py
├── dataset
├── 2wikimultihopqa.json
├── 2wikimultihopqa_corpus.json
├── hotpotqa.json
├── hotpotqa_corpus.json
├── musique.json
└── musique_corpus.json
├── requirements.txt
├── run.py
├── scripts
├── README.md
└── run.sh
├── setup.py
└── src
├── __init__.py
├── evaluation
├── __init__.py
└── evaluation.py
├── main.py
├── models
├── __init__.py
├── agentic_rag.py
├── base_rag.py
├── light_agentic_rag.py
└── vanilla_rag.py
└── utils
├── __init__.py
└── utils.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Python缓存文件
2 | __pycache__/
3 | */__pycache__/
4 | */*/__pycache__/
5 | *.py[cod]
6 | *$py.class
7 | *.so
8 | .Python
9 |
10 | # 虚拟环境
11 | env/
12 | venv/
13 | ENV/
14 | env.bak/
15 | venv.bak/
16 |
17 | # 项目特定目录
18 | result/
19 | cache/
20 |
21 | # 环境变量
22 | .env
23 | .env.*
24 |
25 | # 敏感配置备份
26 | *config_backup*
27 |
28 | # 编辑器和IDE文件
29 | .vscode/
30 | .idea/
31 | *.swp
32 | *.swo
33 | .DS_Store
34 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Agentic RAG
2 |
3 |
9 |
10 |
11 | A modular and extensible system of Retrieval-Augmented Generation (RAG) - with vanilla RAG and some prototype agentic RAG algorithms. Our key features include:
12 |
13 | - **Research Prototype** - A minimal viable implementation for experimentation and academic exploration
14 | - **Modular Architecture** - Clean, decoupled codebase designed for easy customization and extension
15 | - **Extensible** - Easily add new RAG algorithms to the framework
16 |
17 |
18 | ## Installation and Configuration
19 |
20 | - Install dependencies:
21 | ```bash
22 | pip install -r requirements.txt
23 | ```
24 | - Set your OpenAI API key:
25 | ```bash
26 | # Create a .env file in the root directory with:
27 | OPENAI_API_KEY=your_api_key_here
28 | ```
29 |
30 | - Other configuration options can be modified in `config/config.py`
31 |
32 |
33 | ## Quick start
34 |
35 | ### Running Evaluation on a Dataset
36 |
37 |
38 | ```bash
39 | python run.py --model MODEL_NAME --dataset path/to/dataset.json --corpus path/to/corpus.json
40 | ```
41 |
42 | Options:
43 | - `--max-rounds`: Maximum number of agent retrieval rounds (default: 3)
44 | - `--top-k`: Number of top contexts to retrieve (default: 5)
45 | - `--limit`: Number of questions to evaluate (default: 20)
46 |
47 | ### Running a Single Question
48 |
49 | ```bash
50 | python run.py --model MODEL_NAME --question "Your question here" --corpus path/to/corpus.json
51 | ```
52 |
53 | ### Using the Script
54 |
55 | For convenience, you can use the provided script to run evaluations with specific RAG models on all datasets:
56 |
57 | ```bash
58 | # Run evaluations with vanilla RAG on all datasets
59 | ./scripts/run.sh
60 |
61 | # Run evaluations with LightAgenticRAG on all datasets
62 | ./scripts/run.sh --model light
63 | ```
64 |
65 |
66 |
67 |
68 | ## Components
69 |
70 | | Component | Features/Description |
71 | |-----------|---------------------|
72 | | **BaseRAG** | • Loading and processing document corpus
• Computing and caching document embeddings
• Basic retrieval functionality |
73 | | **VanillaRAG** | • Single retrieval step for relevant contexts
• Direct answer generation from retrieved contexts |
74 | | **AgenticRAG** | • Multiple retrieval rounds with iterative refinement
• Reflection on retrieved information to identify missing details
• Generation of focused sub-queries for additional retrieval |
75 | | **LightAgenticRAG** | • Memory-efficient implementation of AgenticRAG
• Optimized for running on systems with limited resources |
76 | | **Evaluation** | • Answer accuracy (LLM evaluated)
• Retrieval metrics
• Performance efficiency
• String-based evaluation metrics |
77 |
78 | ## Adding a New RAG Model
79 |
80 | To add a new RAG algorithm:
81 |
82 | 1. Create a new class that extends `BaseRAG` in the `src/models` directory
83 | 2. Implement the required methods (`answer_question` at minimum)
84 | 3. Add your model to the `RAG_MODELS` dictionary in `src/main.py`
85 |
86 | ```python
87 | # Example of a new RAG model
88 | from src.models.base_rag import BaseRAG
89 |
90 | class MyNewRAG(BaseRAG):
91 | def answer_question(self, question: str):
92 | # Implementation here
93 | return answer, contexts
94 | ```
95 |
96 | ## Example Usage
97 |
98 | ```python
99 | from src.models.agentic_rag import AgenticRAG
100 |
101 | # Initialize RAG system
102 | rag = AgenticRAG('path/to/corpus.json')
103 | rag.set_max_rounds(3)
104 | rag.set_top_k(5)
105 |
106 | # Ask a question
107 | answer, contexts, rounds = rag.answer_question("What is the capital of France?")
108 | print(f"Answer: {answer}")
109 | print(f"Retrieved in {rounds} rounds")
110 | ```
111 |
112 | ## Citation
113 | If you find this work helpful, please cite our recent paper:
114 |
115 | ```
116 | @article{zhang2025survey,
117 | title={A Survey of Graph Retrieval-Augmented Generation for Customized Large Language Models},
118 | author={Zhang, Qinggang and Chen, Shengyuan and Bei, Yuanchen and Yuan, Zheng and Zhou, Huachi and Hong, Zijin and Dong, Junnan and Chen, Hao and Chang, Yi and Huang, Xiao},
119 | journal={arXiv preprint arXiv:2501.13958},
120 | year={2025}
121 | }
122 | ```
--------------------------------------------------------------------------------
/cache/embeddings_11656.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chensyCN/Agentic-RAG/a17bdc8c8eaf231db5983a54703e4454f882ed9a/cache/embeddings_11656.pt
--------------------------------------------------------------------------------
/cache/embeddings_6119.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chensyCN/Agentic-RAG/a17bdc8c8eaf231db5983a54703e4454f882ed9a/cache/embeddings_6119.pt
--------------------------------------------------------------------------------
/cache/embeddings_9811.pt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chensyCN/Agentic-RAG/a17bdc8c8eaf231db5983a54703e4454f882ed9a/cache/embeddings_9811.pt
--------------------------------------------------------------------------------
/config/config.py:
--------------------------------------------------------------------------------
1 | """
2 | Configuration file for API keys and other settings.
3 | """
4 | import os
5 | from dotenv import load_dotenv
6 |
7 | # Load environment variables from .env file
8 | load_dotenv()
9 |
10 | # OpenAI API Configuration
11 | OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY")
12 |
13 | # API Rate Limiting Configuration
14 | CALLS_PER_MINUTE = 20
15 | PERIOD = 60
16 | MAX_RETRIES = 3
17 | RETRY_DELAY = 120
18 |
19 | # Model Configuration
20 | DEFAULT_MODEL = "gpt-4o-mini" # please specify your preferred LLM model
21 | DEFAULT_MAX_TOKENS = 150
22 |
23 | # Embedding Configuration
24 | EMBEDDING_MODEL = "sentence-transformers/all-MiniLM-L6-v2" # please specify your preferred embedding model
25 | EMBEDDING_BATCH_SIZE = 32
26 |
27 | # Cache Configuration
28 | CACHE_DIR = "cache"
29 | RESULT_DIR = "result"
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch>=1.9.0
2 | sentence-transformers>=2.2.0
3 | openai>=1.0.0
4 | tqdm>=4.62.0
5 | numpy>=1.21.0
6 | backoff>=1.11.0
7 | ratelimit>=2.2.1
8 | python-dotenv>=0.19.0
--------------------------------------------------------------------------------
/run.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | """
3 | Entry point for Agentic-RAG.
4 | Run this script to start the program.
5 | """
6 |
7 | from src.main import main
8 |
9 | if __name__ == "__main__":
10 | main()
--------------------------------------------------------------------------------
/scripts/README.md:
--------------------------------------------------------------------------------
1 | # Agentic-RAG Scripts
2 |
3 | This directory contains scripts for running and evaluating the Agentic-RAG system on various datasets.
4 |
5 | ## Available Scripts
6 |
7 | ### run.sh
8 |
9 | Runs evaluations on all available datasets in the `dataset` directory.
10 |
11 | ```bash
12 | # Run with default parameters (standard AgenticRAG)
13 | ./scripts/run.sh
14 |
15 | # Run with LightAgenticRAG
16 | ./scripts/run.sh --rag-type light
17 |
18 | # Run both algorithms for comparison
19 | ./scripts/run.sh --rag-type both
20 |
21 | # You can also modify other parameters:
22 | # - MAX_ROUNDS: Number of agent rounds (default: 3)
23 | # - TOP_K: Number of top contexts to retrieve (default: 5)
24 | # - EVAL_TOP_KS: List of k values for evaluation (default: "5 10 20")
25 | # - LIMIT: Number of questions to evaluate per dataset (default: 20)
26 | ```
27 |
28 | #### What it does:
29 |
30 | 1. Runs evaluations on HotpotQA, MuSiQue, and 2WikiMultihopQA datasets
31 | 2. Saves results to separate files in the `result` directory
32 | 3. Shows progress and completion messages
33 |
34 | #### Output:
35 |
36 | The script will generate output files like:
37 | - `result/hotpotqa_evaluation_standard.json` (for standard AgenticRAG)
38 | - `result/hotpotqa_evaluation_light.json` (for LightAgenticRAG)
39 |
40 | ### run_light_agentic.sh
41 |
42 | Runs evaluations using only the LightAgenticRAG algorithm on all available datasets.
43 |
44 | ```bash
45 | # Run with default parameters
46 | ./scripts/run_light_agentic.sh
47 | ```
48 |
49 | #### What it does:
50 |
51 | 1. Runs LightAgenticRAG evaluations on all datasets
52 | 2. Saves results to the `result/light_agentic` directory
53 | 3. Shows progress and completion messages
54 |
55 | #### Output:
56 |
57 | The script will generate the following output files:
58 | - `result/light_agentic/hotpotqa_light_evaluation.json`
59 | - `result/light_agentic/musique_light_evaluation.json`
60 | - `result/light_agentic/2wikimultihopqa_light_evaluation.json`
61 |
62 | ### test_light_agentic_rag.py
63 |
64 | Tests and compares LightAgenticRAG with standard AgenticRAG using a single question.
65 |
66 | ```bash
67 | # Test only LightAgenticRAG
68 | python scripts/test_light_agentic_rag.py --question "Your question here"
69 |
70 | # Compare LightAgenticRAG with standard AgenticRAG
71 | python scripts/test_light_agentic_rag.py --question "Your question here" --compare
72 | ```
73 |
74 | Additional options:
75 | - `--corpus`: Path to corpus file (default: dataset/hotpotqa_corpus.json)
76 | - `--max-rounds`: Maximum number of rounds (default: 3)
77 | - `--top-k`: Number of contexts to retrieve (default: 5)
78 |
79 | #### Output:
80 |
81 | Results are saved to `result/light_agentic_test_results.json`
82 |
83 | ## Adding New Scripts
84 |
85 | To add a new script:
86 | 1. Create your script in this directory
87 | 2. Make it executable with `chmod +x scripts/your_script.sh`
88 | 3. Add documentation in this README.md file
--------------------------------------------------------------------------------
/scripts/run.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # Agentic-RAG evaluation script
4 | # This script runs evaluations on all available datasets with configurable RAG model
5 |
6 | # Terminal colors
7 | GREEN='\033[0;32m'
8 | YELLOW='\033[1;33m'
9 | BLUE='\033[0;34m'
10 | NC='\033[0m' # No Color
11 |
12 | # Configuration parameters
13 | MAX_ROUNDS=5
14 | TOP_K=3
15 | EVAL_TOP_KS="5 10"
16 | LIMIT=50 # Number of questions to evaluate per dataset
17 | MODEL="light" # Default model: vanilla, agentic, light
18 |
19 | # Output directory for results
20 | RESULTS_DIR="result"
21 | mkdir -p $RESULTS_DIR
22 |
23 | # Show usage information
24 | usage() {
25 | echo "Usage: $0 [options]"
26 | echo ""
27 | echo "Options:"
28 | echo " -h, --help Show this help message"
29 | echo " -m, --max-rounds NUM Set maximum rounds (default: $MAX_ROUNDS)"
30 | echo " -k, --top-k NUM Set top-k contexts (default: $TOP_K)"
31 | echo " -l, --limit NUM Set number of questions per dataset (default: $LIMIT)"
32 | echo " -r, --model MODEL Set RAG model: vanilla, agentic, light (default: $MODEL)"
33 | echo ""
34 | exit 1
35 | }
36 |
37 | # Parse command line arguments
38 | while [[ $# -gt 0 ]]; do
39 | case $1 in
40 | -h|--help)
41 | usage
42 | ;;
43 | -m|--max-rounds)
44 | MAX_ROUNDS="$2"
45 | shift 2
46 | ;;
47 | -k|--top-k)
48 | TOP_K="$2"
49 | shift 2
50 | ;;
51 | -l|--limit)
52 | LIMIT="$2"
53 | shift 2
54 | ;;
55 | -r|--model)
56 | MODEL="$2"
57 | shift 2
58 | ;;
59 | *)
60 | echo "Unknown option: $1"
61 | usage
62 | ;;
63 | esac
64 | done
65 |
66 | # Validate RAG model
67 | if [[ "$MODEL" != "vanilla" && "$MODEL" != "agentic" && "$MODEL" != "light" ]]; then
68 | echo "Invalid RAG model: $MODEL. Must be 'vanilla', 'agentic', or 'light'."
69 | usage
70 | fi
71 |
72 | echo -e "${GREEN}Starting Agentic-RAG evaluations...${NC}"
73 | echo -e "Model: $MODEL, Max rounds: $MAX_ROUNDS, Top-K: $TOP_K, Eval Top-Ks: $EVAL_TOP_KS, Questions per dataset: $LIMIT"
74 | echo ""
75 |
76 | # Function to run evaluation on a dataset
77 | run_evaluation() {
78 | local dataset=$1
79 | local corpus=$2
80 | local output=$3
81 | local dataset_name=$(basename "$dataset" .json)
82 |
83 | echo -e "${BLUE}Evaluating ${dataset_name} with ${MODEL} model${NC}"
84 | echo "Dataset: $dataset"
85 | echo "Corpus: $corpus"
86 | echo "Output: $output"
87 |
88 | python run.py \
89 | --dataset "$dataset" \
90 | --corpus "$corpus" \
91 | --model "$MODEL" \
92 | --max-rounds $MAX_ROUNDS \
93 | --top-k $TOP_K \
94 | --eval-top-ks $EVAL_TOP_KS \
95 | --limit $LIMIT \
96 | --output "$output"
97 |
98 | echo -e "${GREEN}Evaluation complete for $dataset_name${NC}"
99 | echo ""
100 | }
101 |
102 | # HotpotQA dataset
103 | echo -e "${YELLOW}=== HotpotQA Dataset ===${NC}"
104 | run_evaluation \
105 | "dataset/hotpotqa.json" \
106 | "dataset/hotpotqa_corpus.json" \
107 | "${MODEL}_hotpotqa_evaluation.json"
108 |
109 | # MuSiQue dataset
110 | echo -e "${YELLOW}=== MuSiQue Dataset ===${NC}"
111 | run_evaluation \
112 | "dataset/musique.json" \
113 | "dataset/musique_corpus.json" \
114 | "${MODEL}_musique_evaluation.json"
115 |
116 | # 2WikiMultihopQA dataset
117 | echo -e "${YELLOW}=== 2WikiMultihopQA Dataset ===${NC}"
118 | run_evaluation \
119 | "dataset/2wikimultihopqa.json" \
120 | "dataset/2wikimultihopqa_corpus.json" \
121 | "${MODEL}_2wikimultihopqa_evaluation.json"
122 |
123 | echo -e "${GREEN}All evaluations complete!${NC}"
124 | echo -e "Results saved with prefix '${MODEL}_' in the $RESULTS_DIR directory."
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | setup(
4 | name="agentic-rag",
5 | version="0.1.0",
6 | description="Agentic Retrieval-Augmented Generation with iterative retrieval",
7 | author="Your Name",
8 | packages=find_packages(),
9 | install_requires=[
10 | "torch",
11 | "sentence-transformers",
12 | "openai",
13 | "tqdm",
14 | "numpy",
15 | "backoff",
16 | "ratelimit",
17 | ],
18 | python_requires=">=3.7",
19 | )
--------------------------------------------------------------------------------
/src/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chensyCN/Agentic-RAG/a17bdc8c8eaf231db5983a54703e4454f882ed9a/src/__init__.py
--------------------------------------------------------------------------------
/src/evaluation/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chensyCN/Agentic-RAG/a17bdc8c8eaf231db5983a54703e4454f882ed9a/src/evaluation/__init__.py
--------------------------------------------------------------------------------
/src/evaluation/evaluation.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import time
3 | from typing import Dict, List, Tuple, Any
4 | import json
5 | import os
6 | from tqdm import tqdm
7 |
8 | from src.utils.utils import (
9 | normalize_answer,
10 | evaluate_with_llm,
11 | string_based_evaluation,
12 | save_results
13 | )
14 | from src.models.vanilla_rag import VanillaRAG
15 | from src.models.agentic_rag import AgenticRAG
16 | from src.models.light_agentic_rag import LightAgenticRAG
17 | from config.config import RESULT_DIR
18 |
19 | # Configure logging
20 | logging.basicConfig(level=logging.INFO)
21 | logger = logging.getLogger(__name__)
22 |
23 | # Dictionary of available RAG models
24 | RAG_MODELS = {
25 | "vanilla": VanillaRAG,
26 | "agentic": AgenticRAG,
27 | "light": LightAgenticRAG
28 | }
29 |
30 | class RAGEvaluator:
31 | """Evaluator for RAG models."""
32 |
33 | def __init__(self, model_name: str, corpus_path: str, max_rounds: int = 3, top_k: int = 5,
34 | eval_top_ks: List[int] = [5, 10]):
35 | """Initialize the evaluator with corpus path and parameters.
36 |
37 | Args:
38 | model_name: Name of the RAG model to evaluate
39 | corpus_path: Path to the corpus file
40 | max_rounds: Maximum number of rounds for agentic RAG
41 | top_k: Number of contexts to retrieve
42 | eval_top_ks: List of k values for top-k accuracy evaluation
43 | """
44 | self.model_name = model_name
45 | self.corpus_path = corpus_path
46 | self.max_rounds = max_rounds
47 | self.top_k = top_k
48 | self.eval_top_ks = sorted(eval_top_ks) # Sort to ensure consistent processing
49 |
50 | # Create result directory if it doesn't exist
51 | os.makedirs(RESULT_DIR, exist_ok=True)
52 |
53 | # Initialize the RAG model
54 | self._initialize_model()
55 |
56 | def _initialize_model(self):
57 | """Initialize the specified RAG model."""
58 | if self.model_name not in RAG_MODELS:
59 | raise ValueError(f"Unknown RAG model: {self.model_name}")
60 |
61 | # Create model instance
62 | model_class = RAG_MODELS[self.model_name]
63 | self.model = model_class(self.corpus_path)
64 |
65 | # Configure model
66 | self.model.set_top_k(self.top_k)
67 |
68 | # Set max rounds for agentic models
69 | if hasattr(self.model, 'set_max_rounds'):
70 | self.model.set_max_rounds(self.max_rounds)
71 |
72 | logger.info(f"Initialized {self.model_name} model")
73 |
74 | def evaluate_question(self, question: str, gold_answer: str) -> Dict:
75 | """Evaluate the model on a single question."""
76 | # Run the model on the question
77 | start_time = time.time()
78 |
79 | # Handle different return signatures
80 | if self.model_name in ["agentic", "light"]:
81 | answer, contexts, rounds = self.model.answer_question(question)
82 | elapsed_time = time.time() - start_time
83 |
84 | # Evaluate answer with LLM
85 | is_correct = evaluate_with_llm(answer, gold_answer)
86 |
87 | return {
88 | "question": question,
89 | "gold_answer": gold_answer,
90 | "answer": answer,
91 | "contexts": contexts,
92 | "time": elapsed_time,
93 | "rounds": rounds,
94 | "is_correct": is_correct
95 | }
96 | else:
97 | answer, contexts = self.model.answer_question(question)
98 | elapsed_time = time.time() - start_time
99 |
100 | # Evaluate answer with LLM
101 | is_correct = evaluate_with_llm(answer, gold_answer)
102 |
103 | return {
104 | "question": question,
105 | "gold_answer": gold_answer,
106 | "answer": answer,
107 | "contexts": contexts,
108 | "time": elapsed_time,
109 | "is_correct": is_correct
110 | }
111 |
112 | def calculate_retrieval_metrics(self, retrieved_contexts: List[List[str]], answers: List[str]) -> Dict[str, float]:
113 | """Calculate retrieval-based metrics."""
114 | total = len(answers)
115 | found_in_context = 0
116 |
117 | # Initialize answer_in_top_k counters for each k in eval_top_ks
118 | answer_in_top_k = {k: 0 for k in self.eval_top_ks}
119 |
120 | for contexts, answer in zip(retrieved_contexts, answers):
121 | normalized_answer = normalize_answer(answer)
122 |
123 | # Check if answer is in any context
124 | for i, context in enumerate(contexts):
125 | if normalized_answer in normalize_answer(context):
126 | found_in_context += 1
127 | # Update counters for each k value
128 | for k in self.eval_top_ks:
129 | if i < k:
130 | answer_in_top_k[k] += 1
131 | break
132 |
133 | # Prepare result dictionary
134 | result = {
135 | "answer_found_in_context": found_in_context / total,
136 | "total_questions": total
137 | }
138 |
139 | # Add top-k metrics to result
140 | for k in self.eval_top_ks:
141 | result[f"answer_in_top{k}"] = answer_in_top_k[k] / total
142 |
143 | return result
144 |
145 | def run_single_model_evaluation(self, eval_data: List[Dict], output_file: str = "evaluation_results.json"):
146 | """Run evaluation of a single model on the given evaluation data."""
147 | results = []
148 |
149 | # Evaluation metrics
150 | total_questions = len(eval_data)
151 |
152 | # Initialize metrics dictionary with dynamic top-k keys
153 | metrics = {
154 | "total_time": 0,
155 | "answer_coverage": 0,
156 | "answer_accuracy": 0,
157 | "string_accuracy": 0,
158 | "string_precision": 0,
159 | "string_recall": 0
160 | }
161 |
162 | # Add top-k hits for each k in eval_top_ks
163 | for k in self.eval_top_ks:
164 | metrics[f"top{k}_hits"] = 0
165 |
166 | # Add rounds tracking for agentic models
167 | if self.model_name in ["agentic", "light"]:
168 | metrics["total_rounds"] = 0
169 |
170 | for item in tqdm(eval_data, desc=f"Evaluating {self.model_name}"):
171 | question = item['question']
172 | gold_answer = item['answer']
173 |
174 | # Evaluate the model on this question
175 | result = self.evaluate_question(
176 | question=question,
177 | gold_answer=gold_answer
178 | )
179 | results.append(result)
180 |
181 | # Update metrics
182 | metrics["total_time"] += result["time"]
183 | normalized_gold = normalize_answer(gold_answer)
184 |
185 | # String-based evaluation
186 | string_metrics = string_based_evaluation(
187 | result["answer"],
188 | gold_answer
189 | )
190 | metrics["string_accuracy"] += string_metrics["accuracy"]
191 | metrics["string_precision"] += string_metrics["precision"]
192 | metrics["string_recall"] += string_metrics["recall"]
193 |
194 | # Check retrieval coverage
195 | for i, ctx in enumerate(result["contexts"]):
196 | if normalized_gold in normalize_answer(ctx):
197 | metrics["answer_coverage"] += 1
198 | # Update counters for each k value
199 | for k in self.eval_top_ks:
200 | if i < k:
201 | metrics[f"top{k}_hits"] += 1
202 | break
203 |
204 | # Update rounds for agentic models
205 | if self.model_name in ["agentic", "light"] and "rounds" in result:
206 | metrics["total_rounds"] += result["rounds"]
207 |
208 | # Evaluate answer using LLM
209 | if result["is_correct"]:
210 | metrics["answer_accuracy"] += 1
211 |
212 | # Calculate average metrics
213 | avg_metrics = {
214 | "avg_time": metrics["total_time"] / total_questions,
215 | "answer_coverage": metrics["answer_coverage"] / total_questions * 100,
216 | "answer_accuracy": metrics["answer_accuracy"] / total_questions * 100,
217 | "string_accuracy": metrics["string_accuracy"] / total_questions * 100,
218 | "string_precision": metrics["string_precision"] / total_questions * 100,
219 | "string_recall": metrics["string_recall"] / total_questions * 100
220 | }
221 |
222 | # Add top-k coverage (renamed from accuracy) for each k in eval_top_ks
223 | for k in self.eval_top_ks:
224 | avg_metrics[f"top{k}_coverage"] = metrics[f"top{k}_hits"] / total_questions * 100
225 |
226 | # Add average rounds for agentic models
227 | if self.model_name in ["agentic", "light"]:
228 | avg_metrics["avg_rounds"] = metrics["total_rounds"] / total_questions
229 |
230 | # Organize metrics by category
231 | organized_metrics = {
232 | "performance": {
233 | "avg_time": avg_metrics["avg_time"]
234 | },
235 | "string_based": {
236 | "accuracy": avg_metrics["string_accuracy"],
237 | "precision": avg_metrics["string_precision"],
238 | "recall": avg_metrics["string_recall"]
239 | },
240 | "llm_evaluated": {
241 | "answer_accuracy": avg_metrics["answer_accuracy"]
242 | },
243 | "retrieval": {
244 | "answer_coverage": avg_metrics["answer_coverage"]
245 | }
246 | }
247 |
248 | # Add rounds for agentic models
249 | if self.model_name in ["agentic", "light"]:
250 | organized_metrics["performance"]["avg_rounds"] = avg_metrics["avg_rounds"]
251 |
252 | # Add top-k coverage metrics
253 | for k in self.eval_top_ks:
254 | organized_metrics["retrieval"][f"top{k}_coverage"] = avg_metrics[f"top{k}_coverage"]
255 |
256 | # Add raw metrics for backwards compatibility
257 | organized_metrics["raw"] = metrics
258 |
259 | # Prepare final evaluation summary
260 | evaluation_summary = {
261 | "model": self.model_name,
262 | "metrics": organized_metrics,
263 | "results": results
264 | }
265 |
266 | # Save results
267 | save_results(
268 | results=evaluation_summary,
269 | output_file=output_file,
270 | results_dir=RESULT_DIR
271 | )
272 |
273 | # Log results in three sections
274 | logger.info(f"\nEvaluation Summary for {self.model_name}:")
275 |
276 | # Performance metrics
277 | if self.model_name in ["agentic", "light"]:
278 | logger.info(f"Average time per question: {avg_metrics['avg_time']:.2f} seconds")
279 | logger.info(f"Average rounds per question: {avg_metrics['avg_rounds']:.2f}")
280 | else:
281 | logger.info(f"Average time per question: {avg_metrics['avg_time']:.2f} seconds")
282 |
283 | # 1. String-based metrics
284 | logger.info("\n1. String-based Metrics:")
285 | logger.info(f" • Accuracy: {avg_metrics['string_accuracy']:.2f}%")
286 | logger.info(f" • Precision: {avg_metrics['string_precision']:.2f}%")
287 | logger.info(f" • Recall: {avg_metrics['string_recall']:.2f}%")
288 |
289 | # 2. LLM evaluated metrics
290 | logger.info("\n2. LLM Evaluated Metrics:")
291 | logger.info(f" • Answer Accuracy: {avg_metrics['answer_accuracy']:.2f}%")
292 |
293 | # 3. Retrieval performance
294 | logger.info("\n3. Retrieval Performance:")
295 | logger.info(f" • Answer Coverage: {avg_metrics['answer_coverage']:.2f}%")
296 |
297 | # Log top-k coverage metrics
298 | for k in self.eval_top_ks:
299 | logger.info(f" • Top-{k} Coverage: {avg_metrics[f'top{k}_coverage']:.2f}%")
300 |
301 | return evaluation_summary
--------------------------------------------------------------------------------
/src/main.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | import argparse
3 | import json
4 | import logging
5 | import importlib
6 | from typing import Dict, List, Type
7 |
8 | from src.evaluation.evaluation import RAGEvaluator
9 | from src.models.base_rag import BaseRAG
10 | from src.models.vanilla_rag import VanillaRAG
11 | from src.models.agentic_rag import AgenticRAG
12 | from src.models.light_agentic_rag import LightAgenticRAG
13 |
14 | # Configure logging
15 | logging.basicConfig(level=logging.INFO,
16 | format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
17 | logger = logging.getLogger(__name__)
18 |
19 | # Dictionary of available RAG models
20 | RAG_MODELS = {
21 | "vanilla": VanillaRAG,
22 | "agentic": AgenticRAG,
23 | "light": LightAgenticRAG
24 | }
25 |
26 | def parse_arguments():
27 | """Parse command line arguments."""
28 | parser = argparse.ArgumentParser(description='Run RAG models')
29 |
30 | # Dataset arguments
31 | parser.add_argument('--dataset', type=str, default='dataset/hotpotqa.json',
32 | help='Path to the dataset file')
33 | parser.add_argument('--corpus', type=str, default='dataset/hotpotqa_corpus.json',
34 | help='Path to the corpus file')
35 | parser.add_argument('--limit', type=int, default=20,
36 | help='Number of questions to evaluate (default: 20)')
37 |
38 | # RAG configuration
39 | parser.add_argument('--max-rounds', type=int, default=3,
40 | help='Maximum number of agent rounds')
41 | parser.add_argument('--top-k', type=int, default=5,
42 | help='Number of top contexts to retrieve')
43 | parser.add_argument('--eval-top-ks', type=int, nargs='+', default=[5, 10],
44 | help='List of k values for top-k accuracy evaluation (default: [5, 10])')
45 |
46 | # Single question (optional)
47 | parser.add_argument('--question', type=str,
48 | help='Optional: Single question to answer')
49 |
50 | # RAG model selection
51 | parser.add_argument('--model', type=str, choices=list(RAG_MODELS.keys()),
52 | default='vanilla',
53 | help='Which RAG model to use')
54 |
55 | # Evaluation options
56 | parser.add_argument('--output', type=str, default='evaluation_results.json',
57 | help='Output file name')
58 |
59 | return parser.parse_args()
60 |
61 | def load_evaluation_data(dataset_path: str, limit: int) -> List[Dict]:
62 | """Load and limit the evaluation dataset."""
63 | try:
64 | with open(dataset_path, 'r') as f:
65 | eval_data = json.load(f)
66 |
67 | # Limit the number of questions if needed
68 | if limit and limit > 0:
69 | eval_data = eval_data[:limit]
70 |
71 | return eval_data
72 | except Exception as e:
73 | logger.error(f"Error loading dataset: {e}")
74 | return []
75 |
76 | def create_rag_model(model_name: str, corpus_path: str, max_rounds: int = 3, top_k: int = 5) -> BaseRAG:
77 | """Create and configure a RAG model instance.
78 |
79 | Args:
80 | model_name: Name of the RAG model to create
81 | corpus_path: Path to the corpus file
82 | max_rounds: Maximum number of rounds for agentic models
83 | top_k: Number of contexts to retrieve
84 |
85 | Returns:
86 | A configured RAG model instance
87 | """
88 | # Check if model exists
89 | if model_name not in RAG_MODELS:
90 | raise ValueError(f"Unknown RAG model: {model_name}")
91 |
92 | # Create model instance
93 | model_class = RAG_MODELS[model_name]
94 | model = model_class(corpus_path)
95 |
96 | # Configure model
97 | model.set_top_k(top_k)
98 |
99 | # Set max rounds for agentic models
100 | if hasattr(model, 'set_max_rounds'):
101 | model.set_max_rounds(max_rounds)
102 |
103 | return model
104 |
105 | def run_single_question(model_name: str, question: str, corpus_path: str, max_rounds: int, top_k: int):
106 | """Run a single question through the specified RAG model."""
107 | # Create and configure the model
108 | model = create_rag_model(model_name, corpus_path, max_rounds, top_k)
109 |
110 | # Get answer
111 | logger.info(f"\nQuestion: {question}")
112 | logger.info(f"Using {model_name} RAG model")
113 |
114 | # Handle different return signatures
115 | if model_name in ["agentic", "light"]:
116 | answer, contexts, rounds = model.answer_question(question)
117 | logger.info(f"\nAnswer: {answer}")
118 | logger.info(f"Retrieved in {rounds} rounds")
119 | else:
120 | answer, contexts = model.answer_question(question)
121 | logger.info(f"\nAnswer: {answer}")
122 |
123 | # Log contexts
124 | logger.info("\nContexts used:")
125 | for i, ctx in enumerate(contexts):
126 | logger.info(f"{i+1}. {ctx[:100]}...")
127 |
128 | return answer, contexts
129 |
130 | def main():
131 | """Main function to run the RAG model."""
132 | args = parse_arguments()
133 |
134 | # If a question is provided, run in single question mode
135 | if args.question:
136 | run_single_question(
137 | model_name=args.model,
138 | question=args.question,
139 | corpus_path=args.corpus,
140 | max_rounds=args.max_rounds,
141 | top_k=args.top_k
142 | )
143 | return
144 |
145 | # Otherwise run in evaluation mode
146 | logger.info(f"Starting evaluation of {args.model} RAG model")
147 | logger.info(f"Max rounds: {args.max_rounds}, Top-k: {args.top_k}")
148 | logger.info(f"Evaluating top-k accuracy for k values: {args.eval_top_ks}")
149 |
150 | # Load evaluation data
151 | eval_data = load_evaluation_data(args.dataset, args.limit)
152 | if not eval_data:
153 | logger.error("No evaluation data available. Exiting.")
154 | return
155 |
156 | logger.info(f"Loaded {len(eval_data)} questions for evaluation")
157 |
158 | # Initialize evaluator for single model
159 | evaluator = RAGEvaluator(
160 | model_name=args.model,
161 | corpus_path=args.corpus,
162 | max_rounds=args.max_rounds,
163 | top_k=args.top_k,
164 | eval_top_ks=args.eval_top_ks
165 | )
166 |
167 | # Run evaluation
168 | evaluation_summary = evaluator.run_single_model_evaluation(
169 | eval_data=eval_data,
170 | output_file=args.output
171 | )
172 |
173 | logger.info("Evaluation complete!")
174 |
175 | if __name__ == "__main__":
176 | main()
--------------------------------------------------------------------------------
/src/models/__init__.py:
--------------------------------------------------------------------------------
1 | from src.models.base_rag import BaseRAG
2 | from src.models.vanilla_rag import VanillaRAG
3 | from src.models.agentic_rag import AgenticRAG
4 | from src.models.light_agentic_rag import LightAgenticRAG
5 |
6 | __all__ = ['BaseRAG', 'VanillaRAG', 'AgenticRAG', 'LightAgenticRAG']
7 |
--------------------------------------------------------------------------------
/src/models/agentic_rag.py:
--------------------------------------------------------------------------------
1 | import json
2 | import logging
3 | import time
4 | from typing import List, Dict, Tuple, Any
5 | from src.models.base_rag import BaseRAG
6 | from src.utils.utils import get_response_with_retry, REFLECTION_PROMPT
7 |
8 | # Configure logging
9 | logging.basicConfig(level=logging.INFO)
10 | logger = logging.getLogger(__name__)
11 |
12 | class AgenticRAG(BaseRAG):
13 | """
14 | AgenticRAG implements an agentic approach to retrieval-augmented generation.
15 | It uses iterative reflection and retrieval to improve answer quality.
16 | """
17 |
18 | def __init__(self, corpus_path: str = None, cache_dir: str = "./cache"):
19 | """Initialize the AgenticRAG system."""
20 | super().__init__(corpus_path, cache_dir)
21 | self.max_rounds = 3 # Default max rounds for iterative retrieval
22 |
23 | def set_max_rounds(self, max_rounds: int):
24 | """Set the maximum number of retrieval rounds."""
25 | self.max_rounds = max_rounds
26 |
27 | def analyze_completeness(self, question: str, context: List[str]) -> Dict:
28 | """Analyze if the retrieved context is sufficient to answer the question."""
29 | try:
30 | context_text = "\n".join(context)
31 | prompt = f"""Question: {question}
32 |
33 | Retrieved Context:
34 | {context_text}
35 |
36 | {REFLECTION_PROMPT}"""
37 |
38 | try:
39 | response = get_response_with_retry(prompt)
40 |
41 | # Clean up response to ensure it's valid JSON
42 | response = response.strip()
43 |
44 | # Remove any markdown code block markers
45 | response = response.replace('```json', '').replace('```', '')
46 |
47 | # Try to find JSON-like content within the response
48 | import re
49 | json_match = re.search(r'\{.*\}', response, re.DOTALL)
50 | if json_match:
51 | response = json_match.group()
52 |
53 | # Parse the cleaned response
54 | result = json.loads(response)
55 |
56 | # Validate required fields
57 | required_fields = ["can_answer", "missing_info", "subquery", "current_understanding"]
58 | if not all(field in result for field in required_fields):
59 | raise ValueError("Missing required fields")
60 |
61 | # Ensure boolean type for can_answer
62 | result["can_answer"] = bool(result["can_answer"])
63 |
64 | # Ensure non-empty subquery
65 | if not result["subquery"]:
66 | result["subquery"] = question
67 |
68 | return result
69 |
70 | except json.JSONDecodeError as e:
71 | logger.error(f"JSON parsing error: {e}")
72 | logger.error(f"Raw response: {response}")
73 | return {
74 | "can_answer": True,
75 | "missing_info": "",
76 | "subquery": question,
77 | "current_understanding": "Failed to parse reflection response."
78 | }
79 |
80 | except Exception as e:
81 | logger.error(f"Error in analyze_completeness: {e}")
82 | return {
83 | "can_answer": True,
84 | "missing_info": "",
85 | "subquery": question,
86 | "current_understanding": f"Error during analysis: {str(e)}"
87 | }
88 |
89 | def generate_answer(self, question: str, context: List[str],
90 | current_understanding: str = "") -> str:
91 | """Generate final answer based on all retrieved context."""
92 | try:
93 | context_text = "\n".join(context)
94 | current_understanding_text = f"\nCurrent Understanding: {current_understanding}" if current_understanding else ""
95 |
96 | prompt = f"""You must give ONLY the direct answer in the most concise way possible. DO NOT explain or provide any additional context.
97 | If the answer is a simple yes/no, just say "Yes." or "No."
98 | If the answer is a name, just give the name.
99 | If the answer is a date, just give the date.
100 | If the answer is a number, just give the number.
101 | If the answer requires a brief phrase, make it as concise as possible.
102 |
103 | Question: {question}{current_understanding_text}
104 |
105 | Context:
106 | {context_text}
107 |
108 | Remember: Be as concise as vanilla RAG - give ONLY the essential answer, nothing more.
109 | Ans: """
110 |
111 | return get_response_with_retry(prompt)
112 | except Exception as e:
113 | logger.error(f"Error generating answer: {e}")
114 | return ""
115 |
116 | def answer_question(self, question: str) -> Tuple[str, List[str], int]:
117 | """Answer question with iterative retrieval and reflection."""
118 | all_contexts = []
119 | round_count = 0
120 | current_query = question
121 | retrieval_history = []
122 |
123 | while round_count < self.max_rounds:
124 | round_count += 1
125 | logger.info(f"Retrieval round {round_count}")
126 |
127 | # Retrieve relevant contexts
128 | new_contexts = self.retrieve(current_query)
129 | all_contexts.extend(new_contexts)
130 |
131 | # Remove duplicates while preserving order
132 | seen = set()
133 | all_contexts = [x for x in all_contexts if not (x in seen or seen.add(x))]
134 |
135 | # Record retrieval history
136 | retrieval_history.append({
137 | "round": round_count,
138 | "query": current_query,
139 | "contexts": new_contexts
140 | })
141 |
142 | # Analyze completeness
143 | analysis = self.analyze_completeness(question, all_contexts)
144 |
145 | if analysis["can_answer"]:
146 | # Generate and return final answer
147 | answer = self.generate_answer(
148 | question,
149 | all_contexts,
150 | analysis["current_understanding"]
151 | )
152 | return answer, all_contexts, round_count
153 |
154 | # Update query for next round
155 | current_query = analysis["subquery"]
156 | logger.info(f"Generated subquery: {current_query}")
157 |
158 | # If max rounds reached, generate best possible answer
159 | answer = self.generate_answer(
160 | question,
161 | all_contexts,
162 | "Note: Maximum retrieval rounds reached. Providing best possible answer."
163 | )
164 | return answer, all_contexts, round_count
--------------------------------------------------------------------------------
/src/models/base_rag.py:
--------------------------------------------------------------------------------
1 | import os
2 | import gc
3 | import json
4 | import pickle
5 | import torch
6 | import logging
7 | from typing import List, Dict, Any
8 | from tqdm import tqdm
9 | from sentence_transformers import SentenceTransformer
10 | import numpy as np
11 | from config.config import (
12 | CACHE_DIR,
13 | RESULT_DIR,
14 | EMBEDDING_MODEL,
15 | EMBEDDING_BATCH_SIZE
16 | )
17 |
18 | # Configure logging
19 | logging.basicConfig(level=logging.INFO)
20 | logger = logging.getLogger(__name__)
21 |
22 | class BaseRAG:
23 | def __init__(self, corpus_path: str = None, cache_dir: str = CACHE_DIR):
24 | """Initialize the BaseRAG system."""
25 | self.cache_dir = cache_dir
26 | os.makedirs(cache_dir, exist_ok=True)
27 | os.makedirs(RESULT_DIR, exist_ok=True)
28 |
29 | self.model = SentenceTransformer(EMBEDDING_MODEL)
30 | self.corpus = {}
31 | self.corpus_embeddings = None
32 | self.embeddings = None # For compatibility with vanilla_retrieve
33 | self.sentences = None # For compatibility with vanilla_retrieve
34 | self.retrieval_cache = {}
35 | self.top_k = 5 # Default retrieval count
36 |
37 | if corpus_path:
38 | self.load_corpus(corpus_path)
39 |
40 | def load_corpus(self, corpus_path: str):
41 | """Load and process the document corpus."""
42 | logger.info("Loading corpus...")
43 | with open(corpus_path, 'r') as f:
44 | documents = json.load(f)
45 |
46 | # Process documents into chunks
47 | self.corpus = {
48 | i: f"title: {doc['title']} content: {doc['text']}"
49 | for i, doc in enumerate(documents)
50 | }
51 |
52 | # Store sentences for vanilla retrieval
53 | self.sentences = list(self.corpus.values())
54 |
55 | # Try to load cached embeddings
56 | cache_file = os.path.join(self.cache_dir, f'embeddings_{len(self.corpus)}.pt')
57 |
58 | if os.path.exists(cache_file):
59 | logger.info("Loading cached embeddings...")
60 | self.corpus_embeddings = torch.load(cache_file)
61 | self.embeddings = self.corpus_embeddings # For compatibility with vanilla_retrieve
62 | else:
63 | logger.info("Computing embeddings...")
64 | texts = list(self.corpus.values())
65 | self.corpus_embeddings = self.encode_sentences_batch(texts)
66 | self.embeddings = self.corpus_embeddings # For compatibility with vanilla_retrieve
67 | torch.save(self.corpus_embeddings, cache_file)
68 |
69 | def encode_batch(self, texts: List[str], batch_size: int = EMBEDDING_BATCH_SIZE) -> np.ndarray:
70 | """Encode texts in batches."""
71 | all_embeddings = []
72 | for i in range(0, len(texts), batch_size):
73 | batch = texts[i:i + batch_size]
74 | embeddings = self.model.encode(batch, convert_to_tensor=True)
75 | all_embeddings.append(embeddings)
76 | return torch.cat(all_embeddings)
77 |
78 | def encode_sentences_batch(self, sentences: List[str], batch_size: int = 32) -> torch.Tensor:
79 | """Encode sentences in batches with memory management."""
80 | all_embeddings = []
81 |
82 | for i in tqdm(range(0, len(sentences), batch_size), desc="Encoding sentences"):
83 | batch = sentences[i:i + batch_size]
84 |
85 | gc.collect()
86 | if torch.cuda.is_available():
87 | torch.cuda.empty_cache()
88 |
89 | with torch.no_grad():
90 | embeddings = self.model.encode(
91 | batch,
92 | convert_to_tensor=True,
93 | show_progress_bar=False
94 | )
95 | embeddings = embeddings.cpu()
96 | all_embeddings.append(embeddings)
97 |
98 | final_embeddings = torch.cat(all_embeddings, dim=0)
99 | del all_embeddings
100 | gc.collect()
101 |
102 | return final_embeddings
103 |
104 | def build_index(self, sentences: List[str], batch_size: int = 32):
105 | """Build the embedding index for the sentences."""
106 | self.sentences = sentences
107 |
108 | # Try to load existing embeddings
109 | embedding_file = f'cache/embeddings_{len(sentences)}.pkl'
110 | if os.path.exists(embedding_file):
111 | try:
112 | with open(embedding_file, 'rb') as f:
113 | self.embeddings = pickle.load(f)
114 | logger.info(f"Embeddings loaded from {embedding_file}")
115 | return
116 | except Exception as e:
117 | logger.error(f"Error loading embeddings: {e}")
118 |
119 | # Build new embeddings
120 | self.embeddings = self.encode_sentences_batch(sentences, batch_size)
121 |
122 | # Save embeddings
123 | try:
124 | os.makedirs('cache', exist_ok=True)
125 | with open(embedding_file, 'wb') as f:
126 | pickle.dump(self.embeddings, f)
127 | except Exception as e:
128 | logger.error(f"Error saving embeddings: {e}")
129 |
130 | def retrieve(self, query: str) -> List[str]:
131 | """Retrieve similar sentences using query embedding."""
132 | # Check cache first
133 | if query in self.retrieval_cache:
134 | return self.retrieval_cache[query]
135 |
136 | if self.corpus_embeddings is None or not self.corpus:
137 | return []
138 |
139 | try:
140 | # Encode query
141 | with torch.no_grad():
142 | query_embedding = self.model.encode([query], convert_to_tensor=True)[0]
143 | query_embedding = query_embedding.cpu()
144 |
145 | # Calculate similarities
146 | similarities = torch.nn.functional.cosine_similarity(
147 | query_embedding.unsqueeze(0),
148 | self.corpus_embeddings
149 | )
150 |
151 | # Convert indices to list before using them
152 | top_k_scores, top_k_indices = similarities.topk(self.top_k)
153 | indices = top_k_indices.tolist()
154 |
155 | # Get results using integer indices
156 | results = [self.corpus[idx] for idx in indices]
157 |
158 | # Cache results
159 | self.retrieval_cache[query] = results
160 | return results
161 |
162 | except Exception as e:
163 | logger.error(f"Error in retrieve: {e}")
164 | return []
165 |
166 | def set_top_k(self, top_k: int):
167 | """Set the number of top contexts to retrieve."""
168 | self.top_k = top_k
--------------------------------------------------------------------------------
/src/models/light_agentic_rag.py:
--------------------------------------------------------------------------------
1 | import json
2 | import logging
3 | import time
4 | from typing import List, Dict, Tuple, Any
5 | from src.models.base_rag import BaseRAG
6 | from src.utils.utils import get_response_with_retry
7 |
8 | # Configure logging
9 | logging.basicConfig(level=logging.INFO)
10 | logger = logging.getLogger(__name__)
11 |
12 | class LightAgenticRAG(BaseRAG):
13 | """
14 | LightAgenticRAG implements a memory-efficient agentic approach to retrieval-augmented generation.
15 | Instead of accumulating all retrieved contexts across iterations, it maintains a concise
16 | information summary that is continually refined with new retrieved information.
17 | """
18 |
19 | def __init__(self, corpus_path: str = None, cache_dir: str = "./cache"):
20 | """Initialize the LightAgenticRAG system."""
21 | super().__init__(corpus_path, cache_dir)
22 | self.max_rounds = 3 # Default max rounds for iterative retrieval
23 |
24 | def set_max_rounds(self, max_rounds: int):
25 | """Set the maximum number of retrieval rounds."""
26 | self.max_rounds = max_rounds
27 |
28 | def generate_or_refine_summary(self, question: str, new_contexts: List[str],
29 | current_summary: str = "") -> str:
30 | """
31 | Generate a new summary or refine an existing one based on newly retrieved contexts.
32 |
33 | Args:
34 | question: The original question
35 | new_contexts: Newly retrieved context chunks
36 | current_summary: Current information summary (if any)
37 |
38 | Returns:
39 | A concise summary of all relevant information so far
40 | """
41 | try:
42 | context_text = "\n".join(new_contexts)
43 |
44 | if not current_summary:
45 | # Generate initial summary
46 | prompt = f"""Please create a concise summary of the following information as it relates to answering this question:
47 |
48 | Question: {question}
49 |
50 | Information:
51 | {context_text}
52 |
53 | Your summary should:
54 | 1. Include all relevant facts that might help answer the question
55 | 2. Exclude irrelevant information
56 | 3. Be clear and concise
57 | 4. Preserve specific details, dates, numbers, and names that may be relevant
58 |
59 | Summary:"""
60 | else:
61 | # Refine existing summary with new information
62 | prompt = f"""Please refine the following information summary using newly retrieved information.
63 |
64 | Question: {question}
65 |
66 | Current summary:
67 | {current_summary}
68 |
69 | New information:
70 | {context_text}
71 |
72 | Your refined summary should:
73 | 1. Integrate new relevant facts with the existing summary
74 | 2. Remove redundancies
75 | 3. Remain concise while preserving all important information
76 | 4. Prioritize information that helps answer the question
77 | 5. Maintain specific details, dates, numbers, and names that may be relevant
78 |
79 | Refined summary:"""
80 |
81 | summary = get_response_with_retry(prompt)
82 | return summary
83 |
84 | except Exception as e:
85 | logger.error(f"Error generating/refining summary: {e}")
86 | # If error occurs, concatenate current summary with new contexts as fallback
87 | if current_summary:
88 | return f"{current_summary}\n\nNew information:\n{context_text}"
89 | return context_text
90 |
91 | def analyze_completeness(self, question: str, info_summary: str) -> Dict:
92 | """
93 | Analyze if the current information summary is sufficient to answer the question.
94 |
95 | Args:
96 | question: The original question
97 | info_summary: Current information summary
98 |
99 | Returns:
100 | Dictionary with analysis results
101 | """
102 | try:
103 | prompt = f"""Question: {question}
104 |
105 | Available Information:
106 | {info_summary}
107 |
108 | Based on the information provided, please analyze:
109 | 1. Can the question be answered completely with this information? (Yes/No)
110 | 2. What specific information is missing, if any?
111 | 3. What specific question should we ask to find the missing information?
112 | 4. Summarize our current understanding based on available information.
113 |
114 | Please format your response as a JSON object with these keys:
115 | - "can_answer": boolean
116 | - "missing_info": string
117 | - "subquery": string
118 | - "current_understanding": string"""
119 |
120 | try:
121 | response = get_response_with_retry(prompt)
122 |
123 | # Clean up response to ensure it's valid JSON
124 | response = response.strip()
125 |
126 | # Remove any markdown code block markers
127 | response = response.replace('```json', '').replace('```', '')
128 |
129 | # Try to find JSON-like content within the response
130 | import re
131 | json_match = re.search(r'\{.*\}', response, re.DOTALL)
132 | if json_match:
133 | response = json_match.group()
134 |
135 | # Parse the cleaned response
136 | result = json.loads(response)
137 |
138 | # Validate required fields
139 | required_fields = ["can_answer", "missing_info", "subquery", "current_understanding"]
140 | if not all(field in result for field in required_fields):
141 | raise ValueError("Missing required fields")
142 |
143 | # Ensure boolean type for can_answer
144 | result["can_answer"] = bool(result["can_answer"])
145 |
146 | # Ensure non-empty subquery
147 | if not result["subquery"]:
148 | result["subquery"] = question
149 |
150 | return result
151 |
152 | except json.JSONDecodeError as e:
153 | logger.error(f"JSON parsing error: {e}")
154 | logger.error(f"Raw response: {response}")
155 | return {
156 | "can_answer": True,
157 | "missing_info": "",
158 | "subquery": question,
159 | "current_understanding": "Failed to parse reflection response."
160 | }
161 |
162 | except Exception as e:
163 | logger.error(f"Error in analyze_completeness: {e}")
164 | return {
165 | "can_answer": True,
166 | "missing_info": "",
167 | "subquery": question,
168 | "current_understanding": f"Error during analysis: {str(e)}"
169 | }
170 |
171 | def generate_answer(self, question: str, info_summary: str) -> str:
172 | """Generate final answer based on the information summary."""
173 | try:
174 | prompt = f"""You must give ONLY the direct answer in the most concise way possible. DO NOT explain or provide any additional context.
175 | If the answer is a simple yes/no, just say "Yes." or "No."
176 | If the answer is a name, just give the name.
177 | If the answer is a date, just give the date.
178 | If the answer is a number, just give the number.
179 | If the answer requires a brief phrase, make it as concise as possible.
180 |
181 | Question: {question}
182 |
183 | Information Summary:
184 | {info_summary}
185 |
186 | Remember: Be concise - give ONLY the essential answer, nothing more.
187 | Ans: """
188 |
189 | return get_response_with_retry(prompt)
190 | except Exception as e:
191 | logger.error(f"Error generating answer: {e}")
192 | return ""
193 |
194 | def answer_question(self, question: str) -> Tuple[str, List[str], int]:
195 | """
196 | Answer question with iterative retrieval and information summary refinement.
197 |
198 | Returns:
199 | Tuple of (answer, last_retrieved_contexts, round_count)
200 | """
201 | info_summary = "" # Start with empty summary
202 | round_count = 0
203 | current_query = question
204 | retrieval_history = []
205 | last_contexts = [] # Store only the last retrieved contexts
206 |
207 | logger.info(f"LightAgenticRAG answering: {question}")
208 |
209 | while round_count < self.max_rounds:
210 | round_count += 1
211 | logger.info(f"Retrieval round {round_count}")
212 |
213 | # Retrieve relevant contexts for the current query
214 | new_contexts = self.retrieve(current_query)
215 | last_contexts = new_contexts # Save current contexts
216 |
217 | # Record retrieval history
218 | retrieval_history.append({
219 | "round": round_count,
220 | "query": current_query,
221 | "contexts": new_contexts
222 | })
223 |
224 | # Generate or refine information summary with new contexts
225 | info_summary = self.generate_or_refine_summary(
226 | question,
227 | new_contexts,
228 | info_summary
229 | )
230 |
231 | logger.info(f"Information summary after round {round_count} (length: {len(info_summary)})")
232 |
233 | # Analyze if we can answer the question with current summary
234 | analysis = self.analyze_completeness(question, info_summary)
235 |
236 | if analysis["can_answer"]:
237 | # Generate and return final answer
238 | answer = self.generate_answer(question, info_summary)
239 | # We return the last retrieved contexts for evaluation purposes
240 | return answer, last_contexts, round_count
241 |
242 | # Update query for next round
243 | current_query = analysis["subquery"]
244 | logger.info(f"Generated subquery: {current_query}")
245 |
246 | # If max rounds reached, generate best possible answer
247 | logger.info(f"Reached maximum rounds ({self.max_rounds}). Generating final answer...")
248 | answer = self.generate_answer(question, info_summary)
249 | return answer, last_contexts, round_count
--------------------------------------------------------------------------------
/src/models/vanilla_rag.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from typing import List, Dict, Tuple
3 | from src.models.base_rag import BaseRAG
4 | from src.utils.utils import get_response_with_retry
5 |
6 | # Configure logging
7 | logging.basicConfig(level=logging.INFO)
8 | logger = logging.getLogger(__name__)
9 |
10 | class VanillaRAG(BaseRAG):
11 | """
12 | VanillaRAG performs basic retrieval-augmented generation without iterative refinement.
13 | It inherits basic retrieval and embedding functionality from BaseRAG.
14 | """
15 |
16 | def __init__(self, corpus_path: str = None, cache_dir: str = "./cache"):
17 | """Initialize the VanillaRAG system."""
18 | super().__init__(corpus_path, cache_dir)
19 |
20 | def retrieve(self, query: str) -> List[str]:
21 | """Retrieve documents for the given query using vector similarity."""
22 | return super().retrieve(query)
23 |
24 | def answer_question(self, question: str) -> Tuple[str, List[str]]:
25 | """
26 | Answer a question using vanilla RAG approach:
27 | 1. Retrieve relevant contexts
28 | 2. Pass the question and contexts to the LLM to generate an answer
29 | """
30 | # Retrieve relevant contexts
31 | contexts = self.retrieve(question)
32 |
33 | # Generate answer using retrieved contexts
34 | answer = self.generate_answer(question, contexts)
35 |
36 | return answer, contexts
37 |
38 | def generate_answer(self, question: str, contexts: List[str]) -> str:
39 | """Generate an answer based on the retrieved contexts."""
40 | try:
41 | context_text = "\n".join(contexts)
42 |
43 | prompt = f"""You must give ONLY the direct answer in the most concise way possible. DO NOT explain or provide any additional context.
44 | If the answer is a simple yes/no, just say "Yes." or "No."
45 | If the answer is a name, just give the name.
46 | If the answer is a date, just give the date.
47 | If the answer is a number, just give the number.
48 | If the answer requires a brief phrase, make it as concise as possible.
49 |
50 | Question: {question}
51 |
52 | Context:
53 | {context_text}
54 |
55 | Remember: Be concise - give ONLY the essential answer, nothing more.
56 | Ans: """
57 |
58 | return get_response_with_retry(prompt)
59 | except Exception as e:
60 | logger.error(f"Error generating answer: {e}")
61 | return ""
--------------------------------------------------------------------------------
/src/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/chensyCN/Agentic-RAG/a17bdc8c8eaf231db5983a54703e4454f882ed9a/src/utils/__init__.py
--------------------------------------------------------------------------------
/src/utils/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import logging
3 | import re
4 | import json
5 | import time
6 | import backoff
7 | from openai import OpenAI
8 | from ratelimit import limits, sleep_and_retry
9 | from collections import Counter
10 | from typing import List, Dict, Any
11 | from config.config import (
12 | OPENAI_API_KEY,
13 | DEFAULT_MODEL,
14 | DEFAULT_MAX_TOKENS,
15 | CALLS_PER_MINUTE,
16 | PERIOD,
17 | MAX_RETRIES,
18 | RETRY_DELAY
19 | )
20 |
21 | # Configure logging
22 | logging.basicConfig(level=logging.INFO)
23 | logger = logging.getLogger(__name__)
24 |
25 | # Configure OpenAI
26 | client = OpenAI(api_key=OPENAI_API_KEY)
27 | os.environ["TOKENIZERS_PARALLELISM"] = "false"
28 |
29 | REFLECTION_PROMPT = """Based on the question and the retrieved context, analyze:
30 | 1. Can you confidently answer the question with the given context and your knowledge?
31 | 2. If not, what specific information is missing?
32 | 3. Generate a focused search query to find the missing information.
33 |
34 | Format your response as:
35 | {
36 | "can_answer": true/false,
37 | "missing_info": "description of what information is missing",
38 | "subquery": "specific search query for missing information",
39 | "current_understanding": "brief summary of current understanding"
40 | }
41 | """
42 |
43 | @sleep_and_retry
44 | @limits(calls=CALLS_PER_MINUTE, period=PERIOD)
45 | @backoff.on_exception(
46 | backoff.expo,
47 | (Exception),
48 | max_tries=MAX_RETRIES,
49 | max_time=300
50 | )
51 | def get_response_with_retry(prompt: str, temperature: float = 0.0) -> str:
52 | """Get response from OpenAI API with retry logic."""
53 | try:
54 | messages = [
55 | {"role": "system", "content": "You are a helpful assistant."},
56 | {"role": "user", "content": prompt}
57 | ]
58 | response = client.chat.completions.create(
59 | model=DEFAULT_MODEL,
60 | messages=messages,
61 | temperature=temperature,
62 | max_tokens=DEFAULT_MAX_TOKENS
63 | )
64 | return response.choices[0].message.content.strip()
65 | except Exception as e:
66 | logger.error(f"Error in get_response_with_retry: {str(e)}")
67 | return ""
68 |
69 | def normalize_answer(text: str) -> str:
70 | """Normalize answer text for comparison."""
71 | if not isinstance(text, str):
72 | return ""
73 |
74 | # Convert to lowercase
75 | text = text.lower()
76 |
77 | # Replace hyphen with space
78 | text = text.replace('-', ' ')
79 |
80 | # Remove punctuation
81 | text = re.sub(r'[^\w\s]', '', text)
82 |
83 | # Remove extra whitespace
84 | text = ' '.join(text.split())
85 |
86 | return text
87 |
88 | def save_results(results: Dict, output_file: str, results_dir: str = 'result'):
89 | """Save evaluation results to file.
90 |
91 | Args:
92 | results: Dictionary containing results to save
93 | output_file: Filename for the results
94 | results_dir: Directory to save results in
95 | """
96 | os.makedirs(results_dir, exist_ok=True)
97 | output_path = os.path.join(results_dir, output_file)
98 | with open(output_path, 'w', encoding='utf-8') as f:
99 | json.dump(results, f, ensure_ascii=False, indent=2)
100 | logger.info(f"Results saved to {output_path}")
101 |
102 | def evaluate_with_llm(generated: str, gold: str) -> bool:
103 | """Use LLM to evaluate if the generated answer correctly answers the question."""
104 | if not isinstance(generated, str) or not isinstance(gold, str):
105 | return False
106 |
107 | prompt = f"""You are an expert evaluator. Please evaluate if the generated answer is correct by comparing it with the gold answer.
108 |
109 | Generated answer: {generated}
110 | Gold answer: {gold}
111 |
112 | The generated answer should be considered correct if it:
113 | 1. Contains the key information from the gold answer
114 | 2. Is factually accurate and consistent with the gold answer
115 | 3. Does not contain any contradicting information
116 |
117 | Respond with ONLY 'correct' or 'incorrect'.
118 | Response:"""
119 |
120 | try:
121 | response = get_response_with_retry(prompt, temperature=0.0)
122 | return response.strip().lower() == "correct"
123 | except Exception as e:
124 | logger.error(f"Error in LLM evaluation: {e}")
125 | return False
126 |
127 | def string_based_evaluation(generated: str, gold: str) -> dict:
128 | """Evaluate string similarity between generated and gold answers.
129 |
130 | Args:
131 | generated: Generated answer string
132 | gold: Gold/ground truth answer string
133 |
134 | Returns:
135 | Dictionary containing accuracy, precision, recall metrics
136 | """
137 | # Normalize answers
138 | normalized_prediction = normalize_answer(generated)
139 | normalized_ground_truth = normalize_answer(gold)
140 |
141 | # Calculate accuracy
142 | accuracy = 1 if normalized_ground_truth in normalized_prediction else 0
143 |
144 | # Calculate precision and recall
145 | prediction_tokens = normalized_prediction.split()
146 | ground_truth_tokens = normalized_ground_truth.split()
147 |
148 | # Handle yes/no/noanswer cases
149 | if (normalized_prediction in ["yes", "no", "noanswer"] and
150 | normalized_prediction != normalized_ground_truth) or \
151 | (normalized_ground_truth in ["yes", "no", "noanswer"] and
152 | normalized_prediction != normalized_ground_truth):
153 | return {
154 | "accuracy": accuracy,
155 | "precision": 0,
156 | "recall": 0
157 | }
158 |
159 | # Calculate token overlap
160 | common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
161 | num_same = sum(common.values())
162 |
163 | # Calculate precision and recall
164 | precision = 1.0 * num_same / len(prediction_tokens) if prediction_tokens else 0
165 | recall = 1.0 * num_same / len(ground_truth_tokens) if ground_truth_tokens else 0
166 |
167 | return {
168 | "accuracy": accuracy,
169 | "precision": precision,
170 | "recall": recall
171 | }
--------------------------------------------------------------------------------