├── flashrag ├── __init__.py ├── version.py ├── judger │ ├── __init__.py │ └── judger.py ├── prompt │ ├── __init__.py │ ├── base_prompt.py │ └── selfask_examplars.py ├── config │ ├── __init__.py │ ├── basic_config.yaml │ └── config.py ├── utils │ ├── __init__.py │ ├── pred_parse.py │ ├── constants.py │ └── utils.py ├── dataset │ ├── __init__.py │ ├── utils.py │ └── dataset.py ├── refiner │ ├── __init__.py │ └── refiner.py ├── evaluator │ ├── __init__.py │ ├── utils.py │ ├── evaluator.py │ └── _bleu.py ├── retriever │ ├── __main__.py │ ├── __init__.py │ ├── utils.py │ ├── encoder.py │ └── reranker.py ├── generator │ ├── __init__.py │ ├── utils.py │ ├── stop_word_criteria.py │ ├── fid.py │ └── openai_generator.py └── pipeline │ ├── __init__.py │ ├── pipeline.py │ ├── branching_pipeline.py │ └── replug_utils.py ├── hopweaver ├── components │ ├── __init__.py │ ├── utils │ │ └── __init__.py │ ├── bridge │ │ └── __init__.py │ └── compare │ │ └── __init__.py ├── train_reranker │ ├── ds_stage0.json │ ├── README_EN.md │ └── train_reranker.py ├── config_lib │ └── example_config.yaml └── evaluation_system │ └── check_and_complete_evaluations.py ├── fig ├── intro.png ├── reranker.png ├── bridge_case.png ├── framework.png └── comparsion_case.png ├── requirements.txt ├── .gitignore ├── LICENSE ├── CORE_MODULES_CN.md ├── CORE_MODULES.md ├── datasets └── README.md └── ENVIRONMENT_SETUP_CN.md /flashrag/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /hopweaver/components/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /hopweaver/components/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /hopweaver/components/bridge/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /hopweaver/components/compare/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /flashrag/version.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.1.4dev0" 2 | -------------------------------------------------------------------------------- /flashrag/judger/__init__.py: -------------------------------------------------------------------------------- 1 | from flashrag.judger.judger import * -------------------------------------------------------------------------------- /flashrag/prompt/__init__.py: -------------------------------------------------------------------------------- 1 | from flashrag.prompt.base_prompt import * -------------------------------------------------------------------------------- /flashrag/config/__init__.py: -------------------------------------------------------------------------------- 1 | from flashrag.config.config import Config 2 | 3 | -------------------------------------------------------------------------------- /fig/intro.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zh1yuShen/HopWeaver/HEAD/fig/intro.png -------------------------------------------------------------------------------- /fig/reranker.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zh1yuShen/HopWeaver/HEAD/fig/reranker.png -------------------------------------------------------------------------------- /fig/bridge_case.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zh1yuShen/HopWeaver/HEAD/fig/bridge_case.png -------------------------------------------------------------------------------- /fig/framework.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zh1yuShen/HopWeaver/HEAD/fig/framework.png -------------------------------------------------------------------------------- /fig/comparsion_case.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Zh1yuShen/HopWeaver/HEAD/fig/comparsion_case.png -------------------------------------------------------------------------------- /flashrag/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from flashrag.utils.utils import * 2 | from flashrag.utils.pred_parse import * -------------------------------------------------------------------------------- /flashrag/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from flashrag.dataset.dataset import * 2 | from flashrag.dataset.utils import * 3 | -------------------------------------------------------------------------------- /flashrag/refiner/__init__.py: -------------------------------------------------------------------------------- 1 | from flashrag.refiner.refiner import * 2 | from flashrag.refiner.kg_refiner import * -------------------------------------------------------------------------------- /flashrag/evaluator/__init__.py: -------------------------------------------------------------------------------- 1 | from flashrag.evaluator.evaluator import * 2 | from flashrag.evaluator.metrics import * 3 | -------------------------------------------------------------------------------- /flashrag/retriever/__main__.py: -------------------------------------------------------------------------------- 1 | from . import index_builder 2 | 3 | if __name__ == "__main__": 4 | index_builder.main() 5 | -------------------------------------------------------------------------------- /flashrag/retriever/__init__.py: -------------------------------------------------------------------------------- 1 | from flashrag.retriever.retriever import * 2 | from flashrag.retriever.reranker import * 3 | from flashrag.retriever.utils import * -------------------------------------------------------------------------------- /flashrag/generator/__init__.py: -------------------------------------------------------------------------------- 1 | from flashrag.generator.generator import * 2 | from flashrag.generator.openai_generator import * 3 | from flashrag.generator.utils import * 4 | -------------------------------------------------------------------------------- /flashrag/pipeline/__init__.py: -------------------------------------------------------------------------------- 1 | from flashrag.pipeline.pipeline import * 2 | from flashrag.pipeline.branching_pipeline import REPLUGPipeline, SuRePipeline 3 | from flashrag.pipeline.active_pipeline import IterativePipeline, SelfRAGPipeline, FLAREPipeline, SelfAskPipeline, IRCOTPipeline, RQRAGPipeline -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | transformers>=4.40.0 3 | openai 4 | datasets 5 | pandas 6 | numpy 7 | json5 8 | spacy 9 | regex 10 | faiss-cpu 11 | FlagEmbedding 12 | rank_bm25 13 | bm25s[core]==0.2.0 14 | rouge 15 | rouge-chinese 16 | evaluate 17 | tiktoken 18 | peft 19 | sentencepiece 20 | fschat 21 | protobuf 22 | PyYAML 23 | tqdm 24 | deepspeed 25 | base58 26 | matplotlib 27 | seaborn 28 | scipy 29 | scikit-learn 30 | pyarrow 31 | streamlit 32 | chonkie>=1.0.2, <1.1.0 33 | -------------------------------------------------------------------------------- /flashrag/evaluator/utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | import string 3 | 4 | 5 | def normalize_answer(s): 6 | def remove_articles(text): 7 | return re.sub(r"\b(a|an|the)\b", " ", text) 8 | 9 | def white_space_fix(text): 10 | return " ".join(text.split()) 11 | 12 | def remove_punc(text): 13 | exclude = set(string.punctuation) 14 | return "".join(ch for ch in text if ch not in exclude) 15 | 16 | def lower(text): 17 | return text.lower() 18 | 19 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 20 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Python 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | *.so 6 | .Python 7 | build/ 8 | develop-eggs/ 9 | dist/ 10 | downloads/ 11 | eggs/ 12 | .eggs/ 13 | lib/ 14 | lib64/ 15 | parts/ 16 | sdist/ 17 | var/ 18 | wheels/ 19 | *.egg-info/ 20 | .installed.cfg 21 | *.egg 22 | 23 | # 虚拟环境 24 | venv/ 25 | env/ 26 | ENV/ 27 | .env 28 | 29 | # 日志文件 30 | *.log 31 | 32 | # 本地配置文件 33 | *.local.yaml 34 | *.local.yml 35 | local_config.yaml 36 | 37 | # IDE配置 38 | .idea/ 39 | .vscode/ 40 | *.swp 41 | *.swo 42 | 43 | # 系统文件 44 | .DS_Store 45 | Thumbs.db 46 | 47 | # 数据集和大型文件(可根据需要修改) 48 | # datasets/ 49 | # models/ 50 | *.model 51 | *.bin 52 | *.pt 53 | *.pth 54 | 55 | # API密钥和敏感信息 56 | .env 57 | secrets.yaml 58 | 59 | # 缓存 60 | .cache/ 61 | -------------------------------------------------------------------------------- /flashrag/utils/pred_parse.py: -------------------------------------------------------------------------------- 1 | def selfask_pred_parse(pred): 2 | """Parsing the prediction results of self-ask format.""" 3 | FINAL_ANSWER_PREFIX = "So the final answer is: " 4 | 5 | lines = pred.split("\n") 6 | answer = "" 7 | for line in lines: 8 | if FINAL_ANSWER_PREFIX in line: 9 | answer = line.split(FINAL_ANSWER_PREFIX)[1].strip() 10 | break 11 | 12 | return answer 13 | 14 | 15 | def ircot_pred_parse(pred): 16 | FINAL_ANSWER_PREFIX = "So the answer is:" 17 | if FINAL_ANSWER_PREFIX in pred: 18 | answer = pred.split(FINAL_ANSWER_PREFIX)[1].strip() 19 | else: 20 | answer = pred 21 | return answer 22 | 23 | 24 | def basic_pred_parse(pred): 25 | return pred.split("\n")[0].strip() 26 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 HopWeaver Contributors 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /hopweaver/train_reranker/ds_stage0.json: -------------------------------------------------------------------------------- 1 | { 2 | "zero_optimization": { 3 | "stage": 0 4 | }, 5 | 6 | "fp16": { 7 | "enabled": "auto", 8 | "loss_scale": 0, 9 | "initial_scale_power": 10, 10 | "loss_scale_window": 1000, 11 | "hysteresis": 2, 12 | "min_loss_scale": 1 13 | }, 14 | 15 | "optimizer": { 16 | "type": "AdamW", 17 | "params": { 18 | "lr": "auto", 19 | "betas": "auto", 20 | "eps": "auto", 21 | "weight_decay": "auto", 22 | "torch_adam": true 23 | } 24 | }, 25 | 26 | "scheduler": { 27 | "type": "WarmupDecayLR", 28 | "params": { 29 | "warmup_min_lr": "auto", 30 | "warmup_max_lr": "auto", 31 | "warmup_num_steps": "auto", 32 | "total_num_steps": "auto" 33 | } 34 | }, 35 | 36 | "gradient_accumulation_steps": "auto", 37 | "gradient_clipping": "auto", 38 | "steps_per_print": 1000, 39 | "train_batch_size": "auto", 40 | "train_micro_batch_size_per_gpu": "auto", 41 | "wall_clock_breakdown": false 42 | } 43 | -------------------------------------------------------------------------------- /flashrag/utils/constants.py: -------------------------------------------------------------------------------- 1 | OPENAI_MODEL_DICT = { 2 | # chat 3 | "gpt-4o": "o200k_base", 4 | "gpt-4": "cl100k_base", 5 | "gpt-3.5-turbo": "cl100k_base", 6 | "gpt-3.5": "cl100k_base", # Common shorthand 7 | "gpt-35-turbo": "cl100k_base", # Azure deployment name 8 | # base 9 | "davinci-002": "cl100k_base", 10 | "babbage-002": "cl100k_base", 11 | # embeddings 12 | "text-embedding-ada-002": "cl100k_base", 13 | "text-embedding-3-small": "cl100k_base", 14 | "text-embedding-3-large": "cl100k_base", 15 | # DEPRECATED MODELS 16 | # text (DEPRECATED) 17 | "text-davinci-003": "p50k_base", 18 | "text-davinci-002": "p50k_base", 19 | "text-davinci-001": "r50k_base", 20 | "text-curie-001": "r50k_base", 21 | "text-babbage-001": "r50k_base", 22 | "text-ada-001": "r50k_base", 23 | "davinci": "r50k_base", 24 | "curie": "r50k_base", 25 | "babbage": "r50k_base", 26 | "ada": "r50k_base", 27 | # code (DEPRECATED) 28 | "code-davinci-002": "p50k_base", 29 | "code-davinci-001": "p50k_base", 30 | "code-cushman-002": "p50k_base", 31 | "code-cushman-001": "p50k_base", 32 | "davinci-codex": "p50k_base", 33 | "cushman-codex": "p50k_base", 34 | # edit (DEPRECATED) 35 | "text-davinci-edit-001": "p50k_edit", 36 | "code-davinci-edit-001": "p50k_edit", 37 | # old embeddings (DEPRECATED) 38 | "text-similarity-davinci-001": "r50k_base", 39 | "text-similarity-curie-001": "r50k_base", 40 | "text-similarity-babbage-001": "r50k_base", 41 | "text-similarity-ada-001": "r50k_base", 42 | "text-search-davinci-doc-001": "r50k_base", 43 | "text-search-curie-doc-001": "r50k_base", 44 | "text-search-babbage-doc-001": "r50k_base", 45 | "text-search-ada-doc-001": "r50k_base", 46 | "code-search-babbage-code-001": "r50k_base", 47 | "code-search-ada-code-001": "r50k_base", 48 | # open source 49 | "gpt2": "gpt2", 50 | "gpt-2": "gpt2", # Maintains consistency with gpt-4 51 | } 52 | -------------------------------------------------------------------------------- /CORE_MODULES_CN.md: -------------------------------------------------------------------------------- 1 | ### 1. 桥接问题合成流程 2 | 3 | 桥接问题合成包含以下关键步骤: 4 | 5 | - **桥接实体识别**:从随机选取的源文档中,系统识别可以连接不同信息上下文的桥接实体,为多跳推理提供关键枢纽 6 | 7 | - **两阶段粗到细检索**: 8 | - 粗粒度检索:使用修改版最大边际相关性算法,平衡查询相关性、与源文档的差异性和已选文档间的多样性 9 | 10 | **多样性检索评分函数:** 11 | 12 | 多样性检索使用修改版最大边际相关性(MMR)算法: 13 | 14 | $$\text{Score}(d_i) = \lambda_1 \cdot \text{sim}(q, d_i) - \lambda_2 \cdot \text{sim}(d_i, d_s) - \lambda_3 \cdot \max_{d_j \in S} \text{sim}(d_i, d_j)$$ 15 | 16 | 其中: 17 | - $q$ 是查询 18 | - $d_i$ 是候选文档 19 | - $d_s$ 是源文档 20 | - $S$ 是已选文档集合 21 | - $\text{sim}(\cdot, \cdot)$ 表示余弦相似度 22 | - $\lambda_1, \lambda_2, \lambda_3$ 为权重参数,满足 $\lambda_1 + \lambda_2 + \lambda_3 = 1$ 23 | 24 | 此公式被 **diverse** 和 **rerank** 检索方法在粗检索阶段共同使用。 25 | 26 | - 细粒度重排序:使用经过对比学习微调的重排模型,进一步优化候选文档的排序 27 | 28 | - **多跳问题构建**: 29 | - 子问题合成:分别从源文档和补充文档合成子问题,以桥接实体为中心 30 | - 问题合成:将子问题融合为单一连贯的多跳问题,隐含推理路径而不直接暴露桥接实体 31 | - 验证与迭代:确保问题满足可回答性、多跳性和无捷径约束 32 | 33 | ### 2. 比较问题合成流程 34 | 35 | 比较问题合成遵循以下步骤: 36 | 37 | - **实体与属性识别**:从文档中识别主要实体及其3-5个简洁的事实属性值对,筛选出适合比较的属性 38 | 39 | - **筛选与查询合成**: 40 | - 确保实体和属性的具体性与可比性 41 | - 根据源实体合成检索查询,采用直接推荐或多样化搜索策略 42 | 43 | - **问题构建**: 44 | - 引导式比较:针对特定实体和属性进行精确比较 45 | - 开放式发现:在多个属性中寻找第一个有效的可比对 46 | - 合成包含两个实体信息的比较问题,如"哪个实体的属性值更高/更早/更大?" 47 | 48 | ### 3. 问题润色与质量保证 49 | 50 | 在桥接和比较问题合成过程中,系统实施严格的质量控制机制: 51 | 52 | - **问题润色与验证模块**: 53 | - 评估问题的可回答性、多跳性和语言质量 54 | - 根据评估结果分类为通过、调整、重构或拒绝四种结果 55 | - 确保每个问题涉及跨文档推理并隐藏桥接实体 56 | - 维持流畅性,不暴露中间推理步骤 57 | 58 | ### 4. 重排模型训练与优化 59 | 60 | 系统通过模拟关键步骤合成监督信号,提高检索质量: 61 | 62 | - **模拟反馈合成**: 63 | - 从桥接问题合成过程中提取成功和失败的文档样例 64 | - 构建对比训练三元组(查询、正例文档、负例文档) 65 | 66 | - **对比学习优化**: 67 | - 使用交叉熵损失函数指导模型区分互补文档 68 | - 直接从下游任务成功率中获取监督信号 69 | 70 | ### 5. 多维度评估系统 71 | 72 | 系统采用全面的评估框架,确保合成问题的质量: 73 | 74 | - **LLM-as-Judge评估**: 75 | - 使用大型语言模型作为评判,采用李克特量表评估每个问题 76 | - 实现自一致性评估方法,确保评估结果的稳定性和可重现性 77 | - 通过多次重复评估同一输入,分析评估结果的一致性 78 | 79 | - **可回答性和难度评估**: 80 | - **Q-Only条件**:求解器仅接收问题,测试问题的基线可回答性,主要依赖求解器的内部知识和推理能力 81 | - **Q+Docs条件**:求解器接收问题及所有支撑文档,模拟黄金检索场景,评估问题在获得必要证据时的可回答性 82 | - **性能差异分析**:通过Q-Only到Q+Docs的性能提升来判断问题是否具有挑战性,需要跨文档推理而非仅依赖预训练知识 83 | 84 | - **证据可获取性评估**: 85 | - **检索质量评估**:使用多种检索方法获取top-k文档,评估合成问题的证据在语料库中的可获取程度 86 | - **多维检索指标**:记录MAP(平均精度)、RECALL@k(前k召回率)、NDCG@k(归一化折扣累积增益)和Support F1等指标 87 | - **证据完整性验证**:确保合成的问题具有完整的证据支撑,避免无法回答的问题进入最终数据集 -------------------------------------------------------------------------------- /flashrag/dataset/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Any, Union 2 | import numpy as np 3 | from flashrag.dataset import Dataset 4 | 5 | 6 | def convert_numpy(obj: Union[Dict, list, np.ndarray, np.generic]) -> Any: 7 | """Recursively convert numpy objects in nested dictionaries or lists to native Python types.""" 8 | if isinstance(obj, dict): 9 | return {k: convert_numpy(v) for k, v in obj.items()} 10 | elif isinstance(obj, list): 11 | return [convert_numpy(i) for i in obj] 12 | elif isinstance(obj, np.ndarray): 13 | return obj.tolist() # Convert numpy arrays to lists 14 | elif isinstance(obj, (np.integer, np.floating)): 15 | return obj.item() # Convert numpy scalars to native Python scalars 16 | elif isinstance(obj, np.float32): 17 | return float(obj) 18 | else: 19 | return obj # Return the object as-is if it's neither a dict, list, nor numpy type 20 | 21 | 22 | def filter_dataset(dataset: Dataset, filter_func=None): 23 | if filter_func is None: 24 | return dataset 25 | data = dataset.data 26 | for item in data: 27 | if not filter_func(item): 28 | data.remove(item) 29 | return Dataset(config=dataset.config, data=data) 30 | 31 | 32 | def split_dataset(dataset: Dataset, split_symbol: list): 33 | assert len(split_symbol) == len(dataset) 34 | 35 | data = dataset.data 36 | data_split = {symbol: [] for symbol in set(split_symbol)} 37 | for symbol in set(split_symbol): 38 | symbol_data = [x for x, x_symbol in zip(data, split_symbol) if x_symbol == symbol] 39 | data_split[symbol] = Dataset(config=dataset.config, data=symbol_data) 40 | 41 | return data_split 42 | 43 | 44 | def merge_dataset(dataset_split: dict, split_symbol: list): 45 | assert len(split_symbol) == sum([len(data) for data in dataset_split.values()]) 46 | dataset_split_iter = {symbol: iter(dataset.data) for symbol, dataset in dataset_split.items()} 47 | 48 | final_data = [] 49 | for item_symbol in split_symbol: 50 | final_data.append(next(dataset_split_iter[item_symbol])) 51 | final_dataset = Dataset(config=list(dataset_split.values())[0].config, data=final_data) 52 | 53 | return final_dataset 54 | 55 | 56 | def get_batch_dataset(dataset: Dataset, batch_size=16): 57 | data = dataset.data 58 | for idx in range(0, len(data), batch_size): 59 | batched_data = data[idx : idx + batch_size] 60 | batch_dataset = Dataset(config=dataset.config, data=batched_data) 61 | yield batch_dataset 62 | 63 | 64 | def merge_batch_dataset(dataset_list: Dataset): 65 | dataset = dataset_list[0] 66 | total_data = [] 67 | for batch_dataset in dataset_list: 68 | total_data.extend(batch_dataset.data) 69 | dataset = Dataset(config=dataset.config, data=total_data) 70 | return dataset 71 | -------------------------------------------------------------------------------- /flashrag/generator/utils.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | 4 | def resolve_max_tokens(params: dict, generation_params: dict, prioritize_new_tokens: bool = False) -> dict: 5 | """ 6 | Resolve and validate max_tokens parameters from both params and generation_params. 7 | 8 | Args: 9 | params: Dictionary containing user-provided parameters 10 | generation_params: Dictionary containing generation-specific parameters 11 | prioritize_new_tokens: If True, max_new_tokens takes precedence over max_tokens 12 | If False, max_tokens takes precedence (default behavior) 13 | 14 | Returns: 15 | Updated generation_params dictionary 16 | """ 17 | 18 | def get_token_params(param_dict: dict) -> tuple: 19 | """Extract max_tokens and max_new_tokens from a parameter dictionary.""" 20 | return (param_dict.pop("max_tokens", None), param_dict.pop("max_new_tokens", None)) 21 | 22 | def resolve_tokens(max_tokens: int, max_new_tokens: int) -> int: 23 | """ 24 | Resolve between max_tokens and max_new_tokens values based on priority. 25 | Returns the resolved token value or None if no valid value found. 26 | """ 27 | # If either value is None, return the non-None value 28 | if max_tokens is None: 29 | return max_new_tokens 30 | if max_new_tokens is None: 31 | return max_tokens 32 | 33 | # Both values exist but are different 34 | if max_tokens != max_new_tokens: 35 | if prioritize_new_tokens: 36 | warnings.warn( 37 | f"max_tokens ({max_tokens}) and max_new_tokens ({max_new_tokens}) " 38 | f"are different. Using max_new_tokens value as it has priority." 39 | ) 40 | return max_new_tokens 41 | else: 42 | warnings.warn( 43 | f"max_tokens ({max_tokens}) and max_new_tokens ({max_new_tokens}) " 44 | f"are different. Using max_tokens value as it has priority." 45 | ) 46 | return max_tokens 47 | 48 | # Both values are equal 49 | return max_tokens 50 | 51 | # Try to resolve from params first, then fall back to generation_params 52 | max_tokens, max_new_tokens = get_token_params(params) 53 | final_max_tokens = resolve_tokens(max_tokens, max_new_tokens) 54 | 55 | # If no valid tokens found in params, try generation_params 56 | if final_max_tokens is None: 57 | max_tokens, max_new_tokens = get_token_params(generation_params) 58 | final_max_tokens = resolve_tokens(max_tokens, max_new_tokens) 59 | 60 | generation_params.pop("max_new_tokens", None) 61 | generation_params.pop("max_tokens", None) 62 | if final_max_tokens is not None: 63 | if prioritize_new_tokens: 64 | generation_params["max_new_tokens"] = final_max_tokens 65 | else: 66 | generation_params["max_tokens"] = final_max_tokens 67 | return generation_params 68 | -------------------------------------------------------------------------------- /flashrag/evaluator/evaluator.py: -------------------------------------------------------------------------------- 1 | import os 2 | from flashrag.evaluator.metrics import BaseMetric 3 | 4 | 5 | class Evaluator: 6 | """Evaluator is used to summarize the results of all metrics.""" 7 | 8 | def __init__(self, config): 9 | self.config = config 10 | self.save_dir = config["save_dir"] 11 | 12 | self.save_metric_flag = config["save_metric_score"] 13 | self.save_data_flag = config["save_intermediate_data"] 14 | self.metrics = [metric.lower() for metric in self.config["metrics"]] 15 | 16 | self.avaliable_metrics = self._collect_metrics() 17 | 18 | self.metric_class = {} 19 | for metric in self.metrics: 20 | if metric in self.avaliable_metrics: 21 | self.metric_class[metric] = self.avaliable_metrics[metric](self.config) 22 | else: 23 | print(f"{metric} has not been implemented!") 24 | raise NotImplementedError 25 | 26 | def _collect_metrics(self): 27 | """Collect all classes based on ```BaseMetric```.""" 28 | 29 | def find_descendants(base_class, subclasses=None): 30 | if subclasses is None: 31 | subclasses = set() 32 | 33 | direct_subclasses = base_class.__subclasses__() 34 | for subclass in direct_subclasses: 35 | if subclass not in subclasses: 36 | subclasses.add(subclass) 37 | find_descendants(subclass, subclasses) 38 | return subclasses 39 | 40 | avaliable_metrics = {} 41 | for cls in find_descendants(BaseMetric): 42 | metric_name = cls.metric_name 43 | avaliable_metrics[metric_name] = cls 44 | return avaliable_metrics 45 | 46 | def evaluate(self, data): 47 | """Calculate all metric indicators and summarize them.""" 48 | 49 | result_dict = {} 50 | for metric in self.metrics: 51 | try: 52 | metric_result, metric_scores = self.metric_class[metric].calculate_metric(data) 53 | result_dict.update(metric_result) 54 | 55 | for metric_score, item in zip(metric_scores, data): 56 | item.update_evaluation_score(metric, metric_score) 57 | except Exception as e: 58 | print(f"Error in {metric}!") 59 | print(e) 60 | continue 61 | 62 | if self.save_metric_flag: 63 | self.save_metric_score(result_dict) 64 | 65 | if self.save_data_flag: 66 | self.save_data(data) 67 | 68 | return result_dict 69 | 70 | def save_metric_score(self, result_dict, file_name="metric_score.txt"): 71 | save_path = os.path.join(self.save_dir, file_name) 72 | with open(save_path, "w", encoding="utf-8") as f: 73 | for k, v in result_dict.items(): 74 | f.write(f"{k}: {v}\n") 75 | 76 | def save_data(self, data, file_name="intermediate_data.json"): 77 | """Save the evaluated data, including the raw data and the score of each data 78 | sample on each metric.""" 79 | 80 | save_path = os.path.join(self.save_dir, file_name) 81 | 82 | data.save(save_path) 83 | -------------------------------------------------------------------------------- /hopweaver/config_lib/example_config.yaml: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------Global Paths------------------------------------------------# 2 | # Paths to models 3 | model2path: 4 | e5: "ret_model/e5-base-v2" 5 | gte: "ret_model/iic/gte_sentence-embedding_multilingual-base" 6 | e5_instruct: "ret_model/Ceceliachenen/multilingual-e5-large-instruct" 7 | bge: "BAAI/bge-base-en-v1.5" 8 | contriever: "facebook/contriever" 9 | llama2-7B-chat: "meta-llama/Llama-2-7b-chat-hf" 10 | llama2-7B: "meta-llama/Llama-2-7b-hf" 11 | llama2-13B: "meta-llama/Llama-2-13b-hf" 12 | llama2-13B-chat: "meta-llama/Llama-2-13b-chat-hf" 13 | llama3-8B-instruct: "model/Llama-3.1-8B-Instruct" 14 | 15 | openai_setting: 16 | # Pooling methods for each embedding model 17 | model2pooling: 18 | e5: "mean" 19 | gte: "cls" 20 | e5_instruct: "mean" 21 | bge: "cls" 22 | contriever: "mean" 23 | jina: 'mean' 24 | dpr: cls 25 | 26 | # Indexes path for retrieval models 27 | method2index: 28 | e5: 'index/e5_Flat_fulldoc.index' 29 | gte: 'index/gte_Flat.index' 30 | e5_instruct: 'index/e5-in_Flat.index' 31 | bm25: '' 32 | contriever: ~ 33 | 34 | method2corpus: 35 | e5: 'wiki18_fulldoc.jsonl' 36 | gte: 'wiki18_fulldoc_trimmed_4096.jsonl' 37 | e5_instruct: 'wiki18_fulldoc_trimmed_4096_chunk.jsonl' 38 | bm25: 'wiki18_fulldoc_trimmed_4096.jsonl' 39 | 40 | 41 | # ------------------------------------------------Environment Settings------------------------------------------------# 42 | # Directory paths for data and outputs 43 | data_dir: "dataset/" 44 | save_dir: "output/" 45 | output_dir: "output_wiki" 46 | gpu_id: "0" 47 | dataset_name: "nq" # name of the dataset in data_dir 48 | split: [ "test" ] # dataset split to load (e.g. train,dev,test) 49 | 50 | # Sampling configurations for testing 51 | test_sample_num: ~ # number of samples to test (only work in dev/test split), if None, test all samples 52 | random_sample: False # whether to randomly sample the test samples 53 | 54 | # Seed for reproducibility 55 | seed: 2025 56 | 57 | # Whether save intermediate data 58 | save_intermediate_data: True 59 | save_note: 'experiment' 60 | 61 | # -------------------------------------------------Retrieval Settings------------------------------------------------# 62 | # If set the name, the model path will be find in global paths 63 | retrieval_method: "gte" # name or path of the retrieval model. 64 | faiss_gpu: False # whether use gpu to hold index 65 | corpus_path: 'wiki18_fulldoc_trimmed_4096.jsonl' # path to corpus in '.jsonl' format that store the documents 66 | 67 | instruction: ~ # instruction for retrieval model 68 | retrieval_topk: 5 # number of retrieved documents 69 | retrieval_batch_size: 256 # batch size for retrieval 70 | retrieval_use_fp16: True # whether to use fp16 for retrieval model 71 | retrieval_query_max_length: 4096 # max length of the query 72 | save_retrieval_cache: False # whether to save the retrieval cache 73 | use_retrieval_cache: False # whether to use the retrieval cache 74 | retrieval_cache_path: ~ # path to the retrieval cache 75 | retrieval_pooling_method: ~ # set automatically if not provided 76 | 77 | use_reranker: False # whether to use reranker 78 | rerank_model_name: ~ # same as retrieval_method 79 | rerank_model_path: ~ # path to reranker model, path will be automatically find in `retriever_model2path` 80 | rerank_pooling_method: ~ 81 | rerank_topk: 5 # number of remain documents after reranking 82 | rerank_max_length: 4096 83 | rerank_batch_size: 256 # batch size for reranker 84 | rerank_use_fp16: True 85 | 86 | # -------------------------------------------------Generator Settings------------------------------------------------# 87 | generator_max_input_len: 10000 # max length of the input 88 | generator_batch_size: 2 # batch size for generation, invalid for vllm 89 | generation_params: 90 | do_sample: False 91 | temperature: 0 92 | top_p: 0.9 93 | max_tokens: 4096 94 | use_fid: False # whether to use FID, only valid in encoder-decoder model 95 | 96 | generator_model: "gpt-4" 97 | entity_extractor_model: "gpt-4" 98 | question_generator_model: "gpt-4" 99 | polisher_model: "gpt-4" 100 | filter_model: "gpt-4" 101 | 102 | # API Settings - Replace with your own API keys or remove if not needed 103 | # These are just placeholders and should be replaced with your own keys 104 | openai_setting: 105 | api_key: "YOUR_API_KEY_HERE" 106 | base_url: "https://api.openai.com/v1" 107 | -------------------------------------------------------------------------------- /flashrag/config/basic_config.yaml: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------Global Paths------------------------------------------------# 2 | # Paths to various models 3 | model2path: 4 | e5: "intfloat/e5-base-v2" 5 | bge: "BAAI/bge-base-en-v1.5" 6 | contriever: "facebook/contriever" 7 | llama2-7B-chat: "meta-llama/Llama-2-7b-chat-hf" 8 | llama2-7B: "meta-llama/Llama-2-7b-hf" 9 | llama2-13B: "meta-llama/Llama-2-13b-hf" 10 | llama2-13B-chat: "meta-llama/Llama-2-13b-chat-hf" 11 | 12 | # Pooling methods for each embedding model 13 | model2pooling: 14 | e5: "mean" 15 | bge: "cls" 16 | contriever: "mean" 17 | jina: 'mean' 18 | dpr: cls 19 | 20 | # Indexes path for retrieval models 21 | method2index: 22 | e5: ~ 23 | bm25: ~ 24 | contriever: ~ 25 | 26 | # ------------------------------------------------Environment Settings------------------------------------------------# 27 | # Directory paths for data and outputs 28 | data_dir: "dataset/" 29 | save_dir: "output/" 30 | 31 | gpu_id: "0,1,2,3" 32 | dataset_name: "nq" # name of the dataset in data_dir 33 | split: ["test"] # dataset split to load (e.g. train,dev,test) 34 | 35 | # Sampling configurations for testing 36 | test_sample_num: ~ # number of samples to test (only work in dev/test split), if None, test all samples 37 | random_sample: False # whether to randomly sample the test samples 38 | 39 | # Seed for reproducibility 40 | seed: 2024 41 | 42 | # Whether save intermediate data 43 | save_intermediate_data: True 44 | save_note: 'experiment' 45 | 46 | # -------------------------------------------------Retrieval Settings------------------------------------------------# 47 | # If set the name, the model path will be find in global paths 48 | retrieval_method: "e5" # name or path of the retrieval model. 49 | retrieval_model_path: ~ # path to the retrieval model 50 | index_path: ~ # set automatically if not provided. 51 | faiss_gpu: False # whether use gpu to hold index 52 | corpus_path: ~ # path to corpus in '.jsonl' format that store the documents 53 | 54 | instruction: ~ # instruction for the retrieval model 55 | retrieval_topk: 5 # number of retrieved documents 56 | retrieval_batch_size: 256 # batch size for retrieval 57 | retrieval_use_fp16: True # whether to use fp16 for retrieval model 58 | retrieval_query_max_length: 128 # max length of the query 59 | save_retrieval_cache: False # whether to save the retrieval cache 60 | use_retrieval_cache: False # whether to use the retrieval cache 61 | retrieval_cache_path: ~ # path to the retrieval cache 62 | retrieval_pooling_method: ~ # set automatically if not provided 63 | bm25_backend: bm25s # pyserini, bm25s 64 | use_sentence_transformer: False 65 | 66 | use_reranker: False # whether to use reranker 67 | rerank_model_name: ~ # same as retrieval_method 68 | rerank_model_path: ~ # path to reranker model, path will be automatically find in `model2path` 69 | rerank_pooling_method: ~ 70 | rerank_topk: 5 # number of remain documents after reranking 71 | rerank_max_length: 512 72 | rerank_batch_size: 256 # batch size for reranker 73 | rerank_use_fp16: True 74 | 75 | # -------------------------------------------------Generator Settings------------------------------------------------# 76 | framework: fschat # inference frame work of LLM, supporting: 'hf','vllm','fschat', 'openai' 77 | generator_model: "llama3-8B-instruct" # name or path of the generator model 78 | # setting for openai model, only valid in openai framework 79 | openai_setting: 80 | api_key: ~ 81 | base_url: ~ 82 | 83 | generator_model_path: ~ 84 | generator_max_input_len: 1024 # max length of the input 85 | generator_batch_size: 4 # batch size for generation, invalid for vllm 86 | generation_params: 87 | #do_sample: false 88 | max_tokens: 32 89 | #temperature: 1.0 90 | #top_p: 1.0 91 | use_fid: False # whether to use FID, only valid in encoder-decoder model 92 | gpu_memory_utilization: 0.85 # ratio of gpu's memory usage for generator 93 | 94 | # -------------------------------------------------Evaluation Settings------------------------------------------------# 95 | # Metrics to evaluate the result 96 | metrics: ['em','f1','acc','precision','recall','input_tokens'] 97 | # Specify setting for metric, will be called within certain metrics 98 | metric_setting: 99 | retrieval_recall_topk: 5 100 | tokenizer_name: 'gpt-4' 101 | save_metric_score: True # whether to save the metric score into txt file 102 | 103 | 104 | 105 | -------------------------------------------------------------------------------- /flashrag/retriever/utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import warnings 3 | from typing import Dict, Any, Union, List, Dict 4 | import numpy as np 5 | import datasets 6 | from transformers import AutoTokenizer, AutoModel, AutoConfig 7 | 8 | def convert_numpy(obj: Union[Dict, list, np.ndarray, np.generic]) -> Any: 9 | """Recursively convert numpy objects in nested dictionaries or lists to native Python types.""" 10 | if isinstance(obj, dict): 11 | return {k: convert_numpy(v) for k, v in obj.items()} 12 | elif isinstance(obj, list): 13 | return [convert_numpy(i) for i in obj] 14 | elif isinstance(obj, np.ndarray): 15 | return obj.tolist() # Convert numpy arrays to lists 16 | elif isinstance(obj, (np.integer, np.floating)): 17 | return obj.item() # Convert numpy scalars to native Python scalars 18 | elif isinstance(obj, np.float32): 19 | return float(obj) 20 | else: 21 | return obj # Return the object as-is if it's neither a dict, list, nor numpy type 22 | 23 | 24 | def load_model(model_path: str, use_fp16: bool = False): 25 | model_config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) 26 | model = AutoModel.from_pretrained(model_path, trust_remote_code=True) 27 | model.eval() 28 | model.cuda() 29 | if use_fp16: 30 | model = model.half() 31 | tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True, trust_remote_code=True) 32 | 33 | return model, tokenizer 34 | 35 | 36 | def pooling(pooler_output, last_hidden_state, attention_mask=None, pooling_method="mean"): 37 | if pooling_method == "mean": 38 | last_hidden = last_hidden_state.masked_fill(~attention_mask[..., None].bool(), 0.0) 39 | return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None] 40 | elif pooling_method == "cls": 41 | return last_hidden_state[:, 0] 42 | elif pooling_method == "pooler": 43 | return pooler_output 44 | else: 45 | raise NotImplementedError("Pooling method not implemented!") 46 | 47 | 48 | def set_default_instruction(model_name, is_query=True, is_zh=False): 49 | instruction = "" 50 | if "e5" in model_name.lower(): 51 | if is_query: 52 | instruction = "query: " 53 | else: 54 | instruction = "passage: " 55 | 56 | if "bge" in model_name.lower(): 57 | if is_query: 58 | if "zh" in model_name.lower() or is_zh: 59 | instruction = "为这个句子生成表示以用于检索相关文章:" 60 | else: 61 | instruction = "Represent this sentence for searching relevant passages: " 62 | 63 | return instruction 64 | 65 | 66 | def parse_query(model_name, query_list, instruction=None): 67 | """ 68 | processing query for different encoders 69 | """ 70 | 71 | def is_zh(str): 72 | import unicodedata 73 | 74 | zh_char = 0 75 | for c in str: 76 | try: 77 | if "CJK" in unicodedata.name(c): 78 | zh_char += 1 79 | except: 80 | continue 81 | if len(str) == 0: 82 | return False 83 | if zh_char / len(str) > 0.2: 84 | return True 85 | else: 86 | return False 87 | 88 | if isinstance(query_list, str): 89 | query_list = [query_list] 90 | 91 | if instruction is not None: 92 | instruction = instruction.strip() + " " 93 | else: 94 | instruction = set_default_instruction(model_name, is_query=True, is_zh=is_zh(query_list[0])) 95 | print(f"Use `{instruction}` as retreival instruction") 96 | 97 | query_list = [instruction + query for query in query_list] 98 | 99 | return query_list 100 | 101 | 102 | def load_corpus(corpus_path: str): 103 | corpus = datasets.load_dataset("json", data_files=corpus_path, split="train") 104 | return corpus 105 | 106 | 107 | def read_jsonl(file_path): 108 | with open(file_path, "r") as f: 109 | while True: 110 | new_line = f.readline() 111 | if not new_line: 112 | return 113 | new_item = json.loads(new_line) 114 | 115 | yield new_item 116 | 117 | 118 | def load_docs(corpus, doc_idxs): 119 | results = [corpus[int(idx)] for idx in doc_idxs] 120 | 121 | return results 122 | 123 | 124 | def parse_image(image): 125 | from PIL import Image 126 | 127 | if isinstance(image, str): 128 | if image.startswith("http"): 129 | import requests 130 | 131 | image = Image.open(requests.get(image, stream=True).raw) 132 | else: 133 | image = Image.open(image) 134 | return image 135 | -------------------------------------------------------------------------------- /flashrag/generator/stop_word_criteria.py: -------------------------------------------------------------------------------- 1 | """ 2 | Created by Nestor Demeure. 3 | This software is released under the Apache License 2.0. 4 | """ 5 | 6 | from typing import List 7 | import torch 8 | from transformers import StoppingCriteria, AutoTokenizer 9 | 10 | 11 | class StopWordCriteria(StoppingCriteria): 12 | """ 13 | A stopping criteria that halts the text generation process if any specified stop word is encountered. 14 | 15 | Inspired by https://discuss.huggingface.co/t/implimentation-of-stopping-criteria-list/20040/9 16 | And: https://github.com/outlines-dev/outlines/blob/main/outlines/generate/api.py 17 | """ 18 | 19 | def __init__(self, tokenizer: AutoTokenizer, prompts: List[str], stop_words: List[str] = [], check_every: int = 1): 20 | """ 21 | Initializes the StopWordCriteria with the necessary parameters for checking stop words during text generation. 22 | 23 | Parameters: 24 | tokenizer (AutoTokenizer): The tokenizer for encoding prompts and stop words. 25 | prompts (List[str]): Initial prompts used for generation, needed to determine where generated text begins. 26 | stop_words (List[str]): Words that trigger the stopping of generation when detected. 27 | check_every (int): Frequency of checking for stop words in the token stream (a performance optimization, use 1 to cut it out). 28 | """ 29 | super().__init__() 30 | self.tokenizer = tokenizer 31 | self.input_sizes = [self.tokenizer.encode(prompt, return_tensors="pt").size(-1) for prompt in prompts] 32 | self.stop_words = stop_words 33 | self.max_stop_word_size = max( 34 | (self.tokenizer.encode(word, return_tensors="pt").size(-1) for word in stop_words), default=0 35 | ) 36 | self.check_every = check_every 37 | 38 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: 39 | """ 40 | Determines whether to stop generation based on the presence of stop words. 41 | 42 | Stops if a stop word is found in *all* batch elements *and* the sequence length is a multiple of `check_every`. 43 | Note: Delay in stopping may occur if `check_every > 1`. 44 | 45 | Parameters: 46 | input_ids (torch.LongTensor): Generated token IDs. 47 | scores (torch.FloatTensor): Generation scores for each token. Not used here. 48 | 49 | Returns: 50 | bool: True to stop generation, False to continue. 51 | """ 52 | batch_size, seq_len = input_ids.shape 53 | 54 | # Skip check if no stop words are defined or it is not yet time to check 55 | if (len(self.stop_words) == 0) or (seq_len % self.check_every != 0): 56 | return False 57 | 58 | for i in range(batch_size): 59 | # Calculate starting index for new tokens 60 | prompt_size = self.input_sizes[i] 61 | max_new_tokens = (2 * self.max_stop_word_size) + self.check_every 62 | latest_tokens = input_ids[i, prompt_size:][-max_new_tokens:] 63 | 64 | # Check for stop words in the decoded text 65 | if not any( 66 | word in self.tokenizer.decode(latest_tokens, skip_special_tokens=True) for word in self.stop_words 67 | ): 68 | return False # Continue generation if any batch item lacks stop words 69 | 70 | return True # Stop generation if all conditions are met 71 | 72 | def extract_answers(self, input_ids: torch.LongTensor, strip_stopword: bool = True) -> List[str]: 73 | """ 74 | Extracts generated answers by removing prompts and optionally stopping at the first stop word. 75 | 76 | Parameters: 77 | input_ids (torch.LongTensor): Generated token IDs. 78 | strip_stopword (bool): Determines whether the stop word is removed from the output. 79 | 80 | Returns: 81 | List[str]: Extracted answers, with or without stop words. 82 | """ 83 | batch_size, _ = input_ids.shape 84 | result = [] 85 | 86 | for i in range(batch_size): 87 | # Decode generated tokens to text, excluding the prompt 88 | prompt_size = self.input_sizes[i] 89 | answer_tokens = input_ids[i, prompt_size:] 90 | answer_text = self.tokenizer.decode(answer_tokens, skip_special_tokens=True) 91 | 92 | # Find the first occurrence of any stop word 93 | lower_stop_index = len(answer_text) # Default to end of text 94 | for word in self.stop_words: 95 | stop_index = answer_text.find(word) 96 | if stop_index != -1: 97 | # Adjust stop index based on whether we're stripping the stop word 98 | stop_index += 0 if strip_stopword else len(word) 99 | lower_stop_index = min(stop_index, lower_stop_index) 100 | 101 | # Cut the text at the first stop word found (if any) 102 | answer_text = answer_text[:lower_stop_index] 103 | result.append(answer_text) 104 | 105 | return result 106 | -------------------------------------------------------------------------------- /flashrag/retriever/encoder.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | import torch 3 | import numpy as np 4 | from flashrag.retriever.utils import load_model, pooling, parse_query 5 | 6 | 7 | class Encoder: 8 | """ 9 | Encoder class for encoding queries using a specified model. 10 | 11 | Attributes: 12 | model_name (str): The name of the model. 13 | model_path (str): The path to the model. 14 | pooling_method (str): The method used for pooling. 15 | max_length (int): The maximum length of the input sequences. 16 | use_fp16 (bool): Whether to use FP16 precision. 17 | instruction (str): Additional instructions for parsing queries. 18 | 19 | Methods: 20 | encode(query_list: List[str], is_query=True) -> np.ndarray: 21 | Encodes a list of queries into embeddings. 22 | """ 23 | 24 | def __init__(self, model_name, model_path, pooling_method, max_length, use_fp16, instruction): 25 | self.model_name = model_name 26 | self.model_path = model_path 27 | self.pooling_method = pooling_method 28 | self.max_length = max_length 29 | self.use_fp16 = use_fp16 30 | self.instruction = instruction 31 | 32 | self.model, self.tokenizer = load_model(model_path=model_path, use_fp16=use_fp16) 33 | 34 | @torch.inference_mode() 35 | def encode(self, query_list: List[str], is_query=True) -> np.ndarray: 36 | query_list = parse_query(self.model_name, query_list, self.instruction) 37 | 38 | inputs = self.tokenizer( 39 | query_list, max_length=self.max_length, padding=True, truncation=True, return_tensors="pt" 40 | ) 41 | inputs = {k: v.cuda() for k, v in inputs.items()} 42 | 43 | if "T5" in type(self.model).__name__: 44 | # T5-based retrieval model 45 | decoder_input_ids = torch.zeros((inputs["input_ids"].shape[0], 1), dtype=torch.long).to( 46 | inputs["input_ids"].device 47 | ) 48 | output = self.model(**inputs, decoder_input_ids=decoder_input_ids, return_dict=True) 49 | query_emb = output.last_hidden_state[:, 0, :] 50 | 51 | else: 52 | output = self.model(**inputs, return_dict=True) 53 | query_emb = pooling( 54 | output.pooler_output, output.last_hidden_state, inputs["attention_mask"], self.pooling_method 55 | ) 56 | query_emb = torch.nn.functional.normalize(query_emb, dim=-1) 57 | query_emb = query_emb.detach().cpu().numpy() 58 | query_emb = query_emb.astype(np.float32, order="C") 59 | return query_emb 60 | 61 | 62 | class STEncoder: 63 | """ 64 | STEncoder class for encoding queries using SentenceTransformers. 65 | 66 | Attributes: 67 | model_name (str): The name of the model. 68 | model_path (str): The path to the model. 69 | max_length (int): The maximum length of the input sequences. 70 | use_fp16 (bool): Whether to use FP16 precision. 71 | instruction (str): Additional instructions for parsing queries. 72 | 73 | Methods: 74 | encode(query_list: List[str], batch_size=64, is_query=True) -> np.ndarray: 75 | Encodes a list of queries into embeddings. 76 | multi_gpu_encode(query_list: List[str], is_query=True, batch_size=None) -> np.ndarray: 77 | Encodes a list of queries into embeddings using multiple GPUs. 78 | """ 79 | 80 | def __init__(self, model_name, model_path, max_length, use_fp16, instruction): 81 | import torch 82 | from sentence_transformers import SentenceTransformer 83 | 84 | self.model_name = model_name 85 | self.model_path = model_path 86 | self.max_length = max_length 87 | self.use_fp16 = use_fp16 88 | self.instruction = instruction 89 | self.model = SentenceTransformer( 90 | model_path, trust_remote_code=True, model_kwargs={"torch_dtype": torch.float16 if use_fp16 else torch.float} 91 | ) 92 | 93 | @torch.inference_mode() 94 | def encode(self, query_list: List[str], batch_size=64, is_query=True) -> np.ndarray: 95 | query_list = parse_query(self.model_name, query_list, self.instruction) 96 | query_emb = self.model.encode( 97 | query_list, batch_size=batch_size, convert_to_numpy=True, normalize_embeddings=True, show_progress_bar=True 98 | ) 99 | query_emb = query_emb.astype(np.float32, order="C") 100 | 101 | return query_emb 102 | 103 | @torch.inference_mode() 104 | def multi_gpu_encode(self, query_list: List[str], is_query=True, batch_size=None) -> np.ndarray: 105 | query_list = parse_query(self.model_name, query_list, self.instruction) 106 | pool = self.model.start_multi_process_pool() 107 | query_emb = self.model.encode_multi_process( 108 | query_list, 109 | pool, 110 | convert_to_numpy=True, 111 | normalize_embeddings=True, 112 | batch_size=batch_size, 113 | show_progress_bar=True, 114 | ) 115 | self.model.stop_multi_process_pool(pool) 116 | query_emb = query_emb.astype(np.float32, order="C") 117 | 118 | return query_emb 119 | -------------------------------------------------------------------------------- /flashrag/utils/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import importlib 3 | from transformers import AutoConfig 4 | from flashrag.dataset.dataset import Dataset 5 | 6 | 7 | def get_dataset(config): 8 | """Load dataset from config.""" 9 | 10 | dataset_path = config["dataset_path"] 11 | all_split = config["split"] 12 | 13 | split_dict = {split: None for split in all_split} 14 | 15 | for split in all_split: 16 | split_path = os.path.join(dataset_path, f"{split}.jsonl") 17 | if not os.path.exists(split_path): 18 | print(f"{split} file not exists!") 19 | continue 20 | if split in ["test", "val", "dev"]: 21 | split_dict[split] = Dataset( 22 | config, split_path, sample_num=config["test_sample_num"], random_sample=config["random_sample"] 23 | ) 24 | else: 25 | split_dict[split] = Dataset(config, split_path) 26 | 27 | return split_dict 28 | 29 | 30 | def get_generator(config, **params): 31 | """Automatically select generator class based on config.""" 32 | if config["framework"] == "vllm": 33 | return getattr(importlib.import_module("flashrag.generator"), "VLLMGenerator")(config, **params) 34 | elif config["framework"] == "fschat": 35 | return getattr(importlib.import_module("flashrag.generator"), "FastChatGenerator")(config, **params) 36 | elif config["framework"] == "hf": 37 | model_config = AutoConfig.from_pretrained(config["generator_model_path"]) 38 | arch = model_config.architectures[0] 39 | if "t5" in arch.lower() or "bart" in arch.lower() or 'fusionindecoder' in arch.lower(): 40 | return getattr(importlib.import_module("flashrag.generator"), "EncoderDecoderGenerator")(config, **params) 41 | else: 42 | return getattr(importlib.import_module("flashrag.generator"), "HFCausalLMGenerator")(config, **params) 43 | elif config["framework"] == "openai": 44 | return getattr(importlib.import_module("flashrag.generator"), "OpenaiGenerator")(config, **params) 45 | else: 46 | raise NotImplementedError 47 | 48 | 49 | def get_retriever(config): 50 | r"""Automatically select retriever class based on config's retrieval method 51 | 52 | Args: 53 | config (dict): configuration with 'retrieval_method' key 54 | 55 | Returns: 56 | Retriever: retriever instance 57 | """ 58 | if config["retrieval_method"] == "bm25": 59 | return getattr(importlib.import_module("flashrag.retriever"), "BM25Retriever")(config) 60 | else: 61 | return getattr(importlib.import_module("flashrag.retriever"), "DenseRetriever")(config) 62 | 63 | 64 | def get_reranker(config): 65 | model_path = config["rerank_model_path"] 66 | # get model config 67 | model_config = AutoConfig.from_pretrained(model_path) 68 | arch = model_config.architectures[0] 69 | if "forsequenceclassification" in arch.lower(): 70 | return getattr(importlib.import_module("flashrag.retriever"), "CrossReranker")(config) 71 | else: 72 | return getattr(importlib.import_module("flashrag.retriever"), "BiReranker")(config) 73 | 74 | 75 | def get_judger(config): 76 | judger_name = config["judger_name"] 77 | if "skr" in judger_name.lower(): 78 | return getattr(importlib.import_module("flashrag.judger"), "SKRJudger")(config) 79 | elif "adaptive" in judger_name.lower(): 80 | return getattr(importlib.import_module("flashrag.judger"), "AdaptiveJudger")(config) 81 | else: 82 | assert False, "No implementation!" 83 | 84 | 85 | def get_refiner(config, retriever=None, generator=None): 86 | # 预定义默认路径字典 87 | DEFAULT_PATH_DICT = { 88 | "recomp_abstractive_nq": "fangyuan/nq_abstractive_compressor", 89 | "recomp:abstractive_tqa": "fangyuan/tqa_abstractive_compressor", 90 | "recomp:abstractive_hotpotqa": "fangyuan/hotpotqa_abstractive", 91 | } 92 | REFINER_MODULE = importlib.import_module("flashrag.refiner") 93 | 94 | refiner_name = config["refiner_name"] 95 | refiner_path = ( 96 | config["refiner_model_path"] 97 | if config["refiner_model_path"] is not None 98 | else DEFAULT_PATH_DICT.get(refiner_name, None) 99 | ) 100 | 101 | try: 102 | model_config = AutoConfig.from_pretrained(refiner_path) 103 | arch = model_config.architectures[0].lower() 104 | print(arch) 105 | except Exception as e: 106 | print("Warning", e) 107 | model_config, arch = "", "" 108 | 109 | if "recomp" in refiner_name or "bert" in arch: 110 | if model_config.model_type == "t5": 111 | refiner_class = "AbstractiveRecompRefiner" 112 | else: 113 | refiner_class = "ExtractiveRefiner" 114 | elif "lingua" in refiner_name: 115 | refiner_class = "LLMLinguaRefiner" 116 | elif "selective-context" in refiner_name or "sc" in refiner_name: 117 | refiner_class = "SelectiveContextRefiner" 118 | elif "kg-trace" in refiner_name: 119 | return getattr(REFINER_MODULE, "KGTraceRefiner")(config, retriever, generator) 120 | else: 121 | raise ValueError("No implementation!") 122 | 123 | return getattr(REFINER_MODULE, refiner_class)(config) 124 | 125 | 126 | def hash_object(o) -> str: 127 | """Returns a character hash code of arbitrary Python objects.""" 128 | import hashlib 129 | import io 130 | import dill 131 | import base58 132 | 133 | m = hashlib.blake2b() 134 | with io.BytesIO() as buffer: 135 | dill.dump(o, buffer) 136 | m.update(buffer.getbuffer()) 137 | return base58.b58encode(m.digest()).decode() 138 | -------------------------------------------------------------------------------- /CORE_MODULES.md: -------------------------------------------------------------------------------- 1 | ### 1. Bridge Question Synthesis Process 2 | 3 | The bridge question synthesis includes the following key steps: 4 | 5 | - **Bridge Entity Identification**: From randomly selected source documents, the system identifies bridge entities that can connect different information contexts, providing key pivots for multi-hop reasoning 6 | 7 | - **Two-stage Coarse-to-Fine Retrieval**: 8 | - Coarse-grained Retrieval: Using a modified maximum marginal relevance algorithm to balance query relevance, diversity from source documents, and diversity among selected documents 9 | 10 | **Diverse Retrieval Scoring Function:** 11 | 12 | The diverse retrieval uses a modified Maximum Marginal Relevance (MMR) algorithm: 13 | 14 | $$\text{Score}(d_i) = \lambda_1 \cdot \text{sim}(q, d_i) - \lambda_2 \cdot \text{sim}(d_i, d_s) - \lambda_3 \cdot \max_{d_j \in S} \text{sim}(d_i, d_j)$$ 15 | 16 | Where: 17 | - $q$ is the query 18 | - $d_i$ is the candidate document 19 | - $d_s$ is the source document 20 | - $S$ is the set of already selected documents 21 | - $\text{sim}(\cdot, \cdot)$ represents cosine similarity 22 | - $\lambda_1, \lambda_2, \lambda_3$ are weighting parameters with $\lambda_1 + \lambda_2 + \lambda_3 = 1$ 23 | 24 | This formula is used by both **diverse** and **rerank** retrieval methods in their coarse retrieval stage. 25 | 26 | - Fine-grained Reranking: Using a reranking model fine-tuned through contrastive learning to further optimize the ranking of candidate documents 27 | 28 | - **Multi-hop Question Construction**: 29 | - Sub-question Synthesis: Synthesize sub-questions from source and supplementary documents respectively, centered around the bridge entity 30 | - Question Synthesis: Merge sub-questions into a single coherent multi-hop question, implying the reasoning path without directly exposing the bridge entity 31 | - Validation and Iteration: Ensure questions meet answerability, multi-hop nature, and no-shortcut constraints 32 | 33 | ### 2. Comparison Question Synthesis Process 34 | 35 | Comparison question synthesis follows these steps: 36 | 37 | - **Entity and Attribute Identification**: Identify main entities from documents and their 3-5 concise factual attribute-value pairs, filtering out attributes suitable for comparison 38 | 39 | - **Filtering and Query Synthesis**: 40 | - Ensure specificity and comparability of entities and attributes 41 | - Synthesize retrieval queries based on source entities, using direct recommendation or diversified search strategies 42 | 43 | - **Question Construction**: 44 | - Guided Comparison: Precise comparison for specific entities and attributes 45 | - Open Discovery: Find the first valid comparable pair among multiple attributes 46 | - Synthesize comparison questions containing information about two entities, such as "Which entity has a higher/earlier/larger attribute value?" 47 | 48 | ### 3. Question Refinement and Quality Assurance 49 | 50 | During the bridge and comparison question synthesis process, the system implements strict quality control mechanisms: 51 | 52 | - **Question Refinement and Validation Module**: 53 | - Evaluate questions for answerability, multi-hop nature, and language quality 54 | - Classify evaluation results into four categories: pass, adjust, reconstruct, or reject 55 | - Ensure each question involves cross-document reasoning and hides bridge entities 56 | - Maintain fluency without exposing intermediate reasoning steps 57 | 58 | ### 4. Reranker Model Training and Optimization 59 | 60 | The system synthesizes supervision signals by simulating key steps to improve retrieval quality: 61 | 62 | - **Simulated Feedback Synthesis**: 63 | - Extract successful and failed document examples from the bridge question synthesis process 64 | - Construct contrastive training triplets (query, positive document, negative document) 65 | 66 | - **Contrastive Learning Optimization**: 67 | - Use cross-entropy loss function to guide the model in distinguishing complementary documents 68 | - Obtain supervision signals directly from downstream task success rates 69 | 70 | ### 5. Multi-dimensional Evaluation System 71 | 72 | The system employs a comprehensive evaluation framework to ensure question quality: 73 | 74 | - **LLM-as-Judge Evaluation**: 75 | - Use large language models as judges, employing Likert scales to evaluate each question 76 | - Implement self-consistency evaluation methods to ensure stability and reproducibility of evaluation results 77 | - Analyze consistency of evaluation results by repeatedly evaluating the same input 78 | 79 | - **Answerability and Difficulty Evaluation**: 80 | - **Q-Only Condition**: Solver receives only the question, testing baseline answerability using the solver's internal knowledge and reasoning capabilities 81 | - **Q+Docs Condition**: Solver receives the question and all supporting documents, simulating a golden retrieval scenario to evaluate answerability when necessary evidence is available 82 | - **Performance Gap Analysis**: Performance improvement from Q-Only to Q+Docs indicates whether the question is challenging and requires cross-document reasoning rather than relying solely on pre-trained knowledge 83 | 84 | - **Evidence-Accessibility Evaluation**: 85 | - **Retrieval Quality Assessment**: Use multiple retrieval methods to fetch top-k documents and evaluate the accessibility of synthesized question evidence in the corpus 86 | - **Multi-dimensional Retrieval Metrics**: Record MAP (Mean Average Precision), RECALL@k, NDCG@k (Normalized Discounted Cumulative Gain), and Support F1 metrics 87 | - **Evidence Completeness Verification**: Ensure synthesized questions have complete evidence support, preventing unanswerable questions from entering the final dataset -------------------------------------------------------------------------------- /flashrag/retriever/reranker.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | import torch 3 | import warnings 4 | import numpy as np 5 | from tqdm import tqdm 6 | from transformers import AutoTokenizer, AutoModelForSequenceClassification 7 | from flashrag.retriever.encoder import Encoder 8 | 9 | 10 | class BaseReranker: 11 | r"""Base object for all rerankers.""" 12 | 13 | def __init__(self, config): 14 | self.config = config 15 | self.reranker_model_name = config["rerank_model_name"] 16 | self.reranker_model_path = config["rerank_model_path"] 17 | self.topk = config["rerank_topk"] 18 | self.max_length = config["rerank_max_length"] 19 | self.batch_size = config["rerank_batch_size"] 20 | self.device = config["device"] 21 | 22 | def get_rerank_scores(self, query_list: List[str], doc_list: List[str], batch_size): 23 | """Return flatten list of scores for each (query,doc) pair 24 | Args: 25 | query_list: List of N queries 26 | doc_list: Nested list of length N, each element corresponds to K documents of a query 27 | 28 | Return: 29 | [score(q1,d1), score(q1,d2),... score(q2,d1),...] 30 | """ 31 | all_scores = [] 32 | return all_scores 33 | 34 | @torch.inference_mode(mode=True) 35 | def rerank(self, query_list, doc_list, batch_size=None, topk=None): 36 | r"""Rerank doc_list.""" 37 | if batch_size is None: 38 | batch_size = self.batch_size 39 | if topk is None: 40 | topk = self.topk 41 | if isinstance(query_list, str): 42 | query_list = [query_list] 43 | if not isinstance(doc_list[0], list): 44 | doc_list = [doc_list] 45 | 46 | assert len(query_list) == len(doc_list) 47 | if topk < min([len(docs) for docs in doc_list]): 48 | warnings.warn("The number of doc returned by the retriever is less than the topk.") 49 | 50 | # get doc contents 51 | doc_contents = [] 52 | for docs in doc_list: 53 | if all([isinstance(doc, str) for doc in docs]): 54 | doc_contents.append([doc for doc in docs]) 55 | else: 56 | doc_contents.append([doc["contents"] for doc in docs]) 57 | 58 | all_scores = self.get_rerank_scores(query_list, doc_contents, batch_size) 59 | assert len(all_scores) == sum([len(docs) for docs in doc_list]) 60 | 61 | # sort docs 62 | start_idx = 0 63 | final_scores = [] 64 | final_docs = [] 65 | for docs in doc_list: 66 | doc_scores = all_scores[start_idx : start_idx + len(docs)] 67 | doc_scores = [float(score) for score in doc_scores] 68 | sort_idxs = np.argsort(doc_scores)[::-1][:topk] 69 | start_idx += len(docs) 70 | 71 | final_docs.append([docs[idx] for idx in sort_idxs]) 72 | final_scores.append([doc_scores[idx] for idx in sort_idxs]) 73 | 74 | return final_docs, final_scores 75 | 76 | 77 | class CrossReranker(BaseReranker): 78 | def __init__(self, config): 79 | super().__init__(config) 80 | self.tokenizer = AutoTokenizer.from_pretrained(self.reranker_model_path) 81 | self.ranker = AutoModelForSequenceClassification.from_pretrained(self.reranker_model_path, num_labels=1) 82 | self.ranker.eval() 83 | self.ranker.to(self.device) 84 | 85 | @torch.inference_mode(mode=True) 86 | def get_rerank_scores(self, query_list, doc_list, batch_size): 87 | # flatten all pairs 88 | all_pairs = [] 89 | for query, docs in zip(query_list, doc_list): 90 | all_pairs.extend([[query, doc] for doc in docs]) 91 | all_scores = [] 92 | for start_idx in tqdm(range(0, len(all_pairs), batch_size), desc="Reranking process: "): 93 | pair_batch = all_pairs[start_idx : start_idx + batch_size] 94 | 95 | inputs = self.tokenizer( 96 | pair_batch, padding=True, truncation=True, return_tensors="pt", max_length=self.max_length 97 | ).to(self.device) 98 | batch_scores = ( 99 | self.ranker(**inputs, return_dict=True) 100 | .logits.view( 101 | -1, 102 | ) 103 | .float() 104 | .cpu() 105 | ) 106 | all_scores.extend(batch_scores) 107 | 108 | return all_scores 109 | 110 | 111 | class BiReranker(BaseReranker): 112 | def __init__(self, config): 113 | super().__init__(config) 114 | self.encoder = Encoder( 115 | model_name=self.reranker_model_name, 116 | model_path=self.reranker_model_path, 117 | pooling_method=config["rerank_pooling_method"], 118 | max_length=self.max_length, 119 | use_fp16=config["rerank_use_fp16"], 120 | ) 121 | 122 | def get_rerank_scores(self, query_list, doc_list, batch_size): 123 | query_emb = [] 124 | for start_idx in range(0, len(query_list), batch_size): 125 | query_batch = query_list[start_idx : start_idx + batch_size] 126 | batch_emb = self.encoder.encode(query_batch, is_query=True) 127 | query_emb.append(batch_emb) 128 | query_emb = np.concatenate(query_emb, axis=0) 129 | 130 | flat_doc_list = sum(doc_list, []) 131 | doc_emb = [] 132 | for start_idx in range(0, len(flat_doc_list), batch_size): 133 | doc_batch = flat_doc_list[start_idx : start_idx + batch_size] 134 | batch_emb = self.encoder.encode(doc_batch, is_query=False) 135 | doc_emb.append(batch_emb) 136 | doc_emb = np.concatenate(doc_emb, axis=0) 137 | 138 | scores = query_emb @ doc_emb.T # K*L 139 | all_scores = [] 140 | score_idx = 0 141 | for idx, doc in enumerate(doc_list): 142 | all_scores.extend(scores[idx, score_idx : score_idx + len(doc)]) 143 | score_idx += len(doc) 144 | 145 | return all_scores 146 | -------------------------------------------------------------------------------- /hopweaver/train_reranker/README_EN.md: -------------------------------------------------------------------------------- 1 | # Reranker Model Training Guide 2 | 3 | This directory contains tools and datasets for training and evaluating retrieval reranker models. Reranker models can improve the quality of document ranking in retrieval systems, which is crucial for multi-hop question answering systems. 4 | 5 | ## Directory Structure 6 | 7 | ``` 8 | train_reranker/ 9 | ├── data/ # Training and test data 10 | │ ├── test_data.jsonl # Test dataset 11 | │ ├── train_data.jsonl # Training dataset 12 | │ └── test_data_sample.jsonl # Test data sample 13 | ├── contrastive_data_generator.py # Contrastive learning data generator 14 | ├── ds_stage0.json # DeepSpeed configuration file (optimizes training efficiency) 15 | ├── rerank_ablation_test.py # Reranker model ablation study tool 16 | └── train_reranker.py # Reranker model training script 17 | ``` 18 | 19 | ## Data Preparation 20 | 21 | ### Using Pre-generated Data 22 | 23 | The project already includes a complete training dataset (`data/train_data.jsonl`) and test dataset (`data/test_data.jsonl`), which can be directly used for model training. 24 | 25 | ### Generating New Training Data 26 | 27 | If you need to generate new training data, you can use the `contrastive_data_generator.py` script: 28 | 29 | ```bash 30 | conda activate llm 31 | python contrastive_data_generator.py 32 | ``` 33 | 34 | This script will: 35 | 1. Read documents from the corpus 36 | 2. Extract entities and generate potential queries 37 | 3. Retrieve relevant documents as positive samples 38 | 4. Generate negative samples 39 | 5. Save the results to the `./data/` directory 40 | 41 | ### Data Format for Fine-tuning 42 | 43 | The common data format for fine-tuning reranker models within the FlagEmbedding framework is a JSONL (JSON Lines) file. Each line in this file represents a training sample, formatted as a JSON object. Each JSON object should contain the following necessary keys: 44 | 45 | - `"query"`: Represents the search query, type string. 46 | - `"pos"`: Contains a list of positive example documents, where each document is a string relevant to the given query. 47 | - `"neg"`: Contains a list of negative example documents, where each document is a string irrelevant to the query. Although optional, including negative samples is highly recommended as it can significantly improve the model's ability to distinguish between relevant and irrelevant documents. 48 | 49 | Here is a specific example of a JSON object in a JSONL file: 50 | 51 | ```json 52 | {"query": "Explain theory of relativity.", "pos": ["The theory of relativity, proposed by Albert Einstein, describes the relationship between space and time."], "neg": ["Quantum mechanics is a fundamental theory in physics."]} 53 | ``` 54 | 55 | An example command to generate data: 56 | 57 | ```bash 58 | python ./contrastive_data_generator.py --config ../config_lib/extract_config_wikifulldoc.yaml --num_examples 1000 --max_doc_candidates 5 --lambda1 0.85 --lambda2 0.05 --lambda3 0.1 59 | ``` 60 | 61 | ## Model Training 62 | 63 | ### Full Model Training 64 | 65 | Train using the full training dataset: 66 | 67 | ```bash 68 | python train_reranker.py --train_data ./data/train_data.jsonl --output_dir ./output --epochs 2 --batch_size 16 --gradient_accumulation_steps 4 --learning_rate 5e-6 --use_deepspeed --ds_config ./ds_stage0.json 69 | ``` 70 | 71 | **Parameter Description**: 72 | - `--model_path`: Base model path, defaults to `./models/bge-reranker-v2-m3` 73 | - `--train_data`: Training data file 74 | - `--output_dir`: Output model directory 75 | - `--epochs`: Number of training epochs, 2-3 epochs are recommended 76 | - `--batch_size`: Batch size, adjust according to GPU memory 77 | - `--learning_rate`: Learning rate, 5e-6 to 7e-6 is recommended 78 | - `--gradient_accumulation_steps`: Gradient accumulation steps, used to increase effective batch size 79 | - `--use_deepspeed`: Enable DeepSpeed for training optimization 80 | - `--ds_config`: DeepSpeed configuration file, use `./ds_stage0.json` 81 | 82 | ## Model Evaluation 83 | 84 | ### Ablation Study 85 | 86 | Use `rerank_ablation_test.py` for comparative experiments of different retrieval strategies: 87 | 88 | ```bash 89 | conda activate llm 90 | python rerank_ablation_test.py --config_path ./config_lib/example_config.yaml --output_file ./ablation_results/results.json 91 | ``` 92 | 93 | This will test the performance of the following retrievers: 94 | 1. Base Retriever 95 | 2. Diverse Retriever 96 | 3. Zero-shot Reranker Retriever 97 | 4. Diverse + Zero-shot Reranker Retriever 98 | 5. Diverse + Fine-tuned Reranker Retriever 99 | 100 | ### Specific Test Commands 101 | 102 | Test fine-tuned model: 103 | ```bash 104 | conda activate llm && python test_reranker.py --model_path=./output_new --test_data=./data/test_data.jsonl --output_file=./finetune_results.json 105 | 106 | conda activate llm && python test_reranker.py --model_path=./output_new --test_data=./data/test_data.jsonl --output_file=./new_model_results.json 107 | ``` 108 | 109 | Ablation test examples: 110 | ```bash 111 | conda activate llm && python ./rerank_ablation_test.py --config ../config_lib/extract_config_wikifulldoc.yaml --num 50 --candidates 5 112 | 113 | conda activate llm && CUDA_VISIBLE_DEVICES=3 python ./rerank_ablation_test.py --config ../config_lib/extract_config_wikifulldoc.yaml --num 50 --candidates 5 --single diverse 114 | 115 | conda activate llm && CUDA_VISIBLE_DEVICES=3 python ./rerank_ablation_test.py --config ../config_lib/extract_config_wikifulldoc.yaml --num 50 --candidates 5 --doc_ids ./sampled_doc_ids.txt 116 | ``` 117 | Supported ablation types: `standard`, `diverse`, `diverse_zs`, `diverse_ft` 118 | 119 | ## Notes 120 | 121 | 1. Before running for the first time, please ensure that the base model has been downloaded in the `./models/` directory. 122 | 2. The training process may require a large amount of GPU memory. It is recommended to use a GPU with at least 16GB of video memory. 123 | 3. Using DeepSpeed can reduce video memory requirements and improve training efficiency. 124 | 4. Training results will be saved in the specified output directory and can be directly used in the retrieval system. 125 | -------------------------------------------------------------------------------- /datasets/README.md: -------------------------------------------------------------------------------- 1 | # Dataset Processing and Sampling Script (process_and_sample_datasets.py) 2 | 3 | ## Overview 4 | 5 | This Python script is designed to process and sample questions from various multi-hop question-answering datasets: HotpotQA, 2WikiMQA, and MuSiQue. The primary goal is to extract different types of questions (e.g., bridge, comparison) from each dataset and save them as separate files. This allows for targeted evaluation of models on specific categories of human-authored questions. 6 | 7 | ## Features 8 | 9 | - **Processes Multiple Datasets**: Handles HotpotQA, 2WikiMQA, and MuSiQue datasets. 10 | - **Type-Based Extraction**: Identifies and extracts questions based on their reasoning type: 11 | - **HotpotQA**: 12 | - `bridge`: Questions requiring reasoning over multiple documents. 13 | - `comparison`: Questions requiring comparison between entities. 14 | - **2WikiMQA**: 15 | - `bridge`: Similar to HotpotQA bridge. 16 | - `comparison`: Similar to HotpotQA comparison. 17 | - `bridge_comparison`: Questions involving both bridge and comparison reasoning. 18 | - **MuSiQue**: 19 | - `musique-2-steps`: Bridge questions requiring 2 reasoning steps. 20 | - `musique-3-steps`: Bridge questions requiring 3 reasoning steps. 21 | - `musique-4-steps`: Bridge questions requiring 4 reasoning steps. 22 | - **Sampling**: Allows for sampling a specified number of questions from each extracted type. 23 | - **Organized Output**: Saves the processed and sampled datasets into a structured output directory, with each question type in its own JSON file. 24 | 25 | ## Usage 26 | 27 | The script is run from the command line. 28 | 29 | ### Arguments 30 | 31 | - `--hotpotqa` (str): Path to the HotpotQA dataset file. 32 | - Default: `./data_defaults/hotpotqa/hotpot_dev_distractor_v1.json` 33 | - `--twowiki` (str): Path to the 2WikiMQA dataset file. 34 | - Default: `./data_defaults/2wiki/dev.jsonl` 35 | - `--musique` (str): Path to the MuSiQue dataset file. 36 | - Default: `./data_defaults/musique/dev.jsonl` 37 | - `--output_dir` (str): Directory where the processed datasets will be saved. 38 | - Default: `./processed_datasets` 39 | - `--sample_size` (int): Number of samples to draw from each extracted question type. If not specified or set to `None`, all matching questions will be saved. 40 | - Default: `50` 41 | - `--random_seed` (int): Random seed for sampling to ensure reproducibility. 42 | - Default: `42` 43 | - `--use_separated_datasets` (bool, flag): If present, use datasets from the directory specified by `--separated_datasets_dir` instead of the paths provided by `--hotpotqa`, `--twowiki`, and `--musique` or their defaults. 44 | - Default: Not present. 45 | - `--separated_datasets_dir` (str): Directory containing the separated dataset files (`hotpot_dev_distractor_v1.json`, `dev_2wiki.jsonl`, `dev_musique.jsonl`). Only used if `--use_separated_datasets` is active. 46 | - Default: `./hopweaver_dataset_files` 47 | - `--log_level` (str): Logging level (DEBUG, INFO, WARNING, ERROR). 48 | - Default: `INFO` 49 | - `--only_analyze_types` (bool, flag): If present, only analyze and log question type statistics without saving sampled data files. 50 | - Default: Not present. 51 | - `--processed_musique_file` (str): Path to the pre-processed MuSiQue data file (`musique_data.json`) required by the MuSiQue processing logic if the default path is not suitable. This file contains pre-computed information that speeds up the analysis of MuSiQue questions. 52 | - Default: `./data_defaults/dataset_mhqa/musique_data.json` 53 | 54 | ### Example Run Command 55 | 56 | To process the datasets using custom paths (or if default paths are not set up at `./data_defaults/`), sample 100 questions of each type, and save them to a custom output directory `./processed_data`: 57 | 58 | ```bash 59 | python process_and_sample_datasets.py \ 60 | --hotpotqa path/to/your/hotpotqa_dev_distractor_v1.json \ 61 | --twowiki path/to/your/dev.jsonl \ 62 | --musique path/to/your/musique_dev.jsonl \ 63 | --processed_musique_file path/to/your/musique_data.json \ 64 | --output_dir ./processed_data \ 65 | --sample_size 100 \ 66 | --random_seed 123 67 | ``` 68 | 69 | If you have dataset files in the default locations (e.g., `./data_defaults/hotpotqa/...`), you can omit the specific dataset path arguments: 70 | 71 | ```bash 72 | python process_and_sample_datasets.py \ 73 | --output_dir ./processed_data \ 74 | --sample_size 100 75 | # This assumes default --hotpotqa, --twowiki, --musique, and --processed_musique_file paths are valid 76 | ``` 77 | 78 | To use separated dataset files (e.g., prepared in a specific directory like `./my_hopweaver_datasets/`) and save results to `./processed_data_separated` with a sample size of 50: 79 | 80 | ```bash 81 | # Ensure ./my_hopweaver_datasets/ contains hotpot_dev_distractor_v1.json, dev_2wiki.jsonl, dev_musique.jsonl 82 | python process_and_sample_datasets.py \ 83 | --output_dir ./processed_data_separated \ 84 | --sample_size 50 \ 85 | --use_separated_datasets \ 86 | --separated_datasets_dir ./my_hopweaver_datasets 87 | ``` 88 | 89 | If your separated datasets are in the default `./hopweaver_dataset_files/` directory, you can simplify the command: 90 | 91 | ```bash 92 | # Ensure ./hopweaver_dataset_files/ contains the necessary files 93 | python process_and_sample_datasets.py \ 94 | --output_dir ./processed_data_separated \ 95 | --sample_size 50 \ 96 | --use_separated_datasets 97 | ``` 98 | 99 | **Note**: Ensure that the Python environment has necessary libraries installed (e.g., `json`, `argparse`, `logging`, `tqdm`, `collections`). The script primarily uses standard Python libraries. 100 | 101 | ## Output Structure 102 | 103 | The script will create JSON files in the specified `output_dir`. For example: 104 | 105 | - `hotpotqa_bridge.json` 106 | - `hotpotqa_comparison.json` 107 | - `twowiki_bridge.json` 108 | - `twowiki_comparison.json` 109 | - `twowiki_bridge_comparison.json` 110 | - `musique_bridge.json` 111 | 112 | Each file will contain a list of question objects of the corresponding type. 113 | -------------------------------------------------------------------------------- /hopweaver/train_reranker/train_reranker.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | """ 5 | Script for training BGE Reranker model 6 | Fine-tune BAAI/bge-reranker-v2-m3 model using FlagEmbedding framework 7 | """ 8 | 9 | import os 10 | import argparse 11 | import subprocess 12 | 13 | def parse_args(): 14 | parser = argparse.ArgumentParser(description="Train BGE Reranker model") 15 | parser.add_argument("--model_path", type=str, 16 | default="./models/bge-reranker-v2-m3", 17 | help="Model path") 18 | parser.add_argument("--train_data", type=str, 19 | default="./toy_finetune_data.jsonl", 20 | help="Training data path") 21 | parser.add_argument("--output_dir", type=str, 22 | default="./output", 23 | help="Output directory") 24 | parser.add_argument("--cache_dir", type=str, 25 | default="./cache/model", 26 | help="Model cache path") 27 | parser.add_argument("--cache_path", type=str, 28 | default="./cache/data", 29 | help="Data cache path") 30 | parser.add_argument("--num_gpus", type=int, 31 | default=2, 32 | help="Number of GPUs to use") 33 | parser.add_argument("--learning_rate", type=float, 34 | default=6e-5, 35 | help="Learning rate") 36 | parser.add_argument("--num_train_epochs", type=int, 37 | default=10, 38 | help="Number of training epochs") 39 | parser.add_argument("--per_device_train_batch_size", type=int, 40 | default=4, 41 | help="Training batch size per device") 42 | parser.add_argument("--gradient_accumulation_steps", type=int, 43 | default=1, 44 | help="Gradient accumulation steps") 45 | parser.add_argument("--query_max_len", type=int, 46 | default=512, 47 | help="Maximum query length") 48 | parser.add_argument("--passage_max_len", type=int, 49 | default=8196, 50 | help="Maximum passage length") 51 | parser.add_argument("--train_group_size", type=int, 52 | default=4, 53 | help="Training group size (1 positive sample + 3 negative samples)") 54 | parser.add_argument("--pad_to_multiple_of", type=int, 55 | default=8, 56 | help="Pad to multiple of") 57 | parser.add_argument("--deepspeed_config", type=str, 58 | default="./ds_stage0.json", 59 | help="DeepSpeed configuration file path") 60 | parser.add_argument("--warmup_ratio", type=float, 61 | default=0.1, 62 | help="Warmup ratio") 63 | parser.add_argument("--weight_decay", type=float, 64 | default=0.01, 65 | help="Weight decay") 66 | parser.add_argument("--save_steps", type=int, 67 | default=100, 68 | help="Save steps") 69 | parser.add_argument("--save_total_limit", type=int, 70 | default=50, 71 | help="Save total limit") 72 | parser.add_argument("--fp16", action="store_true", 73 | default=True, 74 | help="Use fp16") 75 | parser.add_argument("--overwrite_output_dir", action="store_true", 76 | default=True, 77 | help="Overwrite output directory") 78 | parser.add_argument("--knowledge_distillation", type=str, 79 | default="False", 80 | help="Knowledge distillation") 81 | parser.add_argument("--report_to", type=str, 82 | default="tensorboard", 83 | help="Report training metrics in real-time, supports tensorboard") 84 | return parser.parse_args() 85 | 86 | def main(): 87 | args = parse_args() 88 | 89 | # Create output directory and cache directory 90 | os.makedirs(args.output_dir, exist_ok=True) 91 | os.makedirs(args.cache_dir, exist_ok=True) 92 | os.makedirs(args.cache_path, exist_ok=True) 93 | 94 | # Build training command 95 | # Control visible GPUs via CUDA_VISIBLE_DEVICES environment variable 96 | cmd = [ 97 | "torchrun", 98 | f"--nproc_per_node={args.num_gpus}", 99 | "-m", "FlagEmbedding.finetune.reranker.encoder_only.base", 100 | f"--model_name_or_path={args.model_path}", 101 | f"--cache_dir={args.cache_dir}", 102 | f"--train_data={args.train_data}", 103 | f"--cache_path={args.cache_path}", 104 | f"--train_group_size={args.train_group_size}", 105 | f"--query_max_len={args.query_max_len}", 106 | f"--passage_max_len={args.passage_max_len}", 107 | f"--pad_to_multiple_of={args.pad_to_multiple_of}", 108 | f"--knowledge_distillation={args.knowledge_distillation}", 109 | f"--output_dir={args.output_dir}", 110 | f"--report_to={args.report_to}" 111 | ] 112 | 113 | # Add optional parameters 114 | if args.overwrite_output_dir: 115 | cmd.append("--overwrite_output_dir") 116 | 117 | # Add other parameters 118 | cmd.extend([ 119 | f"--learning_rate={args.learning_rate}", 120 | f"--num_train_epochs={args.num_train_epochs}", 121 | f"--per_device_train_batch_size={args.per_device_train_batch_size}", 122 | f"--gradient_accumulation_steps={args.gradient_accumulation_steps}", 123 | "--dataloader_drop_last=True", 124 | "--logging_steps=1", 125 | f"--save_steps={args.save_steps}", 126 | f"--save_total_limit={args.save_total_limit}", 127 | "--ddp_find_unused_parameters=False", 128 | "--gradient_checkpointing", 129 | f"--weight_decay={args.weight_decay}", 130 | f"--deepspeed={args.deepspeed_config}", 131 | f"--warmup_ratio={args.warmup_ratio}" 132 | ]) 133 | 134 | # Add fp16 or bf16 based on parameters 135 | if args.fp16: 136 | cmd.append("--fp16") 137 | else: 138 | cmd.append("--bf16") 139 | 140 | cmd_str = " ".join(cmd) 141 | print(f"Executing command: {cmd_str}") 142 | 143 | try: 144 | subprocess.run(cmd, check=True) 145 | print("Training complete!") 146 | except subprocess.CalledProcessError as e: 147 | print(f"Training failed: {e}") 148 | except Exception as e: 149 | print(f"An error occurred: {e}") 150 | 151 | if __name__ == "__main__": 152 | main() 153 | -------------------------------------------------------------------------------- /flashrag/judger/judger.py: -------------------------------------------------------------------------------- 1 | from typing import cast, List 2 | import json 3 | from tqdm.auto import trange 4 | from collections import Counter 5 | import numpy as np 6 | import torch 7 | import faiss 8 | from transformers import AutoModelForSeq2SeqLM, AutoTokenizer 9 | from flashrag.retriever.utils import load_model, pooling 10 | 11 | 12 | class BaseJudger: 13 | """Base object of Judger, used for judging whether to retrieve""" 14 | 15 | def __init__(self, config): 16 | self.config = config 17 | self.name = config["judger_name"] if "judger_name" in config else None 18 | self.judger_config = config["judger_config"] if "judger_config" in config else {} 19 | self.device = config["device"] 20 | 21 | def run(self, item) -> str: 22 | """Get judgement result. 23 | 24 | Args: 25 | item: dataset item, contains question, retrieval result... 26 | 27 | Returns: 28 | judgement: bool, whether to retreive 29 | """ 30 | pass 31 | 32 | def batch_run(self, dataset, batch_size=None) -> List[str]: 33 | return [self.run(item) for item in dataset] 34 | 35 | 36 | class SKRJudger(BaseJudger): 37 | """Implementation for SKR-knn 38 | Paper link: https://aclanthology.org/2023.findings-emnlp.691.pdf 39 | """ 40 | 41 | def __init__(self, config): 42 | super().__init__(config) 43 | self.model_path = self.judger_config["model_path"] 44 | self.training_data_path = self.judger_config["training_data_path"] 45 | self.encoder, self.tokenizer = load_model(model_path=self.model_path, use_fp16=False) 46 | self.topk = self.judger_config["topk"] if "topk" in self.judger_config else 5 47 | self.batch_size = self.judger_config["batch_size"] if "batch_size" in config else 64 48 | self.max_length = self.judger_config["max_length"] if "max_length" in config else 128 49 | 50 | with open(self.training_data_path, "r") as f: 51 | self.training_data = json.load(f) 52 | # count number of pos & neg samples in training data 53 | self.training_data_counter = Counter([item["judgement"].strip() for item in self.training_data]) 54 | self.training_pos_num = self.training_data_counter["ir_better"] 55 | self.training_neg_num = self.training_data_counter["ir_worse"] 56 | self.training_data_num = sum(self.training_data_counter.values()) 57 | 58 | # encode training question into faiss 59 | training_questions = [item["question"] for item in self.training_data] 60 | all_embeddings = self.encode(training_questions) 61 | faiss_index = faiss.index_factory(all_embeddings.shape[-1], "Flat", faiss.METRIC_L2) 62 | faiss_index.add(all_embeddings) 63 | self.faiss = faiss_index 64 | 65 | @torch.inference_mode(mode=True) 66 | def encode(self, contents: list): 67 | inputs = self.tokenizer( 68 | contents, 69 | padding=True, 70 | truncation=True, 71 | return_tensors="pt", 72 | max_length=self.max_length, 73 | ).to("cuda") 74 | output = self.encoder(**inputs, return_dict=True) 75 | embeddings = pooling(output.pooler_output, output.last_hidden_state, inputs["attention_mask"], "pooler") 76 | 77 | embeddings = cast(torch.Tensor, embeddings) 78 | embeddings = torch.nn.functional.normalize(embeddings, dim=-1).detach() 79 | 80 | all_embeddings = embeddings.cpu().numpy() 81 | # all_embeddings = np.concatenate(all_embeddings, axis=0) 82 | all_embeddings = all_embeddings.astype(np.float32) 83 | 84 | return all_embeddings 85 | 86 | def judge(self, dataset): 87 | questions = dataset.question 88 | 89 | all_judgements = [] 90 | for start_idx in range(0, len(questions), self.batch_size): 91 | batch_question = questions[start_idx : start_idx + self.batch_size] 92 | batch_emb = self.encode(batch_question) 93 | scores, batch_idxs = self.faiss.search(batch_emb, k=self.topk) 94 | 95 | for idxs in batch_idxs: 96 | topk_samples = [self.training_data[idx]["judgement"].strip() for idx in idxs] 97 | topk_counter = Counter(topk_samples) 98 | 99 | # count number of pos & neg samples in topk 100 | ir_better_num = topk_counter["ir_better"] 101 | ir_worse_num = topk_counter["ir_worse"] 102 | topk_delta = ir_better_num - ir_worse_num 103 | 104 | training_data_delta = self.training_pos_num - self.training_neg_num 105 | 106 | # provide judgments based on the formula in the paper 107 | if training_data_delta < 0: 108 | if topk_delta < 0 and topk_delta <= int(training_data_delta * self.topk / self.training_data_num): 109 | judgement = False 110 | else: 111 | judgement = True 112 | else: 113 | if topk_delta > 0 and topk_delta >= int(training_data_delta * self.topk / self.training_data_num): 114 | judgement = True 115 | else: 116 | judgement = False 117 | 118 | all_judgements.append(judgement) 119 | 120 | return all_judgements 121 | 122 | 123 | class AdaptiveJudger(BaseJudger): 124 | """Implementation for Adaptive-RAG 125 | Paper link: https://aclanthology.org/2024.naacl-long.389.pdf 126 | """ 127 | 128 | def __init__(self, config): 129 | super().__init__(config) 130 | self.model_path = self.judger_config["model_path"] 131 | self.batch_size = self.judger_config.get("batch_size", 16) 132 | self.max_length = self.judger_config.get("max_length", 512) 133 | 134 | self.model = AutoModelForSeq2SeqLM.from_pretrained(self.model_path) 135 | self.tokenizer = AutoTokenizer.from_pretrained(self.model_path) 136 | self.model.eval() 137 | self.model.cuda() 138 | 139 | @torch.inference_mode(mode=True) 140 | def judge(self, dataset): 141 | questions = dataset.question 142 | questions = [q.strip() for q in questions] 143 | 144 | all_preds = [] 145 | for idx in trange(0, len(questions), self.batch_size, desc="Judger process: "): 146 | batch_input = questions[idx : idx + self.batch_size] 147 | batch_input = self.tokenizer( 148 | batch_input, 149 | truncation=True, 150 | padding=True, 151 | max_length=512, 152 | return_tensors="pt", 153 | ).to(self.model.device) 154 | 155 | scores = self.model.generate( 156 | **batch_input, return_dict_in_generate=True, output_scores=True, max_length=self.max_length 157 | ).scores[0] 158 | 159 | probs = ( 160 | torch.nn.functional.softmax( 161 | torch.stack( 162 | [ 163 | scores[:, self.tokenizer("A").input_ids[0]], 164 | scores[:, self.tokenizer("B").input_ids[0]], 165 | scores[:, self.tokenizer("C").input_ids[0]], 166 | ] 167 | ), 168 | dim=0, 169 | ) 170 | .detach() 171 | .cpu() 172 | .numpy() 173 | ) 174 | 175 | preds_labels = np.argmax(probs, 0) 176 | label_to_option = { 177 | 0: "A", 178 | 1: "B", 179 | 2: "C", 180 | } 181 | preds = [label_to_option[pred] for pred in preds_labels] 182 | all_preds.extend(preds) 183 | 184 | return all_preds 185 | -------------------------------------------------------------------------------- /flashrag/dataset/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import random 4 | import warnings 5 | from typing import List, Dict, Any, Optional, Generator 6 | import numpy as np 7 | 8 | 9 | class Item: 10 | """A container class used to store and manipulate a sample within a dataset. 11 | Information related to this sample during training/inference will be stored in `self.output`. 12 | Each attribute of this class can be used like a dict key (also for key in `self.output`). 13 | """ 14 | 15 | def __init__(self, item_dict: Dict[str, Any]) -> None: 16 | self.id: Optional[str] = item_dict.get("id", None) 17 | self.question: Optional[str] = item_dict.get("question", None) 18 | self.golden_answers: List[str] = item_dict.get("golden_answers", []) 19 | self.choices: List[str] = item_dict.get("choices", []) 20 | self.metadata: Dict[str, Any] = item_dict.get("metadata", {}) 21 | self.output: Dict[str, Any] = item_dict.get("output", {}) 22 | self.data: Dict[str, Any] = item_dict 23 | 24 | def update_output(self, key: str, value: Any) -> None: 25 | """Update the output dict and keep a key in self.output can be used as an attribute.""" 26 | if key in ["id", "question", "golden_answers", "output", "choices"]: 27 | raise AttributeError(f"{key} should not be changed") 28 | else: 29 | self.output[key] = value 30 | 31 | def update_evaluation_score(self, metric_name: str, metric_score: float) -> None: 32 | """Update the evaluation score of this sample for a metric.""" 33 | if "metric_score" not in self.output: 34 | self.output["metric_score"] = {} 35 | self.output["metric_score"][metric_name] = metric_score 36 | 37 | def __getattr__(self, attr_name: str) -> Any: 38 | predefined_attrs = ["id", "question", "golden_answers", "metadata", "output", "choices"] 39 | if attr_name in predefined_attrs: 40 | return super().__getattribute__(attr_name) 41 | else: 42 | output = self.output 43 | if attr_name in output: 44 | return output[attr_name] 45 | else: 46 | try: 47 | return self.data[attr_name] 48 | except AttributeError: 49 | raise AttributeError(f"Attribute `{attr_name}` not found") 50 | 51 | def to_dict(self) -> Dict[str, Any]: 52 | """Convert all information within the data sample into a dict. Information generated 53 | during the inference will be saved into output field. 54 | """ 55 | from flashrag.dataset.utils import convert_numpy 56 | 57 | output = { 58 | "id": self.id, 59 | "question": self.question, 60 | "golden_answers": self.golden_answers, 61 | "output": convert_numpy(self.output), 62 | } 63 | if self.metadata: 64 | output["metadata"] = self.metadata 65 | 66 | return output 67 | 68 | def __str__(self) -> str: 69 | """Return a string representation of the item with its main attributes.""" 70 | return json.dumps(self.to_dict(), indent=4) 71 | 72 | 73 | class Dataset: 74 | """A container class used to store the whole dataset. Inside the class, each data sample will be stored 75 | in `Item` class. The properties of the dataset represent the list of attributes corresponding to each item in the dataset. 76 | """ 77 | 78 | def __init__( 79 | self, 80 | config: Optional[Dict[str, Any]] = None, 81 | dataset_path: Optional[str] = None, 82 | data: Optional[List[Dict[str, Any]]] = None, 83 | sample_num: Optional[int] = None, 84 | random_sample: bool = False, 85 | ) -> None: 86 | if config is not None: 87 | self.config = config 88 | dataset_name = config['dataset_name'] if 'dataset_name' in config else 'defalut_dataset' 89 | else: 90 | self.config = None 91 | warnings.warn("dataset_name is not in config, set it as default.") 92 | dataset_name = "default_dataset" 93 | self.dataset_name = dataset_name 94 | self.dataset_path = dataset_path 95 | 96 | self.sample_num = sample_num 97 | self.random_sample = random_sample 98 | 99 | if data is None: 100 | self.data = self._load_data(self.dataset_name, self.dataset_path) 101 | else: 102 | print("Load data from provided data") 103 | if isinstance(data[0], dict): 104 | self.data = [Item(item_dict) for item_dict in data] 105 | else: 106 | assert isinstance(data[0], Item) 107 | self.data = data 108 | 109 | def _load_data(self, dataset_name: str, dataset_path: str) -> List[Item]: 110 | """Load data from the provided dataset_path or directly download the file(TODO).""" 111 | if not os.path.exists(dataset_path): 112 | # TODO: auto download: self._download(self.dataset_name, dataset_path) 113 | raise FileNotFoundError(f"Dataset file {dataset_path} not found.") 114 | 115 | data = [] 116 | with open(dataset_path, "r", encoding="utf-8") as f: 117 | for line in f: 118 | item_dict = json.loads(line) 119 | item = Item(item_dict) 120 | data.append(item) 121 | if self.sample_num is not None: 122 | if self.random_sample: 123 | print(f"Random sample {self.sample_num} items in test set.") 124 | data = random.sample(data, self.sample_num) 125 | else: 126 | data = data[: self.sample_num] 127 | 128 | return data 129 | 130 | def update_output(self, key: str, value_list: List[Any]) -> None: 131 | """Update the overall output field for each sample in the dataset.""" 132 | assert len(self.data) == len(value_list) 133 | for item, value in zip(self.data, value_list): 134 | item.update_output(key, value) 135 | 136 | @property 137 | def question(self) -> List[Optional[str]]: 138 | return [item.question for item in self.data] 139 | 140 | @property 141 | def golden_answers(self) -> List[List[str]]: 142 | return [item.golden_answers for item in self.data] 143 | 144 | @property 145 | def id(self) -> List[Optional[str]]: 146 | return [item.id for item in self.data] 147 | 148 | @property 149 | def output(self) -> List[Dict[str, Any]]: 150 | return [item.output for item in self.data] 151 | 152 | def get_batch_data(self, attr_name: str, batch_size: int) -> Generator[List[Any], None, None]: 153 | """Get an attribute of dataset items in batch.""" 154 | for i in range(0, len(self.data), batch_size): 155 | batch_items = self.data[i : i + batch_size] 156 | yield [item[attr_name] for item in batch_items] 157 | 158 | def __getattr__(self, attr_name: str) -> List[Any]: 159 | return [item.__getattr__(attr_name) for item in self.data] 160 | 161 | def get_attr_data(self, attr_name: str) -> List[Any]: 162 | """For the attributes constructed later (not implemented using property), 163 | obtain a list of this attribute in the entire dataset. 164 | """ 165 | return [item[attr_name] for item in self.data] 166 | 167 | def __getitem__(self, index: int) -> Item: 168 | return self.data[index] 169 | 170 | def __len__(self) -> int: 171 | return len(self.data) 172 | 173 | def save(self, save_path: str) -> None: 174 | """Save the dataset into the original format.""" 175 | 176 | save_data = [item.to_dict() for item in self.data] 177 | def custom_serializer(obj): 178 | if isinstance(obj, np.float32): 179 | return float(obj) 180 | if isinstance(obj, np.bool_): 181 | return str(obj) 182 | raise TypeError(f"Type {type(obj)} not serializable") 183 | with open(save_path, "w", encoding="utf-8") as f: 184 | json.dump(save_data, f, indent=4, default=custom_serializer) 185 | 186 | def __str__(self) -> str: 187 | """Return a string representation of the dataset with a summary of items.""" 188 | return f"Dataset '{self.dataset_name}' with {len(self)} items" 189 | -------------------------------------------------------------------------------- /flashrag/generator/fid.py: -------------------------------------------------------------------------------- 1 | # Source: FiD official repo: https://github.com/facebookresearch/FiD 2 | # This software is released under Creative Commons public licenses. 3 | 4 | import torch 5 | import torch.nn as nn 6 | import transformers 7 | import types 8 | import torch.nn.functional as F 9 | from torch.nn import CrossEntropyLoss 10 | import numpy as np 11 | 12 | class FiDT5(transformers.T5ForConditionalGeneration): 13 | def __init__(self, config): 14 | super().__init__(config) 15 | self.wrap_encoder() 16 | 17 | def forward_(self, **kwargs): 18 | if 'input_ids' in kwargs: 19 | kwargs['input_ids'] = kwargs['input_ids'].view(kwargs['input_ids'].size(0), -1) 20 | if 'attention_mask' in kwargs: 21 | kwargs['attention_mask'] = kwargs['attention_mask'].view(kwargs['attention_mask'].size(0), -1) 22 | 23 | return super(FiDT5, self).forward( 24 | **kwargs 25 | ) 26 | 27 | # input ids : bs, n, seq_len -> bs, n*seq_len 28 | # We need to resize as B x (N * L) instead of (B * N) x L here 29 | # because the T5 forward method uses the input tensors to infer 30 | # dimensions used in the decoder. 31 | # EncoderWrapper resizes the inputs as (B * N) x L. 32 | def forward(self, input_ids=None, attention_mask=None, **kwargs): 33 | if input_ids != None: 34 | # (bs, n, seq_len) -> (bs, n*seq_len) 35 | # inputs might have already be resized in the generate method 36 | if input_ids.dim() == 3: 37 | self.encoder.n_passages = input_ids.size(1) 38 | input_ids = input_ids.view(input_ids.size(0), -1) 39 | if attention_mask != None: 40 | attention_mask = attention_mask.view(attention_mask.size(0), -1) 41 | #print(input_ids.shape) 42 | return super().forward( 43 | input_ids=input_ids, 44 | attention_mask=attention_mask, 45 | **kwargs 46 | ) 47 | 48 | def generate(self, input_ids, attention_mask, **kwargs): 49 | # input ids - bs, n, seq_len -> bs, n*seq_len 50 | self.encoder.n_passages = input_ids.size(1) 51 | return super().generate( 52 | input_ids=input_ids.view(input_ids.size(0), -1), 53 | attention_mask=attention_mask.view(attention_mask.size(0), -1), 54 | **kwargs, 55 | ) 56 | 57 | def wrap_encoder(self, use_checkpoint=False): 58 | """ 59 | Wrap T5 encoder to obtain a Fusion-in-Decoder model. 60 | """ 61 | self.encoder = EncoderWrapper(self.encoder) 62 | 63 | 64 | def unwrap_encoder(self): 65 | """ 66 | Unwrap Fusion-in-Decoder encoder, useful to load T5 weights. 67 | """ 68 | self.encoder = self.encoder.encoder 69 | block = [] 70 | for mod in self.encoder.block: 71 | block.append(mod.module) 72 | block = nn.ModuleList(block) 73 | self.encoder.block = block 74 | 75 | def load_t5(self, state_dict): 76 | self.unwrap_encoder() 77 | self.load_state_dict(state_dict) 78 | self.wrap_encoder() 79 | 80 | def set_checkpoint(self, use_checkpoint): 81 | """ 82 | Enable or disable checkpointing in the encoder. 83 | See https://pytorch.org/docs/stable/checkpoint.html 84 | """ 85 | for mod in self.encoder.encoder.block: 86 | mod.use_checkpoint = use_checkpoint 87 | 88 | def tie_weights(self): 89 | pass 90 | 91 | class CheckpointWrapper(torch.nn.Module): 92 | """ 93 | Wrapper replacing None outputs by empty tensors, which allows the use of 94 | checkpointing. 95 | """ 96 | def __init__(self, module, use_checkpoint=False): 97 | super().__init__() 98 | self.module = module 99 | self.use_checkpoint = use_checkpoint 100 | 101 | def forward(self, hidden_states, attention_mask, position_bias, **kwargs): 102 | if self.use_checkpoint and self.training: 103 | kwargs = {k: v for k, v in kwargs.items() if v is not None} 104 | def custom_forward(*inputs): 105 | output = self.module(*inputs, **kwargs) 106 | empty = torch.tensor( 107 | [], 108 | dtype=torch.float, 109 | device=output[0].device, 110 | requires_grad=True) 111 | output = tuple(x if x is not None else empty for x in output) 112 | return output 113 | 114 | output = torch.utils.checkpoint.checkpoint( 115 | custom_forward, 116 | hidden_states, 117 | attention_mask, 118 | position_bias 119 | ) 120 | output = tuple(x if x.size() != 0 else None for x in output) 121 | else: 122 | output = self.module(hidden_states, attention_mask, position_bias, **kwargs) 123 | return output 124 | 125 | def apply_checkpoint_wrapper(t5stack, use_checkpoint): 126 | """ 127 | Wrap each block of the encoder to enable checkpointing. 128 | """ 129 | block = [] 130 | for mod in t5stack.block: 131 | wrapped_mod = CheckpointWrapper(mod, use_checkpoint) 132 | block.append(wrapped_mod) 133 | block = nn.ModuleList(block) 134 | t5stack.block = block 135 | 136 | 137 | class FiDBart(transformers.BartForConditionalGeneration): 138 | def __init__(self, config): 139 | super().__init__(config) 140 | self.wrap_encoder() 141 | 142 | def forward(self, input_ids=None, attention_mask=None, **kwargs): 143 | 144 | if input_ids != None: 145 | # (bs, n, seq_len) -> (bs, n*seq_len) 146 | # inputs might have already be resized in the generate method 147 | if input_ids.dim() == 3: 148 | self.model.encoder.n_passages = input_ids.size(1) 149 | input_ids = input_ids.view(input_ids.size(0), -1) 150 | 151 | if attention_mask != None: 152 | attention_mask = attention_mask.view(attention_mask.size(0), -1) 153 | 154 | return super().forward(input_ids=input_ids, attention_mask=attention_mask, **kwargs) 155 | 156 | def generate(self, input_ids, attention_mask, **kwargs): 157 | self.model.encoder.n_passages = input_ids.size(1) 158 | return super().generate( 159 | input_ids=input_ids.view(input_ids.size(0), -1), 160 | attention_mask=attention_mask.view(attention_mask.size(0),-1), 161 | **kwargs) 162 | 163 | def wrap_encoder(self): 164 | """ 165 | Wrap T5 encoder to obtain a Fusion-in-Decoder model. 166 | """ 167 | self.model.encoder = EncoderWrapper(self.model.encoder) 168 | 169 | def unwrap_encoder(self): 170 | """ 171 | Unwrap Fusion-in-Decoder encoder, useful to load bart weights. 172 | """ 173 | self.model.encoder = self.model.encoder.encoder 174 | block = [] 175 | for mod in self.model.encoder.layers: 176 | block.append(mod) 177 | block = nn.ModuleList(block) 178 | self.model.encoder.layers = block 179 | 180 | def load_pretrained_model(self, state_dict): 181 | self.unwrap_encoder() 182 | self.load_state_dict(state_dict) 183 | self.wrap_encoder() 184 | def tie_weights(self): 185 | pass 186 | 187 | class EncoderWrapper(torch.nn.Module): 188 | def __init__(self, encoder,use_checkpoint=False): 189 | super().__init__() 190 | self.encoder = encoder 191 | 192 | try: 193 | self.main_input_name = encoder.main_input_name 194 | except: 195 | pass 196 | apply_checkpoint_wrapper(self.encoder, use_checkpoint) 197 | 198 | def forward(self, input_ids=None, attention_mask=None,**kwargs): 199 | bsz, total_length = input_ids.shape 200 | passage_length = total_length // self.n_passages 201 | # total_input 202 | input_ids = input_ids.view(bsz*self.n_passages, passage_length) 203 | attention_mask = attention_mask.view(bsz*self.n_passages, passage_length) 204 | outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask, **kwargs) 205 | outputs.last_hidden_state = outputs.last_hidden_state.view(bsz, self.n_passages*passage_length, -1) 206 | return outputs 207 | -------------------------------------------------------------------------------- /flashrag/evaluator/_bleu.py: -------------------------------------------------------------------------------- 1 | # Source: https://github.com/mjpost/sacrebleu/blob/master/sacrebleu/tokenizers/tokenizer_13a.py 2 | # Copyright 2020 SacreBLEU Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import re 17 | from functools import lru_cache 18 | 19 | 20 | class BaseTokenizer: 21 | """A base dummy tokenizer to derive from.""" 22 | 23 | def signature(self): 24 | """ 25 | Returns a signature for the tokenizer. 26 | :return: signature string 27 | """ 28 | return "none" 29 | 30 | def __call__(self, line): 31 | """ 32 | Tokenizes an input line with the tokenizer. 33 | :param line: a segment to tokenize 34 | :return: the tokenized line 35 | """ 36 | return line 37 | 38 | 39 | class TokenizerRegexp(BaseTokenizer): 40 | def signature(self): 41 | return "re" 42 | 43 | def __init__(self): 44 | self._re = [ 45 | # language-dependent part (assuming Western languages) 46 | (re.compile(r"([\{-\~\[-\` -\&\(-\+\:-\@\/])"), r" \1 "), 47 | # tokenize period and comma unless preceded by a digit 48 | (re.compile(r"([^0-9])([\.,])"), r"\1 \2 "), 49 | # tokenize period and comma unless followed by a digit 50 | (re.compile(r"([\.,])([^0-9])"), r" \1 \2"), 51 | # tokenize dash when preceded by a digit 52 | (re.compile(r"([0-9])(-)"), r"\1 \2 "), 53 | # one space only between words 54 | # NOTE: Doing this in Python (below) is faster 55 | # (re.compile(r'\s+'), r' '), 56 | ] 57 | 58 | @lru_cache(maxsize=2**16) 59 | def __call__(self, line): 60 | """Common post-processing tokenizer for `13a` and `zh` tokenizers. 61 | :param line: a segment to tokenize 62 | :return: the tokenized line 63 | """ 64 | for _re, repl in self._re: 65 | line = _re.sub(repl, line) 66 | 67 | # no leading or trailing spaces, single space within words 68 | # return ' '.join(line.split()) 69 | # This line is changed with regards to the original tokenizer (seen above) to return individual words 70 | return line.split() 71 | 72 | 73 | class Tokenizer13a(BaseTokenizer): 74 | def signature(self): 75 | return "13a" 76 | 77 | def __init__(self): 78 | self._post_tokenizer = TokenizerRegexp() 79 | 80 | @lru_cache(maxsize=2**16) 81 | def __call__(self, line): 82 | """Tokenizes an input line using a relatively minimal tokenization 83 | that is however equivalent to mteval-v13a, used by WMT. 84 | :param line: a segment to tokenize 85 | :return: the tokenized line 86 | """ 87 | 88 | # language-independent part: 89 | line = line.replace("", "") 90 | line = line.replace("-\n", "") 91 | line = line.replace("\n", " ") 92 | 93 | if "&" in line: 94 | line = line.replace(""", '"') 95 | line = line.replace("&", "&") 96 | line = line.replace("<", "<") 97 | line = line.replace(">", ">") 98 | 99 | return self._post_tokenizer(f" {line} ") 100 | 101 | 102 | # Copyright 2017 Google Inc. All Rights Reserved. 103 | # 104 | # Licensed under the Apache License, Version 2.0 (the "License"); 105 | # you may not use this file except in compliance with the License. 106 | # You may obtain a copy of the License at 107 | # 108 | # http://www.apache.org/licenses/LICENSE-2.0 109 | # 110 | # Unless required by applicable law or agreed to in writing, software 111 | # distributed under the License is distributed on an "AS IS" BASIS, 112 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 113 | # See the License for the specific language governing permissions and 114 | # limitations under the License. 115 | # ============================================================================== 116 | 117 | """Python implementation of BLEU and smooth-BLEU. 118 | 119 | This module provides a Python implementation of BLEU and smooth-BLEU. 120 | Smooth BLEU is computed following the method outlined in the paper: 121 | Chin-Yew Lin, Franz Josef Och. ORANGE: a method for evaluating automatic 122 | evaluation metrics for machine translation. COLING 2004. 123 | """ 124 | 125 | import collections 126 | import math 127 | 128 | 129 | def _get_ngrams(segment, max_order): 130 | """Extracts all n-grams upto a given maximum order from an input segment. 131 | 132 | Args: 133 | segment: text segment from which n-grams will be extracted. 134 | max_order: maximum length in tokens of the n-grams returned by this 135 | methods. 136 | 137 | Returns: 138 | The Counter containing all n-grams upto max_order in segment 139 | with a count of how many times each n-gram occurred. 140 | """ 141 | ngram_counts = collections.Counter() 142 | for order in range(1, max_order + 1): 143 | for i in range(0, len(segment) - order + 1): 144 | ngram = tuple(segment[i : i + order]) 145 | ngram_counts[ngram] += 1 146 | return ngram_counts 147 | 148 | 149 | def compute_bleu(reference_corpus, translation_corpus, max_order=4, smooth=False): 150 | """Computes BLEU score of translated segments against one or more references. 151 | 152 | Args: 153 | reference_corpus: list of lists of references for each translation. Each 154 | reference should be tokenized into a list of tokens. 155 | translation_corpus: list of translations to score. Each translation 156 | should be tokenized into a list of tokens. 157 | max_order: Maximum n-gram order to use when computing BLEU score. 158 | smooth: Whether or not to apply Lin et al. 2004 smoothing. 159 | 160 | Returns: 161 | 3-Tuple with the BLEU score, n-gram precisions, geometric mean of n-gram 162 | precisions and brevity penalty. 163 | """ 164 | matches_by_order = [0] * max_order 165 | possible_matches_by_order = [0] * max_order 166 | reference_length = 0 167 | translation_length = 0 168 | for references, translation in zip(reference_corpus, translation_corpus): 169 | reference_length += min(len(r) for r in references) 170 | translation_length += len(translation) 171 | 172 | merged_ref_ngram_counts = collections.Counter() 173 | for reference in references: 174 | merged_ref_ngram_counts |= _get_ngrams(reference, max_order) 175 | translation_ngram_counts = _get_ngrams(translation, max_order) 176 | overlap = translation_ngram_counts & merged_ref_ngram_counts 177 | for ngram in overlap: 178 | matches_by_order[len(ngram) - 1] += overlap[ngram] 179 | for order in range(1, max_order + 1): 180 | possible_matches = len(translation) - order + 1 181 | if possible_matches > 0: 182 | possible_matches_by_order[order - 1] += possible_matches 183 | 184 | precisions = [0] * max_order 185 | for i in range(0, max_order): 186 | if smooth: 187 | precisions[i] = (matches_by_order[i] + 1.0) / (possible_matches_by_order[i] + 1.0) 188 | else: 189 | if possible_matches_by_order[i] > 0: 190 | precisions[i] = float(matches_by_order[i]) / possible_matches_by_order[i] 191 | else: 192 | precisions[i] = 0.0 193 | 194 | if min(precisions) > 0: 195 | p_log_sum = sum((1.0 / max_order) * math.log(p) for p in precisions) 196 | geo_mean = math.exp(p_log_sum) 197 | else: 198 | geo_mean = 0 199 | 200 | ratio = float(translation_length) / reference_length 201 | 202 | if ratio > 1.0: 203 | bp = 1.0 204 | else: 205 | bp = math.exp(1 - 1.0 / ratio) 206 | 207 | bleu = geo_mean * bp 208 | 209 | return (bleu, precisions, bp, ratio, translation_length, reference_length) 210 | -------------------------------------------------------------------------------- /flashrag/prompt/base_prompt.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoTokenizer, AutoConfig 2 | import tiktoken 3 | import warnings 4 | 5 | class PromptTemplate: 6 | placeholders = ["reference", "question"] 7 | base_system_prompt = ( 8 | "Answer the question based on the given document." 9 | "Only give me the answer and do not output any other words." 10 | "\nThe following are given documents.\n\n{reference}" 11 | ) 12 | base_user_prompt = "Question: {question}" 13 | 14 | def __init__(self, config, system_prompt="", user_prompt="", reference_template=None, enable_chat=True): 15 | 16 | self.config = config 17 | self.is_openai = config["framework"] == "openai" 18 | self.max_input_len = config['generator_max_input_len'] 19 | if not self.is_openai: 20 | self.generator_path = config["generator_model_path"] 21 | model_config = AutoConfig.from_pretrained(self.generator_path, trust_remote_code=True) 22 | model_name = model_config._name_or_path.lower() 23 | self.is_chat = False 24 | if "chat" in model_name or "instruct" in model_name: 25 | self.is_chat = True 26 | self.tokenizer = AutoTokenizer.from_pretrained(self.generator_path, trust_remote_code=True) 27 | else: 28 | self.is_chat = True 29 | self.enable_chat = True 30 | try: 31 | self.tokenizer = tiktoken.encoding_for_model(config['generator_model']) 32 | except Exception as e: 33 | print("Error: ", e) 34 | warnings.warn("This model is not supported by tiktoken. Use gpt-3.5-turbo instead.") 35 | self.tokenizer = tiktoken.encoding_for_model('gpt-3.5-turbo') 36 | 37 | if len(system_prompt) == 0 and len(user_prompt) == 0: 38 | system_prompt = self.base_system_prompt 39 | user_prompt = self.base_user_prompt 40 | self.system_prompt = system_prompt 41 | self.user_prompt = user_prompt 42 | self.enable_chat = enable_chat 43 | self.reference_template = reference_template 44 | 45 | # self._check_placeholder() 46 | 47 | def _check_placeholder(self): 48 | # check placeholder in prompt 49 | for holder in self.placeholders: 50 | flag = False 51 | for prompt in [self.system_prompt, self.user_prompt]: 52 | if f"{holder}" in prompt: 53 | print(f"Find `{holder}` in template") 54 | flag = True 55 | break 56 | if not flag and holder != "reference": 57 | assert False 58 | 59 | def truncate_prompt(self, prompt): 60 | if self.is_openai: 61 | truncated_messages = [] 62 | total_tokens = 0 63 | assert isinstance(prompt, list) 64 | for message in prompt: 65 | role_content = message['content'] 66 | encoded_message = self.tokenizer.encode(role_content) 67 | 68 | if total_tokens + len(encoded_message) <= self.max_input_len: 69 | truncated_messages.append(message) 70 | total_tokens += len(encoded_message) 71 | else: 72 | print(f"The input text length is greater than the maximum length ({total_tokens + len(encoded_message)} > {self.max_input_len}) and has been truncated!") 73 | remaining_tokens = self.max_input_len - total_tokens 74 | truncated_message = self.encoding.decode(encoded_message[:remaining_tokens]) 75 | message['content'] = truncated_message 76 | truncated_messages.append(message) 77 | break 78 | 79 | return truncated_messages 80 | 81 | else: 82 | if self.tokenizer is None: 83 | self.tokenizer = AutoTokenizer.from_pretrained(self.generator_path, trust_remote_code=True) 84 | assert isinstance(prompt, str) 85 | tokenized_prompt = self.tokenizer(prompt, truncation=False, return_tensors="pt").input_ids[0] 86 | 87 | if len(tokenized_prompt) > self.max_input_len: 88 | print(f"The input text length is greater than the maximum length ({len(tokenized_prompt)} > {self.max_input_len}) and has been truncated!") 89 | half = int(self.max_input_len / 2) 90 | prompt = self.tokenizer.decode(tokenized_prompt[:half], skip_special_tokens=True) + \ 91 | self.tokenizer.decode(tokenized_prompt[-half:], skip_special_tokens=True) 92 | return prompt 93 | 94 | 95 | 96 | def get_string(self, question=None, retrieval_result=None, formatted_reference=None, previous_gen=None, messages=None, **params): 97 | if messages is not None: 98 | if isinstance(messages, str): 99 | return self.truncate_prompt(messages) 100 | if self.is_chat and self.enable_chat: 101 | if self.is_openai: 102 | self.truncate_prompt(messages) 103 | else: 104 | prompt = self.tokenizer.apply_chat_template( 105 | messages, tokenize=False, add_generation_prompt=True 106 | ) 107 | return self.truncate_prompt(prompt) 108 | else: 109 | prompt = "\n\n".join( 110 | [message['content'] for message in messages if message['content']] 111 | ) 112 | return self.truncate_prompt(prompt) 113 | 114 | if formatted_reference is None: 115 | if retrieval_result is not None: 116 | formatted_reference = self.format_reference(retrieval_result) 117 | else: 118 | formatted_reference = "" 119 | 120 | input_params = {"question": question, "reference": formatted_reference} 121 | input_params.update(**params) 122 | 123 | system_prompt = self.system_prompt.format(**input_params) 124 | user_prompt = self.user_prompt.format(**input_params) 125 | 126 | if self.is_chat and self.enable_chat: 127 | input = [] 128 | if system_prompt != "": 129 | input.append({"role": "system", "content": system_prompt}) 130 | if user_prompt != "": 131 | input.append({"role": "user", "content": user_prompt}) 132 | if not self.is_openai: 133 | input = self.tokenizer.apply_chat_template(input, tokenize=False, add_generation_prompt=True) 134 | else: 135 | input = "\n\n".join([prompt for prompt in [system_prompt, user_prompt] if prompt != ""]) 136 | 137 | if previous_gen is not None and previous_gen not in ["", " "] and self.is_openai is False: 138 | input += previous_gen 139 | 140 | return self.truncate_prompt(input) 141 | 142 | def get_string_with_varying_examplars( 143 | self, 144 | question, 145 | retrieval_result=None, 146 | formatted_reference=None, 147 | previous_gen=None, 148 | examplars=[], 149 | tokenizer=None, 150 | max_length=2048, 151 | **params, 152 | ): 153 | """ 154 | Select the maximum number of examplars that can be placed in the prompt 155 | """ 156 | 157 | final_examplars = None 158 | num = len(examplars) 159 | while len(examplars) > 0: 160 | for num in range(len(examplars), 0, -1): 161 | possible_prompt = self.get_string( 162 | question=question, 163 | retrieval_result=retrieval_result, 164 | formatted_reference=formatted_reference, 165 | previous_gen=previous_gen, 166 | examplars="\n\n".join(examplars[:num]), 167 | **params, 168 | ) 169 | 170 | possible_prompt_tokens = tokenizer.encode(possible_prompt) 171 | if len(possible_prompt_tokens) <= max_length: 172 | final_examplars = examplars[:num] 173 | break 174 | if final_examplars is None: 175 | examplars = examplars[1:] 176 | else: 177 | break 178 | if final_examplars is None: 179 | final_examplars = [] 180 | 181 | final_prompt = self.get_string( 182 | question=question, 183 | retrieval_result=retrieval_result, 184 | formatted_reference=formatted_reference, 185 | previous_gen=previous_gen, 186 | examplars="\n\n".join(final_examplars[:num]), 187 | **params, 188 | ) 189 | 190 | return final_prompt 191 | 192 | def format_reference(self, retrieval_result): 193 | format_reference = "" 194 | for idx, doc_item in enumerate(retrieval_result): 195 | content = doc_item["contents"] 196 | title = content.split("\n")[0] 197 | text = "\n".join(content.split("\n")[1:]) 198 | if self.reference_template is not None: 199 | format_reference += self.reference_template.format(idx=idx, title=title, text=text) 200 | else: 201 | format_reference += f"Doc {idx+1}(Title: {title}) {text}\n" 202 | 203 | return format_reference 204 | -------------------------------------------------------------------------------- /flashrag/prompt/selfask_examplars.py: -------------------------------------------------------------------------------- 1 | SELF_ASK_PROMPT_SINGLE_HOP = """Given the following question, answer it by providing follow up questions and intermediate answers. If intermediate questions are not necessarry, answer the question directly. You are provided with evidence that can help you arrive at the answer before the question. 2 | # 3 | Context1: The Big Red One: Fuller was a World War II veteran and served with the 1st Infantry Division, which is nicknamed "The Big Red One" for the red numeral "1" on the division's shoulder patch. He received the Silver Star, Bronze Star, and Purple Heart during his service. 4 | Question: how did the big red one get its name 5 | Are follow up questions needed here: No. 6 | So the final answer is: its shoulder patch 7 | # 8 | Context1: Module:Location map/data/Cayman Islands: Module:Location map/data/Cayman Islands is a location map definition used to overlay markers and labels on an equirectangular projection map of Cayman 9 | Question: where are the cayman islands on the map 10 | Are follow up questions needed here: No. 11 | So the final answer is: western Caribbean Sea 12 | # 13 | Context1: Korean War | Combatants, Summary, Years, Map ... - Britannica: After more than a million combat casualties had been suffered on both sides, the fighting ended in July 1953 with Korea still divided into two hostile states. Negotiations in 1954 produced no further agreement, and the front line has been accepted ever since as the de facto boundary between North and South Korea. 14 | Question: who won the war between north korea and south korea 15 | Are follow up questions needed here: No. 16 | So the final answer is: technically still at war 17 | # 18 | Context1: It's Always Sunny in Philadelphia (season 13): The thirteenth season of the American comedy television series It's Always Sunny in Philadelphia premiered on FXX on September 5, 2018. 19 | Question: when does it's always sunny in philadelphia season 13 start 20 | Are follow up questions needed here: No. 21 | So the final answer is: September 5, 2018 22 | # 23 | Context1: You've Got a Friend in Me: "You've Got a Friend in Me" is a song by Randy Newman. Used as the theme song for the 1995 Disney/Pixar animated film Toy Story, it has since become a major ... 24 | Question: who sang you got a friend in me from toy story 25 | Are follow up questions needed here: No. 26 | So the final answer is: Randy Newman 27 | # 28 | Context1: Timeline of space exploration: This is a timeline of space exploration which includes notable achievements, first accomplishments and milestones in humanity's exploration of outer space. 29 | Question: when was the first person sent to space 30 | Are follow up questions needed here: No. 31 | So the final answer is: 12 April 1961 32 | #""" 33 | 34 | 35 | SELF_ASK_PROMPT_MULTI_HOP = """Given the following question, answer it by providing follow up questions and intermediate answers. If intermediate questions are not necessarry, answer the question directly. You are provided with evidence that can help you arrive at the answer before the question. 36 | # 37 | Context1: Xawery Żuławski: Polish-Russian War (Wojna polsko-ruska) is a 2009 Polish film directed by Xawery Żuławski based on the novel Polish-Russian War under the white-red flag by Dorota Masłowska. So the answer is Xawery Żuławski. 38 | Context2: Xawery Żuławski: Xawery Żuławski ; National Film School in Łódź · 1995–present · Maria Strzelecka · 2. 39 | Question: Who is the mother of the director of film Polish-Russian War (Film)? 40 | Are follow up questions needed here: Yes. 41 | Follow up: Who is the director of the film Polish-Russian War (Film)? 42 | Intermediate answer: The director of the film Polish-Russian War is Xawery Żuławski. 43 | Follow up: Who is the mother of Xawery Żuławski? 44 | Intermediate answer: The mother of Xawery Żuławski is Małgorzata Braunek. 45 | So the final answer is: Rick Scott Małgorzata Braunek. 46 | # 47 | Context1: 2003: Blind Shaft (Chinese: 盲井; pinyin: Mángjǐng) is a 2003 film about a pair of brutal con artists operating in the illegal coal mines of present-day northern China. So the answer is 2003. 48 | Context2: December 2, 1932: Release and reception. The Mask of Fu Manchu opened in New York on December 2, 1932. The film cost a total of $338,000 and had worldwide rentals of $625,000. It had a profit of $62,000. So the answer is December 2, 1932. 49 | Question: Which film came out first, Blind Shaft or The Mask Of Fu Manchu? 50 | Are follow up questions needed here: Yes. 51 | Follow up: When did Blind Shaft come out? 52 | Intermediate answer: Blind Shaft came out in 2003. 53 | Follow up: When did The Mask Of Fu Manchu come out? 54 | Intermediate answer: The Mask Of Fu Manchu came out in 1932. 55 | So the final answer is: The Mask Of Fu Manchu. 56 | # 57 | Context1: John V, Prince of Anhalt-Zerbst: John was the second (but eldest surviving) son of Ernest I, Prince of Anhalt-Dessau, by his wife Margarete, daughter of Henry I, Duke of Münsterberg-Oels, and granddaughter of George of Poděbrady, King of Bohemia. 58 | Context2: 12 June 1516: Ernest I, Prince of Anhalt-Dessau (died Dessau, 12 June 1516), was a German prince of the House of Ascania and ruler of the principality of Anhalt-Dessau. So the answer is 12 June 1516. 59 | Question: When did John V, Prince Of Anhalt-Zerbst's father die? 60 | Are follow up questions needed here: Yes. 61 | Follow up: Who is the father of John V, Prince Of Anhalt-Zerbst? 62 | Intermediate answer: The father of John V, Prince Of Anhalt-Zerbst is Ernest I, Prince of Anhalt-Dessau. 63 | Follow up: When did Ernest I, Prince of Anhalt-Dessau die? 64 | Intermediate answer: Ernest I, Prince of Anhalt-Dessau died on 12 June 1516. 65 | So the final answer is: 12 June 1516 66 | # 67 | Context1: El extraño viaje: El extraño viaje (English: The Strange Voyage) is a 1964 Spanish black drama film directed by Fernando Fernán Gómez. 68 | Context2: Love in Pawn: Love in Pawn is a 1953 British comedy film directed by Charles Saunders and starring Bernard Braden, Barbara Kelly and Jeannie Carson. 69 | Context3: 28 August 1921: Fernando Fernández Gómez (28 August 1921 – 21 November 2007) better known as Fernando Fernán Gómez was a Spanish actor, screenwriter, film director, theater director and member of the Royal Spanish Academy for seven years. So the answer is 28 August 1921. 70 | Context4: Charles Saunders (director): Charles Joel Saunders (8 April 1904 – 20 April 1997) was an English film director and screenwriter who began in the industry as a film editor, and who also contributed to television. 71 | Question: Which film has the director who was born later, El Extraño Viaje or Love In Pawn? 72 | Are follow up questions needed here: Yes. 73 | Follow up: Who is the director of El Extraño Viaje? 74 | Intermediate answer: The director of El Extraño Viaje is Fernando Fernán Gómez. 75 | Follow up: Who is the director of Love in Pawn? 76 | Intermediate answer: The director of Love in Pawn is Charles Saunders. 77 | Follow up: When was Fernando Fernán Gómez born? 78 | Intermediate answer: Fernando Fernán Gómez was born on 28 August 1921. 79 | Follow up: When was Charles Saunders (director) born? 80 | Intermediate answer: Charles Saunders was born on 8 April 1904. 81 | So the final answer is: El Extraño Viaje. 82 | # 83 | Context1: John, Count Palatine of Neumarkt: John (Johann von Pfalz-Neumarkt; 1383 – 14 March 1443) was the Count Palatine of Neumarkt from 1410 to his death. The son of Rupert III of the Palatinate, he married Catherine of Pomerania in 1407. 84 | Context2: John, Count Palatine of Neumarkt: John (Johann von Pfalz-Neumarkt; 1383 – 14 March 1443) was the Count Palatine of Neumarkt from 1410 to his death. The son of Rupert III of the Palatinate, he married Catherine of Pomerania in 1407. 85 | Question: Who is Catherine Of Pomerania, Countess Palatine Of Neumarkt's father-in-law? 86 | Are follow up questions needed here: Yes. 87 | Follow up: Who is the husband of Catherine of Pomerania, Countess Palatine of Neumarkt? 88 | Intermediate answer: The husband of Catherine of Pomerania, Countess Palatine of Neumarkt is John, Count Palatine of Neumarkt. 89 | Follow up: Who is the father of John, Count Palatine of Neumarkt? 90 | Intermediate answer: The father of John, Count Palatine of Neumarkt is Rupert III of the Palatinate. 91 | So the final answer is: Rupert III of the Palatinate. 92 | # 93 | Context1: Crimen a las tres: Crimen a las tres is a 1935 Argentine crime film directed and written by Luis Saslavsky. Crimen a las tres. Directed by, Luis Saslavsky. 94 | Context2: Elio Petri: The Working Class Goes to Heaven (Italian: La classe operaia va in paradiso), released in the US as Lulu the Tool, is a 1971 political drama film directed by Elio Petri. So the answer is Elio Petri. 95 | Context3: March 20, 1995: Luis Saslavsky (April 21, 1903 – March 20, 1995) was an Argentine film director, screenwriter and film producer, and one of the influential directors in the Cinema of Argentina of the classic era. So the answer is March 20, 1995. 96 | Context4: Elio Petri: Final years. In 1981, Petri visited Geneva to direct Arthur Miller\'s new play The American Clock, with Marcello Mastroianni playing the lead role. Petri died of cancer on 10 November 1982. He was 53 years old. 97 | Question: Which film has the director died first, Crimen A Las Tres or The Working Class Goes To Heaven? 98 | Are follow up questions needed here: Yes. 99 | Follow up: Who is the director of Crimen a las tres? 100 | Intermediate answer: The director of Crimen a las tres is Luis Saslavsky. 101 | Follow up: Who is the director of The Working Class Goes to Heaven? 102 | Intermediate answer: The director of The Working Class Goes to Heaven is Elio Petri. 103 | Follow up: When did Luis Saslavsky die? 104 | Intermediate answer: Luis Saslavsky died on March 20, 1995. 105 | Follow up: When did Elio Petri die? 106 | Intermediate answer: Elio Petri died on 10 November 1982. 107 | So the final answer is: The Working Class Goes to Heaven 108 | #""" 109 | -------------------------------------------------------------------------------- /ENVIRONMENT_SETUP_CN.md: -------------------------------------------------------------------------------- 1 | ### 🛠️ 环境与数据准备 2 | 3 | 在开始使用 HopWeaver 之前,您需要完成以下准备工作: 4 | 5 | #### 1. 克隆代码库并安装依赖 6 | 7 | ```bash 8 | git clone https://github.com/Zh1yuShen/HopWeaver.git 9 | cd HopWeaver 10 | pip install -r requirements.txt 11 | ``` 12 | 13 | #### 2. 配置 LLM API 14 | 15 | 要使用本系统,请在您的 `config_lib/example_config.yaml`(或其副本)中配置LLM API。关键在于 `flashrag/generator/openai_generator.py` 脚本会基于您在 `generator_model` 中指定的模型名称中的**关键字**,从您的YAML文件中选择一个API配置块(例如 `openai_setting`, `google_setting`, `anthropic_setting`, `deepseek_setting`)。 16 | 17 | - 如果 `generator_model` 包含 `"gemini"` (例如, `"gemini-1.5-pro"`),脚本将尝试使用您YAML中的 `google_setting` 配置块。 18 | - 如果 `generator_model` 包含 `"claude"` (例如, `"claude-3-sonnet-20240229"`),它将使用 `anthropic_setting` 配置块。 19 | - 如果 `generator_model` 包含 `"deepseek"` (例如, `"deepseek-chat"`),它将使用 `deepseek_setting` 配置块。 20 | - 如果没有找到这些关键字(或其他内部定义的关键字),则默认使用 `openai_setting` 配置块 (例如, 对于像 `"gpt-4o"` 这样的模型)。 21 | 22 | 您**必须**确保配置文件中包含正确命名的设置块 (例如, `google_setting`),并且该块包含所选模型提供商所需的API密钥和任何其他必要参数。 23 | 24 | ```yaml 25 | # 示例:OpenAI 设置 (如果 generator_model 为例如 "gpt-4o", 26 | # 或在 generator_model 名称中未找到如 "gemini", "claude", "deepseek" 等关键字时,将使用此配置块) 27 | openai_setting: 28 | api_keys: 29 | - "your-openai-api-key-1" 30 | - "your-openai-api-key-2" 31 | - "your-openai-api-key-3" 32 | base_url: "https://api.openai.com/v1" 33 | 34 | # 示例:Google 设置 (如果 generator_model 包含 "gemini",将使用此配置块) 35 | # 如果您将 generator_model 设置为 Gemini 模型,请确保此块存在且已正确填写。 36 | # google_setting: 37 | # api_key: "YOUR_GOOGLE_API_KEY" 38 | # base_url: "YOUR_GOOGLE_BASE_URL" # 例如 https://generativelanguage.googleapis.com/v1 39 | 40 | # 示例:Anthropic 设置 (如果 generator_model 包含 "claude",将使用此配置块) 41 | # anthropic_setting: 42 | # api_key: "YOUR_ANTHROPIC_API_KEY" 43 | # base_url: "YOUR_ANTHROPIC_BASE_URL" # 例如 https://api.anthropic.com 44 | 45 | # 示例:DeepSeek 设置 (如果 generator_model 包含 "deepseek",将使用此配置块) 46 | # deepseek_setting: 47 | # api_key: "YOUR_DEEPSEEK_API_KEY" 48 | # base_url: "YOUR_DEEPSEEK_BASE_URL" # 例如 https://api.deepseek.com/v1 49 | 50 | # 不同组件的模型选择。 51 | # 您为 generator_model (以及其他也使用 openai_generator.py 的 *_model 字段) 52 | # 提供的名称决定了上方哪个 _setting 配置块必须存在且已正确配置。 53 | generator_model: "gpt-4o" 54 | entity_extractor_model: "gpt-4o" # 假设此模型名称也映射到 openai_setting 或有其自己的逻辑 55 | question_generator_model: "gpt-4o" 56 | polisher_model: "gpt-4o" 57 | filter_model: "gpt-4o" 58 | ``` 59 | 60 | > **代码参考**:具体的关键字到设置块的映射逻辑在 `HopWeaver/flashrag/generator/openai_generator.py` 中实现。请查看此文件以了解模型名称是如何被解析以选择像 `config["google_setting"]`、`config["anthropic_setting"]` 等配置节的。如果未匹配到其他特定关键字,则默认为 `config["openai_setting"]`。 61 | 62 | 同时,不同模型(如 GPT-4、Claude、Qwen、DeepSeek 等)可能需要不同的生成参数(例如 temperature、top_p、max_tokens)。请根据您模型的特性,在所选的 `*_setting` 配置块或配置文件的 `generation_params` 部分进行适当设置。 63 | 64 | ##### 🤖 模型选择建议 65 | 66 | HopWeaver 由几个可以使用不同模型的组件组成。以下是基于我们实验的建议: 67 | 68 | - **polisher_model**:我们建议为语言润色组件使用 DeepSeek-R1 或更高级的模型,因为它需要强大的语言优化能力 69 | - **其他组件**:您可以为其他组件(entity_extractor, question_generator, filter 等)使用相同的模型。我们建议为所有合成组件选择同一个模型。在我们的论文中,我们成功测试了各种模型,包括: 70 | - QwQ-32B 71 | - Qwen3-14B 72 | - GLM-9B-0414 73 | 74 | 为了获得最佳性能,我们建议使用至少 7B 参数的模型。较小的模型可能难以处理多跳问题合成所需的复杂推理。 75 | 76 | ##### 💻 本地模型配置 77 | 78 | 您可以使用[FlashRAG](https://github.com/RUC-NLPIR/FlashRAG)提供的本地模型支持,它支持多种本地模型部署方式 79 | 80 | ##### ⚡ API调用优化 81 | 82 | HopWeaver实现了如下优化机制,提高了API调用的稳定性和效率: 83 | 84 | 1. **🔄 多个API Key轮询**: 当配置多个API Key时,系统会自动轮询使用,分散请求率限制 85 | 86 | ```yaml 87 | openai_setting: 88 | api_keys: 89 | - "key1" 90 | - "key2" 91 | - "key3" # 多个API Key列表 92 | ``` 93 | 94 | 2. **🔄 错误自动重试**: 当遇到常见API错误(如速率限制、服务器错误)时,系统会自动重试 95 | 96 | 3. **⚡ 异步请求处理**: 支持批量异步请求,最大化利用API调用频率 97 | 98 | 这些机制使得HopWeaver在面对大量多跳问题合成时,能更高效地利用LLM API资源。 99 | 100 | #### 3. 多API提供商支持 101 | 102 | HopWeaver支持多种API提供商,提供更强的灵活性和冗余能力。您可以在配置文件中配置不同的提供商: 103 | 104 | ```yaml 105 | # 多API提供商配置 106 | api_type: "openai" # 主要API类型 107 | 108 | # OpenAI 配置 109 | openai_setting: 110 | api_keys: 111 | - "your-openai-api-key-1" 112 | - "your-openai-api-key-2" 113 | base_url: "https://api.openai.com/v1" 114 | 115 | # Google Gemini 配置 116 | gemini_setting: 117 | api_keys: 118 | - "your-gemini-api-key-1" 119 | - "your-gemini-api-key-2" 120 | base_url: "https://generativelanguage.googleapis.com/v1" 121 | 122 | # DeepSeek 配置 123 | deepseek_setting: 124 | api_key: "your-deepseek-api-key" 125 | base_url: "https://api.deepseek.com/v1" 126 | 127 | # Claude (Anthropic) 配置 128 | claude_setting: 129 | api_key: "your-claude-api-key" 130 | base_url: "https://api.anthropic.com" 131 | 132 | # OpenRouter 配置(支持多种模型) 133 | openrouter_setting: 134 | api_keys: 135 | - "your-openrouter-key-1" 136 | - "your-openrouter-key-2" 137 | base_url: "https://openrouter.ai/api/v1" 138 | 139 | # GLM (SiliconFlow) 配置 140 | GLM_setting: 141 | api_keys: "your-glm-api-key" 142 | base_url: "https://api.siliconflow.cn/v1" 143 | ``` 144 | 145 | **各提供商支持的模型:** 146 | - **OpenAI**: GPT-4o, GPT-4-turbo, GPT-3.5-turbo 等 147 | - **Google**: Gemini-2.0-flash, Gemini-2.5-flash-preview 等 148 | - **DeepSeek**: DeepSeek-R1, DeepSeek-V3 等 149 | - **Claude**: Claude-3.5-Sonnet 等 150 | - **OpenRouter**: 可访问 QwQ-32B, Gemma-3-27B 等模型 151 | - **GLM**: GLM-4-9B 和其他 SiliconFlow 支持的模型 152 | 153 | #### 4. 全局路径映射配置 154 | 155 | HopWeaver使用全局路径映射来高效管理模型路径、索引和语料库: 156 | 157 | ```yaml 158 | # 全局路径映射 159 | model2path: 160 | e5: "/path/to/e5-base-v2" 161 | gte: "/path/to/gte_sentence-embedding_multilingual-base" 162 | 163 | # 各嵌入模型的池化方法 164 | model2pooling: 165 | e5: "mean" 166 | gte: "cls" 167 | 168 | # 检索模型的索引路径 169 | method2index: 170 | e5: '/path/to/e5_Flat.index' 171 | gte: '/path/to/gte_Flat.index' 172 | bm25: ~ 173 | contriever: ~ 174 | 175 | # 不同方法的语料库路径 176 | method2corpus: 177 | e5: '/path/to/wiki18_fulldoc_trimmed_4096.jsonl' 178 | gte: '/path/to/wiki18_fulldoc_trimmed_4096.jsonl' 179 | ``` 180 | 181 | **配置优势:** 182 | - **集中管理**:所有模型和数据路径集中在一个位置 183 | - **便捷切换**:通过修改 `retrieval_method` 参数即可切换检索方法 184 | - **自动解析**:系统根据方法选择自动解析对应路径 185 | - **可扩展性**:易于添加新的模型和语料库 186 | 187 | #### 5. 检索器高级参数 188 | 189 | 为了精细控制检索过程,可配置以下高级参数: 190 | 191 | **📊 检索方法选择:** 192 | 193 | HopWeaver支持三种检索方法: 194 | - **standard**:标准检索,仅基于查询相关性排序 195 | - **diverse**:多样性检索,使用MMR算法平衡相关性和多样性 196 | - **rerank**:两阶段检索,先进行多样性检索,再使用训练好的重排模型精细排序 197 | 198 | **🔄 重排器模型配置:** 199 | 200 | HopWeaver支持使用重排器模型进一步优化检索结果的排序。您可以选择以下重排器模型: 201 | 202 | **开源重排器模型:** 203 | - **[BAAI/bge-reranker-v2-m3](https://huggingface.co/BAAI/bge-reranker-v2-m3)**: 这是一个基于bge-m3的轻量级多语言重排器模型,具有强大的多语言能力,易于部署,推理速度快。该模型支持中文和英文,在多种基准测试中表现出色。 204 | 205 | **自定义微调重排器:** 206 | - **HopWeaver微调重排器**: 我们基于特定的多跳问答数据进行了微调,专门针对多跳问题的文档检索进行了优化。该模型在多跳合成检索互补文档时表现更佳。(模型链接即将在[HuggingFace]()、[ModelScope]()上发布) 207 | 208 | **重排器配置示例:** 209 | 210 | ```yaml 211 | # 检索器配置 212 | retriever_type: "rerank" # 检索方法选择,选项:"standard"、"diverse" 或 "rerank" 213 | reranker_path: "/path/to/trained/reranker/model" # 重排模型路径(仅在rerank方法时需要) 214 | 215 | # 检索器多样性权重参数(适用于diverse和rerank方法的粗检索阶段) 216 | lambda1: 0.87 # 查询相关性权重 (0-1) 217 | lambda2: 0.03 # 原始文档多样性权重 (0-1) 218 | lambda3: 0.1 # 已选文档多样性权重 (0-1) 219 | 220 | # 重排器性能参数 221 | use_reranker: true # 启用重排器 222 | rerank_topk: 5 # 重排后保留的文档数量 223 | rerank_max_length: 4096 # 重排器输入的最大长度 224 | rerank_batch_size: 256 # 重排批处理大小 225 | rerank_use_fp16: true # 使用FP16加速重排器推理 226 | 227 | # 检索缓存(性能优化) 228 | save_retrieval_cache: false # 保存检索结果到缓存 229 | use_retrieval_cache: false # 使用缓存的检索结果 230 | retrieval_cache_path: ~ # 检索缓存文件路径 231 | ``` 232 | 233 | **参数调优指南:** 234 | - **lambda1 (0.8-0.9)**:更高的值优先考虑查询-文档相关性 235 | - **lambda2 (0.05-0.15)**:控制与源文档的多样性 236 | - **lambda3 (0.05-0.15)**:控制已选文档间的多样性 237 | - **lambda1+lambda2+lambda3 的和应等于 1.0** 238 | 239 | **性能提示:** 240 | - 使用 `use_fp16: true` 可以获得更快的推理速度,质量损失极小 241 | - 根据GPU内存调整 `reranker_batch_size` 242 | - 对于重复实验相同查询,启用缓存可提高效率 243 | 244 | #### 6. 下载Wiki数据集 245 | 246 | 您需要下载`wiki18_fulldoc_trimmed_4096.jsonl`数据文件,这是我们预处理好的Wiki数据集,包含截取了文档长度小于4096的Wiki文章。 247 | 248 | 数据集下载链接: [huggingface](https://huggingface.co/datasets/Shenzy2/HopWeaver_Data) or [modelscope](https://www.modelscope.cn/datasets/szyszy/HopWeaver_Data) 249 | 250 | 对于我们论文中比较的HotpotQA、2wiki、musique的步骤,可以将下载的数据集放入 datasets 文件夹中,并且用 datasets/process_and_sample_datasets.py 处理这些采样出任意样本,用于后续比较。 251 | 252 | **数据格式说明**: 253 | `wiki18_fulldoc_trimmed_4096.jsonl`是JSONL格式文件,每行包含一个JSON对象,结构如下: 254 | ```json 255 | { 256 | "id": "591775", 257 | "title": "Los Ramones", 258 | "doc_size": 1250, 259 | "contents": "Los Ramones\nLos Ramones Los Ramones is the name of a municipality..." 260 | } 261 | ``` 262 | 263 | **字段说明**: 264 | - `id`: 文档的唯一标识符 265 | - `title`: 文档标题 266 | - `doc_size`: 文档内容的字符长度 267 | - `contents`: 文档的完整正文内容 268 | 269 | #### 7. 下载GTE嵌入模型 270 | 271 | HopWeaver使用[GTE](https://huggingface.co/iic/gte_sentence-embedding_multilingual-base)多语言模型进行检索。您可以直接从Hugging Face下载该模型,并在配置文件中指定路径: 272 | 273 | 修改配置文件`config_lib/example_config.yaml`中的模型路径: 274 | ```yaml 275 | model2path: 276 | gte: "您下载的GTE模型路径" 277 | ``` 278 | 279 | #### 8. 下载或构建索引 280 | 281 | 您可以选择下载我们预构建好的索引文件( 282 | [huggingface](https://huggingface.co/datasets/Shenzy2/HopWeaver_Data) or [modelscope](https://www.modelscope.cn/datasets/szyszy/HopWeaver_Data)),或自行构建: 283 | 284 | ```bash 285 | # 创建索引保存目录 286 | mkdir -p index 287 | 288 | # 下载预构建索引(推荐) 289 | # 索引下载链接: [INDEX_DOWNLOAD_LINK_PLACEHOLDER] 290 | 291 | # 或者使用FlashRAG构建索引 292 | python -m flashrag.build_index \ 293 | --model_name_or_path 您下载的GTE模型路径 \ 294 | --corpus_path dataset/wiki18_fulldoc_trimmed_4096.jsonl \ 295 | --index_path index/gte_Flat.index \ 296 | --batch_size 32 \ 297 | --model_type gte \ 298 | --pooling_method cls \ 299 | --use_fp16 300 | ``` 301 | 302 | 参数说明: 303 | - `--model_name_or_path`: GTE模型路径 304 | - `--corpus_path`: Wiki语料库文件路径 305 | - `--index_path`: 生成的索引保存路径 306 | - `--batch_size`: 批处理大小,可根据您的GPU内存调整 307 | - `--model_type`: 模型类型,这里是gte 308 | - `--pooling_method`: 池化方法,GTE使用cls 309 | - `--use_fp16`: 使用FP16以加速索引构建 310 | 311 | 完成上述准备工作后,您就可以开始使用HopWeaver合成多跳问题了。 -------------------------------------------------------------------------------- /flashrag/config/config.py: -------------------------------------------------------------------------------- 1 | import re 2 | import os 3 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 4 | import yaml 5 | import random 6 | import datetime 7 | 8 | 9 | class Config: 10 | def __init__(self, config_file_path=None, config_dict={}): 11 | 12 | self.yaml_loader = self._build_yaml_loader() 13 | self.file_config = self._load_file_config(config_file_path) 14 | self.variable_config = config_dict 15 | 16 | self.external_config = self._merge_external_config() 17 | 18 | self.internal_config = self._get_internal_config() 19 | 20 | self.final_config = self._get_final_config() 21 | 22 | self._check_final_config() 23 | self._set_additional_key() 24 | 25 | self._init_device() 26 | self._set_seed() 27 | self._prepare_dir() 28 | 29 | def _build_yaml_loader(self): 30 | loader = yaml.FullLoader 31 | loader.add_implicit_resolver( 32 | "tag:yaml.org,2002:float", 33 | re.compile( 34 | """^(?: 35 | [-+]?(?:[0-9][0-9_]*)\\.[0-9_]*(?:[eE][-+]?[0-9]+)? 36 | |[-+]?(?:[0-9][0-9_]*)(?:[eE][-+]?[0-9]+) 37 | |\\.[0-9_]+(?:[eE][-+][0-9]+)? 38 | |[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\\.[0-9_]* 39 | |[-+]?\\.(?:inf|Inf|INF) 40 | |\\.(?:nan|NaN|NAN))$""", 41 | re.X, 42 | ), 43 | list("-+0123456789."), 44 | ) 45 | return loader 46 | 47 | def _load_file_config(self, config_file_path: str): 48 | file_config = dict() 49 | if config_file_path: 50 | with open(config_file_path, "r", encoding="utf-8") as f: 51 | file_config.update(yaml.load(f.read(), Loader=self.yaml_loader)) 52 | return file_config 53 | 54 | @staticmethod 55 | def _update_dict(old_dict: dict, new_dict: dict): 56 | # Update the original update method of the dictionary: 57 | # If there is the same key in `old_dict` and `new_dict`, and value is of type dict, update the key in dict 58 | 59 | same_keys = [] 60 | for key, value in new_dict.items(): 61 | if key in old_dict and isinstance(value, dict): 62 | same_keys.append(key) 63 | for key in same_keys: 64 | old_item = old_dict[key] 65 | new_item = new_dict[key] 66 | old_item.update(new_item) 67 | new_dict[key] = old_item 68 | 69 | old_dict.update(new_dict) 70 | return old_dict 71 | 72 | def _merge_external_config(self): 73 | external_config = dict() 74 | external_config = self._update_dict(external_config, self.file_config) 75 | external_config = self._update_dict(external_config, self.variable_config) 76 | 77 | return external_config 78 | 79 | def _get_internal_config(self): 80 | current_path = os.path.dirname(os.path.realpath(__file__)) 81 | init_config_path = os.path.join(current_path, "basic_config.yaml") 82 | internal_config = self._load_file_config(init_config_path) 83 | 84 | return internal_config 85 | 86 | def _get_final_config(self): 87 | final_config = dict() 88 | final_config = self._update_dict(final_config, self.internal_config) 89 | final_config = self._update_dict(final_config, self.external_config) 90 | 91 | return final_config 92 | 93 | def _check_final_config(self): 94 | # check split 95 | split = self.final_config["split"] 96 | if split is None: 97 | split = ["train", "dev", "test"] 98 | if isinstance(split, str): 99 | split = [split] 100 | self.final_config["split"] = split 101 | 102 | def _init_device(self): 103 | gpu_id = self.final_config["gpu_id"] 104 | if gpu_id is not None: 105 | os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id) 106 | try: 107 | # import pynvml 108 | # pynvml.nvmlInit() 109 | # gpu_num = pynvml.nvmlDeviceGetCount() 110 | import torch 111 | gpu_num = torch.cuda.device_count() 112 | except: 113 | gpu_num = 0 114 | self.final_config['gpu_num'] = gpu_num 115 | if gpu_num > 0: 116 | self.final_config["device"] = "cuda" 117 | else: 118 | self.final_config['device'] = 'cpu' 119 | 120 | def _set_additional_key(self): 121 | # set dataset 122 | dataset_name = self.final_config["dataset_name"] 123 | data_dir = self.final_config["data_dir"] 124 | self.final_config["dataset_path"] = os.path.join(data_dir, dataset_name) 125 | 126 | # set model path 127 | retrieval_method = self.final_config["retrieval_method"] 128 | model2path = self.final_config["model2path"] 129 | model2pooling = self.final_config["model2pooling"] 130 | method2index = self.final_config["method2index"] 131 | 132 | generator_model = self.final_config["generator_model"] 133 | 134 | if self.final_config["index_path"] is None: 135 | try: 136 | self.final_config["index_path"] = method2index[retrieval_method] 137 | except: 138 | print("Index is empty!!") 139 | assert False 140 | 141 | if self.final_config.get("retrieval_model_path") is None: 142 | self.final_config["retrieval_model_path"] = model2path.get(retrieval_method, retrieval_method) 143 | # TODO: not support when `retrieval_model` is path 144 | 145 | def set_pooling_method(method, model2pooling): 146 | for key, value in model2pooling.items(): 147 | if key.lower() in method.lower(): 148 | return value 149 | return "mean" 150 | 151 | if self.final_config.get("retrieval_pooling_method") is None: 152 | self.final_config["retrieval_pooling_method"] = set_pooling_method(retrieval_method, model2pooling) 153 | 154 | rerank_model_name = self.final_config["rerank_model_name"] 155 | if self.final_config.get("rerank_model_path") is None: 156 | if rerank_model_name is not None: 157 | self.final_config["rerank_model_path"] = model2path.get(rerank_model_name, rerank_model_name) 158 | if self.final_config["rerank_pooling_method"] is None: 159 | if rerank_model_name is not None: 160 | self.final_config["rerank_pooling_method"] = set_pooling_method(rerank_model_name, model2pooling) 161 | 162 | if self.final_config.get("generator_model_path") is None: 163 | self.final_config["generator_model_path"] = model2path.get(generator_model, generator_model) 164 | 165 | if "refiner_name" in self.final_config: 166 | refiner_model = self.final_config["refiner_name"] 167 | if "refiner_model_path" not in self.final_config or self.final_config["refiner_model_path"] is None: 168 | self.final_config["refiner_model_path"] = model2path.get(refiner_model, None) 169 | if 'instruction' not in self.final_config: 170 | self.final_config['instruction'] = None 171 | 172 | # set model path in metric setting 173 | metric_setting = self.final_config["metric_setting"] 174 | metric_tokenizer_name = metric_setting.get("tokenizer_name", None) 175 | from flashrag.utils.constants import OPENAI_MODEL_DICT 176 | 177 | if metric_tokenizer_name not in OPENAI_MODEL_DICT: 178 | metric_tokenizer_name = model2path.get(metric_tokenizer_name, metric_tokenizer_name) 179 | metric_setting["tokenizer_name"] = metric_tokenizer_name 180 | self.final_config["metric_setting"] = metric_setting 181 | 182 | def _prepare_dir(self): 183 | save_note = self.final_config["save_note"] 184 | 185 | # 检查save_note是否为experiment,如果是则跳过目录创建 186 | if save_note == "experiment": 187 | # 仅在需要时创建基本目录 188 | os.makedirs(self.final_config["save_dir"], exist_ok=True) 189 | 190 | # 不添加时间戳和save_note到目录名 191 | # 不保存config.yaml文件 192 | return 193 | 194 | # 原始逻辑已被注释,防止生成实验输出目录 195 | # current_time = datetime.datetime.now() 196 | # self.final_config["save_dir"] = os.path.join( 197 | # self.final_config["save_dir"], 198 | # f"{self.final_config['dataset_name']}_{current_time.strftime('%Y_%m_%d_%H_%M')}_{save_note}", 199 | # ) 200 | os.makedirs(self.final_config["save_dir"], exist_ok=True) 201 | # # save config parameters 202 | # config_save_path = os.path.join(self.final_config["save_dir"], "config.yaml") 203 | # with open(config_save_path, "w") as f: 204 | # yaml.dump(self.final_config, f) 205 | 206 | def _set_seed(self): 207 | import torch 208 | import numpy as np 209 | 210 | seed = self.final_config["seed"] 211 | random.seed(seed) 212 | np.random.seed(seed) 213 | torch.manual_seed(seed) 214 | torch.cuda.manual_seed(seed) 215 | torch.cuda.manual_seed_all(seed) 216 | torch.backends.cudnn.benchmark = False 217 | torch.backends.cudnn.deterministic = True 218 | 219 | def __setitem__(self, key, value): 220 | if not isinstance(key, str): 221 | raise TypeError("index must be a str.") 222 | self.final_config[key] = value 223 | 224 | def __getattr__(self, item): 225 | if "final_config" not in self.__dict__: 226 | raise AttributeError("'Config' object has no attribute 'final_config'") 227 | if item in self.final_config: 228 | return self.final_config[item] 229 | raise AttributeError(f"'Config' object has no attribute '{item}'") 230 | 231 | def __getitem__(self, item): 232 | return self.final_config.get(item) 233 | 234 | def __contains__(self, key): 235 | if not isinstance(key, str): 236 | raise TypeError("index must be a str.") 237 | return key in self.final_config 238 | 239 | def __repr__(self): 240 | return self.final_config.__str__() 241 | 242 | def update(self, new_config: dict): 243 | """ 244 | 更新配置信息 245 | Args: 246 | new_config (dict): 新的配置字典 247 | """ 248 | self.final_config = self._update_dict(self.final_config, new_config) -------------------------------------------------------------------------------- /flashrag/pipeline/pipeline.py: -------------------------------------------------------------------------------- 1 | from flashrag.evaluator import Evaluator 2 | from flashrag.dataset.utils import split_dataset, merge_dataset 3 | from flashrag.utils import get_retriever, get_generator, get_refiner, get_judger 4 | from flashrag.prompt import PromptTemplate 5 | 6 | 7 | class BasicPipeline: 8 | """Base object of all pipelines. A pipeline includes the overall process of RAG. 9 | If you want to implement a pipeline, you should inherit this class. 10 | """ 11 | 12 | def __init__(self, config, prompt_template=None): 13 | self.config = config 14 | self.device = config["device"] 15 | self.retriever = None 16 | self.evaluator = Evaluator(config) 17 | self.save_retrieval_cache = config["save_retrieval_cache"] 18 | if prompt_template is None: 19 | prompt_template = PromptTemplate(config) 20 | self.prompt_template = prompt_template 21 | 22 | def run(self, dataset): 23 | """The overall inference process of a RAG framework.""" 24 | pass 25 | 26 | def evaluate(self, dataset, do_eval=True, pred_process_fun=None): 27 | """The evaluation process after finishing overall generation""" 28 | 29 | if pred_process_fun is not None: 30 | raw_pred = dataset.pred 31 | processed_pred = [pred_process_fun(pred) for pred in raw_pred] 32 | dataset.update_output("raw_pred", raw_pred) 33 | dataset.update_output("pred", processed_pred) 34 | 35 | if do_eval: 36 | # evaluate & save result 37 | eval_result = self.evaluator.evaluate(dataset) 38 | print(eval_result) 39 | 40 | # save retrieval cache 41 | if self.save_retrieval_cache: 42 | self.retriever._save_cache() 43 | 44 | return dataset 45 | 46 | 47 | class SequentialPipeline(BasicPipeline): 48 | def __init__(self, config, prompt_template=None, retriever=None, generator=None): 49 | """ 50 | inference stage: 51 | query -> pre-retrieval -> retriever -> post-retrieval -> generator 52 | """ 53 | 54 | super().__init__(config, prompt_template) 55 | if generator is None: 56 | self.generator = get_generator(config) 57 | else: 58 | self.generator = generator 59 | 60 | if retriever is None: 61 | self.retriever = get_retriever(config) 62 | else: 63 | self.retriever = retriever 64 | 65 | # TODO: add rewriter module 66 | 67 | self.use_fid = config["use_fid"] 68 | 69 | if config["refiner_name"] is not None: 70 | self.refiner = get_refiner(config, self.retriever, self.generator) 71 | else: 72 | self.refiner = None 73 | 74 | def naive_run(self, dataset, do_eval=True, pred_process_fun=None): 75 | # direct generation without RAG 76 | input_prompts = [self.prompt_template.get_string(question=q) for q in dataset.question] 77 | dataset.update_output("prompt", input_prompts) 78 | 79 | pred_answer_list = self.generator.generate(input_prompts) 80 | dataset.update_output("pred", pred_answer_list) 81 | 82 | dataset = self.evaluate(dataset, do_eval=do_eval, pred_process_fun=pred_process_fun) 83 | return dataset 84 | 85 | def run(self, dataset, do_eval=True, pred_process_fun=None): 86 | input_query = dataset.question 87 | retrieval_results = self.retriever.batch_search(input_query) 88 | dataset.update_output("retrieval_result", retrieval_results) 89 | 90 | if self.refiner: 91 | input_prompt_flag = self.refiner.input_prompt_flag 92 | if "llmlingua" in self.refiner.name and input_prompt_flag: 93 | # input prompt 94 | input_prompts = [ 95 | self.prompt_template.get_string(question=q, retrieval_result=r) 96 | for q, r in zip(dataset.question, dataset.retrieval_result) 97 | ] 98 | dataset.update_output("prompt", input_prompts) 99 | input_prompts = self.refiner.batch_run(dataset) 100 | else: 101 | # input retrieval docs 102 | refine_results = self.refiner.batch_run(dataset) 103 | dataset.update_output("refine_result", refine_results) 104 | input_prompts = [ 105 | self.prompt_template.get_string(question=q, formatted_reference=r) 106 | for q, r in zip(dataset.question, refine_results) 107 | ] 108 | 109 | else: 110 | if not self.use_fid: 111 | input_prompts = [ 112 | self.prompt_template.get_string(question=q, retrieval_result=r) 113 | for q, r in zip(dataset.question, dataset.retrieval_result) 114 | ] 115 | 116 | if self.use_fid: 117 | print("Use FiD generation") 118 | input_prompts = [] 119 | for item in dataset: 120 | q = item.question 121 | docs = item.retrieval_result 122 | input_prompts.append([q + " " + doc['contents'] for doc in docs]) 123 | dataset.update_output("prompt", input_prompts) 124 | 125 | # delete used refiner to release memory 126 | if self.refiner: 127 | del self.refiner 128 | pred_answer_list = self.generator.generate(input_prompts) 129 | dataset.update_output("pred", pred_answer_list) 130 | 131 | dataset = self.evaluate(dataset, do_eval=do_eval, pred_process_fun=pred_process_fun) 132 | 133 | return dataset 134 | 135 | 136 | class ConditionalPipeline(BasicPipeline): 137 | def __init__(self, config, prompt_template=None): 138 | """ 139 | inference stage: 140 | query -> judger -> sequential pipeline or naive generate 141 | """ 142 | 143 | super().__init__(config, prompt_template) 144 | self.generator = get_generator(config) 145 | self.judger = get_judger(config) 146 | self.retriever = get_retriever(config) 147 | 148 | self.sequential_pipeline = SequentialPipeline( 149 | config, prompt_template, retriever=self.retriever, generator=self.generator 150 | ) 151 | 152 | self.zero_shot_templete = PromptTemplate( 153 | config=config, 154 | system_prompt="Answer the question based on your own knowledge. \ 155 | Only give me the answer and do not output any other words.", 156 | user_prompt="Question: {question}", 157 | ) 158 | 159 | def run(self, dataset, do_eval=True, pred_process_fun=None): 160 | # judge_result: list of bool element, representing whether to use retrieval 161 | judge_result = self.judger.judge(dataset) 162 | dataset.update_output("judge_result", judge_result) 163 | 164 | # split dataset based on judge_result 165 | dataset_split = split_dataset(dataset, judge_result) 166 | pos_dataset, neg_dataset = dataset_split[True], dataset_split[False] 167 | 168 | pos_dataset = self.sequential_pipeline.run(pos_dataset, do_eval=False) 169 | self.sequential_pipeline.prompt_template = self.zero_shot_templete 170 | neg_dataset = self.sequential_pipeline.naive_run(neg_dataset, do_eval=False) 171 | 172 | # merge datasets into original format 173 | dataset = merge_dataset(dataset_split, judge_result) 174 | 175 | dataset = self.evaluate(dataset, do_eval=do_eval, pred_process_fun=pred_process_fun) 176 | 177 | return dataset 178 | 179 | 180 | class AdaptivePipeline(BasicPipeline): 181 | def __init__( 182 | self, 183 | config, 184 | norag_template=None, 185 | single_hop_prompt_template=None, 186 | multi_hop_prompt_template=None, 187 | ): 188 | super().__init__(config) 189 | # load adaptive classifier as judger 190 | generator = get_generator(config) 191 | retriever = get_retriever(config) 192 | self.judger = get_judger(config) 193 | self.generator = generator 194 | self.retriever = retriever 195 | 196 | # Load three pipeline for three types of query: naive/single-hop/multi-hop 197 | from flashrag.pipeline import IRCOTPipeline 198 | 199 | if norag_template is None: 200 | norag_templete = PromptTemplate( 201 | config=config, 202 | system_prompt="Answer the question based on your own knowledge. Only give me the answer and do not output any other words.", 203 | user_prompt="Question: {question}", 204 | ) 205 | self.norag_pipeline = SequentialPipeline( 206 | config, 207 | prompt_template=norag_templete, 208 | retriever=retriever, 209 | generator=generator, 210 | ) 211 | 212 | self.single_hop_pipeline = SequentialPipeline( 213 | config, 214 | prompt_template=single_hop_prompt_template, 215 | retriever=retriever, 216 | generator=generator, 217 | ) 218 | 219 | self.multi_hop_pipeline = IRCOTPipeline( 220 | config, prompt_template=multi_hop_prompt_template, retriever=retriever, generator=generator, max_iter=5 221 | ) 222 | 223 | def run(self, dataset, do_eval=True, pred_process_fun=None): 224 | # judge_result: choice result representing which pipeline to use(e.g. A, B, C) 225 | judge_result = self.judger.judge(dataset) 226 | dataset.update_output("judge_result", judge_result) 227 | 228 | # split dataset based on judge_result 229 | dataset_split = split_dataset(dataset, judge_result) 230 | for symbol, symbol_dataset in dataset_split.items(): 231 | if symbol == "A": 232 | symbol_dataset = self.norag_pipeline.naive_run(symbol_dataset, do_eval=False) 233 | elif symbol == "B": 234 | symbol_dataset = self.single_hop_pipeline.run(symbol_dataset, do_eval=False) 235 | elif symbol == "C": 236 | symbol_dataset = self.multi_hop_pipeline.run(symbol_dataset, do_eval=False) 237 | else: 238 | assert False, "Unknown symbol!" 239 | 240 | # merge datasets into original format 241 | dataset = merge_dataset(dataset_split, judge_result) 242 | 243 | dataset = self.evaluate(dataset, do_eval=do_eval, pred_process_fun=pred_process_fun) 244 | 245 | return dataset 246 | -------------------------------------------------------------------------------- /flashrag/refiner/refiner.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | from transformers import AutoModelForSeq2SeqLM, AutoTokenizer 3 | from flashrag.retriever.encoder import Encoder 4 | from tqdm import tqdm 5 | import re 6 | import torch 7 | import numpy as np 8 | 9 | class BaseRefiner: 10 | r"""Base object of Refiner method""" 11 | 12 | def __init__(self, config): 13 | self.config = config 14 | self.name = config["refiner_name"] 15 | self.model_path = config["refiner_model_path"] 16 | self.device = config["device"] 17 | self.input_prompt_flag = config["refiner_input_prompt_flag"] if "refiner_input_prompt_flag" in config else False 18 | 19 | def run(self, item) -> str: 20 | r"""Get refining result. 21 | 22 | Args: 23 | item: dataset item, contains question, retrieval result... 24 | 25 | Returns: 26 | str: refining result of this item 27 | """ 28 | pass 29 | 30 | def batch_run(self, dataset, batch_size=None) -> List[str]: 31 | return [self.run(item) for item in dataset] 32 | 33 | 34 | class LLMLinguaRefiner(BaseRefiner): 35 | """Implementation for (Long)LLMLingua.""" 36 | 37 | def __init__(self, config): 38 | super().__init__(config) 39 | default_config = { 40 | 'use_llmlingua2': False, 41 | "rate": 0.55, 42 | "condition_in_question": "after_condition", 43 | "reorder_context": "sort", 44 | "dynamic_context_compression_ratio": 0.3, 45 | "condition_compare": True, 46 | "context_budget": "+100", 47 | "rank_method": "longllmlingua", 48 | } 49 | if "llmlingua_config" in config and config["llmlingua_config"] is not None: 50 | self.compress_config = config["llmlingua_config"] 51 | else: 52 | self.compress_config = default_config 53 | 54 | from flashrag.refiner.llmlingua_compressor import PromptCompressor 55 | 56 | if 'use_llmlingua2' in self.compress_config: 57 | use_llmlingua2 = self.compress_config.pop('use_llmlingua2') 58 | else: 59 | use_llmlingua2 = False 60 | self.refiner = PromptCompressor(model_name=self.model_path, use_llmlingua2=use_llmlingua2) 61 | 62 | def format_reference(self, retrieval_result): 63 | format_reference = "" 64 | for idx, doc_item in enumerate(retrieval_result): 65 | content = doc_item["contents"] 66 | title = content.split("\n")[0] 67 | text = "\n".join(content.split("\n")[1:]) 68 | format_reference += f"Doc {idx+1}(Title: {title}) {text}\n" 69 | 70 | return format_reference 71 | 72 | def batch_run(self, dataset): 73 | output = [] 74 | for item in tqdm(dataset, desc="Refining process: "): 75 | question = item.question 76 | retrieval_result = item.retrieval_result 77 | # TODO: suit more cases 78 | if self.input_prompt_flag: 79 | input_prompt = item.prompt 80 | prompt_split = input_prompt.split("\n\n") 81 | # need fixed format prompt: instr + demon(retrieval results) + question 82 | instruction, question = prompt_split[0], prompt_split[-1] 83 | demonstration = "\n".join(prompt_split[1:-1]) 84 | item_output = self.refiner.compress_prompt( 85 | [i for i in demonstration.split("\n") if i != ""], 86 | instruction=instruction, 87 | question=question, 88 | **self.compress_config, 89 | ) 90 | else: 91 | docs = self.format_reference(retrieval_result).split("\n") 92 | docs = [i for i in docs if i != ""] 93 | item_output = self.refiner.compress_prompt( 94 | docs, instruction="", question=question, **self.compress_config 95 | ) 96 | output.append(item_output["compressed_prompt"]) 97 | return output 98 | 99 | 100 | class SelectiveContextRefiner(BaseRefiner): 101 | """Implementation for Selective Context""" 102 | 103 | def __init__(self, config): 104 | super().__init__(config) 105 | from flashrag.refiner.selective_context_compressor import SelectiveContext 106 | 107 | default_config = {"reduce_ratio": 0.5} 108 | 109 | self.refiner = SelectiveContext(model_type="gpt2", model_path=self.model_path, lang="en") 110 | if "sc_config" in config and config["sc_config"] is not None: 111 | self.compress_config = config["sc_config"] 112 | else: 113 | self.compress_config = default_config 114 | 115 | def format_reference(self, retrieval_result): 116 | format_reference = "" 117 | for idx, doc_item in enumerate(retrieval_result): 118 | content = doc_item["contents"] 119 | title = content.split("\n")[0] 120 | text = "\n".join(content.split("\n")[1:]) 121 | format_reference += f"Doc {idx+1}(Title: {title}) {text}\n" 122 | 123 | return format_reference 124 | 125 | def batch_run(self, dataset): 126 | # only use text 127 | all_inputs = [] 128 | for item in dataset: 129 | retrieval_result = item.retrieval_result 130 | all_inputs.append(self.format_reference(retrieval_result)) 131 | 132 | output = [] 133 | for text in tqdm(all_inputs, desc="Refining process: "): 134 | compress_text, _ = self.refiner(text, **self.compress_config) 135 | output.append(compress_text) 136 | return output 137 | 138 | 139 | class ExtractiveRefiner(BaseRefiner): 140 | """Implementation for Extractive compressor. 141 | Using retrieval method to select sentences or other granularity data. 142 | """ 143 | 144 | def __init__(self, config): 145 | super().__init__(config) 146 | # number of keeping sentences 147 | self.topk = config["refiner_topk"] 148 | self.pooling_method = config["refiner_pooling_method"] 149 | self.encode_max_length = config["refiner_encode_max_length"] 150 | self.mini_batch_size = config['refiner_mini_batch_size'] if 'refiner_mini_batch_size' in config else 256 151 | # load model 152 | self.encoder = Encoder( 153 | model_name=self.name, 154 | model_path=self.model_path, 155 | pooling_method=self.pooling_method, 156 | max_length=self.encode_max_length, 157 | use_fp16=True 158 | ) 159 | 160 | def batch_run(self, dataset, batch_size=16): 161 | questions = dataset.question 162 | # only use text 163 | retrieval_results = dataset.retrieval_result 164 | retrieval_results = [ 165 | ["\n".join(doc_item["contents"].split("\n")[1:]) for doc_item in item_result] 166 | for item_result in retrieval_results 167 | ] 168 | 169 | # split into sentences: [[sent1, sent2,...], [...]] 170 | sent_lists = [ 171 | [i.strip() for i in re.split(r"(?<=[.!?])\s+", " ".join(res)) if len(i.strip()) > 5] 172 | for res in retrieval_results 173 | ] 174 | score_lists = [] # matching scores, size == sent_lists 175 | for idx in tqdm(range(0, len(questions), batch_size), desc="Refining process: "): 176 | batch_questions = questions[idx : idx + batch_size] 177 | batch_sents = sent_lists[idx : idx + batch_size] 178 | question_embs = self.encoder.encode(batch_questions, is_query=True) 179 | 180 | flatten_batch_sents = sum(batch_sents, []) 181 | sent_embs = [] 182 | for s_index in tqdm(range(0, len(flatten_batch_sents), self.mini_batch_size), desc='Sentence encoding..,'): 183 | mini_batch_sents = flatten_batch_sents[s_index:s_index+self.mini_batch_size] 184 | mini_sent_embs = self.encoder.encode(mini_batch_sents, is_query=False) 185 | sent_embs.append(mini_sent_embs) 186 | sent_embs = np.concatenate(sent_embs, axis=0) 187 | 188 | scores = question_embs @ sent_embs.T 189 | start_idx = 0 190 | for row_score, single_list in zip(scores, batch_sents): 191 | row_score = row_score.tolist() 192 | score_lists.append(row_score[start_idx : start_idx + len(single_list)]) 193 | start_idx += len(single_list) 194 | 195 | # select topk sents 196 | retain_lists = [] 197 | for sent_scores, sent_list in zip(score_lists, sent_lists): 198 | assert len(sent_scores) == len(sent_list) 199 | if len(sent_scores) < self.topk: 200 | retain_lists.append(sent_list) 201 | continue 202 | 203 | topk_idxs = torch.topk(torch.Tensor(sent_scores), min(self.topk, len(sent_scores))).indices.tolist() 204 | retain_lists.append([sent_list[idx] for idx in sorted(topk_idxs) if idx < len(sent_list)]) 205 | 206 | return [" ".join(sents) for sents in retain_lists] 207 | 208 | 209 | class AbstractiveRecompRefiner(BaseRefiner): 210 | """Implementation for Abstractive RECOMP compressor: 211 | RECOMP: Improving Retrieval-Augmented LMs with Compression and Selective Augmentation. 212 | """ 213 | 214 | def __init__(self, config): 215 | super().__init__(config) 216 | 217 | self.max_input_length = config["refiner_max_input_length"] 218 | self.max_output_length = config["refiner_max_output_length"] 219 | 220 | # load model 221 | self.tokenizer = AutoTokenizer.from_pretrained(self.model_path) 222 | self.model = AutoModelForSeq2SeqLM.from_pretrained(self.model_path) 223 | self.model.cuda() 224 | self.model.eval() 225 | 226 | def batch_run(self, dataset, batch_size=2): 227 | # only use text 228 | retrieval_results = dataset.retrieval_result 229 | retrieval_results = [ 230 | ["\n".join(doc_item["contents"].split("\n")[1:]) for doc_item in item_result] 231 | for item_result in retrieval_results 232 | ] 233 | 234 | # input processing in recomp training format 235 | format_inputs = [ 236 | "Question: {question}\n Document: {document}\n Summary: ".format( 237 | question=item.question, document="\n".join(docs) 238 | ) 239 | for item, docs in zip(dataset, retrieval_results) 240 | ] 241 | 242 | results = [] 243 | for idx in tqdm(range(0, len(format_inputs), batch_size), desc="Refining process: "): 244 | batch_inputs = format_inputs[idx : idx + batch_size] 245 | batch_inputs = self.tokenizer( 246 | batch_inputs, return_tensors="pt", padding=True, truncation=True, max_length=self.max_input_length 247 | ).to(self.device) 248 | 249 | batch_outputs = self.model.generate(**batch_inputs, max_length=self.max_output_length) 250 | 251 | batch_outputs = self.tokenizer.batch_decode( 252 | batch_outputs, skip_special_tokens=True, clean_up_tokenization_spaces=False 253 | ) 254 | 255 | results.extend(batch_outputs) 256 | 257 | return results 258 | -------------------------------------------------------------------------------- /flashrag/generator/openai_generator.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import List 3 | from copy import deepcopy 4 | import warnings 5 | from tqdm import tqdm 6 | import numpy as np 7 | import random 8 | import time 9 | 10 | import asyncio 11 | from openai import AsyncOpenAI, AsyncAzureOpenAI 12 | 13 | class OpenaiGenerator: 14 | """Class for api-based openai models""" 15 | 16 | def __init__(self, config): 17 | self.model_name = config["generator_model"] 18 | self.batch_size = config["generator_batch_size"] 19 | self.generation_params = config["generation_params"] 20 | 21 | # 通用模型配置 22 | # 用户可在配置文件设置自己需要的模型 23 | 24 | # 检测模型类型并选择对应的配置 25 | 26 | # OpenRouter模型检测 27 | 28 | # 检测模型类型并选择对应的配置 29 | if "gemini" in self.model_name.lower(): 30 | self.openai_setting = config["google_setting"] 31 | elif "claude" in self.model_name.lower(): 32 | self.openai_setting = config["anthropic_setting"] 33 | elif "deepseek" in self.model_name.lower(): 34 | self.openai_setting = config["deepseek_setting"] 35 | # 其他模型类型判断... 36 | else: 37 | self.openai_setting = config["openai_setting"] 38 | 39 | # 统一处理所有模型的API密钥配置 40 | # 处理多API密钥配置(字符串或列表形式) 41 | if "api_keys" in self.openai_setting and self.openai_setting["api_keys"]: 42 | self.api_keys = self.openai_setting["api_keys"].split(",") if isinstance(self.openai_setting["api_keys"], str) else self.openai_setting["api_keys"] 43 | # 移除空字符串或空白项 44 | self.api_keys = [key.strip() for key in self.api_keys if key.strip()] 45 | if self.api_keys: 46 | # 创建独立的随机数生成器 47 | self.api_key_random = random.Random(time.time()) 48 | # 为保持兼容性,设置第一个密钥为当前api_key 49 | self.openai_setting["api_key"] = self.api_keys[0] 50 | # 创建一个client字典以存储每个密钥对应的client 51 | self.clients = {} 52 | else: 53 | self.api_keys = None 54 | self.clients = None 55 | else: 56 | # 单个API密钥情况 57 | self.api_keys = None 58 | self.clients = None 59 | # print(self.model_name) 60 | # print(self.openai_setting) 61 | if self.openai_setting.get("api_key") is None: 62 | self.openai_setting["api_key"] = os.getenv("OPENAI_API_KEY") 63 | 64 | # 添加 base_url 支持 65 | if "api_base" in self.openai_setting: 66 | self.base_url = self.openai_setting["api_base"] 67 | else: 68 | self.base_url = None 69 | 70 | # 处理单个client或多个client的情况 71 | client_settings = deepcopy(self.openai_setting) 72 | 73 | # 设置较长的超时时间,避免大型请求超时 74 | client_settings["timeout"] = 60.0 # 设置为60秒 75 | 76 | # 从客户端设置中移除非客户端初始化参数 77 | if "api_keys" in client_settings: 78 | del client_settings["api_keys"] 79 | 80 | if "api_type" in client_settings and client_settings["api_type"] == "azure": 81 | del client_settings["api_type"] 82 | # 确保 base_url 被正确传递给 AsyncAzureOpenAI 83 | if self.base_url: 84 | client_settings["base_url"] = self.base_url 85 | 86 | # 创建客户端 - 对所有模型类型适用相同的逻辑 87 | if self.api_keys: 88 | # 为每个API密钥创建一个client 89 | for api_key in self.api_keys: 90 | key_settings = deepcopy(client_settings) 91 | key_settings["api_key"] = api_key 92 | self.clients[api_key] = AsyncAzureOpenAI(**key_settings) 93 | 94 | # 同时保留单个client以保持兼容性 95 | self.client = self.clients[self.openai_setting["api_key"]] 96 | else: 97 | # 常规单一client初始化 98 | self.client = AsyncAzureOpenAI(**client_settings) 99 | else: 100 | # 确保 base_url 被正确传递给 AsyncOpenAI 101 | if self.base_url: 102 | client_settings["base_url"] = self.base_url 103 | 104 | # 创建客户端 - 对所有模型类型适用相同的逻辑 105 | if self.api_keys: 106 | # 为每个API密钥创建一个client 107 | for api_key in self.api_keys: 108 | key_settings = deepcopy(client_settings) 109 | key_settings["api_key"] = api_key 110 | self.clients[api_key] = AsyncOpenAI(**key_settings) 111 | 112 | # 同时保留单个client以保持兼容性 113 | self.client = self.clients[self.openai_setting["api_key"]] 114 | else: 115 | # 常规单一client初始化 116 | self.client = AsyncOpenAI(**client_settings) 117 | 118 | # 移除 tiktoken 相关代码,因为实际上并未使用 tokenizer 119 | 120 | def _get_next_client(self): 121 | """获取API client,实现随机选择机制""" 122 | # 如果没有多个API密钥可用,则使用默认client 123 | if not self.api_keys: 124 | return self.client 125 | 126 | # 对所有模型类型适用相同的随机选择逻辑 127 | # 使用独立的随机数生成器随机选择一个API密钥 128 | random_index = self.api_key_random.randint(0, len(self.api_keys) - 1) 129 | current_key = self.api_keys[random_index] 130 | return self.clients[current_key] 131 | 132 | async def get_response(self, input: List, **params): 133 | # 所有模型均使用统一的轮询机制 134 | retries = 3 # 最大重试次数 135 | last_error = None 136 | 137 | for attempt in range(retries): 138 | current_key = None 139 | try: 140 | # 判断是否可以使用多密钥轮询 141 | if self.api_keys: 142 | # 每次尝试都获取新的客户端,确保API密钥轮换 143 | client = self._get_next_client() 144 | # 获取当前密钥信息便于调试 145 | for key, client_obj in self.clients.items(): 146 | if client_obj == client: 147 | current_key = key 148 | break 149 | # 打印当前使用的密钥前十位字符 150 | key_prefix = current_key[:10] + "..." if current_key else "unknown" 151 | print(f"[尝试 {attempt+1}/{retries}] 使用API密钥: {key_prefix}") 152 | 153 | # 构建完整的请求参数 154 | request_params = deepcopy(params) 155 | 156 | # 所有模型类型统一使用相同的API调用方式 157 | response = await client.chat.completions.create(model=self.model_name, messages=input, **request_params) 158 | else: 159 | # 单个密钥情况,使用默认client 160 | response = await self.client.chat.completions.create(model=self.model_name, messages=input, **params) 161 | 162 | return response.choices[0] 163 | except Exception as e: 164 | last_error = e 165 | # 判断是否是API限制错误,并且是否有多个API密钥可用 166 | if ("429" in str(e) or "rate" in str(e).lower() or "limit" in str(e).lower() or "quota" in str(e).lower()) and self.api_keys: 167 | key_prefix = current_key[:10] + "..." if current_key else "unknown" 168 | print(f"API密钥限制或配额耗尽: {key_prefix}, 切换到随机密钥 (尝试 {attempt+1}/{retries})") 169 | print(f"错误详情: {str(e)}") 170 | # 短暂等待,避免连续请求 171 | await asyncio.sleep(1) 172 | continue 173 | # 其他错误直接抛出 174 | raise e 175 | 176 | # 所有重试都失败 177 | print(f"[错误] 所有API密钥尝试失败") 178 | raise last_error 179 | 180 | async def get_batch_response(self, input_list: List[List], batch_size, **params): 181 | total_input = [self.get_response(input, **params) for input in input_list] 182 | all_result = [] 183 | for idx in tqdm(range(0, len(input_list), batch_size), desc="Generation process: "): 184 | batch_input = total_input[idx : idx + batch_size] 185 | batch_result = await asyncio.gather(*batch_input) 186 | all_result.extend(batch_result) 187 | 188 | return all_result 189 | 190 | def _filter_thinking_chain(self, text): 191 | """过滤掉输出中的思维链部分 192 | 193 | Args: 194 | text (str): 原始响应文本 195 | 196 | Returns: 197 | str: 过滤后的文本 198 | """ 199 | if not text: 200 | return text 201 | 202 | # 过滤 ... 格式的思维链 203 | import re 204 | filtered_text = re.sub(r'.*?', '', text, flags=re.DOTALL) 205 | 206 | return filtered_text.strip() 207 | 208 | def generate(self, input_list: List[List], batch_size=None, return_scores=False, **params) -> List[str]: 209 | # deal with single input 210 | if len(input_list) == 1 and isinstance(input_list[0], dict): 211 | input_list = [input_list] 212 | if batch_size is None: 213 | batch_size = self.batch_size 214 | 215 | # deal with generation params 216 | generation_params = deepcopy(self.generation_params) 217 | generation_params.update(params) 218 | if "do_sample" in generation_params: 219 | generation_params.pop("do_sample") 220 | 221 | # 原来的max_tokens设置代码(已注释) 222 | # max_tokens = params.pop("max_tokens", None) or params.pop("max_new_tokens", None) 223 | # if max_tokens is not None: 224 | # generation_params["max_tokens"] = max_tokens 225 | # else: 226 | # generation_params["max_tokens"] = generation_params.get( 227 | # "max_tokens", generation_params.pop("max_new_tokens", None) 228 | # ) 229 | 230 | # 写死max_tokens为10000,不再从参数或配置中读取 231 | generation_params["max_tokens"] = 8192 232 | # 清理可能存在的其他token限制参数 233 | params.pop("max_tokens", None) 234 | params.pop("max_new_tokens", None) 235 | generation_params.pop("max_new_tokens", None) 236 | 237 | if return_scores: 238 | if generation_params.get("logprobs") is not None: 239 | generation_params["logprobs"] = True 240 | warnings.warn("Set logprobs to True to get generation scores.") 241 | else: 242 | generation_params["logprobs"] = True 243 | 244 | if generation_params.get("n") is not None: 245 | generation_params["n"] = 1 246 | warnings.warn("Set n to 1. It can minimize costs.") 247 | else: 248 | generation_params["n"] = 1 249 | 250 | loop = asyncio.get_event_loop() 251 | result = loop.run_until_complete(self.get_batch_response(input_list, batch_size, **generation_params)) 252 | 253 | # parse result into response text and logprob 254 | scores = [] 255 | response_text = [] 256 | for res in result: 257 | # 过滤思维链 258 | filtered_content = self._filter_thinking_chain(res.message.content) 259 | response_text.append(filtered_content) 260 | if return_scores: 261 | score = np.exp(list(map(lambda x: x.logprob, res.logprobs.content))) 262 | scores.append(score) 263 | if return_scores: 264 | return response_text, scores 265 | else: 266 | return response_text -------------------------------------------------------------------------------- /flashrag/pipeline/branching_pipeline.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | from typing import List 3 | import re 4 | from tqdm import tqdm 5 | import numpy as np 6 | from transformers import LogitsProcessorList 7 | from flashrag.utils import get_retriever, get_generator 8 | from flashrag.pipeline import BasicPipeline 9 | from flashrag.prompt import PromptTemplate 10 | 11 | 12 | class REPLUGPipeline(BasicPipeline): 13 | def __init__(self, config, prompt_template=None): 14 | from flashrag.pipeline.replug_utils import load_replug_model 15 | 16 | super().__init__(config, prompt_template) 17 | # load specify model for REPLUG 18 | model = load_replug_model(config["generator_model_path"]) 19 | self.generator = get_generator(config, model=model) 20 | 21 | self.retriever = get_retriever(config) 22 | 23 | def build_single_doc_prompt(self, question: str, doc_list: List[str]): 24 | return [self.prompt_template.get_string(question=question, formatted_reference=doc) for doc in doc_list] 25 | 26 | def format_reference(self, doc_item): 27 | content = doc_item["contents"] 28 | title = content.split("\n")[0] 29 | text = "\n".join(content.split("\n")[1:]) 30 | return f"Document(Title: {title}): {text}" 31 | 32 | def run(self, dataset, do_eval=True, pred_process_fun=None): 33 | import torch 34 | from flashrag.pipeline.replug_utils import REPLUGLogitsProcessor 35 | 36 | input_query = dataset.question 37 | retrieval_results, doc_scores = self.retriever.batch_search(input_query, return_score=True) 38 | dataset.update_output("retrieval_result", retrieval_results) 39 | dataset.update_output("doc_scores", doc_scores) 40 | 41 | pred_answer_list = [] 42 | # each doc has a prompt 43 | for item in tqdm(dataset, desc="Inference: "): 44 | docs = [self.format_reference(doc_item) for doc_item in item.retrieval_result] 45 | prompts = self.build_single_doc_prompt(question=item.question, doc_list=docs) 46 | 47 | scores = torch.tensor(item.doc_scores, dtype=torch.float32).to(self.device) 48 | output = self.generator.generate( 49 | prompts, batch_size=len(docs), logits_processor=LogitsProcessorList([REPLUGLogitsProcessor(scores)]) 50 | ) 51 | # the output of the batch is same 52 | output = output[0] 53 | pred_answer_list.append(output) 54 | 55 | dataset.update_output("pred", pred_answer_list) 56 | 57 | dataset = self.evaluate(dataset, do_eval=do_eval, pred_process_fun=pred_process_fun) 58 | 59 | return dataset 60 | 61 | 62 | class SuRePipeline(BasicPipeline): 63 | def __init__(self, config, prompt_template=None): 64 | super().__init__(config, prompt_template) 65 | self.config = config 66 | self.generator = get_generator(config) 67 | self.retriever = get_retriever(config) 68 | 69 | self.load_prompts() 70 | 71 | def load_prompts(self): 72 | # prompt for candidates generation 73 | P_CAN_INSTRUCT = ( 74 | "Below are {N} passages related to the question at the end. After reading" 75 | "the passages, provide two correct candidates for the answer to the" 76 | "question at the end. Each answer should be in the form: (a) xx, (b)" 77 | "yy, and should not exceed 3 words for each candidate.\n\n" 78 | "{reference}" 79 | "Question: {question}\n" 80 | "Answer:" 81 | ) 82 | 83 | # prompt for candidate-conditioned summarization 84 | P_SUM_INSTRUCT = ( 85 | "Reference:\n{reference}\n" 86 | "Your job is to act as a professional writer. You need to write a" 87 | "good-quality passage that can support the given prediction about the" 88 | "question only based on the information in the provided supporting passages.\n" 89 | "Now, let's start. After you write, please write [DONE] to indicate you" 90 | "are done. Do not write a prefix (e.g., 'Response:') while writing a passage.\n" 91 | "Question: {question}\n" 92 | "Prediction: {pred}\n" 93 | "Passage:" 94 | ) 95 | 96 | # prompt for instance-wise validation 97 | P_VAL_INSTRUCT = ( 98 | "Question: {question}\n" 99 | "Prediction: {pred}\n" 100 | "Passage: {summary}\n" 101 | "Does the passage correctly support the prediction? Choices: [True,False].\n" 102 | "Answer:" 103 | ) 104 | 105 | # prompt for pair-wise ranking 106 | P_RANK_INSTRUCT = ( 107 | "Question: Given the following passages, determine which one provides a" 108 | "more informative answer to the subsequent question.\n" 109 | "Passage 1: {summary1}\n" 110 | "Passage 2: {summary2}\n" 111 | "Target Question: {question}\n" 112 | "Your Task:\n" 113 | "Identify which passage (Passage 1 or Passage 2) is more relevant and" 114 | "informative to answer the question at hand. Choices: [Passage 1,Passage 2].\n" 115 | "Answer:" 116 | ) 117 | 118 | self.P_CAN_TEMPLATE = PromptTemplate(self.config, "", P_CAN_INSTRUCT) 119 | self.P_SUM_TEMPLATE = PromptTemplate(self.config, "", P_SUM_INSTRUCT) 120 | self.P_VAL_TEMPLATE = PromptTemplate(self.config, "", P_VAL_INSTRUCT) 121 | self.P_RANK_TEMPLATE = PromptTemplate(self.config, "", P_RANK_INSTRUCT) 122 | 123 | @staticmethod 124 | def format_ref(titles, texts): 125 | formatted_ref = "" 126 | idx = 1 127 | for title, text in zip(titles, texts): 128 | formatted_ref += f"Passage #{idx} Title: {title}\n" 129 | formatted_ref += f"Passage #{idx} Text: {text}\n" 130 | formatted_ref += "\n" 131 | idx += 1 132 | return formatted_ref 133 | 134 | @staticmethod 135 | def parse_candidates(model_response): 136 | """Parse candidates from model response""" 137 | model_response = model_response.strip("\n").strip() 138 | # r'\([a-z]\) ([^,]+)' 139 | candidates = re.findall("\((\w+)\)\s*([^()]+)", model_response) 140 | candidates = [cand[1].split("\n")[0].strip() for cand in candidates] 141 | # post-process 142 | candidates = [cand.replace(",", "").strip() for cand in candidates] 143 | return candidates 144 | 145 | @staticmethod 146 | def parse_validation(model_response): 147 | """Parse model's validation result into score based on the paper formula""" 148 | model_response = model_response.strip().lower() 149 | if "true" in model_response: 150 | return 1 151 | else: 152 | return 0 153 | 154 | @staticmethod 155 | def parse_ranking(model_response): 156 | """Parse model's pair ranking result into score""" 157 | model_response = model_response.strip().lower() 158 | if "passage 1" in model_response: 159 | score = 1 160 | elif "passage 2" in model_response: 161 | score = 0 162 | else: 163 | score = 0.5 164 | return score 165 | 166 | def run(self, dataset, do_eval=True, pred_process_fun=None): 167 | input_query = dataset.question 168 | 169 | retrieval_results, doc_scores = self.retriever.batch_search(input_query, return_score=True) 170 | dataset.update_output("retrieval_result", retrieval_results) 171 | 172 | pred_answer_list = [] 173 | for item in tqdm(dataset, desc="Pipeline runing: "): 174 | retrieval_result = item.retrieval_result 175 | doc_num = len(retrieval_result) 176 | # format all docs 177 | for doc_item in retrieval_result: 178 | if "title" not in doc_item or "text" not in doc_item: 179 | doc_item["title"] = doc_item["contents"].split("\n")[0] 180 | doc_item["text"] = "\n".join(doc_item["contents"].split("\n")[1:]) 181 | formatted_ref = self.format_ref( 182 | titles=[i["title"] for i in retrieval_result], texts=[i["text"] for i in retrieval_result] 183 | ) 184 | # get candidates 185 | 186 | input_prompt = self.P_CAN_TEMPLATE.get_string( 187 | N=doc_num, formatted_reference=formatted_ref, question=item.question 188 | ) 189 | output = self.generator.generate([input_prompt])[0] 190 | candidates = self.parse_candidates(output) 191 | item.update_output("candidates", candidates) 192 | 193 | if len(candidates) == 0: 194 | print("No valid predictions!") 195 | pred = "" 196 | pred_answer_list.append(pred) 197 | continue 198 | 199 | # get summarization for each candidate 200 | input_prompts = [ 201 | self.P_SUM_TEMPLATE.get_string(question=item.question, pred=cand, formatted_reference=formatted_ref) 202 | for cand in candidates 203 | ] 204 | 205 | all_summary = self.generator.generate(input_prompts) 206 | item.update_output("all_summary", all_summary) 207 | 208 | # instance-wise validation 209 | input_prompts = [ 210 | self.P_VAL_TEMPLATE.get_string(question=item.question, pred=cand, summary=summary) 211 | for cand, summary in zip(candidates, all_summary) 212 | ] 213 | val_results = self.generator.generate(input_prompts) 214 | val_scores = [self.parse_validation(res) for res in val_results] 215 | item.update_output("val_scores", val_scores) 216 | 217 | # pair-wise ranking 218 | summary_num = len(all_summary) 219 | score_matrix = np.zeros((summary_num, summary_num)) 220 | iter_idxs = list(itertools.permutations(range(summary_num), 2)) 221 | input_prompts = [ 222 | self.P_RANK_TEMPLATE.get_string( 223 | question=item.question, summary1=all_summary[idx_tuple[0]], summary2=all_summary[idx_tuple[1]] 224 | ) 225 | for idx_tuple in iter_idxs 226 | ] 227 | ranking_output = self.generator.generate(input_prompts) 228 | ranking_scores = [self.parse_ranking(res) for res in ranking_output] 229 | for idx_tuple, score in zip(iter_idxs, ranking_scores): 230 | score_matrix[idx_tuple[0], idx_tuple[1]] = score 231 | ranking_scores = score_matrix.sum(axis=1).squeeze().tolist() # ranking score for each summary 232 | item.update_output("ranking_scores", ranking_scores) 233 | 234 | # combine two scores as the final score for each summary 235 | if not isinstance(ranking_scores, list): 236 | ranking_scores = [ranking_scores] 237 | if not isinstance(val_scores, list): 238 | val_scores = [val_scores] 239 | total_scores = [x + y for x, y in zip(val_scores, ranking_scores)] 240 | 241 | best_idx = np.argmax(total_scores) 242 | pred = candidates[best_idx] 243 | pred_answer_list.append(pred) 244 | 245 | dataset.update_output("pred", pred_answer_list) 246 | 247 | dataset = self.evaluate(dataset, do_eval=do_eval, pred_process_fun=pred_process_fun) 248 | 249 | return dataset 250 | -------------------------------------------------------------------------------- /flashrag/pipeline/replug_utils.py: -------------------------------------------------------------------------------- 1 | # Source: https://github.com/IntelLabs/fastRAG/blob/main/fastrag/generators/replug.py 2 | # The release is licensed under the Apache License 2.0 3 | 4 | import warnings 5 | from typing import List, Optional, Union 6 | import transformers 7 | from transformers import ( 8 | MODEL_FOR_CAUSAL_LM_MAPPING, 9 | AutoConfig, 10 | GenerationMixin, 11 | LogitsProcessor, 12 | LogitsProcessorList, 13 | StoppingCriteriaList, 14 | ) 15 | import torch 16 | import torch.nn as nn 17 | import torch.distributed as dist 18 | from transformers.generation.stopping_criteria import validate_stopping_criteria 19 | from transformers.generation.utils import ( 20 | SampleDecoderOnlyOutput, 21 | SampleEncoderDecoderOutput, 22 | SampleOutput, 23 | ) 24 | 25 | 26 | class REPLUG_Generation(GenerationMixin): 27 | """Implementing REPLUG-based sampling text generation.""" 28 | 29 | def sample( 30 | self, 31 | input_ids: torch.LongTensor, 32 | logits_processor: Optional[LogitsProcessorList] = None, 33 | stopping_criteria: Optional[StoppingCriteriaList] = None, 34 | logits_warper: Optional[LogitsProcessorList] = None, 35 | max_length: Optional[int] = None, 36 | pad_token_id: Optional[int] = None, 37 | eos_token_id: Optional[Union[int, List[int]]] = None, 38 | output_attentions: Optional[bool] = None, 39 | output_hidden_states: Optional[bool] = None, 40 | output_scores: Optional[bool] = None, 41 | return_dict_in_generate: Optional[bool] = None, 42 | synced_gpus: bool = False, 43 | streamer: Optional["BaseStreamer"] = None, 44 | **model_kwargs, 45 | ) -> Union[SampleOutput, torch.LongTensor]: 46 | # init values 47 | logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() 48 | stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() 49 | if max_length is not None: 50 | warnings.warn( 51 | "`max_length` is deprecated in this function, use" 52 | " `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.", 53 | UserWarning, 54 | ) 55 | stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) 56 | logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList() 57 | pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id 58 | eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id 59 | if isinstance(eos_token_id, int): 60 | eos_token_id = [eos_token_id] 61 | eos_token_id_tensor = torch.tensor(eos_token_id).to(input_ids.device) if eos_token_id is not None else None 62 | output_scores = output_scores if output_scores is not None else self.generation_config.output_scores 63 | output_attentions = ( 64 | output_attentions if output_attentions is not None else self.generation_config.output_attentions 65 | ) 66 | output_hidden_states = ( 67 | output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states 68 | ) 69 | return_dict_in_generate = ( 70 | return_dict_in_generate 71 | if return_dict_in_generate is not None 72 | else self.generation_config.return_dict_in_generate 73 | ) 74 | 75 | # init attention / hidden states / scores tuples 76 | scores = () if (return_dict_in_generate and output_scores) else None 77 | decoder_attentions = () if (return_dict_in_generate and output_attentions) else None 78 | cross_attentions = () if (return_dict_in_generate and output_attentions) else None 79 | decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None 80 | 81 | # if model is an encoder-decoder, retrieve encoder attention weights and hidden states 82 | if return_dict_in_generate and self.config.is_encoder_decoder: 83 | encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None 84 | encoder_hidden_states = ( 85 | model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None 86 | ) 87 | 88 | # keep track of which sequences are already finished 89 | unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device) 90 | 91 | this_peer_finished = False # used by synced_gpus only 92 | # auto-regressive generation 93 | while True: 94 | if synced_gpus: 95 | # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. 96 | # The following logic allows an early break if all peers finished generating their sequence 97 | this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device) 98 | # send 0.0 if we finished, 1.0 otherwise 99 | dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) 100 | # did all peers finish? the reduced sum will be 0.0 then 101 | if this_peer_finished_flag.item() == 0.0: 102 | break 103 | 104 | # prepare model inputs 105 | model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) 106 | 107 | # forward pass to get next token 108 | outputs = self( 109 | **model_inputs, 110 | return_dict=True, 111 | output_attentions=output_attentions, 112 | output_hidden_states=output_hidden_states, 113 | ) 114 | 115 | if synced_gpus and this_peer_finished: 116 | continue # don't waste resources running the code we don't need 117 | 118 | next_token_logits = outputs.logits[:, -1, :] 119 | 120 | # pre-process distribution 121 | ### REPLUG - document weighting is done via REPLUGLogitsProcessor 122 | next_token_scores = logits_processor(input_ids, next_token_logits) 123 | next_token_scores = logits_warper(input_ids, next_token_scores) 124 | 125 | # Store scores, attentions and hidden_states when required 126 | if return_dict_in_generate: 127 | if output_scores: 128 | scores += (next_token_scores,) 129 | if output_attentions: 130 | decoder_attentions += ( 131 | (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) 132 | ) 133 | if self.config.is_encoder_decoder: 134 | cross_attentions += (outputs.cross_attentions,) 135 | 136 | if output_hidden_states: 137 | decoder_hidden_states += ( 138 | (outputs.decoder_hidden_states,) if self.config.is_encoder_decoder else (outputs.hidden_states,) 139 | ) 140 | 141 | ### REPLUG 142 | # Sample from the normalized "logits", assuming the REPLUG processor was used! 143 | probs = nn.functional.softmax(next_token_scores, dim=-1) 144 | next_tokens = torch.multinomial(probs, num_samples=1).squeeze(-1) 145 | # Lock same next-token for all examples in batch 146 | next_tokens[:] = next_tokens[0] 147 | 148 | # finished sentences should have their next token be a padding token 149 | if eos_token_id is not None: 150 | if pad_token_id is None: 151 | raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") 152 | next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) 153 | 154 | # update generated ids, model inputs, and length for next step 155 | input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) 156 | if streamer is not None: 157 | streamer.put(next_tokens.cpu()) 158 | model_kwargs = self._update_model_kwargs_for_generation( 159 | outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder 160 | ) 161 | 162 | # if eos_token was found in one sentence, set sentence to finished 163 | if eos_token_id_tensor is not None: 164 | unfinished_sequences = unfinished_sequences.mul( 165 | next_tokens.tile(eos_token_id_tensor.shape[0], 1).ne(eos_token_id_tensor.unsqueeze(1)).prod(dim=0) 166 | ) 167 | 168 | # stop when each sentence is finished 169 | if unfinished_sequences.max() == 0: 170 | this_peer_finished = True 171 | 172 | # stop if we exceed the maximum length 173 | if stopping_criteria(input_ids, scores): 174 | this_peer_finished = True 175 | 176 | if this_peer_finished and not synced_gpus: 177 | break 178 | 179 | if streamer is not None: 180 | streamer.end() 181 | 182 | if return_dict_in_generate: 183 | if self.config.is_encoder_decoder: 184 | return SampleEncoderDecoderOutput( 185 | sequences=input_ids, 186 | scores=scores, 187 | encoder_attentions=encoder_attentions, 188 | encoder_hidden_states=encoder_hidden_states, 189 | decoder_attentions=decoder_attentions, 190 | cross_attentions=cross_attentions, 191 | decoder_hidden_states=decoder_hidden_states, 192 | ) 193 | else: 194 | return SampleDecoderOnlyOutput( 195 | sequences=input_ids, 196 | scores=scores, 197 | attentions=decoder_attentions, 198 | hidden_states=decoder_hidden_states, 199 | ) 200 | else: 201 | return input_ids 202 | 203 | 204 | class REPLUGLogitsProcessor(LogitsProcessor): 205 | """ 206 | Merge logits of different docs in one batch. 207 | 208 | Reference: fastRAG 209 | """ 210 | 211 | def __init__(self, doc_scores: torch.FloatTensor): 212 | self.num_docs = doc_scores.shape[0] 213 | # normalize 214 | doc_scores /= doc_scores.sum() 215 | self.doc_scores = torch.unsqueeze(doc_scores, 1) # k*1 216 | 217 | def __call__(self, input_ids, scores): 218 | # doc_score: k*1, scores: k*vocab_size 219 | replug_scores = self.doc_scores * scores 220 | replug_scores = replug_scores.sum(dim=0) # 1*vocab_size 221 | replug_scores = torch.tile(replug_scores, (self.num_docs, 1)) # k*vocab_size 222 | return replug_scores 223 | 224 | 225 | def load_replug_model(name_or_path): 226 | class HF_REPLUG: 227 | "Creates a HF model that inherits from REPLUG_Generation class" 228 | 229 | def __new__(cls, name_or_path, **kwargs): 230 | return factory(name_or_path).from_pretrained(name_or_path, **kwargs) 231 | 232 | def factory(name_or_path): 233 | loadedConfig = AutoConfig.from_pretrained(name_or_path) 234 | try: 235 | pretrained_class_object = getattr(transformers, loadedConfig.architectures[0]) 236 | if pretrained_class_object not in MODEL_FOR_CAUSAL_LM_MAPPING.values(): 237 | raise ValueError(f"Model {pretrained_class_object} is not used for causal LM generation.") 238 | except AttributeError: 239 | raise ValueError("Transformers architecture unknown.") 240 | 241 | class HF(pretrained_class_object, REPLUG_Generation): 242 | """Wrapper around HuggingFace transformers with REPLUG generation.""" 243 | 244 | _keys_to_ignore_on_load_unexpected = [r"cls"] 245 | _tied_weights_keys = ["lm_head.weight"] 246 | 247 | return HF 248 | 249 | return HF_REPLUG(name_or_path) 250 | -------------------------------------------------------------------------------- /hopweaver/evaluation_system/check_and_complete_evaluations.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- encoding: utf-8 -*- 3 | ''' 4 | Check all CSV files in the evaluation results folder, find models with missing scores, and use CSV data as input to retry until all evaluations are complete 5 | ''' 6 | import os 7 | import sys 8 | import pandas as pd 9 | import json 10 | import glob 11 | from tqdm import tqdm 12 | # Add parent directory to path to import flashrag 13 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 14 | # Add current directory to path as well 15 | sys.path.append(os.path.dirname(os.path.abspath(__file__))) 16 | from flashrag.config import Config 17 | from evaluator import QualityEvaluator 18 | 19 | # Evaluation results directory 20 | EVAL_RESULT_DIR = "./eval_result" 21 | # Configuration file path 22 | CONFIG_PATH = "./config_lib/example_config.yaml" 23 | # Dataset samples directory 24 | DATASET_SAMPLES_DIR = "./datasets/samples" 25 | 26 | # We now read data directly from CSV files, no longer needing a mapping relationship 27 | 28 | def load_csv_file(file_path): 29 | """Load CSV file 30 | 31 | Args: 32 | file_path: CSV file path 33 | 34 | Returns: 35 | DataFrame: DataFrame containing evaluation results 36 | """ 37 | try: 38 | if os.path.exists(file_path): 39 | df = pd.read_csv(file_path) 40 | print(f"Successfully loaded {len(df)} rows of data from {file_path}") 41 | return df 42 | else: 43 | print(f"File does not exist: {file_path}") 44 | return pd.DataFrame() 45 | except Exception as e: 46 | print(f"Error loading CSV file: {str(e)}") 47 | return pd.DataFrame() 48 | 49 | def load_questions(file_path): 50 | """Load question dataset 51 | 52 | Args: 53 | file_path: Question dataset file path 54 | 55 | Returns: 56 | list: List containing question data 57 | """ 58 | try: 59 | with open(file_path, 'r', encoding='utf-8') as f: 60 | data = json.load(f) 61 | print(f"Successfully loaded {len(data)} questions from {file_path}") 62 | return data 63 | except Exception as e: 64 | print(f"Error loading question data: {str(e)}") 65 | return [] 66 | 67 | def filter_rejected_questions(questions): 68 | """Filter out rejected questions 69 | 70 | Args: 71 | questions: List of question data 72 | 73 | Returns: 74 | list: List of filtered question data 75 | """ 76 | filtered = [q for q in questions if not (q.get('status') == 'REJECT' or q.get('status') == 'REJECTED')] 77 | print(f"Filtered to {len(filtered)} questions (removed {len(questions) - len(filtered)} rejected questions)") 78 | return filtered 79 | 80 | def find_missing_evaluations(df, models): 81 | """Find missing evaluations 82 | 83 | Args: 84 | df: DataFrame of evaluation results 85 | models: List of model names 86 | 87 | Returns: 88 | list: List containing missing evaluation information, each item as (question_id, model) 89 | """ 90 | missing = [] 91 | 92 | # Get all non-manual evaluation rows 93 | model_rows = df[df['model'] != 'human'] 94 | 95 | # Check each dimension for null values 96 | dimensions = [ 97 | 'multi_hop_reasoning', 'fluency', 'clarity', 'conciseness', 'relevance', 98 | 'consistency', 'question_answerability', 'answer_question_consistency', 99 | 'information_integration_ability', 'reasoning_path_guidance', 'logical_sophistication', 'overall_quality' 100 | ] 101 | 102 | # Check each question_id and model combination 103 | for question_id in df['id'].unique(): 104 | for model in models: 105 | # Get rows for the current question_id and model 106 | row = model_rows[(model_rows['id'] == question_id) & (model_rows['model'] == model)] 107 | 108 | # If rows do not exist or any dimension is empty, consider it missing 109 | if row.empty: 110 | missing.append((question_id, model)) 111 | print(f"Found missing evaluation: Question ID={question_id}, Model={model} (row does not exist)") 112 | elif row[dimensions].isna().any().any(): 113 | missing.append((question_id, model)) 114 | print(f"Found missing evaluation: Question ID={question_id}, Model={model} (NaN value exists)") 115 | elif (row[dimensions].astype(str) == '').any().any(): 116 | missing.append((question_id, model)) 117 | print(f"Found missing evaluation: Question ID={question_id}, Model={model} (empty string exists)") 118 | 119 | return missing 120 | 121 | def get_question_data_from_csv(df, question_id): 122 | """Get question data from CSV data 123 | 124 | Args: 125 | df: DataFrame of evaluation results 126 | question_id: Question ID 127 | 128 | Returns: 129 | dict: Dictionary containing question data, returns None if not found 130 | """ 131 | # Get question data 132 | question_rows = df[df['id'] == question_id] 133 | if question_rows.empty: 134 | return None 135 | 136 | # Get the first row of data (question, answer, and documents should be the same for all rows) 137 | row = question_rows.iloc[0] 138 | 139 | # Initialize basic question data 140 | question_data = { 141 | 'id': question_id, 142 | 'question': row['question'], 143 | 'answer': row['answer'] 144 | } 145 | 146 | # Parse document content from the 'documents' column 147 | try: 148 | if 'documents' in row and pd.notna(row['documents']): 149 | documents = json.loads(row['documents']) 150 | # Convert document list to dictionary format 151 | for i, doc in enumerate(documents, 1): 152 | question_data[f'document{i}'] = doc 153 | except Exception as e: 154 | print(f"Error parsing document list: {str(e)}") 155 | 156 | return question_data 157 | 158 | def get_dataset_path(dataset_name): 159 | """Get dataset evaluation results file path 160 | 161 | Args: 162 | dataset_name: Dataset name 163 | 164 | Returns: 165 | str: Dataset evaluation results file path 166 | """ 167 | # Directly return the CSV file path 168 | csv_file = os.path.join(EVAL_RESULT_DIR, f"{dataset_name}_evaluation.csv") 169 | if os.path.exists(csv_file): 170 | return csv_file 171 | return None 172 | 173 | def complete_evaluations_from_csv(csv_file): 174 | """Complete missing evaluations using CSV data 175 | 176 | Args: 177 | csv_file: CSV file path 178 | """ 179 | # Extract dataset name 180 | dataset_name = os.path.basename(csv_file).split('_')[0] 181 | print(f"Processing dataset: {dataset_name}") 182 | 183 | # Load CSV file 184 | df = load_csv_file(csv_file) 185 | if df.empty: 186 | print(f"Unable to process empty CSV file: {csv_file}") 187 | return 188 | 189 | # Get all model names (except human) 190 | models = df['model'].unique().tolist() 191 | if 'human' in models: 192 | models.remove('human') 193 | 194 | # Find missing evaluations 195 | missing = find_missing_evaluations(df, models) 196 | if not missing: 197 | print("No missing evaluation items found. All scores are complete!") 198 | return 199 | 200 | print(f"Found {len(missing)} missing evaluation items. Starting to supplement...") 201 | 202 | # Retry each missing evaluation 203 | for question_id, model_name in tqdm(missing, desc="Completing missing evaluations"): 204 | print(f"\n{'='*50}") 205 | print(f"Evaluating Question ID: {question_id}, Model: {model_name}") 206 | print(f"{'='*50}") 207 | 208 | # Get question data from CSV 209 | question = get_question_data_from_csv(df, question_id) 210 | 211 | if not question: 212 | print(f"Warning: Cannot find data for Question ID {question_id} in CSV, skipping this item.") 213 | continue 214 | 215 | # Create configuration 216 | model_config = Config(CONFIG_PATH, {}) 217 | model_config["generator_batch_size"] = 1 218 | model_config["evaluator_model"] = model_name 219 | 220 | # Initialize evaluator 221 | evaluator = QualityEvaluator(model_config) 222 | 223 | # Evaluate question 224 | result = evaluator.evaluate_question(question, max_retry=7) 225 | 226 | if result and 'evaluation' in result: 227 | # Update CSV file rows 228 | dimensions = [ 229 | 'multi_hop_reasoning', 'fluency', 'clarity', 'conciseness', 'relevance', 230 | 'consistency', 'question_answerability', 'answer_question_consistency', 231 | 'information_integration_ability', 'reasoning_path_guidance', 232 | 'logical_sophistication', 'overall_quality' 233 | ] 234 | 235 | # Get row index 236 | row_idx = df[(df['id'] == question_id) & (df['model'] == model_name)].index 237 | 238 | if len(row_idx) > 0: 239 | # Update existing row 240 | evaluation = result['evaluation'] 241 | for dim in dimensions: 242 | if dim == 'multi_hop_reasoning': 243 | df.loc[row_idx[0], dim] = 'Yes' if evaluation.get(dim, False) else 'No' 244 | else: 245 | df.loc[row_idx[0], dim] = evaluation.get(dim, '') 246 | 247 | print(f"Updated evaluation for Question {question_id} with Model {model_name}") 248 | else: 249 | # Create new row 250 | # Collect all documents 251 | documents = [] 252 | for i in range(1, 11): 253 | doc_key = f'document{i}' 254 | if doc_key in question and question[doc_key] and question[doc_key].strip(): 255 | documents.append(question[doc_key]) 256 | 257 | new_row = { 258 | 'dataset': dataset_name, 259 | 'id': question_id, 260 | 'model': model_name, 261 | 'question': question.get('question', ''), 262 | 'answer': question.get('answer', ''), 263 | 'document_count': len(documents), 264 | 'documents': json.dumps(documents) 265 | } 266 | 267 | evaluation = result['evaluation'] 268 | for dim in dimensions: 269 | if dim == 'multi_hop_reasoning': 270 | new_row[dim] = 'Yes' if evaluation.get(dim, False) else 'No' 271 | else: 272 | new_row[dim] = evaluation.get(dim, '') 273 | 274 | # Add new row 275 | df = pd.concat([df, pd.DataFrame([new_row])], ignore_index=True) 276 | print(f"Added evaluation for Question {question_id} with Model {model_name}") 277 | else: 278 | print(f"Warning: Evaluation for Question {question_id} with Model {model_name} failed") 279 | 280 | # Save updated CSV file 281 | df.to_csv(csv_file, index=False, encoding='utf-8') 282 | print(f"\nUpdated evaluation results and saved to: {csv_file}") 283 | 284 | # Check again if there are still missing evaluations 285 | missing = find_missing_evaluations(df, models) 286 | if missing: 287 | print(f"Warning: There are still {len(missing)} missing evaluations, may need to run this script again") 288 | else: 289 | print("All evaluations are complete!") 290 | 291 | def main(): 292 | """Main function""" 293 | # Get all CSV files in the evaluation results directory 294 | csv_files = glob.glob(os.path.join(EVAL_RESULT_DIR, "*_evaluation.csv")) 295 | 296 | if not csv_files: 297 | print(f"Warning: No evaluation result files found in directory {EVAL_RESULT_DIR}") 298 | return 299 | 300 | print(f"Found {len(csv_files)} evaluation result files") 301 | 302 | # Process each CSV file 303 | for csv_file in csv_files: 304 | print(f"\n{'='*50}") 305 | print(f"Processing file: {csv_file}") 306 | print(f"{'='*50}") 307 | 308 | complete_evaluations_from_csv(csv_file) 309 | 310 | print("\nAll evaluation result files have been processed") 311 | 312 | if __name__ == '__main__': 313 | main() 314 | --------------------------------------------------------------------------------