├── README.md
├── assets
└── DeepNote.png
├── config
├── config.yaml
└── ds_config_zero3.json
├── prompts
└── en
│ ├── base
│ ├── base_asqa
│ ├── base_strategyqa
│ ├── base_wo_retri
│ ├── base_wo_retri_asqa
│ ├── base_wo_retri_strategyqa
│ ├── compare
│ ├── gen_answer
│ ├── gen_answer_asqa
│ ├── gen_answer_strategyqa
│ ├── gen_new_query
│ ├── init_note
│ └── refine_note
├── requirements.txt
└── src
├── build_index
├── emb
│ └── index.py
└── es
│ ├── index_2wiki.py
│ ├── index_hotpotqa.py
│ └── index_musique.py
├── es_retrieve.py
├── eval.py
├── gen_dpo_data.py
├── main.py
├── select_dpo_data.py
├── train.py
├── train.sh
└── utils.py
/README.md:
--------------------------------------------------------------------------------
1 |
2 | DeepNote: Note-Centric Deep Retrieval-Augmented Generation
3 |
4 |
5 |
6 | We develop **DeepNote**, an adaptive RAG framework that achieves in-depth
7 | and robust exploration of knowledge sources through note-centric adaptive retrieval. DeepNote employs notes as carriers for refining and accumulating knowledge. During in-depth exploration, it uses these notes to determine retrieval timing, formulate retrieval queries, and iteratively assess knowledge growth, ultimately leveraging the best note for answer generation.
8 |
9 | 
10 |
11 | # Prepare Datasets
12 |
13 | All corpus and evaluation files should be placed in the `/data` directory. You can download the experimental data [here](https://drive.google.com/drive/folders/1NeEm-r7l43MQxGS1n7jJ8tPvltgcaPjY?usp=sharing).
14 |
15 | We use Wikipedia as the corpus for ASQA and StrategyQA. Due to its large size, please download it separately [here](https://dl.fbaipublicfiles.com/dpr/wikipedia_split/psgs_w100.tsv.gz) and place it in `/data/corpus/wiki/`.
16 |
17 | # Retrieval Settings
18 |
19 | For different datasets, we employ various retrieval methods:
20 |
21 | For 2WikiMQA, MusiQue, and HotpotQA:
22 | - BM25 retrieval based on ElasticSearch
23 | - Dense retrieval with FAISS index using embeddings from BGE model
24 |
25 | For ASQA and StrategyQA:
26 | - Dense retrieval with FAISS index using embeddings from GTR model
27 |
28 | ## Setup ElasticSearch
29 |
30 | Install Elasticsearch:
31 | ```bash
32 | wget https://artifacts.elastic.co/downloads/elasticsearch/elasticsearch-7.10.2-linux-x86_64.tar.gz
33 | wget https://artifacts.elastic.co/downloads/elasticsearch/elasticsearch-7.10.2-linux-x86_64.tar.gz.sha512
34 | shasum -a 512 -c elasticsearch-7.10.2-linux-x86_64.tar.gz.sha512
35 | tar -xzf elasticsearch-7.10.2-linux-x86_64.tar.gz
36 | cd elasticsearch-7.10.2/
37 | ./bin/elasticsearch # Start the server
38 | pkill -f elasticsearch # To stop the server
39 | ```
40 |
41 | ## Build Indices
42 |
43 | ### For BM25
44 | ```bash
45 | cd src/build_index/es
46 |
47 | # 2WikiMQA
48 | python index_2wiki.py
49 |
50 | # MusiQue
51 | python index_musique.py
52 |
53 | # HotpotQA
54 | python index_hotpotqa.py
55 | ```
56 |
57 | ### For Dense Retrieval
58 |
59 | #### For HotpotQA, 2WikiMQA, and MusiQue
60 | ```bash
61 | cd src/build_index/emb
62 | python index.py --dataset hotpotqa --model bge-base-en-v1.5 # e.g., for HotpotQA dataset
63 | ```
64 |
65 | #### For ASQA and StrategyQA
66 | Since generating GTR embeddings for Wikipedia corpus is time-consuming, you can download the pre-computed GTR embeddings and place them in `data/corpus/wiki/`:
67 | ```bash
68 | wget https://huggingface.co/datasets/princeton-nlp/gtr-t5-xxl-wikipedia-psgs_w100-index/resolve/main/gtr_wikipedia_index.pkl
69 | ```
70 |
71 | Then build FAISS index:
72 | ```bash
73 | cd src/build_index/emb
74 | python index.py --dataset asqa --model gtr-t5-xxl
75 | ```
76 |
77 | # Configuration
78 |
79 | You can configure your API key, URL, and other settings in the `./config/config.yaml` file.
80 |
81 |
82 | # Training DeepNote
83 |
84 | The training process consists of three main steps:
85 |
86 | ## 1. Generate Training Data
87 | Generate the initial training data using LLaMA model:
88 | ```bash
89 | python gen_dpo_data.py \
90 | --model llama-3.1-8b-instruct \
91 | --batch_size 9 \
92 | --output_path ../data/dpo_data \
93 | --device 0,1,2,3
94 | ```
95 |
96 | ## 2. Data Selection
97 | Filter and process the generated data:
98 | ```bash
99 | python select_dpo_data.py \
100 | --output_path ../data/dpo/processed/train.jsonl \
101 | --init_num 1900 \
102 | --refine_num 1900 \
103 | --query_num 1900
104 | ```
105 |
106 | ## 3. Start Training
107 | Launch the training process:
108 | ```bash
109 | bash train.sh
110 | ```
111 |
112 |
113 | # Running DeepNote and Evaluation
114 |
115 | ```bash
116 | python main.py --method deepnote --retrieve_top_k 5 --dataset hotpotqa --max_step 3 --max_fail_step 2 --MaxClients 5 --model gpt-4o-mini-2024-07-18 --device cuda:0
117 | ```
118 | The predicted results and evaluation metrics will be automatically saved in the `output/{dataset}/` directory. The evaluation results can be found at the end of the file.
119 |
--------------------------------------------------------------------------------
/assets/DeepNote.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/thunlp/DeepNote/cc2f0132737e04ec2d895e8c21673cee612ca1e7/assets/DeepNote.png
--------------------------------------------------------------------------------
/config/config.yaml:
--------------------------------------------------------------------------------
1 | model:
2 | OPENAI_API_KEY: ""
3 | OPENAI_BASE_URL: ""
4 | llama-3.1-8b-instruct: "meta-llama/Llama-3.1-8B-Instruct"
5 | llama-3.1-8b-instruct-dpo: ""
6 | llama-3.1-70b-instruct: "meta-llama/Llama-3.1-70B-Instruct"
7 | qwen2.5-7b-instruct: "Qwen/Qwen2.5-7B-Instruct"
8 | qwen2.5-7b-instruct-dpo: ""
9 | bge-base-en-v1.5: "BAAI/bge-base-en-v1.5"
10 | gtr-t5-xxl: "sentence-transformers/gtr-t5-xxl"
11 | es:
12 | url: "http://localhost:9200"
13 |
14 | score:
15 | init_note: 'Task: You will receive a list of notes generated based on a given document content and question. \nYour task is to evaluate and score these notes based on their quality. Quality refers to: relevance, coherence, completeness in answering the specified question, and accuracy of information. \n\nQuestion to be answered: {query}\nDocument content: {refs}\n\nGenerated notes: {notes}\nNote format: Each note contains "_id" and "content" fields. \n\nEvaluate the generated notes. The highest-scoring note must be factually correct based on the document. If no note is correct, or if there is minimal quality difference between notes, use the same _id for both best and worst. \nOutput in the following JSON format:\n```json {{"best_id": <_id of the highest-scoring note>, "worst_id": <_id of the lowest-scoring note>}}```\n\n Do not include any explanations or additional text.'
16 | gen_new_query: 'Task: You will receive a list of new questions generated based on some notes and an existing question list to supplement a given original question.\nYour task is to evaluate these new questions based on their quality. Quality refers to: relevance, specificity, keyword richness, and non-redundancy. The goal is to identify questions that can retrieve useful information to help answer the original question. \n Notes: {notes}\n\nOriginal question: {query}\n\nExisting question list: {query_log}\n\nNew question list: {new_querys}\nQuestion format: Each question contains "_id" and "content" fields. \n\nEvaluate the new question list. The highest-scoring new question must be able to help retrieve relevant information to answer the original question. If no new question can help get useful information, or if there is minimal quality difference between new questions, use the same _id for both best_id and worst_id. \nOutput in the following format:\n```json {{"best_id": <_id of the highest-scoring question>, "worst_id": <_id of the lowest-scoring question>}}```\n\n Do not include any explanations or additional text.'
17 |
--------------------------------------------------------------------------------
/config/ds_config_zero3.json:
--------------------------------------------------------------------------------
1 | {
2 | "fp16": {
3 | "enabled": "auto",
4 | "loss_scale": 0,
5 | "loss_scale_window": 1000,
6 | "initial_scale_power": 16,
7 | "hysteresis": 2,
8 | "min_loss_scale": 1
9 | },
10 | "bf16": {
11 | "enabled": "auto"
12 | },
13 | "optimizer": {
14 | "type": "AdamW",
15 | "params": {
16 | "lr": "auto",
17 | "betas": "auto",
18 | "eps": "auto",
19 | "weight_decay": "auto"
20 | }
21 | },
22 | "zero_optimization": {
23 | "stage": 3,
24 | "offload_optimizer": {
25 | "device": "cpu",
26 | "pin_memory": true
27 | },
28 | "offload_param": {
29 | "device": "cpu",
30 | "pin_memory": true
31 | },
32 | "overlap_comm": true,
33 | "contiguous_gradients": true,
34 | "sub_group_size": 1e9,
35 | "reduce_bucket_size": "auto",
36 | "stage3_prefetch_bucket_size": "auto",
37 | "stage3_param_persistence_threshold": "auto",
38 | "stage3_max_live_parameters": 1e9,
39 | "stage3_max_reuse_distance": 1e9,
40 | "stage3_gather_16bit_weights_on_model_save": true
41 | },
42 | "gradient_accumulation_steps": "auto",
43 | "gradient_clipping": "auto",
44 | "steps_per_print": 2000,
45 | "train_batch_size": "auto",
46 | "train_micro_batch_size_per_gpu": "auto",
47 | "wall_clock_breakdown": false
48 | }
--------------------------------------------------------------------------------
/prompts/en/base:
--------------------------------------------------------------------------------
1 | Answer the question based on the given passages. Only give me the answer and do not output any other words.
2 |
3 | The following are given passages:
4 | {refs}
5 |
6 | Question: {query}
7 | Answer:
--------------------------------------------------------------------------------
/prompts/en/base_asqa:
--------------------------------------------------------------------------------
1 | Instruction: Write an accurate, engaging, and concise answer for the given question using only the provided passages. Use an unbiased and journalistic tone.
2 |
3 |
4 | Question:
5 | {query}
6 |
7 |
8 | Passages:
9 | {refs}
10 |
11 |
--------------------------------------------------------------------------------
/prompts/en/base_strategyqa:
--------------------------------------------------------------------------------
1 | Answer the question based on the given passages. Only give me 'yes' or 'no' as your answer and do not output any other words.
2 |
3 | The following are given passages:
4 | {refs}
5 |
6 | Question: {query}
7 | Answer:
--------------------------------------------------------------------------------
/prompts/en/base_wo_retri:
--------------------------------------------------------------------------------
1 | Only give me the answer and do not output any other words.
2 |
3 | Question: {query}
4 | Answer:
--------------------------------------------------------------------------------
/prompts/en/base_wo_retri_asqa:
--------------------------------------------------------------------------------
1 | Instruction: Write an accurate, engaging, and concise answer for the given question. Use an unbiased and journalistic tone.
2 |
3 | Question:
4 | {query}
5 |
--------------------------------------------------------------------------------
/prompts/en/base_wo_retri_strategyqa:
--------------------------------------------------------------------------------
1 | Only give me 'yes' or 'no' as your answer and do not output any other words.
2 |
3 | Question: {query}
4 | Answer:
--------------------------------------------------------------------------------
/prompts/en/compare:
--------------------------------------------------------------------------------
1 | Task: Please help me determine which note is better based on the following evaluation criteria:
2 | 1. Contains key information directly related to the question.
3 | 2. Completeness of Information: Does it cover all relevant aspects and details?
4 | 3. Level of Detail: Does it provide enough detail to understand the issue in depth?
5 | 4. Practicality: Does the note offer practical help and solutions?
6 |
7 | Please make your judgment adhering strictly to the following rules:
8 | - If Note 2 does not add new meaningful content on top of Note 1, or only adds redundant information, return ```json {{"status":"False"}}``` directly.
9 | - If Note 2 has significant improvements over Note 1 based on the above criteria, return ```json {{"status":"True"}}``` directly; otherwise, return ```json {{"status":"False"}}```.
10 |
11 | Question: {query}
12 | Provided Note 1: {best_note}
13 | Provided Note 2: {new_note}
14 |
15 | Based on the above information, make your judgment without explanation and return the result directly.
16 |
--------------------------------------------------------------------------------
/prompts/en/gen_answer:
--------------------------------------------------------------------------------
1 | Answer the question based on the given notes. Only give me the answer and do not output any other words.
2 |
3 | The following are given notes:
4 | {note}
5 |
6 | Question: {query}
7 | Answer:
--------------------------------------------------------------------------------
/prompts/en/gen_answer_asqa:
--------------------------------------------------------------------------------
1 | Instruction: Write an accurate, engaging, and concise answer for the given question using only the provided notes. Use an unbiased and journalistic tone.
2 |
3 |
4 | Question:
5 | {query}
6 |
7 |
8 | Notes:
9 | {note}
--------------------------------------------------------------------------------
/prompts/en/gen_answer_strategyqa:
--------------------------------------------------------------------------------
1 | Answer the question based on the given notes. Only give me 'yes' or 'no' as your answer and do not output any other words.
2 |
3 | The following are given notes:
4 | {note}
5 |
6 | Question: {query}
7 | Answer:
--------------------------------------------------------------------------------
/prompts/en/gen_new_query:
--------------------------------------------------------------------------------
1 | Task: Based on the notes, propose two new questions. These new questions will be used to retrieve documents to supplement the notes and help answer the original question. The new questions should be concise and include keywords that facilitate retrieval. The new questions should avoid duplication with the existing question list.
2 |
3 | original question: {query}
4 | notes: {note}
5 |
6 | existing question list: {query_log}
7 |
8 | Do not print any other words. Do not explain. Output only the two new questions you asked.
--------------------------------------------------------------------------------
/prompts/en/init_note:
--------------------------------------------------------------------------------
1 | Based on the provided document content, write a note. The note should integrate all relevant information from the original text that can help answer the specified question and form a coherent paragraph. Please ensure that the note includes all original text information useful for answering the question.
2 |
3 |
4 | Question to be answered: {query}
5 | Document content: {refs}
6 |
7 |
8 | Please provide the note you wrote:
--------------------------------------------------------------------------------
/prompts/en/refine_note:
--------------------------------------------------------------------------------
1 | Task: Based on the retrieved documents, supplement the notes with content not yet included but useful for answering the question. The supplement should use the original text from the retrieved documents. The added content should include as much information from the retrieved documents as possible.
2 |
3 | question: {query}
4 | retrieved documents: {refs}
5 |
6 | notes: {note}
7 |
8 | Provide your supplemented notes:
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | backoff
2 | elasticsearch==7.9.1
3 | faiss_cpu
4 | jieba
5 | numpy
6 | openai==1.57.4
7 | pandas
8 | sentence_transformers
9 | torch==2.4.0
10 | tqdm
11 | transformers==4.48.0
12 | pyyaml
13 | llama-index==0.9.30
14 | vllm==0.5.4
15 | trl==0.13.0
16 |
--------------------------------------------------------------------------------
/src/build_index/emb/index.py:
--------------------------------------------------------------------------------
1 | import time
2 | import os
3 | import json
4 | from llama_index import SimpleDirectoryReader
5 | import csv
6 | from sentence_transformers import SentenceTransformer
7 | import faiss
8 | import argparse
9 | import yaml
10 | from glob import glob
11 | from itertools import chain
12 | from tqdm import tqdm
13 | from llama_index import Document
14 | from llama_index.node_parser import SimpleNodeParser
15 |
16 | with open("../../../config/config.yaml", "r") as f:
17 | config = yaml.safe_load(f)
18 | parser = argparse.ArgumentParser()
19 | parser.add_argument(
20 | "--model",
21 | type=str,
22 | choices=["bge-base-en-v1.5", "gtr-t5-xxl"],
23 | default="bge-base-en-v1.5",
24 | help="Model to use",
25 | )
26 | parser.add_argument(
27 | "--dataset",
28 | type=str,
29 | default="2wikimultihopqa",
30 | choices=["asqa", "strategyqa", "2wikimultihopqa", "hotpotqa", "musique"],
31 | help="Dataset to use",
32 | )
33 | parser.add_argument("--chunk_size", type=int, default=512, help="chunk size")
34 | parser.add_argument("--chunk_overlap", type=int, default=0, help="chunk overlap")
35 | parser.add_argument("--device", type=str, default="cuda:7", help="Device to use")
36 | args = parser.parse_args()
37 |
38 |
39 | def split_text(data):
40 |
41 | documents = []
42 | for record in data:
43 | if record["title"]:
44 | combined_text = record["title"] + "\n" + record["content"]
45 | else:
46 | combined_text = record["content"]
47 | documents.append(Document(text=combined_text))
48 |
49 | node_parser = SimpleNodeParser.from_defaults(
50 | chunk_size=args.chunk_size, chunk_overlap=args.chunk_overlap
51 | )
52 | nodes = node_parser.get_nodes_from_documents(documents, show_progress=True)
53 |
54 | contents = [node.text for node in nodes]
55 | return contents
56 |
57 |
58 | def build_index(embeddings, vectorstore_path):
59 | dimension = embeddings.shape[1]
60 | index = faiss.IndexFlatIP(dimension)
61 | index.add(embeddings)
62 | faiss.write_index(index, vectorstore_path)
63 |
64 |
65 | if __name__ == "__main__":
66 |
67 | model = SentenceTransformer(config["model"][args.model], device=args.device)
68 | if args.dataset == "asqa" or args.dataset == "strategyqa":
69 | dataset_name = "wiki"
70 | else:
71 | dataset_name = args.dataset
72 | vectorstore_path = f"../../../data/corpus/{dataset_name}/{dataset_name}.index"
73 | contents = []
74 | print("loading document ...")
75 | start = time.time()
76 | if dataset_name == "wiki":
77 | if os.path.exists("../../../data/corpus/wiki/gtr_wikipedia_index.pkl"):
78 | import pickle
79 |
80 | with open(
81 | "../../../data/corpus/wiki/gtr_wikipedia_index.pkl", "rb"
82 | ) as file:
83 | embeddings = pickle.load(file)
84 | else:
85 | with open(
86 | "../../../data/corpus/wiki/psgs_w100.tsv", "r", encoding="utf-8"
87 | ) as file:
88 | tsv_data = csv.DictReader(file, delimiter="\t")
89 | raw_data = [row["title"] + "\n" + row["text"] for row in tsv_data]
90 | print("dataset length", len(raw_data))
91 | embeddings = model.encode(raw_data, batch_size=100)
92 | elif dataset_name == "2wikimultihopqa":
93 | train = json.load(open("../../../data/corpus/2wikimultihopqa/train.json", "r"))
94 | dev = json.load(open("../../../data/corpus/2wikimultihopqa/dev.json", "r"))
95 | test = json.load(open("../../../data/corpus/2wikimultihopqa/test.json", "r"))
96 |
97 | data = {}
98 | for item in tqdm(chain(train, dev, test)):
99 | for title, sentences in item["context"]:
100 | para = " ".join(sentences)
101 | data[para] = title
102 | contents = [
103 | {"id": i, "content": text, "title": title}
104 | for i, (text, title) in enumerate(data.items())
105 | ]
106 | elif dataset_name == "hotpotqa":
107 | import bz2
108 | from multiprocessing import Pool
109 |
110 | def process_line(line):
111 | data = json.loads(line)
112 | item = {
113 | "id": data["id"],
114 | "title": data["title"],
115 | "content": "".join(data["text"]),
116 | }
117 | return item
118 |
119 | def generate_indexing_queries_from_bz2(bz2file, dry=False):
120 | if dry:
121 | return
122 |
123 | with bz2.open(bz2file, "rt") as f:
124 | body = [process_line(line) for line in f]
125 |
126 | return body
127 |
128 | filelist = glob("../../../data/corpus/hotpotqa/*/wiki_*.bz2")
129 |
130 | print("Making indexing queries...")
131 | pool = Pool()
132 |
133 | for result in tqdm(
134 | pool.imap(generate_indexing_queries_from_bz2, filelist), total=len(filelist)
135 | ):
136 | contents.extend(result)
137 | elif dataset_name == "musique":
138 | train = [
139 | json.loads(line.strip())
140 | for line in open(
141 | "../../../data/corpus/musique/musique_ans_v1.0_train.jsonl"
142 | )
143 | ] + [
144 | json.loads(line.strip())
145 | for line in open(
146 | "../../../data/corpus/musique/musique_full_v1.0_train.jsonl"
147 | )
148 | ]
149 | dev = [
150 | json.loads(line.strip())
151 | for line in open("../../../data/corpus/musique/musique_ans_v1.0_dev.jsonl")
152 | ] + [
153 | json.loads(line.strip())
154 | for line in open("../../../data/corpus/musique/musique_full_v1.0_dev.jsonl")
155 | ]
156 | test = [
157 | json.loads(line.strip())
158 | for line in open("../../../data/corpus/musique/musique_ans_v1.0_test.jsonl")
159 | ] + [
160 | json.loads(line.strip())
161 | for line in open(
162 | "../../../data/corpus/musique/musique_full_v1.0_test.jsonl"
163 | )
164 | ]
165 |
166 | tot = 0
167 | hist = set()
168 | for item in tqdm(chain(train, dev, test)):
169 | for p in item["paragraphs"]:
170 | stamp = p["title"] + " " + p["paragraph_text"]
171 | if not stamp in hist:
172 | contents.append(
173 | {"id": tot, "content": p["paragraph_text"], "title": p["title"]}
174 | )
175 | hist.add(stamp)
176 | tot += 1
177 |
178 | contents = split_text(contents)
179 | embeddings = model.encode(contents, batch_size=600)
180 | with open(
181 | f"../../../data/corpus/{dataset_name}/chunk.json", "w", encoding="utf-8"
182 | ) as fout:
183 | json.dump(contents, fout, ensure_ascii=False)
184 | print("Building index ...")
185 | build_index(embeddings, vectorstore_path)
186 | end = time.time()
187 |
--------------------------------------------------------------------------------
/src/build_index/es/index_2wiki.py:
--------------------------------------------------------------------------------
1 | from argparse import ArgumentParser
2 | from elasticsearch import Elasticsearch
3 | import html
4 | import json
5 | from tqdm import tqdm
6 | from itertools import chain
7 |
8 |
9 | def chunks(l, n):
10 | """Yield successive n-sized chunks from l."""
11 | for i in range(0, len(l), n):
12 | yield l[i : i + n]
13 |
14 |
15 | INDEX_NAME = "2wikimultihopqa"
16 |
17 |
18 | def process_line(data):
19 | item = {
20 | "id": data["id"],
21 | "url": "empty",
22 | "title": data["title"],
23 | "title_unescape": html.unescape(data["title"]),
24 | "text": data["text"],
25 | "title_bigram": html.unescape(data["title"]),
26 | "title_unescape_bigram": html.unescape(data["title"]),
27 | "text_bigram": data["text"],
28 | "original_json": json.dumps(data),
29 | }
30 | return "{}\n{}".format(
31 | json.dumps({"index": {"_id": "wiki-{}".format(data["id"])}}), json.dumps(item)
32 | )
33 |
34 |
35 | es = Elasticsearch(hosts="http://localhost:9200", timeout=100)
36 |
37 |
38 | def index_chunk(chunk):
39 | res = es.bulk(index=INDEX_NAME, body="\n".join(chunk), timeout="100s")
40 | assert not res["errors"], res
41 |
42 |
43 | def main(args):
44 |
45 | train = json.load(open("../../../data/corpus/2wikimultihopqa/train.json", "r"))
46 | dev = json.load(open("../../../data/corpus/2wikimultihopqa/dev.json", "r"))
47 | test = json.load(open("../../../data/corpus/2wikimultihopqa/test.json", "r"))
48 |
49 | data = {}
50 | for item in tqdm(chain(train, dev, test)):
51 | for title, sentences in item["context"]:
52 | para = " ".join(sentences)
53 | data[para] = title
54 | wikipedia_data = [
55 | {"id": i, "text": text, "title": title}
56 | for i, (text, title) in enumerate(data.items())
57 | ]
58 |
59 | # make index
60 | if not args.dry:
61 | # add
62 | if es.indices.exists(index=INDEX_NAME):
63 | es.indices.delete(index=INDEX_NAME, ignore=[400, 403])
64 | es.indices.create(
65 | index=INDEX_NAME,
66 | ignore=400,
67 | body=json.dumps(
68 | {
69 | "mappings": {
70 | "doc": {
71 | "properties": {
72 | "id": {"type": "keyword"},
73 | "url": {"type": "keyword"},
74 | "title": {
75 | "type": "text",
76 | "analyzer": "simple",
77 | "copy_to": "title_all",
78 | },
79 | "title_unescape": {
80 | "type": "text",
81 | "analyzer": "simple",
82 | "copy_to": "title_all",
83 | },
84 | "text": {
85 | "type": "text",
86 | "analyzer": "my_english_analyzer",
87 | },
88 | "anchortext": {
89 | "type": "text",
90 | "analyzer": "my_english_analyzer",
91 | },
92 | "title_bigram": {
93 | "type": "text",
94 | "analyzer": "simple_bigram_analyzer",
95 | "copy_to": "title_all_bigram",
96 | },
97 | "title_unescape_bigram": {
98 | "type": "text",
99 | "analyzer": "simple_bigram_analyzer",
100 | "copy_to": "title_all_bigram",
101 | },
102 | "text_bigram": {
103 | "type": "text",
104 | "analyzer": "bigram_analyzer",
105 | },
106 | "anchortext_bigram": {
107 | "type": "text",
108 | "analyzer": "bigram_analyzer",
109 | },
110 | "original_json": {"type": "string"},
111 | }
112 | }
113 | },
114 | "settings": {
115 | "analysis": {
116 | "my_english_analyzer": {
117 | "type": "standard",
118 | "stopwords": "_english_",
119 | },
120 | "simple_bigram_analyzer": {
121 | "tokenizer": "standard",
122 | "filter": ["lowercase", "shingle", "asciifolding"],
123 | },
124 | "bigram_analyzer": {
125 | "tokenizer": "standard",
126 | "filter": [
127 | "lowercase",
128 | "stop",
129 | "shingle",
130 | "asciifolding",
131 | ],
132 | },
133 | },
134 | },
135 | }
136 | ),
137 | )
138 |
139 | print("Making indexing queries...")
140 | all_queries = []
141 | for item in tqdm(wikipedia_data):
142 | all_queries.append(process_line(item))
143 |
144 | count = sum(len(queries.split("\n")) for queries in all_queries) // 2
145 |
146 | if not args.dry:
147 | print("Indexing...")
148 | chunksize = 100
149 | for chunk in tqdm(
150 | chunks(all_queries, chunksize),
151 | total=(len(all_queries) + chunksize - 1) // chunksize,
152 | ):
153 | res = es.bulk(index=INDEX_NAME, body="\n".join(chunk), timeout="100s")
154 | assert not res["errors"], res
155 |
156 | print(f"{count} documents indexed in total")
157 |
158 |
159 | if __name__ == "__main__":
160 | parser = ArgumentParser()
161 |
162 | parser.add_argument("--reindex", action="store_true", help="Reindex everything")
163 | parser.add_argument("--dry", action="store_true", help="Dry run")
164 |
165 | args = parser.parse_args()
166 |
167 | main(args)
168 |
--------------------------------------------------------------------------------
/src/build_index/es/index_hotpotqa.py:
--------------------------------------------------------------------------------
1 | from argparse import ArgumentParser
2 | import bz2
3 | from elasticsearch import Elasticsearch
4 | from glob import glob
5 | import html
6 | import json
7 | from multiprocessing import Pool
8 | from tqdm import tqdm
9 |
10 | WIKIPEDIA_INDEX_NAME = "hotpotqa"
11 |
12 |
13 | def chunks(l, n):
14 | for i in range(0, len(l), n):
15 | yield l[i : i + n]
16 |
17 |
18 | def process_line(line):
19 | data = json.loads(line)
20 | item = {
21 | "id": data["id"],
22 | "url": data["url"],
23 | "title": data["title"],
24 | "title_unescape": html.unescape(data["title"]),
25 | "text": "".join(data["text"]),
26 | "title_bigram": html.unescape(data["title"]),
27 | "title_unescape_bigram": html.unescape(data["title"]),
28 | "text_bigram": "".join(data["text"]),
29 | "original_json": line,
30 | }
31 | return "{}\n{}".format(
32 | json.dumps({"index": {"_id": "wiki-{}".format(data["id"])}}), json.dumps(item)
33 | )
34 |
35 |
36 | def generate_indexing_queries_from_bz2(bz2file, dry=False):
37 | if dry:
38 | return
39 |
40 | with bz2.open(bz2file, "rt") as f:
41 | body = [process_line(line) for line in f]
42 |
43 | return "\n".join(body)
44 |
45 |
46 | es = Elasticsearch(hosts="http://localhost:9200", timeout=50)
47 |
48 |
49 | def index_chunk(chunk):
50 | res = es.bulk(index=WIKIPEDIA_INDEX_NAME, body="\n".join(chunk), timeout="100s")
51 | assert not res["errors"], res
52 |
53 |
54 | def main(args):
55 | if not args.dry:
56 | if es.indices.exists(index=WIKIPEDIA_INDEX_NAME):
57 | es.indices.delete(index=WIKIPEDIA_INDEX_NAME, ignore=[400, 403])
58 | if not es.indices.exists(index=WIKIPEDIA_INDEX_NAME):
59 | es.indices.create(
60 | index=WIKIPEDIA_INDEX_NAME,
61 | ignore=400,
62 | body=json.dumps(
63 | {
64 | "mappings": {
65 | "doc": {
66 | "properties": {
67 | "id": {"type": "keyword"},
68 | "url": {"type": "keyword"},
69 | "title": {
70 | "type": "text",
71 | "analyzer": "simple",
72 | "copy_to": "title_all",
73 | },
74 | "title_unescape": {
75 | "type": "text",
76 | "analyzer": "simple",
77 | "copy_to": "title_all",
78 | },
79 | "text": {
80 | "type": "text",
81 | "analyzer": "my_english_analyzer",
82 | },
83 | "anchortext": {
84 | "type": "text",
85 | "analyzer": "my_english_analyzer",
86 | },
87 | "title_bigram": {
88 | "type": "text",
89 | "analyzer": "simple_bigram_analyzer",
90 | "copy_to": "title_all_bigram",
91 | },
92 | "title_unescape_bigram": {
93 | "type": "text",
94 | "analyzer": "simple_bigram_analyzer",
95 | "copy_to": "title_all_bigram",
96 | },
97 | "text_bigram": {
98 | "type": "text",
99 | "analyzer": "bigram_analyzer",
100 | },
101 | "anchortext_bigram": {
102 | "type": "text",
103 | "analyzer": "bigram_analyzer",
104 | },
105 | "original_json": {"type": "string"},
106 | }
107 | }
108 | },
109 | "settings": {
110 | "analysis": {
111 | "my_english_analyzer": {
112 | "type": "standard",
113 | "stopwords": "_english_",
114 | },
115 | "simple_bigram_analyzer": {
116 | "tokenizer": "standard",
117 | "filter": ["lowercase", "shingle", "asciifolding"],
118 | },
119 | "bigram_analyzer": {
120 | "tokenizer": "standard",
121 | "filter": [
122 | "lowercase",
123 | "stop",
124 | "shingle",
125 | "asciifolding",
126 | ],
127 | },
128 | },
129 | },
130 | }
131 | ),
132 | )
133 |
134 | filelist = glob("../../../data/corpus/hotpotqa/*/wiki_*.bz2")
135 |
136 | print("Making indexing queries...")
137 | pool = Pool()
138 | all_queries = list(
139 | tqdm(
140 | pool.imap(generate_indexing_queries_from_bz2, filelist), total=len(filelist)
141 | )
142 | )
143 |
144 | count = sum(len(queries.split("\n")) for queries in all_queries) // 2
145 |
146 | if not args.dry:
147 | print("Indexing...")
148 | chunksize = 50
149 | for chunk in tqdm(
150 | chunks(all_queries, chunksize),
151 | total=(len(all_queries) + chunksize - 1) // chunksize,
152 | ):
153 | res = es.bulk(
154 | index=WIKIPEDIA_INDEX_NAME, body="\n".join(chunk), timeout="100s"
155 | )
156 | assert not res["errors"], res
157 |
158 | print(f"{count} documents indexed in total")
159 |
160 |
161 | if __name__ == "__main__":
162 | parser = ArgumentParser()
163 |
164 | parser.add_argument("--reindex", action="store_true", help="Reindex everything")
165 | parser.add_argument("--dry", action="store_true", help="Dry run")
166 |
167 | args = parser.parse_args()
168 |
169 | main(args)
170 |
--------------------------------------------------------------------------------
/src/build_index/es/index_musique.py:
--------------------------------------------------------------------------------
1 | from argparse import ArgumentParser
2 | from elasticsearch import Elasticsearch
3 | import html
4 | import json
5 | from tqdm import tqdm
6 |
7 | from itertools import chain
8 |
9 |
10 | def chunks(l, n):
11 | for i in range(0, len(l), n):
12 | yield l[i : i + n]
13 |
14 |
15 | INDEX_NAME = "musique"
16 |
17 |
18 | def process_line(data):
19 | item = {
20 | "id": data["id"],
21 | "url": "empty",
22 | "title": data["title"],
23 | "title_unescape": html.unescape(data["title"]),
24 | "text": data["text"],
25 | "title_bigram": html.unescape(data["title"]),
26 | "title_unescape_bigram": html.unescape(data["title"]),
27 | "text_bigram": data["text"],
28 | "original_json": json.dumps(data),
29 | }
30 | return "{}\n{}".format(
31 | json.dumps({"index": {"_id": "wiki-{}".format(data["id"])}}), json.dumps(item)
32 | )
33 |
34 |
35 | es = Elasticsearch(hosts="http://localhost:9200", timeout=50)
36 |
37 |
38 | def index_chunk(chunk):
39 | res = es.bulk(index=INDEX_NAME, body="\n".join(chunk), timeout="100s")
40 | assert not res["errors"], res
41 |
42 |
43 | def main(args):
44 | # make index
45 | if not args.dry:
46 | if es.indices.exists(index=INDEX_NAME):
47 | es.indices.delete(index=INDEX_NAME, ignore=[400, 403])
48 | es.indices.create(
49 | index=INDEX_NAME,
50 | ignore=400,
51 | body=json.dumps(
52 | {
53 | "mappings": {
54 | "doc": {
55 | "properties": {
56 | "id": {"type": "keyword"},
57 | "url": {"type": "keyword"},
58 | "title": {
59 | "type": "text",
60 | "analyzer": "simple",
61 | "copy_to": "title_all",
62 | },
63 | "title_unescape": {
64 | "type": "text",
65 | "analyzer": "simple",
66 | "copy_to": "title_all",
67 | },
68 | "text": {
69 | "type": "text",
70 | "analyzer": "my_english_analyzer",
71 | },
72 | "anchortext": {
73 | "type": "text",
74 | "analyzer": "my_english_analyzer",
75 | },
76 | "title_bigram": {
77 | "type": "text",
78 | "analyzer": "simple_bigram_analyzer",
79 | "copy_to": "title_all_bigram",
80 | },
81 | "title_unescape_bigram": {
82 | "type": "text",
83 | "analyzer": "simple_bigram_analyzer",
84 | "copy_to": "title_all_bigram",
85 | },
86 | "text_bigram": {
87 | "type": "text",
88 | "analyzer": "bigram_analyzer",
89 | },
90 | "anchortext_bigram": {
91 | "type": "text",
92 | "analyzer": "bigram_analyzer",
93 | },
94 | "original_json": {"type": "string"},
95 | }
96 | }
97 | },
98 | "settings": {
99 | "analysis": {
100 | "my_english_analyzer": {
101 | "type": "standard",
102 | "stopwords": "_english_",
103 | },
104 | "simple_bigram_analyzer": {
105 | "tokenizer": "standard",
106 | "filter": ["lowercase", "shingle", "asciifolding"],
107 | },
108 | "bigram_analyzer": {
109 | "tokenizer": "standard",
110 | "filter": [
111 | "lowercase",
112 | "stop",
113 | "shingle",
114 | "asciifolding",
115 | ],
116 | },
117 | },
118 | },
119 | }
120 | ),
121 | )
122 |
123 | train = [
124 | json.loads(line.strip())
125 | for line in open("../../../data/corpus/musique/musique_ans_v1.0_train.jsonl")
126 | ] + [
127 | json.loads(line.strip())
128 | for line in open("../../../data/corpus/musique/musique_full_v1.0_train.jsonl")
129 | ]
130 | dev = [
131 | json.loads(line.strip())
132 | for line in open("../../../data/corpus/musique/musique_ans_v1.0_dev.jsonl")
133 | ] + [
134 | json.loads(line.strip())
135 | for line in open("../../../data/corpus/musique/musique_full_v1.0_dev.jsonl")
136 | ]
137 | test = [
138 | json.loads(line.strip())
139 | for line in open("../../../data/corpus/musique/musique_ans_v1.0_test.jsonl")
140 | ] + [
141 | json.loads(line.strip())
142 | for line in open("../../../data/corpus/musique/musique_full_v1.0_test.jsonl")
143 | ]
144 |
145 | tot = 0
146 | wikipedia_data = []
147 | hist = set()
148 | for item in tqdm(chain(train, dev, test)):
149 | for p in item["paragraphs"]:
150 | stamp = p["title"] + " " + p["paragraph_text"]
151 | if not stamp in hist:
152 | wikipedia_data.append(
153 | {"id": tot, "text": p["paragraph_text"], "title": p["title"]}
154 | )
155 | hist.add(stamp)
156 | tot += 1
157 |
158 | print("Making indexing queries...")
159 | all_queries = []
160 | for item in tqdm(wikipedia_data):
161 | all_queries.append(process_line(item))
162 |
163 | count = sum(len(queries.split("\n")) for queries in all_queries) // 2
164 |
165 | if not args.dry:
166 | print("Indexing...")
167 | chunksize = 100
168 | for chunk in tqdm(
169 | chunks(all_queries, chunksize),
170 | total=(len(all_queries) + chunksize - 1) // chunksize,
171 | ):
172 | res = es.bulk(index=INDEX_NAME, body="\n".join(chunk), timeout="100s")
173 | assert not res["errors"], res
174 |
175 | print(f"{count} documents indexed in total")
176 |
177 |
178 | if __name__ == "__main__":
179 | parser = ArgumentParser()
180 |
181 | parser.add_argument("--reindex", action="store_true", help="Reindex everything")
182 | parser.add_argument("--dry", action="store_true", help="Dry run")
183 |
184 | args = parser.parse_args()
185 |
186 | main(args)
187 |
--------------------------------------------------------------------------------
/src/es_retrieve.py:
--------------------------------------------------------------------------------
1 | import json
2 | from elasticsearch import Elasticsearch
3 | import json
4 | import re
5 | import yaml
6 |
7 | with open("../config/config.yaml", "r") as f:
8 | config = yaml.safe_load(f)
9 |
10 | core_title_matcher = re.compile("([^()]+[^\s()])(?:\s*\(.+\))?")
11 | core_title_filter = lambda x: (
12 | core_title_matcher.match(x).group(1) if core_title_matcher.match(x) else x
13 | )
14 |
15 |
16 | class ElasticSearch:
17 | def __init__(self, index_name):
18 | self.index_name = index_name
19 | self.client = Elasticsearch(config["es"]["url"])
20 |
21 | def _extract_one(self, item, lazy=False):
22 | res = {
23 | k: item["_source"][k]
24 | for k in ["id", "url", "title", "text", "title_unescape"]
25 | }
26 | res["_score"] = item["_score"]
27 | return res
28 |
29 | def rerank_with_query(self, query, results):
30 | def score_boost(item, query):
31 | score = item["_score"]
32 | core_title = core_title_filter(item["title_unescape"])
33 | if query.startswith("The ") or query.startswith("the "):
34 | query1 = query[4:]
35 | else:
36 | query1 = query
37 | if query == item["title_unescape"] or query1 == item["title_unescape"]:
38 | score *= 1.5
39 | elif (
40 | query.lower() == item["title_unescape"].lower()
41 | or query1.lower() == item["title_unescape"].lower()
42 | ):
43 | score *= 1.2
44 | elif item["title"].lower() in query:
45 | score *= 1.1
46 | elif query == core_title or query1 == core_title:
47 | score *= 1.2
48 | elif (
49 | query.lower() == core_title.lower()
50 | or query1.lower() == core_title.lower()
51 | ):
52 | score *= 1.1
53 | elif core_title.lower() in query.lower():
54 | score *= 1.05
55 |
56 | item["_score"] = score
57 | return item
58 |
59 | return list(
60 | sorted(
61 | [score_boost(item, query) for item in results],
62 | key=lambda item: -item["_score"],
63 | )
64 | )
65 |
66 | def single_text_query(self, query, topn=10, lazy=False, rerank_topn=50):
67 |
68 | constructed_query = {
69 | "multi_match": {
70 | "query": query,
71 | "fields": [
72 | "title^1.25",
73 | "title_unescape^1.25",
74 | "text",
75 | "title_bigram^1.25",
76 | "title_unescape_bigram^1.25",
77 | "text_bigram",
78 | ],
79 | }
80 | }
81 | res = self.client.search(
82 | index=self.index_name,
83 | body={"query": constructed_query, "timeout": "100s"},
84 | size=max(topn, rerank_topn),
85 | request_timeout=100,
86 | )
87 |
88 | res = [self._extract_one(x, lazy=lazy) for x in res["hits"]["hits"]]
89 | res = self.rerank_with_query(query, res)[:topn]
90 | res = [{"title": _["title"], "paragraph_text": _["text"]} for _ in res]
91 | return res
92 |
93 | def search(self, question, k=10):
94 | try:
95 | res = self.single_text_query(query=question, topn=k)
96 | return json.dumps(res, ensure_ascii=False)
97 | except Exception as err:
98 | print(Exception, err)
99 | raise
100 |
101 |
102 | def retrieve(index_name, query, topk):
103 | ES = ElasticSearch(index_name)
104 | result = ES.search(query, topk)
105 | return json.loads(result)
106 |
107 |
108 | if __name__ == "__main__":
109 |
110 | print(retrieve("musique", "Bilbo Baggins", 2))
111 |
--------------------------------------------------------------------------------
/src/eval.py:
--------------------------------------------------------------------------------
1 | import json
2 | import string
3 | from collections import Counter
4 | import re
5 | import numpy as np
6 |
7 |
8 | def acc_choice(predictions, answers):
9 | num_correct = 0
10 | total = len(predictions)
11 |
12 | for pred, ans_list in zip(predictions, answers):
13 | pred = pred.lower()
14 |
15 | correct = False
16 | for ans in ans_list:
17 | ans = ans.lower()
18 | if re.search(rf"\b{ans}\b", pred):
19 | correct = True
20 | break
21 | if correct:
22 | num_correct += 1
23 |
24 | acc = round(100 * num_correct / total, 2)
25 | return acc
26 |
27 |
28 | def acc_score(predictions, answers):
29 | num_correct = 0
30 | for id, answer in enumerate(answers):
31 | pred = predictions[id]
32 | correctness = (
33 | "True" if any(ans.lower() in pred.lower() for ans in answer) else "False"
34 | )
35 | if correctness == "True":
36 | num_correct += 1
37 | else:
38 | pass
39 | acc = num_correct / len(answers)
40 | return round(100 * acc, 2)
41 |
42 |
43 | def normalize_answer(s):
44 | """Lower text and remove punctuation, articles and extra whitespace."""
45 |
46 | def remove_articles(text):
47 | return re.sub(r"\b(a|an|the)\b", " ", text)
48 |
49 | def white_space_fix(text):
50 | return " ".join(text.split())
51 |
52 | def remove_punc(text):
53 | exclude = set(string.punctuation)
54 | return "".join(ch for ch in text if ch not in exclude)
55 |
56 | def lower(text):
57 | return text.lower()
58 |
59 | return white_space_fix(remove_articles(remove_punc(lower(s))))
60 |
61 |
62 | def f1_score(prediction, ground_truth):
63 | common = Counter(prediction) & Counter(ground_truth)
64 | num_same = sum(common.values())
65 | if num_same == 0:
66 | return 0
67 | precision = 1.0 * num_same / len(prediction)
68 | recall = 1.0 * num_same / len(ground_truth)
69 | f1 = (2 * precision * recall) / (precision + recall)
70 | return f1
71 |
72 |
73 | def qa_f1_score(prediction, ground_truth):
74 | normalized_prediction = normalize_answer(prediction)
75 | normalized_ground_truth = normalize_answer(ground_truth)
76 |
77 | prediction_tokens = normalized_prediction.split()
78 | ground_truth_tokens = normalized_ground_truth.split()
79 | return f1_score(prediction_tokens, ground_truth_tokens)
80 |
81 |
82 | def F1_scorer(predictions, answers):
83 | total_score = 0.0
84 | for prediction, ground_truths in zip(predictions, answers):
85 | score = 0.0
86 | for ground_truth in ground_truths:
87 | score = max(score, qa_f1_score(prediction, ground_truth))
88 | total_score += score
89 | return round(100 * total_score / len(predictions), 2)
90 |
91 |
92 | def compute_exact(predictions, answers):
93 | total_score = 0.0
94 | for prediction, ground_truths in zip(predictions, answers):
95 | score = 0.0
96 | for ground_truth in ground_truths:
97 | score = max(
98 | score,
99 | int(normalize_answer(prediction) == normalize_answer(ground_truth)),
100 | )
101 | total_score += score
102 | return round(100 * total_score / len(predictions), 2)
103 |
104 |
105 | def exact_presence(short_answers, context):
106 | """Verify if any of the answers is present in the given context.
107 | Args:
108 | short_answers: list of short answers to look for in the context
109 | context: a paragraph to search for short answers
110 | Returns:
111 | true if any of the short answers is present in the context
112 | """
113 |
114 | n_short_answers = [normalize_answer(sa) for sa in short_answers]
115 | n_context = normalize_answer(context)
116 |
117 | for ans in n_short_answers:
118 | if ans in n_context:
119 | return True
120 |
121 | return False
122 |
123 |
124 | def compute_str_em(data):
125 | """Compute STR-EM metric (only for ASQA)
126 | Args:
127 | data: requires field `qa_pairs/short_answers` and `output`
128 | Returns:
129 | STR-EM and STR-EM-HIT ()
130 | """
131 |
132 | acc = []
133 | hit = []
134 |
135 | for item in data:
136 | loc_acc = []
137 | for qa_pair in item["qa_pairs"]:
138 | loc_acc.append(exact_presence(qa_pair["short_answers"], item["output"]))
139 | acc.append(np.mean(loc_acc))
140 | hit.append(int(np.mean(loc_acc) == 1))
141 |
142 | return 100 * np.mean(acc), 100 * np.mean(hit)
143 |
144 |
145 | def eval_asqa(pred, raw_data_path="../data/eval/asqa/test.json"):
146 | with open(raw_data_path, "r") as f:
147 | raw_data = json.load(f)
148 | normalized_data = []
149 | for i in range(len(pred)):
150 | pred[i]["answer"] = pred[i]["output"].strip().replace("<|im_end|>", "")
151 | result = {}
152 | for id, data in enumerate(pred):
153 | normalized_data.append(
154 | {
155 | "question": raw_data[id]["question"],
156 | "qa_pairs": raw_data[id]["qa_pairs"],
157 | "output": data["answer"],
158 | }
159 | )
160 | result["str_em"], result["str_hit"] = compute_str_em(normalized_data)
161 |
162 | return result
163 |
164 |
165 | if __name__ == "__main__":
166 |
167 | from argparse import ArgumentParser
168 |
169 | parser = ArgumentParser()
170 | parser.add_argument(
171 | "--dataset",
172 | type=str,
173 | choices=["2wikimultihopqa", "hotpotqa", "musique", "asqa", "strategyqa"],
174 | default="hotpotqa",
175 | help="Dataset to use",
176 | )
177 | parser.add_argument("--pred_path", type=str, help="Location of the prediction file")
178 | args = parser.parse_args()
179 |
180 | outputs = []
181 | with open(args.pred_path, "r") as fin:
182 | for d in fin:
183 | pred = json.loads(d)
184 | if "id" in pred.keys():
185 | outputs.append(pred)
186 | if "asqa" in args.pred_path:
187 | result = eval_asqa(outputs)
188 | print("eval result:", result)
189 | else:
190 | predictions = [data["output"] for data in outputs]
191 | answers = [data["answer"] for data in outputs]
192 | if "strategyqa" in args.dataset:
193 | acc = acc_choice(predictions, answers)
194 | print("Acc:", acc)
195 | else:
196 | acc = acc_score(predictions, answers)
197 | f1 = F1_scorer(predictions, answers)
198 | em = compute_exact(predictions, answers)
199 |
200 | print("Acc:", acc, "F1:", f1, "EM:", em)
201 |
--------------------------------------------------------------------------------
/src/gen_dpo_data.py:
--------------------------------------------------------------------------------
1 | import random
2 | import json
3 | import os
4 | import datetime
5 | from tqdm import tqdm
6 | from torch.utils.data import Dataset, DataLoader
7 | from vllm import LLM, SamplingParams
8 | import argparse
9 | from es_retrieve import retrieve
10 | import yaml
11 | import multiprocessing as mp
12 | from utils import (
13 | p_template,
14 | seed_everything,
15 | LLM_score_gen_new_query,
16 | LLM_score_compare_note,
17 | LLM_score_init_note,
18 | LLM_score_rag,
19 | )
20 |
21 | with open("../config/config.yaml", "r") as f:
22 | config = yaml.safe_load(f)
23 |
24 | parser = argparse.ArgumentParser()
25 | parser.add_argument(
26 | "--dataset",
27 | type=str,
28 | choices=["2wikimultihopqa"],
29 | default="2wikimultihopqa",
30 | help="Name of the dataset to use",
31 | )
32 | parser.add_argument(
33 | "--input_data_path",
34 | type=str,
35 | default="../data/corpus/2wikimultihopqa/train.json",
36 | help="Path to the input data file in JSON/JSONL format",
37 | )
38 | parser.add_argument(
39 | "--output_path",
40 | type=str,
41 | default="../data/dpo_data",
42 | help="Output directory path for generated data",
43 | )
44 | parser.add_argument(
45 | "--batch_size", type=int, default=9, help="Batch size for data processing"
46 | )
47 | parser.add_argument(
48 | "--num_samples",
49 | type=int,
50 | default=15000,
51 | help="Number of samples to generate for each type.",
52 | )
53 | parser.add_argument(
54 | "--per_data_gen_num",
55 | type=int,
56 | default=9,
57 | help="The number of times to generate data for each sample.",
58 | )
59 | parser.add_argument(
60 | "--model",
61 | type=str,
62 | choices=["llama-3.1-8b-instruct", "qwen2.5-7b-instruct"],
63 | default="llama-3.1-8b-instruct",
64 | help="Base model used for data generation",
65 | )
66 | parser.add_argument(
67 | "--score_model",
68 | type=str,
69 | choices=["gpt-4o-mini-2024-07-18"],
70 | default="gpt-4o-mini-2024-07-18",
71 | help="Model used for scoring",
72 | )
73 | parser.add_argument(
74 | "--gpu_memory_utilization",
75 | type=float,
76 | default=0.9,
77 | help="GPU memory utilization limit",
78 | )
79 | parser.add_argument(
80 | "--device",
81 | type=str,
82 | default="0,1,2,3,4,5,6,7",
83 | help="Comma separated list of devices to use for parallel processing",
84 | )
85 | args = parser.parse_args()
86 |
87 |
88 | def load_and_sample_data(file_path, sample_size):
89 | with open(file_path, "r", encoding="utf-8") as f:
90 | data = (
91 | [json.loads(line.strip()) for line in f]
92 | if "jsonl" in file_path
93 | else json.load(f)
94 | )
95 | return random.sample(data, sample_size)
96 |
97 |
98 | def retrieve_q(question, top_k):
99 | refs = retrieve(args.dataset, question, topk=top_k)
100 | text = ""
101 | for ref in refs:
102 | text += f"Title: {ref['title']}\nText: {ref['paragraph_text']} \n\n"
103 | return text
104 |
105 |
106 | class CustomDataset(Dataset):
107 | def __init__(self, data_list, args):
108 | self.data_list = data_list
109 | self.args = args
110 |
111 | def __getitem__(self, index):
112 | item = self.data_list[index]
113 | item["id"] = index
114 | return item
115 |
116 | def __len__(self):
117 | return len(self.data_list)
118 |
119 | def collate_fn(self, batch):
120 | batch = [data for data in batch]
121 | if not batch:
122 | return None
123 | ids = [f["id"] for f in batch]
124 | questions = [f.get("question", None) for f in batch]
125 | answers = [f.get("answer", None) for f in batch]
126 | return {
127 | "ids": ids,
128 | "questions": questions,
129 | "answers": answers,
130 | }
131 |
132 |
133 | def gen_data(split, device_id):
134 | os.environ["CUDA_VISIBLE_DEVICES"] = str(device_id)
135 | llm = LLM(
136 | model=model_path,
137 | tensor_parallel_size=1,
138 | trust_remote_code=True,
139 | dtype="bfloat16",
140 | gpu_memory_utilization=args.gpu_memory_utilization,
141 | )
142 | sampling_params = SamplingParams(max_tokens=1024, temperature=1, top_p=0.9)
143 |
144 | def call_llm(prompts, params=[sampling_params]):
145 | if "llama" in args.model:
146 | model_template = "<|start_header_id|>user<|end_header_id|>\n\n{prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
147 | elif "qwen" in args.model or "minicpm" in args.model:
148 | model_template = (
149 | "<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n"
150 | )
151 | prompts = [model_template.format(prompt=p) for p in prompts]
152 |
153 | responses = [
154 | response.outputs[0].text for response in llm.generate(prompts, params)
155 | ]
156 | return responses
157 |
158 | temperature_list = [0.1, 0.5, 0.9]
159 | top_p_list = [0.1, 0.5, 0.9]
160 | param_combinations = [
161 | (temp, top_p) for temp in temperature_list for top_p in top_p_list
162 | ]
163 | param_combinations = (
164 | param_combinations[: args.per_data_gen_num]
165 | if args.per_data_gen_num <= len(param_combinations)
166 | else param_combinations
167 | + random.choices(
168 | param_combinations, k=args.per_data_gen_num - len(param_combinations)
169 | )
170 | )
171 |
172 | params_dict = {"max_tokens": 1024}
173 |
174 | for batch in tqdm(split):
175 | if not batch:
176 | continue
177 |
178 | batch.update({"prompts": []})
179 | (
180 | batch_temp_rag,
181 | batch_temp_gen_query,
182 | batch_temp_init_note,
183 | batch_temp_refine_note,
184 | ) = ([], [], [], [])
185 | # Generate rag
186 | for id, question in enumerate(batch["questions"]):
187 | top_k = random.randint(3, 9)
188 | refs = retrieve_q(question, top_k)
189 | rag_prompt = p_template("base", {"query": question, "refs": refs})
190 | init_note_prompt = p_template(
191 | "init_note", {"query": question, "refs": refs}
192 | )
193 | rag_sampling_params = [
194 | SamplingParams(**params_dict, temperature=temp, top_p=top_p)
195 | for temp, top_p in param_combinations
196 | ]
197 | rags_generated = call_llm(
198 | [rag_prompt] * len(rag_sampling_params), rag_sampling_params
199 | )
200 |
201 | best_rag, worst_rag = LLM_score_rag(rags_generated, batch["answers"][id])
202 |
203 | if best_rag and worst_rag:
204 | batch_temp_rag.append(
205 | {
206 | "id": f"rag-{batch['ids'][id]}",
207 | "raw_question": question,
208 | "prompt": rag_prompt,
209 | "chosen": best_rag,
210 | "rejected": worst_rag,
211 | "data_type": "rag",
212 | "gen_response_list": [
213 | {
214 | "index": idx,
215 | "response": response,
216 | "temperature": param_combination[0],
217 | "top_p": param_combination[1],
218 | }
219 | for idx, (response, param_combination) in enumerate(
220 | zip(rags_generated, param_combinations)
221 | )
222 | ],
223 | }
224 | )
225 |
226 | # Generate init_note
227 |
228 | init_note_sampling_params = [
229 | SamplingParams(**params_dict, temperature=temp, top_p=top_p)
230 | for temp, top_p in param_combinations
231 | ]
232 | init_notes_generated = call_llm(
233 | [init_note_prompt] * len(init_note_sampling_params),
234 | init_note_sampling_params,
235 | )
236 |
237 | best_init_note, worst_init_note = LLM_score_init_note(
238 | args.score_model, question, refs, init_notes_generated
239 | )
240 | if best_init_note and worst_init_note:
241 | batch_temp_init_note.append(
242 | {
243 | "id": f"init_note-{batch['ids'][id]}",
244 | "raw_question": question,
245 | "prompt": init_note_prompt,
246 | "chosen": best_init_note,
247 | "rejected": worst_init_note,
248 | "data_type": "init_note",
249 | "gen_response_list": [
250 | {
251 | "index": idx,
252 | "response": response,
253 | "temperature": param_combination[0],
254 | "top_p": param_combination[1],
255 | }
256 | for idx, (response, param_combination) in enumerate(
257 | zip(init_notes_generated, param_combinations)
258 | )
259 | ],
260 | }
261 | )
262 |
263 | # Generate queries
264 | gen_query_prompt = p_template(
265 | "gen_new_query",
266 | {"query": question, "note": best_init_note, "query_log": []},
267 | )
268 | gen_query_sampling_params = [
269 | SamplingParams(**params_dict, temperature=temp, top_p=top_p)
270 | for temp, top_p in param_combinations
271 | ]
272 | responses = call_llm(
273 | [gen_query_prompt] * len(gen_query_sampling_params),
274 | gen_query_sampling_params,
275 | )
276 |
277 | gen_querys = responses
278 |
279 | best_query, worst_query = LLM_score_gen_new_query(
280 | args.score_model, question, best_init_note, gen_querys, []
281 | )
282 |
283 | refs = retrieve_q((best_query + "\n" + question)[:500], top_k)
284 | refine_note_prompt = p_template(
285 | "refine_note", {"query": question, "refs": refs, "note": best_init_note}
286 | )
287 | batch["prompts"].append(refine_note_prompt)
288 | if best_query and worst_query:
289 | batch_temp_gen_query.append(
290 | {
291 | "id": f"new_query-{batch['ids'][id]}",
292 | "raw_question": question,
293 | "prompt": gen_query_prompt,
294 | "chosen": best_query,
295 | "rejected": worst_query,
296 | "data_type": "gen_new_query",
297 | "gen_response_list": [
298 | {
299 | "index": idx,
300 | "response": gen_query,
301 | "temperature": param_combination[0],
302 | "top_p": param_combination[1],
303 | }
304 | for idx, (gen_query, param_combination) in enumerate(
305 | zip(gen_querys, param_combinations)
306 | )
307 | ],
308 | }
309 | )
310 |
311 | # Generate refined notes
312 | for temp, top_p in param_combinations:
313 | sampling_params = SamplingParams(
314 | **params_dict, temperature=temp, top_p=top_p
315 | )
316 | prompts = batch["prompts"]
317 |
318 | responses = [call_llm(prompt, temp, top_p) for prompt in prompts]
319 |
320 | for idx, response in enumerate(responses):
321 | if len(batch_temp_refine_note) <= idx:
322 | batch_temp_refine_note.append(
323 | {
324 | "id": f"refine_note-{batch['ids'][idx]}",
325 | "raw_question": batch["questions"][idx],
326 | "prompt": prompts[idx],
327 | "data_type": "refine_note",
328 | "gen_response_list": [],
329 | }
330 | )
331 | score = LLM_score_compare_note(
332 | args.score_model, batch["questions"][idx], best_init_note, response
333 | )
334 | batch_temp_refine_note[idx]["gen_response_list"].append(
335 | {
336 | "index": len(batch_temp_refine_note[idx]["gen_response_list"]),
337 | "response": response,
338 | "score_flag": score,
339 | "temperature": temp,
340 | "top_p": top_p,
341 | }
342 | )
343 |
344 | for refine_note in batch_temp_refine_note:
345 | chosen_list = [
346 | item["response"]
347 | for item in refine_note["gen_response_list"]
348 | if item["score_flag"]
349 | ]
350 | rejected_list = [
351 | item["response"]
352 | for item in refine_note["gen_response_list"]
353 | if not item["score_flag"]
354 | ]
355 | refine_note["chosen"] = random.choice(chosen_list) if chosen_list else ""
356 | refine_note["rejected"] = (
357 | random.choice(rejected_list) if rejected_list else ""
358 | )
359 | batch_temp_refine_note = [
360 | refine_note
361 | for refine_note in batch_temp_refine_note
362 | if refine_note["chosen"] and refine_note["rejected"]
363 | ]
364 | batch_temp = (
365 | batch_temp_init_note
366 | + batch_temp_gen_query
367 | + batch_temp_refine_note
368 | + batch_temp_rag
369 | )
370 | if batch_temp:
371 | with open(output_file, "a", encoding="utf-8") as f:
372 | json_lines = [
373 | json.dumps(data, ensure_ascii=False) for data in batch_temp
374 | ]
375 | f.write("\n".join(json_lines) + "\n")
376 |
377 |
378 | def main():
379 |
380 | visible_devices = args.device.split(",")
381 | num_devices = len(visible_devices)
382 |
383 | data = load_and_sample_data(args.input_data_path, args.num_samples)
384 |
385 | split_data = lambda lst, n: [lst[i::n] for i in range(n)]
386 |
387 | dataset = CustomDataset(data, args)
388 | dataloader = DataLoader(
389 | dataset, batch_size=args.batch_size, collate_fn=dataset.collate_fn
390 | )
391 |
392 | splits = split_data(list(dataloader), num_devices)
393 | processes = []
394 |
395 | for rank, device_id in enumerate(visible_devices):
396 | split_dataloader = splits[rank]
397 | p = mp.Process(target=gen_data, args=(list(split_dataloader), device_id))
398 | p.start()
399 | processes.append(p)
400 |
401 | for p in processes:
402 | p.join()
403 |
404 |
405 | if __name__ == "__main__":
406 | current_date = datetime.datetime.now().strftime("%Y%m%d-%H:%M:%S")
407 | seed_everything(66)
408 | model_path = config["model"][args.model]
409 | output_file = os.path.join(
410 | args.output_path,
411 | f"dpo_data-{args.model}-score_{args.score_model}_num-{args.num_samples}_para-{args.per_data_gen_num}-{current_date}.jsonl",
412 | )
413 | main()
414 |
--------------------------------------------------------------------------------
/src/main.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import argparse
4 | import logging
5 | import datetime
6 | import yaml
7 | import faiss
8 | import pandas as pd
9 | from sentence_transformers import SentenceTransformer
10 | import backoff
11 | from multiprocessing.pool import ThreadPool
12 | from tqdm import tqdm
13 | from eval import acc_score, F1_scorer, compute_exact, eval_asqa, acc_choice
14 | from utils import seed_everything
15 | from vllm import LLM, SamplingParams
16 |
17 | parser = argparse.ArgumentParser()
18 | parser.add_argument(
19 | "--model",
20 | type=str,
21 | choices=[
22 | "gpt-4o-mini-2024-07-18",
23 | "Llama-3-1-70B-Instruct",
24 | "llama-3.1-8b-instruct",
25 | "llama-3.1-8b-instruct-dpo",
26 | "qwen2.5-7b-instruct",
27 | "qwen2.5-7b-instruct-dpo",
28 | ],
29 | default="llama-3.1-8b-instruct",
30 | help="Model to use",
31 | )
32 | parser.add_argument(
33 | "--max_step", type=int, default=3, help="Maximum number of update steps"
34 | )
35 | parser.add_argument(
36 | "--max_fail_step", type=int, default=2, help="Maximum number of failed steps"
37 | )
38 | parser.add_argument(
39 | "--MaxClients", type=int, default=1, help="Number of concurrent clients"
40 | )
41 | parser.add_argument(
42 | "--retrieve_top_k",
43 | type=int,
44 | default=5,
45 | help="Number of documents to retrieve per query",
46 | )
47 | parser.add_argument(
48 | "--max_top_k",
49 | type=int,
50 | default=15,
51 | help="Total maximum number of documents to retrieve",
52 | )
53 | parser.add_argument(
54 | "--dataset",
55 | type=str,
56 | choices=[
57 | "2wikimultihopqa",
58 | "hotpotqa",
59 | "musique",
60 | "asqa",
61 | "strategyqa",
62 | ],
63 | default="hotpotqa",
64 | help="Dataset to use",
65 | )
66 | parser.add_argument(
67 | "--method",
68 | type=str,
69 | default="deepnote",
70 | choices=["deepnote", "base", "base_wo_retri"],
71 | help="Method to use",
72 | )
73 | parser.add_argument(
74 | "--resume_path",
75 | type=str,
76 | default="",
77 | help="Path to the file for resuming generation",
78 | )
79 | parser.add_argument(
80 | "--retrieve_method",
81 | type=str,
82 | default="es",
83 | help="Retrieval method to use (es: ElasticSearch, emb: Dense Retrieval)",
84 | )
85 | parser.add_argument(
86 | "--device", type=str, default="cuda:0", help="Device to run inference on"
87 | )
88 | args = parser.parse_args()
89 |
90 | with open("../config/config.yaml", "r") as f:
91 | config = yaml.safe_load(f)
92 |
93 |
94 | @backoff.on_exception(backoff.expo, (Exception), max_time=500)
95 | def call_gpt(prompt_file, variable_dict):
96 | with open(prompt_file, "r") as fin:
97 | prompt = fin.read().format(**variable_dict)
98 | res = gpt_gen(args.model, prompt)
99 | assert res is not None
100 | return res
101 |
102 |
103 | def call_local(prompt_file, variable_dict):
104 |
105 | with open(prompt_file, "r") as fin:
106 | prompt = fin.read()
107 | if "llama" in args.model:
108 | model_template = "<|start_header_id|>user<|end_header_id|>\n\n{prompt}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
109 | if "qwen" in args.model:
110 | model_template = "<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n"
111 | prompt = model_template.format(prompt=prompt.format(**variable_dict))
112 | response = llm.generate(prompt, sampling_params)[0].outputs[0].text
113 | return response
114 |
115 |
116 | if "gpt" in args.model:
117 | from utils import gpt_gen
118 |
119 | call_llm = call_gpt
120 |
121 | else:
122 | llm = LLM(
123 | model=config["model"][args.model],
124 | tensor_parallel_size=1,
125 | trust_remote_code=True,
126 | dtype="bfloat16",
127 | gpu_memory_utilization=0.9,
128 | )
129 | sampling_params = SamplingParams(max_tokens=1280, temperature=0.1, top_p=0.9)
130 |
131 | call_llm = call_local
132 |
133 |
134 | def get_context(data):
135 |
136 | if retrieve_method == "emb":
137 | text = "\n".join(data)
138 | else:
139 | text = ""
140 | for i in range(len(data)):
141 | text += f"Title: {data[i]['title']}\nText: {data[i]['paragraph_text']} \n\n"
142 | return text
143 |
144 |
145 | def save_log_to_file(logger, log_file="my_log", log_folder="logs"):
146 | if not os.path.exists(log_folder):
147 | os.makedirs(log_folder)
148 | current_date = datetime.datetime.now().strftime("%Y%m%d-%H:%M:%S")
149 | log_file_name = f"{log_file}-{current_date}.log"
150 | file_handler = logging.FileHandler(os.path.join(log_folder, log_file_name))
151 | logger.addHandler(file_handler)
152 |
153 |
154 | def is_directory_empty(directory_path: str) -> bool:
155 | try:
156 | return len(os.listdir(directory_path)) == 0
157 | except OSError:
158 | return False
159 |
160 |
161 | def call_llm_template(template, variables):
162 | return call_llm(f"../prompts/{LAUGUAGE}/{template}", variables)
163 |
164 |
165 | def init_note(query, refs):
166 | return call_llm_template("init_note", {"query": query, "refs": refs})
167 |
168 |
169 | def gen_new_query(query, note, query_log):
170 | return call_llm_template(
171 | "gen_new_query", {"query": query, "note": note, "query_log": query_log}
172 | )
173 |
174 |
175 | def refine_note(note, query, refs):
176 | return call_llm_template(
177 | "refine_note", {"note": note, "query": query, "refs": refs}
178 | )
179 |
180 |
181 | def gen_answer(query, note_str):
182 | if args.dataset in ["asqa", "strategyqa"]:
183 | template = f"gen_answer_{args.dataset}"
184 | else:
185 | template = "gen_answer"
186 | return call_llm_template(template, {"query": query, "note": note_str})
187 |
188 |
189 | def compare_note(query, best_note, new_note):
190 | response = call_llm_template(
191 | "compare", {"query": query, "best_note": best_note, "new_note": new_note}
192 | )
193 |
194 | try:
195 | if "json" in response:
196 | response = response[response.index("{") : response.rindex("}") + 1]
197 | status = json.loads(response)["status"]
198 | return status.lower() == "true"
199 | except:
200 | return "true" in response.lower()
201 |
202 |
203 | def retrieve_note(doc_id, query, answer, top_k=2):
204 | ref_log, llm_times, note_log, query_log, query_list = [], 0, [], [], []
205 | refs = retrieve(args.dataset, query=query, topk=top_k)
206 | note = best_note = init_note(query, get_context(refs))
207 | ref_log.append({"refs": refs, "step": 0, "flag": "init_refs"})
208 | note_log.append({"note": note, "step": 0, "flag": "init_note"})
209 | llm_times += 1
210 | step, notes_status = 0, []
211 |
212 | while step < args.max_step:
213 | all_refs = {
214 | ref
215 | for d in ref_log
216 | for ref in (
217 | d["refs"]
218 | if retrieve_method == "emb"
219 | else [d["title"] + d["paragraph_text"] for d in d["refs"]]
220 | )
221 | }
222 | if len(all_refs) > args.max_top_k:
223 | break
224 | max_ref = (
225 | args.max_top_k - len(all_refs)
226 | if (len(all_refs) + top_k) > args.max_top_k
227 | else 0
228 | )
229 |
230 | new_query = gen_new_query(query, best_note, str(query_list))
231 | query_list.append(new_query)
232 | llm_times += 1
233 |
234 | refs = retrieve(
235 | args.dataset, query=(new_query + "\n" + query)[:500], topk=top_k
236 | )
237 | if max_ref > 0:
238 | refs = (
239 | [d for d in refs if d not in all_refs][:max_ref]
240 | if retrieve_method == "emb"
241 | else [
242 | d for d in refs if d["title"] + d["paragraph_text"] not in all_refs
243 | ][:max_ref]
244 | )
245 |
246 | note = refine_note(best_note, query, get_context(refs)).replace("\n", "")
247 | llm_times += 1
248 | status = compare_note(query, best_note, note)
249 | flag = "True" if status else "False"
250 |
251 | ref_log.append({"refs": refs, "step": step, "flag": flag})
252 | note_log.append({"note": note, "step": step, "flag": flag})
253 | query_log.append({"query": new_query, "step": step, "flag": flag})
254 |
255 | if status:
256 | best_note = note
257 | notes_status.append(status)
258 | if notes_status.count(False) >= args.max_fail_step:
259 | break
260 | step += 1
261 |
262 | llm_times += 1
263 | return {
264 | "id": doc_id,
265 | "question": query,
266 | "answer": answer,
267 | "output": gen_answer(query, best_note),
268 | "deepnote": best_note,
269 | "query_log": query_log,
270 | "note_log": note_log,
271 | "ref_log": ref_log,
272 | }
273 |
274 |
275 | def process_doc_cell(idx, doc_cell, args):
276 | id_new, query, answer = idx, doc_cell["question"], doc_cell["answer"]
277 |
278 | if args.method != "deepnote":
279 | prompt_path = f"../prompts/{LAUGUAGE}"
280 | if args.dataset in ["asqa", "strategyqa"]:
281 | gen_name = f"{args.method}_{args.dataset}"
282 | else:
283 | gen_name = args.method
284 | output = (
285 | call_llm(f"{prompt_path}/{gen_name}", {"query": query})
286 | if "wo_retri" in args.method
287 | else call_llm(
288 | f"{prompt_path}/{gen_name}",
289 | {
290 | "query": query,
291 | "refs": get_context(
292 | retrieve(args.dataset, query=query, topk=args.retrieve_top_k)
293 | ),
294 | },
295 | )
296 | )
297 |
298 | return {"id": id_new, "query": query, "answer": answer, "output": output}
299 | else:
300 | return retrieve_note(id_new, query, answer, top_k=args.retrieve_top_k)
301 |
302 |
303 | if __name__ == "__main__":
304 |
305 | LAUGUAGE = "en"
306 | seed_everything(66)
307 |
308 | logger = logging.getLogger(__name__)
309 | logging.basicConfig(
310 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
311 | datefmt="%m/%d/%Y %H:%M:%S",
312 | level=logging.INFO,
313 | )
314 | save_log_to_file(
315 | logger,
316 | log_file=f"{args.dataset}_{args.method}_{args.model}",
317 | log_folder="../log",
318 | )
319 | logger.info(f"{'*' * 30} CONFIGURATION {'*' * 30}")
320 | for key, val in sorted(vars(args).items()):
321 | keystr = "{}".format(key) + (" " * (30 - len(key)))
322 | logger.info("%s --> %s", keystr, val)
323 | dataset_name = args.dataset
324 | vector_path = f"../data/corpus/{dataset_name}/{dataset_name}.index"
325 | if args.dataset == "asqa" or args.dataset == "strategyqa":
326 |
327 | vector = faiss.read_index(f"../data/corpus/wiki/wiki.index")
328 | emb_model = SentenceTransformer(config["model"]["gtr-t5-xxl"], device="cpu")
329 | raw_data = pd.read_csv("../data/corpus/wiki/psgs_w100.tsv", sep="\t")
330 |
331 | def retrieve(_, query, topk):
332 | feature = emb_model.encode([query])
333 | _, match_id = vector.search(feature, topk)
334 | return [
335 | raw_data.iloc[i]["title"] + "\n" + raw_data.iloc[i]["text"]
336 | for i in match_id[0]
337 | ]
338 |
339 | elif args.retrieve_method == "emb":
340 |
341 | emb_model = SentenceTransformer(
342 | config["model"]["bge-base-en-v1.5"], device=args.device
343 | )
344 | with open(f"../data/corpus/{args.dataset}/chunk.json", encoding="utf-8") as f:
345 | raw_data = json.load(f)
346 | vector = faiss.read_index(vector_path)
347 |
348 | def retrieve(_, query, topk):
349 | feature = emb_model.encode([query])
350 | _, match_id = vector.search(feature, topk)
351 | return [raw_data[i] for i in match_id[0]]
352 |
353 | else:
354 | from es_retrieve import retrieve
355 |
356 | formatted_time = datetime.datetime.now().strftime("%Y%m%d-%H:%M:%S")
357 |
358 | with open(f"../data/eval/{args.dataset}/test.json", encoding="utf-8") as f:
359 | qa_data = json.load(f)
360 |
361 | retrieve_method = args.retrieve_method
362 | if args.dataset in ["asqa", "strategyqa"]:
363 | retrieve_method = "emb"
364 |
365 | save_path = f"../output/{args.dataset}/{retrieve_method}/{args.method}/{args.model}"
366 | os.makedirs(save_path, exist_ok=True)
367 |
368 | all_result = []
369 |
370 | if args.resume_path:
371 | with open(args.resume_path, "r", encoding="utf-8") as fin:
372 | resume_data = [json.loads(i) for i in fin.readlines()]
373 | all_result = resume_data
374 | filepath = args.resume_path
375 | else:
376 | resume_data = []
377 | filepath = (
378 | f"{save_path}/topk-{args.retrieve_top_k}-{formatted_time}.jsonl"
379 | if args.method != "deepnote"
380 | else f"{save_path}/topk-{args.retrieve_top_k}__max_step-{args.max_step}__max_fail_step-{args.max_fail_step}-{formatted_time}.jsonl"
381 | )
382 | logger.info(f"The predicted results will be saved in '{filepath}'.")
383 | last_id = len(resume_data)
384 | batch_size = args.MaxClients
385 | logger.info("start predicting ...")
386 | for i in tqdm(range(last_id, len(qa_data), batch_size)):
387 | pool = ThreadPool(processes=args.MaxClients)
388 | current_batch = qa_data[i : i + batch_size]
389 | tasks = [
390 | (idx + i, doc_cell, args) for idx, doc_cell in enumerate(current_batch)
391 | ]
392 |
393 | results = pool.starmap(process_doc_cell, tasks)
394 | pool.close()
395 | pool.join()
396 |
397 | for result in results:
398 | if result:
399 | all_result.append(result)
400 | with open(filepath, "a", buffering=1) as fout:
401 | fout.write(json.dumps(result, ensure_ascii=False) + "\n")
402 |
403 | logger.info("start evaluating ...")
404 |
405 | predictions = [data["output"] for data in all_result]
406 | answers = [data["answer"] for data in all_result]
407 |
408 | if "asqa" in args.dataset:
409 | eval_result = eval_asqa(all_result)
410 | elif "strategyqa" in args.dataset:
411 | acc = acc_choice(predictions, answers)
412 | eval_result = {"Acc": acc}
413 | else:
414 | acc = acc_score(predictions, answers)
415 | f1 = F1_scorer(predictions, answers)
416 | em = compute_exact(predictions, answers)
417 | eval_result = {"Acc": acc, "F1": f1, "EM": em}
418 |
419 | if eval_result:
420 | with open(filepath, "a", buffering=1) as fout:
421 | fout.write(json.dumps(eval_result, ensure_ascii=False) + "\n")
422 |
423 | logger.info(f"eval result: {eval_result}")
424 |
--------------------------------------------------------------------------------
/src/select_dpo_data.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import re
4 | import random
5 |
6 | parser = argparse.ArgumentParser()
7 | parser.add_argument(
8 | "--init_num", type=int, default=1900, help="Number of init_note samples"
9 | )
10 | parser.add_argument(
11 | "--refine_num", type=int, default=1900, help="Number of refine_note samples"
12 | )
13 | parser.add_argument(
14 | "--query_num", type=int, default=1900, help="Number of gen_query samples"
15 | )
16 | parser.add_argument("--rag_num", type=int, default=300, help="Number of RAG samples")
17 | parser.add_argument(
18 | "--data_path", type=str, required=True, help="Path to input data file"
19 | )
20 | parser.add_argument(
21 | "--output_path",
22 | type=str,
23 | default="../data/dpo/processed/train.jsonl",
24 | help="Path to output file",
25 | )
26 | args = parser.parse_args()
27 |
28 | with open(args.data_path, "r", encoding="utf-8") as file:
29 | data = [json.loads(i) for i in file]
30 | random.shuffle(data)
31 |
32 |
33 | def extract_and_match(input):
34 | match = re.search(r"-(\d+)", input)
35 | return match.group(1)
36 |
37 |
38 | # Collect samples for each category
39 | init_data, refine_data, query_data, rag_data = [], [], [], []
40 | for d in data:
41 | if not (d["chosen"] and d["rejected"]):
42 | continue
43 |
44 | fields = {"prompt": d["prompt"], "chosen": d["chosen"], "rejected": d["rejected"]}
45 |
46 | if "init" in d["id"] and len(init_data) < args.init_num:
47 | init_data.append(fields)
48 | elif "refine" in d["id"] and len(refine_data) < args.refine_num:
49 | refine_data.append(fields)
50 | elif "query" in d["id"] and len(query_data) < args.query_num:
51 | query_data.append(fields)
52 | elif "rag" in d["id"] and len(rag_data) < args.rag_num:
53 | rag_data.append(fields)
54 |
55 | # Combine and save processed data
56 | train_data = init_data + refine_data + query_data + rag_data
57 | random.shuffle(train_data)
58 |
59 | with open(args.output_path, "w") as fout:
60 | for res in train_data:
61 | fout.write(json.dumps(res, ensure_ascii=False) + "\n")
62 | print(
63 | f"Init: {len(init_data)}, Refine: {len(refine_data)}, Query: {len(query_data)}, RAG: {len(rag_data)}"
64 | )
65 | print(f"Total training samples: {len(train_data)}")
66 |
--------------------------------------------------------------------------------
/src/train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import torch
4 | from datasets import load_dataset
5 | from transformers import AutoModelForCausalLM, AutoTokenizer
6 | from functools import partial
7 | from trl import (
8 | DPOConfig,
9 | DPOTrainer,
10 | ModelConfig,
11 | ScriptArguments,
12 | TrlParser,
13 | get_kbit_device_map,
14 | get_peft_config,
15 | get_quantization_config,
16 | )
17 | from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE
18 |
19 |
20 | def preprocessing(example, args, tokenizer):
21 | prompt = example["prompt"]
22 | if "llama" in model_args.model_name_or_path.lower():
23 | template = "<|start_header_id|>user<|end_header_id|>\n\n{prompt}<|eot_id|>\n<|start_header_id|>assistant<|end_header_id|>\n\n"
24 | elif "qwen" in model_args.model_name_or_path.lower():
25 | template = "<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n"
26 |
27 | # Ensure the prompt does not exceed the max prompt length
28 | token_prompt = tokenizer(prompt, truncation=True, max_length=args.max_prompt_length)
29 | prompt = tokenizer.decode(token_prompt.input_ids)
30 | prompt = template.format(prompt=prompt)
31 |
32 | chosen = example["chosen"]
33 | rejected = example["rejected"]
34 |
35 | one_item = {"prompt": prompt, "chosen": chosen, "rejected": rejected}
36 | return one_item
37 |
38 |
39 | def main(script_args, training_args, model_args):
40 | ################
41 | # Model & Tokenizer
42 | ###################
43 | torch_dtype = (
44 | model_args.torch_dtype
45 | if model_args.torch_dtype in ["auto", None]
46 | else getattr(torch, model_args.torch_dtype)
47 | )
48 | quantization_config = get_quantization_config(model_args)
49 | model_kwargs = dict(
50 | revision=model_args.model_revision,
51 | attn_implementation=model_args.attn_implementation,
52 | torch_dtype=torch_dtype,
53 | use_cache=False if training_args.gradient_checkpointing else True,
54 | device_map=get_kbit_device_map() if quantization_config is not None else None,
55 | quantization_config=quantization_config,
56 | )
57 | model = AutoModelForCausalLM.from_pretrained(
58 | model_args.model_name_or_path,
59 | trust_remote_code=model_args.trust_remote_code,
60 | **model_kwargs
61 | )
62 | peft_config = get_peft_config(model_args)
63 | if peft_config is None:
64 | ref_model = AutoModelForCausalLM.from_pretrained(
65 | model_args.model_name_or_path,
66 | trust_remote_code=model_args.trust_remote_code,
67 | **model_kwargs
68 | )
69 | else:
70 | ref_model = None
71 | tokenizer = AutoTokenizer.from_pretrained(
72 | model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code
73 | )
74 | if tokenizer.pad_token is None:
75 | tokenizer.pad_token = tokenizer.eos_token
76 | if tokenizer.chat_template is None:
77 | tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE
78 | if script_args.ignore_bias_buffers:
79 | # torch distributed hack
80 | model._ddp_params_and_buffers_to_ignore = [
81 | name for name, buffer in model.named_buffers() if buffer.dtype == torch.bool
82 | ]
83 |
84 | ################
85 | # Dataset
86 | ################
87 |
88 | partial_preprocess = partial(preprocessing, args=training_args, tokenizer=tokenizer)
89 | dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
90 | dataset = dataset.map(partial_preprocess)
91 |
92 | ##########
93 | # Training
94 | ################
95 | trainer = DPOTrainer(
96 | model,
97 | ref_model,
98 | args=training_args,
99 | train_dataset=dataset[script_args.dataset_train_split],
100 | eval_dataset=(
101 | dataset[script_args.dataset_test_split]
102 | if training_args.eval_strategy != "no"
103 | else None
104 | ),
105 | processing_class=tokenizer,
106 | peft_config=peft_config,
107 | )
108 |
109 | trainer.train()
110 |
111 | if training_args.eval_strategy != "no":
112 | metrics = trainer.evaluate()
113 | trainer.log_metrics("eval", metrics)
114 | trainer.save_metrics("eval", metrics)
115 |
116 | # Save and push to hub
117 | trainer.save_model(training_args.output_dir)
118 | if training_args.push_to_hub:
119 | trainer.push_to_hub(dataset_name=script_args.dataset_name)
120 |
121 |
122 | def make_parser(subparsers: argparse._SubParsersAction = None):
123 | dataclass_types = (ScriptArguments, DPOConfig, ModelConfig)
124 | if subparsers is not None:
125 | parser = subparsers.add_parser(
126 | "dpo", help="Run the DPO training script", dataclass_types=dataclass_types
127 | )
128 | else:
129 | parser = TrlParser(dataclass_types)
130 | return parser
131 |
132 |
133 | if __name__ == "__main__":
134 | parser = make_parser()
135 | script_args, training_args, model_args = parser.parse_args_and_config()
136 | main(script_args, training_args, model_args)
137 |
--------------------------------------------------------------------------------
/src/train.sh:
--------------------------------------------------------------------------------
1 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --nnodes=1 --nproc_per_node=8 --master_addr localhost --master_port 7428 --node_rank 0 train.py \
2 | --model_name_or_path meta-llama/Llama-3.1-8B-Instruct \
3 | --dataset_name ../data/dpo/processed/train \
4 | --trust_remote_code \
5 | --max_length 8192 \
6 | --max_prompt_length 8000 \
7 | --output_dir ../save_model/llama \
8 | --save_steps 500 \
9 | --gradient_accumulation_steps 1 \
10 | --per_device_train_batch_size 1 \
11 | --per_device_eval_batch_size 1 \
12 | --learning_rate 5e-7 \
13 | --logging_strategy steps \
14 | --logging_steps 50 \
15 | --logging_dir ../save_model/llama \
16 | --bf16 True \
17 | --num_train_epochs 1 \
18 | --report_to "tensorboard" \
19 | --save_only_model \
20 | --gradient_checkpointing \
21 | --deepspeed ../config/ds_config_zero3.json
22 |
--------------------------------------------------------------------------------
/src/utils.py:
--------------------------------------------------------------------------------
1 | import os, time, json, re
2 | import torch
3 | from openai import OpenAI
4 | import random
5 | import numpy as np
6 | import yaml
7 | import backoff
8 |
9 | with open("../config/config.yaml", "r") as f:
10 | config = yaml.safe_load(f)
11 | os.environ["OPENAI_API_KEY"] = config["model"]["OPENAI_API_KEY"]
12 | os.environ["OPENAI_BASE_URL"] = config["model"]["OPENAI_BASE_URL"]
13 | client = OpenAI()
14 |
15 |
16 | def p_template(template, variables):
17 | prompt_file = f"../prompts/en/{template}"
18 | with open(prompt_file, "r") as fin:
19 | prompt = fin.read()
20 | prompt = prompt.format(**variables)
21 | return prompt
22 |
23 |
24 | def gpt_gen(model, content, temperature=0.1, top_p=0.9):
25 | try:
26 | completion = client.chat.completions.create(
27 | model=model,
28 | temperature=temperature,
29 | top_p=top_p,
30 | max_tokens=1280,
31 | messages=[{"role": "user", "content": content}],
32 | )
33 |
34 | return completion.choices[0].message.content
35 |
36 | except Exception as e:
37 | print(f"An error occurred: {e}")
38 | time.sleep(0.5)
39 |
40 | return None
41 |
42 |
43 | @backoff.on_exception(backoff.expo, (Exception), max_time=100)
44 | def call_gpt(model, content, temperature=1, top_p=0.9):
45 | res = gpt_gen(model, content, temperature, top_p)
46 | assert res is not None
47 | return res
48 |
49 |
50 | def extract_best_worst(sequence):
51 | try:
52 | if "json" in sequence:
53 | sequence = sequence[sequence.index("{") : sequence.rindex("}") + 1]
54 | result = json.loads(sequence)
55 | best = result["best_id"]
56 | worst = result["worst_id"]
57 |
58 | except:
59 |
60 | pattern_best = re.compile(r'"best_id"\s*:\s*(\d+)')
61 | pattern_worst = re.compile(r'"worst_id"\s*:\s*(\d+)')
62 |
63 | best_match = pattern_best.search(sequence)
64 | worst_match = pattern_worst.search(sequence)
65 |
66 | best = int(best_match.group(1)) if best_match else None
67 | worst = int(worst_match.group(1)) if worst_match else None
68 | if best == worst:
69 | return None, None
70 | return best, worst
71 |
72 |
73 | def LLM_score_rag(responses, answer):
74 | chosen, rejected = [], []
75 | for item in responses:
76 | if item.strip().lower() == answer.lower():
77 | chosen.append(item)
78 | if answer.lower() not in item.lower():
79 | rejected.append(item)
80 | if len(chosen) == 0:
81 | for item in responses:
82 | if answer.strip().lower() in item.lower() and len(item) < 2 * len(answer):
83 | chosen.append(item)
84 | if len(rejected) == 0:
85 | for item in responses:
86 | if len(item.strip()) > 2 * len(answer):
87 | rejected.append(item)
88 | if len(chosen) > 0 and len(rejected) > 0:
89 | return random.sample(chosen, 1)[0], random.sample(rejected, 1)[0]
90 | else:
91 | return "", ""
92 |
93 |
94 | def LLM_score_init_note(model_name, question, refs, init_notes):
95 | init_notes_prompt = [
96 | {"_id": id, "content": content} for id, content in enumerate(init_notes)
97 | ]
98 | prompt = config["score"]["init_note"].format(
99 | query=question, refs=refs, notes=init_notes_prompt
100 | )
101 | response = call_gpt(model_name, prompt)
102 | best, worst = extract_best_worst(response)
103 | try:
104 | best_note, worst_note = init_notes[best], init_notes[worst]
105 | except:
106 | return init_notes[-1], ""
107 | return best_note, worst_note
108 |
109 |
110 | def LLM_score_gen_new_query(
111 | model_name, question, best_init_note, new_querys, query_log
112 | ):
113 | querys_prompt = [
114 | {"_id": id, "content": content} for id, content in enumerate(new_querys)
115 | ]
116 | prompt = config["score"]["gen_new_query"].format(
117 | notes=best_init_note,
118 | query=question,
119 | query_log=query_log,
120 | new_querys=querys_prompt,
121 | )
122 | response = call_gpt(model_name, prompt)
123 | best, worst = extract_best_worst(response)
124 |
125 | try:
126 | best_query, worst_query = new_querys[best], new_querys[worst]
127 | except:
128 | return new_querys[-1], ""
129 | return best_query, worst_query
130 |
131 |
132 | def LLM_score_compare_note(model_name, query, best_note, new_note):
133 | prompt = p_template(
134 | "compare", {"query": query, "best_note": best_note, "new_note": new_note}
135 | )
136 | response = call_gpt(model_name, prompt)
137 |
138 | try:
139 | if "json" in response:
140 | response = response[response.index("{") : response.rindex("}") + 1]
141 | status = json.loads(response)["status"]
142 | return status.lower() == "true"
143 | except:
144 | return "true" in response.lower()
145 |
146 |
147 | def seed_everything(seed):
148 | torch.manual_seed(seed)
149 | torch.cuda.manual_seed(seed)
150 | np.random.seed(seed)
151 | random.seed(seed)
152 | torch.backends.cudnn.benchmark = False
153 | torch.backends.cudnn.deterministic = True
154 | torch.cuda.manual_seed_all(seed)
155 |
--------------------------------------------------------------------------------