├── data_curation ├── __init__.py ├── scripts │ ├── disambiguate.sh │ ├── decompose.sh │ └── multiturn.sh ├── llm_agent │ ├── __init__.py │ ├── prompt_template │ │ ├── __init__.py │ │ ├── QueryGeneratorTemplate.txt │ │ ├── UnambiguousGeneratorTemplate.txt │ │ ├── Template.py │ │ ├── QueryRewriterTemplate.txt │ │ ├── QAGeneratorTemplate.txt │ │ ├── QueryJudgerTemplate.txt │ │ ├── UnambiguousGeneratorTemplateLong.txt │ │ ├── UnambiguousGeneratorTemplateShort.txt │ │ ├── DecomposeGeneratorTemplate_old.txt │ │ ├── DecomposeGeneratorTemplate.txt │ │ ├── MultiTurnGeneratorTemplateForAns.txt │ │ └── MultiTurnGeneratorTemplate.txt │ ├── openai_api.py │ ├── base.py │ ├── openai_judger.py │ ├── openai_rewriter.py │ ├── openai_generator.py │ ├── openai_decomposer.py │ ├── openai_multiturn_retriever_rewriter.py │ └── datatypes.py ├── tools │ ├── duckduckgo_rapidapi.py │ ├── bm25_candidates.py │ ├── bingsearch_azure.py │ └── openai_embedding_search.py └── main_unambiguous_answer_generate.py ├── retrieval_lm ├── __init__.py ├── src │ ├── __init__.py │ ├── index.py │ ├── inbatch.py │ ├── dist_utils.py │ ├── slurm.py │ ├── normalize_text.py │ ├── moco.py │ ├── contriever.py │ ├── finetuning_data.py │ ├── options.py │ ├── evaluation.py │ ├── utils.py │ ├── beir_utils.py │ └── data.py ├── streamlit_utils │ ├── __init__.py │ └── inference_func.py ├── __pycache__ │ └── metrics.cpython-39.pyc ├── scripts │ ├── sample_from_tree │ │ ├── arc.sh │ │ └── hotpotqa.sh │ ├── inference │ │ └── search_engine │ │ │ ├── arc.sh │ │ │ └── hotpotqa.sh │ └── train │ │ └── script_finetune_7b.sh ├── stage3_no_offloading_accelerate.conf ├── generate_embeddings_sag.py ├── metrics.py ├── llama_flash_attn_monkey_patch.py ├── utils.py └── streamlit_demo.py ├── .gitignore ├── images └── data_construction.png ├── requirements.txt └── README.md /data_curation/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /retrieval_lm/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /retrieval_lm/src/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | *.json 3 | *.jsonl 4 | -------------------------------------------------------------------------------- /retrieval_lm/streamlit_utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /images/data_construction.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chanchimin/RQ-RAG/HEAD/images/data_construction.png -------------------------------------------------------------------------------- /retrieval_lm/__pycache__/metrics.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/chanchimin/RQ-RAG/HEAD/retrieval_lm/__pycache__/metrics.cpython-39.pyc -------------------------------------------------------------------------------- /data_curation/scripts/disambiguate.sh: -------------------------------------------------------------------------------- 1 | 2 | cd ./data_creation_sag 3 | 4 | python main_unambiguous_answer_generate.py \ 5 | --raw_data_path_ambiguous \ 6 | "your file" \ 7 | --ndocs \ 8 | 3 \ 9 | --output_path \ 10 | "your output_path" \ 11 | --search_engine_type \ 12 | duckduckgo \ 13 | --openai_api_key \ 14 | "your key" \ 15 | --overwrite_output -------------------------------------------------------------------------------- /data_curation/scripts/decompose.sh: -------------------------------------------------------------------------------- 1 | 2 | cd ./data_creation_sag 3 | 4 | python main_decomposed_answer_generate.py \ 5 | --raw_data_path_decomposed \ 6 | "your file" \ 7 | --ndocs \ 8 | 3 \ 9 | --output_path \ 10 | "your output_path" \ 11 | --search_engine_type \ 12 | bm25_candidates \ 13 | --openai_api_key \ 14 | "your key" \ 15 | --overwrite_output -------------------------------------------------------------------------------- /data_curation/scripts/multiturn.sh: -------------------------------------------------------------------------------- 1 | 2 | cd ./data_creation_sag 3 | 4 | python main_multiturn_answer_generate.py \ 5 | --raw_data_path_multiturn \ 6 | "your file" \ 7 | --ndocs \ 8 | 3 \ 9 | --output_path \ 10 | "your output_path" \ 11 | --search_engine_type \ 12 | duckduckgo \ 13 | --openai_api_key \ 14 | "your key" \ 15 | --overwrite_output -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tqdm 2 | spacy 3 | jsonlines 4 | transformers 5 | filelock 6 | datasets 7 | nltk 8 | dataclasses 9 | sacrebleu 10 | rouge_score 11 | peft 12 | bitsandbytes 13 | einops 14 | sentencepiece 15 | packaging 16 | accelerate 17 | deepspeed 18 | peft 19 | evaluate 20 | tokenizers 21 | tensorboard 22 | openai 23 | tiktoken 24 | bitsandbytes 25 | duckduckgo_search 26 | -------------------------------------------------------------------------------- /data_curation/llm_agent/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import LMAgent 2 | from .openai_rewriter import OpenAIRewriterLMAgent 3 | from .openai_rewriter import OpenAIRewriterLMAgent 4 | from .openai_judger import OpenAIJudgerLMAgent 5 | from .openai_generator import OpenAIGeneratorLMAgent 6 | from .openai_decomposer import OpenAIDecomposerLMAgent 7 | from .openai_multiturn_retriever_rewriter import OpenAIMultiTurnLMAgent 8 | -------------------------------------------------------------------------------- /data_curation/llm_agent/prompt_template/__init__.py: -------------------------------------------------------------------------------- 1 | from .Template import QueryRewriterTemplate, \ 2 | QueryJudgerTemplate, \ 3 | QueryGeneratorTemplate, \ 4 | UnambiguousGeneratorTemplateLong, \ 5 | UnambiguousGeneratorTemplateShort, \ 6 | DecomposeGeneratorTemplate, \ 7 | MultiTurnGeneratorTemplate, \ 8 | MultiTurnGeneratorTemplateForAns, \ 9 | QAGeneratorTemplate 10 | -------------------------------------------------------------------------------- /retrieval_lm/scripts/sample_from_tree/arc.sh: -------------------------------------------------------------------------------- 1 | cd ./retrieval_lm/output 2 | 3 | python sample_from_tree.py \ 4 | --run_name \ 5 | "your result dir containing final_results.json" \ 6 | --task \ 7 | arc_challenge \ 8 | --original_data \ 9 | "your original data containing ground truths" \ 10 | --model_name_or_path \ 11 | "your model" \ 12 | --calc_depth \ 13 | 1 \ 14 | 2 \ 15 | 3 \ 16 | --calc_width \ 17 | S_Rewritten_Query \ 18 | S_Decomposed_Query \ 19 | S_Disambiguated_Query \ 20 | A_Response \ 21 | 22 | 23 | 24 | -------------------------------------------------------------------------------- /retrieval_lm/scripts/sample_from_tree/hotpotqa.sh: -------------------------------------------------------------------------------- 1 | cd ./retrieval_lm/output 2 | 3 | python sample_from_tree.py \ 4 | --run_name \ 5 | "your result dir containing final_results.json" \ 6 | --task \ 7 | hotpotqa \ 8 | --original_data \ 9 | "your original data containing ground truths" \ 10 | --model_name_or_path \ 11 | "your model" \ 12 | --calc_depth \ 13 | 1 \ 14 | 2 \ 15 | 3 \ 16 | --calc_width \ 17 | S_Rewritten_Query \ 18 | S_Decomposed_Query \ 19 | S_Disambiguated_Query \ 20 | A_Response \ 21 | --calc_retrieval_performance 22 | -------------------------------------------------------------------------------- /data_curation/llm_agent/openai_api.py: -------------------------------------------------------------------------------- 1 | import os 2 | import requests 3 | import json 4 | from openai import OpenAI 5 | 6 | 7 | class OpenAIClient: 8 | def __init__(self, config, model_version): 9 | 10 | self.model_version = model_version 11 | self.client = OpenAI(**config) 12 | 13 | 14 | 15 | def chat_sync(self, messages, params={"temperature": 0, "max_tokens": 500}): 16 | 17 | return self.client.chat.completions.create( 18 | model=self.model_version, 19 | messages=messages, 20 | **params 21 | ) 22 | 23 | -------------------------------------------------------------------------------- /retrieval_lm/scripts/inference/search_engine/arc.sh: -------------------------------------------------------------------------------- 1 | cd retrieval_lm 2 | export PYTHONPATH="$(pwd):$PYTHONPATH" 3 | 4 | python ./inference.py \ 5 | --model_name_or_path \ 6 | "your trained model" \ 7 | --input_file \ 8 | "your data" \ 9 | --max_new_tokens \ 10 | 100 \ 11 | --output_path \ 12 | "your output_path" \ 13 | --ndocs \ 14 | 3 \ 15 | --dtype \ 16 | half \ 17 | --use_search_engine \ 18 | --use_hf \ 19 | --task \ 20 | arc_challenge \ 21 | --tree_decode \ 22 | --oracle \ 23 | --search_engine_type \ 24 | "duckduckgo" \ 25 | --expand_on_tokens \ 26 | [S_Rewritten_Query] \ 27 | [S_Decomposed_Query] \ 28 | [S_Disambiguated_Query] \ 29 | [A_Response] -------------------------------------------------------------------------------- /retrieval_lm/scripts/inference/search_engine/hotpotqa.sh: -------------------------------------------------------------------------------- 1 | cd retrieval_lm 2 | 3 | 4 | export PYTHONPATH="$(pwd):$PYTHONPATH" 5 | 6 | python ./inference.py \ 7 | --model_name_or_path \ 8 | "your trained model" \ 9 | --input_file \ 10 | "your data" \ 11 | --max_new_tokens \ 12 | 100 \ 13 | --output_path \ 14 | "your output_path" \ 15 | --ndocs \ 16 | 3 \ 17 | --use_search_engine \ 18 | --use_hf \ 19 | --task \ 20 | hotpotqa \ 21 | --tree_decode \ 22 | --oracle \ 23 | --max_depth \ 24 | 2 \ 25 | --search_engine_type \ 26 | openai_embed \ 27 | --expand_on_tokens \ 28 | [S_Rewritten_Query] \ 29 | [S_Decomposed_Query] \ 30 | [S_Disambiguated_Query] \ 31 | [A_Response] -------------------------------------------------------------------------------- /retrieval_lm/stage3_no_offloading_accelerate.conf: -------------------------------------------------------------------------------- 1 | { 2 | "bf16": { 3 | "enabled": "auto" 4 | }, 5 | "zero_optimization": { 6 | "stage": 3, 7 | "offload_optimizer": { 8 | "device": "cpu", 9 | "pin_memory": true 10 | }, 11 | "overlap_comm": true, 12 | "contiguous_gradients": true, 13 | "sub_group_size": 1e9, 14 | "reduce_bucket_size": "auto", 15 | "stage3_prefetch_bucket_size": "auto", 16 | "stage3_param_persistence_threshold": "auto", 17 | "stage3_max_live_parameters": 1e9, 18 | "stage3_max_reuse_distance": 1e9, 19 | "stage3_gather_16bit_weights_on_model_save": true 20 | }, 21 | "activation_checkpointing": { 22 | "partition_activations": true, 23 | "cpu_checkpointing": true 24 | }, 25 | "gradient_accumulation_steps": "auto", 26 | "gradient_clipping": "auto", 27 | "steps_per_print": 1e5, 28 | "train_batch_size": "auto", 29 | "train_micro_batch_size_per_gpu": "auto", 30 | "wall_clock_breakdown": false 31 | } -------------------------------------------------------------------------------- /data_curation/llm_agent/prompt_template/QueryGeneratorTemplate.txt: -------------------------------------------------------------------------------- 1 | You are a helpful assistant. Please generate a revised answer that takes into account the revised query and the newly retrieved knowledge. 2 | 3 | It is crucial that the new response does not deviate significantly from the reference output, which is treated as ground truth. 4 | 5 | The revised answer should be a balanced integration of the new information with the essence and accuracy of the original output. 6 | 7 | Ensure that the response is both relevant to the revised query and consistent with the key points and facts of the reference output. 8 | 9 | --- 10 | 11 | Original Query: 12 | ${original_query} 13 | 14 | Rewritten Query: 15 | ${rewritten_query} 16 | 17 | Retrieved Evidences: 18 | ${retrieved_evidences} 19 | 20 | Reference Output: 21 | ${reference_output} 22 | 23 | With this in mind, generate a new response that strictly adheres to the original query. 24 | 25 | The response should integrate the new information where relevant, but the priority remains to provide a clear and direct answer to the original question. 26 | 27 | Now, please aim for a balance between incorporating this new context and staying true to the original query's intent: 28 | 29 | 30 | 31 | 32 | -------------------------------------------------------------------------------- /data_curation/llm_agent/prompt_template/UnambiguousGeneratorTemplate.txt: -------------------------------------------------------------------------------- 1 | Your task is to identify and resolve ambiguity in complex questions, ensuring they are clear and unambiguous. This requires pinpointing elements of the question that could be interpreted in more than one way and refining the question to ensure a single, clear interpretation. 2 | 3 | Approach this task as follows: 4 | 5 | Analyze the Question: Read the question thoroughly to identify ambiguous parts. Consider the different ways the question could be interpreted based on its current wording. 6 | Clarify the Query: Reformulate the question to eliminate ambiguity. This may involve specifying details, narrowing down broad terms, or providing additional context to guide the interpretation. 7 | 8 | Here's an example of how to complete the task: 9 | 10 | ### Original Question: 11 | 12 | When does bunk\'d come out? 13 | 14 | ### Disambiguated Query: 15 | 16 | When does episode 42 of bunk\'d come out? 17 | When does episode 43 of bunk\'d come out? 18 | 19 | --- 20 | 21 | As outlined, please format your answer as multiple lines of text. 22 | Ensure there is exactly one line break between each line. 23 | 24 | ### Original Question: 25 | 26 | ${original_query} 27 | 28 | ### Disambiguated Query: 29 | -------------------------------------------------------------------------------- /data_curation/llm_agent/prompt_template/Template.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | cur_dir = os.path.dirname(os.path.abspath(__file__)) 4 | 5 | with open(os.path.join(cur_dir, "QueryRewriterTemplate.txt")) as f: 6 | QueryRewriterTemplate = "".join(f.readlines()) 7 | 8 | with open(os.path.join(cur_dir, "QueryJudgerTemplate.txt")) as f: 9 | QueryJudgerTemplate = "".join(f.readlines()) 10 | 11 | with open(os.path.join(cur_dir, "QueryGeneratorTemplate.txt")) as f: 12 | QueryGeneratorTemplate = "".join(f.readlines()) 13 | 14 | with open(os.path.join(cur_dir, "UnambiguousGeneratorTemplateLong.txt")) as f: 15 | UnambiguousGeneratorTemplateLong = "".join(f.readlines()) 16 | 17 | with open(os.path.join(cur_dir, "UnambiguousGeneratorTemplateShort.txt")) as f: 18 | UnambiguousGeneratorTemplateShort = "".join(f.readlines()) 19 | 20 | with open(os.path.join(cur_dir, "DecomposeGeneratorTemplate.txt")) as f: 21 | DecomposeGeneratorTemplate = "".join(f.readlines()) 22 | 23 | with open(os.path.join(cur_dir, "MultiTurnGeneratorTemplate.txt")) as f: 24 | MultiTurnGeneratorTemplate = "".join(f.readlines()) 25 | 26 | with open(os.path.join(cur_dir, "MultiTurnGeneratorTemplateForAns.txt")) as f: 27 | MultiTurnGeneratorTemplateForAns = "".join(f.readlines()) 28 | 29 | with open(os.path.join(cur_dir, "QAGeneratorTemplate.txt")) as f: 30 | QAGeneratorTemplate = "".join(f.readlines()) 31 | 32 | -------------------------------------------------------------------------------- /data_curation/llm_agent/prompt_template/QueryRewriterTemplate.txt: -------------------------------------------------------------------------------- 1 | You are a helpful assistant. Your success hinges on formulating precise and relevant search queries. In this task, you will transform given queries into more effective ones, optimizing them for search engine use. This process, known as query rewriting, involves refining, expanding, or rephrasing the original queries to enhance clarity, specificity, and relevance. 2 | 3 | Consider the following aspects while rewriting queries: 4 | 5 | Clarity: Ensure that the query is clear and unambiguous. Avoid vague terms and strive for specificity. 6 | 7 | Conciseness: While detail is important, unnecessary verbosity can dilute the query's effectiveness. Aim for a balance between brevity and descriptiveness. 8 | 9 | Relevance: Focus on including keywords and phrases that are directly related to the information needed. Exclude irrelevant or tangential terms. 10 | 11 | Intent: Pay attention to the underlying intent of the query. Whether it's informational, navigational, or transactional, align your rewritten query with this intent. 12 | 13 | Contextual Keywords: If the query is related to a specific field or subject area, incorporate relevant jargon or technical terms. 14 | 15 | Here is the original query that I would like you to help me rewrite: 16 | 17 | ${original_query} 18 | 19 | Now, please help me rewrite a new query different from ${rewritten_queries}: 20 | -------------------------------------------------------------------------------- /retrieval_lm/scripts/train/script_finetune_7b.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 2 | 3 | cd retrieval_lm 4 | 5 | OUTPUT_DIR=output/ 6 | MODEL_SIZE=7B 7 | NUM_GPUS=8 8 | BATCH_SIZE_PER_GPU=1 9 | TOTAL_BATCH_SIZE=32 10 | GRADIENT_ACC_STEPS=$(($TOTAL_BATCH_SIZE/$NUM_GPUS/$BATCH_SIZE_PER_GPU)) 11 | echo "Training llama model ${MODEL_SIZE} using $NUM_GPUS GPUs, $BATCH_SIZE_PER_GPU batch size per GPU, $GRADIENT_ACC_STEPS gradient accumulation steps" 12 | 13 | mkdir -p "${OUTPUT_DIR}" 14 | 15 | accelerate launch \ 16 | --mixed_precision bf16 \ 17 | --num_machines 1 \ 18 | --num_processes $NUM_GPUS \ 19 | --use_deepspeed \ 20 | --deepspeed_config_file stage3_no_offloading_accelerate.conf \ 21 | finetune.py \ 22 | --model_name_or_path meta-llama/Llama-2-7b-hf \ 23 | --tokenizer_name meta-llama/Llama-2-7b-hf \ 24 | --use_slow_tokenizer \ 25 | --dataset_name "your dataset" \ 26 | --dataset_config_name "your dataset config" \ 27 | --max_seq_length 4096 \ 28 | --preprocessing_num_workers 16 \ 29 | --per_device_train_batch_size $BATCH_SIZE_PER_GPU \ 30 | --gradient_accumulation_steps $GRADIENT_ACC_STEPS \ 31 | --learning_rate 2e-5 \ 32 | --lr_scheduler_type linear \ 33 | --warmup_ratio 0.03 \ 34 | --weight_decay 0. \ 35 | --num_train_epochs 1 \ 36 | --output_dir ${OUTPUT_DIR} \ 37 | --with_tracking \ 38 | --report_to "all" \ 39 | --logging_steps 1 \ 40 | --use_special_tokens \ 41 | --checkpointing_steps 200 \ 42 | --sample_train_data "all" \ 43 | 2>&1 | tee ${OUTPUT_DIR}/ERROR.txt 44 | 45 | 46 | -------------------------------------------------------------------------------- /data_curation/llm_agent/base.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import List, Dict, Any, Mapping 3 | 4 | LOGGER = logging.getLogger("Root") 5 | 6 | from .datatypes import Action, State 7 | 8 | 9 | class LMAgent: 10 | """Base class for an agent.""" 11 | 12 | def __init__(self, config: Mapping[str, Any]): 13 | self.config = config 14 | LOGGER.info(f"Initialized {self.__class__.__name__} with config: {config}") 15 | # The agent should not generate observations or expert feedback 16 | self.stop_words = ["\nObservation:", "\nExpert feedback:", "\nTask:", "\n---"] 17 | 18 | def lm_output_to_action(self, lm_output: str) -> Action: 19 | propose_solution = bool("" in lm_output) 20 | return Action(lm_output, not propose_solution) 21 | 22 | def act(self, state: State) -> Action: 23 | """ 24 | The history should be a format like: 25 | [ 26 | {"role": "system", "content": "You are a helpful assistant."}, 27 | {"role": "user", "content": "Who won the world series in 2020?"}, 28 | {"role": "assistant", "content": "The Los Angeles Dodgers won the World Series in 2020."}, 29 | {"role": "user", "content": "Where was it played?"} 30 | ] 31 | """ 32 | raise NotImplementedError 33 | 34 | def add_system_message( 35 | self, messages: List[Dict[str, str]] 36 | ) -> List[Dict[str, str]]: 37 | # Prepend the prompt with the system message 38 | first_msg = messages[0] 39 | assert first_msg["role"] == "user" 40 | system, examples, task = first_msg["content"].split("\n---\n") 41 | messages = [ 42 | {"role": "system", "content": system}, 43 | {"role": "user", "content": examples + "\n---\n" + task}, 44 | ] + messages[1:] 45 | return messages 46 | -------------------------------------------------------------------------------- /data_curation/llm_agent/prompt_template/QAGeneratorTemplate.txt: -------------------------------------------------------------------------------- 1 | You are a helpful assistant given a multiple-choice question. Your job is to generate a list of queries to search the internet for a better answer to the question. You can utilize the following strategy to formulate the queries: 2 | 3 | 1. Formulate a query that is broad enough to provide a comprehensive background. The aim is to obtain an overview that highlights the key components or functions related to the question. 4 | 2. Construct a query that includes all the provided options along with keywords that indicate comparison or equivalence, targeting resources that discuss similarities or functions directly. 5 | 3. After narrowing down to a likely answer, try to confirm the choice and gather detailed information on the specific function or comparison. 6 | 7 | Otherwise, if you are fairly confident that the current query cannot benefit from searching the internet, please decline to form a query. 8 | 9 | For example: 10 | 11 | --- 12 | 13 | ### Current User's Query: 14 | 15 | Which of the following organs in fish has the same function as the human lung? Please choose from "kidney", "heart", "gill", "skin". 16 | 17 | ### Response: 18 | 19 | Retrieval Necessity: 20 | Yes 21 | Query For Search Engine: 22 | "How do fish breathe?" 23 | "fish organ equivalent to human lungs kidney heart gill skin." 24 | "Do gills function like human lungs in fish?" 25 | 26 | --- 27 | 28 | If faced with Non-Search-Engine-Reliant queries which do not need to search: 29 | 30 | --- 31 | 32 | ### Current User's Query: 33 | 34 | Is 210 km/h the correct average speed for 840 km in 4 hours? 35 | 36 | ### Response: 37 | 38 | Retrieval Necessity: 39 | No 40 | 41 | --- 42 | 43 | As outlined, it is necessary to output Retrieval Necessity first, and the output should be one of the "yes" and "no" and the query for the search engine should be split by a line break. 44 | 45 | Now please answer and generate at most three queries (for most of the case, we assume that you need to search): 46 | 47 | ### Current User's Query: 48 | 49 | ${current_query} 50 | 51 | ### Response: 52 | 53 | -------------------------------------------------------------------------------- /data_curation/llm_agent/prompt_template/QueryJudgerTemplate.txt: -------------------------------------------------------------------------------- 1 | You are a helpful assistant. You are tasked with evaluating and scoring a series of queries based on their relevance, clarity, and effectiveness in retrieving accurate and useful information. For each query, you will be provided with the original query, the knowledge retrieved based on that query and the reference output. Your job is to assess how well the query performed in terms of fetching relevant and precise information. 2 | 3 | Evaluation Criteria: 4 | 5 | Relevance: Does the query accurately represent the information need? Does the retrieved knowledge align closely with the query's intent? 6 | 7 | Clarity: Is the query clear and unambiguous? Could the wording of the query lead to misinterpretation or retrieval of irrelevant information? 8 | 9 | Specificity: Does the query include specific details or keywords that help in fetching precise information? Or is it too broad or too vague? 10 | 11 | Comprehensiveness: Does the retrieved knowledge cover the breadth of the query? Does it leave out critical aspects or information that the query seems to ask for? 12 | 13 | Actionability: If the query is action-oriented, does the retrieved knowledge provide actionable information or clear answers? 14 | 15 | Scoring: 16 | After evaluating the query on these criteria, assign a score from 1 to 5, where: 17 | 18 | 1 indicates poor performance (irrelevant or unclear query leading to inadequate information). 19 | 3 indicates average performance (the query is somewhat effective but has room for improvement). 20 | 5 indicates excellent performance (the query is clear, specific, and retrieves highly relevant and comprehensive information). 21 | 22 | --- 23 | 24 | Original Query: 25 | ${original_query} 26 | 27 | Retrieved Evidences: 28 | ${retrieved_evidences} 29 | 30 | Reference Output: 31 | ${reference_output} 32 | 33 | --- 34 | 35 | Notice that you should strictly adhere to the following format for your evaluation. 36 | 37 | Relevance: [score ONLY from 1 to 5] 38 | Clarity: [score ONLY from 1 to 5] 39 | Specificity: [score ONLY from 1 to 5] 40 | Comprehensiveness: [score ONLY from 1 to 5] 41 | Actionability: [score ONLY from 1 to 5] 42 | 43 | --- 44 | 45 | For example: 46 | 47 | Relevance: 5 48 | Clarity: 4 49 | Specificity: 4 50 | Comprehensiveness: 3 51 | Actionability: 3 52 | 53 | --- 54 | 55 | Now, please give your evaluation and do not output anything other than an integer value like N/A and none: 56 | 57 | 58 | 59 | 60 | 61 | 62 | -------------------------------------------------------------------------------- /data_curation/llm_agent/openai_judger.py: -------------------------------------------------------------------------------- 1 | from .base import LMAgent 2 | import openai 3 | import logging 4 | import traceback 5 | from .datatypes import Action 6 | import backoff 7 | from .prompt_template.Template import QueryJudgerTemplate 8 | from string import Template 9 | import re 10 | from .openai_api import OpenAIClient 11 | 12 | LOGGER = logging.getLogger("Root") 13 | 14 | 15 | class OpenAIJudgerLMAgent(LMAgent): 16 | def __init__(self, api_type, config, model_version="gpt-3.5-turbo"): 17 | super().__init__(config) 18 | assert api_type in ["azure", "openai"] 19 | 20 | if api_type == "openai": 21 | self.api = OpenAIClient(config, model_version) 22 | 23 | self.usage_profiles = [] 24 | self.max_try = 3 25 | 26 | @backoff.on_exception( 27 | backoff.fibo, 28 | # https://platform.openai.com/docs/guides/error-codes/python-library-error-types 29 | ( 30 | openai.APIError, 31 | openai.Timeout, 32 | openai.RateLimitError, 33 | openai.APIConnectionError, 34 | ), 35 | ) 36 | def call_lm(self, messages): 37 | 38 | # Prepend the prompt with the system message 39 | js = self.api.chat_sync(messages=messages) 40 | 41 | response = js.choices[0].message.content 42 | 43 | # usage contains input token, output token, times 44 | usage = { 45 | "promptTokens": js.usage.prompt_tokens, 46 | "completionTokens": js.usage.completion_tokens, 47 | "totalTokens": js.usage.total_tokens, 48 | "costTimeMillis": 0 49 | } 50 | 51 | final_response = response.replace('"', "'") 52 | self.usage_profiles.append(usage) 53 | 54 | return final_response, usage 55 | 56 | def act(self, template: str, **kwargs): 57 | 58 | # TODO construct messages here 59 | 60 | llm_query = Template(template).substitute(kwargs) 61 | 62 | messages = [{ 63 | "role": "user", # user / assistant 64 | "content": f"{llm_query}", 65 | }] 66 | 67 | for _ in range(self.max_try): 68 | 69 | try: 70 | lm_output, usage = self.call_lm(messages) 71 | return lm_output 72 | except KeyboardInterrupt: 73 | exit() 74 | except Exception as E: # mostly due to model context window limit 75 | tb = traceback.format_exc() 76 | print(f"Some error happens when calling judger agent: \n{tb}") 77 | 78 | return f"InvalidRequestError" 79 | 80 | -------------------------------------------------------------------------------- /data_curation/llm_agent/openai_rewriter.py: -------------------------------------------------------------------------------- 1 | from .base import LMAgent 2 | import openai 3 | import logging 4 | import traceback 5 | from .datatypes import Action 6 | import backoff 7 | from .prompt_template.Template import QueryRewriterTemplate 8 | from string import Template 9 | import re 10 | from .openai_api import OpenAIClient 11 | 12 | LOGGER = logging.getLogger("Root") 13 | 14 | 15 | class OpenAIRewriterLMAgent(LMAgent): 16 | def __init__(self, api_type, config, model_version="gpt-3.5-turbo"): 17 | super().__init__(config) 18 | assert api_type in ["azure", "openai"] 19 | 20 | if api_type == "openai": 21 | self.api = OpenAIClient(config, model_version) 22 | 23 | self.usage_profiles = [] 24 | self.max_try = 3 25 | 26 | @backoff.on_exception( 27 | backoff.fibo, 28 | # https://platform.openai.com/docs/guides/error-codes/python-library-error-types 29 | ( 30 | openai.APIError, 31 | openai.Timeout, 32 | openai.RateLimitError, 33 | openai.APIConnectionError, 34 | ), 35 | ) 36 | def call_lm(self, messages): 37 | 38 | # Prepend the prompt with the system message 39 | js = self.api.chat_sync(messages=messages) 40 | 41 | response = js.choices[0].message.content 42 | 43 | # usage contains input token, output token, times 44 | usage = { 45 | "promptTokens": js.usage.prompt_tokens, 46 | "completionTokens": js.usage.completion_tokens, 47 | "totalTokens": js.usage.total_tokens, 48 | "costTimeMillis": 0 49 | } 50 | 51 | final_response = response.replace('"', "'") 52 | self.usage_profiles.append(usage) 53 | 54 | return final_response, usage 55 | 56 | def act(self, template: str, **kwargs): 57 | 58 | # TODO construct messages here 59 | 60 | llm_query = Template(template).substitute(kwargs) 61 | 62 | messages = [{ 63 | "role": "user", # user / assistant 64 | "content": f"{llm_query}", 65 | }] 66 | 67 | for _ in range(self.max_try): 68 | 69 | try: 70 | lm_output, usage = self.call_lm(messages) 71 | return lm_output 72 | except KeyboardInterrupt: 73 | exit() 74 | except Exception as E: # mostly due to model context window limit 75 | tb = traceback.format_exc() 76 | print(f"Some error happens when calling rewriter agent: \n{tb}") 77 | 78 | return f"InvalidRequestError" 79 | 80 | -------------------------------------------------------------------------------- /data_curation/llm_agent/openai_generator.py: -------------------------------------------------------------------------------- 1 | from .base import LMAgent 2 | import openai 3 | import logging 4 | import traceback 5 | from .datatypes import Action 6 | import backoff 7 | from .prompt_template.Template import QueryGeneratorTemplate 8 | from string import Template 9 | import json 10 | import re 11 | from .openai_api import OpenAIClient 12 | 13 | LOGGER = logging.getLogger("Root") 14 | 15 | 16 | class OpenAIGeneratorLMAgent(LMAgent): 17 | def __init__(self, api_type, config, model_version="gpt-3.5-turbo"): 18 | super().__init__(config) 19 | assert api_type in ["azure", "openai"] 20 | 21 | if api_type == "openai": 22 | self.api = OpenAIClient(config, model_version) 23 | 24 | self.usage_profiles = [] 25 | self.max_try = 3 26 | 27 | @backoff.on_exception( 28 | backoff.fibo, 29 | # https://platform.openai.com/docs/guides/error-codes/python-library-error-types 30 | ( 31 | openai.APIError, 32 | openai.Timeout, 33 | openai.RateLimitError, 34 | openai.APIConnectionError, 35 | ), 36 | ) 37 | def call_lm(self, messages): 38 | 39 | # Prepend the prompt with the system message 40 | js = self.api.chat_sync(messages=messages) 41 | 42 | response = js.choices[0].message.content 43 | 44 | # usage contains input token, output token, times 45 | usage = { 46 | "promptTokens": js.usage.prompt_tokens, 47 | "completionTokens": js.usage.completion_tokens, 48 | "totalTokens": js.usage.total_tokens, 49 | # donot contain this information 50 | "costTimeMillis": 0 51 | } 52 | 53 | final_response = response.replace('"', "'") 54 | self.usage_profiles.append(usage) 55 | 56 | return final_response, usage 57 | 58 | def act(self, template: str, **kwargs): 59 | 60 | # TODO construct messages here 61 | 62 | llm_query = Template(template).substitute(kwargs) 63 | 64 | messages = [{ 65 | "role": "user", # user / assistant 66 | "content": f"{llm_query}", 67 | }] 68 | 69 | for _ in range(self.max_try): 70 | 71 | try: 72 | lm_output, usage = self.call_lm(messages) 73 | return lm_output 74 | except KeyboardInterrupt: 75 | exit() 76 | except Exception as E: # mostly due to model context window limit 77 | tb = traceback.format_exc() 78 | print(f"Some error happens when calling generator agent: \n{tb}") 79 | 80 | return f"InvalidRequestError" 81 | 82 | -------------------------------------------------------------------------------- /data_curation/llm_agent/prompt_template/UnambiguousGeneratorTemplateLong.txt: -------------------------------------------------------------------------------- 1 | You are provided with an ambiguous question, along with several more specific questions derived from it and the relevant evidences retrieved from external database. Your task is to answer the Original Question by covering all the aspect of the Unambiguous Questions Derived. 2 | 3 | Consider the following when making your decision: 4 | 5 | Consistency with Original Answer: If the original answer from the dataset is still valid and aligns well with the provided evidences, maintain a response that does not deviate significantly from it. If not, give more priority to the ***Answer*** from Unambiguous Questions. 6 | Updating Information: In cases where the original answer is outdated or evidently incorrect based on the new evidences, construct a new, more accurate and reasonable answer. 7 | Balance and Judgment: Exercise balanced judgment in situations where partial updating or reorganization of the original answer is required. 8 | 9 | If the retrieved evidences is not related to the question, please try to get the best information of *** Answer *** from Unambiguous question and incorporate them in your final answer. 10 | NEVER reply that you can not answer the question. 11 | 12 | For example: 13 | 14 | --- 15 | 16 | ### Original Question: 17 | 18 | 'When does bunk\'d come out?' 19 | 20 | ### Unambiguous Questions Derived and Provided Answers: 21 | 22 | Unambiguous Question0: 'When does episode 42 of bunk\'d come out?' 23 | Evidence0: [omitted ...] 24 | 25 | *** Answer0: 'May 24, 2017' *** 26 | 27 | Unambiguous Question1: 'When does episode 41 of bunk\'d come out?' 28 | Evidence1: [omitted ...] 29 | 30 | *** Answer1: 'April 28, 2017' *** 31 | 32 | Unambiguous Question2: 'When does episode 40 of bunk\'d come out?' 33 | Evidence2: [omitted ...] 34 | 35 | *** Answer2: 'April 21, 2017' *** 36 | 37 | ### Original Answer: 38 | 39 | 'The new bunk\'d episode 41 comes out on April 21, 2017, episode 42 comes out on April 28, 2017 and episode 42 is due to come out on May 24, 2017. ' 40 | 41 | ### Your Response: 42 | 43 | The original question did not specify a particular episode. Based on the provided unambiguous question and evidence, the following episodes of bunk\'d have the release dates that; Episode 40 on April 21, 2017; Episode 41 on April 28, 2017; and Episode 42 on May 24, 2017. 44 | 45 | --- 46 | 47 | As shown above, please do not add any superfluous words; only the specific answer is needed. 48 | Now, please answer: 49 | 50 | ### Original Question: 51 | 52 | ${ambiguous_question} 53 | 54 | ### Unambiguous Questions Derived and Provided Answers: 55 | 56 | ${unambiguous_questions_with_answers} 57 | 58 | ### Original Answer: 59 | 60 | ${original_answer} 61 | 62 | ### Your Response: 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | -------------------------------------------------------------------------------- /data_curation/tools/duckduckgo_rapidapi.py: -------------------------------------------------------------------------------- 1 | """Util that calls DockDockGo.""" 2 | from typing import Any, Dict, List, Optional 3 | import os 4 | import requests 5 | from duckduckgo_search import DDGS 6 | import time 7 | 8 | MAX_QUERY_LENGTH = 300 9 | 10 | DockDockGo_DESCRIPTION = """The Wikipedia Search tool provides access to a vast collection of articles covering a wide range of topics. 11 | Can query specific keywords or topics to retrieve accurate and comprehensive information. 12 | """ 13 | 14 | ddgs = DDGS(timeout=20) 15 | 16 | class RapidAPI: 17 | 18 | def __init__(self, rapidapi_name): 19 | self.rapidapi_name = rapidapi_name 20 | 21 | def query(self, text: str, max_results: int): 22 | 23 | time.sleep(1) 24 | response = ddgs.text(text.strip("'"), max_results=10) 25 | return response[:max_results] 26 | 27 | class DDGSQueryRun(): 28 | """Tool that adds the capability to search using the Wikipedia API.""" 29 | 30 | name = "wikipedia_search" 31 | signature = f"{name}(query: str) -> str" 32 | description = DockDockGo_DESCRIPTION 33 | max_results = 5 34 | 35 | def __init__(self, max_results, rapidapi_name="one"): 36 | self.max_results = max_results 37 | self.api_wrapper = RapidAPI(rapidapi_name) 38 | 39 | def __call__( 40 | self, 41 | query: str, 42 | ) -> str: 43 | """Use the DDGS tool.""" 44 | 45 | max_try = -1 # used when the query is bad, and we do not want to retry so many time 46 | 47 | while True: 48 | 49 | try: 50 | output = [r for r in self.api_wrapper.query(query[:MAX_QUERY_LENGTH], max_results=self.max_results)] 51 | break 52 | except Exception as E: 53 | 54 | if isinstance(E, ValueError): 55 | 56 | # the first time encounter bad query 57 | if max_try == -1: 58 | max_try = 5 59 | 60 | if max_try == 0: 61 | output = [] 62 | break 63 | 64 | # cur wrong text "What ... ?|Which ...?|" 65 | query = query.split("|")[0] 66 | max_try -= 1 67 | 68 | print(f"try again. dockdockgo raise the error: {E}") 69 | 70 | evidences = [] 71 | 72 | for ins in output: 73 | evidences.append({ 74 | "title": ins["title"], 75 | "text": ins["description"] if "description" in ins else ins["body"] 76 | }) 77 | 78 | if len(evidences) == 0: 79 | # do not return anything from search engine, add dummy 80 | evidences.append({ 81 | "title": "dummy", 82 | "text": "the search engine did not return anything" 83 | }) 84 | 85 | return evidences 86 | 87 | if __name__ == '__main__': 88 | 89 | engine = DDGSQueryRun(max_results=5) 90 | print(engine("What is the weather today in Beijing?")) 91 | -------------------------------------------------------------------------------- /data_curation/llm_agent/openai_decomposer.py: -------------------------------------------------------------------------------- 1 | from .base import LMAgent 2 | import openai 3 | import logging 4 | import traceback 5 | from .datatypes import Action 6 | import backoff 7 | from .prompt_template.Template import QueryGeneratorTemplate 8 | from string import Template 9 | import json 10 | import re 11 | from .openai_api import OpenAIClient 12 | 13 | LOGGER = logging.getLogger("Root") 14 | 15 | 16 | class OpenAIDecomposerLMAgent(LMAgent): 17 | def __init__(self, api_type, config, model_version="gpt-3.5-turbo"): 18 | super().__init__(config) 19 | assert api_type in ["azure", "openai"] 20 | 21 | if api_type == "openai": 22 | self.api = OpenAIClient(config, model_version) 23 | 24 | self.usage_profiles = [] 25 | self.max_try = 3 26 | 27 | @backoff.on_exception( 28 | backoff.fibo, 29 | # https://platform.openai.com/docs/guides/error-codes/python-library-error-types 30 | ( 31 | openai.APIError, 32 | openai.Timeout, 33 | openai.RateLimitError, 34 | openai.APIConnectionError, 35 | ), 36 | ) 37 | def call_lm(self, messages): 38 | 39 | # Prepend the prompt with the system message 40 | js = self.api.chat_sync(messages=messages) 41 | 42 | response = js.choices[0].message.content 43 | 44 | # usage contains input token, output token, times 45 | usage = { 46 | "promptTokens": js.usage.prompt_tokens, 47 | "completionTokens": js.usage.completion_tokens, 48 | "totalTokens": js.usage.total_tokens, 49 | "costTimeMillis": 0 50 | } 51 | 52 | final_response = response.replace('"', "'") 53 | self.usage_profiles.append(usage) 54 | 55 | return final_response, usage 56 | 57 | def parser_results(self, response): 58 | # assume the response can be splitted by \n, and every element is a dict 59 | response = response.split("\n") 60 | results = [] 61 | 62 | for cur in response: 63 | results.append(cur) 64 | 65 | return results 66 | 67 | def act(self, template: str, **kwargs): 68 | 69 | # TODO construct messages here 70 | 71 | llm_query = Template(template).safe_substitute(kwargs) 72 | 73 | messages = [{ 74 | "role": "user", # user / assistant 75 | "content": f"{llm_query}", 76 | }] 77 | 78 | for _ in range(self.max_try): 79 | 80 | try: 81 | lm_output, usage = self.call_lm(messages) 82 | parsed_results = self.parser_results(lm_output) 83 | 84 | return parsed_results 85 | except KeyboardInterrupt: 86 | exit() 87 | except Exception as e: # mostly due to model context window limit 88 | tb = traceback.format_exc() 89 | print(f"Some error happens when calling generator agent: \n{tb}") 90 | 91 | return f"InvalidRequestError" 92 | 93 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # RQ-RAG: Learning to Refine Queries for Retrieval Augmented Generation 2 | 3 | This is the repo of our paper "RQ-RAG: Learning to Refine Queries for Retrieval Augmented Generation" 4 | 5 | ## Announcement 6 | We are excited to announce that our paper has been accepted to the Conference on Language Modeling (COLM) 2024. Thank you for your interest and support! 7 | 8 | If you find our paper useful, please consider cite our paper: 9 | 10 | ``` 11 | @article{chan2024rq, 12 | title={Rq-rag: Learning to refine queries for retrieval augmented generation}, 13 | author={Chan, Chi-Min and Xu, Chunpu and Yuan, Ruibin and Luo, Hongyin and Xue, Wei and Guo, Yike and Fu, Jie}, 14 | journal={arXiv preprint arXiv:2404.00610}, 15 | year={2024} 16 | } 17 | 18 | ``` 19 | 20 | 21 | 22 | ## Release 23 | 24 | We are also releasing the curated dataset, and the trained Llama2-7B checkpoint, you can download it here; [dataset](https://huggingface.co/datasets/zorowin123/rq_rag), 25 | [checkpoint](https://huggingface.co/zorowin123/rq_rag_llama2_7B). 26 | 27 | 28 | 29 | 30 | 31 | ## Getting Start 32 | 33 | 34 | 1. Installation 35 | 36 | Clone this repository first, and install the dependencies. 37 | 38 | ``` 39 | git clone git@github.com:chanchimin/RQ-RAG.git 40 | cd RQ-RAG 41 | pip install -r requirements.txt 42 | ``` 43 | 44 | 2. Construct Search-Augmented Dataset 45 | 46 | First, set up your openai api key 47 | ``` 48 | export OPENAI_API_KEY="your_api_key_here" 49 | ``` 50 | 51 | Second, preprocess your data to the following format 52 | 53 | ```python 54 | # for multi-turn data, your data should contain messages format, an example: 55 | 56 | { 57 | "id": ..., 58 | "messages": 59 | [ 60 | { 61 | "role":"user", 62 | "content": ..., 63 | }, 64 | { 65 | "role":"assistant", 66 | "content": ..., 67 | } 68 | ] 69 | } 70 | ``` 71 | 72 | Afterward, execute the following lines, make sure to substitute your data path. 73 | 74 | ```shell 75 | cd ./data_curation 76 | 77 | python main_multiturn_answer_generate.py \ 78 | --raw_data_path_multiturn \ 79 | "your file" \ 80 | --ndocs \ 81 | 3 \ 82 | --output_path \ 83 | "your output_path" \ 84 | --search_engine_type \ 85 | duckduckgo \ 86 | --openai_api_key \ 87 | "your key" \ 88 | --overwrite_output 89 | ``` 90 | 91 | After gathering the intermediate results, run: 92 | 93 | ```shell 94 | python merge_intermediate.py \ 95 | --raw_data_path \ 96 | "intermediate data path" \ 97 | --output_path \ 98 | "final data output path" 99 | ``` 100 | 101 | ![](images/data_construction.png) 102 | 103 | 3. Train the model 104 | ```shell 105 | cd .. 106 | bash retrieval_lm/scripts/train/script_finetune_7b.sh 107 | ``` 108 | 109 | 4. Inference and Sample the results 110 | 111 | ```shell 112 | # take hotpotqa as an example 113 | # first inference the model, all the trajectory will be saved to "your_output_dir/final_results.json" 114 | bash retrieval_lm/scripts/inference/search_engine/hotpotqa.sh 115 | # then do the sample 116 | bash retrieval_lm/scripts/sample_from_tree/hotpotqa.sh 117 | ``` 118 | -------------------------------------------------------------------------------- /retrieval_lm/src/index.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import os 8 | import pickle 9 | from typing import List, Tuple 10 | 11 | import faiss 12 | import numpy as np 13 | from tqdm import tqdm 14 | 15 | class Indexer(object): 16 | 17 | def __init__(self, vector_sz, n_subquantizers=0, n_bits=8): 18 | if n_subquantizers > 0: 19 | self.index = faiss.IndexPQ(vector_sz, n_subquantizers, n_bits, faiss.METRIC_INNER_PRODUCT) 20 | else: 21 | self.index = faiss.IndexFlatIP(vector_sz) 22 | #self.index_id_to_db_id = np.empty((0), dtype=np.int64) 23 | self.index_id_to_db_id = [] 24 | 25 | def index_data(self, ids, embeddings): 26 | self._update_id_mapping(ids) 27 | embeddings = embeddings.astype('float32') 28 | if not self.index.is_trained: 29 | self.index.train(embeddings) 30 | self.index.add(embeddings) 31 | 32 | print(f'Total data indexed {len(self.index_id_to_db_id)}') 33 | 34 | def search_knn(self, query_vectors: np.array, top_docs: int, index_batch_size: int = 2048) -> List[Tuple[List[object], List[float]]]: 35 | query_vectors = query_vectors.astype('float32') 36 | result = [] 37 | nbatch = (len(query_vectors)-1) // index_batch_size + 1 38 | for k in tqdm(range(nbatch)): 39 | start_idx = k*index_batch_size 40 | end_idx = min((k+1)*index_batch_size, len(query_vectors)) 41 | q = query_vectors[start_idx: end_idx] 42 | scores, indexes = self.index.search(q, top_docs) 43 | # convert to external ids 44 | db_ids = [[str(self.index_id_to_db_id[i]) for i in query_top_idxs] for query_top_idxs in indexes] 45 | result.extend([(db_ids[i], scores[i]) for i in range(len(db_ids))]) 46 | return result 47 | 48 | def serialize(self, dir_path): 49 | index_file = os.path.join(dir_path, 'index.faiss') 50 | meta_file = os.path.join(dir_path, 'index_meta.faiss') 51 | print(f'Serializing index to {index_file}, meta data to {meta_file}') 52 | 53 | faiss.write_index(self.index, index_file) 54 | with open(meta_file, mode='wb') as f: 55 | pickle.dump(self.index_id_to_db_id, f) 56 | 57 | def deserialize_from(self, dir_path): 58 | index_file = os.path.join(dir_path, 'index.faiss') 59 | meta_file = os.path.join(dir_path, 'index_meta.faiss') 60 | print(f'Loading index from {index_file}, meta data from {meta_file}') 61 | 62 | self.index = faiss.read_index(index_file) 63 | print('Loaded index of type %s and size %d', type(self.index), self.index.ntotal) 64 | 65 | with open(meta_file, "rb") as reader: 66 | self.index_id_to_db_id = pickle.load(reader) 67 | assert len( 68 | self.index_id_to_db_id) == self.index.ntotal, 'Deserialized index_id_to_db_id should match faiss index size' 69 | 70 | def _update_id_mapping(self, db_ids: List): 71 | #new_ids = np.array(db_ids, dtype=np.int64) 72 | #self.index_id_to_db_id = np.concatenate((self.index_id_to_db_id, new_ids), axis=0) 73 | self.index_id_to_db_id.extend(db_ids) -------------------------------------------------------------------------------- /retrieval_lm/src/inbatch.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | import torch 4 | import torch.nn as nn 5 | import numpy as np 6 | import math 7 | import random 8 | import transformers 9 | import logging 10 | import torch.distributed as dist 11 | 12 | from src import contriever, dist_utils, utils 13 | 14 | logger = logging.getLogger(__name__) 15 | 16 | 17 | class InBatch(nn.Module): 18 | def __init__(self, opt, retriever=None, tokenizer=None): 19 | super(InBatch, self).__init__() 20 | 21 | self.opt = opt 22 | self.norm_doc = opt.norm_doc 23 | self.norm_query = opt.norm_query 24 | self.label_smoothing = opt.label_smoothing 25 | if retriever is None or tokenizer is None: 26 | retriever, tokenizer = self._load_retriever( 27 | opt.retriever_model_id, pooling=opt.pooling, random_init=opt.random_init 28 | ) 29 | self.tokenizer = tokenizer 30 | self.encoder = retriever 31 | 32 | def _load_retriever(self, model_id, pooling, random_init): 33 | cfg = utils.load_hf(transformers.AutoConfig, model_id) 34 | tokenizer = utils.load_hf(transformers.AutoTokenizer, model_id) 35 | 36 | if "xlm" in model_id: 37 | model_class = contriever.XLMRetriever 38 | else: 39 | model_class = contriever.Contriever 40 | 41 | if random_init: 42 | retriever = model_class(cfg) 43 | else: 44 | retriever = utils.load_hf(model_class, model_id) 45 | 46 | if "bert-" in model_id: 47 | if tokenizer.bos_token_id is None: 48 | tokenizer.bos_token = "[CLS]" 49 | if tokenizer.eos_token_id is None: 50 | tokenizer.eos_token = "[SEP]" 51 | 52 | retriever.config.pooling = pooling 53 | 54 | return retriever, tokenizer 55 | 56 | def get_encoder(self): 57 | return self.encoder 58 | 59 | def forward(self, q_tokens, q_mask, k_tokens, k_mask, stats_prefix="", iter_stats={}, **kwargs): 60 | 61 | bsz = len(q_tokens) 62 | labels = torch.arange(0, bsz, dtype=torch.long, device=q_tokens.device) 63 | 64 | qemb = self.encoder(input_ids=q_tokens, attention_mask=q_mask, normalize=self.norm_query) 65 | kemb = self.encoder(input_ids=k_tokens, attention_mask=k_mask, normalize=self.norm_doc) 66 | 67 | gather_fn = dist_utils.gather 68 | 69 | gather_kemb = gather_fn(kemb) 70 | 71 | labels = labels + dist_utils.get_rank() * len(kemb) 72 | 73 | scores = torch.einsum("id, jd->ij", qemb / self.opt.temperature, gather_kemb) 74 | 75 | loss = torch.nn.functional.cross_entropy(scores, labels, label_smoothing=self.label_smoothing) 76 | 77 | # log stats 78 | if len(stats_prefix) > 0: 79 | stats_prefix = stats_prefix + "/" 80 | iter_stats[f"{stats_prefix}loss"] = (loss.item(), bsz) 81 | 82 | predicted_idx = torch.argmax(scores, dim=-1) 83 | accuracy = 100 * (predicted_idx == labels).float().mean() 84 | stdq = torch.std(qemb, dim=0).mean().item() 85 | stdk = torch.std(kemb, dim=0).mean().item() 86 | iter_stats[f"{stats_prefix}accuracy"] = (accuracy, bsz) 87 | iter_stats[f"{stats_prefix}stdq"] = (stdq, bsz) 88 | iter_stats[f"{stats_prefix}stdk"] = (stdk, bsz) 89 | 90 | return loss, iter_stats 91 | -------------------------------------------------------------------------------- /data_curation/llm_agent/openai_multiturn_retriever_rewriter.py: -------------------------------------------------------------------------------- 1 | from .base import LMAgent 2 | import openai 3 | import logging 4 | import traceback 5 | from .datatypes import Action 6 | import backoff 7 | from .prompt_template.Template import QueryGeneratorTemplate 8 | from string import Template 9 | import json 10 | import re 11 | from .openai_api import OpenAIClient 12 | 13 | LOGGER = logging.getLogger("Root") 14 | 15 | 16 | class OpenAIMultiTurnLMAgent(LMAgent): 17 | def __init__(self, api_type, config, model_version="gpt-3.5-turbo"): 18 | super().__init__(config) 19 | assert api_type in ["azure", "openai"] 20 | 21 | if api_type == "openai": 22 | self.api = OpenAIClient(config, model_version) 23 | 24 | self.usage_profiles = [] 25 | self.max_try = 3 26 | 27 | @backoff.on_exception( 28 | backoff.fibo, 29 | # https://platform.openai.com/docs/guides/error-codes/python-library-error-types 30 | ( 31 | openai.APIError, 32 | openai.Timeout, 33 | openai.RateLimitError, 34 | openai.APIConnectionError, 35 | ), 36 | ) 37 | def call_lm(self, messages): 38 | 39 | # Prepend the prompt with the system message 40 | js = self.api.chat_sync(messages=messages) 41 | 42 | response = js.choices[0].message.content 43 | 44 | # usage contains input token, output token, times 45 | usage = { 46 | "promptTokens": js.usage.prompt_tokens, 47 | "completionTokens": js.usage.completion_tokens, 48 | "totalTokens": js.usage.total_tokens, 49 | "costTimeMillis": 0 50 | } 51 | 52 | final_response = response.replace('"', "'") 53 | self.usage_profiles.append(usage) 54 | 55 | return final_response, usage 56 | 57 | def parser_results(self, response): 58 | # assume the response can be splitted by \n, and every element is a dict 59 | retrieval_necessity = re.search(r'Retrieval Necessity:\s*(\w+)', response) 60 | queries = re.search(r'Query For Search Engine:\s*([\s\S]+)', response) 61 | 62 | # Output the results 63 | retrieval_necessity = retrieval_necessity.group(1) if retrieval_necessity else None 64 | queries = queries.group(1).split("\n") if queries else None 65 | 66 | if retrieval_necessity in ["yes", "Yes", "YES"]: 67 | retrieval_necessity = True 68 | 69 | if (retrieval_necessity is None) or (queries is None): 70 | return False, [] 71 | 72 | return retrieval_necessity, queries 73 | 74 | def act(self, template: str, **kwargs): 75 | 76 | # TODO construct messages here 77 | 78 | llm_query = Template(template).substitute(kwargs) 79 | 80 | messages = [{ 81 | "role": "user", # user / assistant 82 | "content": f"{llm_query}", 83 | }] 84 | 85 | for _ in range(self.max_try): 86 | 87 | try: 88 | lm_output, usage = self.call_lm(messages) 89 | retrieval_necessity, queries = self.parser_results(lm_output) 90 | 91 | return retrieval_necessity, queries 92 | except KeyboardInterrupt: 93 | exit() 94 | except Exception as E: # mostly due to model context window limit 95 | tb = traceback.format_exc() 96 | print(f"Some error happens when calling generator agent: \n{tb}") 97 | 98 | return False, [] 99 | 100 | -------------------------------------------------------------------------------- /retrieval_lm/generate_embeddings_sag.py: -------------------------------------------------------------------------------- 1 | import jsonlines 2 | import os 3 | import src.normalize_text 4 | import argparse 5 | import csv 6 | import logging 7 | import pickle 8 | import numpy as np 9 | import torch 10 | import transformers 11 | from openai import OpenAI 12 | from tqdm import tqdm 13 | 14 | client = OpenAI() 15 | 16 | def embed_passages(args, input_data, context_num): 17 | 18 | all_embeddings = [] 19 | 20 | for ins_num, ins in tqdm(enumerate(input_data), total=len(input_data)): 21 | context_embeddings_for_cur_ins = [] 22 | 23 | for context in ins["contexts"]: 24 | text = context["title"] + " " + context["paragraph_text"] 25 | text = text.lower() 26 | text = src.normalize_text.normalize(text) 27 | text_embedding = client.embeddings.create(input=[text], model="text-embedding-3-large").data[0].embedding 28 | context_embeddings_for_cur_ins.append(text_embedding) 29 | 30 | if len(context_embeddings_for_cur_ins) < context_num: 31 | while len(context_embeddings_for_cur_ins) < context_num: 32 | context_embeddings_for_cur_ins.append([-100] * len(context_embeddings_for_cur_ins[0])) 33 | 34 | all_embeddings.append(context_embeddings_for_cur_ins) 35 | 36 | all_embeddings = torch.tensor(all_embeddings) 37 | 38 | torch.save(all_embeddings, args.output_file) 39 | assert all_embeddings.shape[0] == len(input_data) 40 | return all_embeddings 41 | 42 | def embed_align_on_the_fly(args, input_data, context_num): 43 | 44 | all_embeddings = [] 45 | 46 | for ins_num, ins in tqdm(enumerate(input_data), total=len(input_data)): 47 | 48 | text = ins["text"].split("*****")[-1] 49 | text = text.lower() 50 | text = src.normalize_text.normalize(text) 51 | text_embedding = client.embeddings.create(input=[text], model="text-embedding-3-large").data[0].embedding 52 | all_embeddings.append(text_embedding) 53 | 54 | all_embeddings = torch.tensor(all_embeddings) 55 | torch.save(all_embeddings, args.output_file) 56 | return all_embeddings 57 | 58 | def main(args, input_data, context_num): 59 | 60 | all_embeddings = embed_passages(args, input_data, context_num) 61 | 62 | 63 | 64 | if __name__ == "__main__": 65 | parser = argparse.ArgumentParser() 66 | 67 | parser.add_argument("--input_file", type=str, default=None, help="Path to passages, eval_data path") 68 | parser.add_argument("--task", type=str, help="because it is data-specific") 69 | parser.add_argument("--output_file", type=str, default="wikipedia_embeddings", help="dir path to save embeddings") 70 | parser.add_argument("--passage_maxlength", type=int, default=512, help="Maximum number of tokens in a passage") 71 | parser.add_argument( 72 | "--model_name_or_path", type=str, help="path to directory containing model weights and config file" 73 | ) 74 | parser.add_argument("--no_title", action="store_true", help="title not added to the passage body") 75 | parser.add_argument("--lowercase", action="store_true", help="lowercase text before encoding") 76 | parser.add_argument("--normalize_text", action="store_true", help="lowercase text before encoding") 77 | 78 | args = parser.parse_args() 79 | 80 | input_data = [] 81 | if args.input_file.endswith("jsonl"): 82 | with jsonlines.open(args.input_file) as f: 83 | for line in f: 84 | input_data.append(line) 85 | 86 | if args.task == "musique": 87 | context_num = 20 88 | else: 89 | context_num = 10 90 | 91 | main(args, input_data, context_num) 92 | -------------------------------------------------------------------------------- /data_curation/llm_agent/prompt_template/UnambiguousGeneratorTemplateShort.txt: -------------------------------------------------------------------------------- 1 | You are provided with a specific question, along with several retrieved evidence from internet which might be relevant or not. Your task is to thoroughly analyze both the questions and the evidence. After this analysis, compare your findings with the original answer previously provided from an original dataset. Based on this comparison, determine if it is necessary to provide a new answer. 2 | 3 | Consider the following when making your decision: 4 | 5 | Consistency with Original Answer: If the original answer from the dataset is still valid and aligns well with the provided evidences, maintain a response that does not deviate significantly from it. 6 | Updating Information: In cases where the original answer is outdated or evidently incorrect based on the new evidences, construct a new, more accurate and reasonable answer. 7 | Balance and Judgment: Exercise balanced judgment in situations where partial updating or reorganization of the original answer is required. 8 | 9 | Analyze the question, evidences, and the original answer. Then directly provide your reorganized answer if you find the retrieved evidences are useful, otherwise, please output the original answers. 10 | 11 | For example: 12 | 13 | 'When does bunk\'d come out?' 14 | 15 | 16 | 17 | --- 18 | 19 | ### Question: 20 | 21 | When does episode 42 of bunk\'d come out? 22 | 23 | ### Retrieved Evidence: 24 | 25 | [{'title': "List of Bunk'd episodes - Wikipedia", 'text': 'July 31, 2015. ( 2015-07-31) 101. 4.24 [6] Emma, Ravi, and Zuri Ross leave the loud and noisy streets of New York to visit Camp Kikiwaka, which is the same camp where their parents met. While Ravi and Zuri meet Jorge and Tiffany, Emma befriends Lou and develops an interest in Xander while competing against Hazel.'}, {'title': "Bunk'd (TV Series 2015-2024) - IMDb", 'text': "Bunk'd: Created by Pamela Eells. With Miranda May, Mallory James Mahoney, Israel Johnson, Will Buie Jr.. Emma, Ravi, and Zuri Ross head off to the rustic summer camp in Maine where their parents met as teens. With their new friends, a love triangle, and the fears of the camp, they settle into their exciting, challenging new lives."}, {'title': "Bunk'd TV Show Air Dates & Track Episodes - Next Episode", 'text': "The Mandalorian. House of the Dragon. The Boys. The Witcher. The Last of Us. The Lord of the Rings: The Rings of Power. Loki. The Umbrella Academy. The Handmaid's Tale."}, {'title': "Season 5 | Bunk'd Wiki | Fandom", 'text': "Season 5 of Bunk'd was officially announced on February 24, 2020. It represents the first fifth season order for a series in Disney Channel history. Production began on September 30, 2020, (coincidently, on the 9th anniversary of Jessie) and wrapped on April 23, 2021. The season premiered on January 15, 2021, and concluded on August 6, 2021. Back for another exciting summer at the sometimes ..."}, {'title': "Bunk'd (TV Series 2015-2024) - Episode list - IMDb", 'text': "Bunk'd. When Gwen returns for a surprise visit to camp, Noah and Destiny hatch a plan to create dangerous situations for Gwen to fix, hoping it will compel her to stay. Lou wins an award from a magazine, but Bill is suspicious of the reporter's motives. Parker and Jake prepare Noah for a date."}] 26 | 27 | ### Original Short Answer from Dataset: 28 | 29 | May 24, 2017 30 | 31 | ### Response: 32 | 33 | May 24, 2017 34 | 35 | --- 36 | 37 | Now, please answer: 38 | 39 | ### Question: 40 | 41 | ${original_query} 42 | 43 | ### Retrieved Evidence: 44 | 45 | ${search_results} 46 | 47 | ### Original Short Answer from Dataset: 48 | 49 | ${original_answers} 50 | 51 | ### Response: 52 | 53 | 54 | 55 | 56 | -------------------------------------------------------------------------------- /data_curation/tools/bm25_candidates.py: -------------------------------------------------------------------------------- 1 | """Util that calls bm25.""" 2 | from typing import Any, Dict, List, Optional 3 | import os 4 | import requests 5 | from rank_bm25 import BM25Okapi 6 | 7 | MAX_QUERY_LENGTH = 300 8 | 9 | BM25_description = """The BM25 Search tool provides access to a vast collection of articles covering a wide range of topics. 10 | Can query specific keywords or topics to retrieve accurate and comprehensive information. 11 | """ 12 | 13 | 14 | class BM25API: 15 | 16 | def query(self, text: str, candidates: list, ): 17 | 18 | tokenized_corpus = [doc.split(" ") for doc in candidates] 19 | 20 | bm25 = BM25Okapi(tokenized_corpus) 21 | tokenized_query = text.split(" ") 22 | 23 | doc_scores = bm25.get_scores(tokenized_query) 24 | 25 | return doc_scores 26 | 27 | 28 | class BM25Run(): 29 | """Tool that adds the capability to search using the Wikipedia API.""" 30 | 31 | name = "wikipedia_search" 32 | signature = f"{name}(query: str) -> str" 33 | description = BM25_description 34 | api_wrapper = BM25API() 35 | max_results = 5 36 | 37 | def __init__(self, max_results): 38 | self.max_results = max_results 39 | 40 | def flatten_corpus(self, corpus): 41 | candidates = [] 42 | for obj in corpus: 43 | candidates.append(f"Title: {obj['title']}\nText: {obj['paragraph_text']}") 44 | 45 | return candidates 46 | 47 | def __call__( 48 | self, 49 | query: str, 50 | corpus: list 51 | ) -> str: 52 | """Use the DDGS tool.""" 53 | 54 | max_try = -1 # used when the query is bad, and we do not want to retry so many time 55 | 56 | candidates = self.flatten_corpus(corpus) 57 | 58 | while True: 59 | 60 | try: 61 | output = [r for r in self.api_wrapper.query(query[:MAX_QUERY_LENGTH], candidates,)] 62 | break 63 | except Exception as E: 64 | 65 | if isinstance(E, ValueError): 66 | 67 | # the first time encounter bad query 68 | if max_try == -1: 69 | max_try = 5 70 | 71 | if max_try == 0: 72 | output = [] 73 | break 74 | 75 | # cur wrong text "What ... ?|Which ...?|" 76 | query = query.split("|")[0] 77 | max_try -= 1 78 | 79 | print(f"try again. BM25 raise the error: {E}") 80 | 81 | evidences = [] 82 | 83 | top_indices = sorted(range(len(output)), key=lambda i: output[i], reverse=True)[:self.max_results] 84 | 85 | for index in top_indices: 86 | evidences.append({ 87 | "title": corpus[index]["title"], 88 | "text": corpus[index]["paragraph_text"] 89 | }) 90 | 91 | if len(evidences) == 0: 92 | # do not return anything from search engine, add dummy 93 | evidences.append({ 94 | "title": "dummy", 95 | "text": "the search engine did not return anything" 96 | }) 97 | 98 | return evidences, top_indices 99 | 100 | if __name__ == '__main__': 101 | 102 | engine = BM25Run(max_results=5) 103 | corpus = [ 104 | {"title": "1", "text": "Hello there good man!"}, 105 | {"title": "2", "text": "It is quite windy in London"}, 106 | {"title": "3", "text": "How is the weather today?"}, 107 | ] 108 | print(engine("What is the weather today in Beijing?", corpus)) 109 | -------------------------------------------------------------------------------- /data_curation/tools/bingsearch_azure.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from pprint import pprint 4 | import requests 5 | from typing import Any, Dict, List, Optional 6 | import os 7 | import requests 8 | 9 | ''' 10 | This sample makes a call to the Bing Web Search API with a query and returns relevant web search. 11 | Documentation: https://docs.microsoft.com/en-us/bing/search-apis/bing-web-search/overview 12 | ''' 13 | 14 | # Add your Bing Search V7 subscription key and endpoint to your environment variables. 15 | os.environ['BING_SEARCH_V7_SUBSCRIPTION_KEY'] = "your key" 16 | os.environ['BING_SEARCH_V7_ENDPOINT'] = "https://api.bing.microsoft.com" 17 | subscription_key = os.environ['BING_SEARCH_V7_SUBSCRIPTION_KEY'] 18 | endpoint = os.environ['BING_SEARCH_V7_ENDPOINT'] + "/v7.0/search" 19 | MAX_QUERY_LENGTH = 300 20 | 21 | BingSearch_DESCRIPTION = """The bing search tool provides access to a vast collection of articles covering a wide range of topics. 22 | Can query specific keywords or topics to retrieve accurate and comprehensive information. 23 | """ 24 | 25 | class BingSearchAPI: 26 | 27 | def query(self, text: str): 28 | 29 | params = {'q': text, 'mkt': 'en-US'} 30 | headers = {'Ocp-Apim-Subscription-Key': subscription_key} 31 | 32 | response = requests.get(endpoint, headers=headers, params=params) 33 | response.raise_for_status() 34 | 35 | if response.status_code == 200: 36 | if len(response.json()['webPages']['value']) == 0: 37 | # do not know why, retry 38 | raise ValueError("status==200, but do not know why return nothing, retry") 39 | return response.json()['webPages']['value'] 40 | 41 | else: 42 | raise ValueError(f" bing search error\n{response}") 43 | 44 | 45 | class BingSearchQueryRun(): 46 | """Tool that adds the capability to search using the Wikipedia API.""" 47 | 48 | name = "bing_search" 49 | signature = f"{name}(query: str) -> str" 50 | description = BingSearch_DESCRIPTION 51 | api_wrapper = BingSearchAPI() 52 | 53 | def __init__(self, max_results): 54 | self.max_results = max_results 55 | 56 | def __call__( 57 | self, 58 | query: str, 59 | ) -> str: 60 | """Use the DDGS tool.""" 61 | 62 | max_try = -1 # used when the query is bad, and we do not want to retry so many time 63 | 64 | while True: 65 | 66 | try: 67 | output = [r for r in self.api_wrapper.query(query[:MAX_QUERY_LENGTH])] 68 | break 69 | except Exception as E: 70 | 71 | if isinstance(E, ValueError): 72 | 73 | # the first time encounter bad query 74 | if max_try == -1: 75 | max_try = 5 76 | 77 | if max_try == 0: 78 | output = [] 79 | break 80 | 81 | # cur wrong text "What ... ?|Which ...?|" 82 | query = query.split("|")[0] 83 | max_try -= 1 84 | 85 | print(f"try again. bing raise the error: {E}") 86 | 87 | evidences = [] 88 | 89 | for ins in output: 90 | evidences.append({ 91 | "title": ins["name"], 92 | "text": ins["snippet"] 93 | }) 94 | 95 | if len(evidences) == 0: 96 | # do not return anything from search engine, add dummy 97 | evidences.append({ 98 | "title": "dummy", 99 | "text": "the search engine did not return anything" 100 | }) 101 | 102 | return evidences[:self.max_results] 103 | 104 | if __name__ == '__main__': 105 | 106 | engine = BingSearchQueryRun(max_results=5) 107 | print(engine("What is the weather today in Beijing?")) 108 | -------------------------------------------------------------------------------- /data_curation/tools/openai_embedding_search.py: -------------------------------------------------------------------------------- 1 | import os 2 | import faiss 3 | from openai import OpenAI 4 | import torch 5 | import numpy as np 6 | import sys 7 | sys.path.append("../..") 8 | import retrieval_lm.src.normalize_text as normalize_text 9 | 10 | 11 | class OpenAIEmbedSearch: 12 | def __init__(self, ndocs, task, args, use_calculated_embeds=True, is_train=False): 13 | self.ndocs = ndocs 14 | self.task = task 15 | self.args = args 16 | self.client = OpenAI() 17 | 18 | # set the question embedding save path and retrieval results save path 19 | if use_calculated_embeds: 20 | 21 | if is_train: 22 | 23 | embeddings_path = os.path.join(os.path.dirname(self.args.input_file), "train_context_embeddings.pt") 24 | self.all_embeddings = torch.load(embeddings_path) 25 | 26 | else: 27 | 28 | embeddings_path = os.path.join(os.path.dirname(self.args.input_file), "test_context_embeddings.pt") 29 | self.all_embeddings = torch.load(embeddings_path) 30 | 31 | def __call__(self, query: str, corpus: list, index: int = None): 32 | 33 | # get the normalized text and embedding 34 | 35 | normalized_query = query.lower() 36 | normalized_query = normalize_text.normalize(normalized_query) 37 | 38 | try: 39 | normalized_query_emb = self.client.embeddings.create(input=[normalized_query], model="text-embedding-3-large").data[0].embedding 40 | 41 | except Exception as E: 42 | print(f"bad request for openai, use dummy input : {E}") 43 | normalized_query_emb = self.client.embeddings.create(input=["dummy"], model="text-embedding-3-large").data[0].embedding 44 | 45 | normalized_query_emb = torch.tensor(normalized_query_emb) 46 | 47 | top_indices = self.find_most_similar_context(normalized_query_emb, index) 48 | 49 | evidences = [] 50 | for top_index in top_indices: 51 | if "title" in corpus[top_index]: 52 | evidences.append({ 53 | "title": corpus[top_index]["title"], 54 | "text": corpus[top_index]["paragraph_text"] 55 | }) 56 | elif "text" in corpus[top_index]: 57 | evidences.append({ 58 | "title": "Retrieved Documents for Reference", 59 | "text": corpus[top_index]["text"].split("*****")[-1] 60 | }) 61 | 62 | if len(evidences) == 0: 63 | # do not return anything from search engine, add dummy 64 | evidences.append({ 65 | "title": "dummy", 66 | "text": "the search engine did not return anything" 67 | }) 68 | 69 | return evidences, top_indices 70 | 71 | def cosine_similarity(self, a, b): 72 | """calc the similarity of a and b""" 73 | dot_product = np.dot(a, b) 74 | norm_a = np.linalg.norm(a) 75 | norm_b = np.linalg.norm(b) 76 | return dot_product / (norm_a * norm_b) 77 | 78 | def find_most_similar_context(self, query_embedding, cur_index): 79 | """find the most relevant context 80 | 81 | parameter: 82 | query_embedding -- query embedding (1D numpy array) 83 | context_embeddings -- context list (2D numpy array) 84 | 85 | return: 86 | the most similar index and embedding 87 | """ 88 | 89 | if cur_index is not None: 90 | cur_context_embeddings = self.all_embeddings[cur_index] 91 | else: 92 | cur_context_embeddings = self.all_embeddings 93 | 94 | similarities = [self.cosine_similarity(query_embedding, context_embedding) for context_embedding in 95 | cur_context_embeddings] 96 | 97 | most_similar_index = sorted(range(len(similarities)), key=lambda i: similarities[i], reverse=True)[:self.ndocs] 98 | 99 | return most_similar_index 100 | -------------------------------------------------------------------------------- /retrieval_lm/src/dist_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | import torch 4 | import torch.distributed as dist 5 | 6 | 7 | class Gather(torch.autograd.Function): 8 | @staticmethod 9 | def forward(ctx, x: torch.tensor): 10 | output = [torch.zeros_like(x) for _ in range(dist.get_world_size())] 11 | dist.all_gather(output, x) 12 | return tuple(output) 13 | 14 | @staticmethod 15 | def backward(ctx, *grads): 16 | all_gradients = torch.stack(grads) 17 | dist.all_reduce(all_gradients) 18 | return all_gradients[dist.get_rank()] 19 | 20 | 21 | def gather(x: torch.tensor): 22 | if not dist.is_initialized(): 23 | return x 24 | x_gather = Gather.apply(x) 25 | x_gather = torch.cat(x_gather, dim=0) 26 | return x_gather 27 | 28 | 29 | @torch.no_grad() 30 | def gather_nograd(x: torch.tensor): 31 | if not dist.is_initialized(): 32 | return x 33 | x_gather = [torch.ones_like(x) for _ in range(dist.get_world_size())] 34 | dist.all_gather(x_gather, x, async_op=False) 35 | 36 | x_gather = torch.cat(x_gather, dim=0) 37 | return x_gather 38 | 39 | 40 | @torch.no_grad() 41 | def varsize_gather_nograd(x: torch.Tensor): 42 | """gather tensors of different sizes along the first dimension""" 43 | if not dist.is_initialized(): 44 | return x 45 | 46 | # determine max size 47 | size = torch.tensor([x.shape[0]], device=x.device, dtype=torch.int) 48 | allsizes = [torch.zeros_like(size) for _ in range(dist.get_world_size())] 49 | dist.all_gather(allsizes, size) 50 | max_size = max([size.cpu().max() for size in allsizes]) 51 | 52 | padded = torch.empty(max_size, *x.shape[1:], dtype=x.dtype, device=x.device) 53 | padded[: x.shape[0]] = x 54 | output = [torch.zeros_like(padded) for _ in range(dist.get_world_size())] 55 | dist.all_gather(output, padded) 56 | 57 | output = [tensor[: allsizes[k]] for k, tensor in enumerate(output)] 58 | output = torch.cat(output, dim=0) 59 | 60 | return output 61 | 62 | 63 | @torch.no_grad() 64 | def get_varsize(x: torch.Tensor): 65 | """gather tensors of different sizes along the first dimension""" 66 | if not dist.is_initialized(): 67 | return [x.shape[0]] 68 | 69 | # determine max size 70 | size = torch.tensor([x.shape[0]], device=x.device, dtype=torch.int) 71 | allsizes = [torch.zeros_like(size) for _ in range(dist.get_world_size())] 72 | dist.all_gather(allsizes, size) 73 | allsizes = torch.cat(allsizes) 74 | return allsizes 75 | 76 | 77 | def get_rank(): 78 | if not dist.is_available(): 79 | return 0 80 | if not dist.is_initialized(): 81 | return 0 82 | return dist.get_rank() 83 | 84 | 85 | def is_main(): 86 | return get_rank() == 0 87 | 88 | 89 | def get_world_size(): 90 | if not dist.is_initialized(): 91 | return 1 92 | else: 93 | return dist.get_world_size() 94 | 95 | 96 | def barrier(): 97 | if dist.is_initialized(): 98 | dist.barrier() 99 | 100 | 101 | def average_main(x): 102 | if not dist.is_initialized(): 103 | return x 104 | if dist.is_initialized() and dist.get_world_size() > 1: 105 | dist.reduce(x, 0, op=dist.ReduceOp.SUM) 106 | if is_main(): 107 | x = x / dist.get_world_size() 108 | return x 109 | 110 | 111 | def sum_main(x): 112 | if not dist.is_initialized(): 113 | return x 114 | if dist.is_initialized() and dist.get_world_size() > 1: 115 | dist.reduce(x, 0, op=dist.ReduceOp.SUM) 116 | return x 117 | 118 | 119 | def weighted_average(x, count): 120 | if not dist.is_initialized(): 121 | if isinstance(x, torch.Tensor): 122 | x = x.item() 123 | return x, count 124 | t_loss = torch.tensor([x * count]).cuda() 125 | t_total = torch.tensor([count]).cuda() 126 | t_loss = sum_main(t_loss) 127 | t_total = sum_main(t_total) 128 | return (t_loss / t_total).item(), t_total.item() 129 | -------------------------------------------------------------------------------- /data_curation/llm_agent/prompt_template/DecomposeGeneratorTemplate_old.txt: -------------------------------------------------------------------------------- 1 | Your task is to effectively decompose complex, multihop questions into simpler, manageable sub-questions or tasks. This process involves breaking down a question that requires information from multiple sources or steps into smaller, more direct questions that can be answered individually. Here's how you should approach this: 2 | 3 | Analyze the Question: Carefully read the multihop question to understand its different components. Identify what specific pieces of information are needed to answer the main question. 4 | 5 | Specifically, we will provide the different type of the complex query which might help you to decompose the question; 6 | 7 | 1. Comparison question; is a type of question that compares two or more entities from the same group in some aspects of the entity. For instance, a comparison question compares two or more people with the date of birth or date of death (e.g., Who was born first, Albert Einstein or Abraham Lincoln?). 8 | 2. Inference question; is created from the two triples (e, r1, e1) and (e1, r2, e2) in the KB. We utilized the logical rule to acquire the new triple (e, r, e2), where r is the inference relation obtained from the two relations r1 and r2. A question–answer pair is created by using the new triple (e, r, e2), its question is created from (e, r) and its answer is e2. For instance, using two triples (Abraham Lincoln, mother, Nancy Hanks Lincoln) and (Nancy Hanks Lincoln, father, James Hanks), we obtain a new triple (Abraham Lincoln, maternal grandfather, James Hanks). A question is: Who is the maternal grandfather of Abraham Lincoln? An answer is James Hanks. 9 | 3. Compositional question; is created from the two triples (e, r1, e1) and (e1, r2, e2) in the KB. Compared with inference question, the difference is that no inference relation r exists from the two relations r1 and r2. For instance, there are two triples (La La Land, distributor, Summit Entertainment) and (Summit Entertainment, founded by, Bernd Eichinger). There is no inference relation r from the two relations distributor and founded-by. In this case, a question is created from the entity e and the two relations r1 and r2: Who is the founder of the company that distributed La La Land film? An answer is the entity e2 of the second triple: Bernd Eichinger. 10 | 4. Bridge-comparison question; is a type of question that combines the bridge question with the comparison question. It requires both finding the bridge entities and doing comparisons to obtain the answer. For instance, instead of directly compare two films, we compare the information of the directors of the two films, e.g., Which movie has the director born first, La La Land or Tenet? To answer this type of question, the model needs to find the bridge entity that connects the two paragraphs, one about the film and one about the director, to get the date of birth information. Then, making a comparison to obtain the final answer. 11 | 12 | Here are an example of how I want you to solve the task: 13 | 14 | ### Multihop Question: 15 | 16 | Are the directors of the films 'Finding Your Feet' and 'The Notorious Miss Lisle' from the same country? 17 | 18 | ### Question Type: 19 | 20 | Bridge-comparison 21 | 22 | ### Evidences Triplets or Supporting Facts that Might Help: 23 | 24 | ['Finding Your Feet', 'director', 'Richard Loncraine'] 25 | ['The Notorious Miss Lisle', 'director', 'James Young'] 26 | ['James Young (director)', 'country of citizenship', 'American'] 27 | ['Richard Loncraine', 'country of citizenship', 'British'] 28 | 29 | ### Decomposed queries: 30 | 31 | "Who is the director of the film 'Finding Your Feet'?" 32 | "Who is the director of the film 'The Notorious Miss Lisle'?" 33 | "What is the nationality of James Young (director)?" 34 | "What is the nationality of Richard Loncraine?" 35 | 36 | --- 37 | 38 | As outlined, please format your answer as multiple lines of text. And make sure every query is a singlehop question. 39 | Ensure there is exactly one line break between each line. 40 | 41 | Now please answer: 42 | 43 | ### Multihop Question: 44 | 45 | ${original_query} 46 | 47 | ### Question Type: 48 | 49 | ${type} 50 | 51 | ### Evidences Triplets or Supporting Facts that Might Help: 52 | 53 | ${evidences} 54 | 55 | ### Decomposed queries: 56 | 57 | 58 | -------------------------------------------------------------------------------- /retrieval_lm/metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import string 3 | import re 4 | from collections import Counter 5 | import re 6 | 7 | def exact_match_score(prediction, ground_truth): 8 | return (normalize_answer(prediction) == normalize_answer(ground_truth)) 9 | 10 | 11 | def metric_max_over_ground_truths(metric_fn, prediction, ground_truths): 12 | scores_for_ground_truths = [] 13 | for ground_truth in ground_truths: 14 | score = metric_fn(prediction, ground_truth) 15 | scores_for_ground_truths.append(score) 16 | return max(scores_for_ground_truths) 17 | 18 | 19 | def accuracy(preds, labels): 20 | match_count = 0 21 | for pred, label in zip(preds, labels): 22 | if pred == label: 23 | match_count += 1 24 | 25 | return match_count / len(preds) 26 | 27 | 28 | def f1(decoded_preds, decoded_labels): 29 | f1_all = [] 30 | for prediction, answers in zip(decoded_preds, decoded_labels): 31 | if type(answers) == list: 32 | if len(answers) == 0: 33 | return 0 34 | f1_all.append(np.max([qa_f1_score(prediction, gt) 35 | for gt in answers])) 36 | else: 37 | f1_all.append(qa_f1_score(prediction, answers)) 38 | return 100 * np.mean(f1_all) 39 | 40 | 41 | def qa_f1_score(prediction, ground_truth): 42 | prediction_tokens = normalize_answer(prediction).split() 43 | ground_truth_tokens = normalize_answer(ground_truth).split() 44 | common = Counter(prediction_tokens) & Counter(ground_truth_tokens) 45 | num_same = sum(common.values()) 46 | if num_same == 0: 47 | return 0 48 | precision = 1.0 * num_same / len(prediction_tokens) 49 | recall = 1.0 * num_same / len(ground_truth_tokens) 50 | f1 = (2 * precision * recall) / (precision + recall) 51 | return f1 52 | 53 | 54 | def normalize_answer(s): 55 | def remove_articles(text): 56 | return re.sub(r'\b(a|an|the)\b', ' ', text) 57 | 58 | def white_space_fix(text): 59 | return ' '.join(text.split()) 60 | 61 | def remove_punc(text): 62 | exclude = set(string.punctuation) 63 | return ''.join(ch for ch in text if ch not in exclude) 64 | 65 | def lower(text): 66 | return text.lower() 67 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 68 | 69 | 70 | def find_entity_tags(sentence): 71 | entity_regex = r'(.+?)(?=\s<|$)' 72 | tag_regex = r'<(.+?)>' 73 | entity_names = re.findall(entity_regex, sentence) 74 | tags = re.findall(tag_regex, sentence) 75 | 76 | results = {} 77 | for entity, tag in zip(entity_names, tags): 78 | if "<" in entity: 79 | results[entity.split("> ")[1]] = tag 80 | else: 81 | results[entity] = tag 82 | return results 83 | 84 | 85 | def match(prediction, ground_truth): 86 | for gt in ground_truth: 87 | if gt in prediction: 88 | return 1 89 | return 0 90 | 91 | 92 | def match_batch(predictions, ground_truths): 93 | 94 | tmp = [] 95 | for prediction, ground_truth in zip(predictions, ground_truths): 96 | if match(prediction, ground_truth["answers"]): 97 | tmp.append(1) 98 | else: 99 | tmp.append(0) 100 | return tmp 101 | 102 | 103 | def calculate_retrieval_em_f1(predicted_support_idxs, gold_support_idxs): 104 | # Taken from hotpot_eval 105 | cur_sp_pred = set(map(int, predicted_support_idxs)) 106 | gold_sp_pred = set(map(int, gold_support_idxs)) 107 | tp, fp, fn = 0, 0, 0 108 | for e in cur_sp_pred: 109 | if e in gold_sp_pred: 110 | tp += 1 111 | else: 112 | fp += 1 113 | for e in gold_sp_pred: 114 | if e not in cur_sp_pred: 115 | fn += 1 116 | prec = 1.0 * tp / (tp + fp) if tp + fp > 0 else 0.0 117 | recall = 1.0 * tp / (tp + fn) if tp + fn > 0 else 0.0 118 | f1 = 2 * prec * recall / (prec + recall) if prec + recall > 0 else 0.0 119 | em = 1.0 if fp + fn == 0 else 0.0 120 | 121 | # In case everything is empty, set both f1, em to be 1.0. 122 | # Without this change, em gets 1 and f1 gets 0 123 | if not cur_sp_pred and not gold_sp_pred: 124 | f1, em = 1.0, 1.0 125 | return f1, em -------------------------------------------------------------------------------- /retrieval_lm/src/slurm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from logging import getLogger 8 | import os 9 | import sys 10 | import torch 11 | import socket 12 | import signal 13 | import subprocess 14 | 15 | 16 | logger = getLogger() 17 | 18 | def sig_handler(signum, frame): 19 | logger.warning("Signal handler called with signal " + str(signum)) 20 | prod_id = int(os.environ['SLURM_PROCID']) 21 | logger.warning("Host: %s - Global rank: %i" % (socket.gethostname(), prod_id)) 22 | if prod_id == 0: 23 | logger.warning("Requeuing job " + os.environ['SLURM_JOB_ID']) 24 | os.system('scontrol requeue ' + os.environ['SLURM_JOB_ID']) 25 | else: 26 | logger.warning("Not the main process, no need to requeue.") 27 | sys.exit(-1) 28 | 29 | 30 | def term_handler(signum, frame): 31 | logger.warning("Signal handler called with signal " + str(signum)) 32 | logger.warning("Bypassing SIGTERM.") 33 | 34 | 35 | def init_signal_handler(): 36 | """ 37 | Handle signals sent by SLURM for time limit / pre-emption. 38 | """ 39 | signal.signal(signal.SIGUSR1, sig_handler) 40 | signal.signal(signal.SIGTERM, term_handler) 41 | 42 | 43 | def init_distributed_mode(params): 44 | """ 45 | Handle single and multi-GPU / multi-node / SLURM jobs. 46 | Initialize the following variables: 47 | - local_rank 48 | - global_rank 49 | - world_size 50 | """ 51 | is_slurm_job = 'SLURM_JOB_ID' in os.environ and not 'WORLD_SIZE' in os.environ 52 | has_local_rank = hasattr(params, 'local_rank') 53 | 54 | # SLURM job without torch.distributed.launch 55 | if is_slurm_job and has_local_rank: 56 | 57 | assert params.local_rank == -1 # on the cluster, this is handled by SLURM 58 | 59 | # local rank on the current node / global rank 60 | params.local_rank = int(os.environ['SLURM_LOCALID']) 61 | params.global_rank = int(os.environ['SLURM_PROCID']) 62 | params.world_size = int(os.environ['SLURM_NTASKS']) 63 | 64 | # define master address and master port 65 | hostnames = subprocess.check_output(['scontrol', 'show', 'hostnames', os.environ['SLURM_JOB_NODELIST']]) 66 | params.main_addr = hostnames.split()[0].decode('utf-8') 67 | assert 10001 <= params.main_port <= 20000 or params.world_size == 1 68 | 69 | # set environment variables for 'env://' 70 | os.environ['MASTER_ADDR'] = params.main_addr 71 | os.environ['MASTER_PORT'] = str(params.main_port) 72 | os.environ['WORLD_SIZE'] = str(params.world_size) 73 | os.environ['RANK'] = str(params.global_rank) 74 | is_distributed = True 75 | 76 | 77 | # multi-GPU job (local or multi-node) - jobs started with torch.distributed.launch 78 | elif has_local_rank and params.local_rank != -1: 79 | 80 | assert params.main_port == -1 81 | 82 | # read environment variables 83 | params.global_rank = int(os.environ['RANK']) 84 | params.world_size = int(os.environ['WORLD_SIZE']) 85 | 86 | is_distributed = True 87 | 88 | # local job (single GPU) 89 | else: 90 | params.local_rank = 0 91 | params.global_rank = 0 92 | params.world_size = 1 93 | is_distributed = False 94 | 95 | # set GPU device 96 | torch.cuda.set_device(params.local_rank) 97 | 98 | # initialize multi-GPU 99 | if is_distributed: 100 | 101 | # http://pytorch.apachecn.org/en/0.3.0/distributed.html#environment-variable-initialization 102 | # 'env://' will read these environment variables: 103 | # MASTER_PORT - required; has to be a free port on machine with rank 0 104 | # MASTER_ADDR - required (except for rank 0); address of rank 0 node 105 | # WORLD_SIZE - required; can be set either here, or in a call to init function 106 | # RANK - required; can be set either here, or in a call to init function 107 | 108 | #print("Initializing PyTorch distributed ...") 109 | torch.distributed.init_process_group( 110 | init_method='env://', 111 | backend='nccl', 112 | #world_size=params.world_size, 113 | #rank=params.global_rank, 114 | ) -------------------------------------------------------------------------------- /retrieval_lm/llama_flash_attn_monkey_patch.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Tuple 2 | 3 | import torch 4 | from torch import nn 5 | 6 | import transformers 7 | from transformers.models.llama.modeling_llama import apply_rotary_pos_emb 8 | 9 | from einops import rearrange 10 | 11 | try: 12 | from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func as flash_attn_unpadded_qkvpacked_func 13 | except ImportError: 14 | from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func 15 | 16 | from flash_attn.bert_padding import unpad_input, pad_input 17 | 18 | 19 | def forward( 20 | self, 21 | hidden_states: torch.Tensor, 22 | attention_mask: Optional[torch.Tensor] = None, 23 | position_ids: Optional[torch.Tensor] = None, 24 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 25 | output_attentions: bool = False, 26 | use_cache: bool = False, 27 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 28 | """Input shape: Batch x Time x Channel 29 | 30 | attention_mask: [bsz, q_len] 31 | """ 32 | bsz, q_len, _ = hidden_states.size() 33 | 34 | query_states = ( 35 | self.q_proj(hidden_states) 36 | .view(bsz, q_len, self.num_heads, self.head_dim) 37 | .transpose(1, 2) 38 | ) 39 | key_states = ( 40 | self.k_proj(hidden_states) 41 | .view(bsz, q_len, self.num_heads, self.head_dim) 42 | .transpose(1, 2) 43 | ) 44 | value_states = ( 45 | self.v_proj(hidden_states) 46 | .view(bsz, q_len, self.num_heads, self.head_dim) 47 | .transpose(1, 2) 48 | ) 49 | # [bsz, q_len, nh, hd] 50 | # [bsz, nh, q_len, hd] 51 | 52 | kv_seq_len = key_states.shape[-2] 53 | assert past_key_value is None, "past_key_value is not supported" 54 | 55 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) 56 | query_states, key_states = apply_rotary_pos_emb( 57 | query_states, key_states, cos, sin, position_ids 58 | ) 59 | # [bsz, nh, t, hd] 60 | assert not output_attentions, "output_attentions is not supported" 61 | assert not use_cache, "use_cache is not supported" 62 | 63 | # Flash attention codes from 64 | # https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attention.py 65 | 66 | # transform the data into the format required by flash attention 67 | qkv = torch.stack( 68 | [query_states, key_states, value_states], dim=2 69 | ) # [bsz, nh, 3, q_len, hd] 70 | qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd] 71 | # We have disabled _prepare_decoder_attention_mask in LlamaModel 72 | # the attention_mask should be the same as the key_padding_mask 73 | key_padding_mask = attention_mask 74 | 75 | if key_padding_mask is None: 76 | qkv = rearrange(qkv, "b s ... -> (b s) ...") 77 | max_s = q_len 78 | cu_q_lens = torch.arange( 79 | 0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device 80 | ) 81 | output = flash_attn_unpadded_qkvpacked_func( 82 | qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True 83 | ) 84 | output = rearrange(output, "(b s) ... -> b s ...", b=bsz) 85 | else: 86 | nheads = qkv.shape[-2] 87 | x = rearrange(qkv, "b s three h d -> b s (three h d)") 88 | x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask) 89 | x_unpad = rearrange( 90 | x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads 91 | ) 92 | output_unpad = flash_attn_unpadded_qkvpacked_func( 93 | x_unpad, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True 94 | ) 95 | output = rearrange( 96 | pad_input( 97 | rearrange(output_unpad, 98 | "nnz h d -> nnz (h d)"), indices, bsz, q_len 99 | ), 100 | "b s (h d) -> b s h d", 101 | h=nheads, 102 | ) 103 | return self.o_proj(rearrange(output, "b s h d -> b s (h d)")), None, None 104 | 105 | 106 | # Disable the transformation of the attention mask in LlamaModel as the flash attention 107 | # requires the attention mask to be the same as the key_padding_mask 108 | def _prepare_decoder_attention_mask( 109 | self, attention_mask, input_shape, inputs_embeds, past_key_values_length 110 | ): 111 | # [bsz, seq_len] 112 | return attention_mask 113 | 114 | 115 | def replace_llama_attn_with_flash_attn(): 116 | transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = ( 117 | _prepare_decoder_attention_mask 118 | ) 119 | transformers.models.llama.modeling_llama.LlamaAttention.forward = forward 120 | -------------------------------------------------------------------------------- /retrieval_lm/src/normalize_text.py: -------------------------------------------------------------------------------- 1 | """ 2 | adapted from chemdataextractor.text.normalize 3 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 4 | Tools for normalizing text. 5 | https://github.com/mcs07/ChemDataExtractor 6 | :copyright: Copyright 2016 by Matt Swain. 7 | :license: MIT 8 | 9 | Permission is hereby granted, free of charge, to any person obtaining 10 | a copy of this software and associated documentation files (the 11 | 'Software'), to deal in the Software without restriction, including 12 | without limitation the rights to use, copy, modify, merge, publish, 13 | distribute, sublicense, and/or sell copies of the Software, and to 14 | permit persons to whom the Software is furnished to do so, subject to 15 | the following conditions: 16 | 17 | The above copyright notice and this permission notice shall be 18 | included in all copies or substantial portions of the Software. 19 | 20 | THE SOFTWARE IS PROVIDED 'AS IS', WITHOUT WARRANTY OF ANY KIND, 21 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 22 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. 23 | IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY 24 | CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, 25 | TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE 26 | SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 27 | """ 28 | 29 | #: Control characters. 30 | CONTROLS = { 31 | '\u0001', '\u0002', '\u0003', '\u0004', '\u0005', '\u0006', '\u0007', '\u0008', '\u000e', '\u000f', '\u0011', 32 | '\u0012', '\u0013', '\u0014', '\u0015', '\u0016', '\u0017', '\u0018', '\u0019', '\u001a', '\u001b', 33 | } 34 | # There are further control characters, but they are instead replaced with a space by unicode normalization 35 | # '\u0009', '\u000a', '\u000b', '\u000c', '\u000d', '\u001c', '\u001d', '\u001e', '\u001f' 36 | 37 | 38 | #: Hyphen and dash characters. 39 | HYPHENS = { 40 | '-', # \u002d Hyphen-minus 41 | '‐', # \u2010 Hyphen 42 | '‑', # \u2011 Non-breaking hyphen 43 | '⁃', # \u2043 Hyphen bullet 44 | '‒', # \u2012 figure dash 45 | '–', # \u2013 en dash 46 | '—', # \u2014 em dash 47 | '―', # \u2015 horizontal bar 48 | } 49 | 50 | #: Minus characters. 51 | MINUSES = { 52 | '-', # \u002d Hyphen-minus 53 | '−', # \u2212 Minus 54 | '-', # \uff0d Full-width Hyphen-minus 55 | '⁻', # \u207b Superscript minus 56 | } 57 | 58 | #: Plus characters. 59 | PLUSES = { 60 | '+', # \u002b Plus 61 | '+', # \uff0b Full-width Plus 62 | '⁺', # \u207a Superscript plus 63 | } 64 | 65 | #: Slash characters. 66 | SLASHES = { 67 | '/', # \u002f Solidus 68 | '⁄', # \u2044 Fraction slash 69 | '∕', # \u2215 Division slash 70 | } 71 | 72 | #: Tilde characters. 73 | TILDES = { 74 | '~', # \u007e Tilde 75 | '˜', # \u02dc Small tilde 76 | '⁓', # \u2053 Swung dash 77 | '∼', # \u223c Tilde operator #in mbert vocab 78 | '∽', # \u223d Reversed tilde 79 | '∿', # \u223f Sine wave 80 | '〜', # \u301c Wave dash #in mbert vocab 81 | '~', # \uff5e Full-width tilde #in mbert vocab 82 | } 83 | 84 | #: Apostrophe characters. 85 | APOSTROPHES = { 86 | "'", # \u0027 87 | '’', # \u2019 88 | '՚', # \u055a 89 | 'Ꞌ', # \ua78b 90 | 'ꞌ', # \ua78c 91 | ''', # \uff07 92 | } 93 | 94 | #: Single quote characters. 95 | SINGLE_QUOTES = { 96 | "'", # \u0027 97 | '‘', # \u2018 98 | '’', # \u2019 99 | '‚', # \u201a 100 | '‛', # \u201b 101 | 102 | } 103 | 104 | #: Double quote characters. 105 | DOUBLE_QUOTES = { 106 | '"', # \u0022 107 | '“', # \u201c 108 | '”', # \u201d 109 | '„', # \u201e 110 | '‟', # \u201f 111 | } 112 | 113 | #: Accent characters. 114 | ACCENTS = { 115 | '`', # \u0060 116 | '´', # \u00b4 117 | } 118 | 119 | #: Prime characters. 120 | PRIMES = { 121 | '′', # \u2032 122 | '″', # \u2033 123 | '‴', # \u2034 124 | '‵', # \u2035 125 | '‶', # \u2036 126 | '‷', # \u2037 127 | '⁗', # \u2057 128 | } 129 | 130 | #: Quote characters, including apostrophes, single quotes, double quotes, accents and primes. 131 | QUOTES = APOSTROPHES | SINGLE_QUOTES | DOUBLE_QUOTES | ACCENTS | PRIMES 132 | 133 | def normalize(text): 134 | for control in CONTROLS: 135 | text = text.replace(control, '') 136 | text = text.replace('\u000b', ' ').replace('\u000c', ' ').replace(u'\u0085', ' ') 137 | 138 | for hyphen in HYPHENS | MINUSES: 139 | text = text.replace(hyphen, '-') 140 | text = text.replace('\u00ad', '') 141 | 142 | for double_quote in DOUBLE_QUOTES: 143 | text = text.replace(double_quote, '"') # \u0022 144 | for single_quote in (SINGLE_QUOTES | APOSTROPHES | ACCENTS): 145 | text = text.replace(single_quote, "'") # \u0027 146 | text = text.replace('′', "'") # \u2032 prime 147 | text = text.replace('‵', "'") # \u2035 reversed prime 148 | text = text.replace('″', "''") # \u2033 double prime 149 | text = text.replace('‶', "''") # \u2036 reversed double prime 150 | text = text.replace('‴', "'''") # \u2034 triple prime 151 | text = text.replace('‷', "'''") # \u2037 reversed triple prime 152 | text = text.replace('⁗', "''''") # \u2057 quadruple prime 153 | 154 | text = text.replace('…', '...').replace(' . . . ', ' ... ') # \u2026 155 | 156 | for slash in SLASHES: 157 | text = text.replace(slash, '/') 158 | 159 | #for tilde in TILDES: 160 | # text = text.replace(tilde, '~') 161 | 162 | return text 163 | -------------------------------------------------------------------------------- /retrieval_lm/src/moco.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | import torch 4 | import torch.nn as nn 5 | import logging 6 | import copy 7 | import transformers 8 | 9 | from src import contriever, dist_utils, utils 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | class MoCo(nn.Module): 15 | def __init__(self, opt): 16 | super(MoCo, self).__init__() 17 | 18 | self.queue_size = opt.queue_size 19 | self.momentum = opt.momentum 20 | self.temperature = opt.temperature 21 | self.label_smoothing = opt.label_smoothing 22 | self.norm_doc = opt.norm_doc 23 | self.norm_query = opt.norm_query 24 | self.moco_train_mode_encoder_k = opt.moco_train_mode_encoder_k # apply the encoder on keys in train mode 25 | 26 | retriever, tokenizer = self._load_retriever( 27 | opt.retriever_model_id, pooling=opt.pooling, random_init=opt.random_init 28 | ) 29 | 30 | self.tokenizer = tokenizer 31 | self.encoder_q = retriever 32 | self.encoder_k = copy.deepcopy(retriever) 33 | 34 | for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()): 35 | param_k.data.copy_(param_q.data) 36 | param_k.requires_grad = False 37 | 38 | # create the queue 39 | self.register_buffer("queue", torch.randn(opt.projection_size, self.queue_size)) 40 | self.queue = nn.functional.normalize(self.queue, dim=0) 41 | 42 | self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long)) 43 | 44 | def _load_retriever(self, model_id, pooling, random_init): 45 | cfg = utils.load_hf(transformers.AutoConfig, model_id) 46 | tokenizer = utils.load_hf(transformers.AutoTokenizer, model_id) 47 | 48 | if "xlm" in model_id: 49 | model_class = contriever.XLMRetriever 50 | else: 51 | model_class = contriever.Contriever 52 | 53 | if random_init: 54 | retriever = model_class(cfg) 55 | else: 56 | retriever = utils.load_hf(model_class, model_id) 57 | 58 | if "bert-" in model_id: 59 | if tokenizer.bos_token_id is None: 60 | tokenizer.bos_token = "[CLS]" 61 | if tokenizer.eos_token_id is None: 62 | tokenizer.eos_token = "[SEP]" 63 | 64 | retriever.config.pooling = pooling 65 | 66 | return retriever, tokenizer 67 | 68 | def get_encoder(self, return_encoder_k=False): 69 | if return_encoder_k: 70 | return self.encoder_k 71 | else: 72 | return self.encoder_q 73 | 74 | def _momentum_update_key_encoder(self): 75 | """ 76 | Update of the key encoder 77 | """ 78 | for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()): 79 | param_k.data = param_k.data * self.momentum + param_q.data * (1.0 - self.momentum) 80 | 81 | @torch.no_grad() 82 | def _dequeue_and_enqueue(self, keys): 83 | # gather keys before updating queue 84 | keys = dist_utils.gather_nograd(keys.contiguous()) 85 | 86 | batch_size = keys.shape[0] 87 | 88 | ptr = int(self.queue_ptr) 89 | assert self.queue_size % batch_size == 0, f"{batch_size}, {self.queue_size}" # for simplicity 90 | 91 | # replace the keys at ptr (dequeue and enqueue) 92 | self.queue[:, ptr : ptr + batch_size] = keys.T 93 | ptr = (ptr + batch_size) % self.queue_size # move pointer 94 | 95 | self.queue_ptr[0] = ptr 96 | 97 | def _compute_logits(self, q, k): 98 | l_pos = torch.einsum("nc,nc->n", [q, k]).unsqueeze(-1) 99 | l_neg = torch.einsum("nc,ck->nk", [q, self.queue.clone().detach()]) 100 | 101 | logits = torch.cat([l_pos, l_neg], dim=1) 102 | return logits 103 | 104 | def forward(self, q_tokens, q_mask, k_tokens, k_mask, stats_prefix="", iter_stats={}, **kwargs): 105 | bsz = q_tokens.size(0) 106 | 107 | q = self.encoder_q(input_ids=q_tokens, attention_mask=q_mask, normalize=self.norm_query) 108 | 109 | # compute key features 110 | with torch.no_grad(): # no gradient to keys 111 | self._momentum_update_key_encoder() # update the key encoder 112 | 113 | if not self.encoder_k.training and not self.moco_train_mode_encoder_k: 114 | self.encoder_k.eval() 115 | 116 | k = self.encoder_k(input_ids=k_tokens, attention_mask=k_mask, normalize=self.norm_doc) 117 | 118 | logits = self._compute_logits(q, k) / self.temperature 119 | 120 | # labels: positive key indicators 121 | labels = torch.zeros(bsz, dtype=torch.long).cuda() 122 | 123 | loss = torch.nn.functional.cross_entropy(logits, labels, label_smoothing=self.label_smoothing) 124 | 125 | self._dequeue_and_enqueue(k) 126 | 127 | # log stats 128 | if len(stats_prefix) > 0: 129 | stats_prefix = stats_prefix + "/" 130 | iter_stats[f"{stats_prefix}loss"] = (loss.item(), bsz) 131 | 132 | predicted_idx = torch.argmax(logits, dim=-1) 133 | accuracy = 100 * (predicted_idx == labels).float().mean() 134 | stdq = torch.std(q, dim=0).mean().item() 135 | stdk = torch.std(k, dim=0).mean().item() 136 | iter_stats[f"{stats_prefix}accuracy"] = (accuracy, bsz) 137 | iter_stats[f"{stats_prefix}stdq"] = (stdq, bsz) 138 | iter_stats[f"{stats_prefix}stdk"] = (stdk, bsz) 139 | 140 | return loss, iter_stats 141 | -------------------------------------------------------------------------------- /retrieval_lm/src/contriever.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | import os 4 | import torch 5 | import transformers 6 | from transformers import BertModel, XLMRobertaModel 7 | 8 | from src import utils 9 | 10 | 11 | class Contriever(BertModel): 12 | def __init__(self, config, pooling="average", **kwargs): 13 | super().__init__(config, add_pooling_layer=False) 14 | if not hasattr(config, "pooling"): 15 | self.config.pooling = pooling 16 | 17 | def forward( 18 | self, 19 | input_ids=None, 20 | attention_mask=None, 21 | token_type_ids=None, 22 | position_ids=None, 23 | head_mask=None, 24 | inputs_embeds=None, 25 | encoder_hidden_states=None, 26 | encoder_attention_mask=None, 27 | output_attentions=None, 28 | output_hidden_states=None, 29 | normalize=False, 30 | ): 31 | 32 | model_output = super().forward( 33 | input_ids=input_ids, 34 | attention_mask=attention_mask, 35 | token_type_ids=token_type_ids, 36 | position_ids=position_ids, 37 | head_mask=head_mask, 38 | inputs_embeds=inputs_embeds, 39 | encoder_hidden_states=encoder_hidden_states, 40 | encoder_attention_mask=encoder_attention_mask, 41 | output_attentions=output_attentions, 42 | output_hidden_states=output_hidden_states, 43 | ) 44 | 45 | last_hidden = model_output["last_hidden_state"] 46 | last_hidden = last_hidden.masked_fill(~attention_mask[..., None].bool(), 0.0) 47 | 48 | if self.config.pooling == "average": 49 | emb = last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None] 50 | elif self.config.pooling == "cls": 51 | emb = last_hidden[:, 0] 52 | 53 | if normalize: 54 | emb = torch.nn.functional.normalize(emb, dim=-1) 55 | return emb 56 | 57 | 58 | class XLMRetriever(XLMRobertaModel): 59 | def __init__(self, config, pooling="average", **kwargs): 60 | super().__init__(config, add_pooling_layer=False) 61 | if not hasattr(config, "pooling"): 62 | self.config.pooling = pooling 63 | 64 | def forward( 65 | self, 66 | input_ids=None, 67 | attention_mask=None, 68 | token_type_ids=None, 69 | position_ids=None, 70 | head_mask=None, 71 | inputs_embeds=None, 72 | encoder_hidden_states=None, 73 | encoder_attention_mask=None, 74 | output_attentions=None, 75 | output_hidden_states=None, 76 | normalize=False, 77 | ): 78 | 79 | model_output = super().forward( 80 | input_ids=input_ids, 81 | attention_mask=attention_mask, 82 | token_type_ids=token_type_ids, 83 | position_ids=position_ids, 84 | head_mask=head_mask, 85 | inputs_embeds=inputs_embeds, 86 | encoder_hidden_states=encoder_hidden_states, 87 | encoder_attention_mask=encoder_attention_mask, 88 | output_attentions=output_attentions, 89 | output_hidden_states=output_hidden_states, 90 | ) 91 | 92 | last_hidden = model_output["last_hidden_state"] 93 | last_hidden = last_hidden.masked_fill(~attention_mask[..., None].bool(), 0.0) 94 | if self.config.pooling == "average": 95 | emb = last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None] 96 | elif self.config.pooling == "cls": 97 | emb = last_hidden[:, 0] 98 | if normalize: 99 | emb = torch.nn.functional.normalize(emb, dim=-1) 100 | return emb 101 | 102 | 103 | def load_retriever(model_path, pooling="average", random_init=False): 104 | # try: check if model exists locally 105 | path = os.path.join(model_path, "checkpoint.pth") 106 | if os.path.exists(path): 107 | pretrained_dict = torch.load(path, map_location="cpu") 108 | opt = pretrained_dict["opt"] 109 | if hasattr(opt, "retriever_model_id"): 110 | retriever_model_id = opt.retriever_model_id 111 | else: 112 | # retriever_model_id = "bert-base-uncased" 113 | retriever_model_id = "bert-base-multilingual-cased" 114 | tokenizer = utils.load_hf(transformers.AutoTokenizer, retriever_model_id) 115 | cfg = utils.load_hf(transformers.AutoConfig, retriever_model_id) 116 | if "xlm" in retriever_model_id: 117 | model_class = XLMRetriever 118 | else: 119 | model_class = Contriever 120 | retriever = model_class(cfg) 121 | pretrained_dict = pretrained_dict["model"] 122 | 123 | if any("encoder_q." in key for key in pretrained_dict.keys()): # test if model is defined with moco class 124 | pretrained_dict = {k.replace("encoder_q.", ""): v for k, v in pretrained_dict.items() if "encoder_q." in k} 125 | elif any("encoder." in key for key in pretrained_dict.keys()): # test if model is defined with inbatch class 126 | pretrained_dict = {k.replace("encoder.", ""): v for k, v in pretrained_dict.items() if "encoder." in k} 127 | retriever.load_state_dict(pretrained_dict, strict=False) 128 | else: 129 | retriever_model_id = model_path 130 | if "xlm" in retriever_model_id: 131 | model_class = XLMRetriever 132 | else: 133 | model_class = Contriever 134 | cfg = utils.load_hf(transformers.AutoConfig, model_path) 135 | tokenizer = utils.load_hf(transformers.AutoTokenizer, model_path) 136 | retriever = utils.load_hf(model_class, model_path) 137 | 138 | return retriever, tokenizer, retriever_model_id 139 | -------------------------------------------------------------------------------- /data_curation/llm_agent/prompt_template/DecomposeGeneratorTemplate.txt: -------------------------------------------------------------------------------- 1 | Your task is to effectively decompose complex, multihop questions into simpler, manageable sub-questions or tasks. This process involves breaking down a question that requires information from multiple sources or steps into smaller, more direct questions that can be answered individually. Here's how you should approach this: 2 | 3 | Analyze the Question: Carefully read the multihop question to understand its different components. Identify what specific pieces of information are needed to answer the main question. 4 | 5 | Here are an example of how I want you to solve the task: 6 | 7 | ### Provided Contexts: 8 | 9 | Title: Ishqbaaaz 10 | Text: Ishqbaaaz (English: "Lovers") is an Indian drama television series which is broadcast on Star Plus. It premiered on 27 June 2016 and airs Mon-Fri 10-11pm IST.Nakuul Mehta, Kunal Jaisingh and Leenesh Mattoo respectively portray Shivaay, Omkara and Rudra, the three heirs of the Oberoi family. The show initially focused on the tale of three brothers, later become centered on the love story of Shivaay and Annika (Surbhi Chandna); with the story of Omkara and Rudra being shifted to the spinoff series "Dil Boley Oberoi". In July 2017 "Dil Boley Oberoi" ended and the storylines were merged back into "Ishqbaaaz" which doubled its runtime. 11 | Title: Hotel Tallcorn 12 | Text: The Hotel Tallcorn is located in Marshalltown, Iowa. Today it is called the Tallcorn Towers Apartments. Built in 1928 by the Eppley Hotel Company, local citizens contributed $120,000 to ensure the successful completion of this seven-story hotel. It was completed in connection to the seventy-fifth anniversary of Marshalltown. The hotel's sale in 1956 from the Eppley chain to the Sheraton Corporation was part of the second largest hotel sale in United States history. The Tallcorn was listed as a contributing property in the Marshalltown Downtown Historic District on the National Register of Historic Places in 2002. 13 | Title: The Oberoi Group 14 | Text: The Oberoi Group is a hotel company with its head office in Delhi. Founded in 1934, the company owns and/or operates 30+ luxury hotels and two river cruise ships in six countries, primarily under its Oberoi Hotels & Resorts and Trident Hotels brands. 15 | Title: Hotel Bond 16 | Text: Hotel Bond is a historic hotel, built in two stages in 1913 and 1921, in downtown Hartford, Connecticut by hotelier Harry S. Bond. It is located near Bushnell Park, and was considered the grandest hotel in Hartford during its heyday. The second section is a 12 story building attached to the 6 story first section. A Statler Hotel opened in the area in 1954, creating competition, and the Bond Hotel company declared bankruptcy shortly after that. It was bought by the California-based Masaglia Hotel chain, which began an incremental renovation program. In 1964 it was sold to a Cincinnati, Ohio investment group which announced extensive renovation plans. However, the financing plans fell through and the hotel was again in bankruptcy. The building was sold at auction to the Roman Catholic Archdiocese of Hartford in 1965, and it became the home of the Saint Francis Hospital School of Nursing. The Bond Ballroom reopened in 2001, with the rest of the building becoming a Homewood Suites by Hilton in 2006. 17 | Title: Future Fibre Technologies 18 | Text: Future Fibre Technologies (FFT) is a fiber optic sensing technologies company based in Melbourne, Australia, with its US head office in Mountain View, California, Middle East head office in Dubai, Indian head office in New Delhi and European head office in London. Founded in 1994, Future Fibre Technologies product line provides optical fiber intrusion detection systems for perimeters, buried oil and gas pipelines and data communication networks. 19 | Title: Mohan Singh Oberoi 20 | Text: Rai Bahadur Mohan Singh Oberoi (15 August 1898 – 3 May 2002) was an Indian hotelier, the founder and chairman of Oberoi Hotels & Resorts, India's second-largest hotel company, with 35 hotels in India, Sri Lanka, Nepal, Egypt, Australia and Hungary. 21 | Title: Oberoi family 22 | Text: The Oberoi family is an Indian family that is famous for its involvement in hotels, namely through The Oberoi Group. 23 | Title: 289th Military Police Company 24 | Text: The 289th Military Police Company was activated on 1 November 1994 and attached to Hotel Company, 3rd Infantry (The Old Guard), Fort Myer, Virginia. Hotel Company is the regiment's specialty company. 25 | Title: Ritz-Carlton Jakarta 26 | Text: The Ritz-Carlton Jakarta is a hotel and skyscraper in Jakarta, Indonesia and 14th Tallest building in Jakarta. It is located in city center of Jakarta, near Mega Kuningan, adjacent to the sister JW Marriott Hotel. It is operated by The Ritz-Carlton Hotel Company. The complex has two towers that comprises a hotel and the Airlangga Apartment respectively. The hotel was opened in 2005. 27 | Title: Glennwanis Hotel 28 | Text: The Glennwanis Hotel is a historic hotel in Glennville, Georgia, Tattnall County, Georgia, built on the site of the Hughes Hotel. The hotel is located at 209-215 East Barnard Street. The old Hughes Hotel was built out of Georgia pine circa 1905 and burned in 1920. The Glennwanis was built in brick in 1926. The local Kiwanis club led the effort to get the replacement hotel built, and organized a Glennville Hotel Company with directors being local business leaders. The wife of a local doctor won a naming contest with the name "Glennwanis Hotel", a suggestion combining "Glennville" and "Kiwanis". 29 | 30 | ### Multihop Question: 31 | 32 | The Oberoi family is part of a hotel company that has a head office in what city? 33 | 34 | ### Decomposed queries: 35 | 36 | "Which hotel company is the Oberoi family part of?" 37 | "Where is the head office of The Oberoi Group located?" 38 | 39 | --- 40 | 41 | As outlined, please format your answer as multiple lines of text. Ensure that each subsequent question follows from the previous one and is self-contained and be capable of being answered on its own. 42 | Ensure there is exactly one line break between each line. 43 | 44 | Now please answer: 45 | 46 | ### Provided Contexts: 47 | 48 | ${contexts} 49 | 50 | ### Multihop Question: 51 | 52 | ${original_query} 53 | 54 | ### Decomposed queries: 55 | 56 | 57 | 58 | -------------------------------------------------------------------------------- /retrieval_lm/utils.py: -------------------------------------------------------------------------------- 1 | import jsonlines 2 | import json 3 | import copy 4 | import re 5 | 6 | system_prompt = { 7 | "arc_challenge": "Given four answer candidates, A, B, C and D, choose the best answer choice.", 8 | "popqa_longtail_w_gs": "Please answer the question.", 9 | "openbookqa": "Given four answer candidates, A, B, C and D, choose the best answer choice.", 10 | "hotpotqa": "Given a question that requires multi-hop reasoning, you need to decompose the question and answer based on the given context. Please provide a short and concise response.", 11 | "2wikimultihopqa":"Given a question that requires multi-hop reasoning, you need to decompose the question and answer based on the given context. Please provide a short and concise response.", 12 | "musique": "Given a question that requires multi-hop reasoning, you need to decompose the question and answer based on the given context. Please provide a short and concise response.", 13 | } 14 | 15 | 16 | def preprocess_single_data(user_query, tokenizer, previous_history=""): 17 | 18 | user_query = "<|user|>\n" + user_query + tokenizer.eos_token + "\n" 19 | user_query += "<|assistant|>\n" 20 | 21 | return [user_query] 22 | 23 | 24 | def preprocess_multi_turn_inference_data(user_query, previous_history, tokenizer, task, turn, retrieve_type): 25 | 26 | # apparently multi_turn inference is hard to support batch decoding. 27 | 28 | if task in ["mt_bench", "alpaca_eval"]: 29 | 30 | if previous_history == "": 31 | eval_example = "" 32 | else: 33 | eval_example = "" 34 | previous_history += "\n" 35 | eval_example += previous_history + "<|user|>\n" + user_query[turn] + tokenizer.eos_token + "\n" 36 | 37 | if retrieve_type == "always": 38 | eval_example += "<|assistant|>\n[S_Rewritten_Query]" 39 | elif retrieve_type == "never": 40 | eval_example += "<|assistant|>\n[A_Response]" 41 | elif retrieve_type == "pure": 42 | eval_example += "<|assistant|>\n" 43 | 44 | else: 45 | raise ValueError(f"task {task} do not in the lists") 46 | 47 | return [eval_example] 48 | 49 | 50 | def preprocess_eval_data(row, tokenizer, task): 51 | 52 | eval_examples = [] 53 | 54 | for cur in row: 55 | if task in ["popqa_longtail_w_gs", "popqa_longtail_w_gs_may_refers_to"]: 56 | eval_example = f"<|system|>\n{system_prompt[task]}" + tokenizer.eos_token + "\n<|user|>\n" + cur["question"] + tokenizer.eos_token + "\n" 57 | eval_example += "<|assistant|>\n" 58 | 59 | elif task in ["hotpotqa", "2wikimultihopqa", "musique"]: 60 | eval_example = f"<|system|>\n{system_prompt[task]}" + tokenizer.eos_token + "\n<|user|>\n" + cur["question_text"] + tokenizer.eos_token + "\n" 61 | eval_example += "<|assistant|>\n" 62 | 63 | elif task == "arc_challenge": 64 | label_dict = { 65 | "A": "A", 66 | "B": "B", 67 | "C": "C", 68 | "D": "D", 69 | "1": "A", 70 | "2": "B", 71 | "3": "C", 72 | "4": "D", 73 | } 74 | 75 | if cur["question"].endswith("?"): 76 | user_query = f"{cur['question']}" 77 | else: 78 | user_query = f"{cur['question']}?" 79 | 80 | user_query += "\n" 81 | user_query += "Please choose from following options:\n" 82 | 83 | for option_index, option_text in enumerate(cur["choices"]["text"]): 84 | if cur["choices"]["label"][option_index] not in label_dict: 85 | print(cur["choices"]["label"][option_index]) 86 | continue 87 | 88 | user_query += "{0}: {1}\n".format(label_dict[cur["choices"]["label"][option_index]], option_text) 89 | 90 | eval_example = f"<|system|>\n{system_prompt[task]}" + tokenizer.eos_token + "\n<|user|>\n" + user_query + tokenizer.eos_token + "\n" 91 | eval_example += "<|assistant|>\n" 92 | 93 | elif task == "openbookqa": 94 | 95 | label_dict = { 96 | "A": "A", 97 | "B": "B", 98 | "C": "C", 99 | "D": "D", 100 | "1": "A", 101 | "2": "B", 102 | "3": "C", 103 | "4": "D", 104 | } 105 | 106 | if cur["question_stem"].endswith("?"): 107 | user_query = f"{cur['question_stem']}" 108 | else: 109 | user_query = f"{cur['question_stem']}?" 110 | 111 | user_query += "\n" 112 | user_query += "Please choose from following options:\n" 113 | 114 | for option_index, option_text in enumerate(cur["choices"]["text"]): 115 | if cur["choices"]["label"][option_index] not in label_dict: 116 | print(cur["choices"]["label"][option_index]) 117 | continue 118 | 119 | user_query += "{0}: {1}\n".format(label_dict[cur["choices"]["label"][option_index]], option_text) 120 | 121 | eval_example = f"<|system|>\n{system_prompt[task]}" + tokenizer.eos_token + "\n<|user|>\n" + user_query + tokenizer.eos_token + "\n" 122 | eval_example += "<|assistant|>\n" 123 | 124 | else: 125 | raise ValueError(f"task {task} do not in the lists") 126 | 127 | eval_examples.append(eval_example) 128 | 129 | return eval_examples 130 | 131 | 132 | def load_sag_special_tokens(tokenizer): 133 | 134 | special_tokens_dict = {} 135 | 136 | for token in tokenizer.additional_special_tokens: 137 | special_tokens_dict[token] = tokenizer.convert_tokens_to_ids(token) 138 | 139 | return special_tokens_dict 140 | 141 | 142 | def fix_spacing(input_text): 143 | # Add a space after periods that lack whitespace 144 | output_text = re.sub(r'(?<=\w)([.!?])(?=\w)', r'\1 ', input_text) 145 | return output_text 146 | 147 | 148 | def load_jsonlines(file): 149 | with jsonlines.open(file, 'r') as jsonl_f: 150 | lst = [obj for obj in jsonl_f] 151 | return lst 152 | 153 | 154 | def load_file(input_fp): 155 | if input_fp.endswith(".json"): 156 | input_data = json.load(open(input_fp)) 157 | else: 158 | input_data = load_jsonlines(input_fp) 159 | return input_data 160 | 161 | 162 | def save_file_jsonl(data, fp): 163 | with jsonlines.open(fp, mode='w') as writer: 164 | writer.write_all(data) 165 | -------------------------------------------------------------------------------- /data_curation/llm_agent/datatypes.py: -------------------------------------------------------------------------------- 1 | import enum 2 | from dataclasses import dataclass 3 | from typing import List, Dict, Any, Optional, Tuple 4 | from collections import defaultdict 5 | 6 | @dataclass(frozen=True) 7 | class Action: 8 | value: str # LM returned string for now 9 | use_tool: bool # if use_tool == False -> propose answer 10 | error: Optional[str] = None 11 | 12 | def to_dict(self) -> Dict[str, Any]: 13 | return { 14 | "action_value": self.value, 15 | "action_use_tool": self.use_tool, 16 | "action_error": self.error, 17 | } 18 | 19 | class State: 20 | """This should contains everything needed to continue the conversation. 21 | 22 | For example, the history of the conversation, the current task (success/failure) at each step, etc. 23 | """ 24 | 25 | def __init__( 26 | self, 27 | history: List[Dict[str, Any]] = None, 28 | finished: bool = False, 29 | success: bool = False, 30 | latest_output: Dict[str, Any] = None, 31 | agent_action_count: Dict[str, int] = None, 32 | terminate_reason: str = None, 33 | ): 34 | self.history: List[Dict[str, Any]] = history 35 | self.finished: bool = finished 36 | self.success: bool = success 37 | self.latest_output: Dict[str, Any] = latest_output 38 | self.agent_action_count: Dict[str, int] = agent_action_count 39 | self.token_counter: Dict[str, int] = defaultdict(int) 40 | self.terminate_reason: str = terminate_reason 41 | self.error: Optional[str] = None 42 | 43 | @property 44 | def empty(self): 45 | return len(self.history) == 0 46 | 47 | def to_dict(self) -> Dict[str, Any]: 48 | return { 49 | "history": self.history, 50 | "finished": self.finished, 51 | "success": self.success, 52 | "latest_output": self.latest_output, 53 | "agent_action_count": self.agent_action_count, 54 | "token_counter": dict(self.token_counter), 55 | "terminate_reason": self.terminate_reason, 56 | "error": self.error, 57 | } 58 | 59 | 60 | class FeedbackType(enum.Enum): 61 | FEEDBACK_WITH_GT = "feedback_with_gt" 62 | FEEDBACK_WO_GT = "feedback_wo_gt" 63 | NO_FEEDBACK = "no_feedback" 64 | 65 | 66 | class StepOutput: 67 | def __init__( 68 | self, 69 | observation: str = None, 70 | feedback: Action = Action(value="",use_tool=True), 71 | feedback_type: FeedbackType = FeedbackType.NO_FEEDBACK, 72 | success: bool = False, 73 | extra: Dict[str, Any] = None, 74 | turn_info: Tuple[int, int] = None, 75 | ): 76 | self.observation: str = observation 77 | self.feedback: Action = feedback 78 | self.feedback_type: FeedbackType = feedback_type 79 | self.success: bool = success 80 | self.extra: Dict[str, Any] = extra 81 | self.turn_info = turn_info 82 | 83 | def __repr__(self) -> str: 84 | return self.observation 85 | 86 | def to_str1(self) -> str: 87 | # output = "Observation:\n" 88 | # print("=================observation===============") 89 | # print(self.observation) 90 | # print("=================feedback==================") 91 | # print(self.feedback.value) 92 | # print("===========================================") 93 | # if self.observation is not None: 94 | # output += self.observation + "\n" 95 | # else: 96 | # if self.success == False: 97 | # output += "Your answer is wrong.\n" 98 | 99 | # if self.turn_info != None: 100 | # n_steps_left, n_propose_solution_left = self.turn_info 101 | # output += "You have {} steps left and {} chances to propose solution left.\n".format( 102 | # n_steps_left, n_propose_solution_left 103 | # ) 104 | # if n_steps_left <= 1: 105 | # output += "You should take the last step to propose a solution.\n" 106 | 107 | # if self.feedback_type != FeedbackType.NO_FEEDBACK and self.feedback != "": 108 | # # output += ( 109 | # # f"\nYour {action_type} is commented by an expert " 110 | # # f"{'without' if self.feedback_type == FeedbackType.FEEDBACK_WO_GT else 'with'}" 111 | # # " access to the correct answer.\nExpert feedback:\n" 112 | # # ) 113 | # # Maybe we don't want to tell the model whether the expert has acces to GT 114 | # # print("self.feedback.value") 115 | # # print(self.feedback.value) 116 | # # output += self.feedback.value + "\n" 117 | output = "Assistant2:\n" + self.feedback.value + "\n" 118 | return output 119 | 120 | def to_str(self) -> str: 121 | output = "Observation:\n" 122 | if self.observation is not None: 123 | output += self.observation + "\n" 124 | else: 125 | if self.success == False: 126 | output += "Your answer is wrong.\n" 127 | 128 | if self.turn_info != None: 129 | n_steps_left, n_propose_solution_left = self.turn_info 130 | output += "You have {} steps left and {} chances to propose solution left.\n".format( 131 | n_steps_left, n_propose_solution_left 132 | ) 133 | if n_steps_left <= 1: 134 | output += "You should take the last step to propose a solution.\n" 135 | 136 | if self.feedback_type != FeedbackType.NO_FEEDBACK and self.feedback != "": 137 | # output += ( 138 | # f"\nYour {action_type} is commented by an expert " 139 | # f"{'without' if self.feedback_type == FeedbackType.FEEDBACK_WO_GT else 'with'}" 140 | # " access to the correct answer.\nExpert feedback:\n" 141 | # ) 142 | # Maybe we don't want to tell the model whether the expert has acces to GT 143 | # print("self.feedback.value") 144 | # print(self.feedback.value) 145 | # output += self.feedback.value + "\n" 146 | output += "\nAssistant2:\n" + self.feedback.value + "\n" 147 | 148 | return output 149 | 150 | def to_dict(self) -> Dict[str, Any]: 151 | return { 152 | # "observation": self.observation, 153 | "feedback": self.feedback, 154 | # "feedback_type": self.feedback_type.value, 155 | # "success": self.success, 156 | } 157 | -------------------------------------------------------------------------------- /retrieval_lm/src/finetuning_data.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | import torch 4 | import random 5 | import json 6 | import sys 7 | import numpy as np 8 | from src import normalize_text 9 | 10 | 11 | class Dataset(torch.utils.data.Dataset): 12 | def __init__( 13 | self, 14 | datapaths, 15 | negative_ctxs=1, 16 | negative_hard_ratio=0.0, 17 | negative_hard_min_idx=0, 18 | training=False, 19 | global_rank=-1, 20 | world_size=-1, 21 | maxload=None, 22 | normalize=False, 23 | ): 24 | self.negative_ctxs = negative_ctxs 25 | self.negative_hard_ratio = negative_hard_ratio 26 | self.negative_hard_min_idx = negative_hard_min_idx 27 | self.training = training 28 | self.normalize_fn = normalize_text.normalize if normalize_text else lambda x: x 29 | self._load_data(datapaths, global_rank, world_size, maxload) 30 | 31 | def __len__(self): 32 | return len(self.data) 33 | 34 | def __getitem__(self, index): 35 | example = self.data[index] 36 | question = example["question"] 37 | if self.training: 38 | gold = random.choice(example["positive_ctxs"]) 39 | 40 | n_hard_negatives, n_random_negatives = self.sample_n_hard_negatives(example) 41 | negatives = [] 42 | if n_random_negatives > 0: 43 | random_negatives = random.sample(example["negative_ctxs"], n_random_negatives) 44 | negatives += random_negatives 45 | if n_hard_negatives > 0: 46 | hard_negatives = random.sample( 47 | example["hard_negative_ctxs"][self.negative_hard_min_idx :], n_hard_negatives 48 | ) 49 | negatives += hard_negatives 50 | else: 51 | gold = example["positive_ctxs"][0] 52 | nidx = 0 53 | if "negative_ctxs" in example: 54 | negatives = [example["negative_ctxs"][nidx]] 55 | else: 56 | negatives = [] 57 | 58 | gold = gold["title"] + " " + gold["text"] if "title" in gold and len(gold["title"]) > 0 else gold["text"] 59 | 60 | negatives = [ 61 | n["title"] + " " + n["text"] if ("title" in n and len(n["title"]) > 0) else n["text"] for n in negatives 62 | ] 63 | 64 | example = { 65 | "query": self.normalize_fn(question), 66 | "gold": self.normalize_fn(gold), 67 | "negatives": [self.normalize_fn(n) for n in negatives], 68 | } 69 | return example 70 | 71 | def _load_data(self, datapaths, global_rank, world_size, maxload): 72 | counter = 0 73 | self.data = [] 74 | for path in datapaths: 75 | path = str(path) 76 | if path.endswith(".jsonl"): 77 | file_data, counter = self._load_data_jsonl(path, global_rank, world_size, counter, maxload) 78 | elif path.endswith(".json"): 79 | file_data, counter = self._load_data_json(path, global_rank, world_size, counter, maxload) 80 | self.data.extend(file_data) 81 | if maxload is not None and maxload > 0 and counter >= maxload: 82 | break 83 | 84 | def _load_data_json(self, path, global_rank, world_size, counter, maxload=None): 85 | examples = [] 86 | with open(path, "r") as fin: 87 | data = json.load(fin) 88 | for example in data: 89 | counter += 1 90 | if global_rank > -1 and not counter % world_size == global_rank: 91 | continue 92 | examples.append(example) 93 | if maxload is not None and maxload > 0 and counter == maxload: 94 | break 95 | 96 | return examples, counter 97 | 98 | def _load_data_jsonl(self, path, global_rank, world_size, counter, maxload=None): 99 | examples = [] 100 | with open(path, "r") as fin: 101 | for line in fin: 102 | counter += 1 103 | if global_rank > -1 and not counter % world_size == global_rank: 104 | continue 105 | example = json.loads(line) 106 | examples.append(example) 107 | if maxload is not None and maxload > 0 and counter == maxload: 108 | break 109 | 110 | return examples, counter 111 | 112 | def sample_n_hard_negatives(self, ex): 113 | 114 | if "hard_negative_ctxs" in ex: 115 | n_hard_negatives = sum([random.random() < self.negative_hard_ratio for _ in range(self.negative_ctxs)]) 116 | n_hard_negatives = min(n_hard_negatives, len(ex["hard_negative_ctxs"][self.negative_hard_min_idx :])) 117 | else: 118 | n_hard_negatives = 0 119 | n_random_negatives = self.negative_ctxs - n_hard_negatives 120 | if "negative_ctxs" in ex: 121 | n_random_negatives = min(n_random_negatives, len(ex["negative_ctxs"])) 122 | else: 123 | n_random_negatives = 0 124 | return n_hard_negatives, n_random_negatives 125 | 126 | 127 | class Collator(object): 128 | def __init__(self, tokenizer, passage_maxlength=200): 129 | self.tokenizer = tokenizer 130 | self.passage_maxlength = passage_maxlength 131 | 132 | def __call__(self, batch): 133 | queries = [ex["query"] for ex in batch] 134 | golds = [ex["gold"] for ex in batch] 135 | negs = [item for ex in batch for item in ex["negatives"]] 136 | allpassages = golds + negs 137 | 138 | qout = self.tokenizer.batch_encode_plus( 139 | queries, 140 | max_length=self.passage_maxlength, 141 | truncation=True, 142 | padding=True, 143 | add_special_tokens=True, 144 | return_tensors="pt", 145 | ) 146 | kout = self.tokenizer.batch_encode_plus( 147 | allpassages, 148 | max_length=self.passage_maxlength, 149 | truncation=True, 150 | padding=True, 151 | add_special_tokens=True, 152 | return_tensors="pt", 153 | ) 154 | q_tokens, q_mask = qout["input_ids"], qout["attention_mask"].bool() 155 | k_tokens, k_mask = kout["input_ids"], kout["attention_mask"].bool() 156 | 157 | g_tokens, g_mask = k_tokens[: len(golds)], k_mask[: len(golds)] 158 | n_tokens, n_mask = k_tokens[len(golds) :], k_mask[len(golds) :] 159 | 160 | batch = { 161 | "q_tokens": q_tokens, 162 | "q_mask": q_mask, 163 | "k_tokens": k_tokens, 164 | "k_mask": k_mask, 165 | "g_tokens": g_tokens, 166 | "g_mask": g_mask, 167 | "n_tokens": n_tokens, 168 | "n_mask": n_mask, 169 | } 170 | 171 | return batch 172 | -------------------------------------------------------------------------------- /retrieval_lm/src/options.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | import argparse 4 | import os 5 | 6 | 7 | class Options: 8 | def __init__(self): 9 | self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 10 | self.initialize() 11 | 12 | def initialize(self): 13 | # basic parameters 14 | self.parser.add_argument( 15 | "--output_dir", type=str, default="./checkpoint/my_experiments", help="models are saved here" 16 | ) 17 | self.parser.add_argument( 18 | "--train_data", 19 | nargs="+", 20 | default=[], 21 | help="Data used for training, passed as a list of directories splitted into tensor files.", 22 | ) 23 | self.parser.add_argument( 24 | "--eval_data", 25 | nargs="+", 26 | default=[], 27 | help="Data used for evaluation during finetuning, this option is not used during contrastive pre-training.", 28 | ) 29 | self.parser.add_argument( 30 | "--eval_datasets", nargs="+", default=[], help="List of datasets used for evaluation, in BEIR format" 31 | ) 32 | self.parser.add_argument( 33 | "--eval_datasets_dir", type=str, default="./", help="Directory where eval datasets are stored" 34 | ) 35 | self.parser.add_argument("--model_path", type=str, default="none", help="path for retraining") 36 | self.parser.add_argument("--continue_training", action="store_true") 37 | self.parser.add_argument("--num_workers", type=int, default=5) 38 | 39 | self.parser.add_argument("--chunk_length", type=int, default=256) 40 | self.parser.add_argument("--loading_mode", type=str, default="split") 41 | self.parser.add_argument("--lower_case", action="store_true", help="perform evaluation after lowercasing") 42 | self.parser.add_argument( 43 | "--sampling_coefficient", 44 | type=float, 45 | default=0.0, 46 | help="coefficient used for sampling between different datasets during training, \ 47 | by default sampling is uniform over datasets", 48 | ) 49 | self.parser.add_argument("--augmentation", type=str, default="none") 50 | self.parser.add_argument("--prob_augmentation", type=float, default=0.0) 51 | 52 | self.parser.add_argument("--dropout", type=float, default=0.1) 53 | self.parser.add_argument("--rho", type=float, default=0.05) 54 | 55 | self.parser.add_argument("--contrastive_mode", type=str, default="moco") 56 | self.parser.add_argument("--queue_size", type=int, default=65536) 57 | self.parser.add_argument("--temperature", type=float, default=1.0) 58 | self.parser.add_argument("--momentum", type=float, default=0.999) 59 | self.parser.add_argument("--moco_train_mode_encoder_k", action="store_true") 60 | self.parser.add_argument("--eval_normalize_text", action="store_true") 61 | self.parser.add_argument("--norm_query", action="store_true") 62 | self.parser.add_argument("--norm_doc", action="store_true") 63 | self.parser.add_argument("--projection_size", type=int, default=768) 64 | 65 | self.parser.add_argument("--ratio_min", type=float, default=0.1) 66 | self.parser.add_argument("--ratio_max", type=float, default=0.5) 67 | self.parser.add_argument("--score_function", type=str, default="dot") 68 | self.parser.add_argument("--retriever_model_id", type=str, default="bert-base-uncased") 69 | self.parser.add_argument("--pooling", type=str, default="average") 70 | self.parser.add_argument("--random_init", action="store_true", help="init model with random weights") 71 | 72 | # dataset parameters 73 | self.parser.add_argument("--per_gpu_batch_size", default=64, type=int, help="Batch size per GPU for training.") 74 | self.parser.add_argument( 75 | "--per_gpu_eval_batch_size", default=256, type=int, help="Batch size per GPU for evaluation." 76 | ) 77 | self.parser.add_argument("--total_steps", type=int, default=1000) 78 | self.parser.add_argument("--warmup_steps", type=int, default=-1) 79 | 80 | self.parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") 81 | self.parser.add_argument("--main_port", type=int, default=10001, help="Master port (for multi-node SLURM jobs)") 82 | self.parser.add_argument("--seed", type=int, default=0, help="random seed for initialization") 83 | # training parameters 84 | self.parser.add_argument("--optim", type=str, default="adamw") 85 | self.parser.add_argument("--scheduler", type=str, default="linear") 86 | self.parser.add_argument("--lr", type=float, default=1e-4, help="learning rate") 87 | self.parser.add_argument( 88 | "--lr_min_ratio", 89 | type=float, 90 | default=0.0, 91 | help="minimum learning rate at the end of the optimization schedule as a ratio of the learning rate", 92 | ) 93 | self.parser.add_argument("--weight_decay", type=float, default=0.01, help="learning rate") 94 | self.parser.add_argument("--beta1", type=float, default=0.9, help="beta1") 95 | self.parser.add_argument("--beta2", type=float, default=0.98, help="beta2") 96 | self.parser.add_argument("--eps", type=float, default=1e-6, help="eps") 97 | self.parser.add_argument( 98 | "--log_freq", type=int, default=100, help="log train stats every steps during training" 99 | ) 100 | self.parser.add_argument( 101 | "--eval_freq", type=int, default=500, help="evaluate model every steps during training" 102 | ) 103 | self.parser.add_argument("--save_freq", type=int, default=50000) 104 | self.parser.add_argument("--maxload", type=int, default=None) 105 | self.parser.add_argument("--label_smoothing", type=float, default=0.0) 106 | 107 | # finetuning options 108 | self.parser.add_argument("--negative_ctxs", type=int, default=1) 109 | self.parser.add_argument("--negative_hard_min_idx", type=int, default=0) 110 | self.parser.add_argument("--negative_hard_ratio", type=float, default=0.0) 111 | 112 | def print_options(self, opt): 113 | message = "" 114 | for k, v in sorted(vars(opt).items()): 115 | comment = "" 116 | default = self.parser.get_default(k) 117 | if v != default: 118 | comment = f"\t[default: %s]" % str(default) 119 | message += f"{str(k):>40}: {str(v):<40}{comment}\n" 120 | print(message, flush=True) 121 | model_dir = os.path.join(opt.output_dir, "models") 122 | if not os.path.exists(model_dir): 123 | os.makedirs(os.path.join(opt.output_dir, "models")) 124 | file_name = os.path.join(opt.output_dir, "opt.txt") 125 | with open(file_name, "wt") as opt_file: 126 | opt_file.write(message) 127 | opt_file.write("\n") 128 | 129 | def parse(self): 130 | opt, _ = self.parser.parse_known_args() 131 | # opt = self.parser.parse_args() 132 | return opt 133 | -------------------------------------------------------------------------------- /retrieval_lm/src/evaluation.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import collections 9 | import logging 10 | import regex 11 | import string 12 | import unicodedata 13 | from functools import partial 14 | from multiprocessing import Pool as ProcessPool 15 | from typing import Tuple, List, Dict 16 | import numpy as np 17 | 18 | """ 19 | Evaluation code from DPR: https://github.com/facebookresearch/DPR 20 | """ 21 | 22 | class SimpleTokenizer(object): 23 | ALPHA_NUM = r'[\p{L}\p{N}\p{M}]+' 24 | NON_WS = r'[^\p{Z}\p{C}]' 25 | 26 | def __init__(self): 27 | """ 28 | Args: 29 | annotators: None or empty set (only tokenizes). 30 | """ 31 | self._regexp = regex.compile( 32 | '(%s)|(%s)' % (self.ALPHA_NUM, self.NON_WS), 33 | flags=regex.IGNORECASE + regex.UNICODE + regex.MULTILINE 34 | ) 35 | 36 | def tokenize(self, text, uncased=False): 37 | matches = [m for m in self._regexp.finditer(text)] 38 | if uncased: 39 | tokens = [m.group().lower() for m in matches] 40 | else: 41 | tokens = [m.group() for m in matches] 42 | return tokens 43 | 44 | logger = logging.getLogger(__name__) 45 | 46 | QAMatchStats = collections.namedtuple('QAMatchStats', ['top_k_hits', 'questions_doc_hits']) 47 | 48 | def calculate_matches(data: List, workers_num: int): 49 | """ 50 | Evaluates answers presence in the set of documents. This function is supposed to be used with a large collection of 51 | documents and results. It internally forks multiple sub-processes for evaluation and then merges results 52 | :param all_docs: dictionary of the entire documents database. doc_id -> (doc_text, title) 53 | :param answers: list of answers's list. One list per question 54 | :param closest_docs: document ids of the top results along with their scores 55 | :param workers_num: amount of parallel threads to process data 56 | :param match_type: type of answer matching. Refer to has_answer code for available options 57 | :return: matching information tuple. 58 | top_k_hits - a list where the index is the amount of top documents retrieved and the value is the total amount of 59 | valid matches across an entire dataset. 60 | questions_doc_hits - more detailed info with answer matches for every question and every retrieved document 61 | """ 62 | 63 | logger.info('Matching answers in top docs...') 64 | 65 | tokenizer = SimpleTokenizer() 66 | get_score_partial = partial(check_answer, tokenizer=tokenizer) 67 | 68 | processes = ProcessPool(processes=workers_num) 69 | scores = processes.map(get_score_partial, data) 70 | 71 | logger.info('Per question validation results len=%d', len(scores)) 72 | 73 | n_docs = len(data[0]['ctxs']) 74 | top_k_hits = [0] * n_docs 75 | for question_hits in scores: 76 | best_hit = next((i for i, x in enumerate(question_hits) if x), None) 77 | if best_hit is not None: 78 | top_k_hits[best_hit:] = [v + 1 for v in top_k_hits[best_hit:]] 79 | 80 | return QAMatchStats(top_k_hits, scores) 81 | 82 | def check_answer(example, tokenizer) -> List[bool]: 83 | """Search through all the top docs to see if they have any of the answers.""" 84 | answers = example['answers'] 85 | ctxs = example['ctxs'] 86 | 87 | hits = [] 88 | 89 | for i, doc in enumerate(ctxs): 90 | text = doc['text'] 91 | 92 | if text is None: # cannot find the document for some reason 93 | logger.warning("no doc in db") 94 | hits.append(False) 95 | continue 96 | 97 | hits.append(has_answer(answers, text, tokenizer)) 98 | 99 | return hits 100 | 101 | def has_answer(answers, text, tokenizer) -> bool: 102 | """Check if a document contains an answer string.""" 103 | text = _normalize(text) 104 | text = tokenizer.tokenize(text, uncased=True) 105 | 106 | for answer in answers: 107 | answer = _normalize(answer) 108 | answer = tokenizer.tokenize(answer, uncased=True) 109 | for i in range(0, len(text) - len(answer) + 1): 110 | if answer == text[i: i + len(answer)]: 111 | return True 112 | return False 113 | 114 | ################################################# 115 | ######## READER EVALUATION ######## 116 | ################################################# 117 | 118 | def _normalize(text): 119 | return unicodedata.normalize('NFD', text) 120 | 121 | #Normalization and score functions from SQuAD evaluation script https://worksheets.codalab.org/rest/bundles/0x6b567e1cf2e041ec80d7098f031c5c9e/contents/blob/ 122 | def normalize_answer(s): 123 | def remove_articles(text): 124 | return regex.sub(r'\b(a|an|the)\b', ' ', text) 125 | 126 | def white_space_fix(text): 127 | return ' '.join(text.split()) 128 | 129 | def remove_punc(text): 130 | exclude = set(string.punctuation) 131 | return ''.join(ch for ch in text if ch not in exclude) 132 | 133 | def lower(text): 134 | return text.lower() 135 | 136 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 137 | 138 | def em(prediction, ground_truth): 139 | return normalize_answer(prediction) == normalize_answer(ground_truth) 140 | 141 | def f1(prediction, ground_truth): 142 | prediction_tokens = normalize_answer(prediction).split() 143 | ground_truth_tokens = normalize_answer(ground_truth).split() 144 | common = Counter(prediction_tokens) & Counter(ground_truth_tokens) 145 | num_same = sum(common.values()) 146 | if num_same == 0: 147 | return 0 148 | precision = 1.0 * num_same / len(prediction_tokens) 149 | recall = 1.0 * num_same / len(ground_truth_tokens) 150 | f1 = (2 * precision * recall) / (precision + recall) 151 | return f1 152 | 153 | def f1_score(prediction, ground_truths): 154 | return max([f1(prediction, gt) for gt in ground_truths]) 155 | 156 | def exact_match_score(prediction, ground_truths): 157 | return max([em(prediction, gt) for gt in ground_truths]) 158 | 159 | #################################################### 160 | ######## RETRIEVER EVALUATION ######## 161 | #################################################### 162 | 163 | def eval_batch(scores, inversions, avg_topk, idx_topk): 164 | for k, s in enumerate(scores): 165 | s = s.cpu().numpy() 166 | sorted_idx = np.argsort(-s) 167 | score(sorted_idx, inversions, avg_topk, idx_topk) 168 | 169 | def count_inversions(arr): 170 | inv_count = 0 171 | lenarr = len(arr) 172 | for i in range(lenarr): 173 | for j in range(i + 1, lenarr): 174 | if (arr[i] > arr[j]): 175 | inv_count += 1 176 | return inv_count 177 | 178 | def score(x, inversions, avg_topk, idx_topk): 179 | x = np.array(x) 180 | inversions.append(count_inversions(x)) 181 | for k in avg_topk: 182 | # ratio of passages in the predicted top-k that are 183 | # also in the topk given by gold score 184 | avg_pred_topk = (x[:k]", "") 53 | return_str += "\n" 54 | 55 | # st_msgs will end with user 56 | return_str += f"<|assistant|>\n" 57 | 58 | return return_str 59 | 60 | 61 | if "initial" not in st.session_state: 62 | st.session_state.initial = True 63 | st.session_state.msgs = [] 64 | langdetect_model = fasttext.load_model('lid.176.bin') 65 | model, tokenizer, search_engine_api = init_model_and_tokenizer_and_tool( 66 | model_name_or_path=MODEL_PATH) 67 | special_tokens_dict = load_sag_special_tokens(tokenizer) 68 | st.session_state.model = model 69 | st.session_state.langdetect_model = langdetect_model 70 | st.session_state.tokenizer = tokenizer 71 | st.session_state.search_engine_api = search_engine_api 72 | st.session_state.special_tokens_dict = special_tokens_dict 73 | 74 | 75 | special_tokens = ["[S_Rewritten_Query]", "[S_Decomposed_Query]", "[S_Disambiguated_Query]", ] # do not made [EOS] to zero, or it won't stop, and do not scale [A_Response] 76 | disable_special_tokens = ["[S_Rewritten_Query]", "[S_Decomposed_Query]", "[S_Disambiguated_Query]", "[A_Response]"] 77 | special_tokens_ids = [st.session_state.tokenizer.convert_tokens_to_ids(special) for special in special_tokens] 78 | disable_special_tokens_ids = [st.session_state.tokenizer.convert_tokens_to_ids(special) for special in disable_special_tokens] 79 | 80 | with st.sidebar: 81 | if st.button("Reset"): 82 | st.session_state.msgs = [] 83 | 84 | option = st.selectbox( 85 | 'Which mode do you want to use', 86 | ('disable search', 'adaptive search', 'force search')) 87 | 88 | st.write('You selected:', option) 89 | 90 | if option == "disable search": 91 | threshold_processor = SpecialTokenLogitsProcessor(st.session_state.tokenizer, disable_special_tokens_ids, enable_search=False) 92 | st.session_state.logits_processor = LogitsProcessorList([threshold_processor]) 93 | 94 | elif option == "adaptive search": 95 | # do not need to search 96 | st.session_state.logits_processor = LogitsProcessorList([]) 97 | 98 | elif option == "force search": 99 | threshold_processor = SpecialTokenLogitsProcessor(st.session_state.tokenizer, special_tokens_ids, enable_search=True) 100 | st.session_state.logits_processor = LogitsProcessorList([threshold_processor]) 101 | 102 | st.chat_message("assistant").write("How can I help you?") 103 | 104 | # show cur_session results 105 | for idx, msg in enumerate(st.session_state.msgs): 106 | 107 | with st.chat_message(msg["role"]): 108 | # Render intermediate steps if any were saved 109 | if msg["role"] == "assistant": 110 | for step in msg["search_results"]: 111 | with st.status(label="done") as cur_status: 112 | cur_status.markdown(step) 113 | st.markdown(msg["display_results"]) 114 | 115 | if prompt := st.chat_input(placeholder="Hi, how are you today?"): 116 | st.chat_message("user").write(prompt) 117 | 118 | cur_lang = st.session_state.langdetect_model.predict(prompt.replace("\n", ""))[0][0] # "__label__en" 119 | 120 | display_prompt = prompt 121 | real_prompt = prompt 122 | 123 | st.session_state.msgs.append({ 124 | "role": "user", 125 | "content": real_prompt, 126 | "display_results": display_prompt 127 | }) 128 | 129 | with st.chat_message("assistant"): 130 | st.container() 131 | 132 | generator = generate_and_retrieve( 133 | examples=[concat_msgs(st.session_state.msgs, lang="en")], 134 | model=st.session_state.model, 135 | tokenizer=st.session_state.tokenizer, 136 | special_tokens_dict=st.session_state.special_tokens_dict, 137 | search_engine_api=st.session_state.search_engine_api, 138 | search_limit=2, 139 | logits_processor=st.session_state.logits_processor) 140 | 141 | # TODO if we want to change the logic, that determine which status is now (searching or answering), 142 | # we may have to change the yield keyword in generator. 143 | 144 | search_results = [] 145 | while True: 146 | with st.status(label="Answering..." if option == "disable search" else "Searching...") as cur_status: 147 | 148 | try: 149 | result_dict = next(generator) 150 | if "search_query" in result_dict: 151 | cur_status.update(label="done") 152 | cur_status.markdown(f"**Search Query:{result_dict['search_query'][0]}**") 153 | newline = "\n" 154 | cur_status.markdown(f"Search Results: \n{result_dict['evidence_list'][0].replace(newline, ' '+newline)}") 155 | search_results.append(f"**Search Query:{result_dict['search_query'][0]}** \nSearch Results: \n{result_dict['evidence_list'][0].replace(newline, ' '+newline)}") 156 | 157 | except StopIteration as e: 158 | result_dict = e.value 159 | cur_status.update(label="done") 160 | cur_status.markdown(result_dict["final_response"][0]) 161 | break 162 | 163 | st.session_state.msgs.append({"role": "assistant", 164 | "content": result_dict["cur_examples"][0].split("<|assistant|>\n")[-1], 165 | "search_results": search_results, 166 | "display_results": result_dict["final_response"][0], # previous output from model 167 | }) 168 | 169 | # above is written in a st.expanded 170 | st.markdown(result_dict["final_response"][0]) -------------------------------------------------------------------------------- /retrieval_lm/src/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | import os 4 | import sys 5 | import logging 6 | import torch 7 | import errno 8 | from typing import Union, Tuple, List, Dict 9 | from collections import defaultdict 10 | 11 | from src import dist_utils 12 | 13 | Number = Union[float, int] 14 | 15 | logger = logging.getLogger(__name__) 16 | 17 | 18 | def init_logger(args, stdout_only=False): 19 | if torch.distributed.is_initialized(): 20 | torch.distributed.barrier() 21 | stdout_handler = logging.StreamHandler(sys.stdout) 22 | handlers = [stdout_handler] 23 | if not stdout_only: 24 | file_handler = logging.FileHandler(filename=os.path.join(args.output_dir, "run.log")) 25 | handlers.append(file_handler) 26 | logging.basicConfig( 27 | datefmt="%m/%d/%Y %H:%M:%S", 28 | level=logging.INFO if dist_utils.is_main() else logging.WARN, 29 | format="[%(asctime)s] {%(filename)s:%(lineno)d} %(levelname)s - %(message)s", 30 | handlers=handlers, 31 | ) 32 | return logger 33 | 34 | 35 | def symlink_force(target, link_name): 36 | try: 37 | os.symlink(target, link_name) 38 | except OSError as e: 39 | if e.errno == errno.EEXIST: 40 | os.remove(link_name) 41 | os.symlink(target, link_name) 42 | else: 43 | raise e 44 | 45 | 46 | def save(model, optimizer, scheduler, step, opt, dir_path, name): 47 | model_to_save = model.module if hasattr(model, "module") else model 48 | path = os.path.join(dir_path, "checkpoint") 49 | epoch_path = os.path.join(path, name) # "step-%s" % step) 50 | os.makedirs(epoch_path, exist_ok=True) 51 | cp = os.path.join(path, "latest") 52 | fp = os.path.join(epoch_path, "checkpoint.pth") 53 | checkpoint = { 54 | "step": step, 55 | "model": model_to_save.state_dict(), 56 | "optimizer": optimizer.state_dict(), 57 | "scheduler": scheduler.state_dict(), 58 | "opt": opt, 59 | } 60 | torch.save(checkpoint, fp) 61 | symlink_force(epoch_path, cp) 62 | if not name == "lastlog": 63 | logger.info(f"Saving model to {epoch_path}") 64 | 65 | 66 | def load(model_class, dir_path, opt, reset_params=False): 67 | epoch_path = os.path.realpath(dir_path) 68 | checkpoint_path = os.path.join(epoch_path, "checkpoint.pth") 69 | logger.info(f"loading checkpoint {checkpoint_path}") 70 | checkpoint = torch.load(checkpoint_path, map_location="cpu") 71 | opt_checkpoint = checkpoint["opt"] 72 | state_dict = checkpoint["model"] 73 | 74 | model = model_class(opt_checkpoint) 75 | model.load_state_dict(state_dict, strict=True) 76 | model = model.cuda() 77 | step = checkpoint["step"] 78 | if not reset_params: 79 | optimizer, scheduler = set_optim(opt_checkpoint, model) 80 | scheduler.load_state_dict(checkpoint["scheduler"]) 81 | optimizer.load_state_dict(checkpoint["optimizer"]) 82 | else: 83 | optimizer, scheduler = set_optim(opt, model) 84 | 85 | return model, optimizer, scheduler, opt_checkpoint, step 86 | 87 | 88 | ############ OPTIM 89 | 90 | 91 | class WarmupLinearScheduler(torch.optim.lr_scheduler.LambdaLR): 92 | def __init__(self, optimizer, warmup, total, ratio, last_epoch=-1): 93 | self.warmup = warmup 94 | self.total = total 95 | self.ratio = ratio 96 | super(WarmupLinearScheduler, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch) 97 | 98 | def lr_lambda(self, step): 99 | if step < self.warmup: 100 | return (1 - self.ratio) * step / float(max(1, self.warmup)) 101 | 102 | return max( 103 | 0.0, 104 | 1.0 + (self.ratio - 1) * (step - self.warmup) / float(max(1.0, self.total - self.warmup)), 105 | ) 106 | 107 | 108 | class CosineScheduler(torch.optim.lr_scheduler.LambdaLR): 109 | def __init__(self, optimizer, warmup, total, ratio=0.1, last_epoch=-1): 110 | self.warmup = warmup 111 | self.total = total 112 | self.ratio = ratio 113 | super(CosineScheduler, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch) 114 | 115 | def lr_lambda(self, step): 116 | if step < self.warmup: 117 | return float(step) / self.warmup 118 | s = float(step - self.warmup) / (self.total - self.warmup) 119 | return self.ratio + (1.0 - self.ratio) * math.cos(0.5 * math.pi * s) 120 | 121 | 122 | def set_optim(opt, model): 123 | if opt.optim == "adamw": 124 | optimizer = torch.optim.AdamW( 125 | model.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2), eps=opt.eps, weight_decay=opt.weight_decay 126 | ) 127 | else: 128 | raise NotImplementedError("optimizer class not implemented") 129 | 130 | scheduler_args = { 131 | "warmup": opt.warmup_steps, 132 | "total": opt.total_steps, 133 | "ratio": opt.lr_min_ratio, 134 | } 135 | if opt.scheduler == "linear": 136 | scheduler_class = WarmupLinearScheduler 137 | elif opt.scheduler == "cosine": 138 | scheduler_class = CosineScheduler 139 | else: 140 | raise ValueError 141 | scheduler = scheduler_class(optimizer, **scheduler_args) 142 | return optimizer, scheduler 143 | 144 | 145 | def get_parameters(net, verbose=False): 146 | num_params = 0 147 | for param in net.parameters(): 148 | num_params += param.numel() 149 | message = "[Network] Total number of parameters : %.6f M" % (num_params / 1e6) 150 | return message 151 | 152 | 153 | class WeightedAvgStats: 154 | """provides an average over a bunch of stats""" 155 | 156 | def __init__(self): 157 | self.raw_stats: Dict[str, float] = defaultdict(float) 158 | self.total_weights: Dict[str, float] = defaultdict(float) 159 | 160 | def update(self, vals: Dict[str, Tuple[Number, Number]]) -> None: 161 | for key, (value, weight) in vals.items(): 162 | self.raw_stats[key] += value * weight 163 | self.total_weights[key] += weight 164 | 165 | @property 166 | def stats(self) -> Dict[str, float]: 167 | return {x: self.raw_stats[x] / self.total_weights[x] for x in self.raw_stats.keys()} 168 | 169 | @property 170 | def tuple_stats(self) -> Dict[str, Tuple[float, float]]: 171 | return {x: (self.raw_stats[x] / self.total_weights[x], self.total_weights[x]) for x in self.raw_stats.keys()} 172 | 173 | def reset(self) -> None: 174 | self.raw_stats = defaultdict(float) 175 | self.total_weights = defaultdict(float) 176 | 177 | @property 178 | def average_stats(self) -> Dict[str, float]: 179 | keys = sorted(self.raw_stats.keys()) 180 | if torch.distributed.is_initialized(): 181 | torch.distributed.broadcast_object_list(keys, src=0) 182 | global_dict = {} 183 | for k in keys: 184 | if not k in self.total_weights: 185 | v = 0.0 186 | else: 187 | v = self.raw_stats[k] / self.total_weights[k] 188 | v, _ = dist_utils.weighted_average(v, self.total_weights[k]) 189 | global_dict[k] = v 190 | return global_dict 191 | 192 | 193 | def load_hf(object_class, model_name): 194 | try: 195 | obj = object_class.from_pretrained(model_name, local_files_only=True) 196 | except: 197 | obj = object_class.from_pretrained(model_name, local_files_only=False) 198 | return obj 199 | 200 | 201 | def init_tb_logger(output_dir): 202 | try: 203 | from torch.utils import tensorboard 204 | 205 | if dist_utils.is_main(): 206 | tb_logger = tensorboard.SummaryWriter(output_dir) 207 | else: 208 | tb_logger = None 209 | except: 210 | logger.warning("Tensorboard is not available.") 211 | tb_logger = None 212 | 213 | return tb_logger 214 | -------------------------------------------------------------------------------- /retrieval_lm/src/beir_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | import os 4 | from collections import defaultdict 5 | from typing import List, Dict 6 | import numpy as np 7 | import torch 8 | import torch.distributed as dist 9 | 10 | import beir.util 11 | from beir.datasets.data_loader import GenericDataLoader 12 | from beir.retrieval.evaluation import EvaluateRetrieval 13 | from beir.retrieval.search.dense import DenseRetrievalExactSearch 14 | 15 | from beir.reranking.models import CrossEncoder 16 | from beir.reranking import Rerank 17 | 18 | import src.dist_utils as dist_utils 19 | from src import normalize_text 20 | 21 | 22 | class DenseEncoderModel: 23 | def __init__( 24 | self, 25 | query_encoder, 26 | doc_encoder=None, 27 | tokenizer=None, 28 | max_length=512, 29 | add_special_tokens=True, 30 | norm_query=False, 31 | norm_doc=False, 32 | lower_case=False, 33 | normalize_text=False, 34 | **kwargs, 35 | ): 36 | self.query_encoder = query_encoder 37 | self.doc_encoder = doc_encoder 38 | self.tokenizer = tokenizer 39 | self.max_length = max_length 40 | self.add_special_tokens = add_special_tokens 41 | self.norm_query = norm_query 42 | self.norm_doc = norm_doc 43 | self.lower_case = lower_case 44 | self.normalize_text = normalize_text 45 | 46 | def encode_queries(self, queries: List[str], batch_size: int, **kwargs) -> np.ndarray: 47 | 48 | if dist.is_initialized(): 49 | idx = np.array_split(range(len(queries)), dist.get_world_size())[dist.get_rank()] 50 | else: 51 | idx = range(len(queries)) 52 | 53 | queries = [queries[i] for i in idx] 54 | if self.normalize_text: 55 | queries = [normalize_text.normalize(q) for q in queries] 56 | if self.lower_case: 57 | queries = [q.lower() for q in queries] 58 | 59 | allemb = [] 60 | nbatch = (len(queries) - 1) // batch_size + 1 61 | with torch.no_grad(): 62 | for k in range(nbatch): 63 | start_idx = k * batch_size 64 | end_idx = min((k + 1) * batch_size, len(queries)) 65 | 66 | qencode = self.tokenizer.batch_encode_plus( 67 | queries[start_idx:end_idx], 68 | max_length=self.max_length, 69 | padding=True, 70 | truncation=True, 71 | add_special_tokens=self.add_special_tokens, 72 | return_tensors="pt", 73 | ) 74 | qencode = {key: value.cuda() for key, value in qencode.items()} 75 | emb = self.query_encoder(**qencode, normalize=self.norm_query) 76 | allemb.append(emb.cpu()) 77 | 78 | allemb = torch.cat(allemb, dim=0) 79 | allemb = allemb.cuda() 80 | if dist.is_initialized(): 81 | allemb = dist_utils.varsize_gather_nograd(allemb) 82 | allemb = allemb.cpu().numpy() 83 | return allemb 84 | 85 | def encode_corpus(self, corpus: List[Dict[str, str]], batch_size: int, **kwargs): 86 | 87 | if dist.is_initialized(): 88 | idx = np.array_split(range(len(corpus)), dist.get_world_size())[dist.get_rank()] 89 | else: 90 | idx = range(len(corpus)) 91 | corpus = [corpus[i] for i in idx] 92 | corpus = [c["title"] + " " + c["text"] if len(c["title"]) > 0 else c["text"] for c in corpus] 93 | if self.normalize_text: 94 | corpus = [normalize_text.normalize(c) for c in corpus] 95 | if self.lower_case: 96 | corpus = [c.lower() for c in corpus] 97 | 98 | allemb = [] 99 | nbatch = (len(corpus) - 1) // batch_size + 1 100 | with torch.no_grad(): 101 | for k in range(nbatch): 102 | start_idx = k * batch_size 103 | end_idx = min((k + 1) * batch_size, len(corpus)) 104 | 105 | cencode = self.tokenizer.batch_encode_plus( 106 | corpus[start_idx:end_idx], 107 | max_length=self.max_length, 108 | padding=True, 109 | truncation=True, 110 | add_special_tokens=self.add_special_tokens, 111 | return_tensors="pt", 112 | ) 113 | cencode = {key: value.cuda() for key, value in cencode.items()} 114 | emb = self.doc_encoder(**cencode, normalize=self.norm_doc) 115 | allemb.append(emb.cpu()) 116 | 117 | allemb = torch.cat(allemb, dim=0) 118 | allemb = allemb.cuda() 119 | if dist.is_initialized(): 120 | allemb = dist_utils.varsize_gather_nograd(allemb) 121 | allemb = allemb.cpu().numpy() 122 | return allemb 123 | 124 | 125 | def evaluate_model( 126 | query_encoder, 127 | doc_encoder, 128 | tokenizer, 129 | dataset, 130 | batch_size=128, 131 | add_special_tokens=True, 132 | norm_query=False, 133 | norm_doc=False, 134 | is_main=True, 135 | split="test", 136 | score_function="dot", 137 | beir_dir="BEIR/datasets", 138 | save_results_path=None, 139 | lower_case=False, 140 | normalize_text=False, 141 | ): 142 | 143 | metrics = defaultdict(list) # store final results 144 | 145 | if hasattr(query_encoder, "module"): 146 | query_encoder = query_encoder.module 147 | query_encoder.eval() 148 | 149 | if doc_encoder is not None: 150 | if hasattr(doc_encoder, "module"): 151 | doc_encoder = doc_encoder.module 152 | doc_encoder.eval() 153 | else: 154 | doc_encoder = query_encoder 155 | 156 | dmodel = DenseRetrievalExactSearch( 157 | DenseEncoderModel( 158 | query_encoder=query_encoder, 159 | doc_encoder=doc_encoder, 160 | tokenizer=tokenizer, 161 | add_special_tokens=add_special_tokens, 162 | norm_query=norm_query, 163 | norm_doc=norm_doc, 164 | lower_case=lower_case, 165 | normalize_text=normalize_text, 166 | ), 167 | batch_size=batch_size, 168 | ) 169 | retriever = EvaluateRetrieval(dmodel, score_function=score_function) 170 | data_path = os.path.join(beir_dir, dataset) 171 | 172 | if not os.path.isdir(data_path) and is_main: 173 | url = "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{}.zip".format(dataset) 174 | data_path = beir.util.download_and_unzip(url, beir_dir) 175 | dist_utils.barrier() 176 | 177 | if not dataset == "cqadupstack": 178 | corpus, queries, qrels = GenericDataLoader(data_folder=data_path).load(split=split) 179 | results = retriever.retrieve(corpus, queries) 180 | if is_main: 181 | ndcg, _map, recall, precision = retriever.evaluate(qrels, results, retriever.k_values) 182 | for metric in (ndcg, _map, recall, precision, "mrr", "recall_cap", "hole"): 183 | if isinstance(metric, str): 184 | metric = retriever.evaluate_custom(qrels, results, retriever.k_values, metric=metric) 185 | for key, value in metric.items(): 186 | metrics[key].append(value) 187 | if save_results_path is not None: 188 | torch.save(results, f"{save_results_path}") 189 | elif dataset == "cqadupstack": # compute macroaverage over datasets 190 | paths = glob.glob(data_path) 191 | for path in paths: 192 | corpus, queries, qrels = GenericDataLoader(data_folder=data_folder).load(split=split) 193 | results = retriever.retrieve(corpus, queries) 194 | if is_main: 195 | ndcg, _map, recall, precision = retriever.evaluate(qrels, results, retriever.k_values) 196 | for metric in (ndcg, _map, recall, precision, "mrr", "recall_cap", "hole"): 197 | if isinstance(metric, str): 198 | metric = retriever.evaluate_custom(qrels, results, retriever.k_values, metric=metric) 199 | for key, value in metric.items(): 200 | metrics[key].append(value) 201 | for key, value in metrics.items(): 202 | assert ( 203 | len(value) == 12 204 | ), f"cqadupstack includes 12 datasets, only {len(value)} values were compute for the {key} metric" 205 | 206 | metrics = {key: 100 * np.mean(value) for key, value in metrics.items()} 207 | 208 | return metrics 209 | -------------------------------------------------------------------------------- /retrieval_lm/src/data.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | import os 4 | import glob 5 | import torch 6 | import random 7 | import json 8 | import csv 9 | import numpy as np 10 | import numpy.random 11 | import logging 12 | from collections import defaultdict 13 | import torch.distributed as dist 14 | 15 | from src import dist_utils 16 | 17 | logger = logging.getLogger(__name__) 18 | 19 | 20 | def load_data(opt, tokenizer): 21 | datasets = {} 22 | for path in opt.train_data: 23 | data = load_dataset(path, opt.loading_mode) 24 | if data is not None: 25 | datasets[path] = Dataset(data, opt.chunk_length, tokenizer, opt) 26 | dataset = MultiDataset(datasets) 27 | dataset.set_prob(coeff=opt.sampling_coefficient) 28 | return dataset 29 | 30 | 31 | def load_dataset(data_path, loading_mode): 32 | files = glob.glob(os.path.join(data_path, "*.p*")) 33 | files.sort() 34 | tensors = [] 35 | if loading_mode == "split": 36 | files_split = list(np.array_split(files, dist_utils.get_world_size()))[dist_utils.get_rank()] 37 | for filepath in files_split: 38 | try: 39 | tensors.append(torch.load(filepath, map_location="cpu")) 40 | except: 41 | logger.warning(f"Unable to load file {filepath}") 42 | elif loading_mode == "full": 43 | for fin in files: 44 | tensors.append(torch.load(fin, map_location="cpu")) 45 | elif loading_mode == "single": 46 | tensors.append(torch.load(files[0], map_location="cpu")) 47 | if len(tensors) == 0: 48 | return None 49 | tensor = torch.cat(tensors) 50 | return tensor 51 | 52 | 53 | class MultiDataset(torch.utils.data.Dataset): 54 | def __init__(self, datasets): 55 | 56 | self.datasets = datasets 57 | self.prob = [1 / len(self.datasets) for _ in self.datasets] 58 | self.dataset_ids = list(self.datasets.keys()) 59 | 60 | def __len__(self): 61 | return sum([len(dataset) for dataset in self.datasets.values()]) 62 | 63 | def __getitem__(self, index): 64 | dataset_idx = numpy.random.choice(range(len(self.prob)), 1, p=self.prob)[0] 65 | did = self.dataset_ids[dataset_idx] 66 | index = random.randint(0, len(self.datasets[did]) - 1) 67 | sample = self.datasets[did][index] 68 | sample["dataset_id"] = did 69 | return sample 70 | 71 | def generate_offset(self): 72 | for dataset in self.datasets.values(): 73 | dataset.generate_offset() 74 | 75 | def set_prob(self, coeff=0.0): 76 | 77 | prob = np.array([float(len(dataset)) for _, dataset in self.datasets.items()]) 78 | prob /= prob.sum() 79 | prob = np.array([p**coeff for p in prob]) 80 | prob /= prob.sum() 81 | self.prob = prob 82 | 83 | 84 | class Dataset(torch.utils.data.Dataset): 85 | """Monolingual dataset based on a list of paths""" 86 | 87 | def __init__(self, data, chunk_length, tokenizer, opt): 88 | 89 | self.data = data 90 | self.chunk_length = chunk_length 91 | self.tokenizer = tokenizer 92 | self.opt = opt 93 | self.generate_offset() 94 | 95 | def __len__(self): 96 | return (self.data.size(0) - self.offset) // self.chunk_length 97 | 98 | def __getitem__(self, index): 99 | start_idx = self.offset + index * self.chunk_length 100 | end_idx = start_idx + self.chunk_length 101 | tokens = self.data[start_idx:end_idx] 102 | q_tokens = randomcrop(tokens, self.opt.ratio_min, self.opt.ratio_max) 103 | k_tokens = randomcrop(tokens, self.opt.ratio_min, self.opt.ratio_max) 104 | q_tokens = apply_augmentation(q_tokens, self.opt) 105 | q_tokens = add_bos_eos(q_tokens, self.tokenizer.bos_token_id, self.tokenizer.eos_token_id) 106 | k_tokens = apply_augmentation(k_tokens, self.opt) 107 | k_tokens = add_bos_eos(k_tokens, self.tokenizer.bos_token_id, self.tokenizer.eos_token_id) 108 | 109 | return {"q_tokens": q_tokens, "k_tokens": k_tokens} 110 | 111 | def generate_offset(self): 112 | self.offset = random.randint(0, self.chunk_length - 1) 113 | 114 | 115 | class Collator(object): 116 | def __init__(self, opt): 117 | self.opt = opt 118 | 119 | def __call__(self, batch_examples): 120 | 121 | batch = defaultdict(list) 122 | for example in batch_examples: 123 | for k, v in example.items(): 124 | batch[k].append(v) 125 | 126 | q_tokens, q_mask = build_mask(batch["q_tokens"]) 127 | k_tokens, k_mask = build_mask(batch["k_tokens"]) 128 | 129 | batch["q_tokens"] = q_tokens 130 | batch["q_mask"] = q_mask 131 | batch["k_tokens"] = k_tokens 132 | batch["k_mask"] = k_mask 133 | 134 | return batch 135 | 136 | 137 | def randomcrop(x, ratio_min, ratio_max): 138 | 139 | ratio = random.uniform(ratio_min, ratio_max) 140 | length = int(len(x) * ratio) 141 | start = random.randint(0, len(x) - length) 142 | end = start + length 143 | crop = x[start:end].clone() 144 | return crop 145 | 146 | 147 | def build_mask(tensors): 148 | shapes = [x.shape for x in tensors] 149 | maxlength = max([len(x) for x in tensors]) 150 | returnmasks = [] 151 | ids = [] 152 | for k, x in enumerate(tensors): 153 | returnmasks.append(torch.tensor([1] * len(x) + [0] * (maxlength - len(x)))) 154 | ids.append(torch.cat((x, torch.tensor([0] * (maxlength - len(x)))))) 155 | ids = torch.stack(ids, dim=0).long() 156 | returnmasks = torch.stack(returnmasks, dim=0).bool() 157 | return ids, returnmasks 158 | 159 | 160 | def add_token(x, token): 161 | x = torch.cat((torch.tensor([token]), x)) 162 | return x 163 | 164 | 165 | def deleteword(x, p=0.1): 166 | mask = np.random.rand(len(x)) 167 | x = [e for e, m in zip(x, mask) if m > p] 168 | return x 169 | 170 | 171 | def replaceword(x, min_random, max_random, p=0.1): 172 | mask = np.random.rand(len(x)) 173 | x = [e if m > p else random.randint(min_random, max_random) for e, m in zip(x, mask)] 174 | return x 175 | 176 | 177 | def maskword(x, mask_id, p=0.1): 178 | mask = np.random.rand(len(x)) 179 | x = [e if m > p else mask_id for e, m in zip(x, mask)] 180 | return x 181 | 182 | 183 | def shuffleword(x, p=0.1): 184 | count = (np.random.rand(len(x)) < p).sum() 185 | """Shuffles any n number of values in a list""" 186 | indices_to_shuffle = random.sample(range(len(x)), k=count) 187 | to_shuffle = [x[i] for i in indices_to_shuffle] 188 | random.shuffle(to_shuffle) 189 | for index, value in enumerate(to_shuffle): 190 | old_index = indices_to_shuffle[index] 191 | x[old_index] = value 192 | return x 193 | 194 | 195 | def apply_augmentation(x, opt): 196 | if opt.augmentation == "mask": 197 | return torch.tensor(maskword(x, mask_id=opt.mask_id, p=opt.prob_augmentation)) 198 | elif opt.augmentation == "replace": 199 | return torch.tensor( 200 | replaceword(x, min_random=opt.start_id, max_random=opt.vocab_size - 1, p=opt.prob_augmentation) 201 | ) 202 | elif opt.augmentation == "delete": 203 | return torch.tensor(deleteword(x, p=opt.prob_augmentation)) 204 | elif opt.augmentation == "shuffle": 205 | return torch.tensor(shuffleword(x, p=opt.prob_augmentation)) 206 | else: 207 | if not isinstance(x, torch.Tensor): 208 | x = torch.Tensor(x) 209 | return x 210 | 211 | 212 | def add_bos_eos(x, bos_token_id, eos_token_id): 213 | if not isinstance(x, torch.Tensor): 214 | x = torch.Tensor(x) 215 | if bos_token_id is None and eos_token_id is not None: 216 | x = torch.cat([x.clone().detach(), torch.tensor([eos_token_id])]) 217 | elif bos_token_id is not None and eos_token_id is None: 218 | x = torch.cat([torch.tensor([bos_token_id]), x.clone().detach()]) 219 | elif bos_token_id is None and eos_token_id is None: 220 | pass 221 | else: 222 | x = torch.cat([torch.tensor([bos_token_id]), x.clone().detach(), torch.tensor([eos_token_id])]) 223 | return x 224 | 225 | 226 | # Used for passage retrieval 227 | def load_passages(path): 228 | if not os.path.exists(path): 229 | logger.info(f"{path} does not exist") 230 | return 231 | logger.info(f"Loading passages from: {path}") 232 | passages = [] 233 | with open(path) as fin: 234 | if path.endswith(".jsonl"): 235 | for k, line in enumerate(fin): 236 | ex = json.loads(line) 237 | passages.append(ex) 238 | else: 239 | reader = csv.reader(fin, delimiter="\t") 240 | for k, row in enumerate(reader): 241 | if not row[0] == "id": 242 | ex = {"id": row[0], "title": row[2], "text": row[1]} 243 | passages.append(ex) 244 | return passages 245 | -------------------------------------------------------------------------------- /data_curation/main_unambiguous_answer_generate.py: -------------------------------------------------------------------------------- 1 | import datasets 2 | import os 3 | import json 4 | import argparse 5 | from llm_agent import LMAgent, \ 6 | OpenAIRewriterLMAgent, \ 7 | OpenAIJudgerLMAgent, \ 8 | OpenAIGeneratorLMAgent 9 | from llm_agent.datatypes import State, Action 10 | from llm_agent.prompt_template import UnambiguousGeneratorTemplateLong, UnambiguousGeneratorTemplateShort 11 | from tools.duckduckgo_rapidapi import DDGSQueryRun 12 | from tqdm import tqdm 13 | 14 | def call_DDGS(ins: dict, DDGS_agent: DDGSQueryRun, rewriter_agent=None, rewritten_queries=[], state:State = None): 15 | """ 16 | directly return the top_k results 17 | 18 | :param ins: 19 | :return: 20 | retrieved_content 21 | query 22 | """ 23 | 24 | original_query = ins["question"] 25 | 26 | if rewriter_agent is not None: 27 | response = rewriter_agent.act(original_query=original_query, rewritten_queries=rewritten_queries) 28 | else: 29 | response = original_query 30 | 31 | 32 | retrieved_evidences = DDGS_agent(response) 33 | 34 | return retrieved_evidences, response 35 | 36 | def call_judge(ins: dict, original_query: str, retrieved_evidences: list, reference_output:str, judger_agent:OpenAIJudgerLMAgent): 37 | """ 38 | call the llm to give the reward on the query and retrieved results, (investigate the fine grained prompt later) 39 | 40 | :param ins: 41 | :return: 42 | """ 43 | 44 | response = judger_agent.act(original_query=original_query, retrieved_evidences=retrieved_evidences, reference_output=reference_output) 45 | 46 | return response 47 | 48 | 49 | def call_generator_short(ins: dict, prompt_template: str, generator_agent: OpenAIGeneratorLMAgent, **kwargs): 50 | 51 | """ 52 | 53 | this function is used to generate the short answer for each unambiguous question; 54 | we take in the unambiguous question and the reference answer from dataset and ask llm to 55 | generate a reorganized answer if necessary. 56 | 57 | """ 58 | 59 | kwargs["original_query"] = ins["question"] 60 | kwargs["search_results"] = ins["search_results"] 61 | kwargs["original_answers"] = ins["short_answers"] 62 | 63 | response = generator_agent.act(template=prompt_template, 64 | **kwargs) 65 | 66 | return response 67 | 68 | 69 | def call_generator_long(ins: dict, prompt_template: str, generator_agent: OpenAIGeneratorLMAgent, **kwargs): 70 | 71 | """ 72 | 73 | similar to func: call_generator_short; 74 | this function is used to generate the long-form answer given all the unambiguous answer and its corresponding context 75 | and the llm are asked to generate a reorganized answer if necessary. 76 | 77 | the reason why we wrote 2 functions is because that not only the prompt template is different, but the logic and some information is also different 78 | 79 | """ 80 | 81 | kwargs["ambiguous_question"] = ins["ambiguous_question"] 82 | 83 | unambiguous_answers = "" 84 | 85 | for index, cur_unambiguous in enumerate(ins["qa_pairs"]): 86 | unambiguous_answers += f"Unambiguous Question{index}: {cur_unambiguous['question']}\n" 87 | unambiguous_answers += f"Evidence{index}: {cur_unambiguous['search_results']} \n" 88 | unambiguous_answers += f"\n*** Answer{index}: {cur_unambiguous['gpt_responses']} ***\n\n" 89 | 90 | kwargs["unambiguous_questions_with_answers"] = unambiguous_answers 91 | 92 | original_long_form_answer = "" 93 | 94 | for index, long_form_answer in enumerate(ins["annotations"]): 95 | original_long_form_answer += f"{long_form_answer['long_answer']} \n" 96 | 97 | kwargs["original_answer"] = original_long_form_answer 98 | 99 | response = generator_agent.act(template=prompt_template, **kwargs) 100 | 101 | return response 102 | 103 | 104 | def print_usage(agent_list:list): 105 | 106 | total_usage_profiles = {} 107 | 108 | for agent in agent_list: 109 | for usage_profile in agent.usage_profiles: 110 | for key, value in usage_profile.items(): 111 | if key in total_usage_profiles: 112 | total_usage_profiles[key] += value 113 | else: 114 | total_usage_profiles[key] = value 115 | 116 | print("\n", total_usage_profiles) 117 | 118 | 119 | 120 | def main(): 121 | parser = argparse.ArgumentParser() 122 | parser.add_argument('--raw_data_path_ambiguous', type=str, nargs="+", default=None, help="for ambiguous dataset we assume the field is different from original ['instruction', 'input', 'output']") 123 | parser.add_argument('--ndocs', type=int, default=3, help="the number of retrieved evidences") 124 | parser.add_argument('--output_path', type=str, default=None, help="the output path of our generated data") 125 | parser.add_argument('--n_samples', type=int, default=None,help="choose a fraction of number to create") 126 | parser.add_argument('--overwrite_output', default=False, action="store_true", help="decide whether to overwrite the outputdir") 127 | parser.add_argument('--api_type', type=str, default="openai", help="choose in openai, azure") 128 | parser.add_argument('--search_engine_type', type=str, default="duckduckgo") 129 | parser.add_argument('--openai_api_key', type=str, default=None) 130 | parser.add_argument('--openai_api_base', type=str, default=None) 131 | parser.add_argument('--is_qa', default=False, action="store_true") 132 | args = parser.parse_args() 133 | 134 | openai_config = { 135 | "api_key": args.openai_api_key, 136 | "base_url": args.openai_api_base, 137 | } 138 | 139 | if args.search_engine_type == "duckduckgo": 140 | search_engine_api = DDGSQueryRun(max_results=args.ndocs) 141 | 142 | if args.api_type == "azure": 143 | 144 | azure_config = { 145 | } 146 | 147 | rewriter_agent = OpenAIRewriterLMAgent(api_type="azure", config=azure_config) 148 | judger_agent = OpenAIJudgerLMAgent(api_type="azure", config=azure_config) 149 | generator_agent = OpenAIGeneratorLMAgent(api_type="azure", config=azure_config) 150 | 151 | elif args.api_type == "openai": 152 | 153 | rewriter_agent = OpenAIRewriterLMAgent(api_type="openai", config=openai_config) 154 | judger_agent = OpenAIJudgerLMAgent(api_type="openai", config=openai_config) 155 | generator_agent = OpenAIGeneratorLMAgent(api_type="openai", config=openai_config) 156 | 157 | os.makedirs(os.path.dirname(args.output_path), exist_ok=True) 158 | 159 | original_data = [] 160 | for path in args.raw_data_path_ambiguous: 161 | if path.endswith(".json"): 162 | cur_data = json.load(open(path)) 163 | original_data.extend(cur_data) 164 | elif "hf_dataset" in path: 165 | cur_data = datasets.load_from_disk(path)["train"] 166 | cur_data_list = [item for item in cur_data] 167 | original_data.extend(cur_data_list) 168 | else: 169 | # assert the data from hf 170 | cur_data = datasets.load_dataset(path)["train"] 171 | cur_data_list = [item for item in cur_data] 172 | original_data.extend(cur_data_list) 173 | 174 | start_index = 0 175 | if os.path.exists(args.output_path)and not args.overwrite_output: 176 | 177 | with open(args.output_path) as f: 178 | generated_data = json.load(f) 179 | 180 | if not args.overwrite_output: 181 | print(f"Detect pre-existing data {len(generated_data)} samples, continue!") 182 | start_index = len(generated_data) 183 | else: 184 | print(f"Detect pre-existing data, but overwrite! check if expected!") 185 | generated_data = [] 186 | else: 187 | generated_data = [] 188 | 189 | if args.n_samples is not None: 190 | original_data = original_data[:args.n_samples] 191 | original_data = original_data[start_index:] 192 | 193 | for index, ins in tqdm(enumerate(original_data), total=len(original_data)): 194 | # continue to count index 195 | real_index = index + start_index 196 | 197 | # 1. for ambiguous data like asqa we will need to first retrieve background for various disambigous question 198 | # 2. then generate a refined answer if necessary 199 | 200 | for cur_unambiguous in ins["qa_pairs"]: 201 | cur_evidences, cur_query = call_DDGS(cur_unambiguous, DDGS_agent=search_engine_api, state=None) 202 | 203 | # add search results to every unambiguous questions 204 | cur_unambiguous["search_results"] = cur_evidences 205 | 206 | cur_unambiguous_response = call_generator_short(cur_unambiguous, 207 | prompt_template=UnambiguousGeneratorTemplateShort, 208 | generator_agent=generator_agent,) 209 | 210 | # add llm_generated_answers to every unambiguous questions 211 | cur_unambiguous["gpt_responses"] = cur_unambiguous_response 212 | 213 | cur_long_ans = call_generator_long(ins, 214 | prompt_template=UnambiguousGeneratorTemplateLong, 215 | generator_agent=generator_agent,) 216 | 217 | # add llm_generated_long_form answer to ambiguous questions 218 | ins["gpt_responses_long_form"] = cur_long_ans 219 | 220 | print_usage([rewriter_agent, judger_agent, generator_agent]) 221 | 222 | generated_data.append(ins) 223 | 224 | with open(args.output_path, "w") as w: 225 | json.dump(generated_data, w, indent=4) 226 | 227 | print("done") 228 | 229 | 230 | if __name__ == '__main__': 231 | main() 232 | -------------------------------------------------------------------------------- /retrieval_lm/streamlit_utils/inference_func.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import logging 3 | import os 4 | import spacy 5 | import jsonlines 6 | import asyncio 7 | from transformers import AutoTokenizer, AutoModelForCausalLM 8 | # from vllm import LLM, SamplingParams 9 | import random 10 | import torch 11 | import numpy as np 12 | import openai 13 | from tqdm import tqdm 14 | import json 15 | import argparse 16 | import ast 17 | import re 18 | from tqdm import tqdm, trange 19 | from collections import Counter 20 | import string 21 | import sys 22 | import time 23 | 24 | from data_curation.tools.duckduckgo_rapidapi import DDGSQueryRun 25 | 26 | seed = 633 27 | 28 | torch.backends.cudnn.deterministic = True 29 | random.seed(seed) 30 | np.random.seed(seed) 31 | torch.manual_seed(seed) 32 | torch.cuda.manual_seed_all(seed) 33 | 34 | 35 | def init_model_and_tokenizer_and_tool(model_name_or_path:str): 36 | 37 | tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, padding_side="left", trust_remote_code=True) 38 | model = AutoModelForCausalLM.from_pretrained( 39 | model_name_or_path, 40 | from_tf=bool(".ckpt" in model_name_or_path), 41 | device_map="auto", 42 | torch_dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16, 43 | ) 44 | 45 | # this line is important for us to control when do we stop 46 | model.generation_config.eos_token_id = [tokenizer.convert_tokens_to_ids(""), 47 | tokenizer.convert_tokens_to_ids("[EOS]")] 48 | 49 | model.generation_config.max_new_tokens = 512 50 | 51 | search_engine_api = DDGSQueryRun(3) 52 | 53 | return model, tokenizer, search_engine_api 54 | 55 | 56 | def format_evidences(evidences: list): 57 | tmp = [] 58 | 59 | for in_batch_evidences in evidences: 60 | ttmp = "" 61 | for cur_evidence in in_batch_evidences: 62 | ttmp += "Title: " + cur_evidence["title"] + "\n" 63 | ttmp += "Text: " + cur_evidence["text"] + "\n" 64 | 65 | tmp.append(ttmp) 66 | 67 | return tmp 68 | 69 | 70 | async def async_wrapper(search_engine_api, query): 71 | loop = asyncio.get_event_loop() 72 | 73 | if query == "": 74 | print("in async_wrapper, this query is '', might finish search, " 75 | "but for supporting batch decoding, still need to return [], do not have effect.") 76 | return [] 77 | 78 | result = await loop.run_in_executor(None, search_engine_api, query) 79 | return result 80 | 81 | 82 | # coroutine function 83 | async def process_query_list(search_engine_api, query_list): 84 | tasks = [async_wrapper(search_engine_api, query) for query in query_list] 85 | results = await asyncio.gather(*tasks) 86 | return results 87 | 88 | 89 | def generate_and_retrieve(examples, model, tokenizer, special_tokens_dict, search_engine_api, 90 | search_limit, **kwargs): 91 | special_tokens_id = [] 92 | for special in ["[S_Decomposed_Query]", "[S_Rewritten_Query]", "[S_Disambiguated_Query]", "[S_Response]", 93 | "[A_Response]"]: 94 | special_tokens_id.append(special_tokens_dict[special]) 95 | 96 | formatted_prediction = [""] * len(examples) 97 | raw_prediction = [""] * len(examples) 98 | 99 | # we are going to iterative generate until we reach the answer, but limit the search num 100 | search_nums = [0] * len(examples) 101 | is_done = [False] * len(examples) 102 | cur_examples = examples 103 | while True: 104 | # for llama tokenizer, it will always add before the sentence, 105 | # if we do iterative generation we do not want this features, so add in preprocess 106 | inputs = tokenizer(cur_examples, return_tensors="pt", add_special_tokens=False, padding=True).to(model.device) 107 | # intermediate_prediction = model.generate(**inputs, stopping_criteria=stopping_criteria) 108 | intermediate_prediction = model.generate(**inputs, logits_processor=kwargs["logits_processor"]) 109 | 110 | # find what is before the stop word (batch). 111 | batch_special_index = [] 112 | for in_batch_index, cur_ins in enumerate(intermediate_prediction): 113 | # that's the reason we do not change the generation_config.pad from 0 to 32..., 114 | # because the we want it to be zero in this case 115 | find_special_token = None 116 | special_index = len(cur_ins) - 1 117 | while special_index >= 0: 118 | # [... [EOS] ... ] we do not want to stop at EOS 119 | if cur_ins[special_index] in special_tokens_id and cur_ins[special_index] != tokenizer.convert_tokens_to_ids("[EOS]"): 120 | find_special_token = cur_ins[special_index] 121 | break 122 | special_index -= 1 123 | 124 | batch_special_index.append([find_special_token, special_index]) 125 | 126 | # get_batch_query 127 | batch_queries = [""] * len(examples) 128 | batch_intermediate_string = tokenizer.batch_decode(intermediate_prediction) 129 | 130 | for in_batch_index, (find_special_token, special_index) in enumerate(batch_special_index): 131 | 132 | if find_special_token == tokenizer.convert_tokens_to_ids("[A_Response]"): 133 | # currently let model decide when to generate [A_response], add hard constraint afterwards 134 | # we generate the final answer 135 | if not batch_intermediate_string[in_batch_index].endswith(""): 136 | cur_example = batch_intermediate_string[in_batch_index] + "" # make sure to add this 137 | else: 138 | cur_example = batch_intermediate_string[in_batch_index] 139 | cur_examples[in_batch_index] = cur_example.replace("", "").replace("", "") 140 | is_done[in_batch_index] = True 141 | continue 142 | elif find_special_token == tokenizer.convert_tokens_to_ids("[S_Response]"): 143 | # do not have effect, because this is the intermediate answer of the sub query, continue 144 | cur_example = tokenizer.decode(intermediate_prediction[in_batch_index]) 145 | cur_examples[in_batch_index] = cur_example.replace("", "").replace("", "") 146 | continue 147 | 148 | elif find_special_token is None: 149 | # for mixed data training, sometimes do not trigger special token 150 | if not batch_intermediate_string[in_batch_index].endswith(""): 151 | cur_example = batch_intermediate_string[in_batch_index] + "" # make sure to add this 152 | else: 153 | cur_example = batch_intermediate_string[in_batch_index] 154 | cur_examples[in_batch_index] = cur_example.replace("", "").replace("", "") 155 | is_done[in_batch_index] = True 156 | continue 157 | else: 158 | 159 | if is_done[in_batch_index]: 160 | continue 161 | # now we assume that we meet the query token 162 | if search_nums[in_batch_index] >= search_limit: 163 | # limit search time, do not use intermediate generated results 164 | cur_examples[in_batch_index] += "[A_Response]" 165 | continue 166 | 167 | search_nums[in_batch_index] += 1 168 | # we do not want the postfix , 169 | query_for_search = tokenizer.decode(intermediate_prediction[in_batch_index][special_index + 1:-1], 170 | skip_special_tokens=True) 171 | batch_queries[in_batch_index] = query_for_search 172 | 173 | if all(is_done): 174 | break 175 | 176 | evidences = asyncio.run(process_query_list(search_engine_api, batch_queries)) 177 | evidences_list = format_evidences(evidences) 178 | 179 | for in_batch_index, in_batch_str in enumerate(batch_intermediate_string): 180 | if len(evidences_list[in_batch_index]) != 0: 181 | cur_examples[in_batch_index] = in_batch_str.replace("", "").replace("","")\ 182 | + "[R_Evidences]" + evidences_list[in_batch_index] + "[/R_Evidences]" 183 | else: 184 | # actually we have jumped out from search_nums[in_batch_index] >= 1: , already add [A_Response] 185 | continue 186 | 187 | # yield every intermediate result here, because we are going to show it on the st.status 188 | yield { 189 | "search_query": batch_queries, 190 | "evidence_list": evidences_list, 191 | "cur_examples": cur_examples, 192 | } 193 | 194 | # extract the answer 195 | answers = [] 196 | for cur in cur_examples: 197 | pattern = r"\[A_Response\](.*?)\[EOS\]" 198 | matches = re.findall(pattern, cur, re.DOTALL) 199 | if len(matches) > 0: 200 | result = matches[-1].strip() 201 | answers.append(result) 202 | else: 203 | 204 | # in this case, do not detect special token, it might direct generate the answer 205 | pattern2 = r"<\|assistant\|>(.*?)" 206 | matches2 = re.findall(pattern2, cur, re.DOTALL) 207 | 208 | if len(matches2) > 0: 209 | result2 = matches2[-1].strip() 210 | answers.append(result2) 211 | else: 212 | answers.append("") 213 | 214 | return {"final_response": answers, "cur_examples": cur_examples} 215 | 216 | 217 | 218 | 219 | 220 | -------------------------------------------------------------------------------- /data_curation/llm_agent/prompt_template/MultiTurnGeneratorTemplateForAns.txt: -------------------------------------------------------------------------------- 1 | You are provided with a conversation history with a user, along with your previous answers, several relevant evidences retrieved from external database, and a recommended answers that might help to resolve the query. Your task is to answer the current query provided by the user. 2 | 3 | Consider the following when making your decision: 4 | 5 | Consistency with Recommended Answer: If the recommended answer from the dataset is still valid and aligns well with the provided evidences, maintain a response that does not deviate significantly from it. If not, give more priority to your own knowledge and the provided context. 6 | Updating Information: In cases where the recommended answer is outdated or evidently incorrect based on the new evidences, construct a new, more accurate and reasonable answer. 7 | Balance and Judgment: Exercise balanced judgment in situations where partial updating or reorganization of the recommended answer is required. 8 | 9 | NEVER reply that you can not answer the question. 10 | 11 | For example: 12 | 13 | --- 14 | 15 | ### Conversation History: 16 | 17 | {'role': 'user', 'content': 'Can you explain contrastive learning in machine learning in simple terms for someone new to the field of ML?', 'all_evidences': [[{'title': 'Understanding Contrastive Learning | by Ekin Tiu | Towards Data Science', 'text': "Contrastive learning is a machine learning technique used to learn the general features of a dataset without labels by teaching the model which data points are similar or different. Let's begin with a simplistic example. Imagine that you are a newborn baby that is trying to make sense of the world."}, {'title': 'Advances in Understanding, Improving, and Applying Contrastive Learning ...', 'text': 'Overview Over the past few years, contrastive learning has emerged as a powerful method for training machine learning models. It has driven a revolution in learning visual representations, powering methods like SimCLR, CLIP, and DALL-E 2. The empirical success of these methods has begged the question - what makes contrastive learning so powerful?'}, {'title': 'Contrastive Learning: A Tutorial | Built In', 'text': 'Contrastive learning involves training a model to differentiate between similar and dissimilar pairs of data points by maximizing their similarity within the same class and minimizing it between different classes. This technique has a wide range of applications, including computer vision and natural language processing.'}], [{'title': 'The Beginner's Guide to Contrastive Learning - v7labs.com', 'text': "V7 Data Annotation V7 Model Training Let's dive in. What is contrastive learning? Contrastive Learning is a Machine Learning paradigm where unlabeled data points are juxtaposed against each other to teach a model which points are similar and which are different."}, {'title': 'Contrastive Representation Learning: A Framework and Review', 'text': 'Examples of how contrastive learning has been applied in computer vision, natural language processing, audio processing, and others, as well as in Reinforcement Learning are also presented. Finally, we discuss the challenges and some of the most promising future research directions ahead. Submission history From: Phuc Le Khac [ view email]'}, {'title': 'Contrastive Representation Learning: A Framework and Review - arXiv.org', 'text': 'Examples of how contrastive learning has been applied in computer vision, natural language processing, audio processing, and others, as well as in Reinforcement Learning are also presented. Finally, we discuss the challenges and some of the most promising future research directions ahead. INDEX TERMS Contrastive learning, representation ...'}], [{'title': '[2106.05967] Revisiting Contrastive Methods for Unsupervised Learning of Visual Representations', 'text': 'Our results show that current contrastive approaches work surprisingly well across: (i) object- versus scene-centric, (ii) uniform versus long-tailed and (iii) general versus domain-specific datasets. Second, given the generality of the approach, we try to realize further gains with minor modifications.'}, {'title': 'Contrastive Learning: A Tutorial | Built In', 'text': 'Contrastive learning is most notably used for self-supervised learning, a type of unsupervised learning where the label, or supervisory signal, comes from the data itself. In the self-supervised setting, contrastive learning allows us to train encoders to learn from massive amounts of unlabeled data. Specifically, in self-supervised contrastive ...'}, {'title': 'Supervised vs. Unsupervised Learning: Key Differences - Scribbr', 'text': 'There are two main approaches to machine learning: supervised and unsupervised learning. The main difference between the two is the type of data used to train the computer. However, there are also more subtle differences.'}]], 'all_queries': ["'Explanation of contrastive learning in machine learning.'", "'Examples of contrastive learning in computer vision.'", "'Comparison of contrastive learning with other types of unsupervised learning.'"]} 18 | {'role': 'assistant', 'content': 'Sure! Let\'s say you want to build a model which can distinguish between images of cats and dogs. You gather your dataset, consisting of many cat and dog pictures. Then you put them through a neural net of your choice, which produces some representation for each image, a sequence of numbers like [0.123, 0.045, 0.334, ...]. The problem is, if your model is unfamiliar with cat and dog images, these representations will be quite random. At one time a cat and a dog picture could have very similar representations (their numbers would be close to each other), while at others two cat images may be represented far apart. In simple terms, the model wouldn\'t be able to tell cats and dogs apart. This is where contrastive learning comes in.\n\nThe point of contrastive learning is to take pairs of samples (in this case images of cats and dogs), then train the model to "pull" representations of similar pairs (cat-cat or dog-dog) closer to each other and "push" representations of different pairs (cat-dog) apart. After doing this for a sufficient number of steps, your model will be able to produce unique, reliable representations for cats and dogs, in essence tell them apart.\n\nThis method is not limited to images, you can typically use it with any dataset that has similar and dissimilar data points.'} 19 | 20 | ### Current Query: 21 | 22 | {'role': 'user', 'content': "I didn't understand how pulling and pushing works. What are we pulling or pushing? And how does it work in terms of code or maths?"} 23 | 24 | ### Recommended Answer: 25 | 26 | {'role': 'assistant', 'content': 'It has to do with a process called differentiation, which means finding a derivative: a function that tells us whether another function is going up or down at a given point. For example, the derivative of `f(x) = x` is `f\'(x) = 1`, because it\'s always going up at a 1:1 ratio. This can be done for a variety of functions; notably, if you know the derivatives of `f(x)` and `g(x)`, you can also get the derivative of `f(g(x))` using a formula called the chain rule. Neural networks happen to be made of differentiable functions, so we can take the derivative of parts or all of it.\n\nTo use this for "pushing" and "pulling", we\'ll put two images through the neural network. Let\'s say the images are of a cat and a dog, so we want to increase the distance between the two. We pick one neuron weight from the network and make it a variable `x`, then construct a function that calculates the output of the network based on it and all the other parameters; let\'s call it `N(x)`. The distance between the cat and dog outputs would be `f(x) = N(cat) - N(dog)`. (Of course, the real output would have more than 1 dimension, but we\'re simplifying.) We now want to nudge the weight such that it moves the two outputs slightly further apart. For that, we can simply take the derivative! If `f\'(x)` is positive, that means that increasing the weight will move them further apart, so we should do that. If it\'s negative, then it\'ll move them closer, so we\'ll want to slightly decrease the weight instead. Apply this to all the neurons enough times and your network will soon converge to a pretty good cat-dog separator!'} 27 | 28 | ### Your Response: 29 | 30 | In contrastive learning, we use a process called differentiation to adjust the weights of the neural network so that it produces the desired representations. We take pairs of samples and train the model to 'pull' representations of similar pairs closer to each other and 'push' representations of different pairs apart. For example, if we have a dataset of cat and dog images, we want to train the model to distinguish between them. We put the images through a neural network, which produces some representation for each image. The representations for cat images should be similar to each other, and the representations for dog images should be similar to each other, but the representations for cat and dog images should be different. We pick one neuron weight from the network and make it a variable `x`, then construct a function that calculates the output of the network based on it and all the other parameters; let's call it `N(x)`. The distance between the cat and dog outputs would be `f(x) = N(cat) - N(dog)`. We now want to nudge the weight such that it moves the two outputs slightly further apart. If `f\'(x)` is positive, that means that increasing the weight will move them further apart, so we should do that. If it's negative, then it'll move them closer, so we'll want to slightly decrease the weight instead. We repeat this process for many pairs of samples until the model produces unique, reliable representations for cats and dogs. This method can be used with any dataset that has similar and dissimilar data points, not just images. 31 | 32 | --- 33 | 34 | Now, please answer: 35 | 36 | ### Conversation History: 37 | 38 | ${cur_history_with_context} 39 | 40 | ### Current Query: 41 | 42 | ${current_query} 43 | 44 | ### Recommended Answer: 45 | 46 | ${recommended_answer} 47 | 48 | ### Your Response: 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | -------------------------------------------------------------------------------- /data_curation/llm_agent/prompt_template/MultiTurnGeneratorTemplate.txt: -------------------------------------------------------------------------------- 1 | In a multi-turn dialogue scenario, your task is to determine whether it is necessary to use a search engine to answer a user's query. Consider the following two scenarios: 2 | 3 | Non-Informational Replies: Sometimes, users may respond with statements or expressions that do not require information retrieval, such as "thank you" or "okay." In these cases, assess whether a search engine query is necessary. 4 | 5 | Ambiguous or Unclear Queries: At times, a user's query might be unclear or lack specific details. Your role is to recognize the user's intent and rewrite the query to make it clearer and more precise, facilitating an effective search engine query. 6 | 7 | Previously Answered Queries: Check if the current query or a similar one has been previously asked and answered in the conversation history. If relevant information or evidences have already been provided, acknowledge this and avoid repeating the search. 8 | 9 | Based on the above 3 scenarios, please reply with the following format strictly: 10 | 11 | For the case that do not need to query the search engine, output as follows: 12 | 13 | ### Response: 14 | 15 | Retrieval Necessity: 16 | No 17 | 18 | For the case that the retrieval is necessary, output as follows: 19 | 20 | ### Response For Retrieval Necessity: 21 | 22 | Retrieval Necessity: 23 | Yes 24 | Query For Search Engine: 25 | [...] 26 | 27 | For examples: 28 | 29 | --- 30 | 31 | In the case that do not need to retrieve; 32 | 33 | ### Conversation History: 34 | 35 | 36 | {'role': 'user', 'content': 'Explain the significance of the American Revolution, including the events that led up to it, the impact it had on the world, and its ongoing relevance today.'} 37 | {'role': 'assistant', 'content': "The American Revolution was a significant event in world history as it marked the first successful colonial war of independence against a European power. The events leading up to the revolution were a combination of political, economic, and philosophical factors. The British government's attempts to exert greater control over the colonies and the imposition of taxes without representation in the colonial legislatures were among the key factors that led to the revolution.\n\nThe American Revolution had a profound impact on the world and changed the course of history. It led to the creation of the United States of America and established the principle that governments derive their power from the consent of the governed. This concept was a cornerstone of modern democracy and has been influential in shaping political systems around the world.\n\nThe American Revolution also had a profound impact on the ideas of liberty and freedom, inspiring other revolutionary movements and helping to spread the ideals of democracy and human rights. The American Revolution is still relevant today as its principles continue to inspire people around the world who are fighting for their own freedom and democracy. Additionally, the ideas and political institutions established during the American Revolution continue to shape American society and the broader world today."} 38 | 39 | 40 | ### Current User's Query: 41 | 42 | {'role': 'user', 'content': 'This is pretty much I would like to know, thank you!'} 43 | 44 | ### Response For Retrieval Necessity: 45 | 46 | Retrieval Necessity: 47 | No 48 | 49 | --- 50 | 51 | In the case that retrieval is necessary, and the original query from user is fine; 52 | 53 | ### Conversation History: 54 | 55 | {'role': 'user', 'content': 'Explain the significance of the American Revolution, including the events that led up to it, the impact it had on the world, and its ongoing relevance today.'} 56 | {'role': 'assistant', 'content': "The American Revolution was a significant event in world history as it marked the first successful colonial war of independence against a European power. The events leading up to the revolution were a combination of political, economic, and philosophical factors. The British government's attempts to exert greater control over the colonies and the imposition of taxes without representation in the colonial legislatures were among the key factors that led to the revolution.\n\nThe American Revolution had a profound impact on the world and changed the course of history. It led to the creation of the United States of America and established the principle that governments derive their power from the consent of the governed. This concept was a cornerstone of modern democracy and has been influential in shaping political systems around the world.\n\nThe American Revolution also had a profound impact on the ideas of liberty and freedom, inspiring other revolutionary movements and helping to spread the ideals of democracy and human rights. The American Revolution is still relevant today as its principles continue to inspire people around the world who are fighting for their own freedom and democracy. Additionally, the ideas and political institutions established during the American Revolution continue to shape American society and the broader world today."} 57 | {'role': 'user', 'content': 'Which had the most impact between American and French Revolution?'} 58 | {'role': 'assistant', 'content': 'It is difficult to determine which revolution had the most impact as both the American Revolution and the French Revolution had far-reaching effects on world history. The American Revolution was the first successful colonial revolt against a European power and resulted in the formation of the United States of America as an independent nation. This revolution inspired other colonial movements and helped spread the idea of democracy and liberty around the world.\n\nOn the other hand, the French Revolution was a major turning point in European history and had a profound impact on the political and social structures of France and Europe as a whole. It marked the end of the monarchy and the rise of the modern nation-state, leading to the establishment of a democratic republic in France and inspiring similar movements in other parts of Europe. The French Revolution also contributed to the spread of nationalism, challenged the power of the Catholic Church, and laid the foundation for the concept of human rights.\n\nIn conclusion, both the American and French Revolutions had significant impacts on world history and it is difficult to say which one had the most impact. Each revolution played a key role in shaping the modern world in its own unique way.'} 59 | 60 | ### Current User's Query: 61 | 62 | {'role': 'user', 'content': 'In how many possible states can a Rubik\'s Cube be?'} 63 | 64 | ### Response For Retrieval Necessity: 65 | 66 | Retrieval Necessity: 67 | Yes 68 | Query For Search Engine: 69 | "In how many possible states can a Rubik\'s Cube be?" 70 | 71 | --- 72 | 73 | In the case that retrieval is necessary, and the user's query is not clear, thus help it rewritten; 74 | 75 | ### Conversation History: 76 | 77 | {'role': 'user', 'content': 'Explain the significance of the American Revolution, including the events that led up to it, the impact it had on the world, and its ongoing relevance today.'} 78 | {'role': 'assistant', 'content': "The American Revolution was a significant event in world history as it marked the first successful colonial war of independence against a European power. The events leading up to the revolution were a combination of political, economic, and philosophical factors. The British government's attempts to exert greater control over the colonies and the imposition of taxes without representation in the colonial legislatures were among the key factors that led to the revolution.\n\nThe American Revolution had a profound impact on the world and changed the course of history. It led to the creation of the United States of America and established the principle that governments derive their power from the consent of the governed. This concept was a cornerstone of modern democracy and has been influential in shaping political systems around the world.\n\nThe American Revolution also had a profound impact on the ideas of liberty and freedom, inspiring other revolutionary movements and helping to spread the ideals of democracy and human rights. The American Revolution is still relevant today as its principles continue to inspire people around the world who are fighting for their own freedom and democracy. Additionally, the ideas and political institutions established during the American Revolution continue to shape American society and the broader world today."} 79 | {'role': 'user', 'content': 'Which had the most impact between American and French Revolution?'} 80 | {'role': 'assistant', 'content': 'It is difficult to determine which revolution had the most impact as both the American Revolution and the French Revolution had far-reaching effects on world history. The American Revolution was the first successful colonial revolt against a European power and resulted in the formation of the United States of America as an independent nation. This revolution inspired other colonial movements and helped spread the idea of democracy and liberty around the world.\n\nOn the other hand, the French Revolution was a major turning point in European history and had a profound impact on the political and social structures of France and Europe as a whole. It marked the end of the monarchy and the rise of the modern nation-state, leading to the establishment of a democratic republic in France and inspiring similar movements in other parts of Europe. The French Revolution also contributed to the spread of nationalism, challenged the power of the Catholic Church, and laid the foundation for the concept of human rights.\n\nIn conclusion, both the American and French Revolutions had significant impacts on world history and it is difficult to say which one had the most impact. Each revolution played a key role in shaping the modern world in its own unique way.'} 81 | 82 | ### Current User's Query: 83 | 84 | {'role': 'user', 'content': 'Tell me something about that famous space thing.'} 85 | 86 | ### Response For Retrieval Necessity: 87 | 88 | Retrieval Necessity: 89 | Yes 90 | Query For Search Engine: 91 | "Provide details about the Apollo 11 moon landing mission." 92 | "Biography and achievements of astronaut Neil Armstrong." 93 | "Summary and reviews of the movie 'Interstellar'." 94 | 95 | --- 96 | 97 | As outlined, it is necessary to output Retrieval Necessity first, and the output should be one of the "yes" and "no" and the query for the search engine should be split by a line break. 98 | For most of the case, retrieval process might help you better answer the question, only skip the retrival process when you are fairly confident about not doing so. 99 | Now, please answer: 100 | 101 | ### Conversation History: 102 | 103 | ${conversation_history} 104 | 105 | ### Current User's Query: 106 | 107 | ${current_query} 108 | 109 | ### Response For Retrieval Necessity: 110 | 111 | 112 | --------------------------------------------------------------------------------