├── README.md ├── assets └── DeepNote.png ├── config ├── config.yaml └── ds_config_zero3.json ├── prompts └── en │ ├── base │ ├── base_asqa │ ├── base_strategyqa │ ├── base_wo_retri │ ├── base_wo_retri_asqa │ ├── base_wo_retri_strategyqa │ ├── compare │ ├── gen_answer │ ├── gen_answer_asqa │ ├── gen_answer_strategyqa │ ├── gen_new_query │ ├── init_note │ └── refine_note ├── requirements.txt └── src ├── build_index ├── emb │ └── index.py └── es │ ├── index_2wiki.py │ ├── index_hotpotqa.py │ └── index_musique.py ├── es_retrieve.py ├── eval.py ├── gen_dpo_data.py ├── main.py ├── select_dpo_data.py ├── train.py ├── train.sh └── utils.py /README.md: -------------------------------------------------------------------------------- 1 |

2 | DeepNote: Note-Centric Deep Retrieval-Augmented Generation 3 |

4 | 5 | 6 | We develop **DeepNote**, an adaptive RAG framework that achieves in-depth 7 | and robust exploration of knowledge sources through note-centric adaptive retrieval. DeepNote employs notes as carriers for refining and accumulating knowledge. During in-depth exploration, it uses these notes to determine retrieval timing, formulate retrieval queries, and iteratively assess knowledge growth, ultimately leveraging the best note for answer generation. 8 | 9 | ![RNote](assets/DeepNote.png) 10 | 11 | # Prepare Datasets 12 | 13 | All corpus and evaluation files should be placed in the `/data` directory. You can download the experimental data [here](https://drive.google.com/drive/folders/1NeEm-r7l43MQxGS1n7jJ8tPvltgcaPjY?usp=sharing). 14 | 15 | We use Wikipedia as the corpus for ASQA and StrategyQA. Due to its large size, please download it separately [here](https://dl.fbaipublicfiles.com/dpr/wikipedia_split/psgs_w100.tsv.gz) and place it in `/data/corpus/wiki/`. 16 | 17 | # Retrieval Settings 18 | 19 | For different datasets, we employ various retrieval methods: 20 | 21 | For 2WikiMQA, MusiQue, and HotpotQA: 22 | - BM25 retrieval based on ElasticSearch 23 | - Dense retrieval with FAISS index using embeddings from BGE model 24 | 25 | For ASQA and StrategyQA: 26 | - Dense retrieval with FAISS index using embeddings from GTR model 27 | 28 | ## Setup ElasticSearch 29 | 30 | Install Elasticsearch: 31 | ```bash 32 | wget https://artifacts.elastic.co/downloads/elasticsearch/elasticsearch-7.10.2-linux-x86_64.tar.gz 33 | wget https://artifacts.elastic.co/downloads/elasticsearch/elasticsearch-7.10.2-linux-x86_64.tar.gz.sha512 34 | shasum -a 512 -c elasticsearch-7.10.2-linux-x86_64.tar.gz.sha512 35 | tar -xzf elasticsearch-7.10.2-linux-x86_64.tar.gz 36 | cd elasticsearch-7.10.2/ 37 | ./bin/elasticsearch # Start the server 38 | pkill -f elasticsearch # To stop the server 39 | ``` 40 | 41 | ## Build Indices 42 | 43 | ### For BM25 44 | ```bash 45 | cd src/build_index/es 46 | 47 | # 2WikiMQA 48 | python index_2wiki.py 49 | 50 | # MusiQue 51 | python index_musique.py 52 | 53 | # HotpotQA 54 | python index_hotpotqa.py 55 | ``` 56 | 57 | ### For Dense Retrieval 58 | 59 | #### For HotpotQA, 2WikiMQA, and MusiQue 60 | ```bash 61 | cd src/build_index/emb 62 | python index.py --dataset hotpotqa --model bge-base-en-v1.5 # e.g., for HotpotQA dataset 63 | ``` 64 | 65 | #### For ASQA and StrategyQA 66 | Since generating GTR embeddings for Wikipedia corpus is time-consuming, you can download the pre-computed GTR embeddings and place them in `data/corpus/wiki/`: 67 | ```bash 68 | wget https://huggingface.co/datasets/princeton-nlp/gtr-t5-xxl-wikipedia-psgs_w100-index/resolve/main/gtr_wikipedia_index.pkl 69 | ``` 70 | 71 | Then build FAISS index: 72 | ```bash 73 | cd src/build_index/emb 74 | python index.py --dataset asqa --model gtr-t5-xxl 75 | ``` 76 | 77 | # Configuration 78 | 79 | You can configure your API key, URL, and other settings in the `./config/config.yaml` file. 80 | 81 | 82 | # Training DeepNote 83 | 84 | The training process consists of three main steps: 85 | 86 | ## 1. Generate Training Data 87 | Generate the initial training data using LLaMA model: 88 | ```bash 89 | python gen_dpo_data.py \ 90 | --model llama-3.1-8b-instruct \ 91 | --batch_size 9 \ 92 | --output_path ../data/dpo_data \ 93 | --device 0,1,2,3 94 | ``` 95 | 96 | ## 2. Data Selection 97 | Filter and process the generated data: 98 | ```bash 99 | python select_dpo_data.py \ 100 | --output_path ../data/dpo/processed/train.jsonl \ 101 | --init_num 1900 \ 102 | --refine_num 1900 \ 103 | --query_num 1900 104 | ``` 105 | 106 | ## 3. Start Training 107 | Launch the training process: 108 | ```bash 109 | bash train.sh 110 | ``` 111 | 112 | 113 | # Running DeepNote and Evaluation 114 | 115 | ```bash 116 | python main.py --method deepnote --retrieve_top_k 5 --dataset hotpotqa --max_step 3 --max_fail_step 2 --MaxClients 5 --model gpt-4o-mini-2024-07-18 --device cuda:0 117 | ``` 118 | The predicted results and evaluation metrics will be automatically saved in the `output/{dataset}/` directory. The evaluation results can be found at the end of the file. 119 | -------------------------------------------------------------------------------- /assets/DeepNote.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/thunlp/DeepNote/cc2f0132737e04ec2d895e8c21673cee612ca1e7/assets/DeepNote.png -------------------------------------------------------------------------------- /config/config.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | OPENAI_API_KEY: "" 3 | OPENAI_BASE_URL: "" 4 | llama-3.1-8b-instruct: "meta-llama/Llama-3.1-8B-Instruct" 5 | llama-3.1-8b-instruct-dpo: "" 6 | llama-3.1-70b-instruct: "meta-llama/Llama-3.1-70B-Instruct" 7 | qwen2.5-7b-instruct: "Qwen/Qwen2.5-7B-Instruct" 8 | qwen2.5-7b-instruct-dpo: "" 9 | bge-base-en-v1.5: "BAAI/bge-base-en-v1.5" 10 | gtr-t5-xxl: "sentence-transformers/gtr-t5-xxl" 11 | es: 12 | url: "http://localhost:9200" 13 | 14 | score: 15 | init_note: 'Task: You will receive a list of notes generated based on a given document content and question. \nYour task is to evaluate and score these notes based on their quality. Quality refers to: relevance, coherence, completeness in answering the specified question, and accuracy of information. \n\nQuestion to be answered: {query}\nDocument content: {refs}\n\nGenerated notes: {notes}\nNote format: Each note contains "_id" and "content" fields. \n\nEvaluate the generated notes. The highest-scoring note must be factually correct based on the document. If no note is correct, or if there is minimal quality difference between notes, use the same _id for both best and worst. \nOutput in the following JSON format:\n```json {{"best_id": <_id of the highest-scoring note>, "worst_id": <_id of the lowest-scoring note>}}```\n\n Do not include any explanations or additional text.' 16 | gen_new_query: 'Task: You will receive a list of new questions generated based on some notes and an existing question list to supplement a given original question.\nYour task is to evaluate these new questions based on their quality. Quality refers to: relevance, specificity, keyword richness, and non-redundancy. The goal is to identify questions that can retrieve useful information to help answer the original question. \n Notes: {notes}\n\nOriginal question: {query}\n\nExisting question list: {query_log}\n\nNew question list: {new_querys}\nQuestion format: Each question contains "_id" and "content" fields. \n\nEvaluate the new question list. The highest-scoring new question must be able to help retrieve relevant information to answer the original question. If no new question can help get useful information, or if there is minimal quality difference between new questions, use the same _id for both best_id and worst_id. \nOutput in the following format:\n```json {{"best_id": <_id of the highest-scoring question>, "worst_id": <_id of the lowest-scoring question>}}```\n\n Do not include any explanations or additional text.' 17 | -------------------------------------------------------------------------------- /config/ds_config_zero3.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "optimizer": { 14 | "type": "AdamW", 15 | "params": { 16 | "lr": "auto", 17 | "betas": "auto", 18 | "eps": "auto", 19 | "weight_decay": "auto" 20 | } 21 | }, 22 | "zero_optimization": { 23 | "stage": 3, 24 | "offload_optimizer": { 25 | "device": "cpu", 26 | "pin_memory": true 27 | }, 28 | "offload_param": { 29 | "device": "cpu", 30 | "pin_memory": true 31 | }, 32 | "overlap_comm": true, 33 | "contiguous_gradients": true, 34 | "sub_group_size": 1e9, 35 | "reduce_bucket_size": "auto", 36 | "stage3_prefetch_bucket_size": "auto", 37 | "stage3_param_persistence_threshold": "auto", 38 | "stage3_max_live_parameters": 1e9, 39 | "stage3_max_reuse_distance": 1e9, 40 | "stage3_gather_16bit_weights_on_model_save": true 41 | }, 42 | "gradient_accumulation_steps": "auto", 43 | "gradient_clipping": "auto", 44 | "steps_per_print": 2000, 45 | "train_batch_size": "auto", 46 | "train_micro_batch_size_per_gpu": "auto", 47 | "wall_clock_breakdown": false 48 | } -------------------------------------------------------------------------------- /prompts/en/base: -------------------------------------------------------------------------------- 1 | Answer the question based on the given passages. Only give me the answer and do not output any other words. 2 | 3 | The following are given passages: 4 | {refs} 5 | 6 | Question: {query} 7 | Answer: -------------------------------------------------------------------------------- /prompts/en/base_asqa: -------------------------------------------------------------------------------- 1 | Instruction: Write an accurate, engaging, and concise answer for the given question using only the provided passages. Use an unbiased and journalistic tone. 2 | 3 | 4 | Question: 5 | {query} 6 | 7 | 8 | Passages: 9 | {refs} 10 | 11 | -------------------------------------------------------------------------------- /prompts/en/base_strategyqa: -------------------------------------------------------------------------------- 1 | Answer the question based on the given passages. Only give me 'yes' or 'no' as your answer and do not output any other words. 2 | 3 | The following are given passages: 4 | {refs} 5 | 6 | Question: {query} 7 | Answer: -------------------------------------------------------------------------------- /prompts/en/base_wo_retri: -------------------------------------------------------------------------------- 1 | Only give me the answer and do not output any other words. 2 | 3 | Question: {query} 4 | Answer: -------------------------------------------------------------------------------- /prompts/en/base_wo_retri_asqa: -------------------------------------------------------------------------------- 1 | Instruction: Write an accurate, engaging, and concise answer for the given question. Use an unbiased and journalistic tone. 2 | 3 | Question: 4 | {query} 5 | -------------------------------------------------------------------------------- /prompts/en/base_wo_retri_strategyqa: -------------------------------------------------------------------------------- 1 | Only give me 'yes' or 'no' as your answer and do not output any other words. 2 | 3 | Question: {query} 4 | Answer: -------------------------------------------------------------------------------- /prompts/en/compare: -------------------------------------------------------------------------------- 1 | Task: Please help me determine which note is better based on the following evaluation criteria: 2 | 1. Contains key information directly related to the question. 3 | 2. Completeness of Information: Does it cover all relevant aspects and details? 4 | 3. Level of Detail: Does it provide enough detail to understand the issue in depth? 5 | 4. Practicality: Does the note offer practical help and solutions? 6 | 7 | Please make your judgment adhering strictly to the following rules: 8 | - If Note 2 does not add new meaningful content on top of Note 1, or only adds redundant information, return ```json {{"status":"False"}}``` directly. 9 | - If Note 2 has significant improvements over Note 1 based on the above criteria, return ```json {{"status":"True"}}``` directly; otherwise, return ```json {{"status":"False"}}```. 10 | 11 | Question: {query} 12 | Provided Note 1: {best_note} 13 | Provided Note 2: {new_note} 14 | 15 | Based on the above information, make your judgment without explanation and return the result directly. 16 | -------------------------------------------------------------------------------- /prompts/en/gen_answer: -------------------------------------------------------------------------------- 1 | Answer the question based on the given notes. Only give me the answer and do not output any other words. 2 | 3 | The following are given notes: 4 | {note} 5 | 6 | Question: {query} 7 | Answer: -------------------------------------------------------------------------------- /prompts/en/gen_answer_asqa: -------------------------------------------------------------------------------- 1 | Instruction: Write an accurate, engaging, and concise answer for the given question using only the provided notes. Use an unbiased and journalistic tone. 2 | 3 | 4 | Question: 5 | {query} 6 | 7 | 8 | Notes: 9 | {note} -------------------------------------------------------------------------------- /prompts/en/gen_answer_strategyqa: -------------------------------------------------------------------------------- 1 | Answer the question based on the given notes. Only give me 'yes' or 'no' as your answer and do not output any other words. 2 | 3 | The following are given notes: 4 | {note} 5 | 6 | Question: {query} 7 | Answer: -------------------------------------------------------------------------------- /prompts/en/gen_new_query: -------------------------------------------------------------------------------- 1 | Task: Based on the notes, propose two new questions. These new questions will be used to retrieve documents to supplement the notes and help answer the original question. The new questions should be concise and include keywords that facilitate retrieval. The new questions should avoid duplication with the existing question list. 2 | 3 | original question: {query} 4 | notes: {note} 5 | 6 | existing question list: {query_log} 7 | 8 | Do not print any other words. Do not explain. Output only the two new questions you asked. -------------------------------------------------------------------------------- /prompts/en/init_note: -------------------------------------------------------------------------------- 1 | Based on the provided document content, write a note. The note should integrate all relevant information from the original text that can help answer the specified question and form a coherent paragraph. Please ensure that the note includes all original text information useful for answering the question. 2 | 3 | 4 | Question to be answered: {query} 5 | Document content: {refs} 6 | 7 | 8 | Please provide the note you wrote: -------------------------------------------------------------------------------- /prompts/en/refine_note: -------------------------------------------------------------------------------- 1 | Task: Based on the retrieved documents, supplement the notes with content not yet included but useful for answering the question. The supplement should use the original text from the retrieved documents. The added content should include as much information from the retrieved documents as possible. 2 | 3 | question: {query} 4 | retrieved documents: {refs} 5 | 6 | notes: {note} 7 | 8 | Provide your supplemented notes: -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | backoff 2 | elasticsearch==7.9.1 3 | faiss_cpu 4 | jieba 5 | numpy 6 | openai==1.57.4 7 | pandas 8 | sentence_transformers 9 | torch==2.4.0 10 | tqdm 11 | transformers==4.48.0 12 | pyyaml 13 | llama-index==0.9.30 14 | vllm==0.5.4 15 | trl==0.13.0 16 | -------------------------------------------------------------------------------- /src/build_index/emb/index.py: -------------------------------------------------------------------------------- 1 | import time 2 | import os 3 | import json 4 | from llama_index import SimpleDirectoryReader 5 | import csv 6 | from sentence_transformers import SentenceTransformer 7 | import faiss 8 | import argparse 9 | import yaml 10 | from glob import glob 11 | from itertools import chain 12 | from tqdm import tqdm 13 | from llama_index import Document 14 | from llama_index.node_parser import SimpleNodeParser 15 | 16 | with open("../../../config/config.yaml", "r") as f: 17 | config = yaml.safe_load(f) 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument( 20 | "--model", 21 | type=str, 22 | choices=["bge-base-en-v1.5", "gtr-t5-xxl"], 23 | default="bge-base-en-v1.5", 24 | help="Model to use", 25 | ) 26 | parser.add_argument( 27 | "--dataset", 28 | type=str, 29 | default="2wikimultihopqa", 30 | choices=["asqa", "strategyqa", "2wikimultihopqa", "hotpotqa", "musique"], 31 | help="Dataset to use", 32 | ) 33 | parser.add_argument("--chunk_size", type=int, default=512, help="chunk size") 34 | parser.add_argument("--chunk_overlap", type=int, default=0, help="chunk overlap") 35 | parser.add_argument("--device", type=str, default="cuda:7", help="Device to use") 36 | args = parser.parse_args() 37 | 38 | 39 | def split_text(data): 40 | 41 | documents = [] 42 | for record in data: 43 | if record["title"]: 44 | combined_text = record["title"] + "\n" + record["content"] 45 | else: 46 | combined_text = record["content"] 47 | documents.append(Document(text=combined_text)) 48 | 49 | node_parser = SimpleNodeParser.from_defaults( 50 | chunk_size=args.chunk_size, chunk_overlap=args.chunk_overlap 51 | ) 52 | nodes = node_parser.get_nodes_from_documents(documents, show_progress=True) 53 | 54 | contents = [node.text for node in nodes] 55 | return contents 56 | 57 | 58 | def build_index(embeddings, vectorstore_path): 59 | dimension = embeddings.shape[1] 60 | index = faiss.IndexFlatIP(dimension) 61 | index.add(embeddings) 62 | faiss.write_index(index, vectorstore_path) 63 | 64 | 65 | if __name__ == "__main__": 66 | 67 | model = SentenceTransformer(config["model"][args.model], device=args.device) 68 | if args.dataset == "asqa" or args.dataset == "strategyqa": 69 | dataset_name = "wiki" 70 | else: 71 | dataset_name = args.dataset 72 | vectorstore_path = f"../../../data/corpus/{dataset_name}/{dataset_name}.index" 73 | contents = [] 74 | print("loading document ...") 75 | start = time.time() 76 | if dataset_name == "wiki": 77 | if os.path.exists("../../../data/corpus/wiki/gtr_wikipedia_index.pkl"): 78 | import pickle 79 | 80 | with open( 81 | "../../../data/corpus/wiki/gtr_wikipedia_index.pkl", "rb" 82 | ) as file: 83 | embeddings = pickle.load(file) 84 | else: 85 | with open( 86 | "../../../data/corpus/wiki/psgs_w100.tsv", "r", encoding="utf-8" 87 | ) as file: 88 | tsv_data = csv.DictReader(file, delimiter="\t") 89 | raw_data = [row["title"] + "\n" + row["text"] for row in tsv_data] 90 | print("dataset length", len(raw_data)) 91 | embeddings = model.encode(raw_data, batch_size=100) 92 | elif dataset_name == "2wikimultihopqa": 93 | train = json.load(open("../../../data/corpus/2wikimultihopqa/train.json", "r")) 94 | dev = json.load(open("../../../data/corpus/2wikimultihopqa/dev.json", "r")) 95 | test = json.load(open("../../../data/corpus/2wikimultihopqa/test.json", "r")) 96 | 97 | data = {} 98 | for item in tqdm(chain(train, dev, test)): 99 | for title, sentences in item["context"]: 100 | para = " ".join(sentences) 101 | data[para] = title 102 | contents = [ 103 | {"id": i, "content": text, "title": title} 104 | for i, (text, title) in enumerate(data.items()) 105 | ] 106 | elif dataset_name == "hotpotqa": 107 | import bz2 108 | from multiprocessing import Pool 109 | 110 | def process_line(line): 111 | data = json.loads(line) 112 | item = { 113 | "id": data["id"], 114 | "title": data["title"], 115 | "content": "".join(data["text"]), 116 | } 117 | return item 118 | 119 | def generate_indexing_queries_from_bz2(bz2file, dry=False): 120 | if dry: 121 | return 122 | 123 | with bz2.open(bz2file, "rt") as f: 124 | body = [process_line(line) for line in f] 125 | 126 | return body 127 | 128 | filelist = glob("../../../data/corpus/hotpotqa/*/wiki_*.bz2") 129 | 130 | print("Making indexing queries...") 131 | pool = Pool() 132 | 133 | for result in tqdm( 134 | pool.imap(generate_indexing_queries_from_bz2, filelist), total=len(filelist) 135 | ): 136 | contents.extend(result) 137 | elif dataset_name == "musique": 138 | train = [ 139 | json.loads(line.strip()) 140 | for line in open( 141 | "../../../data/corpus/musique/musique_ans_v1.0_train.jsonl" 142 | ) 143 | ] + [ 144 | json.loads(line.strip()) 145 | for line in open( 146 | "../../../data/corpus/musique/musique_full_v1.0_train.jsonl" 147 | ) 148 | ] 149 | dev = [ 150 | json.loads(line.strip()) 151 | for line in open("../../../data/corpus/musique/musique_ans_v1.0_dev.jsonl") 152 | ] + [ 153 | json.loads(line.strip()) 154 | for line in open("../../../data/corpus/musique/musique_full_v1.0_dev.jsonl") 155 | ] 156 | test = [ 157 | json.loads(line.strip()) 158 | for line in open("../../../data/corpus/musique/musique_ans_v1.0_test.jsonl") 159 | ] + [ 160 | json.loads(line.strip()) 161 | for line in open( 162 | "../../../data/corpus/musique/musique_full_v1.0_test.jsonl" 163 | ) 164 | ] 165 | 166 | tot = 0 167 | hist = set() 168 | for item in tqdm(chain(train, dev, test)): 169 | for p in item["paragraphs"]: 170 | stamp = p["title"] + " " + p["paragraph_text"] 171 | if not stamp in hist: 172 | contents.append( 173 | {"id": tot, "content": p["paragraph_text"], "title": p["title"]} 174 | ) 175 | hist.add(stamp) 176 | tot += 1 177 | 178 | contents = split_text(contents) 179 | embeddings = model.encode(contents, batch_size=600) 180 | with open( 181 | f"../../../data/corpus/{dataset_name}/chunk.json", "w", encoding="utf-8" 182 | ) as fout: 183 | json.dump(contents, fout, ensure_ascii=False) 184 | print("Building index ...") 185 | build_index(embeddings, vectorstore_path) 186 | end = time.time() 187 | -------------------------------------------------------------------------------- /src/build_index/es/index_2wiki.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | from elasticsearch import Elasticsearch 3 | import html 4 | import json 5 | from tqdm import tqdm 6 | from itertools import chain 7 | 8 | 9 | def chunks(l, n): 10 | """Yield successive n-sized chunks from l.""" 11 | for i in range(0, len(l), n): 12 | yield l[i : i + n] 13 | 14 | 15 | INDEX_NAME = "2wikimultihopqa" 16 | 17 | 18 | def process_line(data): 19 | item = { 20 | "id": data["id"], 21 | "url": "empty", 22 | "title": data["title"], 23 | "title_unescape": html.unescape(data["title"]), 24 | "text": data["text"], 25 | "title_bigram": html.unescape(data["title"]), 26 | "title_unescape_bigram": html.unescape(data["title"]), 27 | "text_bigram": data["text"], 28 | "original_json": json.dumps(data), 29 | } 30 | return "{}\n{}".format( 31 | json.dumps({"index": {"_id": "wiki-{}".format(data["id"])}}), json.dumps(item) 32 | ) 33 | 34 | 35 | es = Elasticsearch(hosts="http://localhost:9200", timeout=100) 36 | 37 | 38 | def index_chunk(chunk): 39 | res = es.bulk(index=INDEX_NAME, body="\n".join(chunk), timeout="100s") 40 | assert not res["errors"], res 41 | 42 | 43 | def main(args): 44 | 45 | train = json.load(open("../../../data/corpus/2wikimultihopqa/train.json", "r")) 46 | dev = json.load(open("../../../data/corpus/2wikimultihopqa/dev.json", "r")) 47 | test = json.load(open("../../../data/corpus/2wikimultihopqa/test.json", "r")) 48 | 49 | data = {} 50 | for item in tqdm(chain(train, dev, test)): 51 | for title, sentences in item["context"]: 52 | para = " ".join(sentences) 53 | data[para] = title 54 | wikipedia_data = [ 55 | {"id": i, "text": text, "title": title} 56 | for i, (text, title) in enumerate(data.items()) 57 | ] 58 | 59 | # make index 60 | if not args.dry: 61 | # add 62 | if es.indices.exists(index=INDEX_NAME): 63 | es.indices.delete(index=INDEX_NAME, ignore=[400, 403]) 64 | es.indices.create( 65 | index=INDEX_NAME, 66 | ignore=400, 67 | body=json.dumps( 68 | { 69 | "mappings": { 70 | "doc": { 71 | "properties": { 72 | "id": {"type": "keyword"}, 73 | "url": {"type": "keyword"}, 74 | "title": { 75 | "type": "text", 76 | "analyzer": "simple", 77 | "copy_to": "title_all", 78 | }, 79 | "title_unescape": { 80 | "type": "text", 81 | "analyzer": "simple", 82 | "copy_to": "title_all", 83 | }, 84 | "text": { 85 | "type": "text", 86 | "analyzer": "my_english_analyzer", 87 | }, 88 | "anchortext": { 89 | "type": "text", 90 | "analyzer": "my_english_analyzer", 91 | }, 92 | "title_bigram": { 93 | "type": "text", 94 | "analyzer": "simple_bigram_analyzer", 95 | "copy_to": "title_all_bigram", 96 | }, 97 | "title_unescape_bigram": { 98 | "type": "text", 99 | "analyzer": "simple_bigram_analyzer", 100 | "copy_to": "title_all_bigram", 101 | }, 102 | "text_bigram": { 103 | "type": "text", 104 | "analyzer": "bigram_analyzer", 105 | }, 106 | "anchortext_bigram": { 107 | "type": "text", 108 | "analyzer": "bigram_analyzer", 109 | }, 110 | "original_json": {"type": "string"}, 111 | } 112 | } 113 | }, 114 | "settings": { 115 | "analysis": { 116 | "my_english_analyzer": { 117 | "type": "standard", 118 | "stopwords": "_english_", 119 | }, 120 | "simple_bigram_analyzer": { 121 | "tokenizer": "standard", 122 | "filter": ["lowercase", "shingle", "asciifolding"], 123 | }, 124 | "bigram_analyzer": { 125 | "tokenizer": "standard", 126 | "filter": [ 127 | "lowercase", 128 | "stop", 129 | "shingle", 130 | "asciifolding", 131 | ], 132 | }, 133 | }, 134 | }, 135 | } 136 | ), 137 | ) 138 | 139 | print("Making indexing queries...") 140 | all_queries = [] 141 | for item in tqdm(wikipedia_data): 142 | all_queries.append(process_line(item)) 143 | 144 | count = sum(len(queries.split("\n")) for queries in all_queries) // 2 145 | 146 | if not args.dry: 147 | print("Indexing...") 148 | chunksize = 100 149 | for chunk in tqdm( 150 | chunks(all_queries, chunksize), 151 | total=(len(all_queries) + chunksize - 1) // chunksize, 152 | ): 153 | res = es.bulk(index=INDEX_NAME, body="\n".join(chunk), timeout="100s") 154 | assert not res["errors"], res 155 | 156 | print(f"{count} documents indexed in total") 157 | 158 | 159 | if __name__ == "__main__": 160 | parser = ArgumentParser() 161 | 162 | parser.add_argument("--reindex", action="store_true", help="Reindex everything") 163 | parser.add_argument("--dry", action="store_true", help="Dry run") 164 | 165 | args = parser.parse_args() 166 | 167 | main(args) 168 | -------------------------------------------------------------------------------- /src/build_index/es/index_hotpotqa.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | import bz2 3 | from elasticsearch import Elasticsearch 4 | from glob import glob 5 | import html 6 | import json 7 | from multiprocessing import Pool 8 | from tqdm import tqdm 9 | 10 | WIKIPEDIA_INDEX_NAME = "hotpotqa" 11 | 12 | 13 | def chunks(l, n): 14 | for i in range(0, len(l), n): 15 | yield l[i : i + n] 16 | 17 | 18 | def process_line(line): 19 | data = json.loads(line) 20 | item = { 21 | "id": data["id"], 22 | "url": data["url"], 23 | "title": data["title"], 24 | "title_unescape": html.unescape(data["title"]), 25 | "text": "".join(data["text"]), 26 | "title_bigram": html.unescape(data["title"]), 27 | "title_unescape_bigram": html.unescape(data["title"]), 28 | "text_bigram": "".join(data["text"]), 29 | "original_json": line, 30 | } 31 | return "{}\n{}".format( 32 | json.dumps({"index": {"_id": "wiki-{}".format(data["id"])}}), json.dumps(item) 33 | ) 34 | 35 | 36 | def generate_indexing_queries_from_bz2(bz2file, dry=False): 37 | if dry: 38 | return 39 | 40 | with bz2.open(bz2file, "rt") as f: 41 | body = [process_line(line) for line in f] 42 | 43 | return "\n".join(body) 44 | 45 | 46 | es = Elasticsearch(hosts="http://localhost:9200", timeout=50) 47 | 48 | 49 | def index_chunk(chunk): 50 | res = es.bulk(index=WIKIPEDIA_INDEX_NAME, body="\n".join(chunk), timeout="100s") 51 | assert not res["errors"], res 52 | 53 | 54 | def main(args): 55 | if not args.dry: 56 | if es.indices.exists(index=WIKIPEDIA_INDEX_NAME): 57 | es.indices.delete(index=WIKIPEDIA_INDEX_NAME, ignore=[400, 403]) 58 | if not es.indices.exists(index=WIKIPEDIA_INDEX_NAME): 59 | es.indices.create( 60 | index=WIKIPEDIA_INDEX_NAME, 61 | ignore=400, 62 | body=json.dumps( 63 | { 64 | "mappings": { 65 | "doc": { 66 | "properties": { 67 | "id": {"type": "keyword"}, 68 | "url": {"type": "keyword"}, 69 | "title": { 70 | "type": "text", 71 | "analyzer": "simple", 72 | "copy_to": "title_all", 73 | }, 74 | "title_unescape": { 75 | "type": "text", 76 | "analyzer": "simple", 77 | "copy_to": "title_all", 78 | }, 79 | "text": { 80 | "type": "text", 81 | "analyzer": "my_english_analyzer", 82 | }, 83 | "anchortext": { 84 | "type": "text", 85 | "analyzer": "my_english_analyzer", 86 | }, 87 | "title_bigram": { 88 | "type": "text", 89 | "analyzer": "simple_bigram_analyzer", 90 | "copy_to": "title_all_bigram", 91 | }, 92 | "title_unescape_bigram": { 93 | "type": "text", 94 | "analyzer": "simple_bigram_analyzer", 95 | "copy_to": "title_all_bigram", 96 | }, 97 | "text_bigram": { 98 | "type": "text", 99 | "analyzer": "bigram_analyzer", 100 | }, 101 | "anchortext_bigram": { 102 | "type": "text", 103 | "analyzer": "bigram_analyzer", 104 | }, 105 | "original_json": {"type": "string"}, 106 | } 107 | } 108 | }, 109 | "settings": { 110 | "analysis": { 111 | "my_english_analyzer": { 112 | "type": "standard", 113 | "stopwords": "_english_", 114 | }, 115 | "simple_bigram_analyzer": { 116 | "tokenizer": "standard", 117 | "filter": ["lowercase", "shingle", "asciifolding"], 118 | }, 119 | "bigram_analyzer": { 120 | "tokenizer": "standard", 121 | "filter": [ 122 | "lowercase", 123 | "stop", 124 | "shingle", 125 | "asciifolding", 126 | ], 127 | }, 128 | }, 129 | }, 130 | } 131 | ), 132 | ) 133 | 134 | filelist = glob("../../../data/corpus/hotpotqa/*/wiki_*.bz2") 135 | 136 | print("Making indexing queries...") 137 | pool = Pool() 138 | all_queries = list( 139 | tqdm( 140 | pool.imap(generate_indexing_queries_from_bz2, filelist), total=len(filelist) 141 | ) 142 | ) 143 | 144 | count = sum(len(queries.split("\n")) for queries in all_queries) // 2 145 | 146 | if not args.dry: 147 | print("Indexing...") 148 | chunksize = 50 149 | for chunk in tqdm( 150 | chunks(all_queries, chunksize), 151 | total=(len(all_queries) + chunksize - 1) // chunksize, 152 | ): 153 | res = es.bulk( 154 | index=WIKIPEDIA_INDEX_NAME, body="\n".join(chunk), timeout="100s" 155 | ) 156 | assert not res["errors"], res 157 | 158 | print(f"{count} documents indexed in total") 159 | 160 | 161 | if __name__ == "__main__": 162 | parser = ArgumentParser() 163 | 164 | parser.add_argument("--reindex", action="store_true", help="Reindex everything") 165 | parser.add_argument("--dry", action="store_true", help="Dry run") 166 | 167 | args = parser.parse_args() 168 | 169 | main(args) 170 | -------------------------------------------------------------------------------- /src/build_index/es/index_musique.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | from elasticsearch import Elasticsearch 3 | import html 4 | import json 5 | from tqdm import tqdm 6 | 7 | from itertools import chain 8 | 9 | 10 | def chunks(l, n): 11 | for i in range(0, len(l), n): 12 | yield l[i : i + n] 13 | 14 | 15 | INDEX_NAME = "musique" 16 | 17 | 18 | def process_line(data): 19 | item = { 20 | "id": data["id"], 21 | "url": "empty", 22 | "title": data["title"], 23 | "title_unescape": html.unescape(data["title"]), 24 | "text": data["text"], 25 | "title_bigram": html.unescape(data["title"]), 26 | "title_unescape_bigram": html.unescape(data["title"]), 27 | "text_bigram": data["text"], 28 | "original_json": json.dumps(data), 29 | } 30 | return "{}\n{}".format( 31 | json.dumps({"index": {"_id": "wiki-{}".format(data["id"])}}), json.dumps(item) 32 | ) 33 | 34 | 35 | es = Elasticsearch(hosts="http://localhost:9200", timeout=50) 36 | 37 | 38 | def index_chunk(chunk): 39 | res = es.bulk(index=INDEX_NAME, body="\n".join(chunk), timeout="100s") 40 | assert not res["errors"], res 41 | 42 | 43 | def main(args): 44 | # make index 45 | if not args.dry: 46 | if es.indices.exists(index=INDEX_NAME): 47 | es.indices.delete(index=INDEX_NAME, ignore=[400, 403]) 48 | es.indices.create( 49 | index=INDEX_NAME, 50 | ignore=400, 51 | body=json.dumps( 52 | { 53 | "mappings": { 54 | "doc": { 55 | "properties": { 56 | "id": {"type": "keyword"}, 57 | "url": {"type": "keyword"}, 58 | "title": { 59 | "type": "text", 60 | "analyzer": "simple", 61 | "copy_to": "title_all", 62 | }, 63 | "title_unescape": { 64 | "type": "text", 65 | "analyzer": "simple", 66 | "copy_to": "title_all", 67 | }, 68 | "text": { 69 | "type": "text", 70 | "analyzer": "my_english_analyzer", 71 | }, 72 | "anchortext": { 73 | "type": "text", 74 | "analyzer": "my_english_analyzer", 75 | }, 76 | "title_bigram": { 77 | "type": "text", 78 | "analyzer": "simple_bigram_analyzer", 79 | "copy_to": "title_all_bigram", 80 | }, 81 | "title_unescape_bigram": { 82 | "type": "text", 83 | "analyzer": "simple_bigram_analyzer", 84 | "copy_to": "title_all_bigram", 85 | }, 86 | "text_bigram": { 87 | "type": "text", 88 | "analyzer": "bigram_analyzer", 89 | }, 90 | "anchortext_bigram": { 91 | "type": "text", 92 | "analyzer": "bigram_analyzer", 93 | }, 94 | "original_json": {"type": "string"}, 95 | } 96 | } 97 | }, 98 | "settings": { 99 | "analysis": { 100 | "my_english_analyzer": { 101 | "type": "standard", 102 | "stopwords": "_english_", 103 | }, 104 | "simple_bigram_analyzer": { 105 | "tokenizer": "standard", 106 | "filter": ["lowercase", "shingle", "asciifolding"], 107 | }, 108 | "bigram_analyzer": { 109 | "tokenizer": "standard", 110 | "filter": [ 111 | "lowercase", 112 | "stop", 113 | "shingle", 114 | "asciifolding", 115 | ], 116 | }, 117 | }, 118 | }, 119 | } 120 | ), 121 | ) 122 | 123 | train = [ 124 | json.loads(line.strip()) 125 | for line in open("../../../data/corpus/musique/musique_ans_v1.0_train.jsonl") 126 | ] + [ 127 | json.loads(line.strip()) 128 | for line in open("../../../data/corpus/musique/musique_full_v1.0_train.jsonl") 129 | ] 130 | dev = [ 131 | json.loads(line.strip()) 132 | for line in open("../../../data/corpus/musique/musique_ans_v1.0_dev.jsonl") 133 | ] + [ 134 | json.loads(line.strip()) 135 | for line in open("../../../data/corpus/musique/musique_full_v1.0_dev.jsonl") 136 | ] 137 | test = [ 138 | json.loads(line.strip()) 139 | for line in open("../../../data/corpus/musique/musique_ans_v1.0_test.jsonl") 140 | ] + [ 141 | json.loads(line.strip()) 142 | for line in open("../../../data/corpus/musique/musique_full_v1.0_test.jsonl") 143 | ] 144 | 145 | tot = 0 146 | wikipedia_data = [] 147 | hist = set() 148 | for item in tqdm(chain(train, dev, test)): 149 | for p in item["paragraphs"]: 150 | stamp = p["title"] + " " + p["paragraph_text"] 151 | if not stamp in hist: 152 | wikipedia_data.append( 153 | {"id": tot, "text": p["paragraph_text"], "title": p["title"]} 154 | ) 155 | hist.add(stamp) 156 | tot += 1 157 | 158 | print("Making indexing queries...") 159 | all_queries = [] 160 | for item in tqdm(wikipedia_data): 161 | all_queries.append(process_line(item)) 162 | 163 | count = sum(len(queries.split("\n")) for queries in all_queries) // 2 164 | 165 | if not args.dry: 166 | print("Indexing...") 167 | chunksize = 100 168 | for chunk in tqdm( 169 | chunks(all_queries, chunksize), 170 | total=(len(all_queries) + chunksize - 1) // chunksize, 171 | ): 172 | res = es.bulk(index=INDEX_NAME, body="\n".join(chunk), timeout="100s") 173 | assert not res["errors"], res 174 | 175 | print(f"{count} documents indexed in total") 176 | 177 | 178 | if __name__ == "__main__": 179 | parser = ArgumentParser() 180 | 181 | parser.add_argument("--reindex", action="store_true", help="Reindex everything") 182 | parser.add_argument("--dry", action="store_true", help="Dry run") 183 | 184 | args = parser.parse_args() 185 | 186 | main(args) 187 | -------------------------------------------------------------------------------- /src/es_retrieve.py: -------------------------------------------------------------------------------- 1 | import json 2 | from elasticsearch import Elasticsearch 3 | import json 4 | import re 5 | import yaml 6 | 7 | with open("../config/config.yaml", "r") as f: 8 | config = yaml.safe_load(f) 9 | 10 | core_title_matcher = re.compile("([^()]+[^\s()])(?:\s*\(.+\))?") 11 | core_title_filter = lambda x: ( 12 | core_title_matcher.match(x).group(1) if core_title_matcher.match(x) else x 13 | ) 14 | 15 | 16 | class ElasticSearch: 17 | def __init__(self, index_name): 18 | self.index_name = index_name 19 | self.client = Elasticsearch(config["es"]["url"]) 20 | 21 | def _extract_one(self, item, lazy=False): 22 | res = { 23 | k: item["_source"][k] 24 | for k in ["id", "url", "title", "text", "title_unescape"] 25 | } 26 | res["_score"] = item["_score"] 27 | return res 28 | 29 | def rerank_with_query(self, query, results): 30 | def score_boost(item, query): 31 | score = item["_score"] 32 | core_title = core_title_filter(item["title_unescape"]) 33 | if query.startswith("The ") or query.startswith("the "): 34 | query1 = query[4:] 35 | else: 36 | query1 = query 37 | if query == item["title_unescape"] or query1 == item["title_unescape"]: 38 | score *= 1.5 39 | elif ( 40 | query.lower() == item["title_unescape"].lower() 41 | or query1.lower() == item["title_unescape"].lower() 42 | ): 43 | score *= 1.2 44 | elif item["title"].lower() in query: 45 | score *= 1.1 46 | elif query == core_title or query1 == core_title: 47 | score *= 1.2 48 | elif ( 49 | query.lower() == core_title.lower() 50 | or query1.lower() == core_title.lower() 51 | ): 52 | score *= 1.1 53 | elif core_title.lower() in query.lower(): 54 | score *= 1.05 55 | 56 | item["_score"] = score 57 | return item 58 | 59 | return list( 60 | sorted( 61 | [score_boost(item, query) for item in results], 62 | key=lambda item: -item["_score"], 63 | ) 64 | ) 65 | 66 | def single_text_query(self, query, topn=10, lazy=False, rerank_topn=50): 67 | 68 | constructed_query = { 69 | "multi_match": { 70 | "query": query, 71 | "fields": [ 72 | "title^1.25", 73 | "title_unescape^1.25", 74 | "text", 75 | "title_bigram^1.25", 76 | "title_unescape_bigram^1.25", 77 | "text_bigram", 78 | ], 79 | } 80 | } 81 | res = self.client.search( 82 | index=self.index_name, 83 | body={"query": constructed_query, "timeout": "100s"}, 84 | size=max(topn, rerank_topn), 85 | request_timeout=100, 86 | ) 87 | 88 | res = [self._extract_one(x, lazy=lazy) for x in res["hits"]["hits"]] 89 | res = self.rerank_with_query(query, res)[:topn] 90 | res = [{"title": _["title"], "paragraph_text": _["text"]} for _ in res] 91 | return res 92 | 93 | def search(self, question, k=10): 94 | try: 95 | res = self.single_text_query(query=question, topn=k) 96 | return json.dumps(res, ensure_ascii=False) 97 | except Exception as err: 98 | print(Exception, err) 99 | raise 100 | 101 | 102 | def retrieve(index_name, query, topk): 103 | ES = ElasticSearch(index_name) 104 | result = ES.search(query, topk) 105 | return json.loads(result) 106 | 107 | 108 | if __name__ == "__main__": 109 | 110 | print(retrieve("musique", "Bilbo Baggins", 2)) 111 | -------------------------------------------------------------------------------- /src/eval.py: -------------------------------------------------------------------------------- 1 | import json 2 | import string 3 | from collections import Counter 4 | import re 5 | import numpy as np 6 | 7 | 8 | def acc_choice(predictions, answers): 9 | num_correct = 0 10 | total = len(predictions) 11 | 12 | for pred, ans_list in zip(predictions, answers): 13 | pred = pred.lower() 14 | 15 | correct = False 16 | for ans in ans_list: 17 | ans = ans.lower() 18 | if re.search(rf"\b{ans}\b", pred): 19 | correct = True 20 | break 21 | if correct: 22 | num_correct += 1 23 | 24 | acc = round(100 * num_correct / total, 2) 25 | return acc 26 | 27 | 28 | def acc_score(predictions, answers): 29 | num_correct = 0 30 | for id, answer in enumerate(answers): 31 | pred = predictions[id] 32 | correctness = ( 33 | "True" if any(ans.lower() in pred.lower() for ans in answer) else "False" 34 | ) 35 | if correctness == "True": 36 | num_correct += 1 37 | else: 38 | pass 39 | acc = num_correct / len(answers) 40 | return round(100 * acc, 2) 41 | 42 | 43 | def normalize_answer(s): 44 | """Lower text and remove punctuation, articles and extra whitespace.""" 45 | 46 | def remove_articles(text): 47 | return re.sub(r"\b(a|an|the)\b", " ", text) 48 | 49 | def white_space_fix(text): 50 | return " ".join(text.split()) 51 | 52 | def remove_punc(text): 53 | exclude = set(string.punctuation) 54 | return "".join(ch for ch in text if ch not in exclude) 55 | 56 | def lower(text): 57 | return text.lower() 58 | 59 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 60 | 61 | 62 | def f1_score(prediction, ground_truth): 63 | common = Counter(prediction) & Counter(ground_truth) 64 | num_same = sum(common.values()) 65 | if num_same == 0: 66 | return 0 67 | precision = 1.0 * num_same / len(prediction) 68 | recall = 1.0 * num_same / len(ground_truth) 69 | f1 = (2 * precision * recall) / (precision + recall) 70 | return f1 71 | 72 | 73 | def qa_f1_score(prediction, ground_truth): 74 | normalized_prediction = normalize_answer(prediction) 75 | normalized_ground_truth = normalize_answer(ground_truth) 76 | 77 | prediction_tokens = normalized_prediction.split() 78 | ground_truth_tokens = normalized_ground_truth.split() 79 | return f1_score(prediction_tokens, ground_truth_tokens) 80 | 81 | 82 | def F1_scorer(predictions, answers): 83 | total_score = 0.0 84 | for prediction, ground_truths in zip(predictions, answers): 85 | score = 0.0 86 | for ground_truth in ground_truths: 87 | score = max(score, qa_f1_score(prediction, ground_truth)) 88 | total_score += score 89 | return round(100 * total_score / len(predictions), 2) 90 | 91 | 92 | def compute_exact(predictions, answers): 93 | total_score = 0.0 94 | for prediction, ground_truths in zip(predictions, answers): 95 | score = 0.0 96 | for ground_truth in ground_truths: 97 | score = max( 98 | score, 99 | int(normalize_answer(prediction) == normalize_answer(ground_truth)), 100 | ) 101 | total_score += score 102 | return round(100 * total_score / len(predictions), 2) 103 | 104 | 105 | def exact_presence(short_answers, context): 106 | """Verify if any of the answers is present in the given context. 107 | Args: 108 | short_answers: list of short answers to look for in the context 109 | context: a paragraph to search for short answers 110 | Returns: 111 | true if any of the short answers is present in the context 112 | """ 113 | 114 | n_short_answers = [normalize_answer(sa) for sa in short_answers] 115 | n_context = normalize_answer(context) 116 | 117 | for ans in n_short_answers: 118 | if ans in n_context: 119 | return True 120 | 121 | return False 122 | 123 | 124 | def compute_str_em(data): 125 | """Compute STR-EM metric (only for ASQA) 126 | Args: 127 | data: requires field `qa_pairs/short_answers` and `output` 128 | Returns: 129 | STR-EM and STR-EM-HIT () 130 | """ 131 | 132 | acc = [] 133 | hit = [] 134 | 135 | for item in data: 136 | loc_acc = [] 137 | for qa_pair in item["qa_pairs"]: 138 | loc_acc.append(exact_presence(qa_pair["short_answers"], item["output"])) 139 | acc.append(np.mean(loc_acc)) 140 | hit.append(int(np.mean(loc_acc) == 1)) 141 | 142 | return 100 * np.mean(acc), 100 * np.mean(hit) 143 | 144 | 145 | def eval_asqa(pred, raw_data_path="../data/eval/asqa/test.json"): 146 | with open(raw_data_path, "r") as f: 147 | raw_data = json.load(f) 148 | normalized_data = [] 149 | for i in range(len(pred)): 150 | pred[i]["answer"] = pred[i]["output"].strip().replace("<|im_end|>", "") 151 | result = {} 152 | for id, data in enumerate(pred): 153 | normalized_data.append( 154 | { 155 | "question": raw_data[id]["question"], 156 | "qa_pairs": raw_data[id]["qa_pairs"], 157 | "output": data["answer"], 158 | } 159 | ) 160 | result["str_em"], result["str_hit"] = compute_str_em(normalized_data) 161 | 162 | return result 163 | 164 | 165 | if __name__ == "__main__": 166 | 167 | from argparse import ArgumentParser 168 | 169 | parser = ArgumentParser() 170 | parser.add_argument( 171 | "--dataset", 172 | type=str, 173 | choices=["2wikimultihopqa", "hotpotqa", "musique", "asqa", "strategyqa"], 174 | default="hotpotqa", 175 | help="Dataset to use", 176 | ) 177 | parser.add_argument("--pred_path", type=str, help="Location of the prediction file") 178 | args = parser.parse_args() 179 | 180 | outputs = [] 181 | with open(args.pred_path, "r") as fin: 182 | for d in fin: 183 | pred = json.loads(d) 184 | if "id" in pred.keys(): 185 | outputs.append(pred) 186 | if "asqa" in args.pred_path: 187 | result = eval_asqa(outputs) 188 | print("eval result:", result) 189 | else: 190 | predictions = [data["output"] for data in outputs] 191 | answers = [data["answer"] for data in outputs] 192 | if "strategyqa" in args.dataset: 193 | acc = acc_choice(predictions, answers) 194 | print("Acc:", acc) 195 | else: 196 | acc = acc_score(predictions, answers) 197 | f1 = F1_scorer(predictions, answers) 198 | em = compute_exact(predictions, answers) 199 | 200 | print("Acc:", acc, "F1:", f1, "EM:", em) 201 | -------------------------------------------------------------------------------- /src/gen_dpo_data.py: -------------------------------------------------------------------------------- 1 | import random 2 | import json 3 | import os 4 | import datetime 5 | from tqdm import tqdm 6 | from torch.utils.data import Dataset, DataLoader 7 | from vllm import LLM, SamplingParams 8 | import argparse 9 | from es_retrieve import retrieve 10 | import yaml 11 | import multiprocessing as mp 12 | from utils import ( 13 | p_template, 14 | seed_everything, 15 | LLM_score_gen_new_query, 16 | LLM_score_compare_note, 17 | LLM_score_init_note, 18 | LLM_score_rag, 19 | ) 20 | 21 | with open("../config/config.yaml", "r") as f: 22 | config = yaml.safe_load(f) 23 | 24 | parser = argparse.ArgumentParser() 25 | parser.add_argument( 26 | "--dataset", 27 | type=str, 28 | choices=["2wikimultihopqa"], 29 | default="2wikimultihopqa", 30 | help="Name of the dataset to use", 31 | ) 32 | parser.add_argument( 33 | "--input_data_path", 34 | type=str, 35 | default="../data/corpus/2wikimultihopqa/train.json", 36 | help="Path to the input data file in JSON/JSONL format", 37 | ) 38 | parser.add_argument( 39 | "--output_path", 40 | type=str, 41 | default="../data/dpo_data", 42 | help="Output directory path for generated data", 43 | ) 44 | parser.add_argument( 45 | "--batch_size", type=int, default=9, help="Batch size for data processing" 46 | ) 47 | parser.add_argument( 48 | "--num_samples", 49 | type=int, 50 | default=15000, 51 | help="Number of samples to generate for each type.", 52 | ) 53 | parser.add_argument( 54 | "--per_data_gen_num", 55 | type=int, 56 | default=9, 57 | help="The number of times to generate data for each sample.", 58 | ) 59 | parser.add_argument( 60 | "--model", 61 | type=str, 62 | choices=["llama-3.1-8b-instruct", "qwen2.5-7b-instruct"], 63 | default="llama-3.1-8b-instruct", 64 | help="Base model used for data generation", 65 | ) 66 | parser.add_argument( 67 | "--score_model", 68 | type=str, 69 | choices=["gpt-4o-mini-2024-07-18"], 70 | default="gpt-4o-mini-2024-07-18", 71 | help="Model used for scoring", 72 | ) 73 | parser.add_argument( 74 | "--gpu_memory_utilization", 75 | type=float, 76 | default=0.9, 77 | help="GPU memory utilization limit", 78 | ) 79 | parser.add_argument( 80 | "--device", 81 | type=str, 82 | default="0,1,2,3,4,5,6,7", 83 | help="Comma separated list of devices to use for parallel processing", 84 | ) 85 | args = parser.parse_args() 86 | 87 | 88 | def load_and_sample_data(file_path, sample_size): 89 | with open(file_path, "r", encoding="utf-8") as f: 90 | data = ( 91 | [json.loads(line.strip()) for line in f] 92 | if "jsonl" in file_path 93 | else json.load(f) 94 | ) 95 | return random.sample(data, sample_size) 96 | 97 | 98 | def retrieve_q(question, top_k): 99 | refs = retrieve(args.dataset, question, topk=top_k) 100 | text = "" 101 | for ref in refs: 102 | text += f"Title: {ref['title']}\nText: {ref['paragraph_text']} \n\n" 103 | return text 104 | 105 | 106 | class CustomDataset(Dataset): 107 | def __init__(self, data_list, args): 108 | self.data_list = data_list 109 | self.args = args 110 | 111 | def __getitem__(self, index): 112 | item = self.data_list[index] 113 | item["id"] = index 114 | return item 115 | 116 | def __len__(self): 117 | return len(self.data_list) 118 | 119 | def collate_fn(self, batch): 120 | batch = [data for data in batch] 121 | if not batch: 122 | return None 123 | ids = [f["id"] for f in batch] 124 | questions = [f.get("question", None) for f in batch] 125 | answers = [f.get("answer", None) for f in batch] 126 | return { 127 | "ids": ids, 128 | "questions": questions, 129 | "answers": answers, 130 | } 131 | 132 | 133 | def gen_data(split, device_id): 134 | os.environ["CUDA_VISIBLE_DEVICES"] = str(device_id) 135 | llm = LLM( 136 | model=model_path, 137 | tensor_parallel_size=1, 138 | trust_remote_code=True, 139 | dtype="bfloat16", 140 | gpu_memory_utilization=args.gpu_memory_utilization, 141 | ) 142 | sampling_params = SamplingParams(max_tokens=1024, temperature=1, top_p=0.9) 143 | 144 | def call_llm(prompts, params=[sampling_params]): 145 | if "llama" in args.model: 146 | model_template = "<|start_header_id|>user<|end_header_id|>\n\n{prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" 147 | elif "qwen" in args.model or "minicpm" in args.model: 148 | model_template = ( 149 | "<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n" 150 | ) 151 | prompts = [model_template.format(prompt=p) for p in prompts] 152 | 153 | responses = [ 154 | response.outputs[0].text for response in llm.generate(prompts, params) 155 | ] 156 | return responses 157 | 158 | temperature_list = [0.1, 0.5, 0.9] 159 | top_p_list = [0.1, 0.5, 0.9] 160 | param_combinations = [ 161 | (temp, top_p) for temp in temperature_list for top_p in top_p_list 162 | ] 163 | param_combinations = ( 164 | param_combinations[: args.per_data_gen_num] 165 | if args.per_data_gen_num <= len(param_combinations) 166 | else param_combinations 167 | + random.choices( 168 | param_combinations, k=args.per_data_gen_num - len(param_combinations) 169 | ) 170 | ) 171 | 172 | params_dict = {"max_tokens": 1024} 173 | 174 | for batch in tqdm(split): 175 | if not batch: 176 | continue 177 | 178 | batch.update({"prompts": []}) 179 | ( 180 | batch_temp_rag, 181 | batch_temp_gen_query, 182 | batch_temp_init_note, 183 | batch_temp_refine_note, 184 | ) = ([], [], [], []) 185 | # Generate rag 186 | for id, question in enumerate(batch["questions"]): 187 | top_k = random.randint(3, 9) 188 | refs = retrieve_q(question, top_k) 189 | rag_prompt = p_template("base", {"query": question, "refs": refs}) 190 | init_note_prompt = p_template( 191 | "init_note", {"query": question, "refs": refs} 192 | ) 193 | rag_sampling_params = [ 194 | SamplingParams(**params_dict, temperature=temp, top_p=top_p) 195 | for temp, top_p in param_combinations 196 | ] 197 | rags_generated = call_llm( 198 | [rag_prompt] * len(rag_sampling_params), rag_sampling_params 199 | ) 200 | 201 | best_rag, worst_rag = LLM_score_rag(rags_generated, batch["answers"][id]) 202 | 203 | if best_rag and worst_rag: 204 | batch_temp_rag.append( 205 | { 206 | "id": f"rag-{batch['ids'][id]}", 207 | "raw_question": question, 208 | "prompt": rag_prompt, 209 | "chosen": best_rag, 210 | "rejected": worst_rag, 211 | "data_type": "rag", 212 | "gen_response_list": [ 213 | { 214 | "index": idx, 215 | "response": response, 216 | "temperature": param_combination[0], 217 | "top_p": param_combination[1], 218 | } 219 | for idx, (response, param_combination) in enumerate( 220 | zip(rags_generated, param_combinations) 221 | ) 222 | ], 223 | } 224 | ) 225 | 226 | # Generate init_note 227 | 228 | init_note_sampling_params = [ 229 | SamplingParams(**params_dict, temperature=temp, top_p=top_p) 230 | for temp, top_p in param_combinations 231 | ] 232 | init_notes_generated = call_llm( 233 | [init_note_prompt] * len(init_note_sampling_params), 234 | init_note_sampling_params, 235 | ) 236 | 237 | best_init_note, worst_init_note = LLM_score_init_note( 238 | args.score_model, question, refs, init_notes_generated 239 | ) 240 | if best_init_note and worst_init_note: 241 | batch_temp_init_note.append( 242 | { 243 | "id": f"init_note-{batch['ids'][id]}", 244 | "raw_question": question, 245 | "prompt": init_note_prompt, 246 | "chosen": best_init_note, 247 | "rejected": worst_init_note, 248 | "data_type": "init_note", 249 | "gen_response_list": [ 250 | { 251 | "index": idx, 252 | "response": response, 253 | "temperature": param_combination[0], 254 | "top_p": param_combination[1], 255 | } 256 | for idx, (response, param_combination) in enumerate( 257 | zip(init_notes_generated, param_combinations) 258 | ) 259 | ], 260 | } 261 | ) 262 | 263 | # Generate queries 264 | gen_query_prompt = p_template( 265 | "gen_new_query", 266 | {"query": question, "note": best_init_note, "query_log": []}, 267 | ) 268 | gen_query_sampling_params = [ 269 | SamplingParams(**params_dict, temperature=temp, top_p=top_p) 270 | for temp, top_p in param_combinations 271 | ] 272 | responses = call_llm( 273 | [gen_query_prompt] * len(gen_query_sampling_params), 274 | gen_query_sampling_params, 275 | ) 276 | 277 | gen_querys = responses 278 | 279 | best_query, worst_query = LLM_score_gen_new_query( 280 | args.score_model, question, best_init_note, gen_querys, [] 281 | ) 282 | 283 | refs = retrieve_q((best_query + "\n" + question)[:500], top_k) 284 | refine_note_prompt = p_template( 285 | "refine_note", {"query": question, "refs": refs, "note": best_init_note} 286 | ) 287 | batch["prompts"].append(refine_note_prompt) 288 | if best_query and worst_query: 289 | batch_temp_gen_query.append( 290 | { 291 | "id": f"new_query-{batch['ids'][id]}", 292 | "raw_question": question, 293 | "prompt": gen_query_prompt, 294 | "chosen": best_query, 295 | "rejected": worst_query, 296 | "data_type": "gen_new_query", 297 | "gen_response_list": [ 298 | { 299 | "index": idx, 300 | "response": gen_query, 301 | "temperature": param_combination[0], 302 | "top_p": param_combination[1], 303 | } 304 | for idx, (gen_query, param_combination) in enumerate( 305 | zip(gen_querys, param_combinations) 306 | ) 307 | ], 308 | } 309 | ) 310 | 311 | # Generate refined notes 312 | for temp, top_p in param_combinations: 313 | sampling_params = SamplingParams( 314 | **params_dict, temperature=temp, top_p=top_p 315 | ) 316 | prompts = batch["prompts"] 317 | 318 | responses = [call_llm(prompt, temp, top_p) for prompt in prompts] 319 | 320 | for idx, response in enumerate(responses): 321 | if len(batch_temp_refine_note) <= idx: 322 | batch_temp_refine_note.append( 323 | { 324 | "id": f"refine_note-{batch['ids'][idx]}", 325 | "raw_question": batch["questions"][idx], 326 | "prompt": prompts[idx], 327 | "data_type": "refine_note", 328 | "gen_response_list": [], 329 | } 330 | ) 331 | score = LLM_score_compare_note( 332 | args.score_model, batch["questions"][idx], best_init_note, response 333 | ) 334 | batch_temp_refine_note[idx]["gen_response_list"].append( 335 | { 336 | "index": len(batch_temp_refine_note[idx]["gen_response_list"]), 337 | "response": response, 338 | "score_flag": score, 339 | "temperature": temp, 340 | "top_p": top_p, 341 | } 342 | ) 343 | 344 | for refine_note in batch_temp_refine_note: 345 | chosen_list = [ 346 | item["response"] 347 | for item in refine_note["gen_response_list"] 348 | if item["score_flag"] 349 | ] 350 | rejected_list = [ 351 | item["response"] 352 | for item in refine_note["gen_response_list"] 353 | if not item["score_flag"] 354 | ] 355 | refine_note["chosen"] = random.choice(chosen_list) if chosen_list else "" 356 | refine_note["rejected"] = ( 357 | random.choice(rejected_list) if rejected_list else "" 358 | ) 359 | batch_temp_refine_note = [ 360 | refine_note 361 | for refine_note in batch_temp_refine_note 362 | if refine_note["chosen"] and refine_note["rejected"] 363 | ] 364 | batch_temp = ( 365 | batch_temp_init_note 366 | + batch_temp_gen_query 367 | + batch_temp_refine_note 368 | + batch_temp_rag 369 | ) 370 | if batch_temp: 371 | with open(output_file, "a", encoding="utf-8") as f: 372 | json_lines = [ 373 | json.dumps(data, ensure_ascii=False) for data in batch_temp 374 | ] 375 | f.write("\n".join(json_lines) + "\n") 376 | 377 | 378 | def main(): 379 | 380 | visible_devices = args.device.split(",") 381 | num_devices = len(visible_devices) 382 | 383 | data = load_and_sample_data(args.input_data_path, args.num_samples) 384 | 385 | split_data = lambda lst, n: [lst[i::n] for i in range(n)] 386 | 387 | dataset = CustomDataset(data, args) 388 | dataloader = DataLoader( 389 | dataset, batch_size=args.batch_size, collate_fn=dataset.collate_fn 390 | ) 391 | 392 | splits = split_data(list(dataloader), num_devices) 393 | processes = [] 394 | 395 | for rank, device_id in enumerate(visible_devices): 396 | split_dataloader = splits[rank] 397 | p = mp.Process(target=gen_data, args=(list(split_dataloader), device_id)) 398 | p.start() 399 | processes.append(p) 400 | 401 | for p in processes: 402 | p.join() 403 | 404 | 405 | if __name__ == "__main__": 406 | current_date = datetime.datetime.now().strftime("%Y%m%d-%H:%M:%S") 407 | seed_everything(66) 408 | model_path = config["model"][args.model] 409 | output_file = os.path.join( 410 | args.output_path, 411 | f"dpo_data-{args.model}-score_{args.score_model}_num-{args.num_samples}_para-{args.per_data_gen_num}-{current_date}.jsonl", 412 | ) 413 | main() 414 | -------------------------------------------------------------------------------- /src/main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import argparse 4 | import logging 5 | import datetime 6 | import yaml 7 | import faiss 8 | import pandas as pd 9 | from sentence_transformers import SentenceTransformer 10 | import backoff 11 | from multiprocessing.pool import ThreadPool 12 | from tqdm import tqdm 13 | from eval import acc_score, F1_scorer, compute_exact, eval_asqa, acc_choice 14 | from utils import seed_everything 15 | from vllm import LLM, SamplingParams 16 | 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument( 19 | "--model", 20 | type=str, 21 | choices=[ 22 | "gpt-4o-mini-2024-07-18", 23 | "Llama-3-1-70B-Instruct", 24 | "llama-3.1-8b-instruct", 25 | "llama-3.1-8b-instruct-dpo", 26 | "qwen2.5-7b-instruct", 27 | "qwen2.5-7b-instruct-dpo", 28 | ], 29 | default="llama-3.1-8b-instruct", 30 | help="Model to use", 31 | ) 32 | parser.add_argument( 33 | "--max_step", type=int, default=3, help="Maximum number of update steps" 34 | ) 35 | parser.add_argument( 36 | "--max_fail_step", type=int, default=2, help="Maximum number of failed steps" 37 | ) 38 | parser.add_argument( 39 | "--MaxClients", type=int, default=1, help="Number of concurrent clients" 40 | ) 41 | parser.add_argument( 42 | "--retrieve_top_k", 43 | type=int, 44 | default=5, 45 | help="Number of documents to retrieve per query", 46 | ) 47 | parser.add_argument( 48 | "--max_top_k", 49 | type=int, 50 | default=15, 51 | help="Total maximum number of documents to retrieve", 52 | ) 53 | parser.add_argument( 54 | "--dataset", 55 | type=str, 56 | choices=[ 57 | "2wikimultihopqa", 58 | "hotpotqa", 59 | "musique", 60 | "asqa", 61 | "strategyqa", 62 | ], 63 | default="hotpotqa", 64 | help="Dataset to use", 65 | ) 66 | parser.add_argument( 67 | "--method", 68 | type=str, 69 | default="deepnote", 70 | choices=["deepnote", "base", "base_wo_retri"], 71 | help="Method to use", 72 | ) 73 | parser.add_argument( 74 | "--resume_path", 75 | type=str, 76 | default="", 77 | help="Path to the file for resuming generation", 78 | ) 79 | parser.add_argument( 80 | "--retrieve_method", 81 | type=str, 82 | default="es", 83 | help="Retrieval method to use (es: ElasticSearch, emb: Dense Retrieval)", 84 | ) 85 | parser.add_argument( 86 | "--device", type=str, default="cuda:0", help="Device to run inference on" 87 | ) 88 | args = parser.parse_args() 89 | 90 | with open("../config/config.yaml", "r") as f: 91 | config = yaml.safe_load(f) 92 | 93 | 94 | @backoff.on_exception(backoff.expo, (Exception), max_time=500) 95 | def call_gpt(prompt_file, variable_dict): 96 | with open(prompt_file, "r") as fin: 97 | prompt = fin.read().format(**variable_dict) 98 | res = gpt_gen(args.model, prompt) 99 | assert res is not None 100 | return res 101 | 102 | 103 | def call_local(prompt_file, variable_dict): 104 | 105 | with open(prompt_file, "r") as fin: 106 | prompt = fin.read() 107 | if "llama" in args.model: 108 | model_template = "<|start_header_id|>user<|end_header_id|>\n\n{prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" 109 | if "qwen" in args.model: 110 | model_template = "<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n" 111 | prompt = model_template.format(prompt=prompt.format(**variable_dict)) 112 | response = llm.generate(prompt, sampling_params)[0].outputs[0].text 113 | return response 114 | 115 | 116 | if "gpt" in args.model: 117 | from utils import gpt_gen 118 | 119 | call_llm = call_gpt 120 | 121 | else: 122 | llm = LLM( 123 | model=config["model"][args.model], 124 | tensor_parallel_size=1, 125 | trust_remote_code=True, 126 | dtype="bfloat16", 127 | gpu_memory_utilization=0.9, 128 | ) 129 | sampling_params = SamplingParams(max_tokens=1280, temperature=0.1, top_p=0.9) 130 | 131 | call_llm = call_local 132 | 133 | 134 | def get_context(data): 135 | 136 | if retrieve_method == "emb": 137 | text = "\n".join(data) 138 | else: 139 | text = "" 140 | for i in range(len(data)): 141 | text += f"Title: {data[i]['title']}\nText: {data[i]['paragraph_text']} \n\n" 142 | return text 143 | 144 | 145 | def save_log_to_file(logger, log_file="my_log", log_folder="logs"): 146 | if not os.path.exists(log_folder): 147 | os.makedirs(log_folder) 148 | current_date = datetime.datetime.now().strftime("%Y%m%d-%H:%M:%S") 149 | log_file_name = f"{log_file}-{current_date}.log" 150 | file_handler = logging.FileHandler(os.path.join(log_folder, log_file_name)) 151 | logger.addHandler(file_handler) 152 | 153 | 154 | def is_directory_empty(directory_path: str) -> bool: 155 | try: 156 | return len(os.listdir(directory_path)) == 0 157 | except OSError: 158 | return False 159 | 160 | 161 | def call_llm_template(template, variables): 162 | return call_llm(f"../prompts/{LAUGUAGE}/{template}", variables) 163 | 164 | 165 | def init_note(query, refs): 166 | return call_llm_template("init_note", {"query": query, "refs": refs}) 167 | 168 | 169 | def gen_new_query(query, note, query_log): 170 | return call_llm_template( 171 | "gen_new_query", {"query": query, "note": note, "query_log": query_log} 172 | ) 173 | 174 | 175 | def refine_note(note, query, refs): 176 | return call_llm_template( 177 | "refine_note", {"note": note, "query": query, "refs": refs} 178 | ) 179 | 180 | 181 | def gen_answer(query, note_str): 182 | if args.dataset in ["asqa", "strategyqa"]: 183 | template = f"gen_answer_{args.dataset}" 184 | else: 185 | template = "gen_answer" 186 | return call_llm_template(template, {"query": query, "note": note_str}) 187 | 188 | 189 | def compare_note(query, best_note, new_note): 190 | response = call_llm_template( 191 | "compare", {"query": query, "best_note": best_note, "new_note": new_note} 192 | ) 193 | 194 | try: 195 | if "json" in response: 196 | response = response[response.index("{") : response.rindex("}") + 1] 197 | status = json.loads(response)["status"] 198 | return status.lower() == "true" 199 | except: 200 | return "true" in response.lower() 201 | 202 | 203 | def retrieve_note(doc_id, query, answer, top_k=2): 204 | ref_log, llm_times, note_log, query_log, query_list = [], 0, [], [], [] 205 | refs = retrieve(args.dataset, query=query, topk=top_k) 206 | note = best_note = init_note(query, get_context(refs)) 207 | ref_log.append({"refs": refs, "step": 0, "flag": "init_refs"}) 208 | note_log.append({"note": note, "step": 0, "flag": "init_note"}) 209 | llm_times += 1 210 | step, notes_status = 0, [] 211 | 212 | while step < args.max_step: 213 | all_refs = { 214 | ref 215 | for d in ref_log 216 | for ref in ( 217 | d["refs"] 218 | if retrieve_method == "emb" 219 | else [d["title"] + d["paragraph_text"] for d in d["refs"]] 220 | ) 221 | } 222 | if len(all_refs) > args.max_top_k: 223 | break 224 | max_ref = ( 225 | args.max_top_k - len(all_refs) 226 | if (len(all_refs) + top_k) > args.max_top_k 227 | else 0 228 | ) 229 | 230 | new_query = gen_new_query(query, best_note, str(query_list)) 231 | query_list.append(new_query) 232 | llm_times += 1 233 | 234 | refs = retrieve( 235 | args.dataset, query=(new_query + "\n" + query)[:500], topk=top_k 236 | ) 237 | if max_ref > 0: 238 | refs = ( 239 | [d for d in refs if d not in all_refs][:max_ref] 240 | if retrieve_method == "emb" 241 | else [ 242 | d for d in refs if d["title"] + d["paragraph_text"] not in all_refs 243 | ][:max_ref] 244 | ) 245 | 246 | note = refine_note(best_note, query, get_context(refs)).replace("\n", "") 247 | llm_times += 1 248 | status = compare_note(query, best_note, note) 249 | flag = "True" if status else "False" 250 | 251 | ref_log.append({"refs": refs, "step": step, "flag": flag}) 252 | note_log.append({"note": note, "step": step, "flag": flag}) 253 | query_log.append({"query": new_query, "step": step, "flag": flag}) 254 | 255 | if status: 256 | best_note = note 257 | notes_status.append(status) 258 | if notes_status.count(False) >= args.max_fail_step: 259 | break 260 | step += 1 261 | 262 | llm_times += 1 263 | return { 264 | "id": doc_id, 265 | "question": query, 266 | "answer": answer, 267 | "output": gen_answer(query, best_note), 268 | "deepnote": best_note, 269 | "query_log": query_log, 270 | "note_log": note_log, 271 | "ref_log": ref_log, 272 | } 273 | 274 | 275 | def process_doc_cell(idx, doc_cell, args): 276 | id_new, query, answer = idx, doc_cell["question"], doc_cell["answer"] 277 | 278 | if args.method != "deepnote": 279 | prompt_path = f"../prompts/{LAUGUAGE}" 280 | if args.dataset in ["asqa", "strategyqa"]: 281 | gen_name = f"{args.method}_{args.dataset}" 282 | else: 283 | gen_name = args.method 284 | output = ( 285 | call_llm(f"{prompt_path}/{gen_name}", {"query": query}) 286 | if "wo_retri" in args.method 287 | else call_llm( 288 | f"{prompt_path}/{gen_name}", 289 | { 290 | "query": query, 291 | "refs": get_context( 292 | retrieve(args.dataset, query=query, topk=args.retrieve_top_k) 293 | ), 294 | }, 295 | ) 296 | ) 297 | 298 | return {"id": id_new, "query": query, "answer": answer, "output": output} 299 | else: 300 | return retrieve_note(id_new, query, answer, top_k=args.retrieve_top_k) 301 | 302 | 303 | if __name__ == "__main__": 304 | 305 | LAUGUAGE = "en" 306 | seed_everything(66) 307 | 308 | logger = logging.getLogger(__name__) 309 | logging.basicConfig( 310 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 311 | datefmt="%m/%d/%Y %H:%M:%S", 312 | level=logging.INFO, 313 | ) 314 | save_log_to_file( 315 | logger, 316 | log_file=f"{args.dataset}_{args.method}_{args.model}", 317 | log_folder="../log", 318 | ) 319 | logger.info(f"{'*' * 30} CONFIGURATION {'*' * 30}") 320 | for key, val in sorted(vars(args).items()): 321 | keystr = "{}".format(key) + (" " * (30 - len(key))) 322 | logger.info("%s --> %s", keystr, val) 323 | dataset_name = args.dataset 324 | vector_path = f"../data/corpus/{dataset_name}/{dataset_name}.index" 325 | if args.dataset == "asqa" or args.dataset == "strategyqa": 326 | 327 | vector = faiss.read_index(f"../data/corpus/wiki/wiki.index") 328 | emb_model = SentenceTransformer(config["model"]["gtr-t5-xxl"], device="cpu") 329 | raw_data = pd.read_csv("../data/corpus/wiki/psgs_w100.tsv", sep="\t") 330 | 331 | def retrieve(_, query, topk): 332 | feature = emb_model.encode([query]) 333 | _, match_id = vector.search(feature, topk) 334 | return [ 335 | raw_data.iloc[i]["title"] + "\n" + raw_data.iloc[i]["text"] 336 | for i in match_id[0] 337 | ] 338 | 339 | elif args.retrieve_method == "emb": 340 | 341 | emb_model = SentenceTransformer( 342 | config["model"]["bge-base-en-v1.5"], device=args.device 343 | ) 344 | with open(f"../data/corpus/{args.dataset}/chunk.json", encoding="utf-8") as f: 345 | raw_data = json.load(f) 346 | vector = faiss.read_index(vector_path) 347 | 348 | def retrieve(_, query, topk): 349 | feature = emb_model.encode([query]) 350 | _, match_id = vector.search(feature, topk) 351 | return [raw_data[i] for i in match_id[0]] 352 | 353 | else: 354 | from es_retrieve import retrieve 355 | 356 | formatted_time = datetime.datetime.now().strftime("%Y%m%d-%H:%M:%S") 357 | 358 | with open(f"../data/eval/{args.dataset}/test.json", encoding="utf-8") as f: 359 | qa_data = json.load(f) 360 | 361 | retrieve_method = args.retrieve_method 362 | if args.dataset in ["asqa", "strategyqa"]: 363 | retrieve_method = "emb" 364 | 365 | save_path = f"../output/{args.dataset}/{retrieve_method}/{args.method}/{args.model}" 366 | os.makedirs(save_path, exist_ok=True) 367 | 368 | all_result = [] 369 | 370 | if args.resume_path: 371 | with open(args.resume_path, "r", encoding="utf-8") as fin: 372 | resume_data = [json.loads(i) for i in fin.readlines()] 373 | all_result = resume_data 374 | filepath = args.resume_path 375 | else: 376 | resume_data = [] 377 | filepath = ( 378 | f"{save_path}/topk-{args.retrieve_top_k}-{formatted_time}.jsonl" 379 | if args.method != "deepnote" 380 | else f"{save_path}/topk-{args.retrieve_top_k}__max_step-{args.max_step}__max_fail_step-{args.max_fail_step}-{formatted_time}.jsonl" 381 | ) 382 | logger.info(f"The predicted results will be saved in '{filepath}'.") 383 | last_id = len(resume_data) 384 | batch_size = args.MaxClients 385 | logger.info("start predicting ...") 386 | for i in tqdm(range(last_id, len(qa_data), batch_size)): 387 | pool = ThreadPool(processes=args.MaxClients) 388 | current_batch = qa_data[i : i + batch_size] 389 | tasks = [ 390 | (idx + i, doc_cell, args) for idx, doc_cell in enumerate(current_batch) 391 | ] 392 | 393 | results = pool.starmap(process_doc_cell, tasks) 394 | pool.close() 395 | pool.join() 396 | 397 | for result in results: 398 | if result: 399 | all_result.append(result) 400 | with open(filepath, "a", buffering=1) as fout: 401 | fout.write(json.dumps(result, ensure_ascii=False) + "\n") 402 | 403 | logger.info("start evaluating ...") 404 | 405 | predictions = [data["output"] for data in all_result] 406 | answers = [data["answer"] for data in all_result] 407 | 408 | if "asqa" in args.dataset: 409 | eval_result = eval_asqa(all_result) 410 | elif "strategyqa" in args.dataset: 411 | acc = acc_choice(predictions, answers) 412 | eval_result = {"Acc": acc} 413 | else: 414 | acc = acc_score(predictions, answers) 415 | f1 = F1_scorer(predictions, answers) 416 | em = compute_exact(predictions, answers) 417 | eval_result = {"Acc": acc, "F1": f1, "EM": em} 418 | 419 | if eval_result: 420 | with open(filepath, "a", buffering=1) as fout: 421 | fout.write(json.dumps(eval_result, ensure_ascii=False) + "\n") 422 | 423 | logger.info(f"eval result: {eval_result}") 424 | -------------------------------------------------------------------------------- /src/select_dpo_data.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import re 4 | import random 5 | 6 | parser = argparse.ArgumentParser() 7 | parser.add_argument( 8 | "--init_num", type=int, default=1900, help="Number of init_note samples" 9 | ) 10 | parser.add_argument( 11 | "--refine_num", type=int, default=1900, help="Number of refine_note samples" 12 | ) 13 | parser.add_argument( 14 | "--query_num", type=int, default=1900, help="Number of gen_query samples" 15 | ) 16 | parser.add_argument("--rag_num", type=int, default=300, help="Number of RAG samples") 17 | parser.add_argument( 18 | "--data_path", type=str, required=True, help="Path to input data file" 19 | ) 20 | parser.add_argument( 21 | "--output_path", 22 | type=str, 23 | default="../data/dpo/processed/train.jsonl", 24 | help="Path to output file", 25 | ) 26 | args = parser.parse_args() 27 | 28 | with open(args.data_path, "r", encoding="utf-8") as file: 29 | data = [json.loads(i) for i in file] 30 | random.shuffle(data) 31 | 32 | 33 | def extract_and_match(input): 34 | match = re.search(r"-(\d+)", input) 35 | return match.group(1) 36 | 37 | 38 | # Collect samples for each category 39 | init_data, refine_data, query_data, rag_data = [], [], [], [] 40 | for d in data: 41 | if not (d["chosen"] and d["rejected"]): 42 | continue 43 | 44 | fields = {"prompt": d["prompt"], "chosen": d["chosen"], "rejected": d["rejected"]} 45 | 46 | if "init" in d["id"] and len(init_data) < args.init_num: 47 | init_data.append(fields) 48 | elif "refine" in d["id"] and len(refine_data) < args.refine_num: 49 | refine_data.append(fields) 50 | elif "query" in d["id"] and len(query_data) < args.query_num: 51 | query_data.append(fields) 52 | elif "rag" in d["id"] and len(rag_data) < args.rag_num: 53 | rag_data.append(fields) 54 | 55 | # Combine and save processed data 56 | train_data = init_data + refine_data + query_data + rag_data 57 | random.shuffle(train_data) 58 | 59 | with open(args.output_path, "w") as fout: 60 | for res in train_data: 61 | fout.write(json.dumps(res, ensure_ascii=False) + "\n") 62 | print( 63 | f"Init: {len(init_data)}, Refine: {len(refine_data)}, Query: {len(query_data)}, RAG: {len(rag_data)}" 64 | ) 65 | print(f"Total training samples: {len(train_data)}") 66 | -------------------------------------------------------------------------------- /src/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch 4 | from datasets import load_dataset 5 | from transformers import AutoModelForCausalLM, AutoTokenizer 6 | from functools import partial 7 | from trl import ( 8 | DPOConfig, 9 | DPOTrainer, 10 | ModelConfig, 11 | ScriptArguments, 12 | TrlParser, 13 | get_kbit_device_map, 14 | get_peft_config, 15 | get_quantization_config, 16 | ) 17 | from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE 18 | 19 | 20 | def preprocessing(example, args, tokenizer): 21 | prompt = example["prompt"] 22 | if "llama" in model_args.model_name_or_path.lower(): 23 | template = "<|start_header_id|>user<|end_header_id|>\n\n{prompt}<|eot_id|>\n<|start_header_id|>assistant<|end_header_id|>\n\n" 24 | elif "qwen" in model_args.model_name_or_path.lower(): 25 | template = "<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n" 26 | 27 | # Ensure the prompt does not exceed the max prompt length 28 | token_prompt = tokenizer(prompt, truncation=True, max_length=args.max_prompt_length) 29 | prompt = tokenizer.decode(token_prompt.input_ids) 30 | prompt = template.format(prompt=prompt) 31 | 32 | chosen = example["chosen"] 33 | rejected = example["rejected"] 34 | 35 | one_item = {"prompt": prompt, "chosen": chosen, "rejected": rejected} 36 | return one_item 37 | 38 | 39 | def main(script_args, training_args, model_args): 40 | ################ 41 | # Model & Tokenizer 42 | ################### 43 | torch_dtype = ( 44 | model_args.torch_dtype 45 | if model_args.torch_dtype in ["auto", None] 46 | else getattr(torch, model_args.torch_dtype) 47 | ) 48 | quantization_config = get_quantization_config(model_args) 49 | model_kwargs = dict( 50 | revision=model_args.model_revision, 51 | attn_implementation=model_args.attn_implementation, 52 | torch_dtype=torch_dtype, 53 | use_cache=False if training_args.gradient_checkpointing else True, 54 | device_map=get_kbit_device_map() if quantization_config is not None else None, 55 | quantization_config=quantization_config, 56 | ) 57 | model = AutoModelForCausalLM.from_pretrained( 58 | model_args.model_name_or_path, 59 | trust_remote_code=model_args.trust_remote_code, 60 | **model_kwargs 61 | ) 62 | peft_config = get_peft_config(model_args) 63 | if peft_config is None: 64 | ref_model = AutoModelForCausalLM.from_pretrained( 65 | model_args.model_name_or_path, 66 | trust_remote_code=model_args.trust_remote_code, 67 | **model_kwargs 68 | ) 69 | else: 70 | ref_model = None 71 | tokenizer = AutoTokenizer.from_pretrained( 72 | model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code 73 | ) 74 | if tokenizer.pad_token is None: 75 | tokenizer.pad_token = tokenizer.eos_token 76 | if tokenizer.chat_template is None: 77 | tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE 78 | if script_args.ignore_bias_buffers: 79 | # torch distributed hack 80 | model._ddp_params_and_buffers_to_ignore = [ 81 | name for name, buffer in model.named_buffers() if buffer.dtype == torch.bool 82 | ] 83 | 84 | ################ 85 | # Dataset 86 | ################ 87 | 88 | partial_preprocess = partial(preprocessing, args=training_args, tokenizer=tokenizer) 89 | dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config) 90 | dataset = dataset.map(partial_preprocess) 91 | 92 | ########## 93 | # Training 94 | ################ 95 | trainer = DPOTrainer( 96 | model, 97 | ref_model, 98 | args=training_args, 99 | train_dataset=dataset[script_args.dataset_train_split], 100 | eval_dataset=( 101 | dataset[script_args.dataset_test_split] 102 | if training_args.eval_strategy != "no" 103 | else None 104 | ), 105 | processing_class=tokenizer, 106 | peft_config=peft_config, 107 | ) 108 | 109 | trainer.train() 110 | 111 | if training_args.eval_strategy != "no": 112 | metrics = trainer.evaluate() 113 | trainer.log_metrics("eval", metrics) 114 | trainer.save_metrics("eval", metrics) 115 | 116 | # Save and push to hub 117 | trainer.save_model(training_args.output_dir) 118 | if training_args.push_to_hub: 119 | trainer.push_to_hub(dataset_name=script_args.dataset_name) 120 | 121 | 122 | def make_parser(subparsers: argparse._SubParsersAction = None): 123 | dataclass_types = (ScriptArguments, DPOConfig, ModelConfig) 124 | if subparsers is not None: 125 | parser = subparsers.add_parser( 126 | "dpo", help="Run the DPO training script", dataclass_types=dataclass_types 127 | ) 128 | else: 129 | parser = TrlParser(dataclass_types) 130 | return parser 131 | 132 | 133 | if __name__ == "__main__": 134 | parser = make_parser() 135 | script_args, training_args, model_args = parser.parse_args_and_config() 136 | main(script_args, training_args, model_args) 137 | -------------------------------------------------------------------------------- /src/train.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --nnodes=1 --nproc_per_node=8 --master_addr localhost --master_port 7428 --node_rank 0 train.py \ 2 | --model_name_or_path meta-llama/Llama-3.1-8B-Instruct \ 3 | --dataset_name ../data/dpo/processed/train \ 4 | --trust_remote_code \ 5 | --max_length 8192 \ 6 | --max_prompt_length 8000 \ 7 | --output_dir ../save_model/llama \ 8 | --save_steps 500 \ 9 | --gradient_accumulation_steps 1 \ 10 | --per_device_train_batch_size 1 \ 11 | --per_device_eval_batch_size 1 \ 12 | --learning_rate 5e-7 \ 13 | --logging_strategy steps \ 14 | --logging_steps 50 \ 15 | --logging_dir ../save_model/llama \ 16 | --bf16 True \ 17 | --num_train_epochs 1 \ 18 | --report_to "tensorboard" \ 19 | --save_only_model \ 20 | --gradient_checkpointing \ 21 | --deepspeed ../config/ds_config_zero3.json 22 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | import os, time, json, re 2 | import torch 3 | from openai import OpenAI 4 | import random 5 | import numpy as np 6 | import yaml 7 | import backoff 8 | 9 | with open("../config/config.yaml", "r") as f: 10 | config = yaml.safe_load(f) 11 | os.environ["OPENAI_API_KEY"] = config["model"]["OPENAI_API_KEY"] 12 | os.environ["OPENAI_BASE_URL"] = config["model"]["OPENAI_BASE_URL"] 13 | client = OpenAI() 14 | 15 | 16 | def p_template(template, variables): 17 | prompt_file = f"../prompts/en/{template}" 18 | with open(prompt_file, "r") as fin: 19 | prompt = fin.read() 20 | prompt = prompt.format(**variables) 21 | return prompt 22 | 23 | 24 | def gpt_gen(model, content, temperature=0.1, top_p=0.9): 25 | try: 26 | completion = client.chat.completions.create( 27 | model=model, 28 | temperature=temperature, 29 | top_p=top_p, 30 | max_tokens=1280, 31 | messages=[{"role": "user", "content": content}], 32 | ) 33 | 34 | return completion.choices[0].message.content 35 | 36 | except Exception as e: 37 | print(f"An error occurred: {e}") 38 | time.sleep(0.5) 39 | 40 | return None 41 | 42 | 43 | @backoff.on_exception(backoff.expo, (Exception), max_time=100) 44 | def call_gpt(model, content, temperature=1, top_p=0.9): 45 | res = gpt_gen(model, content, temperature, top_p) 46 | assert res is not None 47 | return res 48 | 49 | 50 | def extract_best_worst(sequence): 51 | try: 52 | if "json" in sequence: 53 | sequence = sequence[sequence.index("{") : sequence.rindex("}") + 1] 54 | result = json.loads(sequence) 55 | best = result["best_id"] 56 | worst = result["worst_id"] 57 | 58 | except: 59 | 60 | pattern_best = re.compile(r'"best_id"\s*:\s*(\d+)') 61 | pattern_worst = re.compile(r'"worst_id"\s*:\s*(\d+)') 62 | 63 | best_match = pattern_best.search(sequence) 64 | worst_match = pattern_worst.search(sequence) 65 | 66 | best = int(best_match.group(1)) if best_match else None 67 | worst = int(worst_match.group(1)) if worst_match else None 68 | if best == worst: 69 | return None, None 70 | return best, worst 71 | 72 | 73 | def LLM_score_rag(responses, answer): 74 | chosen, rejected = [], [] 75 | for item in responses: 76 | if item.strip().lower() == answer.lower(): 77 | chosen.append(item) 78 | if answer.lower() not in item.lower(): 79 | rejected.append(item) 80 | if len(chosen) == 0: 81 | for item in responses: 82 | if answer.strip().lower() in item.lower() and len(item) < 2 * len(answer): 83 | chosen.append(item) 84 | if len(rejected) == 0: 85 | for item in responses: 86 | if len(item.strip()) > 2 * len(answer): 87 | rejected.append(item) 88 | if len(chosen) > 0 and len(rejected) > 0: 89 | return random.sample(chosen, 1)[0], random.sample(rejected, 1)[0] 90 | else: 91 | return "", "" 92 | 93 | 94 | def LLM_score_init_note(model_name, question, refs, init_notes): 95 | init_notes_prompt = [ 96 | {"_id": id, "content": content} for id, content in enumerate(init_notes) 97 | ] 98 | prompt = config["score"]["init_note"].format( 99 | query=question, refs=refs, notes=init_notes_prompt 100 | ) 101 | response = call_gpt(model_name, prompt) 102 | best, worst = extract_best_worst(response) 103 | try: 104 | best_note, worst_note = init_notes[best], init_notes[worst] 105 | except: 106 | return init_notes[-1], "" 107 | return best_note, worst_note 108 | 109 | 110 | def LLM_score_gen_new_query( 111 | model_name, question, best_init_note, new_querys, query_log 112 | ): 113 | querys_prompt = [ 114 | {"_id": id, "content": content} for id, content in enumerate(new_querys) 115 | ] 116 | prompt = config["score"]["gen_new_query"].format( 117 | notes=best_init_note, 118 | query=question, 119 | query_log=query_log, 120 | new_querys=querys_prompt, 121 | ) 122 | response = call_gpt(model_name, prompt) 123 | best, worst = extract_best_worst(response) 124 | 125 | try: 126 | best_query, worst_query = new_querys[best], new_querys[worst] 127 | except: 128 | return new_querys[-1], "" 129 | return best_query, worst_query 130 | 131 | 132 | def LLM_score_compare_note(model_name, query, best_note, new_note): 133 | prompt = p_template( 134 | "compare", {"query": query, "best_note": best_note, "new_note": new_note} 135 | ) 136 | response = call_gpt(model_name, prompt) 137 | 138 | try: 139 | if "json" in response: 140 | response = response[response.index("{") : response.rindex("}") + 1] 141 | status = json.loads(response)["status"] 142 | return status.lower() == "true" 143 | except: 144 | return "true" in response.lower() 145 | 146 | 147 | def seed_everything(seed): 148 | torch.manual_seed(seed) 149 | torch.cuda.manual_seed(seed) 150 | np.random.seed(seed) 151 | random.seed(seed) 152 | torch.backends.cudnn.benchmark = False 153 | torch.backends.cudnn.deterministic = True 154 | torch.cuda.manual_seed_all(seed) 155 | --------------------------------------------------------------------------------