├── src ├── retrievers │ ├── utils │ │ ├── __init__.py │ │ └── utils.py │ ├── embeddings │ │ ├── utils │ │ │ ├── __init__.py │ │ │ └── normalize_text.py │ │ ├── __init__.py │ │ ├── base.py │ │ ├── contriever.py │ │ ├── ada_embedding.py │ │ ├── dense_embedding.py │ │ ├── embedder.py │ │ └── e5.py │ ├── __init__.py │ ├── vector_index │ │ ├── __init__.py │ │ ├── base.py │ │ └── faiss_index.py │ ├── passage_embedder.py │ ├── multihop_data_extrator.py │ └── passage_retriever.py ├── conf │ ├── __init__.py │ └── config.py ├── efficient_rag │ ├── model │ │ ├── __init__.py │ │ └── model.py │ ├── data │ │ ├── __init__.py │ │ ├── filter_dataset.py │ │ ├── label_only_dataset.py │ │ └── labeler_dataset.py │ ├── token_weight_avg.py │ ├── filter_training.py │ └── labeler_training.py ├── utils │ ├── __init__.py │ ├── utils.py │ └── model.py ├── language_models │ ├── cloudgpt │ │ ├── __init__.py │ │ └── cloudgpt_aoai.py │ ├── base.py │ ├── __init__.py │ ├── llama.py │ ├── deepseek.py │ └── aoai.py ├── data_module │ ├── __init__.py │ ├── format.py │ └── dataset.py ├── data_synthesize │ ├── prompts │ │ ├── __init__.py │ │ ├── negative_token_labeling.py │ │ ├── hotpotQA.py │ │ └── span_labeling.py │ ├── training_data_synthesize.py │ ├── negative_sampling.py │ ├── negative_sampling_labeled.py │ ├── graph_driven.py │ ├── chunk_sampling.py │ ├── gpt_query_chunk_synethesize.py │ ├── token_extraction.py │ ├── negative_token_extraction.py │ ├── next_hop_query_filtering.py │ └── span_labeling.py ├── efficientrag_retrieve.sh ├── evaluation │ ├── retrieve.py │ └── correctness.py ├── baseline │ └── retrieve │ │ ├── ceiling.py │ │ ├── decompose.py │ │ └── direct.py └── efficientrag_qa.py ├── static ├── bert_labeler.png └── bert_question.png ├── requirements.txt ├── CONTRIBUTING.md ├── LICENSE ├── SUPPORT.md ├── SECURITY.md ├── .gitignore └── README.md /src/retrievers/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/retrievers/embeddings/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/conf/__init__.py: -------------------------------------------------------------------------------- 1 | from .config import * 2 | -------------------------------------------------------------------------------- /src/retrievers/__init__.py: -------------------------------------------------------------------------------- 1 | from .passage_retriever import Retriever 2 | -------------------------------------------------------------------------------- /src/efficient_rag/model/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import DebertaForSequenceTokenClassification 2 | -------------------------------------------------------------------------------- /src/retrievers/vector_index/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import BaseIndex 2 | from .faiss_index import FaissIndex 3 | -------------------------------------------------------------------------------- /static/bert_labeler.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NIL-zhuang/EfficientRAG-official/HEAD/static/bert_labeler.png -------------------------------------------------------------------------------- /static/bert_question.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NIL-zhuang/EfficientRAG-official/HEAD/static/bert_question.png -------------------------------------------------------------------------------- /src/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import ask_model, ask_model_in_parallel 2 | from .utils import load_jsonl, write_jsonl 3 | -------------------------------------------------------------------------------- /src/language_models/cloudgpt/__init__.py: -------------------------------------------------------------------------------- 1 | from .cloudgpt_aoai import ( 2 | auto_refresh_token, 3 | cloudGPT_available_models, 4 | get_openai_token, 5 | ) 6 | -------------------------------------------------------------------------------- /src/data_module/__init__.py: -------------------------------------------------------------------------------- 1 | from .dataset import ( 2 | HotpotQADataset, 3 | MultiHopDataset, 4 | MuSiQueDataset, 5 | WikiMQADataset, 6 | get_dataset, 7 | ) 8 | -------------------------------------------------------------------------------- /src/efficient_rag/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .filter_dataset import FilterDataset 2 | from .label_only_dataset import LabelOnlyDataset 3 | from .labeler_dataset import LabelerDataset 4 | -------------------------------------------------------------------------------- /src/data_synthesize/prompts/__init__.py: -------------------------------------------------------------------------------- 1 | from .hotpotQA import * 2 | from .musique import * 3 | from .negative_token_labeling import * 4 | from .query_labeling import * 5 | from .token_labeling import * 6 | from .wikimqa import * 7 | -------------------------------------------------------------------------------- /src/retrievers/embeddings/__init__.py: -------------------------------------------------------------------------------- 1 | from .contriever import Contriever 2 | from .e5 import E5BaseV2Embedding, E5LargeV2Embedding, E5MistralInstructEmbedding 3 | from .embedder import Embedder, EmbeddingModelTypes, ModelCheckpointMapping, ModelTypes 4 | -------------------------------------------------------------------------------- /src/retrievers/embeddings/base.py: -------------------------------------------------------------------------------- 1 | class BaseEmbedding(object): 2 | def __init__(self): ... 3 | 4 | def embed(self, query: str): 5 | raise NotImplementedError 6 | 7 | def embed_batch(self, queries: list[str]): 8 | raise NotImplementedError 9 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | faiss-cpu 2 | msal 3 | numpy 4 | openai>1.0.0 5 | Requests 6 | tqdm 7 | transformers>=4.40 8 | rich 9 | tenacity 10 | spacy 11 | datasets>=2.18.0 12 | wandb 13 | sentencepiece>=0.2.0 14 | accelerate==0.29.1 15 | scikit-learn 16 | deepspeed>=0.14.1 17 | vllm 18 | ray[serve] 19 | black 20 | -------------------------------------------------------------------------------- /src/efficientrag_retrieve.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | CUDA_VISIBLE_DEVICES=1 \ 4 | python src/efficientrag_retrieve.py \ 5 | --dataset hotpotQA \ 6 | --retriever contriever \ 7 | --labels 2 \ 8 | --suffix label2 \ 9 | --labeler_ckpt LABELER_CKPT_PATH \ 10 | --filter_ckpt FILTER_CKPT_PATH 11 | -------------------------------------------------------------------------------- /src/retrievers/vector_index/base.py: -------------------------------------------------------------------------------- 1 | class BaseIndex: 2 | def __init__(self): 3 | pass 4 | 5 | def search(self, query, top_k): 6 | raise NotImplementedError 7 | 8 | def serialize(self, dir_path): 9 | raise NotImplementedError 10 | 11 | def deserialize(self, dir_path): 12 | raise NotImplementedError 13 | 14 | def exist_index(self, dir_path): 15 | raise NotImplementedError 16 | -------------------------------------------------------------------------------- /src/utils/utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | 4 | def load_jsonl(file_path: str): 5 | with open(file_path, "r", encoding="utf-8") as f: 6 | return [json.loads(line) for line in f] 7 | 8 | 9 | def write_jsonl(data: list, file_path: str): 10 | with open(file_path, "w+", encoding="utf-8") as f: 11 | for sample in data: 12 | info = json.dumps(sample, ensure_ascii=False) 13 | f.write(info + "\n") 14 | -------------------------------------------------------------------------------- /src/language_models/base.py: -------------------------------------------------------------------------------- 1 | from typing import Union 2 | 3 | 4 | class LanguageModel: 5 | def __init__(self, model, *args, **kwargs): 6 | self.model = model 7 | 8 | def chat( 9 | self, 10 | messages: Union[str, list[str]], 11 | system_msg: str, 12 | json_mode: bool, 13 | **kwargs 14 | ): 15 | raise NotImplementedError 16 | 17 | def complete(self, prompts: str): 18 | raise NotImplementedError 19 | -------------------------------------------------------------------------------- /src/retrievers/embeddings/contriever.py: -------------------------------------------------------------------------------- 1 | from .dense_embedding import DenseEmbedding 2 | 3 | 4 | class Contriever(DenseEmbedding): 5 | def __init__(self, model_name_or_path: str): 6 | if model_name_or_path is None: 7 | model_name_or_path = "facebook/contriever-msmarco" 8 | super().__init__( 9 | model_name_or_path=model_name_or_path, 10 | embedding_vector_size=768, 11 | pooling_type="average", 12 | ) 13 | -------------------------------------------------------------------------------- /src/data_module/format.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | sys.path.append(os.path.join(os.path.dirname(__file__), "..")) 5 | 6 | from conf import SEP_TOKEN 7 | 8 | INFO_TEMPLATE = "Info: {info}" 9 | QUERY_INFO_SENTENCE_TEMPLATE = f"Query: {{query}} {{info_str}}" 10 | 11 | 12 | def build_query_info_sentence(info_list: list[str], query: str) -> str: 13 | infos = [INFO_TEMPLATE.format(info=info) for info in info_list] 14 | info_str = "; ".join(infos) 15 | res = QUERY_INFO_SENTENCE_TEMPLATE.format(query=query, info_str=info_str) 16 | return res 17 | -------------------------------------------------------------------------------- /src/retrievers/utils/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import csv 4 | 5 | 6 | def load_passages(fpath): 7 | if not os.path.exists(fpath): 8 | raise FileNotFoundError(f"{fpath} does not exist") 9 | 10 | passages = [] 11 | with open(fpath) as fin: 12 | if fpath.endswith(".jsonl"): 13 | for k, line in enumerate(fin): 14 | ex = json.loads(line) 15 | passages.append(ex) 16 | else: 17 | reader = csv.reader(fin, delimiter="\t") 18 | for k, row in enumerate(reader): 19 | if not row[0] == "id": 20 | ex = {"id": row[0], "title": row[2], "text": row[1]} 21 | passages.append(ex) 22 | return passages 23 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | ## Contributing 2 | 3 | This project welcomes contributions and suggestions. Most contributions require you to agree to a Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us the rights to use your contribution. For details, visit https://cla.microsoft.com. 4 | 5 | When you submit a pull request, a CLA-bot will automatically determine whether you need to provide a CLA and decorate the PR appropriately (e.g., label, comment). Simply follow the instructions provided by the bot. You will only need to do this once across all repositories using our CLA. 6 | 7 | This project has adopted the Microsoft Open Source Code of Conduct. For more information see the Code of Conduct FAQ or contact opencode@microsoft.com with any additional questions or comments. -------------------------------------------------------------------------------- /src/retrievers/embeddings/ada_embedding.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | sys.path.append(os.path.join(os.path.dirname(__file__), "../..")) 5 | from language_models import AOAI 6 | 7 | from .dense_embedding import DenseEmbedding 8 | 9 | 10 | class AdaEmbedding(DenseEmbedding): 11 | def __init__( 12 | self, 13 | embedding_model: str = "text-embedding-ada-002", 14 | api_version: str = "2024-02-15-preview", 15 | ): 16 | super().__init__(embedding_model, embedding_vector_size=1536) 17 | self.model = AOAI(embedding_model=embedding_model, api_version=api_version) 18 | 19 | def instantiate(self): 20 | return 21 | 22 | def embed(self, query: str): 23 | return self.model.embed(query) 24 | 25 | def embed_batch(self, queries: list[str], max_workers: int = 50): 26 | return self.model.embed(queries) 27 | -------------------------------------------------------------------------------- /src/language_models/__init__.py: -------------------------------------------------------------------------------- 1 | from .aoai import AOAI 2 | from .base import LanguageModel 3 | from .deepseek import DeepSeek 4 | from .llama import LlamaServer 5 | 6 | MODEL_DICT = { 7 | "gpt35": "gpt-35-turbo-1106", 8 | "gpt4": "gpt-4-0125-preview", 9 | "llama": "Meta-Llama-3-70B-Instruct", 10 | "llama-8B": "Meta-Llama-3-8B-Instruct", 11 | "deepseek": "deepseek-chat", 12 | } 13 | 14 | 15 | def get_model(model_name: str, **kwargs) -> LanguageModel: 16 | if model_name in MODEL_DICT.keys(): 17 | model_name = MODEL_DICT[model_name] 18 | 19 | if "gpt" in model_name.lower(): 20 | return AOAI(model=model_name, **kwargs) 21 | elif "deepseek" in model_name.lower(): 22 | return DeepSeek(model=model_name, **kwargs) 23 | elif "llama" in model_name.lower(): 24 | return LlamaServer(model=model_name, **kwargs) 25 | else: 26 | raise NotImplementedError(f"Model {model_name} not implemented") 27 | -------------------------------------------------------------------------------- /src/language_models/cloudgpt/cloudgpt_aoai.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | from typing import Literal 3 | 4 | 5 | def get_openai_token(token_cache_file: str = "cloudgpt-apim-token-cache.bin") -> str: 6 | # REPLACE THIS WITH YOUR OWN CODE OR API TOKEN 7 | pass 8 | 9 | 10 | cloudGPT_available_models = Literal[ 11 | "gpt-4-0125-preview", 12 | "gpt-4-1106-preview", 13 | "gpt-4-vision-preview", 14 | "gpt-4", 15 | "gpt-4-0314", 16 | "gpt-4-0613", 17 | "gpt-4-32k", 18 | "gpt-4-32k-0314", 19 | "gpt-4-32k-0613", 20 | "gpt-35-turbo-1106", 21 | "gpt-35-turbo", 22 | "gpt-35-turbo-16k", 23 | "gpt-35-turbo-0301", 24 | "gpt-35-turbo-0613", 25 | "gpt-35-turbo-16k-0613", 26 | ] 27 | 28 | 29 | def auto_refresh_token( 30 | token_cache_file: str = "cloudgpt-apim-token-cache.bin", 31 | interval: datetime.timedelta = datetime.timedelta(minutes=15), 32 | on_token_update: callable = None, 33 | ) -> callable: 34 | # REPLACE THIS WITH YOUR OWN CODE OR API TOKEN 35 | pass 36 | -------------------------------------------------------------------------------- /src/evaluation/retrieve.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | 5 | import numpy as np 6 | 7 | sys.path.append(os.path.join(os.path.dirname(__file__), "..")) 8 | from utils import load_jsonl 9 | 10 | 11 | def recall(oracle, chunks): 12 | match = [ 13 | chunk for chunk in chunks if set(chunk.split("//")).intersection(set(oracle)) 14 | ] 15 | match = list(set(match)) 16 | recall = len(match) / len(oracle) 17 | return recall 18 | 19 | 20 | def main(fpath: str): 21 | data = load_jsonl(fpath) 22 | 23 | result_avg_len = [len(sample["chunk_ids"]) for sample in data] 24 | print(f"Average number of chunks: {np.mean(result_avg_len)}") 25 | 26 | recall_list = [recall(sample["oracle_ids"], sample["chunk_ids"]) for sample in data] 27 | print(f"Recall: {np.mean(recall_list):.4f}") 28 | 29 | 30 | def parse_args(): 31 | parser = argparse.ArgumentParser() 32 | parser.add_argument("--fpath", type=str, required=True) 33 | return parser.parse_args() 34 | 35 | 36 | if __name__ == "__main__": 37 | options = parse_args() 38 | main(options.fpath) 39 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) Microsoft Corporation. 2 | 3 | MIT License 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /SUPPORT.md: -------------------------------------------------------------------------------- 1 | # TODO: The maintainer of this repo has not yet edited this file 2 | 3 | **REPO OWNER**: Do you want Customer Service & Support (CSS) support for this product/project? 4 | 5 | - **No CSS support:** Fill out this template with information about how to file issues and get help. 6 | - **Yes CSS support:** Fill out an intake form at [aka.ms/onboardsupport](https://aka.ms/onboardsupport). CSS will work with/help you to determine next steps. 7 | - **Not sure?** Fill out an intake as though the answer were "Yes". CSS will help you decide. 8 | 9 | *Then remove this first heading from this SUPPORT.MD file before publishing your repo.* 10 | 11 | # Support 12 | 13 | ## How to file issues and get help 14 | 15 | This project uses GitHub Issues to track bugs and feature requests. Please search the existing 16 | issues before filing new issues to avoid duplicates. For new issues, file your bug or 17 | feature request as a new Issue. 18 | 19 | For help and questions about using this project, please **REPO MAINTAINER: INSERT INSTRUCTIONS HERE 20 | FOR HOW TO ENGAGE REPO OWNERS OR COMMUNITY FOR HELP. COULD BE A STACK OVERFLOW TAG OR OTHER 21 | CHANNEL. WHERE WILL YOU HELP PEOPLE?**. 22 | 23 | ## Microsoft Support Policy 24 | 25 | Support for this **PROJECT or PRODUCT** is limited to the resources listed above. -------------------------------------------------------------------------------- /src/data_synthesize/prompts/negative_token_labeling.py: -------------------------------------------------------------------------------- 1 | NEGATIVE_TOKEN_LABEL_SYNTHESIZE_FEW_SHOT_PROMPT = """ 2 | You have been assigned an information extraction task. 3 | Your mission is to extract the words from a given paragraph so that others(GPT3.5) can answer a question using only your extracted words. 4 | Your extracted words should cover information from the question, including entities (e.g. people, location, film) and core relations. 5 | If the paragraph cannot answer the question, you should extract words that are mostly related to the question, if there is none, return an empty string. 6 | Your response should be in JSON format and include the following key: 7 | - "extracted_words": a string composed of a list of words extracted from the paragraph, separated by a space. 8 | 9 | Please adhere to the following guidelines: 10 | - Do not reorder, change, miss, or add words. Keep it the same as the original paragraph. 11 | - Identify and extract ONLY the words explicitly mentioned in either the question or its answer, and strongly related to the question or its answer. 12 | - NEVER label any words that do not contribute meaningful information to the question or answer. 13 | - Only extract words that occured in the paragraph. 14 | - Extract as few words as possible. 15 | 16 | Question: {question} 17 | Paragraph: {paragraph} 18 | Your response: 19 | """.strip() 20 | # 21 | -------------------------------------------------------------------------------- /src/language_models/llama.py: -------------------------------------------------------------------------------- 1 | from openai import OpenAI 2 | 3 | from .base import LanguageModel 4 | 5 | LLAMA_ENDPOINT = "USE YOUR OWN VLLM ENDPOINT" 6 | LLAMA_API_KEY = "USE YOUR OWN VLLM API KEY" 7 | 8 | 9 | class LlamaServer(LanguageModel): 10 | def __init__(self, model: str = "Meta-Llama-3-70B-Instruct", *args, **kwargs): 11 | super().__init__(model, *args, **kwargs) 12 | self.client = OpenAI( 13 | base_url=LLAMA_ENDPOINT, 14 | api_key=LLAMA_API_KEY, 15 | ) 16 | 17 | def chat(self, message: str, system_msg: str = None, json_mode: bool = False): 18 | if system_msg is None: 19 | system_msg = "You are a helpful assistant." 20 | messages = [ 21 | {"role": "system", "content": system_msg}, 22 | {"role": "user", "content": message}, 23 | ] 24 | response = self.client.chat.completions.create( 25 | model=self.model, 26 | messages=messages, 27 | ) 28 | response = response.choices[0].message.content 29 | return response 30 | 31 | def complete(self, prompts: str): 32 | response = self.client.completions.create( 33 | model=self.model, prompt=prompts, echo=False, max_tokens=100 34 | ) 35 | response = response.choices[0].text 36 | return response 37 | 38 | 39 | if __name__ == "__main__": 40 | llama = LlamaServer("Meta-Llama-3-8B-Instruct") 41 | response = llama.complete( 42 | "The reason of human landing on moon is that, some one found it strange behind the moon." 43 | ) 44 | print(response) 45 | -------------------------------------------------------------------------------- /src/retrievers/passage_embedder.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import pickle 4 | import sys 5 | 6 | from embeddings import Embedder, ModelTypes 7 | 8 | sys.path.append(os.path.join(os.path.dirname(__file__), "..")) 9 | from retrievers.utils.utils import load_passages 10 | 11 | 12 | def parse_args(): 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument( 15 | "--passages", 16 | type=str, 17 | required=True, 18 | help="Path to passages (tsv or jsonl file)", 19 | ) 20 | parser.add_argument("--output_dir", type=str, required=True, help="Path to output file") 21 | parser.add_argument("--model_type", type=str, default="e5-base-v2", choices=list(ModelTypes.keys())) 22 | parser.add_argument("--model_name_or_path", type=str, default=None) 23 | parser.add_argument("--batch_size", type=int, default=1024) 24 | parser.add_argument("--chunk_size", type=int, default=int(2e6), help="passages per chunk") 25 | parser.add_argument("--test_mode", action="store_true", help="Run in test mode") 26 | args = parser.parse_args() 27 | return args 28 | 29 | 30 | def main(opts): 31 | embedder = Embedder( 32 | opts.model_type, 33 | opts.model_name_or_path, 34 | batch_size=opts.batch_size, 35 | chunk_size=opts.chunk_size, 36 | text_normalize=True, 37 | ) 38 | print(f"Loading passages from {opts.passages}") 39 | data = load_passages(opts.passages) 40 | output_dir = opts.output_dir 41 | if not os.path.isdir(output_dir): 42 | os.makedirs(output_dir) 43 | for idx, (ids, embeddings) in embedder.embed_passages(data): 44 | output_file = os.path.join(output_dir, f"passages_{idx:02d}") 45 | with open(output_file, "wb") as f: 46 | pickle.dump((ids, embeddings), f) 47 | print(f"Save {len(ids)} embeddings to {output_file}") 48 | 49 | 50 | if __name__ == "__main__": 51 | options = parse_args() 52 | main(options) 53 | -------------------------------------------------------------------------------- /src/conf/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os.path import sep 3 | 4 | DATA_BASE_PATH = "data" 5 | DATASET_PATH = os.path.join(DATA_BASE_PATH, "dataset") 6 | 7 | # Data Synthesize 8 | SYNTHESIZED_DECOMPOSED_DATA_PATH = f"data{sep}synthesized_decomposed" 9 | SYNTHESIZED_TOKEN_LABELING_DATA_PATH = f"data{sep}synthesized_token_labeling" 10 | SYNTHESIZED_TOKEN_EXTRACTED_DATA_PATH = f"data{sep}token_extracted" 11 | SYNTHESIZED_NEXT_QUERY_DATA_PATH = f"data{sep}synthesized_next_query" 12 | SYNTHESIZED_NEXT_QUERY_EXTRACTED_DATA_PATH = f"data{sep}next_query_extracted" 13 | SYNTHESIZED_NEGATIVE_SAMPLING_DATA_PATH = f"data{sep}negative_sampling" 14 | SYNTHESIZED_NEGATIVE_SAMPLING_LABELED_DATA_PATH = f"data{sep}negative_sampling_labeled" 15 | SYNTHESIZED_NEGATIVE_SAMPLING_EXTRACTED_DATA_PATH = ( 16 | f"data{sep}negative_sampling_extracted" 17 | ) 18 | EFFICIENT_RAG_LABELER_TRAINING_DATA_PATH = f"data{sep}efficient_rag{sep}labeler" 19 | EFFICIENT_RAG_FILTER_TRAINING_DATA_PATH = f"data{sep}efficient_rag{sep}filter" 20 | CORPUS_DATA_PATH = f"data{sep}corpus" 21 | 22 | SYNTHESIZED_SPAN_LABELING_DATA_PATH = f"data{sep}synthesized_span_labeling" 23 | 24 | # Results 25 | RETRIEVE_RESULT_PATH = f"results{sep}retrieve" 26 | 27 | CONTINUE_TAG = "" 28 | FINISH_TAG = "" 29 | TERMINATE_TAG = "" 30 | 31 | TAG_MAPPING = { 32 | CONTINUE_TAG: 0, 33 | TERMINATE_TAG: 1, 34 | FINISH_TAG: 2, 35 | } 36 | TAG_MAPPING_REV = {v: k for k, v in TAG_MAPPING.items()} 37 | 38 | TAG_MAPPING_TWO = { 39 | CONTINUE_TAG: 0, 40 | TERMINATE_TAG: 1, 41 | FINISH_TAG: 0, 42 | } 43 | TAG_MAPPING_TWO_REV = { 44 | 0: CONTINUE_TAG, 45 | 1: TERMINATE_TAG, 46 | } 47 | TERMINATE_ID = TAG_MAPPING[TERMINATE_TAG] 48 | 49 | # Special Tokens 50 | CLS_TOKEN = "[CLS]" 51 | SEP_TOKEN = "[SEP]" 52 | PAD_TOKEN = "[PAD]" 53 | 54 | MODEL_PATH = f"model_cache{sep}deberta-v3-large" 55 | 56 | MODEL_DICT = { 57 | "gpt35": "gpt-35-turbo-1106", 58 | "gpt4": "gpt-4-0125-preview", 59 | "llama": "Meta-Llama-3-70B-Instruct", 60 | "llama-8B": "Meta-Llama-3-8B-Instruct", 61 | "deepseek": "deepseek-chat", 62 | } 63 | 64 | EMBEDDING_ALIAS = { 65 | "contriever": "contriever", 66 | "e5-base-v2": "e5-base", 67 | "e5-large-v2": "e5-large", 68 | "ada-002": "ada-002", 69 | } 70 | -------------------------------------------------------------------------------- /src/retrievers/embeddings/dense_embedding.py: -------------------------------------------------------------------------------- 1 | from typing import Literal, Union 2 | 3 | import torch 4 | from transformers import AutoModel, AutoTokenizer 5 | 6 | from .base import BaseEmbedding 7 | 8 | Pooling = Union[str, Literal["average", "cls"]] 9 | 10 | 11 | class DenseEmbedding(BaseEmbedding): 12 | def __init__( 13 | self, 14 | model_name_or_path: str, 15 | embedding_vector_size: int, 16 | no_fp16: bool = False, 17 | pooling_type: Pooling = "average", 18 | ): 19 | super().__init__() 20 | self.model_name_or_path = model_name_or_path 21 | self.embedding_vector_size = embedding_vector_size 22 | self.model = None 23 | self.tokenizer = None 24 | self.fp16 = not no_fp16 25 | self.pooling_type = pooling_type 26 | 27 | def instantiate(self): 28 | self.tokenizer = AutoTokenizer.from_pretrained(self.model_name_or_path) 29 | self.model = AutoModel.from_pretrained(self.model_name_or_path) 30 | self.model.eval() 31 | self.model = self.model.cuda() 32 | if self.fp16: 33 | self.model = self.model.half() 34 | 35 | def embed(self, query: str): 36 | if self.model is None: 37 | self.instantiate() 38 | queries = [query] 39 | return self.embed_batch(queries)[0] 40 | 41 | def embed_batch(self, queries: list[str]): 42 | if self.model is None: 43 | self.instantiate() 44 | 45 | with torch.no_grad(): 46 | inputs = self.tokenizer( 47 | queries, 48 | return_tensors="pt", 49 | padding=True, 50 | truncation=True, 51 | max_length=512, 52 | ) 53 | inputs = {k: v.cuda() for k, v in inputs.items()} 54 | outputs = self.model(**inputs) 55 | embeddings = self.pooling(outputs.last_hidden_state, inputs["attention_mask"]) 56 | embeddings = embeddings.cpu().numpy() 57 | return embeddings 58 | 59 | def pooling(self, last_hidden_states, attention_mask) -> torch.Tensor: 60 | if self.pooling_type == "average": 61 | last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0) 62 | return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None] 63 | elif self.pooling_type == "cls": 64 | return last_hidden_states[:, 0, :] 65 | else: 66 | raise NotImplementedError 67 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Security 4 | 5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet) and [Xamarin](https://github.com/xamarin). 6 | 7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/security.md/definition), please report it to us as described below. 8 | 9 | ## Reporting Security Issues 10 | 11 | **Please do not report security vulnerabilities through public GitHub issues.** 12 | 13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/security.md/msrc/create-report). 14 | 15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/security.md/msrc/pgp). 16 | 17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc). 18 | 19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue: 20 | 21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.) 22 | * Full paths of source file(s) related to the manifestation of the issue 23 | * The location of the affected source code (tag/branch/commit or direct URL) 24 | * Any special configuration required to reproduce the issue 25 | * Step-by-step instructions to reproduce the issue 26 | * Proof-of-concept or exploit code (if possible) 27 | * Impact of the issue, including how an attacker might exploit the issue 28 | 29 | This information will help us triage your report more quickly. 30 | 31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/security.md/msrc/bounty) page for more details about our active programs. 32 | 33 | ## Preferred Languages 34 | 35 | We prefer all communications to be in English. 36 | 37 | ## Policy 38 | 39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/security.md/cvd). 40 | 41 | -------------------------------------------------------------------------------- /src/retrievers/multihop_data_extrator.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import re 5 | import sys 6 | from collections import defaultdict 7 | 8 | sys.path.append(os.path.join(os.path.dirname(__file__), "..")) 9 | 10 | from tqdm import tqdm 11 | 12 | from conf import CORPUS_DATA_PATH 13 | from data_module import MultiHopDataset, get_dataset 14 | 15 | 16 | def parse_args(): 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument( 19 | "--dataset", 20 | type=str, 21 | choices=["musique", "2WikiMQA", "hotpotQA"], 22 | default="2WikiMQA" 23 | ) 24 | parser.add_argument("--split", type=str, default=None) 25 | args = parser.parse_args() 26 | return args 27 | 28 | 29 | def parse_chunks(dataset: MultiHopDataset): 30 | for sample in dataset: 31 | id = sample["id"] 32 | for idx, chunk in enumerate(sample["chunks"]): 33 | cid = f"{id}-{idx:02d}" 34 | yield {"id": cid, "text": chunk} 35 | 36 | 37 | def purify_text(text: str): 38 | # delete all space and punctuations of the text 39 | pattern = r"[^\w]" 40 | cleaned_text = re.sub(pattern, "", text) 41 | return cleaned_text 42 | 43 | 44 | def merge_chunks(chunks: list[dict]): 45 | chunk_mapping = defaultdict(set) 46 | pattern_title = r"(.*?)" 47 | 48 | for chunk in chunks: 49 | cid = chunk["id"] 50 | text = chunk["text"] 51 | key = purify_text(text) 52 | # chunk_mapping[text].add(cid) 53 | chunk_mapping[key].add((cid, text)) 54 | 55 | chunks = [] 56 | # for text, ids in chunk_mapping.items(): 57 | for key, id_text_pairs in tqdm(chunk_mapping.items()): 58 | id_text_pairs = list(id_text_pairs) 59 | text = id_text_pairs[0][1] 60 | ids = [pair[0] for pair in id_text_pairs] 61 | ids = "//".join(list(ids)) 62 | title = text.split(":")[0].strip() 63 | chunk_info = {"id": ids, "title": title, "text": text} 64 | chunks.append(chunk_info) 65 | return chunks 66 | 67 | 68 | def main(opt: argparse.Namespace): 69 | if opt.split is not None: 70 | split = [opt.split] 71 | else: 72 | split = ["train", "valid", "test"] 73 | chunks = [] 74 | for s in split: 75 | try: 76 | dataset = get_dataset(opt.dataset, s) 77 | for d in parse_chunks(dataset): 78 | chunks.append(d) 79 | except: 80 | continue 81 | chunks = merge_chunks(chunks) 82 | output_dir = os.path.join(CORPUS_DATA_PATH, opt.dataset, "corpus.jsonl") 83 | with open(output_dir, "w+") as f: 84 | for chunk in chunks: 85 | data = json.dumps(chunk) 86 | f.write(data + "\n") 87 | 88 | 89 | if __name__ == "__main__": 90 | options = parse_args() 91 | main(options) 92 | -------------------------------------------------------------------------------- /src/efficient_rag/data/filter_dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | 4 | 5 | class FilterDataset(Dataset): 6 | def __init__( 7 | self, 8 | texts: list[list[str]], 9 | labels: list[list[bool]] = None, 10 | max_len: int = 128, 11 | tokenizer=None, 12 | ): 13 | self.tokenizer = tokenizer 14 | self.max_len = max_len 15 | 16 | self.texts = texts 17 | self.labels = labels 18 | 19 | self.cls_token = "[CLS]" 20 | self.sep_token = "[SEP]" 21 | self.unk_token = "[UNK]" 22 | self.pad_token = "[PAD]" 23 | self.mask_token = "[MASK]" 24 | 25 | def __getitem__(self, index): 26 | text = self.texts[index] 27 | labels = self.labels[index][:] 28 | tokenized_text, labels = self.tokenize_and_preserve_labels( 29 | text, labels, self.tokenizer 30 | ) 31 | assert len(tokenized_text) == len(labels) 32 | # [CLS] + tokenized_text + [SEP] 33 | labels = [False] + labels + [False] 34 | tokenized_text = [self.cls_token] + tokenized_text + [self.sep_token] 35 | 36 | if len(tokenized_text) > self.max_len: 37 | tokenized_text = tokenized_text[: self.max_len] 38 | labels = labels[: self.max_len] 39 | else: 40 | append_length = self.max_len - len(tokenized_text) 41 | tokenized_text = tokenized_text + [self.pad_token] * append_length 42 | labels = labels + [False] * append_length 43 | 44 | attn_mask = [1 if tok != self.pad_token else 0 for tok in tokenized_text] 45 | 46 | ids = self.tokenizer.convert_tokens_to_ids(tokenized_text) 47 | 48 | sample = { 49 | "input_ids": torch.tensor(ids, dtype=torch.long), 50 | "attention_mask": torch.tensor(attn_mask, dtype=torch.long), 51 | "labels": torch.tensor(labels, dtype=torch.long), 52 | } 53 | return sample 54 | 55 | def __len__(self): 56 | return len(self.texts) 57 | 58 | def tokenize_and_preserve_labels(self, text, text_labels, tokenizer): 59 | """ 60 | Word piece tokenization makes it difficult to match word labels 61 | back up with individual word pieces. This function tokenizes each 62 | word one at a time so that it is easier to preserve the correct 63 | label for each subword. It is, of course, a bit slower in processing 64 | time, but it will help our model achieve higher accuracy. 65 | """ 66 | 67 | tokenized_text = [] 68 | labels = [] 69 | for word, label in zip(text, text_labels): 70 | tokenized_word = tokenizer.tokenize(word) 71 | n_subwords = len(tokenized_word) 72 | tokenized_text.extend(tokenized_word) 73 | labels.extend([label] * n_subwords) 74 | 75 | return tokenized_text, labels 76 | -------------------------------------------------------------------------------- /src/baseline/retrieve/ceiling.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | 5 | from tqdm.rich import tqdm_rich 6 | 7 | sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..")) 8 | 9 | from conf import ( 10 | CORPUS_DATA_PATH, 11 | RETRIEVE_RESULT_PATH, 12 | SYNTHESIZED_NEXT_QUERY_EXTRACTED_DATA_PATH, 13 | ) 14 | from retrievers import Retriever 15 | from utils import load_jsonl, write_jsonl 16 | 17 | 18 | def parse_args(): 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument("--dataset", type=str, required=True) 21 | parser.add_argument("--retriever", type=str, required=True) 22 | parser.add_argument("--topk", type=int, default=10) 23 | return parser.parse_args() 24 | 25 | 26 | def main(opt: argparse.Namespace): 27 | passage_path = os.path.join(CORPUS_DATA_PATH, opt.dataset, "corpus.jsonl") 28 | if opt.retriever == "e5-base-v2": 29 | embedding_path = os.path.join(CORPUS_DATA_PATH, opt.dataset, "e5-base") 30 | elif opt.retriever == "contriever": 31 | embedding_path = os.path.join(CORPUS_DATA_PATH, opt.dataset, "contriever") 32 | else: 33 | raise NotImplementedError() 34 | 35 | retriever = Retriever( 36 | passage_path=passage_path, 37 | passage_embedding_path=embedding_path, 38 | index_path_dir=embedding_path, 39 | model_type=opt.retriever, 40 | ) 41 | dataset = load_jsonl( 42 | os.path.join( 43 | SYNTHESIZED_NEXT_QUERY_EXTRACTED_DATA_PATH, opt.dataset, "valid.jsonl" 44 | ) 45 | ) 46 | 47 | queries = [ 48 | [chunk["query_info"] for chunk in d["decomposed_questions"].values()] 49 | # [chunk["sub_question"] for chunk in d["decomposed_questions"].values()] 50 | for d in dataset 51 | ] 52 | 53 | chunks = [ 54 | retriever.search(query_chunk, top_k=opt.topk) 55 | for query_chunk in tqdm_rich(queries, desc="Retrieving") 56 | ] 57 | 58 | results = [] 59 | for chunk_list, sample in zip(chunks, dataset): 60 | chunk_list = sum(chunk_list, []) 61 | chunk_ids = [p["id"] for p in chunk_list] 62 | sub_ids = sorted(list(sample["decomposed_questions"].keys())) 63 | oracle_ids = [ 64 | f"{sample['id']}-{sample['decomposed_questions'][sub_id]['positive_paragraph_idx']}" 65 | for sub_id in sub_ids 66 | ] 67 | 68 | results.append( 69 | { 70 | "question_id": sample["id"], 71 | "question": sample["question"], 72 | "oracle_ids": oracle_ids, 73 | "chunk_ids": chunk_ids, 74 | } 75 | ) 76 | output_path = os.path.join( 77 | RETRIEVE_RESULT_PATH, 78 | "efficient_rag", 79 | "filtered_ceiling", 80 | f"{opt.dataset}-{opt.retriever}-@{opt.topk}.jsonl", 81 | ) 82 | write_jsonl(results, output_path) 83 | 84 | 85 | if __name__ == "__main__": 86 | options = parse_args() 87 | main(options) 88 | -------------------------------------------------------------------------------- /src/data_synthesize/prompts/hotpotQA.py: -------------------------------------------------------------------------------- 1 | # given question, decomposed question, answer, e.t.c 2 | # decompose the original question into successor sub_questions 3 | 4 | hotpotQAFactPrompt = """ 5 | document for sub_question #{question_id} 6 | supporting facts: {facts}""" 7 | 8 | 9 | HotpotQAPromptComparison = """You are assigned a multi-hop question decomposition task. 10 | Your mission is to decompose a multi-hop question into a list of single-hop sub_questions based on supporting documents, and such that you (GPT-4) can answer each sub_question independently from each document. 11 | The JSON output must contain the following keys: 12 | - "question": a string, the original multi-hop question. 13 | - "decomposed_questions": a dict of sub_questions and answers. The key should be the sub_question number(string format), and each value should be a dict containing: 14 | - "sub_question": a string, the decomposed single-hop sub_question. The sub_question MUST NOT contain more information than the original question and its dependent sub_question. NEVER introduce information from the documents. 15 | - "answer": a string, the answer of the sub_question. 16 | - "dependency": an empty list. Because the sub_question is independent. 17 | 18 | The origin multi-hop questions is: {question} 19 | Followings are documents to answer each sub_question. 20 | You MUST decompose the original multi-hoip question based on the given documents. DO NOT change the order or miss anyone of them. 21 | {chunks} 22 | 23 | Your output must always be a JSON object only, do not explain yourself or output anything else. 24 | Follow the documents, synthesize the sub_questions and answers one-by-one. NEVER miss any of them. 25 | """ 26 | 27 | HotpotQAPromptCompose = """You are assigned a multi-hop question decomposition task. 28 | Your mission is to decompose a multi-hop question into a list of single-hop sub_questions based on supporting documents, and such that you (GPT-4) can answer each sub_question independently from each document. 29 | The JSON output must contain the following keys: 30 | - "question": a string, the original multi-hop question. 31 | - "decomposed_questions": a dict of sub_questions and answers. The key should be the sub_question number(string format), and each value should be a dict containing: 32 | - "sub_question": a string, the decomposed single-hop sub_question. The sub_question MUST NOT contain more information than the original question and its dependent sub_question. NEVER introduce information from the documents. 33 | - "answer": a string, the answer of the sub_question. 34 | - "dependency": a list of sub_question number(string format). If the sub_question relies on the answer of other sub_questions, you should list the sub_question number here. 35 | 36 | The origin multi-hop questions is: {question} 37 | And its answer is: {answer} 38 | Followings are documents to answer each sub_question. 39 | Make sure one sub_question depends on the other! Identify which sub_question depends on the answer of another according to the question. 40 | You MUST decompose the question based on the documents with the sub_questions and answers. DO NOT change the order or miss anyone of them. 41 | {chunks} 42 | 43 | Your output must always be a JSON object only, do not explain yourself or output anything else. 44 | Follow the documents, synthesize the sub_questions and answers one-by-one. NEVER miss any of them. 45 | """ 46 | -------------------------------------------------------------------------------- /src/utils/model.py: -------------------------------------------------------------------------------- 1 | import json 2 | import random 3 | import re 4 | import time 5 | from concurrent.futures import ThreadPoolExecutor, as_completed 6 | from typing import Callable, Literal 7 | 8 | from tenacity import retry, stop_after_attempt 9 | from tqdm.rich import tqdm_rich 10 | 11 | from language_models import LanguageModel 12 | 13 | 14 | @retry(stop=stop_after_attempt(3), reraise=False, retry_error_callback=lambda x: None) 15 | def ask_model( 16 | model: LanguageModel, 17 | prompt: str, 18 | system_msg: str = None, 19 | type: Literal["json", "text"] = "json", 20 | check_if_valid: Callable = None, 21 | sleep: bool = True, 22 | mode: Literal["chat", "completion"] = "chat", 23 | ) -> dict: 24 | if sleep: 25 | sleep_time = random.uniform(1.0, 3.0) 26 | time.sleep(sleep_time) 27 | if mode == "chat": 28 | result = model.chat(prompt, system_msg, json_mode=(type == "json")) 29 | # print(result) 30 | elif mode == "completion": 31 | result = model.complete(prompt) 32 | parser = get_type_parser(type) 33 | info = parser(result) 34 | if check_if_valid is not None and not check_if_valid(info): 35 | print(f"Invalid response {info}") 36 | raise ValueError("Invalid response") 37 | return info 38 | 39 | 40 | def ask_model_in_parallel( 41 | model: LanguageModel, 42 | prompts: list[str], 43 | system_msg: str = None, 44 | type: Literal["json", "text"] = "json", 45 | check_if_valid_list: list[Callable] = None, 46 | max_workers: int = 4, 47 | desc: str = "Processing...", 48 | verbose=True, 49 | mode: Literal["chat", "completion"] = "chat", 50 | ): 51 | if max_workers == -1: 52 | max_workers = len(prompts) 53 | assert max_workers >= 1, "max_workers should be greater than or equal to 1" 54 | with ThreadPoolExecutor(max_workers=max_workers) as executor: 55 | if check_if_valid_list is None: 56 | check_if_valid_list = [None] * len(prompts) 57 | assert len(prompts) == len( 58 | check_if_valid_list 59 | ), "Length of prompts and check_if_valid_list should be the same" 60 | tasks = { 61 | executor.submit( 62 | ask_model, model, prompt, system_msg, type, check_if_valid, mode 63 | ): idx 64 | for idx, (prompt, check_if_valid) in enumerate( 65 | zip(prompts, check_if_valid_list) 66 | ) 67 | } 68 | results = [] 69 | for future in tqdm_rich( 70 | as_completed(tasks), total=len(tasks), desc=desc, disable=not verbose 71 | ): 72 | task_id = tasks[future] 73 | try: 74 | result = future.result() 75 | results.append((task_id, result)) 76 | finally: 77 | ... 78 | results = [result[1] for result in sorted(results, key=lambda r: r[0])] 79 | return results 80 | 81 | 82 | def get_type_parser(type: str) -> Callable: 83 | def json_parser(result: str): 84 | # pattern = r"```json(.*?)```" 85 | pattern = r"{.*?}" 86 | matches = re.findall(pattern, result, re.DOTALL) 87 | if matches: 88 | result = matches[0].strip() 89 | return json.loads(result) 90 | 91 | def text_parser(result: str): 92 | return result 93 | 94 | if type == "json": 95 | return json_parser 96 | elif type == "text": 97 | return text_parser 98 | else: 99 | raise ValueError(f"Unsupported type: {type}") 100 | -------------------------------------------------------------------------------- /src/efficient_rag/model/model.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Optional, Tuple, Union 3 | 4 | import torch 5 | import torch.nn as nn 6 | from transformers.modeling_outputs import ModelOutput 7 | from transformers.models.deberta_v2.modeling_deberta_v2 import ( 8 | ContextPooler, 9 | DebertaV2Model, 10 | DebertaV2PreTrainedModel, 11 | StableDropout, 12 | ) 13 | 14 | 15 | @dataclass 16 | class SequenceTokenClassifierOutput(ModelOutput): 17 | loss: Optional[torch.FloatTensor] = None 18 | sequence_logits: torch.FloatTensor = None 19 | token_logits: torch.FloatTensor = None 20 | hidden_states: Optional[Tuple[torch.FloatTensor]] = None 21 | attentions: Optional[Tuple[torch.FloatTensor]] = None 22 | 23 | 24 | class DebertaForSequenceTokenClassification(DebertaV2PreTrainedModel): 25 | def __init__(self, config, token_labels: int = 2, sequence_labels: int = 3): 26 | super().__init__(config) 27 | 28 | self.token_labels = token_labels 29 | self.sequence_labels = sequence_labels 30 | 31 | self.deberta = DebertaV2Model(config) 32 | self.pooler = ContextPooler(config) 33 | output_dim = self.pooler.output_dim 34 | self.sequence_classifier = nn.Linear(output_dim, sequence_labels) 35 | sequence_dropout = getattr(config, "cls_dropout", None) 36 | sequence_dropout = self.config.hidden_dropout_prob if sequence_dropout is None else sequence_dropout 37 | self.sequence_dropout = StableDropout(sequence_dropout) 38 | 39 | self.token_classifier = nn.Linear(config.hidden_size, token_labels) 40 | self.token_dropout = nn.Dropout(config.hidden_dropout_prob) 41 | 42 | self.post_init() 43 | 44 | def get_input_embeddings(self): 45 | return self.deberta.get_input_embeddings() 46 | 47 | def set_input_embeddings(self, value): 48 | return self.deberta.set_input_embeddings(value) 49 | 50 | def forward( 51 | self, 52 | input_ids: Optional[torch.Tensor] = None, 53 | attention_mask: Optional[torch.Tensor] = None, 54 | token_type_ids: Optional[torch.Tensor] = None, 55 | position_ids: Optional[torch.Tensor] = None, 56 | inputs_embeds: Optional[torch.Tensor] = None, 57 | token_labels: Optional[torch.Tensor] = None, 58 | sequence_labels: Optional[torch.Tensor] = None, 59 | output_attentions: Optional[bool] = None, 60 | output_hidden_states: Optional[bool] = None, 61 | return_dict: Optional[bool] = None, 62 | ) -> Union[Tuple, ModelOutput]: 63 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 64 | 65 | outputs = self.deberta( 66 | input_ids, 67 | token_type_ids=token_type_ids, 68 | attention_mask=attention_mask, 69 | position_ids=position_ids, 70 | inputs_embeds=inputs_embeds, 71 | output_attentions=output_attentions, 72 | output_hidden_states=output_hidden_states, 73 | return_dict=return_dict, 74 | ) 75 | 76 | hidden_state = outputs[0] 77 | pooled_output = self.pooler(hidden_state) 78 | pooled_output = self.sequence_dropout(pooled_output) 79 | sequence_logits = self.sequence_classifier(pooled_output) 80 | 81 | token_output = self.token_dropout(hidden_state) 82 | token_logits = self.token_classifier(token_output) 83 | return SequenceTokenClassifierOutput( 84 | sequence_logits=sequence_logits, 85 | token_logits=token_logits, 86 | hidden_states=outputs.hidden_states, 87 | attentions=outputs.attentions, 88 | ) 89 | -------------------------------------------------------------------------------- /src/retrievers/vector_index/faiss_index.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | from typing import List, Tuple 4 | from typing import Literal 5 | from tqdm import tqdm 6 | 7 | import faiss 8 | import numpy as np 9 | 10 | from .base import BaseIndex 11 | 12 | IndexType = Literal[ 13 | "Flat", 14 | "HNSW64", 15 | "IVF100,Flat", 16 | "PQ16", 17 | "IVF100,PQ16", 18 | "LSH", 19 | ] 20 | 21 | 22 | class FaissIndex(BaseIndex): 23 | def __init__( 24 | self, 25 | dim: int = 768, 26 | index_type: IndexType = "Flat", 27 | max_search_batch_size: int = 2048, 28 | max_index_batch_size: int = int(1e6), 29 | ): 30 | super().__init__() 31 | self.index_fname = "index.faiss" 32 | self.index_meta_fname = "index_meta.faiss" 33 | 34 | self.index = faiss.index_factory(dim, index_type, faiss.METRIC_INNER_PRODUCT) 35 | self.idx2db = [] 36 | 37 | self.max_search_batch_size = max_search_batch_size 38 | self.max_index_batch_size = max_index_batch_size 39 | 40 | def search(self, query_vectors: np.array, top_k: int = 20) -> List[Tuple[List[object], List[float]]]: 41 | query_vectors = query_vectors.astype("float32") 42 | result = [] 43 | batches = (len(query_vectors) - 1) // self.max_search_batch_size + 1 44 | for idx in range(batches): 45 | start_idx = idx * self.max_search_batch_size 46 | end_idx = min((idx + 1) * self.max_search_batch_size, len(query_vectors)) 47 | q = query_vectors[start_idx:end_idx] 48 | scores, indexes = self.index.search(q, top_k) 49 | # convert index to passage id 50 | db_ids = [[str(self.idx2db[i]) for i in query_indexes] for query_indexes in indexes] 51 | result.extend([(db_ids[i], scores[i]) for i in range(len(db_ids))]) 52 | return result 53 | 54 | def serialize(self, dir_path): 55 | index_file = os.path.join(dir_path, self.index_fname) 56 | meta_file = os.path.join(dir_path, self.index_meta_fname) 57 | print(f"Serializing index to {index_file}, meta data to {meta_file}") 58 | if not os.path.exists(dir_path): 59 | os.makedirs(dir_path) 60 | 61 | faiss.write_index(self.index, index_file) 62 | with open(meta_file, "wb") as f: 63 | pickle.dump(self.idx2db, f) 64 | 65 | def deserialize(self, dir_path): 66 | index_file = os.path.join(dir_path, self.index_fname) 67 | meta_file = os.path.join(dir_path, self.index_meta_fname) 68 | print(f"Loading index from {index_file}, meta data from {meta_file}") 69 | self.index = faiss.read_index(index_file) 70 | with open(meta_file, "rb") as f: 71 | self.idx2db = pickle.load(f) 72 | assert len(self.idx2db) == self.index.ntotal, "Deserialized idx2db should match faiss index size" 73 | 74 | def exist_index(self, dir_path): 75 | index_file = os.path.join(dir_path, self.index_fname) 76 | meta_file = os.path.join(dir_path, self.index_meta_fname) 77 | return os.path.exists(index_file) and os.path.exists(meta_file) 78 | 79 | def load_data(self, passage_embeddings: List[str]): 80 | embeddings = np.array([]) 81 | ids = [] 82 | for fpath in tqdm(passage_embeddings, desc="Load embeddings"): 83 | with open(fpath, "rb") as fin: 84 | cur_ids, cur_embeddings = pickle.load(fin) 85 | ids.extend(cur_ids) 86 | embeddings = np.vstack((embeddings, cur_embeddings)) if embeddings.size else cur_embeddings 87 | embeddings = embeddings.astype("float32") 88 | self.idx2db = ids 89 | if not self.index.is_trained: 90 | self.index.train(embeddings) 91 | self.index.add(embeddings) 92 | print(f"Total data indexed {len(self.idx2db)}") 93 | -------------------------------------------------------------------------------- /src/data_synthesize/training_data_synthesize.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | 5 | sys.path.append(os.path.join(os.path.dirname(__file__), "..")) 6 | from tqdm.rich import tqdm_rich 7 | 8 | from conf import ( 9 | CONTINUE_TAG, 10 | EFFICIENT_RAG_FILTER_TRAINING_DATA_PATH, 11 | EFFICIENT_RAG_LABELER_TRAINING_DATA_PATH, 12 | FINISH_TAG, 13 | SYNTHESIZED_NEGATIVE_SAMPLING_EXTRACTED_DATA_PATH, 14 | TERMINATE_TAG, 15 | ) 16 | from utils import load_jsonl, write_jsonl 17 | 18 | INFO_TEMPLATE = "Info: {info}" 19 | QUERY_TEMPLATE = "Q: {query}" 20 | 21 | 22 | def build_labeler_data(samples: list[dict]): 23 | results = [] 24 | for sample in tqdm_rich(samples, desc="Building labeler data"): 25 | for subq_id, subq in sample["decomposed_questions"].items(): 26 | try: 27 | if subq_id == sorted(sample["decomposed_questions"].keys())[-1]: 28 | positive_tag = FINISH_TAG 29 | else: 30 | positive_tag = CONTINUE_TAG 31 | positive_sample = { 32 | "question": subq["filtered_query"], 33 | "chunk": subq["positive_paragraph"], 34 | "matched": subq["matched"], 35 | "chunk_tokens": subq["paragraph_tokens"], 36 | "labels": subq["labels"], 37 | "tag": positive_tag, 38 | } 39 | negative_samples = { 40 | "question": subq["filtered_query"], 41 | "chunk": subq["negative_paragraph"], 42 | "matched": subq["negative_matched"], 43 | "chunk_tokens": subq["negative_paragraph_tokens"], 44 | "labels": subq["negative_labels"], 45 | "tag": TERMINATE_TAG, 46 | } 47 | results.append(positive_sample) 48 | results.append(negative_samples) 49 | except Exception as e: 50 | continue 51 | return results 52 | 53 | 54 | def build_filter_data(samples: list[dict]): 55 | results = [] 56 | for sample in tqdm_rich(samples, desc="Building filter data"): 57 | for subq_id, subq in sample["decomposed_questions"].items(): 58 | if "query_info_tokens" not in subq.keys(): 59 | continue 60 | filter_data = { 61 | "query_info_tokens": subq["query_info_tokens"], 62 | "query_info_labels": subq["query_info_labels"], 63 | } 64 | results.append(filter_data) 65 | return results 66 | 67 | 68 | def parse_args(): 69 | parser = argparse.ArgumentParser() 70 | parser.add_argument( 71 | "--dataset", 72 | type=str, 73 | choices=["hotpotQA", "musique", "2WikiMQA"], 74 | required=True, 75 | ) 76 | parser.add_argument("--split", type=str, default="demo") 77 | args = parser.parse_args() 78 | return args 79 | 80 | 81 | def main(opt: argparse.Namespace): 82 | data_path = os.path.join( 83 | SYNTHESIZED_NEGATIVE_SAMPLING_EXTRACTED_DATA_PATH, 84 | opt.dataset, 85 | f"{opt.split}.jsonl", 86 | ) 87 | data = load_jsonl(data_path) 88 | 89 | labeler_data_path = os.path.join( 90 | EFFICIENT_RAG_LABELER_TRAINING_DATA_PATH, opt.dataset, f"{opt.split}.jsonl" 91 | ) 92 | labeler_training_data = build_labeler_data(data) 93 | write_jsonl(labeler_training_data, labeler_data_path) 94 | 95 | filter_data_path = os.path.join( 96 | EFFICIENT_RAG_FILTER_TRAINING_DATA_PATH, opt.dataset, f"{opt.split}.jsonl" 97 | ) 98 | filter_training_data = build_filter_data(data) 99 | write_jsonl(filter_training_data, filter_data_path) 100 | 101 | 102 | if __name__ == "__main__": 103 | options = parse_args() 104 | main(options) 105 | -------------------------------------------------------------------------------- /src/language_models/deepseek.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from time import sleep 4 | 5 | import openai 6 | from openai import OpenAI 7 | from openai._types import NotGiven 8 | 9 | from .base import LanguageModel 10 | 11 | DEEPSEEK_BASE_URL = "https://api.deepseek.com" 12 | SLEEP_SEC = 3 13 | 14 | 15 | class DeepSeek(LanguageModel): 16 | def __init__( 17 | self, model: str = "deepseek_chat", api_key: str = None, *args, **kwargs 18 | ): 19 | super().__init__(model, *args, **kwargs) 20 | if api_key is None: 21 | api_key = open("deepseek_api_key.txt").read().strip() 22 | self.client = OpenAI(base_url=DEEPSEEK_BASE_URL, api_key=api_key) 23 | 24 | def chat(self, messages: str, system_msg: str = None, **kwargs): 25 | try: 26 | response = self._chat(messages, system_msg, **kwargs) 27 | return response 28 | except openai.BadRequestError as e: 29 | err = json.loads(e.response.text) 30 | if err["error"]["code"] == "content_filter": 31 | print("Content filter triggered!") 32 | return None 33 | print(f"The OpenAI API request was invalid: {e}") 34 | return None 35 | except openai.APIConnectionError as e: 36 | print(f"The OpenAI API connection failed: {e}") 37 | sleep(SLEEP_SEC) 38 | return self.chat(messages, system_msg, **kwargs) 39 | except openai.RateLimitError as e: 40 | print(f"Token rate limit exceeded. Retrying after {SLEEP_SEC} second...") 41 | sleep(SLEEP_SEC) 42 | return self.chat(messages, system_msg, **kwargs) 43 | except openai.AuthenticationError as e: 44 | print(f"Invalid API token: {e}") 45 | self.update_api_key() 46 | sleep(SLEEP_SEC) 47 | return self.chat(messages, system_msg, **kwargs) 48 | except openai.APIError as e: 49 | if "The operation was timeout" in str(e): 50 | # Handle the timeout error here 51 | print("The OpenAI API request timed out. Please try again later.") 52 | sleep(SLEEP_SEC) 53 | return self.chat(messages, system_msg, **kwargs) 54 | elif "DeploymentNotFound" in str(e): 55 | print("The API deployment for this resource does not exist") 56 | return None 57 | else: 58 | # Handle other API errors here 59 | print(f"The OpenAI API returned an error: {e}") 60 | sleep(SLEEP_SEC) 61 | return self.chat(messages, system_msg, **kwargs) 62 | except Exception as e: 63 | print(f"An error occurred: {e}") 64 | 65 | def _chat( 66 | self, 67 | messages: str, 68 | system_msg="", 69 | temperature: float = 0.3, 70 | max_tokens: int = 1000, 71 | top_p: float = 0.95, 72 | frequency_penalty: float = 0.0, 73 | presence_penalty: float = 0.0, 74 | json_mode: bool = False, 75 | ): 76 | if system_msg is None or system_msg == "": 77 | system_msg = "You are a helpful assistant." 78 | msg = [ 79 | {"role": "system", "content": system_msg}, 80 | {"role": "user", "content": messages}, 81 | ] 82 | response = self.client.chat.completions.create( 83 | model=self.model, 84 | response_format={"type": "json_object"} if json_mode else NotGiven(), 85 | messages=msg, 86 | temperature=temperature, 87 | max_tokens=max_tokens, 88 | top_p=top_p, 89 | frequency_penalty=frequency_penalty, 90 | presence_penalty=presence_penalty, 91 | ) 92 | return response.choices[0].message.content 93 | 94 | 95 | if __name__ == "__main__": 96 | aoai = DeepSeek(model="deepseek-chat") 97 | print(aoai.chat("Hello, who are you?", json_mode=True)) 98 | -------------------------------------------------------------------------------- /src/data_synthesize/negative_sampling.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import sys 5 | 6 | sys.path.append(os.path.join(os.path.dirname(__file__), "..")) 7 | from concurrent.futures import ThreadPoolExecutor, as_completed 8 | from typing import Iterator 9 | 10 | from tqdm.rich import tqdm_rich 11 | 12 | from conf import ( 13 | CORPUS_DATA_PATH, 14 | SYNTHESIZED_NEGATIVE_SAMPLING_DATA_PATH, 15 | SYNTHESIZED_NEXT_QUERY_EXTRACTED_DATA_PATH, 16 | ) 17 | from retrievers import Retriever 18 | from retrievers.embeddings import ModelCheckpointMapping, ModelTypes 19 | from utils import load_jsonl 20 | 21 | 22 | def negative_sampling(retriever: Retriever, samples: list[dict]) -> Iterator[dict]: 23 | for sample in tqdm_rich(samples, total=len(samples), desc="Negative Sampling..."): 24 | if not all( 25 | ["filtered_query" in sample["decomposed_questions"][sub_id] for sub_id in sample["decomposed_questions"]] 26 | ): 27 | print(f"Invalid sample {sample['id']}") 28 | continue 29 | sub_ids = sorted(list(sample["decomposed_questions"].keys())) 30 | filtered_queries = [ 31 | sample["decomposed_questions"][sub_id]["filtered_query"] 32 | for sub_id in sub_ids 33 | if sample["decomposed_questions"][sub_id] 34 | ] 35 | oracle_chunk_ids = set( 36 | [ 37 | f"{sample['id']}-{'{:02d}'.format(sample['decomposed_questions'][sub_id]['positive_paragraph_idx'])}" 38 | for sub_id in sub_ids 39 | ] 40 | ) 41 | result = sample.copy() 42 | candidate_chunks = retriever.search(filtered_queries, top_k=10) 43 | for subq_id, candidate_chunk_list in zip(sub_ids, candidate_chunks): 44 | for candidate_chunk in candidate_chunk_list: 45 | chunk_idx = set(candidate_chunk["id"].split("//")) 46 | if not chunk_idx.intersection(oracle_chunk_ids): 47 | result["decomposed_questions"][subq_id]["negative_paragraph"] = candidate_chunk["text"] 48 | result["decomposed_questions"][subq_id]["negative_paragraph_idx"] = candidate_chunk["id"] 49 | break 50 | yield result 51 | 52 | 53 | def parse_args(): 54 | parser = argparse.ArgumentParser() 55 | parser.add_argument( 56 | "--dataset", 57 | type=str, 58 | choices=["musique", "2WikiMQA", "hotpotQA"], 59 | required=True, 60 | ) 61 | parser.add_argument("--split", type=str, default="demo") 62 | parser.add_argument("--retriever", type=str, choices=ModelTypes.keys(), default="contriever") 63 | args = parser.parse_args() 64 | return args 65 | 66 | 67 | def main(opts: argparse.Namespace): 68 | passage_path = os.path.join(CORPUS_DATA_PATH, opts.dataset, "corpus.jsonl") 69 | if opts.retriever == "e5-base-v2": 70 | embedding_path = os.path.join(CORPUS_DATA_PATH, opts.dataset, "e5-base") 71 | elif opts.retriever == "contriever": 72 | embedding_path = os.path.join(CORPUS_DATA_PATH, opts.dataset, "contriever") 73 | else: 74 | raise NotImplementedError(f"Retriever {opts.retriever} not implemented") 75 | 76 | retriever = Retriever( 77 | passage_path=passage_path, 78 | passage_embedding_path=embedding_path, 79 | index_path_dir=embedding_path, 80 | model_type=opts.retriever, 81 | model_path=ModelCheckpointMapping[opts.retriever], 82 | ) 83 | subq_data_path = os.path.join(SYNTHESIZED_NEXT_QUERY_EXTRACTED_DATA_PATH, opts.dataset, f"{opts.split}.jsonl") 84 | samples = load_jsonl(subq_data_path) 85 | output_data_path = os.path.join(SYNTHESIZED_NEGATIVE_SAMPLING_DATA_PATH, opts.dataset, f"{opts.split}.jsonl") 86 | with open(output_data_path, "w+", encoding="utf-8") as f: 87 | for sample in negative_sampling(retriever, samples): 88 | d = json.dumps(sample, ensure_ascii=False) 89 | f.write(d + "\n") 90 | 91 | 92 | if __name__ == "__main__": 93 | options = parse_args() 94 | main(options) 95 | -------------------------------------------------------------------------------- /src/efficient_rag/data/label_only_dataset.py: -------------------------------------------------------------------------------- 1 | import spacy 2 | import torch 3 | from torch.utils.data import Dataset 4 | 5 | 6 | class LabelOnlyDataset(Dataset): 7 | def __init__( 8 | self, 9 | questions: list[str], 10 | chunk_tokens: list[list[str]], 11 | labels: list[list[str]], 12 | tags: list[int], 13 | max_len=512, 14 | tokenizer=None, 15 | ): 16 | self.tokenizer = tokenizer 17 | self.max_len = max_len 18 | 19 | self.questions = questions 20 | self.labels = labels 21 | self.chunk_tokens = chunk_tokens 22 | self.tags = tags 23 | 24 | self.cls_token = "[CLS]" 25 | self.sep_token = "[SEP]" 26 | self.unk_token = "[UNK]" 27 | self.pad_token = "[PAD]" 28 | self.mask_token = "[MASK]" 29 | 30 | self.nlp = spacy.load("en_core_web_sm") 31 | 32 | def __getitem__(self, index): 33 | # text = self.questions[index] 34 | chunk = self.chunk_tokens[index] 35 | chunk_label = self.labels[index][:] 36 | question, question_labels = self.construct_question_labels(self.questions[index]) 37 | tokenized_question, tokenized_question_labels = self.tokenize_and_preserve_labels( 38 | question, question_labels, self.tokenizer 39 | ) 40 | assert self.labels is not None 41 | 42 | tokenized_chunk, chunk_label = self.tokenize_and_preserve_labels(chunk, chunk_label, self.tokenizer) 43 | assert len(tokenized_chunk) == len(chunk_label) 44 | 45 | # [CLS] question [SEP] chunk [SEP] 46 | tokenized_text = [self.cls_token] + tokenized_question + [self.sep_token] + tokenized_chunk + [self.sep_token] 47 | labels = [False] + tokenized_question_labels + [False] + chunk_label + [False] 48 | 49 | if len(tokenized_text) > self.max_len: 50 | tokenized_text = tokenized_text[: self.max_len] 51 | if self.labels is not None: 52 | labels = labels[: self.max_len] 53 | else: 54 | tokenized_text = tokenized_text + [self.pad_token for _ in range(self.max_len - len(tokenized_text))] 55 | if self.labels is not None: 56 | labels = labels + [False for _ in range(self.max_len - len(labels))] 57 | 58 | attn_mask = [1 if tok != self.pad_token else 0 for tok in tokenized_text] 59 | ids = self.tokenizer.convert_tokens_to_ids(tokenized_text) 60 | chunk_tags = self.tags[index] 61 | sample = { 62 | "input_ids": torch.tensor(ids, dtype=torch.long), 63 | "attention_mask": torch.tensor(attn_mask, dtype=torch.long), 64 | # "token_labels": torch.tensor(labels, dtype=torch.long), 65 | "labels": torch.tensor(chunk_tags, dtype=torch.long), 66 | } 67 | return sample 68 | 69 | def __len__(self): 70 | return len(self.questions) 71 | 72 | def tokenize_and_preserve_labels(self, text, text_labels, tokenizer): 73 | """ 74 | Word piece tokenization makes it difficult to match word labels 75 | back up with individual word pieces. This function tokenizes each 76 | word one at a time so that it is easier to preserve the correct 77 | label for each subword. It is, of course, a bit slower in processing 78 | time, but it will help our model achieve higher accuracy. 79 | """ 80 | 81 | tokenized_text = [] 82 | labels = [] 83 | for word, label in zip(text, text_labels): 84 | tokenized_word = tokenizer.tokenize(word) 85 | n_subwords = len(tokenized_word) 86 | tokenized_text.extend(tokenized_word) 87 | labels.extend([label] * n_subwords) 88 | 89 | return tokenized_text, labels 90 | 91 | def construct_question_labels(self, question, ignore_tokens=set([","])): 92 | doc = self.nlp(question) 93 | labels = [False] * len(doc) 94 | words = [] 95 | for word in doc.ents: 96 | if word.lemma_ not in ignore_tokens: 97 | words.append(word.lemma_) 98 | return words, labels 99 | -------------------------------------------------------------------------------- /src/efficient_rag/data/labeler_dataset.py: -------------------------------------------------------------------------------- 1 | import spacy 2 | import torch 3 | from torch.utils.data import Dataset 4 | 5 | 6 | class LabelerDataset(Dataset): 7 | def __init__( 8 | self, 9 | questions: list[str], 10 | chunk_tokens: list[list[str]], 11 | labels: list[list[str]], 12 | tags: list[int], 13 | max_len=512, 14 | tokenizer=None, 15 | ): 16 | self.tokenizer = tokenizer 17 | self.max_len = max_len 18 | 19 | self.questions = questions 20 | self.labels = labels 21 | self.chunk_tokens = chunk_tokens 22 | self.tags = tags 23 | 24 | self.cls_token = "[CLS]" 25 | self.sep_token = "[SEP]" 26 | self.unk_token = "[UNK]" 27 | self.pad_token = "[PAD]" 28 | self.mask_token = "[MASK]" 29 | 30 | self.nlp = spacy.load("en_core_web_sm") 31 | 32 | def __getitem__(self, index): 33 | # text = self.questions[index] 34 | chunk = self.chunk_tokens[index] 35 | chunk_label = self.labels[index][:] 36 | question, question_labels = self.construct_question_labels(self.questions[index]) 37 | tokenized_question, tokenized_question_labels = self.tokenize_and_preserve_labels( 38 | question, question_labels, self.tokenizer 39 | ) 40 | assert self.labels is not None 41 | 42 | tokenized_chunk, chunk_label = self.tokenize_and_preserve_labels(chunk, chunk_label, self.tokenizer) 43 | assert len(tokenized_chunk) == len(chunk_label) 44 | 45 | # [CLS] question [SEP] chunk [SEP] 46 | tokenized_text = [self.cls_token] + tokenized_question + [self.sep_token] + tokenized_chunk + [self.sep_token] 47 | labels = [False] + tokenized_question_labels + [False] + chunk_label + [False] 48 | 49 | if len(tokenized_text) > self.max_len: 50 | tokenized_text = tokenized_text[: self.max_len] 51 | if self.labels is not None: 52 | labels = labels[: self.max_len] 53 | else: 54 | tokenized_text = tokenized_text + [self.pad_token for _ in range(self.max_len - len(tokenized_text))] 55 | if self.labels is not None: 56 | labels = labels + [False for _ in range(self.max_len - len(labels))] 57 | 58 | attn_mask = [1 if tok != self.pad_token else 0 for tok in tokenized_text] 59 | ids = self.tokenizer.convert_tokens_to_ids(tokenized_text) 60 | chunk_tags = self.tags[index] 61 | sample = { 62 | "input_ids": torch.tensor(ids, dtype=torch.long), 63 | "attention_mask": torch.tensor(attn_mask, dtype=torch.long), 64 | "token_labels": torch.tensor(labels, dtype=torch.long), 65 | "sequence_labels": torch.tensor(chunk_tags, dtype=torch.long), 66 | } 67 | return sample 68 | 69 | def __len__(self): 70 | return len(self.questions) 71 | 72 | def tokenize_and_preserve_labels(self, text, text_labels, tokenizer): 73 | """ 74 | Word piece tokenization makes it difficult to match word labels 75 | back up with individual word pieces. This function tokenizes each 76 | word one at a time so that it is easier to preserve the correct 77 | label for each subword. It is, of course, a bit slower in processing 78 | time, but it will help our model achieve higher accuracy. 79 | """ 80 | 81 | tokenized_text = [] 82 | labels = [] 83 | for word, label in zip(text, text_labels): 84 | tokenized_word = tokenizer.tokenize(word) 85 | n_subwords = len(tokenized_word) 86 | tokenized_text.extend(tokenized_word) 87 | labels.extend([label] * n_subwords) 88 | 89 | return tokenized_text, labels 90 | 91 | def construct_question_labels(self, question, ignore_tokens=set([","])): 92 | doc = self.nlp(question) 93 | words = [] 94 | for word in doc: 95 | if word.lemma_ not in ignore_tokens: 96 | words.append(word.text) 97 | labels = [False] * len(words) 98 | return words, labels 99 | -------------------------------------------------------------------------------- /src/data_synthesize/negative_sampling_labeled.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import sys 5 | 6 | sys.path.append(os.path.join(os.path.dirname(__file__), "..")) 7 | from concurrent.futures import ThreadPoolExecutor, as_completed 8 | 9 | from prompts import ( 10 | NEGATIVE_TOKEN_LABEL_SYNTHESIZE_FEW_SHOT_PROMPT, 11 | TOKEN_LABELING_SYSTEM_MSG, 12 | ) 13 | from tqdm.rich import tqdm_rich 14 | 15 | from conf import ( 16 | MODEL_DICT, 17 | SYNTHESIZED_NEGATIVE_SAMPLING_DATA_PATH, 18 | SYNTHESIZED_NEGATIVE_SAMPLING_LABELED_DATA_PATH, 19 | ) 20 | from language_models import AOAI 21 | from utils import ask_model, load_jsonl 22 | 23 | 24 | class NegativeTokenLabeler: 25 | def __init__(self, model: str, dataset: str, split: str) -> None: 26 | # self.model = AOAI(model) 27 | negative_sampling_data = os.path.join(SYNTHESIZED_NEGATIVE_SAMPLING_DATA_PATH, dataset, f"{split}.jsonl") 28 | self.negative_sampling_data = load_jsonl(negative_sampling_data) 29 | self.check_if_valid = lambda x: all([k in x.keys() for k in ["extracted_words"]]) 30 | 31 | def parse(self, starting: int = 0, ending: int = None, workers=10) -> list[dict]: 32 | if ending is None: 33 | ending = len(self.negative_sampling_data) 34 | labeled_data = self.negative_sampling_data[starting:ending] 35 | if workers > 1: 36 | with ThreadPoolExecutor(max_workers=workers) as executor: 37 | tasks = {executor.submit(self.parse_sample, sample): idx for idx, sample in enumerate(labeled_data)} 38 | results = [] 39 | for future in tqdm_rich(as_completed(tasks), total=len(tasks), desc="Processing..."): 40 | task_id = tasks[future] 41 | try: 42 | result = future.result() 43 | results.append((task_id, result)) 44 | finally: 45 | pass 46 | results = [r[1] for r in sorted(results, key=lambda x: x[0])] 47 | else: 48 | results = [] 49 | for idx, sample in tqdm_rich(enumerate(labeled_data), total=len(labeled_data), desc="Processing..."): 50 | result = self.parse_sample(sample) 51 | results.append(result) 52 | return results 53 | 54 | def parse_sample(self, sample: dict) -> dict: 55 | for subq_id, subq in sample["decomposed_questions"].items(): 56 | # prompt = NEGATIVE_TOKEN_LABEL_SYNTHESIZE_FEW_SHOT_PROMPT.format( 57 | # question=subq["sub_question"], paragraph=subq["negative_paragraph"] 58 | # ) 59 | # result = ask_model( self.model, prompt, TOKEN_LABELING_SYSTEM_MSG, type="json", check_if_valid=self.check_if_valid,) 60 | result = {"extracted_words": ""} 61 | # if result is None: 62 | # sample["decomposed_questions"][subq_id]["state"] = "error" 63 | # continue 64 | sample["decomposed_questions"][subq_id]["negative_extracted_words"] = result["extracted_words"] 65 | return sample 66 | 67 | 68 | def parse_args(): 69 | parser = argparse.ArgumentParser() 70 | parser.add_argument( 71 | "--dataset", 72 | type=str, 73 | required=True, 74 | choices=["hotpotQA", "musique", "2WikiMQA"], 75 | ) 76 | parser.add_argument("--split", type=str, required=True, default="demo") 77 | parser.add_argument("--model", type=str, default="gpt35") 78 | parser.add_argument("--workers", type=int, default=10) 79 | args = parser.parse_args() 80 | return args 81 | 82 | 83 | def main(opt: argparse.Namespace): 84 | model = MODEL_DICT[opt.model] 85 | labeler = NegativeTokenLabeler(model, opt.dataset, opt.split) 86 | with open( 87 | os.path.join( 88 | SYNTHESIZED_NEGATIVE_SAMPLING_LABELED_DATA_PATH, 89 | opt.dataset, 90 | f"{opt.split}.jsonl", 91 | ), 92 | "w", 93 | ) as f: 94 | for sample in labeler.parse(): 95 | f.write(json.dumps(sample) + "\n") 96 | 97 | 98 | if __name__ == "__main__": 99 | options = parse_args() 100 | main(options) 101 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .vscode 2 | .idea 3 | wandb 4 | saved_models 5 | plt/ 6 | results 7 | model_cache/ 8 | rebuttal/ 9 | deepseek_api_key.txt 10 | 11 | # data/**/* 12 | *.json 13 | *.jsonl 14 | data/corpus/* 15 | data/corpus_full/* 16 | data/**/*.py 17 | data/backup 18 | !demo.json 19 | !demo.jsonl 20 | 21 | # Byte-compiled / optimized / DLL files 22 | __pycache__/ 23 | *.py[cod] 24 | *$py.class 25 | 26 | # C extensions 27 | *.so 28 | 29 | # Distribution / packaging 30 | .Python 31 | build/ 32 | develop-eggs/ 33 | dist/ 34 | downloads/ 35 | eggs/ 36 | .eggs/ 37 | lib/ 38 | lib64/ 39 | parts/ 40 | sdist/ 41 | var/ 42 | wheels/ 43 | share/python-wheels/ 44 | *.egg-info/ 45 | .installed.cfg 46 | *.egg 47 | MANIFEST 48 | 49 | # PyInstaller 50 | # Usually these files are written by a python script from a template 51 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 52 | *.manifest 53 | *.spec 54 | 55 | # Installer logs 56 | pip-log.txt 57 | pip-delete-this-directory.txt 58 | 59 | # Unit test / coverage reports 60 | htmlcov/ 61 | .tox/ 62 | .nox/ 63 | .coverage 64 | .coverage.* 65 | .cache 66 | nosetests.xml 67 | coverage.xml 68 | *.cover 69 | *.py,cover 70 | .hypothesis/ 71 | .pytest_cache/ 72 | cover/ 73 | 74 | # Translations 75 | *.mo 76 | *.pot 77 | 78 | # Django stuff: 79 | *.log 80 | local_settings.py 81 | db.sqlite3 82 | db.sqlite3-journal 83 | 84 | # Flask stuff: 85 | instance/ 86 | .webassets-cache 87 | 88 | # Scrapy stuff: 89 | .scrapy 90 | 91 | # Sphinx documentation 92 | docs/_build/ 93 | 94 | # PyBuilder 95 | .pybuilder/ 96 | target/ 97 | 98 | # Jupyter Notebook 99 | .ipynb_checkpoints 100 | 101 | # IPython 102 | profile_default/ 103 | ipython_config.py 104 | 105 | # pyenv 106 | # For a library or package, you might want to ignore these files since the code is 107 | # intended to run in multiple environments; otherwise, check them in: 108 | # .python-version 109 | 110 | # pipenv 111 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 112 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 113 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 114 | # install all needed dependencies. 115 | #Pipfile.lock 116 | 117 | # poetry 118 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 119 | # This is especially recommended for binary packages to ensure reproducibility, and is more 120 | # commonly ignored for libraries. 121 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 122 | #poetry.lock 123 | 124 | # pdm 125 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 126 | #pdm.lock 127 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 128 | # in version control. 129 | # https://pdm.fming.dev/#use-with-ide 130 | .pdm.toml 131 | 132 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 133 | __pypackages__/ 134 | 135 | # Celery stuff 136 | celerybeat-schedule 137 | celerybeat.pid 138 | 139 | # SageMath parsed files 140 | *.sage.py 141 | 142 | # Environments 143 | .env 144 | .venv 145 | env/ 146 | venv/ 147 | ENV/ 148 | env.bak/ 149 | venv.bak/ 150 | 151 | # Spyder project settings 152 | .spyderproject 153 | .spyproject 154 | 155 | # Rope project settings 156 | .ropeproject 157 | 158 | # mkdocs documentation 159 | /site 160 | 161 | # mypy 162 | .mypy_cache/ 163 | .dmypy.json 164 | dmypy.json 165 | 166 | # Pyre type checker 167 | .pyre/ 168 | 169 | # pytype static type analyzer 170 | .pytype/ 171 | 172 | # Cython debug symbols 173 | cython_debug/ 174 | 175 | # PyCharm 176 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 177 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 178 | # and can be added to the global gitignore or merged into this file. For a more nuclear 179 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 180 | #.idea/ 181 | 182 | *.bin 183 | .DS_Store 184 | zirO8EiG 185 | misc/ 186 | -------------------------------------------------------------------------------- /src/efficient_rag/token_weight_avg.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | from typing import Literal 5 | 6 | sys.path.append(os.path.join(os.path.dirname(__file__), "..")) 7 | from datetime import datetime 8 | 9 | import torch 10 | from tqdm import tqdm 11 | from transformers import DebertaV2TokenizerFast as DebertaV2Tokenizer 12 | 13 | from conf import ( 14 | EFFICIENT_RAG_FILTER_TRAINING_DATA_PATH, 15 | EFFICIENT_RAG_LABELER_TRAINING_DATA_PATH, 16 | TAG_MAPPING_TWO, 17 | ) 18 | from efficient_rag.data import FilterDataset, LabelerDataset 19 | from utils import load_jsonl 20 | 21 | 22 | def build_labeler_dataset( 23 | dataset: str, 24 | split: str, 25 | max_len: int = 384, 26 | tokenizer=None, 27 | test_mode: bool = False, 28 | test_sample_cnt: int = 100, 29 | ): 30 | data_path = os.path.join( 31 | EFFICIENT_RAG_LABELER_TRAINING_DATA_PATH, dataset, f"{split}.jsonl" 32 | ) 33 | data = load_jsonl(data_path) 34 | original_question = [d["question"] for d in data] 35 | chunk_tokens = [d["chunk_tokens"] for d in data] 36 | chunk_labels = [d["labels"] for d in data] 37 | tags = [TAG_MAPPING_TWO[d["tag"]] for d in data] 38 | 39 | if test_mode: 40 | return LabelerDataset( 41 | original_question[:test_sample_cnt], 42 | chunk_tokens[:test_sample_cnt], 43 | chunk_labels[:test_sample_cnt], 44 | tags[:test_sample_cnt], 45 | max_len, 46 | tokenizer, 47 | ) 48 | return LabelerDataset( 49 | original_question, chunk_tokens, chunk_labels, tags, max_len, tokenizer 50 | ) 51 | 52 | 53 | def build_filter_dataset( 54 | dataset: str, split: str, max_len: int = 256, tokenizer=None, test_mode=False 55 | ): 56 | data_path = os.path.join( 57 | EFFICIENT_RAG_FILTER_TRAINING_DATA_PATH, dataset, f"{split}.jsonl" 58 | ) 59 | data = load_jsonl(data_path) 60 | texts = [d["query_info_tokens"] for d in data] 61 | labels = [d["query_info_labels"] for d in data] 62 | if test_mode: 63 | return FilterDataset(texts[:100], labels[:100], max_len, tokenizer=tokenizer) 64 | return FilterDataset(texts, labels, max_len, tokenizer=tokenizer) 65 | 66 | 67 | def make_dataset( 68 | type: Literal["filter", "labeler"] = "filter", 69 | dataset: Literal["musique", "2WikiMQA", "hotpotQA"] = "musique", 70 | split: str = "train", 71 | tokenizer: DebertaV2Tokenizer=None, 72 | ): 73 | build_dataset_fn = build_filter_dataset if type == "filter" else build_labeler_dataset 74 | dataset = build_dataset_fn(dataset, split, tokenizer=tokenizer) 75 | return dataset 76 | 77 | 78 | def main(): 79 | parser = argparse.ArgumentParser() 80 | parser.add_argument("--type", type=str, default="labeler") 81 | parser.add_argument("--dataset", type=str, default="musique") 82 | parser.add_argument("--split", type=str, default="train") 83 | args = parser.parse_args() 84 | 85 | tokenizer = DebertaV2Tokenizer.from_pretrained("microsoft/deberta-v3-large") 86 | dataset = make_dataset(args.type, args.dataset, args.split, tokenizer=tokenizer) 87 | positive_cnt_sum = 0 88 | negative_cnt_sum = 0 89 | token_label_key = "token_labels" if args.type == "labeler" else "labels" 90 | for data in tqdm(dataset): 91 | attention_mask = data["attention_mask"] 92 | labels = data[token_label_key] 93 | positive_cnt = torch.sum(attention_mask & labels).item() 94 | negative_cnt = torch.sum(attention_mask ^ labels).item() 95 | positive_cnt_sum += positive_cnt 96 | negative_cnt_sum += negative_cnt 97 | print(f"Negative Count for {args.dataset}-{args.split}: {negative_cnt_sum}") 98 | print(f"Positive Count for {args.dataset}-{args.split}: {positive_cnt_sum}") 99 | 100 | positive_weight = (positive_cnt_sum + negative_cnt_sum) / (2 * positive_cnt_sum) 101 | negative_weight = (positive_cnt_sum + negative_cnt_sum) / (2 * negative_cnt_sum) 102 | print(f"Negative Weight for {args.dataset}-{args.split}: {negative_weight}") 103 | print(f"Positive Weight for {args.dataset}-{args.split}: {positive_weight}") 104 | 105 | 106 | if __name__ == "__main__": 107 | main() 108 | -------------------------------------------------------------------------------- /src/retrievers/embeddings/embedder.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from typing import Literal 3 | 4 | import numpy as np 5 | import torch 6 | from tqdm import TqdmExperimentalWarning, tqdm 7 | from tqdm.rich import tqdm_rich 8 | 9 | from .ada_embedding import AdaEmbedding 10 | from .contriever import Contriever 11 | from .e5 import E5BaseV2Embedding, E5LargeV2Embedding 12 | from .utils.normalize_text import normalize 13 | 14 | warnings.filterwarnings("ignore", category=TqdmExperimentalWarning) 15 | 16 | 17 | EmbeddingModelTypes = Literal[ 18 | "contriever", 19 | "e5-base-v2", 20 | "e5-large-v2", 21 | "e5-mistral-instruct", 22 | "ada-002", 23 | ] 24 | 25 | ModelTypes = { 26 | "contriever": Contriever, 27 | "e5-base-v2": E5BaseV2Embedding, 28 | "e5-large-v2": E5LargeV2Embedding, 29 | "ada-002": AdaEmbedding, 30 | } 31 | 32 | ModelCheckpointMapping = { 33 | "contriever": "model_cache/contriever-msmarco", 34 | "e5-base-v2": "model_cache/e5-base-v2", 35 | "e5-large-v2": "model_cache/e5-large-v2", 36 | "ada-002": "text-embedding-ada-002", 37 | } 38 | 39 | 40 | class Embedder(object): 41 | def __init__( 42 | self, 43 | model_type: EmbeddingModelTypes, 44 | model_name_or_path: str = None, 45 | batch_size: int = 128, 46 | chunk_size: int = int(2e6), 47 | text_lower_case: bool = False, 48 | text_normalize: bool = False, 49 | no_title: bool = False, 50 | ): 51 | if model_name_or_path is None: 52 | model_name_or_path = ModelCheckpointMapping[model_type] 53 | self.embedder = ModelTypes[model_type](model_name_or_path) 54 | self.batch_size = batch_size 55 | self.chunk_size = chunk_size 56 | 57 | self.text_lower_case = text_lower_case 58 | self.text_normalize = text_normalize 59 | self.no_title = no_title 60 | 61 | def process_text(self, line): 62 | if isinstance(line, dict): 63 | if self.no_title or "title" not in line: 64 | text = line["text"] 65 | else: 66 | text = f"{line['title']}: {line['text']}" 67 | else: 68 | text = line 69 | 70 | if self.text_lower_case: 71 | text = text.lower() 72 | if self.text_normalize: 73 | text = normalize(text) 74 | return text 75 | 76 | def get_ids(self, data): 77 | return [line["id"] for line in data] 78 | 79 | def embed_passages(self, data): 80 | ids = self.get_ids(data) 81 | texts = [self.process_text(line) for line in data] 82 | 83 | chunkBatch = (len(texts) - 1) // self.chunk_size + 1 84 | with torch.no_grad(): 85 | for idx in range(chunkBatch): 86 | print(f"Processing chunk {idx + 1}/{chunkBatch}") 87 | chunkStartIdx = idx * self.chunk_size 88 | chunkEndIdx = min((idx + 1) * self.chunk_size, len(texts)) 89 | chunk = texts[chunkStartIdx:chunkEndIdx] 90 | chunk_ids = ids[chunkStartIdx:chunkEndIdx] 91 | chunk_embeddings = self.embed(chunk, verbose=True) 92 | yield idx, (chunk_ids, chunk_embeddings) 93 | 94 | def embed(self, textBatch, verbose=False): 95 | embeddings = np.array([]) 96 | textBatch = [self.process_text(text) for text in textBatch] 97 | batches = (len(textBatch) - 1) // self.batch_size + 1 98 | with torch.no_grad(): 99 | if verbose: 100 | iter_range = tqdm_rich(range(batches), desc="Embedding") 101 | else: 102 | iter_range = range(batches) 103 | # for idx in tqdm_rich(range(batches), desc="Embedding"): 104 | for idx in iter_range: 105 | start_idx = idx * self.batch_size 106 | end_idx = min((idx + 1) * self.batch_size, len(textBatch)) 107 | batch = textBatch[start_idx:end_idx] 108 | curEmbeddings = self.embedder.embed_batch(batch) 109 | embeddings = np.vstack((embeddings, curEmbeddings)) if embeddings.size else curEmbeddings 110 | return embeddings 111 | 112 | def get_dim(self): 113 | return self.embedder.embedding_vector_size 114 | -------------------------------------------------------------------------------- /src/retrievers/embeddings/e5.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch._tensor import Tensor 3 | import torch.nn.functional as F 4 | 5 | from .dense_embedding import DenseEmbedding 6 | 7 | 8 | class E5Embedding(DenseEmbedding): 9 | def __init__( 10 | self, 11 | model_name_or_path: str, 12 | embedding_vector_size: int, 13 | pooling_type: str = None, 14 | ): 15 | if pooling_type is None: 16 | pooling_type = "e5-average" 17 | super().__init__( 18 | model_name_or_path=model_name_or_path, 19 | embedding_vector_size=embedding_vector_size, 20 | pooling_type=pooling_type, 21 | ) 22 | 23 | def pooling(self, last_hidden_states, attention_mask) -> Tensor: 24 | last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0) 25 | embeddings = last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None] 26 | embeddings = F.normalize(embeddings, p=2, dim=1) 27 | return embeddings 28 | 29 | 30 | class E5BaseV2Embedding(E5Embedding): 31 | def __init__(self, model_name_or_path: str = None): 32 | if model_name_or_path is None: 33 | model_name_or_path = "intfloat/e5-base-v2" 34 | super().__init__( 35 | model_name_or_path=model_name_or_path, 36 | embedding_vector_size=768, 37 | ) 38 | 39 | 40 | class E5LargeV2Embedding(E5Embedding): 41 | def __init__(self, model_name_or_path: str = None): 42 | if model_name_or_path is None: 43 | model_name_or_path = "intfloat/e5-large-v2" 44 | super().__init__( 45 | model_name_or_path=model_name_or_path, 46 | embedding_vector_size=1024, 47 | ) 48 | 49 | 50 | class E5MistralInstructEmbedding(E5Embedding): 51 | def __init__(self, model_name_or_path: str = None): 52 | if model_name_or_path is None: 53 | model_name_or_path = "intfloat/e5-mistral-7b-instruct" 54 | super().__init__( 55 | model_name_or_path=model_name_or_path, 56 | embedding_vector_size=4096, 57 | pooling_type="last_token_pool", 58 | ) 59 | 60 | self.template = "Instruct: {task_description}\nQuery: {query}" 61 | self.max_length = 4096 62 | 63 | def get_detailed_instruct(self, task_description: str, query: str) -> str: 64 | return self.template.format(task_description=task_description, query=query) 65 | 66 | def pooling(self, last_hidden_states, attention_mask) -> torch.Tensor: 67 | left_padding = attention_mask[:, -1].sum() == attention_mask.shape[0] 68 | if left_padding: 69 | embeddings = last_hidden_states[:, -1] 70 | else: 71 | sequence_lengths = attention_mask.sum(dim=1) - 1 72 | batch_size = last_hidden_states.shape[0] 73 | embeddings = last_hidden_states[ 74 | torch.arange(batch_size, device=last_hidden_states.device), 75 | sequence_lengths, 76 | ] 77 | embeddings = F.normalize(embeddings, p=2, dim=1) 78 | return embeddings 79 | 80 | def embed_batch(self, queries: list[str]): 81 | if self.model is None or self.tokenizer is None: 82 | self.instantiate() 83 | 84 | with torch.no_grad(): 85 | batch_dict = self.tokenizer( 86 | queries, 87 | max_length=self.max_length - 1, 88 | return_attention_mask=False, 89 | padding=False, 90 | truncation=True, 91 | ) 92 | batch_dict["input_ids"] = [ 93 | input_ids + [self.tokenizer.eos_token_id] for input_ids in batch_dict["input_ids"] 94 | ] 95 | batch_dict = self.tokenizer.pad( 96 | batch_dict, 97 | padding=True, 98 | return_attention_mask=True, 99 | return_tensors="pt", 100 | ) 101 | batch_dict = {k: v.cuda() for k, v in batch_dict.items()} 102 | outputs = self.model(**batch_dict) 103 | embeddings = self.pooling(outputs.last_hidden_state, batch_dict["attention_mask"]) 104 | embeddings = embeddings.cpu().numpy() 105 | return embeddings 106 | -------------------------------------------------------------------------------- /src/efficientrag_qa.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from concurrent.futures import ThreadPoolExecutor, as_completed 4 | from typing import Dict, List 5 | 6 | from tqdm.rich import tqdm_rich 7 | 8 | from baseline.retrieve.direct import ( 9 | DIRECT_RETRIEVE_ANSWER_PROMPT_HOTPOTQA, 10 | DIRECT_RETRIEVE_ANSWER_PROMPT_MUSIQUE, 11 | DIRECT_RETRIEVE_ANSWER_PROMPT_WIKIMQA, 12 | ) 13 | from conf import RETRIEVE_RESULT_PATH 14 | from language_models import LanguageModel, get_model 15 | from utils import ask_model, load_jsonl, write_jsonl 16 | 17 | PROMPT_MAPPING = { 18 | "hotpotQA": DIRECT_RETRIEVE_ANSWER_PROMPT_HOTPOTQA, 19 | "2WikiMQA": DIRECT_RETRIEVE_ANSWER_PROMPT_WIKIMQA, 20 | "musique": DIRECT_RETRIEVE_ANSWER_PROMPT_MUSIQUE, 21 | } 22 | 23 | 24 | class EfficientRAG_QA: 25 | def __init__( 26 | self, 27 | model: LanguageModel, 28 | data: List[Dict], 29 | dataset: str, 30 | num_workers: int = 10, 31 | ) -> None: 32 | self.model = model 33 | self.data = data 34 | self.num_workers = num_workers 35 | self.prompt_template = PROMPT_MAPPING[dataset] 36 | 37 | def parse_samples_in_parallel(self) -> List[Dict]: 38 | with ThreadPoolExecutor(max_workers=self.num_workers) as executor: 39 | tasks = { 40 | executor.submit(self.parse_sample, sample): idx 41 | for idx, sample in enumerate(self.data) 42 | } 43 | results = [] 44 | for future in tqdm_rich(as_completed(tasks), total=len(self.data)): 45 | idx = tasks[future] 46 | try: 47 | res = future.result() 48 | results.append((idx, res)) 49 | except Exception as e: 50 | import traceback 51 | 52 | print(f"Error processing sample {idx}: {e}") 53 | traceback.print_exc() 54 | results = [r[1] for r in sorted(results, key=lambda x: x[0])] 55 | return results 56 | 57 | def extract_chunks(self, sample: Dict): 58 | chunks = [] 59 | for iter in range(4): 60 | if f"{iter}" not in sample: 61 | break 62 | for chunk in sample[f"{iter}"]["docs"]: 63 | if chunk["label"] != "": 64 | chunks.append(chunk) 65 | return chunks 66 | 67 | def parse_sample(self, sample: Dict) -> Dict: 68 | knowledge_list = self.extract_chunks(sample) 69 | chunks = "\n".join([chunk["text"] for chunk in knowledge_list]) 70 | question = sample["query"] 71 | prompt = self.prompt_template.format(question=question, knowledge=chunks) 72 | result = ask_model( 73 | self.model, 74 | prompt, 75 | type="json", 76 | check_if_valid=lambda x: isinstance(x, dict) and "answer" in x, 77 | mode="chat", 78 | ) 79 | chunk_ids = [p["id"] for p in knowledge_list] 80 | response = { 81 | # "question_id": sample["id"], 82 | "question": question, 83 | "answer": sample["answer"], 84 | "model_output": result["answer"], 85 | "oracle_ids": sample["oracle"], 86 | "chunk_ids": chunk_ids, 87 | } 88 | return response 89 | 90 | 91 | def parse_args(): 92 | parser = argparse.ArgumentParser() 93 | parser.add_argument("--fpath", type=str, required=True) 94 | parser.add_argument("--model", type=str, default="llama") 95 | parser.add_argument("--dataset", type=str, required=True) 96 | parser.add_argument("--workers", type=int, default=10) 97 | parser.add_argument("--suffix", type=str, default="") 98 | args = parser.parse_args() 99 | return args 100 | 101 | 102 | def main(opts: argparse.Namespace): 103 | model: LanguageModel = get_model(opts.model) 104 | dataset = load_jsonl(opts.fpath) 105 | efficient_rag_qa = EfficientRAG_QA(model, dataset, opts.dataset) 106 | results = efficient_rag_qa.parse_samples_in_parallel() 107 | save_path = os.path.join( 108 | RETRIEVE_RESULT_PATH, 109 | "efficient_rag", 110 | f"{opts.dataset}-{opts.suffix}_qa_results.jsonl", 111 | ) 112 | write_jsonl(results, save_path) 113 | 114 | 115 | if __name__ == "__main__": 116 | options = parse_args() 117 | main(options) 118 | -------------------------------------------------------------------------------- /src/retrievers/passage_retriever.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | from glob import glob 5 | from typing import List, Union 6 | 7 | sys.path.append(os.path.join(os.path.dirname(__file__), "..")) 8 | from retrievers.embeddings import ( 9 | Embedder, 10 | EmbeddingModelTypes, 11 | ModelCheckpointMapping, 12 | ModelTypes, 13 | ) 14 | from retrievers.utils.utils import load_passages 15 | from retrievers.vector_index import FaissIndex 16 | from retrievers.vector_index.faiss_index import IndexType 17 | 18 | os.environ["TOKENIZERS_PARALLELISM"] = "true" 19 | 20 | 21 | class Retriever(object): 22 | def __init__( 23 | self, 24 | passage_path: str, 25 | passage_embedding_path: str = None, 26 | index_path_dir: str = None, 27 | model_type: EmbeddingModelTypes = "e5-base-v2", 28 | model_path: str = None, 29 | save_or_load_index: bool = True, 30 | batch_size: int = 128, 31 | embed_vector_dim: int = None, 32 | index_type: IndexType = "Flat", 33 | max_search_batch_size: int = 2048, 34 | ): 35 | if model_path is None: 36 | model_path = ModelCheckpointMapping[model_type] 37 | self.embedder = Embedder(model_type, model_path, batch_size) 38 | if embed_vector_dim is None: 39 | embed_vector_dim = self.embedder.get_dim() 40 | self.index = FaissIndex(embed_vector_dim, index_type, max_search_batch_size) 41 | if index_path_dir is None: 42 | index_path_dir = passage_embedding_path 43 | 44 | if save_or_load_index and self.index.exist_index(index_path_dir): 45 | print(f"Loading index from {index_path_dir}") 46 | self.index.deserialize(index_path_dir) 47 | else: 48 | print(f"Building index from {passage_embedding_path}") 49 | self.load_embeddings(passage_embedding_path) 50 | if save_or_load_index: 51 | print(f"Saving index to {index_path_dir}") 52 | self.index.serialize(index_path_dir) 53 | print(f"Loading passages from {passage_path}") 54 | passages = load_passages(passage_path) 55 | self.passage_map = {p["id"]: p for p in passages} 56 | print(f"Loaded {len(passages)} passages.") 57 | 58 | def load_embeddings(self, passage_embedding_path): 59 | embedding_file = sorted(glob(f"{passage_embedding_path}/passage*")) 60 | self.index.load_data(embedding_file) 61 | 62 | def search(self, query: Union[str, List[str]], top_k: int = 10): 63 | query = [query] if isinstance(query, str) else query 64 | query_vectors = self.embedder.embed(query) 65 | top_ids_scores = self.index.search(query_vectors, top_k) 66 | # convert passage id to passage 67 | docs = [ 68 | [self.passage_map[doc_id] for doc_id in top_docs] 69 | for top_docs, top_scores in top_ids_scores 70 | ] 71 | docs = [doc_list[:top_k] for doc_list in docs] 72 | return docs 73 | 74 | 75 | def parse_args(): 76 | parser = argparse.ArgumentParser() 77 | parser.add_argument( 78 | "--passages", type=str, required=True, help="document file path" 79 | ) 80 | parser.add_argument( 81 | "--model_type", 82 | type=str, 83 | default="e5-base-v2", 84 | choices=list(ModelTypes.keys()), 85 | help="Embedding Model", 86 | ) 87 | parser.add_argument( 88 | "--model_name_or_path", 89 | type=str, 90 | default=None, 91 | help="Embedding model checkpoint", 92 | ) 93 | parser.add_argument("--save_or_load_index", action="store_true") 94 | parser.add_argument( 95 | "--embeddings", type=str, required=True, help="Document embedding path" 96 | ) 97 | parser.add_argument("--query", type=str, help="query") 98 | args = parser.parse_args() 99 | return args 100 | 101 | 102 | def test(opt: argparse.Namespace): 103 | retriever = Retriever( 104 | opt.passages, 105 | opt.embeddings, 106 | model_type=opt.model_type, 107 | model_path=opt.model_name_or_path, 108 | ) 109 | if opt.query is None: 110 | queries = [ 111 | "Were Scott Derrickson and Ed Wood of the same nationality?", 112 | "What is the difference between llama and alpaca?", 113 | ] 114 | else: 115 | queries = [opt.query] 116 | docs = retriever.search(queries, 20) 117 | print(docs) 118 | 119 | 120 | if __name__ == "__main__": 121 | options = parse_args() 122 | test(options) 123 | -------------------------------------------------------------------------------- /src/efficient_rag/filter_training.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | 5 | sys.path.append(os.path.join(os.path.dirname(__file__), "..")) 6 | from datetime import datetime 7 | 8 | import torch 9 | from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score 10 | from transformers import ( 11 | DebertaV2ForTokenClassification, 12 | DebertaV2Tokenizer, 13 | EvalPrediction, 14 | Trainer, 15 | TrainingArguments, 16 | ) 17 | 18 | from conf import EFFICIENT_RAG_FILTER_TRAINING_DATA_PATH, MODEL_PATH 19 | from efficient_rag.data import FilterDataset 20 | from utils import load_jsonl 21 | 22 | os.environ["WANDB_PROJECT"] = "EfficientRAG_filter" 23 | 24 | 25 | def eval_filter(pred: EvalPrediction): 26 | preds = torch.tensor(pred.predictions.argmax(-1)) 27 | labels = torch.tensor(pred.label_ids) 28 | mask = torch.tensor(pred.inputs != 0) 29 | 30 | preds = torch.masked_select(preds, mask) 31 | labels = torch.masked_select(labels, mask) 32 | 33 | filter_f1 = f1_score(labels, preds, average=None) #noqa 34 | 35 | result = { 36 | "accuracy": accuracy_score(labels, preds), 37 | "recall": recall_score(labels, preds, average="micro"), 38 | "precision": precision_score(labels, preds, average="micro"), 39 | "f1": f1_score(labels, preds, average="micro"), 40 | "f1_marco": f1_score(labels, preds, average="macro"), 41 | "negative_f1": filter_f1[0], 42 | "positive_f1": filter_f1[1], 43 | } 44 | return result 45 | 46 | 47 | def parse_args(): 48 | parser = argparse.ArgumentParser(description="EfficientRAG Query Filter") 49 | parser.add_argument("--dataset", required=True, type=str) 50 | parser.add_argument("--save_path", type=str, default="saved_models/filter") 51 | parser.add_argument("--lr", help="learning rate", default=1e-5, type=float) 52 | parser.add_argument("--epoch", default=2, type=int) 53 | parser.add_argument("--batch_size", type=int, default=32) 54 | parser.add_argument("--max_length", type=int, default=256) 55 | parser.add_argument("--warmup_steps", type=int, default=200) 56 | parser.add_argument("--eval_steps", type=int, default=100) 57 | parser.add_argument("--logging_steps", type=int, default=10) 58 | parser.add_argument("--test", action="store_true") 59 | args = parser.parse_args() 60 | return args 61 | 62 | 63 | def build_dataset(dataset: str, split: str, max_len: int = 128, tokenizer=None, test_mode=False): 64 | data_path = os.path.join(EFFICIENT_RAG_FILTER_TRAINING_DATA_PATH, dataset, f"{split}.jsonl") 65 | data = load_jsonl(data_path) 66 | texts = [d["query_info_tokens"] for d in data] 67 | labels = [d["query_info_labels"] for d in data] 68 | if test_mode: 69 | return FilterDataset(texts[:100], labels[:100], max_len, tokenizer=tokenizer) 70 | return FilterDataset(texts, labels, max_len, tokenizer=tokenizer) 71 | 72 | 73 | def main(opt: argparse.Namespace): 74 | tokenizer = DebertaV2Tokenizer.from_pretrained(MODEL_PATH) 75 | model = DebertaV2ForTokenClassification.from_pretrained(MODEL_PATH, num_labels=2) 76 | save_dir = os.path.join(opt.save_path, f"filter_{datetime.now().strftime(r'%Y%m%d_%H%M%S')}") 77 | run_name = f"{opt.dataset}-{datetime.now().strftime(r'%m%d%H%M')}" 78 | train_dataset = build_dataset(opt.dataset, "train", opt.max_length, tokenizer, test_mode=opt.test) 79 | valid_dataset = build_dataset(opt.dataset, "valid", opt.max_length, tokenizer, test_mode=opt.test) 80 | 81 | training_args = TrainingArguments( 82 | output_dir=save_dir, 83 | num_train_epochs=opt.epoch, 84 | learning_rate=opt.lr, 85 | per_device_train_batch_size=opt.batch_size, 86 | per_device_eval_batch_size=64, 87 | weight_decay=0.01, 88 | logging_dir=os.path.join(save_dir, "log"), 89 | save_strategy="epoch", 90 | evaluation_strategy="steps", 91 | eval_steps=opt.eval_steps, 92 | report_to="wandb", 93 | run_name=run_name, 94 | logging_steps=opt.logging_steps, 95 | warmup_steps=opt.warmup_steps, 96 | save_only_model=True, 97 | include_inputs_for_metrics=True, 98 | ) 99 | trainer = Trainer( 100 | model=model, 101 | args=training_args, 102 | train_dataset=train_dataset, 103 | eval_dataset=valid_dataset, 104 | tokenizer=tokenizer, 105 | compute_metrics=eval_filter, 106 | ) 107 | trainer.train() 108 | 109 | 110 | if __name__ == "__main__": 111 | options = parse_args() 112 | if options.test: 113 | os.environ["WANDB_MODE"] = "dryrun" 114 | main(options) 115 | -------------------------------------------------------------------------------- /src/data_synthesize/graph_driven.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import re 4 | import sys 5 | from typing import Literal 6 | 7 | from tqdm import tqdm 8 | 9 | sys.path.append(os.path.join(os.path.dirname(__file__), "..")) 10 | 11 | from language_models import AOAI 12 | 13 | QUERY_TYPE = Literal["extremely long-tail", "long tail", "common"] 14 | # QUERY_LENGTH = Literal['less than 5 words', '5 to 15 words', 'at least 10 words'] 15 | DIFFICULTY = Literal["children", "high school", "college", "PhD"] 16 | CLARITY = Literal["clear", "understandable with some effort", "ambiguous"] 17 | DOC_WORDS = Literal["50", "100", "200", "300", "400", "500"] 18 | HOP = Literal["single-hop", "2-hop", "3-hop", "4-hop"] 19 | HOP_TYPE = Literal["combination", "extension"] 20 | 21 | # Topology generation system prompt 22 | # TODO: the examples can be seed examples from KGQA cases 23 | 24 | TOPOLOGY_GENERATION_SYSTEM_PROMPT = """Brainstorm a list of tuples of query, multi-hop reasoning path, and the answer. The reasoning path consists of a series of hops, and one hop is a relation between two entities, represented as ["entity", "relation", "entity"]. 25 | 26 | There are several basic types of reasoning paths, here are a few basic examples for your reference: 27 | - Composition: 28 | "user_query": "which country is the birthplace of the father Walter Chetwync?", 29 | "reasoning_path": ["Walter Chetwynd, 1st Viscount Chetwynd", "father", "John Chetwynd"], ["John Chetwynd", "country of citizenship", "United Kingdom of Great Britain and Ireland"] 30 | "answer": "United Kingdom of Great Britain and Ireland" 31 | 32 | - Comparison: 33 | "user_query": "which movie has a earlier publication date, Moonstruck or One step to Eternity?", 34 | "reasoning_path": ["Moonstruck", "publication date", "1987"], ["One step to Eternity", "publication date", "1954"] 35 | "answer": "One step to Eternity" 36 | 37 | - Intersection: 38 | "user_query": "Who is the only person to win an olympic medal and a Nobel prize?", 39 | "reasoning_path": ["Philip John Noel-Baker", "won", "An olympic medal"], ["Philip John Noel-Baker", "award received", "Nobel Prize in Peace"] 40 | "answer": "Philip John Noel-Baker" 41 | 42 | There can be more types of multi-hop reasoning paths that could be combination of the above types, or extension of the above types. 43 | 44 | Here are a few complex examples for your reference: 45 | - Extension of Composition: 46 | "user_query": "What influenced the architectural style of the Hagia Sophia?", 47 | "reasoning_path": ["Hagia Sophia", "construction period", "6th century"], ["6th century", "Byzantine Empire", "influence"], ["Byzantine Empire", "Roman architecture", "influence"] 48 | "answer": "Roman architecture" 49 | 50 | - Combination of Composition and Comparison: 51 | "user_query": "The director of The Millionaire's Double and the director of The Groundstar Conspiracy, who was born later?", 52 | "reasoning_path": ["The Millionaire's Double", "director", "Harry Davenport"], ["The Groundstar Conspiracy", "director", "Lamont Johnson"], ["Harry Davenport", "date of birth", "1866"], ["Lamont Johnson", "date of birth", "1922"] 53 | "answer": "Lamont Johnson" 54 | 55 | The tuple is a JSON object which must contain the following keys: 56 | - "user_query": a string, a multi-hop query. 57 | - "reasoning_chain": a string, a reasoning chain that describes the information and retrieve process of the user query. 58 | - "answer": a string, the answer to the user query. 59 | 60 | Please adhere to the following guidelines: 61 | - The "user_query" should be {query_type},and diverse in topic. 62 | - The "reasoning_chain" should contain {hop} hops. 63 | - Both the query and reasoning path should be in English. 64 | - Both the query and reasoning path require {difficulty} level education to understand. 65 | - the reasoning path should be {hop_type} of basic types of reasoning paths. 66 | 67 | Your output must always be a python list of JSON object only, with about 5 elements. Do not explain yourself or output anything else. Be creative!""" 68 | 69 | 70 | class QueryChunkSynthesizer(object): 71 | def __init__( 72 | self, 73 | model: str = "gpt-4-1106-preview", 74 | ): 75 | if "gpt" in model: 76 | self.model = AOAI(model, api_version="2023-12-01-preview") 77 | else: 78 | raise ValueError("Model not supported") 79 | self.max_retry = 3 80 | 81 | def query_reasoning_path_generation( 82 | self, 83 | query_type: QUERY_TYPE, 84 | difficulty: DIFFICULTY, 85 | hop: HOP, 86 | hop_type: HOP_TYPE, 87 | ): 88 | trials = 0 89 | query = TOPOLOGY_GENERATION_SYSTEM_PROMPT.format( 90 | query_type=query_type, difficulty=difficulty, hop=hop, hop_type=hop_type 91 | ) 92 | while trials < self.max_retry: 93 | try: 94 | query_reasoning_path = self.model.chat(query) 95 | print(query_reasoning_path) 96 | matches = re.findall(r"```python(.*?)```", query_reasoning_path, re.DOTALL) 97 | if matches: 98 | query_reasoning_path = matches[0] 99 | query_reasoning_path = query_reasoning_path.strip("\n") 100 | data = json.loads(query_reasoning_path) 101 | return data 102 | except Exception as e: 103 | trials += 1 104 | return None 105 | 106 | 107 | def main(): 108 | synthesizer = QueryChunkSynthesizer() 109 | query_reasoning_path = synthesizer.query_reasoning_path_generation("common", "PhD", "5-hop", "combination") 110 | 111 | 112 | if __name__ == "__main__": 113 | main() 114 | -------------------------------------------------------------------------------- /src/retrievers/embeddings/utils/normalize_text.py: -------------------------------------------------------------------------------- 1 | """ 2 | adapted from chemdataextractor.text.normalize 3 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 4 | Tools for normalizing text. 5 | https://github.com/mcs07/ChemDataExtractor 6 | :copyright: Copyright 2016 by Matt Swain. 7 | :license: MIT 8 | 9 | Permission is hereby granted, free of charge, to any person obtaining 10 | a copy of this software and associated documentation files (the 11 | 'Software'), to deal in the Software without restriction, including 12 | without limitation the rights to use, copy, modify, merge, publish, 13 | distribute, sublicense, and/or sell copies of the Software, and to 14 | permit persons to whom the Software is furnished to do so, subject to 15 | the following conditions: 16 | 17 | The above copyright notice and this permission notice shall be 18 | included in all copies or substantial portions of the Software. 19 | 20 | THE SOFTWARE IS PROVIDED 'AS IS', WITHOUT WARRANTY OF ANY KIND, 21 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 22 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. 23 | IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY 24 | CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, 25 | TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE 26 | SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 27 | """ 28 | 29 | #: Control characters. 30 | CONTROLS = { 31 | "\u0001", 32 | "\u0002", 33 | "\u0003", 34 | "\u0004", 35 | "\u0005", 36 | "\u0006", 37 | "\u0007", 38 | "\u0008", 39 | "\u000e", 40 | "\u000f", 41 | "\u0011", 42 | "\u0012", 43 | "\u0013", 44 | "\u0014", 45 | "\u0015", 46 | "\u0016", 47 | "\u0017", 48 | "\u0018", 49 | "\u0019", 50 | "\u001a", 51 | "\u001b", 52 | } 53 | # There are further control characters, but they are instead replaced with a space by unicode normalization 54 | # '\u0009', '\u000a', '\u000b', '\u000c', '\u000d', '\u001c', '\u001d', '\u001e', '\u001f' 55 | 56 | 57 | #: Hyphen and dash characters. 58 | HYPHENS = { 59 | "-", # \u002d Hyphen-minus 60 | "‐", # \u2010 Hyphen 61 | "‑", # \u2011 Non-breaking hyphen 62 | "⁃", # \u2043 Hyphen bullet 63 | "‒", # \u2012 figure dash 64 | "–", # \u2013 en dash 65 | "—", # \u2014 em dash 66 | "―", # \u2015 horizontal bar 67 | } 68 | 69 | #: Minus characters. 70 | MINUSES = { 71 | "-", # \u002d Hyphen-minus 72 | "−", # \u2212 Minus 73 | "-", # \uff0d Full-width Hyphen-minus 74 | "⁻", # \u207b Superscript minus 75 | } 76 | 77 | #: Plus characters. 78 | PLUSES = { 79 | "+", # \u002b Plus 80 | "+", # \uff0b Full-width Plus 81 | "⁺", # \u207a Superscript plus 82 | } 83 | 84 | #: Slash characters. 85 | SLASHES = { 86 | "/", # \u002f Solidus 87 | "⁄", # \u2044 Fraction slash 88 | "∕", # \u2215 Division slash 89 | } 90 | 91 | #: Tilde characters. 92 | TILDES = { 93 | "~", # \u007e Tilde 94 | "˜", # \u02dc Small tilde 95 | "⁓", # \u2053 Swung dash 96 | "∼", # \u223c Tilde operator #in mbert vocab 97 | "∽", # \u223d Reversed tilde 98 | "∿", # \u223f Sine wave 99 | "〜", # \u301c Wave dash #in mbert vocab 100 | "~", # \uff5e Full-width tilde #in mbert vocab 101 | } 102 | 103 | #: Apostrophe characters. 104 | APOSTROPHES = { 105 | "'", # \u0027 106 | "’", # \u2019 107 | "՚", # \u055a 108 | "Ꞌ", # \ua78b 109 | "ꞌ", # \ua78c 110 | "'", # \uff07 111 | } 112 | 113 | #: Single quote characters. 114 | SINGLE_QUOTES = { 115 | "'", # \u0027 116 | "‘", # \u2018 117 | "’", # \u2019 118 | "‚", # \u201a 119 | "‛", # \u201b 120 | } 121 | 122 | #: Double quote characters. 123 | DOUBLE_QUOTES = { 124 | '"', # \u0022 125 | "“", # \u201c 126 | "”", # \u201d 127 | "„", # \u201e 128 | "‟", # \u201f 129 | } 130 | 131 | #: Accent characters. 132 | ACCENTS = { 133 | "`", # \u0060 134 | "´", # \u00b4 135 | } 136 | 137 | #: Prime characters. 138 | PRIMES = { 139 | "′", # \u2032 140 | "″", # \u2033 141 | "‴", # \u2034 142 | "‵", # \u2035 143 | "‶", # \u2036 144 | "‷", # \u2037 145 | "⁗", # \u2057 146 | } 147 | 148 | #: Quote characters, including apostrophes, single quotes, double quotes, accents and primes. 149 | QUOTES = APOSTROPHES | SINGLE_QUOTES | DOUBLE_QUOTES | ACCENTS | PRIMES 150 | 151 | 152 | def normalize(text): 153 | for control in CONTROLS: 154 | text = text.replace(control, "") 155 | text = text.replace("\u000b", " ").replace("\u000c", " ").replace("\u0085", " ") 156 | 157 | for hyphen in HYPHENS | MINUSES: 158 | text = text.replace(hyphen, "-") 159 | text = text.replace("\u00ad", "") 160 | 161 | for double_quote in DOUBLE_QUOTES: 162 | text = text.replace(double_quote, '"') # \u0022 163 | for single_quote in SINGLE_QUOTES | APOSTROPHES | ACCENTS: 164 | text = text.replace(single_quote, "'") # \u0027 165 | text = text.replace("′", "'") # \u2032 prime 166 | text = text.replace("‵", "'") # \u2035 reversed prime 167 | text = text.replace("″", "''") # \u2033 double prime 168 | text = text.replace("‶", "''") # \u2036 reversed double prime 169 | text = text.replace("‴", "'''") # \u2034 triple prime 170 | text = text.replace("‷", "'''") # \u2037 reversed triple prime 171 | text = text.replace("⁗", "''''") # \u2057 quadruple prime 172 | 173 | text = text.replace("…", "...").replace(" . . . ", " ... ") # \u2026 174 | 175 | for slash in SLASHES: 176 | text = text.replace(slash, "/") 177 | 178 | # for tilde in TILDES: 179 | # text = text.replace(tilde, '~') 180 | 181 | return text 182 | -------------------------------------------------------------------------------- /src/baseline/retrieve/decompose.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | from concurrent.futures import ThreadPoolExecutor, as_completed 5 | from typing import List 6 | 7 | sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..")) 8 | 9 | import time 10 | 11 | from direct import ( 12 | DIRECT_RETRIEVE_ANSWER_PROMPT_HOTPOTQA, 13 | DIRECT_RETRIEVE_ANSWER_PROMPT_MUSIQUE, 14 | DIRECT_RETRIEVE_ANSWER_PROMPT_WIKIMQA, 15 | ) 16 | from tqdm.rich import tqdm_rich 17 | 18 | from conf import ( 19 | CORPUS_DATA_PATH, 20 | RETRIEVE_RESULT_PATH, 21 | SYNTHESIZED_NEXT_QUERY_EXTRACTED_DATA_PATH, 22 | ) 23 | from empirical_retrieve import QUERY_DECOMPOSE_PROMPT 24 | from language_models import LanguageModel, get_model 25 | from retrievers import Retriever 26 | from utils import ask_model, load_jsonl, write_jsonl 27 | 28 | 29 | def parse_args(): 30 | parser = argparse.ArgumentParser() 31 | parser.add_argument("--dataset", type=str, required=True) 32 | parser.add_argument("--retriever", type=str, required=True) 33 | parser.add_argument("--topk", type=int, default=10) 34 | parser.add_argument("--model", type=str, default="llama-8B") 35 | parser.add_argument("--workers", type=int, default=10) 36 | return parser.parse_args() 37 | 38 | 39 | def llm_decompose(model, question): 40 | prompt = QUERY_DECOMPOSE_PROMPT.format(question=question) 41 | decomposition = ask_model( 42 | model, 43 | prompt, 44 | type="json", 45 | check_if_valid=lambda x: type(x) is dict and "decomposed_questions" in x, 46 | ) 47 | queries = decomposition["decomposed_questions"] 48 | return queries 49 | 50 | 51 | def process_sample( 52 | model: LanguageModel, 53 | retriever: Retriever, 54 | prompt_template: str, 55 | sample: dict, 56 | topk: int = 10, 57 | ) -> dict: 58 | question = sample["question"] 59 | decomposed_questions = llm_decompose(model, question) 60 | # TODO: noqa, check if it works!! 61 | knowledges = sum(retriever.search(decomposed_questions, top_k=topk), []) 62 | 63 | # deduplicate knowledge 64 | seen_ids = set() 65 | knowledge = [] 66 | for k in knowledges: 67 | if k["id"] not in seen_ids: 68 | knowledge.append(k) 69 | seen_ids.add(k["id"]) 70 | 71 | # knowledge = retriever.search(question, top_k=topk)[0] 72 | knowledge_chunk = "\n".join([p["text"] for p in knowledge]) 73 | chunk_ids = [p["id"] for p in knowledge] 74 | prompt = prompt_template.format(knowledge=knowledge_chunk, question=question) 75 | result = ask_model( 76 | model, 77 | prompt, 78 | type="json", 79 | check_if_valid=lambda x: type(x) is dict and "answer" in x, 80 | mode="chat", 81 | ) 82 | sub_ids = sorted(list(sample["decomposed_questions"].keys())) 83 | oracle_ids = [ 84 | f"{sample['id']}-{'{:02d}'.format(sample['decomposed_questions'][sub_id]['positive_paragraph_idx'])}" 85 | for sub_id in sub_ids 86 | ] 87 | result = { 88 | "question_id": sample["id"], 89 | "question": question, 90 | "answer": sample["answer"], 91 | "oracle_ids": oracle_ids, 92 | "chunk_ids": chunk_ids, 93 | "model_output": result["answer"], 94 | } 95 | return result 96 | 97 | 98 | def process_dataset( 99 | model: LanguageModel, 100 | retriever: Retriever, 101 | prompt_template: str, 102 | topk: int, 103 | dataset: List[dict], 104 | max_workers: int = 10, 105 | ): 106 | with ThreadPoolExecutor(max_workers=max_workers) as executor: 107 | tasks = { 108 | executor.submit( 109 | process_sample, model, retriever, prompt_template, sample, topk 110 | ): idx 111 | for idx, sample in enumerate(dataset) 112 | } 113 | results = [] 114 | for future in tqdm_rich(as_completed(tasks), total=len(dataset)): 115 | idx = tasks[future] 116 | try: 117 | res = future.result() 118 | results.append((idx, res)) 119 | except Exception as e: 120 | print(f"Error processing sample {idx}: {e}") 121 | results = [r[1] for r in sorted(results, key=lambda x: x[0])] 122 | return results 123 | 124 | 125 | def main(opt: argparse.Namespace): 126 | start = time.time() 127 | passage_path = os.path.join(CORPUS_DATA_PATH, opt.dataset, "corpus.jsonl") 128 | if opt.retriever == "e5-base-v2": 129 | embedding_path = os.path.join(CORPUS_DATA_PATH, opt.dataset, "e5-base") 130 | elif opt.retriever == "contriever": 131 | embedding_path = os.path.join(CORPUS_DATA_PATH, opt.dataset, "contriever") 132 | retriever = Retriever( 133 | passage_path=passage_path, 134 | passage_embedding_path=embedding_path, 135 | index_path_dir=embedding_path, 136 | model_type=opt.retriever, 137 | ) 138 | dataset = load_jsonl( 139 | os.path.join( 140 | SYNTHESIZED_NEXT_QUERY_EXTRACTED_DATA_PATH, opt.dataset, "valid.jsonl" 141 | ) 142 | ) 143 | 144 | prompt_template_mapping = { 145 | "musique": DIRECT_RETRIEVE_ANSWER_PROMPT_MUSIQUE, 146 | "2WikiMQA": DIRECT_RETRIEVE_ANSWER_PROMPT_WIKIMQA, 147 | "hotpotQA": DIRECT_RETRIEVE_ANSWER_PROMPT_HOTPOTQA, 148 | } 149 | prompt_template = prompt_template_mapping[opt.dataset] 150 | model: LanguageModel 151 | model = get_model(opt.model) 152 | results = process_dataset( 153 | model, retriever, prompt_template, opt.topk, dataset, max_workers=opt.workers 154 | ) 155 | 156 | end = time.time() 157 | print(f"Retrieval time: {end - start:.2f}s") 158 | print(f"Average retrieval time: {(end - start) / len(dataset):.2f}s") 159 | output_dir = os.path.join(RETRIEVE_RESULT_PATH, "decomposed") 160 | os.makedirs(output_dir, exist_ok=True) 161 | output_path = os.path.join(output_dir, f"{opt.dataset}-@{opt.topk}.jsonl") 162 | write_jsonl(results, output_path) 163 | 164 | 165 | if __name__ == "__main__": 166 | options = parse_args() 167 | main(options) 168 | -------------------------------------------------------------------------------- /src/data_synthesize/chunk_sampling.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | from typing import Union 5 | 6 | import numpy as np 7 | 8 | sys.path.append(os.path.join(os.path.dirname(__file__), "..")) 9 | from conf import ( 10 | CORPUS_DATA_PATH, 11 | EMBEDDING_ALIAS, 12 | MODEL_DICT, 13 | SYNTHESIZED_TOKEN_LABELING_DATA_PATH, 14 | ) 15 | from language_models import AOAI 16 | from retrievers import Retriever 17 | from retrievers.embeddings import ModelTypes 18 | from utils import load_jsonl 19 | 20 | 21 | class ChunkSampler: 22 | def __init__( 23 | self, 24 | retriever: Retriever, 25 | ) -> None: 26 | self.retriever = retriever 27 | 28 | def sample(self, query: Union[str, list[str]], top_k: int = 10) -> list[list[dict]]: 29 | def parse_chunk(c): 30 | return { 31 | "id": set( 32 | [ 33 | (((pair := cid.split("-"))[0], int(pair[1])) if "-" in cid else (cid, "-1")) 34 | for cid in c["id"].split("//") 35 | ] 36 | ), 37 | "text": c["text"], 38 | } 39 | 40 | results = self.retriever.search(query, top_k) 41 | results = [[parse_chunk(c) for c in chunk] for chunk in results] 42 | return results 43 | 44 | 45 | def sample_origin_question(sampler: ChunkSampler, dataset: list[dict], top_k: int = 10): 46 | questions = [] 47 | oracles = [] 48 | for data in dataset: 49 | questions.append(data["question"]) 50 | oracles.append( 51 | set([(data["id"], chunk["positive_paragraph_idx"]) for k, chunk in data["decomposed_questions"].items()]) 52 | ) 53 | samples = sampler.sample(questions, top_k) 54 | coverages = eval(samples, oracles, questions) 55 | return np.round(np.mean(coverages), 4) 56 | 57 | 58 | def sample_sub_question(sampler: ChunkSampler, dataset: list[dict], top_k: int = 10): 59 | questions = [] 60 | oracles = [] 61 | for data in dataset: 62 | for k, chunk in data["decomposed_questions"].items(): 63 | questions.append(chunk["sub_question"]) 64 | oracles.append(set([(data["id"], chunk["positive_paragraph_idx"])])) 65 | samples = sampler.sample(questions, top_k) 66 | coverages = eval(samples, oracles, questions) 67 | return np.round(np.mean(coverages), 4) 68 | 69 | 70 | def sample_labeled_words(sampler: ChunkSampler, dataset: list[dict], top_k: int = 10): 71 | questions = [] 72 | oracles = [] 73 | for data in dataset: 74 | for k, chunk in data["decomposed_questions"].items(): 75 | labeled_words = " ".join(chunk["labeled_words"]) 76 | questions.append(labeled_words) 77 | oracles.append(set([(data["id"], chunk["positive_paragraph_idx"])])) 78 | samples = sampler.sample(questions, top_k) 79 | coverages = eval(samples, oracles, questions) 80 | return np.round(np.mean(coverages), 4) 81 | 82 | 83 | def sample_symmetric_diff(sampler: ChunkSampler, dataset: list[dict], top_k: int = 10): 84 | def construct_symmetric_diff_query(sample: dict) -> tuple[list[str], list[str]]: 85 | subq, subo = [], [] 86 | for sub_qid, chunk in data["decomposed_questions"].items(): 87 | # TODO: implement symmetric diff query 88 | subq.append(chunk["sub_question"]) 89 | subo.append(set([(data["id"], chunk["positive_paragraph_idx"])])) 90 | return subq, subo 91 | 92 | questions, oracles = [], [] 93 | for data in dataset: 94 | subq, subo = construct_symmetric_diff_query(data) 95 | questions.extend(subq) 96 | oracles.extend(subo) 97 | samples = sampler.sample(questions, top_k) 98 | coverages = eval(samples, oracles, questions) 99 | return np.round(np.mean(coverages), 4) 100 | 101 | 102 | def eval(samples: list[list[dict]], oracles: list[set], questions: list[str]): 103 | coverages = [] 104 | for oracle, sample in zip(oracles, samples): 105 | chunks = set() 106 | for chunk in sample: 107 | chunks.update(chunk["id"]) 108 | coverages.append(len(chunks.intersection(oracle)) / len(oracle)) 109 | return coverages 110 | 111 | 112 | def parse_args(): 113 | parser = argparse.ArgumentParser() 114 | parser.add_argument( 115 | "--dataset", 116 | type=str, 117 | choices=["hotpotQA", "musique", "2WikiMQA"], 118 | default="musique", 119 | ) 120 | parser.add_argument("--split", type=str, choices=["train", "valid", "test", "demo"], default="valid") 121 | parser.add_argument("--embedder", type=str, choices=list(ModelTypes.keys()), help="Embedding Model") 122 | parser.add_argument("--topk", type=int, default=5, help="Top k retrieval") 123 | parser.add_argument("--model", choices=["gpt35", "gpt4"], default="gpt4") 124 | args = parser.parse_args() 125 | return args 126 | 127 | 128 | def build_sampler(dataset: str, embedder: str): 129 | corpus_path = os.path.join(CORPUS_DATA_PATH, dataset) 130 | passages = os.path.join(corpus_path, "corpus.jsonl") 131 | embedding_path = os.path.join(corpus_path, EMBEDDING_ALIAS[embedder]) 132 | retriever = Retriever( 133 | passage_path=passages, 134 | passage_embedding_path=embedding_path, 135 | model_type=embedder, 136 | save_or_load_index=True, 137 | ) 138 | chunk_sampler = ChunkSampler(retriever) 139 | return chunk_sampler 140 | 141 | 142 | def main(opt: argparse.Namespace): 143 | dataset = load_jsonl(os.path.join(SYNTHESIZED_TOKEN_LABELING_DATA_PATH, opt.dataset, f"{opt.split}.jsonl")) 144 | chunk_sampler = build_sampler(opt.dataset, opt.embedder) 145 | for topk in (1, 3, 5, 10, 50, 100, 500, 1000, 2000): 146 | mean_coverage = sample_origin_question(chunk_sampler, dataset, topk) 147 | mean_coverage = sample_sub_question(chunk_sampler, dataset, topk) 148 | mean_coverage = sample_labeled_words(chunk_sampler, dataset, topk) 149 | 150 | print(f"{opt.dataset}-{opt.split}, embedder {opt.embedder}, top-{topk} mean coverage: {mean_coverage}") 151 | 152 | 153 | if __name__ == "__main__": 154 | options = parse_args() 155 | main(options) 156 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # EfficientRAG-official 2 | 3 |
4 | 5 |
6 | 7 | Code repo for EMNLP 2024 paper - **EfficientRAG: Efficient Retriever for Multi-Hop Question Answering** 8 | 9 | Efficient RAG is a new framework to train Labeler and Filter to learn to conduct multi-hop RAG without multiple LLM calls. 10 | 11 | ## Updates 12 | 13 | * 2024-09-12 open source the code 14 | * 2025-03-04 release our data 15 | 16 | ## Setup 17 | 18 | You can now download our synthesized data from this [link](https://box.nju.edu.cn/f/a86b512077c7489b8da3/). 19 | 20 | You should unzip the `EfficientRAG.zip` file and place all the data under the `data` directory. 21 | Within this directory, the `negative_sampling_extracted` folder contains our final synthesized data, which is referenced in [2.4 Negative Sampling](https://github.com/NIL-zhuang/EfficientRAG-official?tab=readme-ov-file#24-negative-sampling). 22 | Additionally, the `efficient_rag` directory includes two folders: `labeler` and `filter`, which store the training data constructed for the model, as referenced in [2.5 Training Data](https://github.com/NIL-zhuang/EfficientRAG-official?tab=readme-ov-file#25-training-data). 23 | 24 | ### 1. Installation 25 | 26 | You need to install PyTorch >= 2.1.0 first, and then install dependent Python libraries by running the command 27 | 28 | ```bash 29 | pip install -r requirements.txt 30 | ``` 31 | 32 | You can also create a conda environment with python>=3.9 33 | 34 | ```bash 35 | conda create -n python=3.9 pip 36 | conda activate 37 | pip install -r requirements.txt 38 | ``` 39 | 40 | ### Preparation 41 | 42 | 1. Download the dataset from [HotpotQA](https://huggingface.co/datasets/hotpotqa/hotpot_qa), [2WikiMQA](https://github.com/Alab-NII/2wikimultihop) and [MuSiQue](https://huggingface.co/datasets/dgslibisey/MuSiQue). Separate them as train, dev and test set, and then put them under `data/dataset`. 43 | 44 | 2. Download the retriever model [Contriever](https://huggingface.co/facebook/contriever-msmarco) and base model [DeBERTa](https://huggingface.co/microsoft/deberta-v3-large), put them under `model_cache` 45 | 46 | 3. Prepare the corpus by extract documents and construct embedding. 47 | 48 | ```bash 49 | python src/retrievers/multihop_data_extractor.py --dataset hotpotQA 50 | ``` 51 | 52 | ```bash 53 | python src/retrievers/passage_embedder.py \ 54 | --passages data/corpus/hotpotQA/corpus.jsonl \ 55 | --output_dir data/corpus/hotpotQA/contriever \ 56 | --model_type contriever 57 | ``` 58 | 59 | 4. Deploy [LLaMA-3-70B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3-70B-Instruct) with [vLLM](https://github.com/vllm-project/vllm) framework, and configure it in `src/language_models/llama.py` 60 | 61 | ### 2. Training Data Construction 62 | 63 | We will use hotpotQA training set as an example. You could construct 2WikiMQA and MuSiQue in the same way. 64 | 65 | #### 2.1 Query Decompose 66 | 67 | ```bash 68 | python src/data_synthesize/query_decompose.py \ 69 | --dataset hotpotQA \ 70 | --split train \ 71 | --model llama3 72 | ``` 73 | 74 | #### 2.2 Token Labeling 75 | 76 | ```bash 77 | python src/data_synthesize/token_labeling.py \ 78 | --dataset hotpotQA \ 79 | --split train \ 80 | --model llama3 81 | ``` 82 | 83 | ```bash 84 | python src/data_synthesize/token_extraction.py \ 85 | --data_path data/synthesized_token_labeling/hotpotQA/train.jsonl \ 86 | --save_path data/token_extracted/hotpotQA/train.jsonl \ 87 | --verbose 88 | ``` 89 | 90 | #### 2.3 Next Query Filtering 91 | 92 | ```bash 93 | python src/data_synthesize/next_hop_query_construction.py \ 94 | --dataset hotpotQA \ 95 | --split train \ 96 | --model llama 97 | ``` 98 | 99 | ```bash 100 | python src/data_synthesize/next_hop_query_filtering.py \ 101 | --data_path data/synthesized_next_query/hotpotQA/train.jsonl \ 102 | --save_path data/next_query_extracted/hotpotQA/train.jsonl \ 103 | --verbose 104 | ``` 105 | 106 | #### 2.4 Negative Sampling 107 | 108 | ```bash 109 | python src/data_synthesize/negative_sampling.py \ 110 | --dataset hotpotQA \ 111 | --split train \ 112 | --retriever contriever 113 | ``` 114 | 115 | ```bash 116 | python src/data_synthesize/negative_sampling_labeled.py \ 117 | --dataset hotpotQA \ 118 | --split train \ 119 | --model llama 120 | ``` 121 | 122 | ```bash 123 | python src/data_synthesize/negative_token_extraction.py \ 124 | --dataset hotpotQA \ 125 | --split train \ 126 | --verbose 127 | ``` 128 | 129 | #### 2.5 Training Data 130 | 131 | ```bash 132 | python src/data_synthesize/training_data_synthesize.py \ 133 | --dataset hotpotQA \ 134 | --split train 135 | ``` 136 | 137 | ## Training 138 | 139 | Training Filter model 140 | 141 | ```bash 142 | python src/efficient_rag/filter_training.py \ 143 | --dataset hotpotQA \ 144 | --save_path saved_models/filter 145 | ``` 146 | 147 | Training Labeler model 148 | 149 | ```bash 150 | python src/efficient_rag/labeler_training.py \ 151 | --dataset hotpotQA \ 152 | --tags 2 153 | ``` 154 | 155 | ## Inference 156 | 157 | EfficientRAG retrieve procedure 158 | 159 | ```bash 160 | python src/efficientrag_retrieve.py \ 161 | --dataset hotpotQA \ 162 | --retriever contriever \ 163 | --labels 2 \ 164 | --labeler_ckpt <> \ 165 | --filter_ckpt <> \ 166 | --topk 10 \ 167 | ``` 168 | 169 | Use LLaMA-3-8B-Instruct as generator 170 | ```bash 171 | python src/efficientrag_qa.py \ 172 | --fpath <> \ 173 | --model llama-8B \ 174 | --dataset hotpotQA 175 | ``` 176 | 177 | ## Citation 178 | 179 | If you find this paper or code useful, please cite by: 180 | 181 | ```txt 182 | @inproceedings{zhuang2024efficientrag, 183 | title={EfficientRAG: Efficient Retriever for Multi-Hop Question Answering}, 184 | author={Zhuang, Ziyuan and Zhang, Zhiyang and Cheng, Sitao and Yang, Fangkai and Liu, Jia and Huang, Shujian and Lin, Qingwei and Rajmohan, Saravan and Zhang, Dongmei and Zhang, Qi}, 185 | booktitle={Proceedings of the 2024 Conference on Empirical Methods in Natural Language Processing}, 186 | pages={3392--3411}, 187 | year={2024} 188 | } 189 | ``` 190 | -------------------------------------------------------------------------------- /src/data_synthesize/gpt_query_chunk_synethesize.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import re 4 | import sys 5 | from typing import Literal 6 | 7 | from tqdm import tqdm 8 | 9 | sys.path.append(os.path.join(os.path.dirname(__file__), "..")) 10 | 11 | from language_models import AOAI 12 | 13 | QUERY_TYPE = Literal["extremely long-tail", "long tail", "common"] 14 | # QUERY_LENGTH = Literal['less than 5 words', '5 to 15 words', 'at least 10 words'] 15 | DIFFICULTY = Literal["children", "high school", "college", "PhD"] 16 | CLARITY = Literal["clear", "understandable with some effort", "ambiguous"] 17 | DOC_WORDS = Literal["50", "100", "200", "300", "400", "500"] 18 | HOP = Literal["single-hop", "2-hop", "3-hop", "4-hop"] 19 | HOP_TYPE = Literal["compositional", "comparison", "bridge comparison", "inference"] 20 | 21 | # Task generation system prompt 22 | TASK_GENERATION_SYSTEM_PROMPT = """Brainstorm a list of Wikipedia QnA text retrieval tasks. Here are a few examples for your reference: 23 | - Make a comparison between two different types of animals. 24 | - Retrieve documents that bridge two people from different time periods. 25 | 26 | Please adhere to the following guidelines: 27 | - Specify what the query is, and what the desired documents are. 28 | - Each retrieval task should cover a wide range of queries, and should not be too specific. 29 | 30 | Your output must always be a python list of strings only, with about 20 elements, and each element corresponds to a distinct retrieval task in one sentence. Do not explain yourself or output anything else. Be creative!""" 31 | 32 | # Query and chunk generation system prompt 33 | QUERY_CHUNK_SYSTEM_PROMPT = """You have been assigned a retrieval task: {task} 34 | Your mission is to write one multi-hop Wikipedia QnA text retrieval example for this task in JSON format. The JSON object must contain the following keys: 35 | - "user_query": a string, a random user search query specified by the retrieval task. 36 | - "reasoning_chain": a string, a reasoning chain that describes the information and retrieve process of the user query. #! TODO: reasoning chain 37 | - "chunks": a list of search path. Each element in the list is a JSON object with the following keys 38 | - "positive_document": a string, a relevant document for the user query. 39 | - "hard_negative_document": a string, a hard negative document that only appears relevant to the query. 40 | 41 | Please adhere to the following guidelines: 42 | - The "user_query" should be {query_type}, {hop}, wikipedia focused,and diverse in topic. 43 | - The "user_query" should be a {hop_type} query. 44 | - All documents must be created independent of the query. Avoid copying the query verbatim. It’s acceptable if some parts of the "positive_document" are not topically related to the query. 45 | - All documents should be at least {num_words} words long. 46 | - The "hard_negative_document" contains some useful information, but it should be less useful or comprehensive compared to the "positive_document". 47 | - Both the query and documents should be in English. 48 | - Do not provide any explanation in any document on why it is relevant or not relevant to the query. 49 | - Both the query and documents require {difficulty} level education to understand. 50 | 51 | Your output must always be a JSON object only, do not explain yourself or output anything else. Be creative!""" 52 | 53 | 54 | class QueryChunkSynthesizer(object): 55 | def __init__( 56 | self, 57 | model: str = "gpt-4-1106-preview", 58 | ): 59 | if "gpt" in model: 60 | self.model = AOAI(model, api_version="2023-12-01-preview") 61 | else: 62 | raise ValueError("Model not supported") 63 | self.max_retry = 3 64 | 65 | def task_generation(self): 66 | trials = 0 67 | while trials < self.max_retry: 68 | try: 69 | tasks = self.model.chat(TASK_GENERATION_SYSTEM_PROMPT) 70 | matches = re.findall(r"```python(.*?)```", tasks, re.DOTALL) 71 | if matches: 72 | tasks = matches[0] 73 | tasks = tasks.strip("\n") 74 | tasks = eval(tasks) 75 | if isinstance(tasks, list): 76 | return tasks 77 | elif isinstance(tasks, str): 78 | return [tasks] 79 | except Exception as e: 80 | trials += 1 81 | return [] 82 | 83 | def query_chunk_generation( 84 | self, 85 | task: str, 86 | query_type: QUERY_TYPE, 87 | # query_length: QUERY_LENGTH, 88 | # clarity: CLARITY, 89 | num_words: DOC_WORDS, 90 | difficulty: DIFFICULTY, 91 | hop: HOP, 92 | hop_type: HOP_TYPE, 93 | ): 94 | trials = 0 95 | query = QUERY_CHUNK_SYSTEM_PROMPT.format( 96 | task=task, 97 | query_type=query_type, 98 | # query_length=query_length, clarity=clarity, 99 | num_words=num_words, 100 | difficulty=difficulty, 101 | hop=hop, 102 | hop_type=hop_type, 103 | ) 104 | while trials < self.max_retry: 105 | try: 106 | query_chunk = self.model.chat(query) 107 | matches = re.findall(r"```json(.*?)```", query_chunk, re.DOTALL) 108 | if matches: 109 | query_chunk = matches[0] 110 | query_chunk = query_chunk.strip("\n") 111 | data = json.loads(query_chunk) 112 | # for key in ("user_query", "positive_document", "hard_negative_document"): 113 | # if key not in data: 114 | # raise ValueError(f"Key {key} not found in the response") 115 | return data 116 | except Exception as e: 117 | trials += 1 118 | return None 119 | 120 | 121 | def main(): 122 | synthesizer = QueryChunkSynthesizer() 123 | tasks = synthesizer.task_generation() 124 | print(f"Tasks: {tasks}") 125 | data_list = [] 126 | for task in tqdm(tasks[:1]): 127 | data = synthesizer.query_chunk_generation(task, "common", "50", "children", "2-hop", "compositional") 128 | data_list.append(data) 129 | print(data_list) 130 | 131 | 132 | if __name__ == "__main__": 133 | main() 134 | -------------------------------------------------------------------------------- /src/data_synthesize/token_extraction.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import sys 5 | 6 | import spacy 7 | from tqdm.rich import tqdm_rich 8 | 9 | sys.path.append(os.path.join(os.path.dirname(__file__), "..")) 10 | from utils import load_jsonl, write_jsonl 11 | 12 | nlp = spacy.load("en_core_web_sm") 13 | 14 | 15 | def split_string(input_string, ignore_tokens=set([","])): 16 | doc = nlp(input_string) 17 | lemma_word_list = [] 18 | word_list = [] 19 | for word in doc: 20 | if word.lemma_ not in ignore_tokens: 21 | word_list.append(word.text) 22 | lemma_word_list.append(word.lemma_) 23 | return lemma_word_list, word_list 24 | 25 | 26 | def is_equal(token1, token2): 27 | return token1.lower() == token2.lower() 28 | 29 | 30 | def label_word( 31 | origin_paragraph: str, 32 | extracted_words: str, 33 | window_size: int = 150, 34 | verbose: bool = False, 35 | ): 36 | lemma_paragraph_tokens, paragraph_tokens = split_string(origin_paragraph) 37 | lemma_comp_tokens, comp_tokens = split_string(extracted_words) 38 | origin_lemma_tokens_set = set(lemma_paragraph_tokens) 39 | for lemma_paragraph_token in lemma_paragraph_tokens: 40 | origin_lemma_tokens_set.add(lemma_paragraph_token.lower()) 41 | 42 | num_find = 0 43 | prev_idx = 0 44 | num_origin_tokens = len(lemma_paragraph_tokens) 45 | labels = [False] * num_origin_tokens 46 | for lemma_comp_token, comp_token in zip(lemma_comp_tokens, comp_tokens): 47 | if ( 48 | lemma_comp_token in origin_lemma_tokens_set 49 | or lemma_comp_token.lower() in origin_lemma_tokens_set 50 | ): 51 | num_find += 1 52 | for i in range(window_size): 53 | # look forward 54 | token_idx = min(prev_idx + i, num_origin_tokens - 1) 55 | if ( 56 | is_equal(lemma_paragraph_tokens[token_idx], lemma_comp_token) 57 | and not labels[token_idx] 58 | ): 59 | labels[token_idx] = True 60 | # window do not go too fast 61 | if token_idx - prev_idx > window_size // 2: 62 | prev_idx += window_size // 2 63 | else: 64 | prev_idx = token_idx 65 | if verbose: 66 | print( 67 | lemma_comp_token, 68 | token_idx, 69 | prev_idx, 70 | lemma_paragraph_tokens[token_idx - 1 : token_idx + 2], 71 | ) 72 | break 73 | # look backward 74 | token_idx = max(prev_idx - i, 0) 75 | if ( 76 | is_equal(lemma_paragraph_tokens[token_idx], lemma_comp_token) 77 | and not labels[token_idx] 78 | ): 79 | labels[token_idx] = True 80 | prev_idx = token_idx 81 | if verbose: 82 | print( 83 | lemma_comp_token, 84 | token_idx, 85 | prev_idx, 86 | lemma_paragraph_tokens[token_idx - 1 : token_idx + 2], 87 | ) 88 | break 89 | 90 | retrieval_tokens = [] 91 | for idx, token in enumerate(paragraph_tokens): 92 | if labels[idx]: 93 | retrieval_tokens.append(token) 94 | matched = " ".join(retrieval_tokens) 95 | 96 | comp_rate = len(lemma_comp_tokens) / len(lemma_paragraph_tokens) 97 | if len(lemma_comp_tokens) > 0: 98 | find_rate = num_find / len(lemma_comp_tokens) 99 | else: 100 | find_rate = 0.0 101 | variation_rate = 1 - find_rate 102 | hitting_rate = num_find / len(lemma_paragraph_tokens) 103 | matching_rate = sum(labels) / len(labels) 104 | alignment_gap = hitting_rate - matching_rate 105 | 106 | if alignment_gap > 0.1: 107 | print(origin_paragraph) 108 | print("-" * 50) 109 | print(extracted_words) 110 | print("-" * 50) 111 | print(matched) 112 | print("-" * 50) 113 | print(paragraph_tokens) 114 | print("-" * 50) 115 | print(comp_tokens) 116 | print("-" * 50) 117 | print(retrieval_tokens) 118 | print("=" * 50) 119 | print( 120 | f"compress rate: {comp_rate}, variation_rate: {variation_rate}, alignment_gap: {alignment_gap}" 121 | ) 122 | 123 | return { 124 | "labels": labels, 125 | "matched": matched, 126 | "paragraph_tokens": paragraph_tokens, 127 | "comp_rate": comp_rate, 128 | "find_rate": find_rate, 129 | "variation_rate": variation_rate, 130 | "hitting_rate": hitting_rate, 131 | "matching_rate": matching_rate, 132 | "alignment_gap": alignment_gap, 133 | } 134 | 135 | 136 | def main(opts: argparse.Namespace): 137 | data = load_jsonl(opts.data_path) 138 | 139 | infos = { 140 | "comp_rate": 0.0, 141 | "variation_rate": 0.0, 142 | "hitting_rate": 0.0, 143 | "matching_rate": 0.0, 144 | "alignment_gap": 0.0, 145 | "find_rate": 0.0, 146 | } 147 | 148 | num_samples = 0 149 | for sample in tqdm_rich(data): 150 | flag = True 151 | for sid, subq in sample["decomposed_questions"].items(): 152 | if subq.get("extracted_words", None) is None: 153 | flag = False 154 | break 155 | if not flag: 156 | del sample 157 | continue 158 | for sid, subq in sample["decomposed_questions"].items(): 159 | results = label_word( 160 | subq["positive_paragraph"], 161 | subq["extracted_words"], 162 | verbose=opts.verbose, 163 | ) 164 | subq["labels"] = results["labels"] 165 | subq["matched"] = results["matched"] 166 | subq["paragraph_tokens"] = results["paragraph_tokens"] 167 | 168 | num_samples += 1 169 | for k in infos.keys(): 170 | infos[k] += results[k] 171 | 172 | for k, v in infos.items(): 173 | v = v / num_samples * 100 174 | print(f"{k}: {v:.2f}") 175 | 176 | os.makedirs(os.path.dirname(opts.save_path), exist_ok=True) 177 | write_jsonl(data, opts.save_path) 178 | 179 | 180 | def parse_args(): 181 | parser = argparse.ArgumentParser(description="annotate token") 182 | parser.add_argument("--data_path", required=True, type=str) 183 | parser.add_argument("--save_path", required=True, type=str) 184 | parser.add_argument("--verbose", action="store_true", default=False) 185 | args = parser.parse_args() 186 | return args 187 | 188 | 189 | if __name__ == "__main__": 190 | options = parse_args() 191 | main(options) 192 | -------------------------------------------------------------------------------- /src/language_models/aoai.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from time import sleep 4 | from typing import Optional, Union 5 | 6 | import numpy as np 7 | import openai 8 | from openai import AzureOpenAI 9 | from openai._types import NotGiven 10 | 11 | from .base import LanguageModel 12 | from .cloudgpt import auto_refresh_token, cloudGPT_available_models, get_openai_token 13 | 14 | AOAI_ENDPOINT = os.environ.get("AOAI_ENDPOINT", None) 15 | if AOAI_ENDPOINT is None: 16 | AOAI_ENDPOINT = "https://cloudgpt-openai.azure-api.net/" 17 | 18 | SLEEP_SEC = 3 19 | 20 | 21 | class AOAI(LanguageModel): 22 | def __init__( 23 | self, 24 | model: Union[str, cloudGPT_available_models] = "gpt-4-0125-preview", 25 | embedding_model: Optional[str] = "text-embedding-ada-002", 26 | api_version: str = "2024-02-15-preview", 27 | ): 28 | super().__init__(model) 29 | self.api_version = api_version 30 | self.client = AzureOpenAI( 31 | azure_endpoint=AOAI_ENDPOINT, 32 | api_key=get_openai_token(), 33 | api_version=self.api_version, 34 | ) 35 | self.embedding_model = embedding_model 36 | auto_refresh_token() 37 | 38 | def chat(self, messages: str, system_msg: str = None, **kwargs): 39 | try: 40 | response = self._chat(messages, system_msg, **kwargs) 41 | return response 42 | except openai.BadRequestError as e: 43 | err = json.loads(e.response.text) 44 | if err["error"]["code"] == "content_filter": 45 | print("Content filter triggered!") 46 | return None 47 | print(f"The OpenAI API request was invalid: {e}") 48 | return None 49 | except openai.APIConnectionError as e: 50 | print(f"The OpenAI API connection failed: {e}") 51 | sleep(SLEEP_SEC) 52 | return self.chat(messages, system_msg, **kwargs) 53 | except openai.RateLimitError as e: 54 | print(f"Token rate limit exceeded. Retrying after {SLEEP_SEC} second...") 55 | sleep(SLEEP_SEC) 56 | return self.chat(messages, system_msg, **kwargs) 57 | except openai.AuthenticationError as e: 58 | print(f"Invalid API token: {e}") 59 | self.update_api_key() 60 | sleep(SLEEP_SEC) 61 | return self.chat(messages, system_msg, **kwargs) 62 | except openai.APIError as e: 63 | if "The operation was timeout" in str(e): 64 | # Handle the timeout error here 65 | print("The OpenAI API request timed out. Please try again later.") 66 | sleep(SLEEP_SEC) 67 | return self.chat(messages, system_msg, **kwargs) 68 | elif "DeploymentNotFound" in str(e): 69 | print("The API deployment for this resource does not exist") 70 | return None 71 | else: 72 | # Handle other API errors here 73 | print(f"The OpenAI API returned an error: {e}") 74 | sleep(SLEEP_SEC) 75 | return self.chat(messages, system_msg, **kwargs) 76 | except Exception as e: 77 | print(f"An error occurred: {e}") 78 | 79 | def _chat( 80 | self, 81 | messages: str, 82 | system_msg="", 83 | temperature: float = 0.3, 84 | max_tokens: int = 1000, 85 | top_p: float = 0.95, 86 | frequency_penalty: float = 0.0, 87 | presence_penalty: float = 0.0, 88 | json_mode: bool = False, 89 | ): 90 | if system_msg is None or system_msg == "": 91 | system_msg = "You are a helpful assistant." 92 | msg = [ 93 | {"role": "system", "content": system_msg}, 94 | {"role": "user", "content": messages}, 95 | ] 96 | response = self.client.chat.completions.create( 97 | model=self.model, 98 | response_format={"type": "json_object"} if json_mode else NotGiven(), 99 | messages=msg, 100 | temperature=temperature, 101 | max_tokens=max_tokens, 102 | top_p=top_p, 103 | frequency_penalty=frequency_penalty, 104 | presence_penalty=presence_penalty, 105 | ) 106 | return response.choices[0].message.content 107 | 108 | def update_api_key(self): 109 | self.client = AzureOpenAI( 110 | azure_endpoint=AOAI_ENDPOINT, 111 | api_key=get_openai_token(), 112 | api_version=self.api_version, 113 | ) 114 | 115 | def _embed(self, query: Union[str, list[str]]): 116 | response = self.client.embeddings.create( 117 | input=query, 118 | model=self.embedding_model, 119 | ) 120 | result = np.array([response.data[idx].embedding for idx in range(len(response.data))]) 121 | if isinstance(query, str): 122 | return result[0] 123 | return result 124 | 125 | def embed(self, query: Union[str, list[str]], **kwargs): 126 | try: 127 | response = self._embed(query, **kwargs) 128 | return response 129 | except openai.BadRequestError as e: 130 | err = json.loads(e.response.text) 131 | if err["error"]["code"] == "content_filter": 132 | print("Content filter triggered!") 133 | return None 134 | print(f"The OpenAI API request was invalid: {e}") 135 | return None 136 | except openai.APIConnectionError as e: 137 | print(f"The OpenAI API connection failed: {e}") 138 | sleep(SLEEP_SEC) 139 | return self.embed(query, **kwargs) 140 | except openai.RateLimitError as e: 141 | print(f"Token rate limit exceeded. Retrying after {SLEEP_SEC} second...") 142 | sleep(SLEEP_SEC) 143 | return self.embed(query, **kwargs) 144 | except openai.AuthenticationError as e: 145 | print(f"Invalid API token: {e}") 146 | self.update_api_key() 147 | sleep(SLEEP_SEC) 148 | return self.embed(query, **kwargs) 149 | except openai.APIError as e: 150 | if "The operation was timeout" in str(e): 151 | # Handle the timeout error here 152 | print("The OpenAI API request timed out. Please try again later.") 153 | sleep(SLEEP_SEC) 154 | return self.embed(query, **kwargs) 155 | elif "DeploymentNotFound" in str(e): 156 | print("The API deployment for this resource does not exist") 157 | return None 158 | else: 159 | # Handle other API errors here 160 | print(f"The OpenAI API returned an error: {e}") 161 | sleep(SLEEP_SEC) 162 | return self.embed(query, **kwargs) 163 | except Exception as e: 164 | print(f"An error occurred: {e}") 165 | 166 | 167 | if __name__ == "__main__": 168 | aoai = AOAI(model="gpt-4-1106-preview") 169 | print(aoai.chat("Hello, how are you?", json_mode=False)) 170 | -------------------------------------------------------------------------------- /src/evaluation/correctness.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import re 4 | import string 5 | import sys 6 | from collections import Counter 7 | from concurrent.futures import ThreadPoolExecutor, as_completed 8 | 9 | from tqdm.rich import tqdm_rich 10 | 11 | sys.path.append(os.path.join(os.path.dirname(__file__), "..")) 12 | from language_models import LanguageModel, get_model 13 | from utils import ask_model, load_jsonl, write_jsonl 14 | 15 | LLM_EVAL_PROMPT = """ 16 | You are an experienced linguist who is responsible for evaluating the correctness of the generated responses. 17 | You are provided with question, the generated responses and the corresponding ground truth answer. 18 | Your task is to compare the generated responses with the ground truth responses and evaluate the correctness of the generated responses. 19 | Response in JSON format with key "response" and value "yes" or "no". 20 | 21 | Question: {question} 22 | Prediction: {prediction} 23 | Ground-truth Answer: {answer} 24 | Your response: 25 | """.strip() 26 | 27 | LLM_EXTRACT_ANSWER_PROMPT = """ 28 | Given a question, you should simplify the response to a more concise form of answer. If the response is already in a concise form, you can response with the same answer. If the response does not contain the answer, you can return "noanswer". 29 | You should come out the simplified answer in JSON format with key "answer" and the answer string as the value. Your response should be in markdown code block. Like 30 | ```json 31 | {"answer": "simplified answer"} 32 | ``` 33 | 34 | Question: {question} 35 | Response: {response} 36 | """.strip() 37 | 38 | 39 | def parse_args(): 40 | parser = argparse.ArgumentParser() 41 | parser.add_argument("--fpath", type=str, required=True) 42 | parser.add_argument("--model", type=str, default="llama") 43 | parser.add_argument("--extract_answer", action="store_true") 44 | parser.add_argument("--workers", type=int, default=10) 45 | args = parser.parse_args() 46 | return args 47 | 48 | 49 | def normalize_answer(s): 50 | 51 | def remove_articles(text): 52 | return re.sub(r"\b(a|an|the)\b", " ", text) 53 | 54 | def white_space_fix(text): 55 | return " ".join(text.split()) 56 | 57 | def remove_punc(text): 58 | exclude = set(string.punctuation) 59 | return "".join(ch for ch in text if ch not in exclude) 60 | 61 | def lower(text): 62 | return text.lower() 63 | 64 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 65 | 66 | 67 | def f1_score(prediction, ground_truth): 68 | normalized_prediction = normalize_answer(prediction) 69 | normalized_ground_truth = normalize_answer(ground_truth) 70 | 71 | ZERO_METRIC = (0, 0, 0) 72 | 73 | if ( 74 | normalized_prediction in ["yes", "no", "noanswer"] 75 | and normalized_prediction != normalized_ground_truth 76 | ): 77 | return ZERO_METRIC 78 | if ( 79 | normalized_ground_truth in ["yes", "no", "noanswer"] 80 | and normalized_prediction != normalized_ground_truth 81 | ): 82 | return ZERO_METRIC 83 | 84 | prediction_tokens = normalized_prediction.split() 85 | ground_truth_tokens = normalized_ground_truth.split() 86 | common = Counter(prediction_tokens) & Counter(ground_truth_tokens) 87 | num_same = sum(common.values()) 88 | if num_same == 0: 89 | return ZERO_METRIC 90 | precision = 1.0 * num_same / len(prediction_tokens) 91 | recall = 1.0 * num_same / len(ground_truth_tokens) 92 | f1 = (2 * precision * recall) / (precision + recall) 93 | return f1, precision, recall 94 | 95 | 96 | def exact_match(prediction, ground_truth): 97 | normalized_prediction = normalize_answer(prediction) 98 | normalized_answer = normalize_answer(ground_truth) 99 | if normalized_prediction in normalized_answer: 100 | return 1 101 | return 0 102 | 103 | 104 | def acc_evaluate(question: str, answer: str, prediction: str, model: LanguageModel): 105 | prompt = LLM_EVAL_PROMPT.format( 106 | question=question, prediction=prediction, answer=answer 107 | ) 108 | response = ask_model( 109 | model, 110 | prompt, 111 | mode="chat", 112 | type="json", 113 | check_if_valid=lambda resp: type(resp) is dict 114 | and "response" in resp 115 | and resp["response"] in ["yes", "no"], 116 | ) 117 | if response is None: 118 | return False 119 | return response["response"] == "yes" 120 | 121 | 122 | def extract_answer(question: str, response: str, model: LanguageModel): 123 | prompt = LLM_EXTRACT_ANSWER_PROMPT.format(question=question, response=response) 124 | result = ask_model( 125 | model, 126 | prompt, 127 | mode="chat", 128 | type="json", 129 | check_if_valid=lambda resp: type(resp) is dict and "answer" in resp, 130 | ) 131 | if result is None: 132 | return response 133 | return result["answer"] 134 | 135 | 136 | def evaluate_sample(sample, model: LanguageModel): 137 | question = sample["question"] 138 | answer = sample["answer"] 139 | 140 | if type(answer) is list: 141 | answer = " ".join(answer) 142 | assert type(answer) is str, f"Answer is not a string: {answer}" # noqa 143 | 144 | prediction = sample["model_output"] 145 | # prediction = sample["model_answer"] 146 | origin_pred = prediction 147 | if opts.extract_answer: 148 | prediction = extract_answer(question, prediction, model) 149 | 150 | acc = acc_evaluate(question, answer, prediction, model) 151 | f1, precision, recall = f1_score(prediction, answer) 152 | em = exact_match(prediction, answer) 153 | return { 154 | "question": question, 155 | "answer": answer, 156 | "origin_prediction": origin_pred, 157 | "prediction": prediction, 158 | "correctness": acc, 159 | "f1": f1, 160 | "em": em, 161 | } 162 | 163 | 164 | def main(opts: argparse.Namespace): 165 | fpath = opts.fpath 166 | data = load_jsonl(fpath) 167 | model = get_model(opts.model) 168 | base_path = os.path.splitext(fpath)[0] 169 | results = [] 170 | with ThreadPoolExecutor(max_workers=opts.workers) as executor: 171 | tasks = { 172 | executor.submit(evaluate_sample, sample, model): idx 173 | for idx, sample in enumerate(data) 174 | } 175 | 176 | for future in tqdm_rich( 177 | as_completed(tasks), total=len(tasks), desc="Evaluating" 178 | ): 179 | try: 180 | idx = tasks[future] 181 | result = future.result() 182 | results.append((result, idx)) 183 | except Exception as e: 184 | print(f"Error processing sample {idx}: {e}") 185 | finally: 186 | pass 187 | 188 | results = [result for result, _ in sorted(results, key=lambda x: x[1])] 189 | 190 | output_path = f"{base_path}_correctness.jsonl" 191 | write_jsonl(results, output_path) 192 | 193 | accuracy_score = sum([result["correctness"] for result in results]) / len(results) 194 | f1_score_avg = sum([result["f1"] for result in results]) / len(results) 195 | em_score_avg = sum([result["em"] for result in results]) / len(results) 196 | print(f"EM: {em_score_avg:.4f}") 197 | print(f"F1: {f1_score_avg:.4f}") 198 | print(f"Accuracy: {accuracy_score:.4f}") 199 | 200 | 201 | if __name__ == "__main__": 202 | opts = parse_args() 203 | main(opts) 204 | -------------------------------------------------------------------------------- /src/data_synthesize/negative_token_extraction.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2023 Microsoft 2 | # Licensed under The MIT License [see LICENSE for details] 3 | 4 | import argparse 5 | import json 6 | import os 7 | import sys 8 | 9 | import spacy 10 | from tqdm.rich import tqdm_rich 11 | 12 | sys.path.append(os.path.join(os.path.dirname(__file__), "..")) 13 | from conf import ( 14 | SYNTHESIZED_NEGATIVE_SAMPLING_EXTRACTED_DATA_PATH, 15 | SYNTHESIZED_NEGATIVE_SAMPLING_LABELED_DATA_PATH, 16 | ) 17 | from utils import load_jsonl, write_jsonl 18 | 19 | nlp = spacy.load("en_core_web_sm") 20 | 21 | 22 | def split_string(input_string, ignore_tokens=set([","])): 23 | doc = nlp(input_string) 24 | lemma_word_list = [] 25 | word_list = [] 26 | for word in doc: 27 | if word.lemma_ not in ignore_tokens: 28 | lemma_word_list.append(word.lemma_) 29 | word_list.append(word.text) 30 | return lemma_word_list, word_list 31 | 32 | 33 | def is_equal(token1, token2): 34 | return token1.lower() == token2.lower() 35 | 36 | 37 | def label_word( 38 | origin_paragraph: str, 39 | extracted_words: str, 40 | window_size: int = 150, 41 | verbose: bool = False, 42 | ): 43 | lemma_paragraph_tokens, paragraph_tokens = split_string(origin_paragraph) 44 | lemma_comp_tokens, comp_tokens = split_string(extracted_words) 45 | origin_lemma_tokens_set = set(lemma_paragraph_tokens) 46 | for lemma_paragraph_token in lemma_paragraph_tokens: 47 | origin_lemma_tokens_set.add(lemma_paragraph_token.lower()) 48 | 49 | num_find = 0 50 | prev_idx = 0 51 | num_origin_tokens = len(lemma_paragraph_tokens) 52 | labels = [False] * num_origin_tokens 53 | for lemma_comp_token, comp_token in zip(lemma_comp_tokens, comp_tokens): 54 | if ( 55 | lemma_comp_token in origin_lemma_tokens_set 56 | or lemma_comp_token.lower() in origin_lemma_tokens_set 57 | ): 58 | num_find += 1 59 | for i in range(window_size): 60 | # look forward 61 | token_idx = min(prev_idx + i, num_origin_tokens - 1) 62 | if ( 63 | is_equal(lemma_paragraph_tokens[token_idx], lemma_comp_token) 64 | and not labels[token_idx] 65 | ): 66 | labels[token_idx] = True 67 | # window do not go too fast 68 | if token_idx - prev_idx > window_size // 2: 69 | prev_idx += window_size // 2 70 | else: 71 | prev_idx = token_idx 72 | if verbose: 73 | print( 74 | lemma_comp_token, 75 | token_idx, 76 | prev_idx, 77 | lemma_paragraph_tokens[token_idx - 1 : token_idx + 2], 78 | ) 79 | break 80 | # look backward 81 | token_idx = max(prev_idx - i, 0) 82 | if ( 83 | is_equal(lemma_paragraph_tokens[token_idx], lemma_comp_token) 84 | and not labels[token_idx] 85 | ): 86 | labels[token_idx] = True 87 | prev_idx = token_idx 88 | if verbose: 89 | print( 90 | lemma_comp_token, 91 | token_idx, 92 | prev_idx, 93 | lemma_paragraph_tokens[token_idx - 1 : token_idx + 2], 94 | ) 95 | break 96 | 97 | retrieval_tokens = [] 98 | for idx, token in enumerate(paragraph_tokens): 99 | if labels[idx]: 100 | retrieval_tokens.append(token) 101 | matched = " ".join(retrieval_tokens) 102 | 103 | comp_rate = len(lemma_comp_tokens) / len(lemma_paragraph_tokens) 104 | if len(lemma_comp_tokens) > 0: 105 | find_rate = num_find / len(lemma_comp_tokens) 106 | else: 107 | find_rate = 0.0 108 | variation_rate = 1 - find_rate 109 | hitting_rate = num_find / len(lemma_paragraph_tokens) 110 | matching_rate = sum(labels) / len(labels) 111 | alignment_gap = hitting_rate - matching_rate 112 | 113 | if alignment_gap > 0.1: 114 | print(origin_paragraph) 115 | print("-" * 50) 116 | print(extracted_words) 117 | print("-" * 50) 118 | print(matched) 119 | print("-" * 50) 120 | print(lemma_paragraph_tokens) 121 | print("-" * 50) 122 | print(lemma_comp_tokens) 123 | print("-" * 50) 124 | print(retrieval_tokens) 125 | print("=" * 50) 126 | print( 127 | f"compress rate: {comp_rate}, variation_rate: {variation_rate}, alignment_gap: {alignment_gap}" 128 | ) 129 | 130 | return { 131 | "labels": labels, 132 | "matched": matched, 133 | "paragraph_tokens": paragraph_tokens, 134 | "comp_rate": comp_rate, 135 | "find_rate": find_rate, 136 | "variation_rate": variation_rate, 137 | "hitting_rate": hitting_rate, 138 | "matching_rate": matching_rate, 139 | "alignment_gap": alignment_gap, 140 | } 141 | 142 | 143 | def main(opts: argparse.Namespace): 144 | data = load_jsonl(opts.data_path) 145 | 146 | infos = { 147 | "comp_rate": 0.0, 148 | "variation_rate": 0.0, 149 | "hitting_rate": 0.0, 150 | "matching_rate": 0.0, 151 | "alignment_gap": 0.0, 152 | "find_rate": 0.0, 153 | } 154 | 155 | num_samples = 0 156 | for sample in tqdm_rich(data): 157 | flag = True 158 | for sid, subq in sample["decomposed_questions"].items(): 159 | if subq.get("negative_extracted_words", None) is None: 160 | flag = False 161 | break 162 | if not flag: 163 | del sample 164 | continue 165 | for sid, subq in sample["decomposed_questions"].items(): 166 | results = label_word( 167 | subq["negative_paragraph"], 168 | subq["negative_extracted_words"], 169 | verbose=opts.verbose, 170 | ) 171 | subq["negative_labels"] = results["labels"] 172 | subq["negative_matched"] = results["matched"] 173 | subq["negative_paragraph_tokens"] = results["paragraph_tokens"] 174 | 175 | num_samples += 1 176 | for k in infos.keys(): 177 | infos[k] += results[k] 178 | 179 | for k, v in infos.items(): 180 | v = v / num_samples * 100 181 | print(f"{k}: {v:.2f}") 182 | 183 | os.makedirs(os.path.dirname(opts.save_path), exist_ok=True) 184 | write_jsonl(data, opts.save_path) 185 | 186 | 187 | def parse_args(): 188 | parser = argparse.ArgumentParser(description="annotate token") 189 | parser.add_argument("--verbose", action="store_true", default=False) 190 | parser.add_argument( 191 | "--dataset", 192 | required=True, 193 | type=str, 194 | choices=["hotpotQA", "musique", "2WikiMQA"], 195 | ) 196 | parser.add_argument("--split", required=True, type=str, default="demo") 197 | args = parser.parse_args() 198 | args.data_path = os.path.join( 199 | SYNTHESIZED_NEGATIVE_SAMPLING_LABELED_DATA_PATH, 200 | args.dataset, 201 | f"{args.split}.jsonl", 202 | ) 203 | args.save_path = os.path.join( 204 | SYNTHESIZED_NEGATIVE_SAMPLING_EXTRACTED_DATA_PATH, 205 | args.dataset, 206 | f"{args.split}.jsonl", 207 | ) 208 | return args 209 | 210 | 211 | if __name__ == "__main__": 212 | options = parse_args() 213 | main(options) 214 | -------------------------------------------------------------------------------- /src/efficient_rag/labeler_training.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | 5 | sys.path.append(os.path.join(os.path.dirname(__file__), "..")) 6 | from datetime import datetime 7 | 8 | import torch 9 | import torch.nn as nn 10 | import wandb 11 | from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score 12 | from torch.nn import DataParallel 13 | from transformers import DebertaV2Tokenizer, EvalPrediction, Trainer, TrainingArguments 14 | 15 | from conf import ( 16 | EFFICIENT_RAG_LABELER_TRAINING_DATA_PATH, 17 | MODEL_PATH, 18 | TAG_MAPPING, 19 | TAG_MAPPING_TWO, 20 | TERMINATE_ID, 21 | ) 22 | from efficient_rag.data import LabelerDataset 23 | from efficient_rag.model import DebertaForSequenceTokenClassification 24 | from utils import load_jsonl 25 | 26 | 27 | class LabelerTrainer(Trainer): 28 | def compute_loss( 29 | self, 30 | model: DebertaForSequenceTokenClassification, 31 | inputs: dict, 32 | return_outputs: bool = False, 33 | ): 34 | inputs = {k: v.cuda() for k, v in inputs.items()} 35 | token_labels = inputs.pop("token_labels") 36 | sequence_labels = inputs.pop("sequence_labels") 37 | outputs = model(**inputs) 38 | token_logits = outputs.token_logits 39 | sequence_logits = outputs.sequence_logits 40 | 41 | weight = torch.tensor(WEIGHT_AVERAGE).cuda() # noqa 42 | loss_fct = nn.CrossEntropyLoss(weight=weight) 43 | selected_sequence_logits = sequence_logits.argmax(-1) 44 | 45 | # Remove the token logits that labeled as 46 | token_logits = token_logits[selected_sequence_logits != TERMINATE_ID] 47 | token_labels = token_labels[selected_sequence_logits != TERMINATE_ID] 48 | 49 | module = model 50 | if type(model) == DataParallel: 51 | module = model.module 52 | 53 | token_loss = loss_fct( 54 | token_logits.view(-1, module.token_labels), 55 | token_labels.view(-1), 56 | ) 57 | sequence_loss = loss_fct( 58 | sequence_logits.view(-1, module.sequence_labels), 59 | sequence_labels.view(-1), 60 | ) 61 | wandb.log({"token_loss": token_loss, "sequence_loss": sequence_loss}) 62 | loss = token_loss + sequence_loss 63 | return (loss, outputs) if return_outputs else loss 64 | 65 | 66 | def eval_labeler(pred: EvalPrediction): 67 | tag_prediction = torch.tensor(pred.predictions[0].argmax(-1)) 68 | token_prediction = torch.tensor(pred.predictions[1].argmax(-1)) 69 | tag_label = torch.tensor(pred.label_ids[1]) 70 | token_label = torch.tensor(pred.label_ids[0]) 71 | 72 | mask = torch.tensor(pred.inputs != 0) 73 | token_prediction = torch.masked_select(token_prediction, mask) 74 | token_label = torch.masked_select(token_label, mask) 75 | 76 | result = { 77 | "tag_accuracy": accuracy_score(tag_prediction, tag_label), 78 | "token_accuracy": accuracy_score(token_prediction, token_label), 79 | } 80 | metrics = { 81 | # "recall": recall_score, 82 | # "precision": precision_score, 83 | "f1": f1_score, 84 | } 85 | 86 | tag_f1 = f1_score(tag_prediction, tag_label, average=None, zero_division=0) 87 | for tag, idx in CHUNK_TAG_MAPPING.items(): 88 | result[f"tag_f1-{tag.strip('<>')}"] = tag_f1[idx] 89 | result["tag_f1"] = f1_score(token_prediction, token_label, average="micro") 90 | 91 | token_f1 = f1_score(tag_prediction, tag_label, average=None) 92 | result["token_f1_positive"] = token_f1[1] 93 | result["token_f1_negative"] = token_f1[0] 94 | result["token_f1"] = f1_score(tag_prediction, tag_label, average="micro") 95 | return result 96 | 97 | 98 | def parse_args(): 99 | parser = argparse.ArgumentParser(description="EfficientRAG Query Labeler") 100 | parser.add_argument( 101 | "--tags", type=int, default=2, choices=[2, 3], help="number of tags" 102 | ) 103 | parser.add_argument("--dataset", required=True, type=str) 104 | parser.add_argument("--lr", help="learning rate", default=5e-6, type=float) 105 | parser.add_argument("--epoch", default=2, type=int) 106 | parser.add_argument("--batch_size", type=int, default=32) 107 | parser.add_argument("--max_length", type=int, default=384) 108 | parser.add_argument("--warmup_steps", type=int, default=200) 109 | parser.add_argument("--eval_steps", type=int, default=200) 110 | parser.add_argument("--logging_steps", type=int, default=10) 111 | parser.add_argument("--test", action="store_true") 112 | parser.add_argument("--test_samples", type=int, default=100) 113 | parser.add_argument("--weight_average", action="store_true") 114 | args = parser.parse_args() 115 | return args 116 | 117 | 118 | def build_dataset( 119 | dataset: str, 120 | split: str, 121 | max_len: int = 128, 122 | tokenizer=None, 123 | test_mode: bool = False, 124 | test_sample_cnt: int = 100, 125 | ): 126 | data_path = os.path.join( 127 | EFFICIENT_RAG_LABELER_TRAINING_DATA_PATH, dataset, f"{split}.jsonl" 128 | ) 129 | data = load_jsonl(data_path) 130 | original_question = [d["question"] for d in data] 131 | chunk_tokens = [d["chunk_tokens"] for d in data] 132 | chunk_labels = [d["labels"] for d in data] 133 | tags = [CHUNK_TAG_MAPPING[d["tag"]] for d in data] 134 | 135 | if test_mode: 136 | return LabelerDataset( 137 | original_question[:test_sample_cnt], 138 | chunk_tokens[:test_sample_cnt], 139 | chunk_labels[:test_sample_cnt], 140 | tags[:test_sample_cnt], 141 | max_len, 142 | tokenizer, 143 | ) 144 | return LabelerDataset( 145 | original_question, chunk_tokens, chunk_labels, tags, max_len, tokenizer 146 | ) 147 | 148 | 149 | def main(opt: argparse.Namespace): 150 | global tokenizer 151 | global CHUNK_TAG_MAPPING 152 | if opt.tags == 2: 153 | WANDB_PROJ_NAME = "EfficientRAG_labeler_two" 154 | CHUNK_TAG_MAPPING = TAG_MAPPING_TWO 155 | elif opt.tags == 3: 156 | WANDB_PROJ_NAME = "EfficientRAG_labeler" 157 | CHUNK_TAG_MAPPING = TAG_MAPPING 158 | os.environ["WANDB_PROJECT"] = WANDB_PROJ_NAME 159 | 160 | global WEIGHT_AVERAGE 161 | if opt.weight_average: 162 | # TODO: Add hotpotQA and 2WikiMQA weight here, noqa 163 | WEIGHT_AVERAGE = { 164 | "hotpotQA": [1.0, 1.0], 165 | "2WikiMQA": [0.51, 25.32], 166 | "musique": [0.51, 25.45], 167 | }[opt.dataset] 168 | else: 169 | WEIGHT_AVERAGE = [1.0, 1.0] 170 | 171 | 172 | tokenizer = DebertaV2Tokenizer.from_pretrained(MODEL_PATH) 173 | model = DebertaForSequenceTokenClassification.from_pretrained( 174 | MODEL_PATH, sequence_labels=opt.tags, token_labels=2 175 | ) 176 | save_path_mapping = { 177 | 2: "saved_models/labeler_two", 178 | 3: "saved_models/labeler", 179 | } 180 | save_dir = os.path.join( 181 | save_path_mapping[opt.tags], 182 | f"labeler_{datetime.now().strftime(r'%Y%m%d_%H%M%S')}", 183 | ) 184 | run_name = f"{opt.dataset}-{datetime.now().strftime(r'%m%d%H%M')}" 185 | train_dataset = build_dataset( 186 | opt.dataset, 187 | "train", 188 | opt.max_length, 189 | tokenizer, 190 | test_mode=opt.test, 191 | test_sample_cnt=opt.test_samples, 192 | ) 193 | valid_dataset = build_dataset( 194 | opt.dataset, 195 | "valid", 196 | opt.max_length, 197 | tokenizer, 198 | ) 199 | 200 | training_args = TrainingArguments( 201 | output_dir=save_dir, 202 | num_train_epochs=opt.epoch, 203 | learning_rate=opt.lr, 204 | per_device_train_batch_size=opt.batch_size, 205 | per_device_eval_batch_size=96, 206 | weight_decay=0.01, 207 | logging_dir=os.path.join(save_dir, "log"), 208 | save_strategy="epoch" if not opt.test else "no", 209 | evaluation_strategy="steps", 210 | eval_steps=opt.eval_steps, 211 | report_to="wandb", 212 | run_name=run_name, 213 | logging_steps=opt.logging_steps, 214 | warmup_steps=opt.warmup_steps, 215 | save_only_model=True, 216 | include_inputs_for_metrics=True, 217 | ) 218 | trainer = LabelerTrainer( 219 | model=model, 220 | args=training_args, 221 | train_dataset=train_dataset, 222 | eval_dataset=valid_dataset, 223 | tokenizer=tokenizer, 224 | compute_metrics=eval_labeler, 225 | ) 226 | trainer.train() 227 | 228 | 229 | if __name__ == "__main__": 230 | options = parse_args() 231 | if options.test: 232 | os.environ["WANDB_MODE"] = "dryrun" 233 | main(options) 234 | -------------------------------------------------------------------------------- /src/data_synthesize/next_hop_query_filtering.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | 5 | import spacy 6 | from tqdm.rich import tqdm_rich 7 | 8 | sys.path.append(os.path.join(os.path.dirname(__file__), "..")) 9 | from data_module.format import build_query_info_sentence 10 | from utils import load_jsonl, write_jsonl 11 | 12 | nlp = spacy.load("en_core_web_sm") 13 | 14 | 15 | def split_string(input_string, ignore_tokens=set([","])): 16 | doc = nlp(input_string) 17 | lemma_word_list = [] 18 | word_list = [] 19 | for word in doc: 20 | if word.lemma_ not in ignore_tokens: 21 | lemma_word_list.append(word.lemma_) 22 | word_list.append(word.text) 23 | return lemma_word_list, word_list 24 | 25 | 26 | def is_equal(token1, token2): 27 | return token1.lower() == token2.lower() 28 | 29 | 30 | def label_word( 31 | origin_paragraph: str, 32 | extracted_words: str, 33 | window_size: int = 150, 34 | verbose: bool = False, 35 | ): 36 | lemma_paragraph_tokens, paragraph_tokens = split_string(origin_paragraph) 37 | lemma_comp_tokens, comp_tokens = split_string(extracted_words) 38 | origin_lemma_tokens_set = set(lemma_paragraph_tokens) 39 | for lemma_paragraph_token in lemma_paragraph_tokens: 40 | origin_lemma_tokens_set.add(lemma_paragraph_token.lower()) 41 | 42 | num_find = 0 43 | prev_idx = 0 44 | num_origin_tokens = len(lemma_paragraph_tokens) 45 | labels = [False] * num_origin_tokens 46 | for lemma_comp_token, comp_token in zip(lemma_comp_tokens, comp_tokens): 47 | if ( 48 | lemma_comp_token in origin_lemma_tokens_set 49 | or lemma_comp_token.lower() in origin_lemma_tokens_set 50 | ): 51 | num_find += 1 52 | for i in range(window_size): 53 | # look forward 54 | token_idx = min(prev_idx + i, num_origin_tokens - 1) 55 | if ( 56 | is_equal(lemma_paragraph_tokens[token_idx], lemma_comp_token) 57 | and not labels[token_idx] 58 | ): 59 | labels[token_idx] = True 60 | # window do not go too fast 61 | if token_idx - prev_idx > window_size // 2: 62 | prev_idx += window_size // 2 63 | else: 64 | prev_idx = token_idx 65 | if verbose: 66 | print( 67 | lemma_comp_token, 68 | token_idx, 69 | prev_idx, 70 | lemma_paragraph_tokens[token_idx - 1 : token_idx + 2], 71 | ) 72 | break 73 | # look backward 74 | token_idx = max(prev_idx - i, 0) 75 | if ( 76 | is_equal(lemma_paragraph_tokens[token_idx], lemma_comp_token) 77 | and not labels[token_idx] 78 | ): 79 | labels[token_idx] = True 80 | prev_idx = token_idx 81 | if verbose: 82 | print( 83 | lemma_comp_token, 84 | token_idx, 85 | prev_idx, 86 | lemma_paragraph_tokens[token_idx - 1 : token_idx + 2], 87 | ) 88 | break 89 | 90 | retrieval_tokens = [] 91 | for idx, token in enumerate(paragraph_tokens): 92 | if labels[idx]: 93 | retrieval_tokens.append(token) 94 | matched = " ".join(retrieval_tokens) 95 | 96 | comp_rate = len(lemma_comp_tokens) / len(lemma_paragraph_tokens) 97 | if len(lemma_comp_tokens) > 0: 98 | find_rate = num_find / len(lemma_comp_tokens) 99 | else: 100 | find_rate = 0.0 101 | variation_rate = 1 - find_rate 102 | hitting_rate = num_find / len(lemma_paragraph_tokens) 103 | matching_rate = sum(labels) / len(labels) 104 | alignment_gap = hitting_rate - matching_rate 105 | 106 | if alignment_gap > 0.1: 107 | print(origin_paragraph) 108 | print("-" * 50) 109 | print(extracted_words) 110 | print("-" * 50) 111 | print(matched) 112 | print("-" * 50) 113 | print(paragraph_tokens) 114 | print("-" * 50) 115 | print(comp_tokens) 116 | print("-" * 50) 117 | print(retrieval_tokens) 118 | print("=" * 50) 119 | print( 120 | f"compress rate: {comp_rate}, variation_rate: {variation_rate}, alignment_gap: {alignment_gap}" 121 | ) 122 | 123 | return { 124 | "labels": labels, 125 | "matched": matched, 126 | "paragraph_tokens": paragraph_tokens, 127 | "comp_rate": comp_rate, 128 | "find_rate": find_rate, 129 | "variation_rate": variation_rate, 130 | "hitting_rate": hitting_rate, 131 | "matching_rate": matching_rate, 132 | "alignment_gap": alignment_gap, 133 | } 134 | 135 | 136 | def extract_next_hop_sample(sample: dict, sid: str) -> tuple[dict]: 137 | dependency = sample["decomposed_questions"][sid]["dependency"] 138 | 139 | for dependent_id in dependency: 140 | if ( 141 | dependent_id not in sample["decomposed_questions"].keys() 142 | or dependent_id == sid 143 | ): 144 | dependency.remove(dependent_id) 145 | info_list = [ 146 | sample["decomposed_questions"][dep_id]["matched"] for dep_id in dependency 147 | ] 148 | prev_question = sample["decomposed_questions"][dependency[0]]["filtered_query"] 149 | query_info_sentence = build_query_info_sentence(info_list, prev_question) 150 | return query_info_sentence 151 | 152 | 153 | def extract_next_hop_sample_2wiki(sample: dict, sid: str) -> tuple[dict]: 154 | dependency = sample["decomposed_questions"][sid]["dependency"] 155 | 156 | for dependent_id in dependency: 157 | if ( 158 | dependent_id not in sample["decomposed_questions"].keys() 159 | or dependent_id == sid 160 | ): 161 | dependency.remove(dependent_id) 162 | # info_list = [ 163 | # sample["decomposed_questions"][dep_id]["matched"] for dep_id in dependency 164 | # ] 165 | info_list = [ 166 | chunk["matched"] 167 | for chunk in sample["decomposed_questions"].values() 168 | if len(chunk["dependency"]) == 0 169 | ] 170 | prev_question = sample["decomposed_questions"][dependency[0]]["filtered_query"] 171 | query_info_sentence = build_query_info_sentence(info_list, prev_question) 172 | return query_info_sentence 173 | 174 | 175 | def main(opts: argparse.Namespace): 176 | data = load_jsonl(opts.data_path) 177 | 178 | infos = { 179 | "comp_rate": 0.0, 180 | "variation_rate": 0.0, 181 | "hitting_rate": 0.0, 182 | "matching_rate": 0.0, 183 | "alignment_gap": 0.0, 184 | "find_rate": 0.0, 185 | } 186 | 187 | num_samples = 0 188 | for sample in tqdm_rich(data): 189 | flag = True 190 | for sid, subq in sample["decomposed_questions"].items(): 191 | if subq.get("filtered_query", None) is None: 192 | flag = False 193 | break 194 | if not flag: 195 | del sample 196 | continue 197 | 198 | for sid, subq in sample["decomposed_questions"].items(): 199 | if len(subq["dependency"]) == 0: 200 | subq["query_info"] = subq["filtered_query"] 201 | continue 202 | 203 | constructed_query = subq["filtered_query"] 204 | if "2wiki" in opts.data_path.lower(): 205 | query_info_pairs = extract_next_hop_sample_2wiki(sample, sid) 206 | else: 207 | query_info_pairs = extract_next_hop_sample(sample, sid) 208 | results = label_word( 209 | query_info_pairs, constructed_query, verbose=opts.verbose 210 | ) 211 | subq["query_info_labels"] = results["labels"] 212 | subq["query_info"] = results["matched"] 213 | subq["query_info_tokens"] = results["paragraph_tokens"] 214 | 215 | num_samples += 1 216 | for k in infos.keys(): 217 | infos[k] += results[k] 218 | 219 | for k, v in infos.items(): 220 | v = v / num_samples * 100 221 | print(f"{k}: {v:.2f}") 222 | 223 | os.makedirs(os.path.dirname(opts.save_path), exist_ok=True) 224 | write_jsonl(data, opts.save_path) 225 | 226 | 227 | def parse_args(): 228 | parser = argparse.ArgumentParser(description="annotate token") 229 | parser.add_argument("--data_path", required=True, type=str) 230 | parser.add_argument("--save_path", required=True, type=str) 231 | parser.add_argument("--verbose", action="store_true", default=False) 232 | args = parser.parse_args() 233 | return args 234 | 235 | 236 | if __name__ == "__main__": 237 | options = parse_args() 238 | main(options) 239 | -------------------------------------------------------------------------------- /src/baseline/retrieve/direct.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | from concurrent.futures import ThreadPoolExecutor, as_completed 5 | from typing import List 6 | 7 | sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..")) 8 | 9 | import time 10 | 11 | from tqdm.rich import tqdm_rich 12 | 13 | from conf import ( 14 | CORPUS_DATA_PATH, 15 | RETRIEVE_RESULT_PATH, 16 | SYNTHESIZED_NEXT_QUERY_EXTRACTED_DATA_PATH, 17 | ) 18 | from language_models import LanguageModel, get_model 19 | from retrievers import Retriever 20 | from utils import ask_model, load_jsonl, write_jsonl 21 | 22 | DIRECT_RETRIEVE_ANSWER_PROMPT_WIKIMQA = """ 23 | As an assistant, your task is to answer the question based on the given knowledge. Your answer should be after in JSON format with key "answer" and its value should be string. 24 | Your must be wrapped by ```json and ```. 25 | The given knowledge will be embraced by and tags. You can refer to the knowledge to answer the question. If the knowledge does not contain the answer, answer the question directly. 26 | 27 | There are some examples for you to refer to: 28 | 29 | {{KNOWLEDGE FOR YOUR REFERENCE}} 30 | 31 | : Which film came out first, Blind Shaft or The Mask Of Fu Manchu? 32 | : 33 | ```json 34 | {{"answer": "The Mask Of Fu Manchu"}} 35 | ``` 36 | 37 | 38 | {{KNOWLEDGE FOR YOUR REFERENCE}} 39 | 40 | : When did John V, Prince Of Anhalt-Zerbst's father die? 41 | : 42 | ```json 43 | {{"answer": "12 June 1516"}} 44 | ``` 45 | 46 | 47 | {{KNOWLEDGE FOR YOUR REFERENCE}} 48 | 49 | : Which film has the director who was born later, El Extrano Viaje or Love In Pawn? 50 | : 51 | ```json 52 | {{"answer": "El Extrano Viaje"}} 53 | ``` 54 | 55 | Now your question and reference knowledge are as follows. 56 | 57 | {knowledge} 58 | 59 | : {question} 60 | : 61 | """.strip() 62 | 63 | DIRECT_RETRIEVE_ANSWER_PROMPT_MUSIQUE = """ 64 | As an assistant, your task is to answer the question based on the given knowledge. Your answer should be after in JSON format with key "answer" and its value should be string. 65 | Your must be wrapped by ```json and ```. 66 | The given knowledge will be embraced by and tags. You can refer to the knowledge to answer the question. If the knowledge does not contain the answer, answer the question directly. 67 | 68 | There are some examples for you to refer to: 69 | 70 | 71 | {{KNOWLEDGE FOR YOUR REFERENCE}} 72 | 73 | : In which year did the publisher of In Cold Blood form? 74 | : 75 | ```json 76 | {{"answer": "2001"}} 77 | ``` 78 | 79 | 80 | {{KNOWLEDGE FOR YOUR REFERENCE}} 81 | 82 | : Who was in charge of the city where The Killing of a Sacred Deer was filmed? 83 | : 84 | ```json 85 | {{"answer": "John Cranley"}} 86 | ``` 87 | 88 | 89 | {{KNOWLEDGE FOR YOUR REFERENCE}} 90 | 91 | : Where on the Avalon Peninsula is the city that Signal Hill overlooks? 92 | : 93 | ```json 94 | {{"answer": "eastern tip"}} 95 | ``` 96 | 97 | Now your question and reference knowledge are as follows. 98 | 99 | {knowledge} 100 | 101 | : {question} 102 | : 103 | """.strip() 104 | 105 | DIRECT_RETRIEVE_ANSWER_PROMPT_HOTPOTQA = """ 106 | Answer the given question in JSON format, you can refer to the document provided. 107 | As an assistant, your task is to answer the question based on the given knowledge. Your answer should be after in JSON format with key "answer" and its value should be string. 108 | Your must be wrapped by ```json and ```. 109 | The given knowledge will be embraced by and tags. You can refer to the knowledge to answer the question. If the knowledge does not contain the answer, answer the question directly. 110 | 111 | There are some examples for you to refer to: 112 | 113 | {{KNOWLEDGE FOR YOUR REFERENCE}} 114 | 115 | : What is the name of this American musician, singer, actor, comedian, and songwriter, who worked with Modern Records and born in December 5, 1932? 116 | : 117 | ```json 118 | {{"answer": "Little Richard"}} 119 | ``` 120 | 121 | 122 | {{KNOWLEDGE FOR YOUR REFERENCE}} 123 | 124 | : Between Chinua Achebe and Rachel Carson, who had more diverse jobs? 125 | : 126 | ```json 127 | {{"answer": "Chinua Achebe"}} 128 | ``` 129 | 130 | 131 | {{KNOWLEDGE FOR YOUR REFERENCE}} 132 | 133 | : Remember Me Ballin' is a CD single by Indo G that features an American rapper born in what year? 134 | : 135 | ```json 136 | {{"answer": "1979"}} 137 | ``` 138 | 139 | Now your question and reference knowledge are as follows. 140 | 141 | {knowledge} 142 | 143 | : {question} 144 | : 145 | """.strip() 146 | 147 | 148 | def parse_args(): 149 | parser = argparse.ArgumentParser() 150 | parser.add_argument("--dataset", type=str, required=True) 151 | parser.add_argument("--retriever", type=str, required=True) 152 | parser.add_argument("--topk", type=int, default=10) 153 | parser.add_argument("--model", type=str, default="llama-8B") 154 | parser.add_argument("--workers", type=int, default=10) 155 | return parser.parse_args() 156 | 157 | 158 | def process_sample( 159 | model: LanguageModel, 160 | retriever: Retriever, 161 | prompt_template: str, 162 | sample: dict, 163 | topk: int = 10, 164 | ) -> dict: 165 | question = sample["question"] 166 | knowledge = retriever.search(question, top_k=topk)[0] 167 | knowledge_chunk = "\n".join([p["text"] for p in knowledge]) 168 | chunk_ids = [p["id"] for p in knowledge] 169 | prompt = prompt_template.format(knowledge=knowledge_chunk, question=question) 170 | result = ask_model( 171 | model, 172 | prompt, 173 | type="json", 174 | check_if_valid=lambda x: type(x) is dict and "answer" in x, 175 | mode="chat", 176 | ) 177 | sub_ids = sorted(list(sample["decomposed_questions"].keys())) 178 | oracle_ids = [ 179 | f"{sample['id']}-{'{:02d}'.format(sample['decomposed_questions'][sub_id]['positive_paragraph_idx'])}" 180 | for sub_id in sub_ids 181 | ] 182 | result = { 183 | "question_id": sample["id"], 184 | "question": question, 185 | "answer": sample["answer"], 186 | "oracle_ids": oracle_ids, 187 | "chunk_ids": chunk_ids, 188 | "model_output": result["answer"], 189 | } 190 | return result 191 | 192 | 193 | def process_dataset( 194 | model: LanguageModel, 195 | retriever: Retriever, 196 | prompt_template: str, 197 | topk: int, 198 | dataset: List[dict], 199 | max_workers: int = 10, 200 | ): 201 | with ThreadPoolExecutor(max_workers=max_workers) as executor: 202 | tasks = { 203 | executor.submit( 204 | process_sample, model, retriever, prompt_template, sample, topk 205 | ): idx 206 | for idx, sample in enumerate(dataset) 207 | } 208 | results = [] 209 | for future in tqdm_rich(as_completed(tasks), total=len(dataset)): 210 | idx = tasks[future] 211 | try: 212 | res = future.result() 213 | results.append((idx, res)) 214 | except Exception as e: 215 | print(f"Error processing sample {idx}: {e}") 216 | results = [r[1] for r in sorted(results, key=lambda x: x[0])] 217 | return results 218 | 219 | 220 | def main(opt: argparse.Namespace): 221 | start = time.time() 222 | passage_path = os.path.join(CORPUS_DATA_PATH, opt.dataset, "corpus.jsonl") 223 | if opt.retriever == "e5-base-v2": 224 | embedding_path = os.path.join(CORPUS_DATA_PATH, opt.dataset, "e5-base") 225 | elif opt.retriever == "contriever": 226 | embedding_path = os.path.join(CORPUS_DATA_PATH, opt.dataset, "contriever") 227 | retriever = Retriever( 228 | passage_path=passage_path, 229 | passage_embedding_path=embedding_path, 230 | index_path_dir=embedding_path, 231 | model_type=opt.retriever, 232 | ) 233 | dataset = load_jsonl( 234 | os.path.join( 235 | SYNTHESIZED_NEXT_QUERY_EXTRACTED_DATA_PATH, opt.dataset, "valid.jsonl" 236 | ) 237 | ) 238 | 239 | prompt_template_mapping = { 240 | "musique": DIRECT_RETRIEVE_ANSWER_PROMPT_MUSIQUE, 241 | "2WikiMQA": DIRECT_RETRIEVE_ANSWER_PROMPT_WIKIMQA, 242 | "hotpotQA": DIRECT_RETRIEVE_ANSWER_PROMPT_HOTPOTQA, 243 | } 244 | prompt_template = prompt_template_mapping[opt.dataset] 245 | model: LanguageModel 246 | model = get_model(opt.model) 247 | results = process_dataset( 248 | model, retriever, prompt_template, opt.topk, dataset, max_workers=opt.workers 249 | ) 250 | 251 | end = time.time() 252 | print(f"Retrieval time: {end - start:.2f}s") 253 | print(f"Average retrieval time: {(end - start) / len(dataset):.2f}s") 254 | output_path = os.path.join( 255 | RETRIEVE_RESULT_PATH, "direct", f"{opt.dataset}-@{opt.topk}.jsonl" 256 | ) 257 | write_jsonl(results, output_path) 258 | 259 | 260 | if __name__ == "__main__": 261 | options = parse_args() 262 | main(options) 263 | -------------------------------------------------------------------------------- /src/data_module/dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from pprint import pprint 4 | from typing import Any, Literal, NamedTuple, Union 5 | 6 | from torch.utils.data import Dataset 7 | 8 | from conf import DATASET_PATH 9 | 10 | ChunkInfo = NamedTuple("ChunkInfo", id=int, title=str, chunk=str) 11 | AllDatasets = Literal["hotpotQA", "wikiMQA", "musique"] 12 | 13 | 14 | class MultiHopDataset(Dataset): 15 | """ 16 | Base class for multi-hop datasets 17 | return params: 18 | @id: str, unique id for the question sample 19 | @type: str, multi-hop type of the question 20 | - compose: Q(explicit) -> A -> B 21 | - compare: Q -> A, Q -> B, (A, B) 22 | - inference: Q(implicit) -> A -> B 23 | - bridge_compare: Q -> A1 -> A2, Q -> B1 -> B2, (A2, B2) 24 | @question: str, the question 25 | @chunks: list[str], list of chunks 26 | @supporting_facts: list[ChunkInfo], list of supporting facts 27 | @decomposition: 28 | """ 29 | 30 | def __init__(self, data_path: str) -> None: 31 | super().__init__() 32 | self.data = self.load_data(data_path) 33 | 34 | def __getitem__(self, index): 35 | samples = self.data[index] 36 | if isinstance(samples, list): 37 | return [self.process(sample) for sample in samples] 38 | else: 39 | return self.process(samples) 40 | 41 | def process(self, sample: dict[str, Any]) -> dict[str, Any]: 42 | question = self.get_question(sample) 43 | # TBD: Whether to remove the question mark 44 | # if question.endswith("?"): 45 | # question = question[:-1] 46 | 47 | chunks = self.get_chunks(sample) 48 | chunks = [chunk.chunk for chunk in chunks] 49 | answer = self.get_answer(sample) 50 | supporting_facts = self.get_supporting_facts(sample) 51 | supporting_facts = [fact._asdict() for fact in supporting_facts] 52 | 53 | decomposition = self.get_decomposition(sample) 54 | return { 55 | "id": self.get_id(sample), 56 | "type": self.get_type(sample), 57 | "question": question, 58 | "answer": answer, 59 | "chunks": chunks, 60 | "supporting_facts": supporting_facts, 61 | "decomposition": decomposition, 62 | } 63 | 64 | def __len__(self) -> int: 65 | return len(self.data) 66 | 67 | def get_question(self, sample: dict) -> str: 68 | return sample["question"] 69 | 70 | def get_decomposition(self, sample: dict) -> list[str]: 71 | return None 72 | 73 | def get_answer(self, sample: dict) -> str: 74 | return sample["answer"] 75 | 76 | def get_answer(self, sample: dict) -> str: 77 | return sample["answer"] 78 | 79 | def get_id(self, sample: dict) -> str: 80 | return sample["id"] 81 | 82 | def get_type(self, sample: dict) -> str: 83 | raise NotImplementedError 84 | 85 | def get_supporting_facts(self, sample: dict) -> list[ChunkInfo]: 86 | raise NotImplementedError 87 | 88 | def get_chunks(self, sample: dict) -> list[ChunkInfo]: 89 | raise NotImplementedError 90 | 91 | def load_data(self, data_path: str): 92 | with open(data_path, "r", encoding="utf-8") as f: 93 | data = json.load(f) 94 | return data 95 | 96 | def process_chunk(self, title: str, chunk: Union[str, list[str]]) -> str: 97 | if isinstance(chunk, list): 98 | chunk = " ".join(chunk) 99 | return f"{title}: {chunk}" 100 | # Title of chunk: chunk chunk chunk 101 | 102 | def collate_fn(self, batch: list[dict[str, Any]]): 103 | raise NotImplementedError 104 | 105 | 106 | class HotpotQADataset(MultiHopDataset): 107 | def __init__(self, data_path: str) -> None: 108 | super().__init__(data_path) 109 | 110 | def get_supporting_facts(self, sample: dict) -> list[ChunkInfo]: 111 | titles = list(set(sample["supporting_facts"]["title"])) 112 | supporting_facts = [] 113 | for title in titles: 114 | idx = sample["context"]["title"].index(title) 115 | sentences = sample["context"]["sentences"][idx] 116 | chunk = self.process_chunk(title, sentences) 117 | info = ChunkInfo(idx, title, chunk) 118 | supporting_facts.append(info) 119 | return supporting_facts 120 | 121 | def get_chunks(self, sample: dict) -> list[ChunkInfo]: 122 | chunks = [] 123 | for idx, (title, context) in enumerate( 124 | zip(sample["context"]["title"], sample["context"]["sentences"]) 125 | ): 126 | chunk = self.process_chunk(title, context) 127 | info = ChunkInfo(idx, title, chunk) 128 | chunks.append(info) 129 | return chunks 130 | 131 | def get_type(self, sample: dict) -> str: 132 | type = sample["type"] 133 | if type == "bridge": 134 | return "compose" 135 | return type 136 | 137 | def get_hop(self, sample: dict) -> str: 138 | return 2 139 | 140 | 141 | class WikiMQADataset(MultiHopDataset): 142 | def __init__(self, data_path: str) -> None: 143 | super().__init__(data_path) 144 | 145 | def get_supporting_facts(self, sample: dict) -> list[ChunkInfo]: 146 | facts = sample["supporting_facts"] 147 | titles = list(set([fact[0] for fact in facts])) 148 | chunks = self.get_chunks(sample) 149 | facts = [chunk for chunk in chunks if chunk.title in titles] 150 | return facts 151 | 152 | def get_chunks(self, sample: dict) -> list[ChunkInfo]: 153 | chunks = [] 154 | for idx, content in enumerate(sample["context"]): 155 | title = content[0] 156 | chunk = self.process_chunk(content[0], content[1]) 157 | info = ChunkInfo(idx, title, chunk) 158 | chunks.append(info) 159 | return chunks 160 | 161 | def get_type(self, sample: dict) -> str: 162 | return sample["type"] 163 | 164 | def get_hop(self, sample: dict) -> str: 165 | if self.get_type(sample) == "bridge_compare": 166 | return 4 167 | return 2 168 | 169 | def get_id(self, sample: dict) -> str: 170 | return sample["_id"] 171 | 172 | def get_decomposition(self, sample: dict): 173 | chunks = self.get_chunks(sample) 174 | decompositions = [] 175 | if len(sample["supporting_facts"]) != len(sample["evidences"]): 176 | return [] 177 | for supporting_fact, evidence in zip( 178 | sample["supporting_facts"], sample["evidences"] 179 | ): 180 | title, idx = supporting_fact[0], supporting_fact[1] 181 | evidence = f"{evidence[0]} - {evidence[1]} - {evidence[2]}" 182 | fact_chunk = next(chunk for chunk in chunks if chunk.title == title) 183 | decomposition = { 184 | # 'chunk': chunks[idx].chunk, 185 | "id": fact_chunk.id, 186 | "chunk": fact_chunk.chunk, 187 | "title": title, 188 | "evidence": evidence, 189 | } 190 | decompositions.append(decomposition) 191 | return decompositions 192 | 193 | 194 | class MuSiQueDataset(MultiHopDataset): 195 | def __init__(self, data_path) -> None: 196 | super().__init__(data_path) 197 | 198 | def get_supporting_facts(self, sample: dict) -> list[ChunkInfo]: 199 | chunks = self.get_chunks(sample) 200 | is_supports = sample["paragraphs"]["is_supporting"] 201 | supporting_facts = [] 202 | for is_support, chunk in zip(is_supports, chunks): 203 | if is_support: 204 | supporting_facts.append(chunk) 205 | return supporting_facts 206 | 207 | def get_chunks(self, sample: dict) -> list[ChunkInfo]: 208 | chunks = [] 209 | for idx, title, chunk in zip( 210 | sample["paragraphs"]["idx"], 211 | sample["paragraphs"]["title"], 212 | sample["paragraphs"]["paragraph_text"], 213 | ): 214 | chunk = self.process_chunk(title, chunk) 215 | info = ChunkInfo(idx, title, chunk) 216 | chunks.append(info) 217 | return chunks 218 | 219 | def get_decomposition(self, sample: dict) -> list[str]: 220 | return sample["question_decomposition"] 221 | 222 | def get_type(self, sample: dict) -> str: 223 | # TODO this is a difficult question 224 | return None 225 | 226 | def get_hop(self, sample: dict) -> str: 227 | return int(sample[id][0]) 228 | 229 | 230 | def get_dataset(dataset: AllDatasets, split: str) -> MultiHopDataset: 231 | if dataset == "hotpotQA": 232 | return HotpotQADataset(os.path.join(DATASET_PATH, "hotpotQA", f"{split}.json")) 233 | elif dataset == "2WikiMQA": 234 | return WikiMQADataset(os.path.join(DATASET_PATH, "2WikiMQA", f"{split}.json")) 235 | elif "musique" in dataset: 236 | return MuSiQueDataset(os.path.join(DATASET_PATH, dataset, f"{split}.json")) 237 | else: 238 | raise ValueError(f"Unknown dataset: {dataset}") 239 | 240 | 241 | def main(): 242 | dataset = get_dataset("2WikiMQA", "demo") 243 | pprint(dataset[0]) 244 | 245 | 246 | if __name__ == "__main__": 247 | main() 248 | -------------------------------------------------------------------------------- /src/data_synthesize/prompts/span_labeling.py: -------------------------------------------------------------------------------- 1 | SPAN_LABELING_SYSTEM_PROMPT = """ 2 | You are an outstanding linguistic, and very good at identifying information following the user instructions. 3 | """.strip() 4 | 5 | SPAN_LABELING_PROMPT = """ 6 | You are assigned an information labeling task. 7 | We will show you a multi-hop question, a single-hop question, a document and the answer to the single-hop question. You should identify the span in the document that contains the answer to the single-hop question, and the span in the multi-hop question that contains the single-hop question. 8 | You should embrace the span from the multi-hop question with and tags, containing the single-hop question. And the span from the document with and tags, containing the answer to the single-hop question. 9 | You first should think step by step, finding out the single-hop question span and corresponding answer, and then figure out the rephrased question, which represents the multi-hop question with the single-hop question span replaced by the answer span. Try to make the rephrased question more fluent by modifying the question span. 10 | Your response should be in JSON format and include the following keys: 11 | - "labeled_document": a string representing the document with the answer span embraced by and . 12 | - "labeled_question": a string representing the multi-hop question with the single-hop question span embraced by and . If the multi-hop question share the same meaning with the single-hop question, embrace the whole question. 13 | 14 | Please adhere to the following guidelines: 15 | - Do not reorder, change, or add words. All the words from your response should be present in the document or the multi-hop question. 16 | - You must label both the multi-hop question and the document. 17 | - You must label ONLY ONE span in the document and the multi-hop question. 18 | 19 | Multi-hop Question: "What does the name of the organization the Haiti national football team belongs to stand for?" 20 | Single-hop Question: "What organization is the Haiti national football team a member of?" 21 | Document: "2014 Kosovo v Haiti football match: Kosovo vs Haiti was the first international match involving the Kosovar national football team to be recognised by FIFA, and the first to take place within Kosovo. The match was an international friendly between representative teams from Kosovo and Haiti." 22 | Answer: "FIFA" 23 | Your response: 24 | Thought: The single-hop question span should cover "organization the Haiti national football team belongs to", whose answer is FIFA. The rephrased question should be "What does FIFA stand for?", so the labeled question should be "What does the name of the organization the Haiti national football team belongs to stand for". 25 | ```json 26 | {{ 27 | "labeled_document": "2014 Kosovo v Haiti football match: Kosovo vs Haiti was the first international match involving the Kosovar national football team to be recognised by FIFA, and the first to take place within Kosovo. The match was an international friendly between representative teams from Kosovo and Haiti.", 28 | "labeled_question": "What does the name of the organization the Haiti national football team belongs to stand for?" 29 | }} 30 | ``` 31 | 32 | Multi-hop Question: "What does FIFA stand for?" 33 | Single-hop Question: "What is the meaning of FIFA?" 34 | Document: "Switzerland: Swiss are fans of football and the national team is nicknamed the 'Nati'. The headquarters of the sport's governing body, the International Federation of Association Football (FIFA), is located in Z\u00fcrich. Switzerland hosted the 1954 FIFA World Cup, and was the joint host, with Austria, of the Euro 2008 tournament. The Swiss Super League is the nation's professional club league. For the Brasil 2014 World Cup finals tournament, the country's German-speaking cantons will be closely monitored by local police forces to prevent celebrations beyond one hour after matches end. Europe's highest football pitch, at 2,000 metres (6,600 ft) above sea level, is located in Switzerland and is named the Ottmar Hitzfeld Stadium." 35 | Answer: "International Federation of Association Football" 36 | Your response: 37 | Thought: The question is already a single-hop question, which shares the same meaning with the single-hop question. So we should embrace the whole queston. The labeled question should be "What does FIFA stand for?". 38 | ```json 39 | {{ 40 | "labeled_document": "Switzerland: Swiss are fans of football and the national team is nicknamed the 'Nati'. The headquarters of the sport's governing body, the International Federation of Association Football (FIFA), is located in Z\u00fcrich. Switzerland hosted the 1954 FIFA World Cup, and was the joint host, with Austria, of the Euro 2008 tournament. The Swiss Super League is the nation's professional club league. For the Brasil 2014 World Cup finals tournament, the country's German-speaking cantons will be closely monitored by local police forces to prevent celebrations beyond one hour after matches end. Europe's highest football pitch, at 2,000 metres (6,600 ft) above sea level, is located in Switzerland and is named the Ottmar Hitzfeld Stadium.", 41 | "labeled_question": "What does FIFA stand for?" 42 | }} 43 | ``` 44 | 45 | Multi-hop Question: "The military group of which the Air Defense Artillery is a branch was unprepared for the invasion of the territory the Nazis occupied. The country of this group was the only communist country to have an embassy where?" 46 | Single-hop Question: "The Air Defense Artillery is a branch of what" 47 | Document: "United States Army: Currently, the army is divided into the Regular Army, the Army Reserve, and the Army National Guard. The army is also divided into major branches such as Air Defense Artillery, Infantry, Aviation, Signal Corps, Corps of Engineers, and Armor. Before 1903 members of the National Guard were considered state soldiers unless federalized (i.e., activated) by the President. Since the Militia Act of 1903 all National Guard soldiers have held dual status: as National Guardsmen under the authority of the governor of their state or territory and, when activated, as a reserve of the U.S. Army under the authority of the President." 48 | Answer: "the US Army" 49 | Your response: 50 | Thought: The single-hop question span should cover "the Air Defense Artillery is a branch of what", whose answer is "the US Army". The answer "the US Army" is corresponding to "United States Army" from the document. So the rephrased question should be "The military group of United States Army was unprepared for the invasion of the territory the Nazis occupied. The country of this group was the only communist country to have an embassy where?". The labeled question should be "The military group of which the Air Defense Artillery is a branch was unprepared for the invasion of the territory the Nazis occupied. The country of this group was the only communist country to have an embassy where?". 51 | ```json 52 | {{ 53 | "labeled_document": "United States Army: Currently, the army is divided into the Regular Army, the Army Reserve, and the Army National Guard. The army is also divided into major branches such as Air Defense Artillery, Infantry, Aviation, Signal Corps, Corps of Engineers, and Armor. Before 1903 members of the National Guard were considered state soldiers unless federalized (i.e., activated) by the President. Since the Militia Act of 1903 all National Guard soldiers have held dual status: as National Guardsmen under the authority of the governor of their state or territory and, when activated, as a reserve of the U.S. Army under the authority of the President.", 54 | "labeled_question": "The military group of which the Air Defense Artillery is a branch was unprepared for the invasion of the territory the Nazis occupied. The country of this group was the only communist country to have an embassy where?" 55 | }} 56 | ``` 57 | 58 | Multi-hop Question: "The military group of United States Army was unprepared for the invasion of Czechoslovakia. The country of this group was the only communist country to have an embassy where?" 59 | Single-hop Question: "What's the country of the Army that was unprepared for the invasion of Czechoslovakia?" 60 | Document: "Josip Broz Tito: In 1968, Tito offered Czechoslovak leader Alexander Dub\u010dek to fly to Prague on three hours notice if Dub\u010dek needed help in facing down the Soviets. In April 1969, Tito removed generals Ivan Go\u0161njak and Rade Hamovi\u0107 in the aftermath of the invasion of Czechoslovakia due to the unpreparedness of the Yugoslav army to respond to a similar invasion of Yugoslavia." 61 | Answer: "Yugoslavia" 62 | Your response: 63 | Thought: The single-hop question span should cover "The country of the Army unprepared for the invasion of Czechoslovakia?", whose answer is "Yugoslavia". So the rephrased question should be "Yugoslavia was the only communist country to have an embassy where?". The labeled question should be "The military group of United States Army was unprepared for the invasion of Czechoslovakia. The country of this group was the only communist country to have an embassy where?". 64 | ```json 65 | {{ 66 | "labeled_document": "Josip Broz Tito: In 1968, Tito offered Czechoslovak leader Alexander Dub\u010dek to fly to Prague on three hours notice if Dub\u010dek needed help in facing down the Soviets. In April 1969, Tito removed generals Ivan Go\u0161njak and Rade Hamovi\u0107 in the aftermath of the invasion of Czechoslovakia due to the unpreparedness of the Yugoslav army to respond to a similar invasion of Yugoslavia .", 67 | "labeled_question": "The military group of United States Army was unprepared for the invasion of Czechoslovakia. The country of this group was the only communist country to have an embassy where?" 68 | }} 69 | ``` 70 | 71 | Multi-hop Question: "{multi_hop_question}" 72 | Single-hop Question: "{single_hop_question}" 73 | Document: "{document}" 74 | Answer: "{answer}" 75 | Your response: 76 | """.strip() 77 | -------------------------------------------------------------------------------- /src/data_synthesize/span_labeling.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import re 5 | import sys 6 | from concurrent.futures import ThreadPoolExecutor, as_completed 7 | 8 | from tqdm.rich import tqdm_rich 9 | 10 | sys.path.append(os.path.join(os.path.dirname(__file__), "..")) 11 | 12 | from prompts.span_labeling import SPAN_LABELING_PROMPT, SPAN_LABELING_SYSTEM_PROMPT 13 | 14 | from conf import ( 15 | MODEL_DICT, 16 | SYNTHESIZED_DECOMPOSED_DATA_PATH, 17 | SYNTHESIZED_SPAN_LABELING_DATA_PATH, 18 | ) 19 | from language_models import AOAI, LlamaServer 20 | from utils import ask_model, load_jsonl 21 | 22 | BEGIN_OF_QUESTION_SPAN_TOKEN = "" 23 | END_OF_QUESTION_SPAN_TOKEN = "" 24 | BEGIN_OF_ANSWER_SPAN_TOKEN = "" 25 | END_OF_ANSWER_SPAN_TOKEN = "" 26 | 27 | QUESTION_SPAN_TOKEN_PATTERN = ( 28 | rf"{BEGIN_OF_QUESTION_SPAN_TOKEN}(.+?){END_OF_QUESTION_SPAN_TOKEN}" 29 | ) 30 | ANSWER_SPAN_TOKEN_PATTERN = ( 31 | rf"{BEGIN_OF_ANSWER_SPAN_TOKEN}(.+?){END_OF_ANSWER_SPAN_TOKEN}" 32 | ) 33 | 34 | 35 | class SpanLabeler: 36 | def __init__(self, model: str, dataset: str, split: str) -> None: 37 | if "gpt" in model: 38 | self.model = AOAI(model) 39 | elif "Llama" in model: 40 | self.model = LlamaServer(model) 41 | 42 | decomposed_data_path = os.path.join( 43 | SYNTHESIZED_DECOMPOSED_DATA_PATH, dataset, f"{split}.jsonl" 44 | ) 45 | self.data = load_jsonl(decomposed_data_path) 46 | self.data_mapping = {d["id"]: d for d in self.data} 47 | # self.check_if_valid = lambda x: True 48 | self.check_if_valid = ( 49 | lambda x: all( 50 | [ 51 | k in x.keys() 52 | for k in [ 53 | "labeled_question", 54 | "labeled_document", 55 | ] 56 | ] 57 | ) 58 | and x["labeled_question"].find(BEGIN_OF_QUESTION_SPAN_TOKEN) >= 0 59 | and x["labeled_question"].find(END_OF_QUESTION_SPAN_TOKEN) >= 0 60 | and x["labeled_document"].find(BEGIN_OF_ANSWER_SPAN_TOKEN) >= 0 61 | and x["labeled_document"].find(END_OF_ANSWER_SPAN_TOKEN) >= 0 62 | ) 63 | 64 | def parse(self, starting: int = 0, ending: int = None, workers=10) -> list[dict]: 65 | if ending is None: 66 | ending = len(self.data) 67 | data = self.data[starting:ending] 68 | data = [d for d in data if d.get("state", None) is None] 69 | 70 | if workers > 1: 71 | with ThreadPoolExecutor(max_workers=workers) as executor: 72 | tasks = { 73 | executor.submit(self.parse_sample, sample): idx 74 | for idx, sample in enumerate(data) 75 | } 76 | results = [] 77 | for future in tqdm_rich( 78 | as_completed(tasks), total=len(tasks), desc="Processing..." 79 | ): 80 | task_id = tasks[future] 81 | try: 82 | result = future.result() 83 | results.append((task_id, result)) 84 | finally: 85 | ... 86 | results = [result[1] for result in sorted(results, key=lambda x: x[0])] 87 | else: 88 | results = [ 89 | self.parse_sample(sample) 90 | for sample in tqdm_rich(data, desc="Processing...") 91 | ] 92 | return results 93 | 94 | def parse_sample( 95 | self, 96 | sample: dict, 97 | ) -> dict: 98 | # First hop 99 | for subq_id, chunk in sample["decomposed_questions"].items(): 100 | if len(chunk["dependency"]) == 0: 101 | chunk["current_question"] = sample["question"] 102 | 103 | max_iter = 5 104 | cur_iter = 0 105 | while cur_iter < max_iter and not all( 106 | [ 107 | "current_question" in subq.keys() 108 | for subq in sample["decomposed_questions"].values() 109 | ] 110 | ): 111 | cur_iter += 1 112 | # construct filtered_query for each sub-question 113 | prompt_list, subq_id_list, current_question_list = self.parse_prompt(sample) 114 | if len(prompt_list) == 0: 115 | break 116 | 117 | for prompt, subq_id, current_question in zip( 118 | prompt_list, subq_id_list, current_question_list 119 | ): 120 | result = ask_model( 121 | self.model, 122 | prompt, 123 | SPAN_LABELING_SYSTEM_PROMPT, 124 | type="json", 125 | check_if_valid=self.check_if_valid, 126 | sleep=False, 127 | ) 128 | chunk = sample["decomposed_questions"][subq_id] 129 | chunk["current_question"] = current_question 130 | try: 131 | assert "span_status" not in chunk.keys(), "The chunk parse failed" 132 | assert result is not None, "result is None" 133 | result = self.parse_result(result) 134 | except Exception as e: 135 | print(e) 136 | chunk["span_status"] = "error" 137 | continue 138 | 139 | for k, v in result.items(): 140 | chunk[k] = v 141 | 142 | return sample 143 | 144 | def parse_result(self, result: dict) -> dict: 145 | labeled_question = result["labeled_question"] 146 | labeled_document = result["labeled_document"] 147 | 148 | part_of_question = re.search( 149 | QUESTION_SPAN_TOKEN_PATTERN, labeled_question 150 | ).group(1) 151 | part_of_document = re.search(ANSWER_SPAN_TOKEN_PATTERN, labeled_document).group( 152 | 1 153 | ) 154 | next_question = ( 155 | labeled_question.replace(part_of_question, part_of_document) 156 | .replace(BEGIN_OF_QUESTION_SPAN_TOKEN, "") 157 | .replace(END_OF_QUESTION_SPAN_TOKEN, "") 158 | ) 159 | return { 160 | "labeled_question": labeled_question, 161 | "labeled_document": labeled_document, 162 | "part_of_question": part_of_question, 163 | "part_of_document": part_of_document, 164 | "next_question": next_question, 165 | } 166 | 167 | def parse_prompt(self, sample: dict) -> list[dict]: 168 | def construct_subq(subq_id, dependencies): 169 | if len(dependencies) == 0: 170 | return sample["question"] 171 | 172 | cur_question = None 173 | for dep in dependencies: 174 | if dep in sample["decomposed_questions"].keys(): 175 | subq = sample["decomposed_questions"][dep] 176 | part_of_question = subq.get("part_of_question", None) 177 | part_of_document = subq.get("part_of_document", None) 178 | if part_of_question is None or part_of_document is None: 179 | return None 180 | 181 | if cur_question is None: 182 | cur_question = subq["current_question"] 183 | cur_question = cur_question.replace( 184 | part_of_question, part_of_document 185 | ) 186 | return cur_question 187 | 188 | prompt_list = [] 189 | subq_id_list = [] 190 | current_question_list = [] 191 | 192 | for subq_id in sorted(sample["decomposed_questions"].keys()): 193 | subq = sample["decomposed_questions"][subq_id] 194 | deps = subq["dependency"] 195 | if "next_question" in subq.keys(): 196 | # already labeled 197 | continue 198 | 199 | question = construct_subq(subq_id, deps) 200 | if question is None: 201 | # not ready to be labeled 202 | continue 203 | 204 | sub_question = subq["sub_question"] 205 | paragraph = subq["positive_paragraph"] 206 | sub_answer = subq["answer"] 207 | 208 | prompt = SPAN_LABELING_PROMPT.format( 209 | multi_hop_question=question, 210 | single_hop_question=sub_question, 211 | document=paragraph, 212 | answer=sub_answer, 213 | ) 214 | prompt_list.append(prompt) 215 | subq_id_list.append(subq_id) 216 | current_question_list.append(question) 217 | return prompt_list, subq_id_list, current_question_list 218 | 219 | 220 | def parse_args(): 221 | parser = argparse.ArgumentParser() 222 | parser.add_argument( 223 | "--dataset", 224 | type=str, 225 | choices=["hotpotQA", "musique", "2WikiMQA"], 226 | default="musique", 227 | ) 228 | parser.add_argument("--split", type=str, default="demo") 229 | parser.add_argument("--model", choices=["gpt35", "gpt4", "llama"], default="llama") 230 | parser.add_argument("--workers", type=int, default=10, help="parallel processors") 231 | parser.add_argument("--starting", type=int, default=0) 232 | parser.add_argument("--ending", type=int, default=None) 233 | args = parser.parse_args() 234 | return args 235 | 236 | 237 | def main(opts: argparse.Namespace): 238 | model = MODEL_DICT[opts.model] 239 | span_labeler = SpanLabeler(model, opts.dataset, opts.split) 240 | save_path = os.path.join( 241 | SYNTHESIZED_SPAN_LABELING_DATA_PATH, opts.dataset, f"{opts.split}.jsonl" 242 | ) 243 | with open(save_path, "w+", encoding="utf-8") as f: 244 | for labeled in span_labeler.parse(workers=opts.workers): 245 | info = json.dumps(labeled, ensure_ascii=False) 246 | f.write(info + "\n") 247 | 248 | 249 | if __name__ == "__main__": 250 | options = parse_args() 251 | main(options) 252 | --------------------------------------------------------------------------------