├── BERRI
├── README.md
├── berri_instructions.tsv
├── create_tart_dual_train_data.py
├── create_tart_full_train_data.py
├── denoising.py
└── enc_t5
│ ├── __init__.py
│ ├── modeling_enc_t5.py
│ └── tokenization_enc_t5.py
├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── TART
├── custom_metrics.py
├── eval_beir.py
├── eval_cross_task.py
├── finetuning.py
├── finetuning_tart_full.py
├── generate_passage_embeddings.py
├── interactive.py
├── passage_retrieval.py
├── preprocess.py
├── requirements.txt
├── src
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-39.pyc
│ │ ├── contriever.cpython-39.pyc
│ │ ├── data.cpython-39.pyc
│ │ ├── dist_utils.cpython-39.pyc
│ │ ├── evaluation.cpython-39.pyc
│ │ ├── index.cpython-39.pyc
│ │ ├── modeling_enc_t5.cpython-39.pyc
│ │ ├── normalize_text.cpython-39.pyc
│ │ ├── slurm.cpython-39.pyc
│ │ ├── tokenization_enc_t5.cpython-39.pyc
│ │ └── utils.cpython-39.pyc
│ ├── beir_utils.py
│ ├── contriever.py
│ ├── data.py
│ ├── dist_utils.py
│ ├── evaluation.py
│ ├── finetuning_data.py
│ ├── inbatch.py
│ ├── index.py
│ ├── moco.py
│ ├── modeling_enc_t5.py
│ ├── normalize_text.py
│ ├── options.py
│ ├── rerank.py
│ ├── slurm.py
│ ├── tokenization_enc_t5.py
│ └── utils.py
└── train.py
├── cross_task_cross_eval
├── create_cross_task_data.py
└── download_create_data.sh
└── figures
└── intro.png
/BERRI/README.md:
--------------------------------------------------------------------------------
1 | # BERRI
2 |
3 | **BERRI** is a collection of retrieval task with instructions.
4 |
5 | Due to some legal reasons, Meta cannot directly release the preprocessed scripts, so this repository contains the script to re-process and create data.
6 | You can also download the data processed by third party below:
7 |
8 | You can download the processed source data (from the process (i)) as well as the final training data for TART-dual and full, processed by a third party here:
9 | - [source data (22 GB)](https://drive.google.com/file/d/1hzlN4cEFOZRkdVeCMq62NUxvMNTopB1o/view?usp=share_link)
10 | - [TART-full training data (1 GB)](https://drive.google.com/file/d/1oijzAb2gWKT54OgeE7_KB9VcHvA7UxpQ/view?usp=share_link)
11 | - [TART-dual training data (14 GB)](https://drive.google.com/file/d/1lMmD5lTxYWYf0z0ua0-GaGKz2qs2mG1r/view?usp=share_link)
12 |
13 |
14 | ## Preprocessing
15 | First please download the corpus (`corpus.tsv`) and source data file (`qa.jsonl`) for all retrieval tasks [here]((https://drive.google.com/file/d/1hzlN4cEFOZRkdVeCMq62NUxvMNTopB1o/view?usp=share_link).
16 |
17 |
18 | ### Step 1: run Contriever to find the top documents
19 |
20 | - generate embeddings
21 | First, generate passage embeddings using `facebook/contriever-msmarco`.
22 |
23 | ```sh
24 | cd ../TART
25 | for i in {0..7}; do
26 | export CUDA_VISIBLE_DEVICES=${i}
27 | nohup python generate_passage_embeddings.py --model_name_or_path facebook/contriever-msmarco --output_dir OUTPUT_DIR_NAME \
28 | --passages ../BERRI/berri_corpus_data/TASK_NAME/corpus.tsv --shard_id ${i} --num_shards 8 > ./log/nohup.log.${i} 2>&1 &
29 | done
30 | ```
31 |
32 | Then, retrieve top passages as follows:
33 |
34 | ```
35 | python passage_retrieval.py \
36 | --model_name_or_path facebook/contriever-msmarco \
37 | --passages ../BERRI/berri_corpus_data/TASK_NAME/corpus.tsv \
38 | --passages_embeddings "YOUR_EMBEDDING_PATH/*" \
39 | --data ../BERRI/berri_corpus_data/TASK_NAME/qa_data.json \
40 | --output_dir PATH_TO_OUTPUT_DIR --n_docs 100
41 | ```
42 |
43 | ### Step 2: Denoise the passages
44 |
45 | ```
46 | python denoising.py \
47 | --task_name TASK_NAME \
48 | --train_file PATH_TO_TRAIN_FILE \
49 | --test_input_file output_dir/qa_data.json \
50 | --model_name_or_path PATH_TO_DENOISING_MODEL_NAME \
51 | --output_dir PATH_TO_OUTPUT_DIR \
52 | --do_predict \
53 | --evaluation_strategy steps \
54 | --max_seq_length 512 --overwrite_cache --top_k 30 \
55 | --instruction_file berri_instructions.tsv # only for creating tart-dual training data.
56 | ```
57 |
58 | ### Step 3: Combine denoised results and create training data
59 |
60 | - Creating TART-dual training data
61 | ```
62 | python create_tart_dual_train_data.py \
63 | --inst_file berri_instructions.tsv \
64 | --output_file PATH_TO_OUTPUT_DIR \
65 | --input_dir PATH_TO_DENOISED_RESULTS \
66 | ```
67 | - Creating TART-full training data
68 | ```
69 | python create_tart_full_train_data.py \
70 | --input_file berri_instructions.tsv \
71 | --output_dir --output_file PATH_TO_OUTPUT_DIR \
72 | --input_dir PATH_TO_DENOISED_RESULTS \
73 | ```
74 |
75 |
76 |
77 |
--------------------------------------------------------------------------------
/BERRI/berri_instructions.tsv:
--------------------------------------------------------------------------------
1 | dataset prompt_1 prompt_2 prompt_3 prompt_4 prompt_5 prompt_6 prompt_7 prompt_8
2 | altlex Find an sentence from Simple Wikipedia that corresponds to the following Wikipedia sentence Retrieve a sentence that talks about the same as the following in a simple or shorter English A simplified sentence from Simple Wikipedia of this Wikipedia sentence You need to find a simplified sentence from Simple Wikipedia corresponding to the following sentence
3 | cnn_dailymail The following sentences are the summaries of an news article. Find the source news article. Retrieve a news article that is summarized as following Find the original news article based on which the following highlights are written
4 | coco_captions Retrieve a image caption sentence that is written for the same image as the following caption Find a image caption describing the same image as Can you find an image caption talking about the same image as
5 | codesearch_go Match the following natural language instruction to Go codes Retrieve Go implementations achieving the following features Find a Go code implementation on GitHub for the following natural language instruction I want to find a Go code implementing the following function from Github.
6 | codesearch_java Match the following natural language instruction to Java codes Retrieve Java implementations achieving the following features Find a Java code implementation on GitHub for the following natural language instruction I want to find Java code implementing the following function from Github.
7 | codesearch_javascript Match the following natural language instruction to JavaScript codes Retrieve JavaScript implementations achieving the following features Find a JavaScript code implementation on GitHub for the following natural language instruction I want to find JavaScript code implementing the following function from Github.
8 | codesearch_ruby Match the following natural language instruction to Ruby codes Retrieve Ruby implementations achieving the following features Find a Ruby code implementation on GitHub for the following natural language instruction I want to find Ruby code implementing the following function from Github.
9 | eli5_question_answer Find an answer to this question for me. Retrieve an paragraph-length answer to this question. You have to find answer to this question.
10 | fever Retrieve a Wikipedia paragraph to verify this claim Find an evidence paragraph from Wikipedia to confirm the statement is correct I want to know if this sentence is a fact or not. Can you find relented Wikipedia passages for me? You need to find Wikipedia paragraphs that support or refute this sentence
11 | gigaword Retrieve a extremely short summary of the following Gigaword article Find a corresponding headline of the this Gigaword article summary. I want to retrieve a headline of this gigaword news
12 | hotpotqa Retrieve a Wikipedia paragraph that provides useful evidence or answer for this question I want to know the answer of this question. Please find related Wikipedia passages for me. Find a paragraph that provides useful information to answer this question You need to find multiple Wikipedia passages to answer this multi-hop question. Can you find related passages?
13 | yahoo_answers_title_answer Find an answer from a community QA forum for the following question Retrieve the most voted answer for this question from Yahoo Answers. This is an question's title posted in Yahoo Answers, a community forum. Please retrieve the answer post.
14 | mdmcqa Find an evidence from medical text book to answer this medical exam question I want to know the answer of this following medical exam question. Retrieve an evidence passage answering question Find the explanation for the correct answer of this medical question
15 | medical_sim Please retrieve a medical paper summary that is written in a simple language so that my patient can understand You need to find a simple summary of medical paper that corresponds to the following paper abstract You need to retrieve a medical paper summary written in a simple English of this paper Match this following medical paper abstract to a simplified English summary for non experts
16 | msmarco-triplets I want to know the answer to the question. Can you find good evidence on the web? Retrieve a web paragraph that answers the following Find an evidence paragraph on the web with evidence and answer for this
17 | multilexsum Find a short summary of this following legal case Map this legal case summary to a sentence-long summary An extremely short summary sentence of this legal case Retrieve a one-sentence summary of the following legal case
18 | nq retrieve passages from Wikipedia that provides answers to the following question You have to find a Wikipedia paragraph that provides the answer to the question I want to find an answer for question. Can you find some paragraphs that provide evidence from Wikipedia? I'm looking for a Wikipedia paragraph that answers this question. Give me a Wikipedia paragraph answering this open-domain question. A Wikipedia paragraph providing sufficient evidence to answer this question Your job is to find a Wikipedia paragraph that answers my question You need to retrieve an evidence paragraph from Wikipedia to answer this question
19 | oqa Find a question that is paraphrased of this Retrieve a question that is duplicated with the following You need to find duplicate questions in Wiki forum. Could you find a question that is similar to this question Find an open-domain question that is similar to the following An open-domain question that is duplicated with the following
20 | agnews Find a news summary sentence corresponding to the following header Retrieve an sentence-long news summary for this header Please find a good summary of this news from the news summary collection
21 | pubmedqa Find a related medical paper to answer the following question I want to check if the following statement is true or not. Retrieve a scientific paper from PubMed for me Help me to find a highly related pubmed paper to answer this question Retrieve a medical paper abstract answering the following question
22 | qrecc Find a meaningful dialogue response to answer the user's question Retrieve a good dialogue response that answers this question You need to find a good response from a collection of previous responses and help users to know this topic more
23 | record Find a News article to verify the following sentence I want to know if this sentence is true. Please retrieve a highly-relevant news article News articles that provide a piece of sufficient evidence to verify the following statement
24 | scitldr Find a sentence-length summary of this paper. Retrieve a really short summary of this paper abstract. What is the TLDR of this paper? Retrieve a sentence that summarizes the following abstract
25 | searchQA_top5_snippets Pick up the top web search results' snippets for the following question. Find the top 5 Web snippets that answer this You have to match this question to top five web search snippets.
26 | sentence-compression Find a short sentence that compresses the following long sentence You have to match this long sentence to a shorter compressed one Retrieve a compressed version of the following sentence written by human annotators
27 | npr Given a news article headline published at npr.org, find a corresponding summary of the news Retrieve a news summary of the following news headline You have to match the following news article to a short summary
28 | squad_pairs.jsonl Find a Wikipedia paragraph that answer the question You have to retrieve a paragraph from Wikipedia that provides evidence and answer for this question This question asks about the details written in a Wikipedia paragraph. Select the paragraph the question is about
29 | stackexchange_duplicate_questions_title-body_title-body Find a question paragraph that is duplicated with the following question paragraph at StackExchange. I want to find a question similar to this question already asked in StackExchange. Retrieve a question (main text) that is similar to this following paragraph StackExchange is a community QA forum for diverse topics including technical or science. Help me to find a question paragraph that duplicates with my following question paragraph
30 | stackexchange_duplicate_questions_title_title Find an duplicated question on StackExchange, a community forum. Find a question title that is similar to the following question title asked in StackExchange, a community QA forum I want to find a related question asked in StackExchange. Can you find one for me?
31 | paq Find a web paragraph long answer for this question Can you answer my question by finding an article on the web? Retrieve a paragraph answer for the following question
32 | triviaqa retrieve a related web article that provides evidences and answers for the following Trivia question You have to find a answer paragraph on the web for this question I want to find an answer for the following Trivia questions. Can you find some paragraphs that provide evidence on the web?
33 | wikihow Find the how-to article to achieve the following goal from Wikihow, a website collecting how-to articles. WikiHow is an wiki-style community and database of how-to guides. Suggest the best article for the following Find a detailed paragraph from WikiHow that explains how-to to achieve
34 | wow Find an Wikipedia paragraph that is related to the current conversation topic to generate a meaningful response. Retrieve a paragraph from Wikipedia to help a conversational AI to generate a knowledge-grounded dialogue You must find a Wikipedia paragraph to help a chatbot to generate informative response. Can you find one? Find an Wikipedia paragraph related to the following conversation topic.
35 | xsum Find a short summary of this following legal case Map this legal case summary to a sentence-long summary An extremely short summary sentence of this legal case
36 | quora Find a question that is duplicated with the following question asked in Quora, a community QA forum I want to find a question similar to this question already asked in Quora. Retrieve a question body that is similar to Check if a Quora question is duplicated with this question
37 | ccnews Retrieve a news article that corresponds to this title Find the original news article that this title is written for I want to know the details of this news. Can you find a detailed news article on this for me?
38 | wow response Find a knowledgeable response for an AI chat bot given the following user's inputs Retrieve a informative knowledge-grounded dialogue response given this AI-user chat history A good chat bot response to answer this user's query
--------------------------------------------------------------------------------
/BERRI/create_tart_dual_train_data.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | import os
7 | import argparse
8 | import json
9 | import glob
10 | import random
11 | import jsonlines
12 | import numpy as np
13 | import pandas as pd
14 |
15 | import numpy as np
16 | from tqdm import tqdm
17 | import copy
18 |
19 |
20 | def load_data(data_path):
21 | data = []
22 | with open(data_path, "r") as fin:
23 | for k, example in enumerate(fin):
24 | example = json.loads(example)
25 | data.append(example)
26 | return data
27 |
28 |
29 | def process_instruction_unfollowing_sample(input_data, prompts):
30 | new_data = {}
31 | for item in input_data:
32 | query = item["question"]
33 | sampled_context = random.sample(item["ctxs"], k=1)
34 | new_data[query] = sampled_context
35 | return new_data
36 |
37 |
38 | def load_jsonlines(file_name):
39 | with jsonlines.open(file_name, 'r') as jsonl_f:
40 | data = [obj for obj in jsonl_f]
41 | return data
42 |
43 |
44 | def process_prompts(row):
45 | # load instructions
46 | prompts = []
47 | for i in range(1, 9):
48 | column_name = "prompt_{}".format(i)
49 | if type(column_name) == str:
50 | prompts.append(row[column_name])
51 | return prompts
52 |
53 |
54 | def main(args):
55 | print(f"Loading model from: {args.model_name_or_path}")
56 | final_data = []
57 | min_neg_nums = 100
58 | file_names = pd.read_csv(args.inst_file, sep="\t")
59 |
60 | if args.instruction_unfollowing_file is not None:
61 | inst_unfollow_dict = {}
62 | inst_unfollow_file_names = open(
63 | args.instruction_unfollowing_file).read().split("\n")[:-1]
64 | # Load instruction unfollowing files.
65 | # src indicates the original task name (query task) while tgt indicate the corpus task name.
66 | for file in inst_unfollow_file_names:
67 | src_task_name = (file.split("src_")[1]).split("_tgt")[0]
68 | inst_unfollow_dict.setdefault(src_task_name, [])
69 | inst_unfollow_dict[src_task_name].append(file)
70 |
71 | for _, file_data in tqdm(file_names.iterrows()):
72 | input_file = os.path.join(
73 | args.input_dir, file_data["dataset"] + ".jsonl")
74 | prompts = process_prompts(file_data)
75 | input_data = load_jsonlines(input_file)
76 | # mostly for ablation. Skip tasks that are specified in `ignore_tasks`, or that are not specified in `task_names`.
77 | if args.ignore_tasks is not None and file_data["dataset"] in args.ignore_tasks:
78 | print("skip {0}".format(file_data["dataset"]))
79 | continue
80 | if args.task_names is not None and file_data["dataset"] not in args.task_names:
81 | print("skip {0}".format(file_data["dataset"]))
82 | continue
83 | print("# of data: {}".format(len(input_data)))
84 |
85 | # add instruction-unfollowing data.
86 | if args.instruction_unfollowing_file is not None and file_data["dataset"] in inst_unfollow_dict:
87 | for unfollowing_file_name in inst_unfollow_dict[file_data["dataset"]]:
88 | unfollowing_file_name = glob.glob(
89 | unfollowing_file_name + "/*.json*")[0]
90 | print(unfollowing_file_name)
91 | unfollowing_input_data = load_data(unfollowing_file_name)
92 | processed_training_data_unfollowing = process_instruction_unfollowing_sample(
93 | unfollowing_input_data, prompts)
94 |
95 | for datapoint in input_data:
96 | if args.instruction_unfollowing_file is not None and file_data["dataset"] in inst_unfollow_dict and datapoint["question"] in processed_training_data_unfollowing:
97 | instructions_unfollowing_negatives = processed_training_data_unfollowing[
98 | datapoint["question"]]
99 | else:
100 | instructions_unfollowing_negatives = []
101 | if len(datapoint["question"]) < 20:
102 | continue
103 |
104 | # To make a model robust to different instructions, sample two prompts per instance.
105 | if len(prompts) > 2:
106 | sampled_prompts = random.sample(prompts, k=2)
107 | else:
108 | sampled_prompts = prompts
109 | true_negatives = []
110 | true_hard_negatives = []
111 | # final check if the paragraph is not included in the labeled gold paragraphs.
112 | for neg in datapoint["negative_ctxs"]:
113 | skip = False
114 | for pos in datapoint["positive_ctxs"]:
115 | if neg["text"] == pos["text"]:
116 | print("false negatives")
117 | skip = True
118 | if skip is False:
119 | true_negatives.append(neg)
120 |
121 | # a cross encoder can falsely predict a positive paragraph "negative".
122 | # check if the paragraph is not included in the labeled gold paragraphs.
123 | for neg in datapoint["hard_negative_ctxs"]:
124 | skip = False
125 | for pos in datapoint["positive_ctxs"]:
126 | if neg["text"] == pos["text"]:
127 | print("false negatives")
128 | skip = True
129 | if skip is False:
130 | true_hard_negatives.append(neg)
131 | datapoint["negative_ctxs"] = true_negatives
132 | datapoint["hard_negative_ctxs"] = true_hard_negatives
133 |
134 | # create training samples.
135 |
136 | for p in sampled_prompts:
137 | new_data = copy.deepcopy(datapoint)
138 | new_data["question"] = "{0} [SEP] {1}".format(
139 | p, new_data["question"])
140 | if args.only_negative is True:
141 | # do not add additional CE scores
142 | new_data["positive_ctxs"] = [
143 | pos for pos in new_data["positive_ctxs"] if "ce_score" not in pos]
144 |
145 | if len(new_data["positive_ctxs"]) > args.num_positive_paragraphs:
146 | new_data["positive_ctxs"] = random.sample(
147 | new_data["positive_ctxs"], k=args.num_positive_paragraphs)
148 | for pos in new_data["positive_ctxs"]:
149 |
150 | if "title" not in pos:
151 | neg["title"] = ""
152 |
153 | if pos["title"] is None:
154 | print("none title")
155 | pos["title"] = ""
156 |
157 | if type(pos["text"]) is list:
158 | pos["text"] = pos["text"][0]
159 |
160 | if len(new_data["negative_ctxs"]) > args.num_negative_paragraphs:
161 | new_data["negative_ctxs"] = random.sample(
162 | new_data["negative_ctxs"], k=args.num_negative_paragraphs)
163 | for neg in new_data["negative_ctxs"]:
164 | if type(neg["text"]) is list:
165 | neg["text"] = neg["text"][0]
166 |
167 | if "title" not in neg:
168 | neg["title"] = ""
169 |
170 | if neg["title"] is None:
171 | neg["title"] = ""
172 |
173 | if len(new_data["hard_negative_ctxs"]) > args.num_hard_negative_paragraphs:
174 | new_data["hard_negative_ctxs"] = random.sample(
175 | new_data["hard_negative_ctxs"], k=args.num_hard_negative_paragraphs)
176 | for neg in new_data["hard_negative_ctxs"]:
177 | if type(neg["text"]) is list:
178 | neg["text"] = neg["text"][0]
179 | if "title" not in neg:
180 | neg["title"] = ""
181 | if neg["title"] is None:
182 | neg["title"] = ""
183 |
184 | if len(instructions_unfollowing_negatives) > args.num_instructions_unfollowing_negatives:
185 | new_data["hard_negative_ctxs"] = random.sample(
186 | new_data["hard_negative_ctxs"], k=args.num_instructions_unfollowing_negatives)
187 | for neg in instructions_unfollowing_negatives:
188 | if type(neg["text"]) is list:
189 | neg["text"] = neg["text"][0]
190 | if "title" not in neg:
191 | neg["title"] = ""
192 | if neg["title"] is None:
193 | neg["title"] = ""
194 | new_data["hard_negative_ctxs"].append(neg)
195 |
196 | assert len(new_data["positive_ctxs"]) > 0
197 | final_data.append(new_data)
198 | if len(new_data["negative_ctxs"]) + len(new_data["hard_negative_ctxs"]) < min_neg_nums:
199 | min_neg_nums = len(
200 | new_data["negative_ctxs"]) + len(new_data["hard_negative_ctxs"])
201 | print("# of data: {}".format(len(final_data)))
202 |
203 | random.shuffle(final_data)
204 | # split data into train and dev set for development purpose.
205 | train_data, dev_data = final_data[5000:], final_data[:5000]
206 |
207 | with jsonlines.open(args.output_file + "_train.jsonl", "w") as writer:
208 | writer.write_all(train_data)
209 | with jsonlines.open(args.output_file + "_dev.jsonl", "w") as writer:
210 | writer.write_all(dev_data)
211 |
212 |
213 | if __name__ == "__main__":
214 | parser = argparse.ArgumentParser()
215 |
216 | parser.add_argument(
217 | "--inst_file",
218 | required=True,
219 | type=str,
220 | default=None,
221 | help=".json file containing question and answers, similar format to reader data",
222 | )
223 |
224 | parser.add_argument(
225 | "--only_negative",
226 | action="store_true"
227 | )
228 |
229 | parser.add_argument(
230 | "--kd",
231 | action="store_true"
232 | )
233 | parser.add_argument(
234 | "--ignore_tasks",
235 | type=str,
236 | nargs="+"
237 | )
238 |
239 | parser.add_argument(
240 | "--input_dir",
241 | required=True,
242 | type=str,
243 | default=None,
244 | help=".json file containing question and answers, similar format to reader data",
245 | )
246 | parser.add_argument("--output_file", type=str, default=None,
247 | help="Path to passages (.tsv file)")
248 | parser.add_argument("--sample_dataset_num", type=int, default=None,
249 | help="Path to passages (.tsv file)")
250 | parser.add_argument("--output_dir", type=str, default=None,
251 | help="Path to passages (.tsv file)")
252 | parser.add_argument("--task_names", type=str, default=None,
253 | help="task names filter", nargs="+")
254 | parser.add_argument("--prompts", type=str, default=None,
255 | help="prompt", nargs="+")
256 | parser.add_argument("--per_gpu_batch_size", type=int,
257 | default=64, help="Batch size for question encoding")
258 | parser.add_argument("--n_docs", type=int, default=100,
259 | help="Number of documents to retrieve per questions")
260 | parser.add_argument("--instruction_unfollowing_file", type=str, default=None,
261 | help="instruction unfollowing file")
262 | parser.add_argument(
263 | "--model_name_or_path", type=str, help="path to directory containing model weights and config file"
264 | )
265 | parser.add_argument("--no_fp16", action="store_true",
266 | help="inference in fp32")
267 | parser.add_argument("--question_maxlength", type=int,
268 | default=512, help="Maximum number of tokens in a question")
269 | parser.add_argument("--start_idx", type=int, default=None,)
270 | parser.add_argument("--end_idx", type=int, default=None,)
271 | parser.add_argument("--num_positive_paragraphs", type=int, default=2,)
272 | parser.add_argument("--num_negative_paragraphs", type=int, default=7,)
273 | parser.add_argument("--num_hard_negative_paragraphs", type=int, default=3,)
274 | parser.add_argument(
275 | "--num_instructions_unfollowing_negatives", type=int, default=1,)
276 | args = parser.parse_args()
277 | main(args)
278 |
--------------------------------------------------------------------------------
/BERRI/create_tart_full_train_data.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | import os
7 | import argparse
8 | import json
9 | import glob
10 | import random
11 | import jsonlines
12 | import pandas as pd
13 |
14 | from tqdm import tqdm
15 |
16 |
17 | def load_jsonlines(file_name):
18 | with jsonlines.open(file_name, 'r') as jsonl_f:
19 | data = [obj for obj in jsonl_f]
20 | return data
21 |
22 |
23 | def process_data(input_data, prompts):
24 | new_data = []
25 | false_negatives = 0
26 | for item in input_data:
27 | query = item["question"]
28 | positive_ctxs = item["positive_ctxs"]
29 | if len(positive_ctxs) > args.num_positive_ctxs:
30 | positive_ctxs = random.sample(
31 | positive_ctxs, k=args.num_positive_ctxs)
32 | negative_ctxs = item["negative_ctxs"] + \
33 | item["hard_negative_ctxs"] if "hard_negative_ctxs" in item else item["negative_ctxs"]
34 | final_negatives = []
35 | for neg in negative_ctxs:
36 | if neg["text"] in [pos["text"] for pos in item["positive_ctxs"]]:
37 | false_negatives += 1
38 | continue
39 | else:
40 | final_negatives.append(neg)
41 | negative_ctxs = final_negatives
42 |
43 | if len(negative_ctxs) > args.num_negative_ctxs:
44 | negative_ctxs = random.sample(
45 | negative_ctxs, k=args.num_negative_ctxs)
46 |
47 | for pos in positive_ctxs:
48 | prompt = random.sample(prompts, k=1)[0]
49 | prompted_query = "{0} {1}".format(prompt, query)
50 |
51 | title = pos["title"]
52 | text = pos["text"]
53 | if title is not None and len(title) > 0:
54 | text = "{0} {1}".format(title, text)
55 | new_data.append(
56 | {"query": prompted_query, "document": text, "label": 1})
57 | for neg in negative_ctxs:
58 | prompt = random.sample(prompts, k=1)[0]
59 | prompted_query = "{0} {1}".format(prompt, query)
60 |
61 | title = neg["title"]
62 | text = neg["text"]
63 | if title is not None and len(title) > 0:
64 | text = "{0} {1}".format(title, text)
65 | new_data.append(
66 | {"query": prompted_query, "document": text, "label": 0})
67 | print("{} data created".format(len(new_data)))
68 | print("false negatives {}".format(false_negatives))
69 | return new_data
70 |
71 |
72 | def process_instruction_unfollowing_sample(input_data, prompts, full_data_num):
73 | inst_num = min(int(full_data_num * 0.8 * 0.2), 10000)
74 | if len(input_data) > 7500:
75 | input_data = random.sample(input_data, k=inst_num)
76 | new_data = []
77 | for item in input_data:
78 | query = item["question"]
79 | sampled_context = random.sample(item["ctxs"], k=1)
80 | for ctx in sampled_context:
81 | prompt = random.sample(prompts, k=1)[0]
82 | prompted_query = "{0} {1}".format(prompt, query)
83 |
84 | title = ctx["title"]
85 | text = ctx["text"]
86 | if len(title) > 0:
87 | text = "{0} {1}".format(title, text)
88 | new_data.append(
89 | {"query": prompted_query, "document": text, "label": 0})
90 | print("instructions unfollowing samples")
91 | print("{} data created".format(len(new_data)))
92 | return new_data
93 |
94 |
95 | def process_prompts(row):
96 | # load instructions
97 | prompts = []
98 | for i in range(1, 9):
99 | column_name = "prompt_{}".format(i)
100 | if type(column_name) == str:
101 | prompts.append(row[column_name])
102 | return prompts
103 |
104 |
105 | def load_data(data_path):
106 | data = []
107 | with open(data_path, "r") as fin:
108 | for k, example in enumerate(fin):
109 | example = json.loads(example)
110 | data.append(example)
111 | return data
112 |
113 |
114 | def main(args):
115 | if not os.path.exists(args.output_dir):
116 | os.makedirs(args.output_dir)
117 |
118 | all_data = []
119 | file_names = pd.read_csv(args.input_file, sep="\t")
120 | if args.instruction_unfollowing_file is not None:
121 | inst_unfollow_dict = {}
122 | inst_unfollow_file_names = open(
123 | args.instruction_unfollowing_file).read().split("\n")[:-1]
124 | for file in inst_unfollow_file_names:
125 | src_task_name = (file.split("src_")[1]).split("_tgt")[0]
126 | inst_unfollow_dict.setdefault(src_task_name, [])
127 | inst_unfollow_dict[src_task_name].append(file)
128 |
129 | if args.sample_dataset_num is not None:
130 | dataset_names = file_names["dataset"]
131 | sampled_datasets = random.sample(
132 | list(dataset_names), k=args.sample_dataset_num)
133 |
134 | for idx, file_data in tqdm(file_names.iterrows()):
135 | if args.task_names is not None and file_data["dataset"] not in args.task_names:
136 | continue
137 | if args.start_idx is not None and idx < args.start_idx:
138 | continue
139 | if args.end_idx is not None and idx > args.end_idx:
140 | continue
141 |
142 | if os.path.exists(os.path.join(args.input_dir, file_data["dataset"] + ".jsonl")) is False:
143 | print("file name is mssing ")
144 | print(os.path.join(args.input_dir,
145 | file_data["dataset"] + ".jsonl"))
146 | print(file_data["dataset"])
147 | continue
148 |
149 | if args.sample_dataset_num is not None and file_data["dataset"] not in sampled_datasets:
150 | continue
151 |
152 | task_input_file = os.path.join(
153 | args.input_dir, file_data["dataset"] + ".jsonl")
154 | print(task_input_file)
155 | prompts = process_prompts(file_data)
156 | input_data = load_data(task_input_file)
157 | processed_training_data = process_data(input_data, prompts)
158 | all_data += processed_training_data
159 |
160 | if args.instruction_unfollowing_file is not None and file_data["dataset"] in inst_unfollow_dict:
161 | for unfollowing_file_name in inst_unfollow_dict[file_data["dataset"]]:
162 | unfollowing_file_name = glob.glob(
163 | unfollowing_file_name + "/*.json*")[0]
164 | print(unfollowing_file_name)
165 | unfollowing_input_data = load_data(unfollowing_file_name)
166 | processed_training_data_unfollowing = process_instruction_unfollowing_sample(
167 | unfollowing_input_data, prompts, len(processed_training_data))
168 | all_data += processed_training_data_unfollowing
169 |
170 | if len(all_data) > args.max_train_data:
171 | all_data = random.sample(all_data, k=args.max_train_data)
172 |
173 | random.shuffle(all_data)
174 | train_data, dev_data = all_data[10000:], all_data[:10000]
175 |
176 | with jsonlines.open(os.path.join(args.output_dir, "tart_full_train.json"), 'w') as writer:
177 | writer.write_all(train_data)
178 | with jsonlines.open(os.path.join(args.output_dir, "tart_full_dev.json"), 'w') as writer:
179 | writer.write_all(dev_data)
180 |
181 |
182 | if __name__ == "__main__":
183 | parser = argparse.ArgumentParser()
184 |
185 | parser.add_argument(
186 | "--input_file",
187 | required=True,
188 | type=str,
189 | default=None,
190 | help=".json file containing question and answers, similar format to reader data",
191 | )
192 | parser.add_argument("--output_file", type=str, default=None,
193 | help="Path to passages (.tsv file)")
194 | parser.add_argument("--output_dir", type=str, default=None,
195 | help="Path to passages (.tsv file)")
196 | parser.add_argument("--input_dir", type=str, default=None,
197 | help="Path to passages (.tsv file)")
198 | parser.add_argument("--sample_dataset_num", type=int, default=None,
199 | help="Path to passages (.tsv file)")
200 | parser.add_argument("--instruction_unfollowing_file", type=str, default=None,
201 | help="instruction unfollowing file")
202 | parser.add_argument("--prompts", type=str, default=None,
203 | help="prompt", nargs="+")
204 | parser.add_argument("--task_names", type=str, default=None,
205 | help="task names filter", nargs="+")
206 | parser.add_argument("--instance_idx_start", type=int,
207 | default=None, help="instance start index")
208 | parser.add_argument("--instance_idx_end", type=int,
209 | default=None, help="instance end index")
210 | parser.add_argument("--per_gpu_batch_size", type=int,
211 | default=64, help="Batch size for question encoding")
212 | parser.add_argument("--n_docs", type=int, default=100,
213 | help="Number of documents to retrieve per questions")
214 | parser.add_argument(
215 | "--model_name_or_path", type=str, help="path to directory containing model weights and config file"
216 | )
217 | parser.add_argument("--no_fp16", action="store_true",
218 | help="inference in fp32")
219 | parser.add_argument("--question_maxlength", type=int,
220 | default=512, help="Maximum number of tokens in a question")
221 | parser.add_argument("--start_idx", type=int, default=None,)
222 | parser.add_argument("--end_idx", type=int, default=None,)
223 | parser.add_argument("--max_train_data", type=int, default=3000000)
224 | parser.add_argument("--num_positive_paragraphs", type=int, default=1,)
225 | parser.add_argument("--num_negative_paragraphs", type=int, default=4,)
226 | args = parser.parse_args()
227 | src.slurm.init_distributed_mode(args)
228 | main(args)
229 |
--------------------------------------------------------------------------------
/BERRI/enc_t5/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | from .modeling_enc_t5 import EncT5ForSequenceClassification
7 | from .tokenization_enc_t5 import EncT5Tokenizer
8 |
--------------------------------------------------------------------------------
/BERRI/enc_t5/modeling_enc_t5.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | import copy
7 | import torch
8 | from torch import nn
9 | from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
10 | from transformers.modeling_outputs import SequenceClassifierOutput
11 | from transformers.models.t5.modeling_t5 import T5Config, T5PreTrainedModel, T5Stack
12 | from transformers.utils.model_parallel_utils import assert_device_map, get_device_map
13 |
14 |
15 | class EncT5ForSequenceClassification(T5PreTrainedModel):
16 | _keys_to_ignore_on_load_missing = [
17 | r"encoder\.embed_tokens\.weight",
18 | ]
19 |
20 | def __init__(self, config: T5Config, dropout=0.1):
21 | super().__init__(config)
22 | self.num_labels = config.num_labels
23 | self.config = config
24 |
25 | self.shared = nn.Embedding(config.vocab_size, config.d_model)
26 |
27 | encoder_config = copy.deepcopy(config)
28 | encoder_config.use_cache = False
29 | encoder_config.is_encoder_decoder = False
30 | self.encoder = T5Stack(encoder_config, self.shared)
31 |
32 | self.dropout = nn.Dropout(dropout)
33 | self.classifier = nn.Linear(config.hidden_size, config.num_labels)
34 |
35 | # Initialize weights and apply final processing
36 | self.post_init()
37 |
38 | # Model parallel
39 | self.model_parallel = False
40 | self.device_map = None
41 |
42 | def parallelize(self, device_map=None):
43 | self.device_map = (
44 | get_device_map(len(self.encoder.block),
45 | range(torch.cuda.device_count()))
46 | if device_map is None
47 | else device_map
48 | )
49 | assert_device_map(self.device_map, len(self.encoder.block))
50 | self.encoder.parallelize(self.device_map)
51 | self.classifier = self.classifier.to(self.encoder.first_device)
52 | self.model_parallel = True
53 |
54 | def deparallelize(self):
55 | self.encoder.deparallelize()
56 | self.encoder = self.encoder.to("cpu")
57 | self.model_parallel = False
58 | self.device_map = None
59 | torch.cuda.empty_cache()
60 |
61 | def get_input_embeddings(self):
62 | return self.shared
63 |
64 | def set_input_embeddings(self, new_embeddings):
65 | self.shared = new_embeddings
66 | self.encoder.set_input_embeddings(new_embeddings)
67 |
68 | def get_encoder(self):
69 | return self.encoder
70 |
71 | def _prune_heads(self, heads_to_prune):
72 | """
73 | Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
74 | class PreTrainedModel
75 | """
76 | for layer, heads in heads_to_prune.items():
77 | self.encoder.layer[layer].attention.prune_heads(heads)
78 |
79 | def forward(
80 | self,
81 | input_ids=None,
82 | attention_mask=None,
83 | head_mask=None,
84 | inputs_embeds=None,
85 | labels=None,
86 | output_attentions=None,
87 | output_hidden_states=None,
88 | return_dict=None,
89 | ):
90 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
91 |
92 | outputs = self.encoder(
93 | input_ids=input_ids,
94 | attention_mask=attention_mask,
95 | inputs_embeds=inputs_embeds,
96 | head_mask=head_mask,
97 | output_attentions=output_attentions,
98 | output_hidden_states=output_hidden_states,
99 | return_dict=return_dict,
100 | )
101 |
102 | hidden_states = outputs[0]
103 | # Take bos token (equiv. to )
104 | pooled_output = hidden_states[:, 0, :]
105 |
106 | pooled_output = self.dropout(pooled_output)
107 | logits = self.classifier(pooled_output)
108 |
109 | loss = None
110 | if labels is not None:
111 | if self.config.problem_type is None:
112 | if self.num_labels == 1:
113 | self.config.problem_type = "regression"
114 | elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
115 | self.config.problem_type = "single_label_classification"
116 | else:
117 | self.config.problem_type = "multi_label_classification"
118 |
119 | if self.config.problem_type == "regression":
120 | loss_fct = MSELoss()
121 | if self.num_labels == 1:
122 | loss = loss_fct(logits.squeeze(), labels.squeeze())
123 | else:
124 | loss = loss_fct(logits, labels)
125 | elif self.config.problem_type == "single_label_classification":
126 | loss_fct = CrossEntropyLoss()
127 | loss = loss_fct(
128 | logits.view(-1, self.num_labels), labels.view(-1))
129 | elif self.config.problem_type == "multi_label_classification":
130 | loss_fct = BCEWithLogitsLoss()
131 | loss = loss_fct(logits, labels)
132 | if not return_dict:
133 | output = (logits,) + outputs[1:]
134 | return ((loss,) + output) if loss is not None else output
135 |
136 | return SequenceClassifierOutput(
137 | loss=loss,
138 | logits=logits,
139 | hidden_states=outputs.hidden_states,
140 | attentions=outputs.attentions,
141 | )
142 |
--------------------------------------------------------------------------------
/BERRI/enc_t5/tokenization_enc_t5.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 | from typing import Any, Dict, List, Optional
7 | from transformers import T5Tokenizer
8 |
9 |
10 | class EncT5Tokenizer(T5Tokenizer):
11 | def __init__(
12 | self,
13 | vocab_file,
14 | bos_token="",
15 | eos_token="",
16 | unk_token="",
17 | pad_token="",
18 | extra_ids=100,
19 | additional_special_tokens=None,
20 | sp_model_kwargs: Optional[Dict[str, Any]] = None,
21 | **kwargs,
22 | ) -> None:
23 | sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
24 |
25 | super().__init__(
26 | vocab_file=vocab_file,
27 | bos_token=bos_token,
28 | eos_token=eos_token,
29 | unk_token=unk_token,
30 | pad_token=pad_token,
31 | extra_ids=extra_ids,
32 | additional_special_tokens=additional_special_tokens,
33 | sp_model_kwargs=sp_model_kwargs,
34 | **kwargs,
35 | )
36 |
37 | def get_special_tokens_mask(
38 | self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
39 | ) -> List[int]:
40 | """
41 | Retrieve sequence ids from a token list that has no special tokens added.
42 | This will be called when adding special tokens using the tokenizer `prepare_for_model` method.
43 | """
44 | if already_has_special_tokens:
45 | return super().get_special_tokens_mask(
46 | token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
47 | )
48 |
49 | # normal case: some special tokens
50 | if token_ids_1 is None:
51 | return [1] + ([0] * len(token_ids_0)) + [1]
52 | return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
53 |
54 | def create_token_type_ids_from_sequences(
55 | self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
56 | ) -> List[int]:
57 | bos = [self.bos_token_id]
58 | eos = [self.eos_token_id]
59 |
60 | if token_ids_1 is None:
61 | return len(bos + token_ids_0 + eos) * [0]
62 | return len(bos + token_ids_0 + eos + token_ids_1 + eos) * [0]
63 |
64 | def build_inputs_with_special_tokens(
65 | self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
66 | ) -> List[int]:
67 | """
68 | Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
69 | adding special tokens. A sequence has the following format:
70 | - single sequence: ` X `
71 | - pair of sequences: ` A B `
72 | """
73 | if token_ids_1 is None:
74 | return [self.bos_token_id] + token_ids_0 + [self.eos_token_id]
75 | else:
76 | return [self.bos_token_id] + token_ids_0 + [self.eos_token_id] + token_ids_1 + [self.eos_token_id]
77 |
--------------------------------------------------------------------------------
/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
1 | # Code of Conduct
2 |
3 | Facebook has adopted a Code of Conduct that we expect project participants to adhere to.
4 | Please read the [full text](https://code.fb.com/codeofconduct/)
5 | so that you can understand what actions will and will not be tolerated.
6 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # Contributing to this repo
2 |
3 | ## Pull Requests
4 |
5 | In order to accept your pull request, we need you to submit a CLA. You only need
6 | to do this once to work on any of Facebook's open source projects.
7 |
8 | Complete your CLA here:
9 |
10 | ## Issues
11 | We use GitHub issues to track public bugs. Please ensure your description is
12 | clear and has sufficient instructions to be able to reproduce the issue.
13 |
14 | ## License
15 | By contributing to this repo, you agree that your contributions will be licensed
16 | under the LICENSE file in the root directory of this source tree.
17 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Task-aware Retrieval with Instructions
2 |
3 | This is the official repository for our preprint, [Task-aware Retrieval with Instructions](https://arxiv.org/abs/2211.09260).
4 |
5 | We introduce a new retrieval task formulation, **retrieval with instructions**, constructs **BERRI**, the first large-scale collection of retrieval datasets with instructions, and present **TART**, multi-task instruction-following retrieval models trained on BERRI.
6 |
7 |
8 |
9 |
10 |
11 | ## Content
12 |
13 | 1. [Getting started](#getting-started)
14 | 2. [Pretrained Checkpoints](#pre-trained-checkpoints)
15 | - [Pre-trained checkpoints ](#pre-trained-checkpoints)
16 | - [Embeddings](#embeddings)
17 | 3. [Evaluation](#evaluation)
18 | - [BERRI](#beir)
19 | - [LOTTE](#lotte)
20 | - [Cross-task Cross-domain evaluation](#cross-task-cross-domain-dataset)
21 | 4. [Training](#training)
22 | 5. [Dataset: BERRI](#dataset-berri)
23 | 5. [Citations and Contact](#citation-and-contact)
24 |
25 |
26 | *Updates*
27 | **2023.12**: Released initial codes.
28 | **2023.2**: Released scripts to create BERRI with links to the data created by the third party.
29 | **2023.2**: Released training instructions.
30 |
31 | ## Getting started
32 | Pre-trained models can be loaded through the HuggingFace transformers library:
33 |
34 | ```py
35 | from src.modeling_enc_t5 import EncT5ForSequenceClassification
36 | from src.tokenization_enc_t5 import EncT5Tokenizer
37 | import torch
38 | import torch.nn.functional as F
39 | import numpy as np
40 | # load TART full and tokenizer
41 | model = EncT5ForSequenceClassification.from_pretrained("facebook/tart-full-flan-t5-xl")
42 | tokenizer = EncT5Tokenizer.from_pretrained("facebook/tart-full-flan-t5-xl")
43 | ```
44 |
45 | Then, you can test TART-full as follows:
46 | ```py
47 | model.eval()
48 | q = "What is the population of Tokyo?"
49 | in_answer = "retrieve a passage that answers this question from Wikipedia"
50 |
51 | p_1 = "The population of Japan's capital, Tokyo, dropped by about 48,600 people to just under 14 million at the start of 2022, the first decline since 1996, the metropolitan government reported Monday."
52 | p_2 = "Tokyo, officially the Tokyo Metropolis (東京都, Tōkyō-to), is the capital and largest city of Japan."
53 |
54 | # 1. TART-full can identify more relevant paragraph.
55 | features = tokenizer(['{0} [SEP] {1}'.format(in_answer, q), '{0} [SEP] {1}'.format(in_answer, q)], [p_1, p_2], padding=True, truncation=True, return_tensors="pt")
56 | with torch.no_grad():
57 | scores = model(**features).logits
58 | normalized_scores = [float(score[1]) for score in F.softmax(scores, dim=1)]
59 |
60 | print([p_1, p_2][np.argmax(normalized_scores)]) # "The population of Japan's capital, Tokyo, dropped by about 48,600 people to just under 14 million ... "
61 |
62 | # 2. TART-full can identify the document that is more relevant AND follows instructions.
63 | in_sim = "You need to find duplicated questions in Wiki forum. Could you find a question that is similar to this question"
64 | q_1 = "How many people live in Tokyo?"
65 | features = tokenizer(['{0} [SEP] {1}'.format(in_sim, q), '{0} [SEP] {1}'.format(in_sim, q)], [p, q_1], padding=True, truncation=True, return_tensors="pt")
66 | with torch.no_grad():
67 | scores = model(**features).logits
68 | normalized_scores = [float(score[1]) for score in F.softmax(scores, dim=1)]
69 |
70 | print([p_1, q_1][np.argmax(normalized_scores)]) # "How many people live in Tokyo?"
71 |
72 | ```
73 | As you can see, TART not only gives lower scores to the wrong query (q_1 v.s. q_2), but also gives a lower score to the document that is relevant but does not follow the instruction.
74 |
75 | ### Interactive mode
76 | Interactive mode enables you to type whatever questions and retrieve pre-encoded documents. To run this interactive mode, you need to encode documents, as well as the models.
77 |
78 | ```sh
79 | python interactive.py \
80 | --passages scifact/corpus.jsonl \
81 | --passages_embeddings "PATH_TO_YOUR_EMBEDDINGS/*" \
82 | --model_name_or_path PATH_TO_YOUR_BE_MODEL \
83 | --ce_model PATH_TO_YOUR_CE_MODEL \
84 | ```
85 |
86 |
87 | ## Pre-trained checkpoints
88 | ### TART-full
89 | We release TART-full models trained on BERRI using different initial encoder weights. Our TART-full model on the paper is based on T0-3B. We will release TART-full using smaller model soon!
90 |
91 | Models are all on huggingface hub.
92 |
93 | | name | size | initialization |
94 | | ----------- | ----------- | ----------- |
95 | | [facebook/tart-full-flan-t0-3b](https://huggingface.co/facebook/tart-full-flan-t0-3b) | 1.5 billions |[T0-3B](https://huggingface.co/bigscience/T0_3B)|
96 | | [facebook/tart-full-flan-t5-xl](https://huggingface.co/facebook/tart-full-flan-t5-xl) | 1.5 billions | [FLANT5-XL](https://huggingface.co/google/flan-t5-xl)|
97 |
98 | ### TART-dual
99 | TART-dual is an efficient bi-encoder model sharing an encoder for document and query encodings.
100 | | name | size | initialization |
101 | | ----------- | ----------- | ----------- |
102 | | [facebook/tart-dual-contriever-msmarco](https://homes.cs.washington.edu/~akari/models/tart-dual-contriever-msmarco.zip) | 110 millions |[facebook/contriever-msmarco](https://huggingface.co/facebook/contriever-msmarco)|
103 |
104 | The main model on the paper uses [Contriever-MS MARCO](https://huggingface.co/facebook/contriever-msmarco) pre-trained on Wikipedia 2020 dump.
105 |
106 | ## Embeddings
107 | We release the pre-encoded embeddings for the BEIR datasets [here](https://drive.google.com/file/d/1qLqFLlvLpQ2OGHowO0Jwv_di8mJgh2be/view?usp=sharing):
108 |
109 |
110 | ## Evaluation
111 | ### BEIR
112 | You can evaluate the models on BEIR, by running `eval_beir.py` or `eval_cross_task.py`.
113 |
114 | `eval_beir.py` is adopted from the official BEIR repository, encodes and runs inference using a single GPU every time, while `eval_cross_task.py` assumes that you have encoded document embeddings and parallelize inference using multiple GPUs. If you have multiple GPUs or try to evaluate TART on datasets with millions of documents (e.g., Climate-FEVER), we recommend using `eval_cross_task.py` script.
115 |
116 | #### Run evaluation with `eval_beir.py`
117 |
118 | ```sh
119 | python eval_beir.py \
120 | --model_name_or_path BI_ENCODER_MODEL_NAME_OR_PATH \
121 | --dataset BEIR_DATASET_NAME \
122 | --output_dir YOUR_OUTPUT_DIR
123 | --model_name_or_path BI_ENCODER_MODEL_NAME_OR_PATH \
124 | --ce_model CROSS_ENCODER_MODEL_NAME_OR_PATH \
125 | --prompt "YOUR INSTRUCTIONS"
126 | ```
127 |
128 |
129 | #### Run evaluation with `eval_cross_task.py`
130 | As mentioned above, there are two steps to run `eval_cross_task.py` script: **STEP1: encode all documents**, and **STEP2: run evaluations using encoded embeddings**.
131 |
132 | ##### STEP1: Encode all of the document
133 | To encode document using a single GPU, please run the command below:
134 |
135 | ```sh
136 | python generate_passage_embeddings.py --model_name_or_path YOUR_MODEL_NAME --output_dir OUTPUT_DIR_NAME \
137 | --passages PATH_TO_YOUR_INPUT_DATA_DIR/corpus.jsonl --shard_id ${i} --num_shards 1
138 | ```
139 |
140 | If you want to use multiple GPUs to speed up the process, you can run the following command:
141 |
142 | ```sh
143 | for i in {0..7}; do
144 | export CUDA_VISIBLE_DEVICES=${i}
145 | nohup python generate_passage_embeddings.py --model_name_or_path BI_ENCODER_MODEL_NAME_OR_PATH --output_dir OUTPUT_DIR_NAME \
146 | --passages PATH_TO_YOUR_INPUT_DATA_DIR/corpus.jsonl --shard_id ${i} --num_shards 8 > ./log/nohup.log.${i} 2>&1 &
147 | done
148 | ```
149 |
150 | The corpus file is a `jsonlines` file, where each item contains `text` and `title`, and optional `_id` and `meta_data`.
151 |
152 | e.g.,
153 |
154 | ```
155 | {"_id": "doc9", "title": "Chicago Fire (season 4)", "text": "Hermann is rushed to Chicago Med after being stabbed at Molly's. After losing a lot a blood, it is determined he needs emergency surgery. Feeling guilty about Hermann's present state, Cruz searches for Freddy to turn him in. Severide is reinstated as Lieutenant while Borelli grows more concerned about Chili's erratic behavior. Mouch considers finally proposing to Platt.", "metadata": {}}
156 | ```
157 |
158 | ##### STEP2: Run predictions
159 |
160 | Once you encode passages, you can run the evaluations as follows:
161 | ```sh
162 | python eval_cross_task.py \
163 | --passages PATH_TO_YOUR_INPUT_DATA_DIR/corpus.jsonl \
164 | --passages_embeddings "PATH_TO_YOUR_EMBEDDING_OUTPUT_DIR/passages_*" \
165 | --qrels PATH_TO_YOUR_INPUT_DATA_DIR/qrels/test.csv \
166 | --output_dir OUT_PUT_DIR_NAME \
167 | --model_name_or_path BI_ENCODER_MODEL_NAME_OR_PATH \
168 | --ce_model CROSS_ENCODER_MODEL_NAME_OR_PATH \
169 | --data PATH_TO_YOUR_INPUT_DATA_DIR/queries.jsonl \
170 | --prompt "YOUR INSTRUCTIONS"
171 | ```
172 |
173 | ### LOTTE
174 | We evaluate our model on LOTTE-search (pooled). To run the evaluations on LOTTE, you can download our processed data (the data itself is the same but we convert the input data file formats and add instructions) as follows:
175 |
176 | ```sh
177 | wget https://homes.cs.washington.edu/~akari/tart/processed_lotte_search_pooled.zip
178 | unzip processed_lotte_search_pooled.zip
179 | ```
180 |
181 | Encode passages as in the previous section.
182 |
183 | ```sh
184 | for i in {0..7}; do
185 | export CUDA_VISIBLE_DEVICES=${i}
186 | nohup python generate_passage_embeddings.py --model_name_or_path BI_ENCODER_MODEL_NAME_OR_PATH --output_dir OUTPUT_DIR_NAME \
187 | --passages processed_lotte_search_pooled/corpus.jsonl --shard_id ${i} --num_shards 8 > ./log/nohup.log.${i} 2>&1 &
188 | done
189 | ```
190 |
191 | Once you encode the passages, you can run evaluations
192 | ```sh
193 | python eval_cross_task.py \
194 | --passages processed_lotte_search_pooled/corpus.jsonl \
195 | --passages_embeddings "contriever_lotte_corpus/passages_*" \
196 | --qrels processed_lotte_search_pooled/qrels/test.tsv \
197 | --output_dir OUT_PUT_DIR_NAME \
198 | --model_name_or_path BI_ENCODER_MODEL_NAME_OR_PATH \
199 | --ce_model CROSS_ENCODER_MODEL_NAME_OR_PATH \
200 | --data processed_lotte_search_pooled/queries_w_instructions_sep.jsonl \
201 | --lotte
202 | ```
203 | This code output the lotte's official evaluation script format data under `CROSS_ENCODER_MODEL_NAME_OR_PATH/`
204 | Then you can run the official evaluation script as follows:
205 |
206 | ```sh
207 | cp lotte
208 | python evaluate_lotte_rankings.py --k 5 --split test --data_path ../lotte --rankings_path PATH_TO_PREDICTION_FILE
209 | ```
210 |
211 | ### Cross-task Cross-domain dataset
212 | In this paper, we newly introduce cross-task cross-domain evaluation, where given an instruction and a single large-scale domain, a system needs to retrieve documents that follow instructions.
213 |
214 | Due to legal reasons, Meta cannot host this data. The script to create cross-task cross-domain dataset is available at [cross_task_cross_domain](https://github.com/facebookresearch/tart/tree/main/cross_task_cross_eval), and you can also download the processed cross task dataset as follows.
215 |
216 | ```sh
217 | wget https://homes.cs.washington.edu/~akari/tart/cross_task_cross_domain_final.zip
218 | unzip https://homes.cs.washington.edu/~akari/tart/cross_task_cross_domain_final.zip
219 | ```
220 |
221 | Due to the larger corpus, we highly recommend encoding every documents beforehand.
222 | Encoded documents are available at the [encoded documents](#embeddings) Section.
223 |
224 | Then you can run evaluations on the cross-task cross-domain data as follows:
225 | ```sh
226 | python eval_cross_task.py \
227 | --passages ../cross_2_eval/nq/corpus.jsonl ../cross_tacross_2_evalsk_eval/scifact/corpus.jsonl ../cross_2_eval/gooaq_med/corpus.jsonl ../cross_2_eval/linkso_py/corpus.jsonl ../cross_2_eval/ambig/corpus.jsonl ../cross_2_eval/wikiqa/corpus.jsonl ../cross_2_eval/gooaq_technical/corpus.jsonl ../cross_2_eval/codesearch_py/corpus_new.jsonl \
228 | --passages_embeddings "linkso_py_contriever_embeddings/passages_*" "ambig_contriever_embeddings/passages_*" "scifact_contriever_embeddings/*" "nq_contriever_embeddings/passages_*" "gooaq_technical_contriever_embeddings/passages_*" "codesearch_py_contriever_embeddings/passages_*" "wikiqa_contriever_embeddings/passages_*" \
229 | --qrels ../cross_task_eval/linkso/qrels/test_new.tsv \
230 | --output_dir YOUR_OUTPUT_DIR \
231 | --model_name_or_path BI_ENCODER_MODEL_NAME_OR_PATH \
232 | --ce_model CROSS_ENCODER_MODEL_NAME_OR_PATH \
233 | ```
234 |
235 | ## Training
236 | ### TART-full
237 | To train TART-full model, run the script below. We use 8 GBU for training. We found that training 3B cross-encoder too long can make the model overfit to the data, so we only train the models for one epoch and pick up the best model based on the development sore.
238 |
239 | ```
240 | python finetuning_tart_full.py --task_name ce \
241 | --train_data PATH_TO_TRAIN_DATA \
242 | --eval_data PATH_TO_DEV_DATA \
243 | --model_name_or_path [bigscience/T0_3B, google/flan-t5-xl] \
244 | --output_dir PATH_TO_YOUR_OUTPUT_DIR \
245 | --do_train --do_eval \
246 | --overwrite_output_dir --evaluation_strategy steps \
247 | --eval_steps 2000 --save_steps 2000 \
248 | --metric_for_best_model accuracy \
249 | --per_gpu_train_batch_size 1 --num_train_epochs 1 \
250 | --gradient_accumulation_steps 8 --max_seq_length 512 \
251 | --load_best_model_at_end
252 | ```
253 |
254 | ### TART-dual
255 | To train TART-dual model, run the script below. We use 64 GBU for training.
256 |
257 | ```
258 | cd TART
259 | python finetuning.py \
260 | --model_path facebook/contriever-msmarco \
261 | --train_data PATH_TO_TRAIN_DATA \
262 | --eval_data PATH_TO_DEV_DATA \
263 | --chunk_length 256 --negative_ctxs 5 --warmup_steps 1000 --total_steps 50000 \
264 | --lr 0.00001 --scheduler linear --optim adamw --per_gpu_batch_size 16 --temperature 0.05 \
265 | --per_gpu_eval_batch_size 16 --output_dir PATH_TO_OUTPUT_DIR \
266 | --save_freq 5000 --eval_freq 5000 --negative_hard_ratio 0.1
267 | ```
268 |
269 |
270 | ## Dataset: BERRI
271 |
272 | [`berri`](berri) directory includes the scripts and instruction data to construct the BERRI dataset.
273 |
274 | ### Instructions
275 | All of the annotated instructions are in [berri_instructions.tsv](berri/berri_instructions.tsv)
276 |
277 | ### Preprocessing script and processed data
278 | BERRI data constructions consists of (i) convert existing datasets into retrieval tasks, adding initial positive and randomly sampled negative passages, (ii) retrieve top paragraphs using an efficient bi-encoder (i.e., Contriever), and then (iii) denoise the top documents.
279 |
280 | [`berri/README.md`](berri/README.md) describes the detailed instructions.
281 |
282 | You can download the processed source data (from the process (i)) as well as the final training data for TART-dual and full, processed by a third party here:
283 | - [source data (22 GB)](https://drive.google.com/file/d/1hzlN4cEFOZRkdVeCMq62NUxvMNTopB1o/view?usp=share_link)
284 | - [TART-full training data (1 GB)](https://drive.google.com/file/d/1oijzAb2gWKT54OgeE7_KB9VcHvA7UxpQ/view?usp=share_link)
285 | - [TART-dual training data (14 GB)](https://drive.google.com/file/d/1lMmD5lTxYWYf0z0ua0-GaGKz2qs2mG1r/view?usp=share_link)
286 |
287 |
288 | ## Citation and Contact
289 | If you find this repository helpful, please cite our paper.
290 |
291 | ```
292 | @article{asai2022tart,
293 | title={Task-aware Retrieval with Instructions},
294 | author={Asai, Akari and Schick, Timo and Lewis, Patrick and Chen, Xilun and Izacard, Gautier and Riedel, Sebastian and Hajishirzi, Hannaneh and Yih, Wen-tau},
295 | journal={arXiv preprint arXiv:2211.09260},
296 | year={2022}
297 | }
298 | ```
299 |
300 | If you have any questions about the paper, feel free to contact Akari Asai (akari[at]cs.washington.edu) or open an issue, and mention @AkariAsai
301 |
302 | ### License
303 | See the [LICENSE](LICENSE) file for more details.
304 |
--------------------------------------------------------------------------------
/TART/custom_metrics.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import logging
8 | from typing import List, Dict, Union, Tuple
9 |
10 | def mrr(qrels: Dict[str, Dict[str, int]],
11 | results: Dict[str, Dict[str, float]],
12 | k_values: List[int]) -> Tuple[Dict[str, float]]:
13 |
14 | MRR = {}
15 |
16 | for k in k_values:
17 | MRR[f"MRR@{k}"] = 0.0
18 |
19 | k_max, top_hits = max(k_values), {}
20 | logging.info("\n")
21 |
22 | for query_id, doc_scores in results.items():
23 | top_hits[query_id] = sorted(doc_scores.items(), key=lambda item: item[1], reverse=True)[0:k_max]
24 |
25 | for query_id in top_hits:
26 | query_relevant_docs = set([doc_id for doc_id in qrels[query_id] if qrels[query_id][doc_id] > 0])
27 | for k in k_values:
28 | for rank, hit in enumerate(top_hits[query_id][0:k]):
29 | if hit[0] in query_relevant_docs:
30 | MRR[f"MRR@{k}"] += 1.0 / (rank + 1)
31 | break
32 |
33 | for k in k_values:
34 | MRR[f"MRR@{k}"] = round(MRR[f"MRR@{k}"]/len(qrels), 5)
35 | logging.info("MRR@{}: {:.4f}".format(k, MRR[f"MRR@{k}"]))
36 |
37 | return MRR
38 |
39 | def recall_cap(qrels: Dict[str, Dict[str, int]],
40 | results: Dict[str, Dict[str, float]],
41 | k_values: List[int]) -> Tuple[Dict[str, float]]:
42 |
43 | capped_recall = {}
44 |
45 | for k in k_values:
46 | capped_recall[f"R_cap@{k}"] = 0.0
47 |
48 | k_max = max(k_values)
49 | logging.info("\n")
50 |
51 | for query_id, doc_scores in results.items():
52 | top_hits = sorted(doc_scores.items(), key=lambda item: item[1], reverse=True)[0:k_max]
53 | query_relevant_docs = [doc_id for doc_id in qrels[query_id] if qrels[query_id][doc_id] > 0]
54 | for k in k_values:
55 | retrieved_docs = [row[0] for row in top_hits[0:k] if qrels[query_id].get(row[0], 0) > 0]
56 | denominator = min(len(query_relevant_docs), k)
57 | capped_recall[f"R_cap@{k}"] += (len(retrieved_docs) / denominator)
58 |
59 | for k in k_values:
60 | capped_recall[f"R_cap@{k}"] = round(capped_recall[f"R_cap@{k}"]/len(qrels), 5)
61 | logging.info("R_cap@{}: {:.4f}".format(k, capped_recall[f"R_cap@{k}"]))
62 |
63 | return capped_recall
64 |
65 |
66 | def hole(qrels: Dict[str, Dict[str, int]],
67 | results: Dict[str, Dict[str, float]],
68 | k_values: List[int]) -> Tuple[Dict[str, float]]:
69 |
70 | Hole = {}
71 |
72 | for k in k_values:
73 | Hole[f"Hole@{k}"] = 0.0
74 |
75 | annotated_corpus = set()
76 | for _, docs in qrels.items():
77 | for doc_id, score in docs.items():
78 | annotated_corpus.add(doc_id)
79 |
80 | k_max = max(k_values)
81 | logging.info("\n")
82 |
83 | for _, scores in results.items():
84 | top_hits = sorted(scores.items(), key=lambda item: item[1], reverse=True)[0:k_max]
85 | for k in k_values:
86 | hole_docs = [row[0] for row in top_hits[0:k] if row[0] not in annotated_corpus]
87 | Hole[f"Hole@{k}"] += len(hole_docs) / k
88 |
89 | for k in k_values:
90 | Hole[f"Hole@{k}"] = round(Hole[f"Hole@{k}"]/len(qrels), 5)
91 | logging.info("Hole@{}: {:.4f}".format(k, Hole[f"Hole@{k}"]))
92 |
93 | return Hole
94 |
95 | def top_k_accuracy(
96 | qrels: Dict[str, Dict[str, int]],
97 | results: Dict[str, Dict[str, float]],
98 | k_values: List[int]) -> Tuple[Dict[str, float]]:
99 |
100 | top_k_acc = {}
101 |
102 | for k in k_values:
103 | top_k_acc[f"Accuracy@{k}"] = 0.0
104 |
105 | k_max, top_hits = max(k_values), {}
106 | logging.info("\n")
107 |
108 | for query_id, doc_scores in results.items():
109 | top_hits[query_id] = [item[0] for item in sorted(doc_scores.items(), key=lambda item: item[1], reverse=True)[0:k_max]]
110 |
111 | for query_id in top_hits:
112 | query_relevant_docs = set([doc_id for doc_id in qrels[query_id] if qrels[query_id][doc_id] > 0])
113 | for k in k_values:
114 | for relevant_doc_id in query_relevant_docs:
115 | if relevant_doc_id in top_hits[query_id][0:k]:
116 | top_k_acc[f"Accuracy@{k}"] += 1.0
117 | break
118 |
119 | for k in k_values:
120 | top_k_acc[f"Accuracy@{k}"] = round(top_k_acc[f"Accuracy@{k}"]/len(qrels), 5)
121 | logging.info("Accuracy@{}: {:.4f}".format(k, top_k_acc[f"Accuracy@{k}"]))
122 |
123 | return top_k_acc
--------------------------------------------------------------------------------
/TART/eval_beir.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import sys
8 | import argparse
9 | from turtle import update
10 | import torch
11 | import logging
12 | import json
13 | import numpy as np
14 | import os
15 | import copy
16 |
17 | import src.slurm
18 | import src.contriever
19 | import src.beir_utils
20 | import src.utils
21 | import src.dist_utils
22 | import src.contriever
23 |
24 | logger = logging.getLogger(__name__)
25 |
26 |
27 | def main(args):
28 |
29 | src.slurm.init_distributed_mode(args)
30 | src.slurm.init_signal_handler()
31 |
32 | os.makedirs(args.output_dir, exist_ok=True)
33 |
34 | logger = src.utils.init_logger(args)
35 |
36 | model, tokenizer, _ = src.contriever.load_retriever(
37 | args.model_name_or_path)
38 | if args.bi_encoder is True:
39 | if args.ckpt_path is not None:
40 | state_dict = torch.load(args.ckpt_path)["model"]
41 | query_encoder = copy.deepcopy(model)
42 | doc_encoder = copy.deepcopy(model)
43 | query_encoder_dic = {}
44 | doc_encoder_dic = {}
45 | for name, param in state_dict.items():
46 | print(name)
47 | if "q_encoder" in name:
48 | if "encoder" in name:
49 | orig_name = name.replace(
50 | "q_encoder.encoder", "encoder")
51 | if "embeddings" in name:
52 | orig_name = name.replace(
53 | "q_encoder.embeddings", "embeddings")
54 | query_encoder_dic[orig_name] = param
55 | print(orig_name)
56 | if "p_encoder" in name:
57 | if "encoder" in name:
58 | orig_name = name.replace(
59 | "p_encoder.encoder", "encoder")
60 | if "embeddings" in name:
61 | orig_name = name.replace(
62 | "p_encoder.embeddings", "embeddings")
63 | print(orig_name)
64 | doc_encoder_dic[orig_name] = param
65 |
66 | # print(query_encoder_dic.keys())
67 | # print(doc_encoder_dic.keys())
68 | query_encoder.load_state_dict(query_encoder_dic)
69 | # doc_encoder.load_state_dict(doc_encoder_dic)
70 |
71 | query_encoder = query_encoder.cuda()
72 | query_encoder.eval()
73 |
74 | doc_encoder = doc_encoder.cuda()
75 | doc_encoder.eval()
76 | else:
77 | print("for biencoder model, you have to preload the fine-tuned checkpoints.")
78 | raise NotImplementedError()
79 |
80 | else:
81 | if args.ckpt_path is not None:
82 | print("loading model")
83 | state_dict = torch.load(args.ckpt_path)["model"]
84 |
85 | new_dict = {}
86 | # state_dict = {k.replace("encoder_q.", ""): v for k, v in state_dict.items() if "encoder_q." in k}
87 | print(state_dict.keys())
88 | # print(dict(model.named_parameters()).keys())
89 |
90 | for name, param in model.named_parameters():
91 | if name in state_dict:
92 | new_dict[name] = state_dict[name]
93 | print("updated")
94 | print(name)
95 | else:
96 | new_dict[name] = param
97 |
98 | # print(new_dict.keys())
99 | # print(model.keys())
100 | # assert model.keys() == new_dict.keys()
101 |
102 | model.load_state_dict(new_dict, strict=False)
103 |
104 | model = model.cuda()
105 | model.eval()
106 | query_encoder = model
107 | doc_encoder = model
108 |
109 | logger.info("Start indexing")
110 |
111 | if args.multiple_prompts is not None:
112 | metrics = src.beir_utils.evaluate_model_multiple(
113 | query_encoder=query_encoder,
114 | doc_encoder=doc_encoder,
115 | tokenizer=tokenizer,
116 | dataset=args.dataset,
117 | batch_size=args.per_gpu_batch_size,
118 | norm_query=args.norm_query,
119 | norm_doc=args.norm_doc,
120 | is_main=src.dist_utils.is_main(),
121 | split="dev" if args.dataset == "msmarco" else "test",
122 | score_function=args.score_function,
123 | beir_dir=args.beir_dir,
124 | save_results_path=args.save_results_path,
125 | lower_case=args.lower_case,
126 | normalize_text=args.normalize_text,
127 | prompt=args.prompt,
128 | multiple_prompts=args.multiple_prompts
129 | )
130 | else:
131 | metrics = src.beir_utils.evaluate_model(
132 | query_encoder=query_encoder,
133 | doc_encoder=doc_encoder,
134 | tokenizer=tokenizer,
135 | dataset=args.dataset,
136 | batch_size=args.per_gpu_batch_size,
137 | norm_query=args.norm_query,
138 | norm_doc=args.norm_doc,
139 | is_main=src.dist_utils.is_main(),
140 | split="dev" if args.dataset == "msmarco" else "test",
141 | score_function=args.score_function,
142 | beir_dir=args.beir_dir,
143 | save_results_path=args.save_results_path,
144 | lower_case=args.lower_case,
145 | normalize_text=args.normalize_text,
146 | prompt=args.prompt,
147 | emb_load_path=args.emb_load_path,
148 | emb_save_path=args.emb_save_path
149 | )
150 |
151 | if src.dist_utils.is_main():
152 | for key, value in metrics.items():
153 | logger.info(f"{args.dataset} : {key}: {value:.1f}")
154 |
155 | print("saving results")
156 | if os.path.exists(os.path.join(args.output_dir, "{0}_{1}_results.json".format(args.dataset, args.model_id))) is True:
157 | results_log = json.load(open(os.path.join(
158 | args.output_dir, "{0}_{1}_results.json".format(args.dataset, args.model_id))))
159 | else:
160 | results_log = {}
161 | results_log.setdefault(args.prompt, {})
162 | results_log[args.prompt] = metrics
163 |
164 | with open(os.path.join(args.output_dir, "{0}_{1}_results.json".format(args.dataset, args.model_id)), "w") as outfile:
165 | json.dump(results_log, outfile)
166 |
167 |
168 | if __name__ == "__main__":
169 | parser = argparse.ArgumentParser(
170 | formatter_class=argparse.ArgumentDefaultsHelpFormatter)
171 |
172 | parser.add_argument("--dataset", type=str,
173 | help="Evaluation dataset from the BEIR benchmark")
174 | parser.add_argument("--beir_dir", type=str, default="./",
175 | help="Directory to save and load beir datasets")
176 | parser.add_argument("--text_maxlength", type=int,
177 | default=512, help="Maximum text length")
178 | parser.add_argument("--emb_load_path", type=str, default=None,
179 | help="path to load already computed embeddings.", nargs="+")
180 | parser.add_argument("--emb_save_path", type=str, default=None,
181 | help="path to save already computed embeddings.")
182 |
183 | parser.add_argument("--per_gpu_batch_size", default=128,
184 | type=int, help="Batch size per GPU/CPU for indexing.")
185 | parser.add_argument("--output_dir", type=str,
186 | default="./my_experiment", help="Output directory")
187 | parser.add_argument("--model_name_or_path", type=str,
188 | help="Model name or path")
189 | parser.add_argument(
190 | "--score_function", type=str, default="dot", help="Metric used to compute similarity between two embeddings"
191 | )
192 | parser.add_argument("--norm_query", action="store_true",
193 | help="Normalize query representation")
194 | parser.add_argument("--norm_doc", action="store_true",
195 | help="Normalize document representation")
196 | parser.add_argument("--lower_case", action="store_true",
197 | help="lowercase query and document text")
198 | parser.add_argument(
199 | "--normalize_text", action="store_true", help="Apply function to normalize some common characters"
200 | )
201 | parser.add_argument("--save_results_path", type=str,
202 | default=None, help="Path to save result object")
203 |
204 | parser.add_argument("--local_rank", type=int, default=-1,
205 | help="For distributed training: local_rank")
206 | parser.add_argument("--main_port", type=int, default=-1,
207 | help="Main port (for multi-node SLURM jobs)")
208 | parser.add_argument("--ckpt_path", type=str, help="Model name or path")
209 | parser.add_argument("--bi_encoder", action="store_true", )
210 | parser.add_argument(
211 | "--prompt", type=str, default=None, help="instructional prompt."
212 | )
213 | parser.add_argument(
214 | "--multiple_prompts", type=str, nargs='+'
215 | )
216 | parser.add_argument(
217 | "--model_id", type=str, default=None, help="for logging"
218 | )
219 | args, _ = parser.parse_known_args()
220 | main(args)
221 |
--------------------------------------------------------------------------------
/TART/generate_passage_embeddings.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import os
8 |
9 | import argparse
10 | import csv
11 | import logging
12 | import pickle
13 |
14 | import numpy as np
15 | import torch
16 |
17 | import transformers
18 |
19 | import src.slurm
20 | import src.contriever
21 | import src.utils
22 | import src.data
23 | import src.normalize_text
24 |
25 |
26 | def embed_passages(args, passages, model, tokenizer):
27 | total = 0
28 | allids, allembeddings = [], []
29 | batch_ids, batch_text = [], []
30 | with torch.no_grad():
31 | for k, p in enumerate(passages):
32 | batch_ids.append(p["id"] if "id" in p else p["_id"])
33 | if args.no_title or not "title" in p:
34 | text = p["text"]
35 | else:
36 | if type(p["title"]) is not str:
37 | print("title incorrect")
38 | print(p)
39 | if type(p["text"]) is not str:
40 | print("text incorrect")
41 | print(p)
42 | text = p["title"] + " " + p["text"]
43 | if args.lowercase:
44 | text = text.lower()
45 | if args.normalize_text:
46 | text = src.normalize_text.normalize(text)
47 | batch_text.append(text)
48 |
49 | if len(batch_text) == args.per_gpu_batch_size or k == len(passages) - 1:
50 |
51 | encoded_batch = tokenizer.batch_encode_plus(
52 | batch_text,
53 | return_tensors="pt",
54 | max_length=args.passage_maxlength,
55 | padding=True,
56 | truncation=True,
57 | )
58 |
59 | encoded_batch = {k: v.cuda() for k, v in encoded_batch.items()}
60 | embeddings = model(**encoded_batch)
61 |
62 | embeddings = embeddings.cpu()
63 | total += len(batch_ids)
64 | allids.extend(batch_ids)
65 | allembeddings.append(embeddings)
66 |
67 | batch_text = []
68 | batch_ids = []
69 | if k % 100000 == 0 and k > 0:
70 | print(f"Encoded passages {total}")
71 |
72 | allembeddings = torch.cat(allembeddings, dim=0).numpy()
73 | return allids, allembeddings
74 |
75 |
76 | def main(args):
77 | model, tokenizer, _ = src.contriever.load_retriever(
78 | args.model_name_or_path)
79 | print(f"Model loaded from {args.model_name_or_path}.", flush=True)
80 | model.eval()
81 | model = model.cuda()
82 | if not args.no_fp16:
83 | model = model.half()
84 |
85 | passages = src.data.load_passages(args.passages)
86 |
87 | shard_size = len(passages) // args.num_shards
88 | start_idx = args.shard_id * shard_size
89 | end_idx = start_idx + shard_size
90 | if args.shard_id == args.num_shards - 1:
91 | end_idx = len(passages)
92 |
93 | passages = passages[start_idx:end_idx]
94 | print(
95 | f"Embedding generation for {len(passages)} passages from idx {start_idx} to {end_idx}.")
96 |
97 | allids, allembeddings = embed_passages(args, passages, model, tokenizer)
98 |
99 | save_file = os.path.join(
100 | args.output_dir, args.prefix + f"_{args.shard_id:02d}")
101 | os.makedirs(args.output_dir, exist_ok=True)
102 | print(f"Saving {len(allids)} passage embeddings to {save_file}.")
103 | with open(save_file, mode="wb") as f:
104 | pickle.dump((allids, allembeddings), f)
105 |
106 | print(f"Total passages processed {len(allids)}. Written to {save_file}.")
107 |
108 |
109 | if __name__ == "__main__":
110 | parser = argparse.ArgumentParser()
111 |
112 | parser.add_argument("--passages", type=str, default=None,
113 | help="Path to passages (.tsv file)")
114 | parser.add_argument("--output_dir", type=str,
115 | default="wikipedia_embeddings", help="dir path to save embeddings")
116 | parser.add_argument("--prefix", type=str, default="passages",
117 | help="prefix path to save embeddings")
118 | parser.add_argument("--shard_id", type=int, default=0,
119 | help="Id of the current shard")
120 | parser.add_argument("--num_shards", type=int, default=1,
121 | help="Total number of shards")
122 | parser.add_argument(
123 | "--per_gpu_batch_size", type=int, default=512, help="Batch size for the passage encoder forward pass"
124 | )
125 | parser.add_argument("--passage_maxlength", type=int,
126 | default=512, help="Maximum number of tokens in a passage")
127 | parser.add_argument(
128 | "--model_name_or_path", type=str, help="path to directory containing model weights and config file"
129 | )
130 | parser.add_argument("--no_fp16", action="store_true",
131 | help="inference in fp32")
132 | parser.add_argument("--no_title", action="store_true",
133 | help="title not added to the passage body")
134 | parser.add_argument("--lowercase", action="store_true",
135 | help="lowercase text before encoding")
136 | parser.add_argument("--normalize_text", action="store_true",
137 | help="lowercase text before encoding")
138 |
139 | args = parser.parse_args()
140 |
141 | src.slurm.init_distributed_mode(args)
142 |
143 | main(args)
144 |
--------------------------------------------------------------------------------
/TART/passage_retrieval.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import os
8 | import argparse
9 | import csv
10 | import json
11 | import logging
12 | import pickle
13 | import time
14 | import glob
15 | from pathlib import Path
16 |
17 | import numpy as np
18 | import torch
19 | import transformers
20 |
21 | import src.index
22 | import src.contriever
23 | import src.utils
24 | import src.slurm
25 | import src.data
26 | from src.evaluation import calculate_matches
27 | import src.normalize_text
28 |
29 | os.environ["TOKENIZERS_PARALLELISM"] = "true"
30 |
31 |
32 | def embed_queries(args, queries, model, tokenizer):
33 | model.eval()
34 | embeddings, batch_question = [], []
35 | with torch.no_grad():
36 |
37 | for k, q in enumerate(queries):
38 | if args.lowercase:
39 | q = q.lower()
40 | if args.normalize_text:
41 | q = src.normalize_text.normalize(q)
42 | batch_question.append(q)
43 |
44 | if len(batch_question) == args.per_gpu_batch_size or k == len(queries) - 1:
45 |
46 | encoded_batch = tokenizer.batch_encode_plus(
47 | batch_question,
48 | return_tensors="pt",
49 | max_length=args.question_maxlength,
50 | padding=True,
51 | truncation=True,
52 | )
53 | encoded_batch = {k: v.cuda() for k, v in encoded_batch.items()}
54 | output = model(**encoded_batch)
55 | embeddings.append(output.cpu())
56 |
57 | batch_question = []
58 |
59 | embeddings = torch.cat(embeddings, dim=0)
60 | print(f"Questions embeddings shape: {embeddings.size()}")
61 |
62 | return embeddings.numpy()
63 |
64 |
65 | def index_encoded_data(index, embedding_files, indexing_batch_size):
66 | allids = []
67 | allembeddings = np.array([])
68 | for i, file_path in enumerate(embedding_files):
69 | print(f"Loading file {file_path}")
70 | with open(file_path, "rb") as fin:
71 | ids, embeddings = pickle.load(fin)
72 |
73 | allembeddings = np.vstack(
74 | (allembeddings, embeddings)) if allembeddings.size else embeddings
75 | allids.extend(ids)
76 | while allembeddings.shape[0] > indexing_batch_size:
77 | allembeddings, allids = add_embeddings(
78 | index, allembeddings, allids, indexing_batch_size)
79 |
80 | while allembeddings.shape[0] > 0:
81 | allembeddings, allids = add_embeddings(
82 | index, allembeddings, allids, indexing_batch_size)
83 |
84 | print("Data indexing completed.")
85 |
86 |
87 | def add_embeddings(index, embeddings, ids, indexing_batch_size):
88 | end_idx = min(indexing_batch_size, embeddings.shape[0])
89 | ids_toadd = ids[:end_idx]
90 | embeddings_toadd = embeddings[:end_idx]
91 | ids = ids[end_idx:]
92 | embeddings = embeddings[end_idx:]
93 | index.index_data(ids_toadd, embeddings_toadd)
94 | return embeddings, ids
95 |
96 |
97 | def validate(data, workers_num):
98 | match_stats = calculate_matches(data, workers_num)
99 | top_k_hits = match_stats.top_k_hits
100 |
101 | print("Validation results: top k documents hits %s", top_k_hits)
102 | top_k_hits = [v / len(data) for v in top_k_hits]
103 | message = ""
104 | for k in [5, 10, 20, 100]:
105 | if k <= len(top_k_hits):
106 | message += f"R@{k}: {top_k_hits[k-1]} "
107 | print(message)
108 | return match_stats.questions_doc_hits
109 |
110 |
111 | def add_passages(data, passages, top_passages_and_scores):
112 | # add passages to original data
113 | merged_data = []
114 | assert len(data) == len(top_passages_and_scores)
115 | for i, d in enumerate(data):
116 | results_and_scores = top_passages_and_scores[i]
117 | docs = [passages[doc_id] for doc_id in results_and_scores[0]]
118 | scores = [str(score) for score in results_and_scores[1]]
119 | ctxs_num = len(docs)
120 | d["ctxs"] = [
121 | {
122 | "id": results_and_scores[0][c],
123 | "title": docs[c]["title"] if "title" in docs[c] else "",
124 | "text": docs[c]["text"],
125 | "score": scores[c],
126 | }
127 | for c in range(ctxs_num)
128 | ]
129 |
130 |
131 | def add_hasanswer(data, hasanswer):
132 | # add hasanswer to data
133 | for i, ex in enumerate(data):
134 | for k, d in enumerate(ex["ctxs"]):
135 | d["hasanswer"] = hasanswer[i][k]
136 |
137 | # fix me
138 |
139 |
140 | def load_data(data_path):
141 | if data_path.endswith(".json"):
142 | with open(data_path, "r") as fin:
143 | data = json.load(fin)
144 | elif data_path.endswith(".jsonl"):
145 | data = []
146 | with open(data_path, "r") as fin:
147 | for k, example in enumerate(fin):
148 | example = json.loads(example)
149 | data.append(example)
150 | return data
151 |
152 |
153 | def main(args):
154 |
155 | print(f"Loading model from: {args.model_name_or_path}")
156 | model, tokenizer, _ = src.contriever.load_retriever(
157 | args.model_name_or_path)
158 | model.eval()
159 | model = model.cuda()
160 | if not args.no_fp16:
161 | model = model.half()
162 |
163 | index = src.index.Indexer(args.projection_size,
164 | args.n_subquantizers, args.n_bits)
165 |
166 | # index all passages
167 | input_paths = glob.glob(args.passages_embeddings)
168 | input_paths = sorted(input_paths)
169 | embeddings_dir = os.path.dirname(input_paths[0])
170 | index_path = os.path.join(embeddings_dir, "index.faiss")
171 | if args.save_or_load_index and os.path.exists(index_path):
172 | index.deserialize_from(embeddings_dir)
173 | else:
174 | print(f"Indexing passages from files {input_paths}")
175 | start_time_indexing = time.time()
176 | index_encoded_data(index, input_paths, args.indexing_batch_size)
177 | print(f"Indexing time: {time.time()-start_time_indexing:.1f} s.")
178 | if args.save_or_load_index:
179 | index.serialize(embeddings_dir)
180 |
181 | # load passages
182 | passages = src.data.load_passages(args.passages)
183 | passage_id_map = {x["id"]: x for x in passages}
184 |
185 | data_paths = glob.glob(args.data)
186 | alldata = []
187 | for path in data_paths:
188 | data = load_data(path)
189 | output_path = os.path.join(args.output_dir, os.path.basename(path))
190 |
191 | queries = [ex["question"] for ex in data]
192 | questions_embedding = embed_queries(args, queries, model, tokenizer)
193 |
194 | # get top k results
195 | start_time_retrieval = time.time()
196 | top_ids_and_scores = index.search_knn(questions_embedding, args.n_docs)
197 | print(f"Search time: {time.time()-start_time_retrieval:.1f} s.")
198 |
199 | add_passages(data, passage_id_map, top_ids_and_scores)
200 | hasanswer = validate(data, args.validation_workers)
201 | add_hasanswer(data, hasanswer)
202 | os.makedirs(os.path.dirname(output_path), exist_ok=True)
203 | with open(output_path, "w") as fout:
204 | for ex in data:
205 | json.dump(ex, fout, ensure_ascii=False)
206 | fout.write("\n")
207 | print(f"Saved results to {output_path}")
208 |
209 |
210 | if __name__ == "__main__":
211 | parser = argparse.ArgumentParser()
212 |
213 | parser.add_argument(
214 | "--data",
215 | required=True,
216 | type=str,
217 | default=None,
218 | help=".json file containing question and answers, similar format to reader data",
219 | )
220 | parser.add_argument("--passages", type=str, default=None,
221 | help="Path to passages (.tsv file)")
222 | parser.add_argument("--passages_embeddings", type=str,
223 | default=None, help="Glob path to encoded passages")
224 | parser.add_argument(
225 | "--output_dir", type=str, default=None, help="Results are written to outputdir with data suffix"
226 | )
227 | parser.add_argument("--n_docs", type=int, default=100,
228 | help="Number of documents to retrieve per questions")
229 | parser.add_argument(
230 | "--validation_workers", type=int, default=32, help="Number of parallel processes to validate results"
231 | )
232 | parser.add_argument("--per_gpu_batch_size", type=int,
233 | default=64, help="Batch size for question encoding")
234 | parser.add_argument(
235 | "--save_or_load_index", action="store_true", help="If enabled, save index and load index if it exists"
236 | )
237 | parser.add_argument(
238 | "--model_name_or_path", type=str, help="path to directory containing model weights and config file"
239 | )
240 | parser.add_argument("--no_fp16", action="store_true",
241 | help="inference in fp32")
242 | parser.add_argument("--question_maxlength", type=int,
243 | default=512, help="Maximum number of tokens in a question")
244 | parser.add_argument(
245 | "--indexing_batch_size", type=int, default=1000000, help="Batch size of the number of passages indexed"
246 | )
247 | parser.add_argument("--projection_size", type=int, default=768)
248 | parser.add_argument(
249 | "--n_subquantizers",
250 | type=int,
251 | default=0,
252 | help="Number of subquantizer used for vector quantization, if 0 flat index is used",
253 | )
254 | parser.add_argument("--n_bits", type=int, default=8,
255 | help="Number of bits per subquantizer")
256 | parser.add_argument("--lang", nargs="+")
257 | parser.add_argument("--dataset", type=str, default="none")
258 | parser.add_argument("--lowercase", action="store_true",
259 | help="lowercase text before encoding")
260 | parser.add_argument("--normalize_text",
261 | action="store_true", help="normalize text")
262 |
263 | args = parser.parse_args()
264 | src.slurm.init_distributed_mode(args)
265 | main(args)
266 |
--------------------------------------------------------------------------------
/TART/preprocess.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | import os
8 | import argparse
9 | import torch
10 |
11 | import transformers
12 | from src.normalize_text import normalize
13 |
14 |
15 | def save(tensor, split_path):
16 | if not os.path.exists(os.path.dirname(split_path)):
17 | os.makedirs(os.path.dirname(split_path))
18 | with open(split_path, 'wb') as fout:
19 | torch.save(tensor, fout)
20 |
21 |
22 | def apply_tokenizer(path, tokenizer, normalize_text=False):
23 | alltokens = []
24 | lines = []
25 | with open(path, "r", encoding="utf-8") as fin:
26 | for k, line in enumerate(fin):
27 | if normalize_text:
28 | line = normalize(line)
29 |
30 | lines.append(line)
31 | if len(lines) > 1000000:
32 | tokens = tokenizer.batch_encode_plus(
33 | lines, add_special_tokens=False)['input_ids']
34 | tokens = [torch.tensor(x, dtype=torch.int) for x in tokens]
35 | alltokens.extend(tokens)
36 | lines = []
37 |
38 | tokens = tokenizer.batch_encode_plus(
39 | lines, add_special_tokens=False)['input_ids']
40 | tokens = [torch.tensor(x, dtype=torch.int) for x in tokens]
41 | alltokens.extend(tokens)
42 |
43 | alltokens = torch.cat(alltokens)
44 | return alltokens
45 |
46 |
47 | def tokenize_file(args):
48 | filename = os.path.basename(args.datapath)
49 | savepath = os.path.join(args.outdir, f"{filename}.pkl")
50 | if os.path.exists(savepath):
51 | if args.overwrite:
52 | print(f"File {savepath} already exists, overwriting")
53 | else:
54 | print(f"File {savepath} already exists, exiting")
55 | return
56 | try:
57 | tokenizer = transformers.AutoTokenizer.from_pretrained(
58 | args.tokenizer, local_files_only=True)
59 | except:
60 | tokenizer = transformers.AutoTokenizer.from_pretrained(
61 | args.tokenizer, local_files_only=False)
62 | print(f"Encoding {args.datapath}...")
63 | tokens = apply_tokenizer(args.datapath, tokenizer,
64 | normalize_text=args.normalize_text)
65 |
66 | print(f"Saving at {savepath}...")
67 | save(tokens, savepath)
68 |
69 |
70 | if __name__ == '__main__':
71 | parser = argparse.ArgumentParser(
72 | formatter_class=argparse.ArgumentDefaultsHelpFormatter)
73 | parser.add_argument("--datapath", type=str)
74 | parser.add_argument("--outdir", type=str)
75 | parser.add_argument("--tokenizer", type=str)
76 | parser.add_argument("--overwrite", action="store_true")
77 | parser.add_argument("--normalize_text", action="store_true")
78 |
79 | args, _ = parser.parse_known_args()
80 | tokenize_file(args)
81 |
--------------------------------------------------------------------------------
/TART/requirements.txt:
--------------------------------------------------------------------------------
1 | torch==1.11.0
2 | transformers==4.18.0
3 | beir==1.0.0
4 | pytrec_eval
5 | jsonlines
6 | gdown
--------------------------------------------------------------------------------
/TART/src/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/facebookresearch/tart/4ed5fcb0ed0254b1062305adbc390617b296fd29/TART/src/__init__.py
--------------------------------------------------------------------------------
/TART/src/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/facebookresearch/tart/4ed5fcb0ed0254b1062305adbc390617b296fd29/TART/src/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/TART/src/__pycache__/contriever.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/facebookresearch/tart/4ed5fcb0ed0254b1062305adbc390617b296fd29/TART/src/__pycache__/contriever.cpython-39.pyc
--------------------------------------------------------------------------------
/TART/src/__pycache__/data.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/facebookresearch/tart/4ed5fcb0ed0254b1062305adbc390617b296fd29/TART/src/__pycache__/data.cpython-39.pyc
--------------------------------------------------------------------------------
/TART/src/__pycache__/dist_utils.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/facebookresearch/tart/4ed5fcb0ed0254b1062305adbc390617b296fd29/TART/src/__pycache__/dist_utils.cpython-39.pyc
--------------------------------------------------------------------------------
/TART/src/__pycache__/evaluation.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/facebookresearch/tart/4ed5fcb0ed0254b1062305adbc390617b296fd29/TART/src/__pycache__/evaluation.cpython-39.pyc
--------------------------------------------------------------------------------
/TART/src/__pycache__/index.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/facebookresearch/tart/4ed5fcb0ed0254b1062305adbc390617b296fd29/TART/src/__pycache__/index.cpython-39.pyc
--------------------------------------------------------------------------------
/TART/src/__pycache__/modeling_enc_t5.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/facebookresearch/tart/4ed5fcb0ed0254b1062305adbc390617b296fd29/TART/src/__pycache__/modeling_enc_t5.cpython-39.pyc
--------------------------------------------------------------------------------
/TART/src/__pycache__/normalize_text.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/facebookresearch/tart/4ed5fcb0ed0254b1062305adbc390617b296fd29/TART/src/__pycache__/normalize_text.cpython-39.pyc
--------------------------------------------------------------------------------
/TART/src/__pycache__/slurm.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/facebookresearch/tart/4ed5fcb0ed0254b1062305adbc390617b296fd29/TART/src/__pycache__/slurm.cpython-39.pyc
--------------------------------------------------------------------------------
/TART/src/__pycache__/tokenization_enc_t5.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/facebookresearch/tart/4ed5fcb0ed0254b1062305adbc390617b296fd29/TART/src/__pycache__/tokenization_enc_t5.cpython-39.pyc
--------------------------------------------------------------------------------
/TART/src/__pycache__/utils.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/facebookresearch/tart/4ed5fcb0ed0254b1062305adbc390617b296fd29/TART/src/__pycache__/utils.cpython-39.pyc
--------------------------------------------------------------------------------
/TART/src/contriever.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2 |
3 | import os
4 | from string import printable
5 | import torch
6 | import transformers
7 | from transformers import BertModel, XLMRobertaModel, AlbertModel, T5EncoderModel
8 |
9 | from src import utils
10 |
11 |
12 | class Contriever(BertModel):
13 | def __init__(self, config, pooling="average", **kwargs):
14 | super().__init__(config, add_pooling_layer=False)
15 | if not hasattr(config, "pooling"):
16 | self.config.pooling = pooling
17 |
18 | def forward(
19 | self,
20 | input_ids=None,
21 | attention_mask=None,
22 | token_type_ids=None,
23 | position_ids=None,
24 | head_mask=None,
25 | inputs_embeds=None,
26 | encoder_hidden_states=None,
27 | encoder_attention_mask=None,
28 | output_attentions=None,
29 | output_hidden_states=None,
30 | normalize=False,
31 | ):
32 |
33 | model_output = super().forward(
34 | input_ids=input_ids,
35 | attention_mask=attention_mask,
36 | token_type_ids=token_type_ids,
37 | position_ids=position_ids,
38 | head_mask=head_mask,
39 | inputs_embeds=inputs_embeds,
40 | encoder_hidden_states=encoder_hidden_states,
41 | encoder_attention_mask=encoder_attention_mask,
42 | output_attentions=output_attentions,
43 | output_hidden_states=output_hidden_states,
44 | )
45 |
46 | last_hidden = model_output["last_hidden_state"]
47 | last_hidden = last_hidden.masked_fill(~attention_mask[..., None].bool(), 0.0)
48 |
49 | if self.config.pooling == "average":
50 | emb = last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
51 | elif self.config.pooling == "cls":
52 | emb = last_hidden[:, 0]
53 |
54 | if normalize:
55 | emb = torch.nn.functional.normalize(emb, dim=-1)
56 | return emb
57 |
58 |
59 | class XLMRetriever(XLMRobertaModel):
60 | def __init__(self, config, pooling="average", **kwargs):
61 | super().__init__(config, add_pooling_layer=False)
62 | if not hasattr(config, "pooling"):
63 | self.config.pooling = pooling
64 |
65 | def forward(
66 | self,
67 | input_ids=None,
68 | attention_mask=None,
69 | token_type_ids=None,
70 | position_ids=None,
71 | head_mask=None,
72 | inputs_embeds=None,
73 | encoder_hidden_states=None,
74 | encoder_attention_mask=None,
75 | output_attentions=None,
76 | output_hidden_states=None,
77 | normalize=False,
78 | ):
79 |
80 | model_output = super().forward(
81 | input_ids=input_ids,
82 | attention_mask=attention_mask,
83 | token_type_ids=token_type_ids,
84 | position_ids=position_ids,
85 | head_mask=head_mask,
86 | inputs_embeds=inputs_embeds,
87 | encoder_hidden_states=encoder_hidden_states,
88 | encoder_attention_mask=encoder_attention_mask,
89 | output_attentions=output_attentions,
90 | output_hidden_states=output_hidden_states,
91 | )
92 |
93 | last_hidden = model_output["last_hidden_state"]
94 | last_hidden = last_hidden.masked_fill(~attention_mask[..., None].bool(), 0.0)
95 | if self.config.pooling == "average":
96 | emb = last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
97 | elif self.config.pooling == "cls":
98 | emb = last_hidden[:, 0]
99 | if normalize:
100 | emb = torch.nn.functional.normalize(emb, dim=-1)
101 | return emb
102 |
103 |
104 | class ALBERTRetriever(AlbertModel):
105 | def __init__(self, config, pooling="average", **kwargs):
106 | super().__init__(config, add_pooling_layer=False)
107 | if not hasattr(config, "pooling"):
108 | self.config.pooling = pooling
109 |
110 | def forward(
111 | self,
112 | input_ids=None,
113 | attention_mask=None,
114 | token_type_ids=None,
115 | position_ids=None,
116 | head_mask=None,
117 | inputs_embeds=None,
118 | encoder_hidden_states=None,
119 | encoder_attention_mask=None,
120 | output_attentions=None,
121 | output_hidden_states=None,
122 | normalize=False,
123 | ):
124 |
125 | model_output = super().forward(
126 | input_ids=input_ids,
127 | attention_mask=attention_mask,
128 | token_type_ids=token_type_ids,
129 | position_ids=position_ids,
130 | head_mask=head_mask,
131 | inputs_embeds=inputs_embeds,
132 | output_attentions=output_attentions,
133 | output_hidden_states=output_hidden_states,
134 | )
135 |
136 | last_hidden = model_output["last_hidden_state"]
137 | last_hidden = last_hidden.masked_fill(~attention_mask[..., None].bool(), 0.0)
138 | if self.config.pooling == "average":
139 | emb = last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
140 | elif self.config.pooling == "cls":
141 | emb = last_hidden[:, 0]
142 | if normalize:
143 | emb = torch.nn.functional.normalize(emb, dim=-1)
144 | return emb
145 |
146 | class T5Contriever(T5EncoderModel):
147 | def __init__(self, config, pooling="average", **kwargs):
148 | super().__init__(config)
149 | if not hasattr(config, "pooling"):
150 | self.config.pooling = pooling
151 |
152 | def forward(
153 | self,
154 | input_ids=None,
155 | attention_mask=None,
156 | token_type_ids=None,
157 | position_ids=None,
158 | head_mask=None,
159 | inputs_embeds=None,
160 | encoder_hidden_states=None,
161 | encoder_attention_mask=None,
162 | output_attentions=None,
163 | output_hidden_states=None,
164 | normalize=False,
165 | ):
166 |
167 | model_output = super().forward(
168 | input_ids=input_ids,
169 | attention_mask=attention_mask,
170 | head_mask=head_mask,
171 | inputs_embeds=inputs_embeds,
172 | output_attentions=output_attentions,
173 | output_hidden_states=output_hidden_states,
174 | )
175 |
176 | last_hidden = model_output["last_hidden_state"]
177 | last_hidden = last_hidden.masked_fill(~attention_mask[..., None].bool(), 0.0)
178 |
179 | if self.config.pooling == "average":
180 | emb = last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
181 | elif self.config.pooling == "cls":
182 | emb = last_hidden[:, 0]
183 |
184 | if normalize:
185 | emb = torch.nn.functional.normalize(emb, dim=-1)
186 | return emb
187 |
188 | def load_retriever(model_path, pooling="average", random_init=False):
189 | # try: check if model exists locally
190 | path = os.path.join(model_path, "checkpoint.pth")
191 | if os.path.exists(path):
192 | pretrained_dict = torch.load(path, map_location="cpu")
193 | opt = pretrained_dict["opt"]
194 | if hasattr(opt, "retriever_model_id"):
195 | retriever_model_id = opt.retriever_model_id
196 | else:
197 | # retriever_model_id = "bert-base-uncased"
198 | retriever_model_id = "bert-base-multilingual-cased"
199 | tokenizer = utils.load_hf(transformers.AutoTokenizer, retriever_model_id)
200 | cfg = utils.load_hf(transformers.AutoConfig, retriever_model_id)
201 | if "xlm" in retriever_model_id:
202 | model_class = XLMRetriever
203 | elif "albert" in retriever_model_id:
204 | print("Albert Contriever")
205 | model_class = ALBERTRetriever
206 | elif "t5" in retriever_model_id or "T0" in retriever_model_id or "gtr" in retriever_model_id:
207 | model_class = T5Contriever
208 | else:
209 | model_class = Contriever
210 | retriever = model_class(cfg)
211 | pretrained_dict = pretrained_dict["model"]
212 |
213 | if any("encoder_q." in key for key in pretrained_dict.keys()): # test if model is defined with moco class
214 | pretrained_dict = {k.replace("encoder_q.", ""): v for k, v in pretrained_dict.items() if "encoder_q." in k}
215 | # elif any("encoder." in key for key in pretrained_dict.keys()): # test if model is defined with inbatch class
216 | # pretrained_dict = {k.replace("encoder.", ""): v for k, v in pretrained_dict.items() if "encoder." in k}
217 | retriever.load_state_dict(pretrained_dict)
218 | else:
219 | retriever_model_id = model_path
220 | if "xlm" in retriever_model_id:
221 | model_class = XLMRetriever
222 | elif "albert" in retriever_model_id:
223 | model_class = ALBERTRetriever
224 | elif "t5" in retriever_model_id or "T0" in retriever_model_id or "gtr" in retriever_model_id:
225 | model_class = T5Contriever
226 | else:
227 | model_class = Contriever
228 | cfg = utils.load_hf(transformers.AutoConfig, model_path)
229 | tokenizer = utils.load_hf(transformers.AutoTokenizer, model_path)
230 | retriever = utils.load_hf(model_class, model_path)
231 |
232 | return retriever, tokenizer, retriever_model_id
233 |
--------------------------------------------------------------------------------
/TART/src/data.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2 |
3 | import os
4 | import glob
5 | import torch
6 | import random
7 | import json
8 | import csv
9 | import numpy as np
10 | import numpy.random
11 | import logging
12 | from collections import defaultdict
13 | import torch.distributed as dist
14 | import io
15 |
16 | from src import dist_utils
17 |
18 | logger = logging.getLogger(__name__)
19 |
20 |
21 | def load_data(opt, tokenizer):
22 | datasets = {}
23 | for path in opt.train_data:
24 | data = load_dataset(path, opt.loading_mode)
25 | if data is not None:
26 | datasets[path] = Dataset(data, opt.chunk_length, tokenizer, opt)
27 | dataset = MultiDataset(datasets)
28 | dataset.set_prob(coeff=opt.sampling_coefficient)
29 | return dataset
30 |
31 |
32 | def load_dataset(data_path, loading_mode):
33 | files = glob.glob(os.path.join(data_path, "*.p*"))
34 | files.sort()
35 | tensors = []
36 | if loading_mode == "split":
37 | files_split = list(np.array_split(files, dist_utils.get_world_size()))[dist_utils.get_rank()]
38 | for filepath in files_split:
39 | try:
40 | tensors.append(torch.load(filepath, map_location="cpu"))
41 | except:
42 | logger.warning(f"Unable to load file {filepath}")
43 | elif loading_mode == "full":
44 | for fin in files:
45 | tensors.append(torch.load(fin, map_location="cpu"))
46 | elif loading_mode == "single":
47 | tensors.append(torch.load(files[0], map_location="cpu"))
48 | if len(tensors) == 0:
49 | return None
50 | tensor = torch.cat(tensors)
51 | return tensor
52 |
53 |
54 | class MultiDataset(torch.utils.data.Dataset):
55 | def __init__(self, datasets):
56 |
57 | self.datasets = datasets
58 | self.prob = [1 / len(self.datasets) for _ in self.datasets]
59 | self.dataset_ids = list(self.datasets.keys())
60 |
61 | def __len__(self):
62 | return sum([len(dataset) for dataset in self.datasets.values()])
63 |
64 | def __getitem__(self, index):
65 | dataset_idx = numpy.random.choice(range(len(self.prob)), 1, p=self.prob)[0]
66 | did = self.dataset_ids[dataset_idx]
67 | index = random.randint(0, len(self.datasets[did]) - 1)
68 | sample = self.datasets[did][index]
69 | sample["dataset_id"] = did
70 | return sample
71 |
72 | def generate_offset(self):
73 | for dataset in self.datasets.values():
74 | dataset.generate_offset()
75 |
76 | def set_prob(self, coeff=0.0):
77 |
78 | prob = np.array([float(len(dataset)) for _, dataset in self.datasets.items()])
79 | prob /= prob.sum()
80 | prob = np.array([p**coeff for p in prob])
81 | prob /= prob.sum()
82 | self.prob = prob
83 |
84 |
85 | class Dataset(torch.utils.data.Dataset):
86 | """Monolingual dataset based on a list of paths"""
87 |
88 | def __init__(self, data, chunk_length, tokenizer, opt):
89 |
90 | self.data = data
91 | self.chunk_length = chunk_length
92 | self.tokenizer = tokenizer
93 | self.opt = opt
94 | self.generate_offset()
95 |
96 | def __len__(self):
97 | return (self.data.size(0) - self.offset) // self.chunk_length
98 |
99 | def __getitem__(self, index):
100 | start_idx = self.offset + index * self.chunk_length
101 | end_idx = start_idx + self.chunk_length
102 | tokens = self.data[start_idx:end_idx]
103 | q_tokens = randomcrop(tokens, self.opt.ratio_min, self.opt.ratio_max)
104 | k_tokens = randomcrop(tokens, self.opt.ratio_min, self.opt.ratio_max)
105 | q_tokens = apply_augmentation(q_tokens, self.opt)
106 | q_tokens = add_bos_eos(q_tokens, self.tokenizer.bos_token_id, self.tokenizer.eos_token_id)
107 | k_tokens = apply_augmentation(k_tokens, self.opt)
108 | k_tokens = add_bos_eos(k_tokens, self.tokenizer.bos_token_id, self.tokenizer.eos_token_id)
109 |
110 | return {"q_tokens": q_tokens, "k_tokens": k_tokens}
111 |
112 | def generate_offset(self):
113 | self.offset = random.randint(0, self.chunk_length - 1)
114 |
115 |
116 | class Collator(object):
117 | def __init__(self, opt):
118 | self.opt = opt
119 |
120 | def __call__(self, batch_examples):
121 |
122 | batch = defaultdict(list)
123 | for example in batch_examples:
124 | for k, v in example.items():
125 | batch[k].append(v)
126 |
127 | q_tokens, q_mask = build_mask(batch["q_tokens"])
128 | k_tokens, k_mask = build_mask(batch["k_tokens"])
129 |
130 | batch["q_tokens"] = q_tokens
131 | batch["q_mask"] = q_mask
132 | batch["k_tokens"] = k_tokens
133 | batch["k_mask"] = k_mask
134 |
135 | return batch
136 |
137 |
138 | def randomcrop(x, ratio_min, ratio_max):
139 |
140 | ratio = random.uniform(ratio_min, ratio_max)
141 | length = int(len(x) * ratio)
142 | start = random.randint(0, len(x) - length)
143 | end = start + length
144 | crop = x[start:end].clone()
145 | return crop
146 |
147 |
148 | def build_mask(tensors):
149 | shapes = [x.shape for x in tensors]
150 | maxlength = max([len(x) for x in tensors])
151 | returnmasks = []
152 | ids = []
153 | for k, x in enumerate(tensors):
154 | returnmasks.append(torch.tensor([1] * len(x) + [0] * (maxlength - len(x))))
155 | ids.append(torch.cat((x, torch.tensor([0] * (maxlength - len(x))))))
156 | ids = torch.stack(ids, dim=0).long()
157 | returnmasks = torch.stack(returnmasks, dim=0).bool()
158 | return ids, returnmasks
159 |
160 |
161 | def add_token(x, token):
162 | x = torch.cat((torch.tensor([token]), x))
163 | return x
164 |
165 |
166 | def deleteword(x, p=0.1):
167 | mask = np.random.rand(len(x))
168 | x = [e for e, m in zip(x, mask) if m > p]
169 | return x
170 |
171 |
172 | def replaceword(x, min_random, max_random, p=0.1):
173 | mask = np.random.rand(len(x))
174 | x = [e if m > p else random.randint(min_random, max_random) for e, m in zip(x, mask)]
175 | return x
176 |
177 |
178 | def maskword(x, mask_id, p=0.1):
179 | mask = np.random.rand(len(x))
180 | x = [e if m > p else mask_id for e, m in zip(x, mask)]
181 | return x
182 |
183 |
184 | def shuffleword(x, p=0.1):
185 | count = (np.random.rand(len(x)) < p).sum()
186 | """Shuffles any n number of values in a list"""
187 | indices_to_shuffle = random.sample(range(len(x)), k=count)
188 | to_shuffle = [x[i] for i in indices_to_shuffle]
189 | random.shuffle(to_shuffle)
190 | for index, value in enumerate(to_shuffle):
191 | old_index = indices_to_shuffle[index]
192 | x[old_index] = value
193 | return x
194 |
195 |
196 | def apply_augmentation(x, opt):
197 | if opt.augmentation == "mask":
198 | return torch.tensor(maskword(x, mask_id=opt.mask_id, p=opt.prob_augmentation))
199 | elif opt.augmentation == "replace":
200 | return torch.tensor(
201 | replaceword(x, min_random=opt.start_id, max_random=opt.vocab_size - 1, p=opt.prob_augmentation)
202 | )
203 | elif opt.augmentation == "delete":
204 | return torch.tensor(deleteword(x, p=opt.prob_augmentation))
205 | elif opt.augmentation == "shuffle":
206 | return torch.tensor(shuffleword(x, p=opt.prob_augmentation))
207 | else:
208 | if not isinstance(x, torch.Tensor):
209 | x = torch.Tensor(x)
210 | return x
211 |
212 |
213 | def add_bos_eos(x, bos_token_id, eos_token_id):
214 | if not isinstance(x, torch.Tensor):
215 | x = torch.Tensor(x)
216 | if bos_token_id is None and eos_token_id is not None:
217 | x = torch.cat([x.clone().detach(), torch.tensor([eos_token_id])])
218 | elif bos_token_id is not None and eos_token_id is None:
219 | x = torch.cat([torch.tensor([bos_token_id]), x.clone().detach()])
220 | elif bos_token_id is None and eos_token_id is None:
221 | pass
222 | else:
223 | x = torch.cat([torch.tensor([bos_token_id]), x.clone().detach(), torch.tensor([eos_token_id])])
224 | return x
225 |
226 |
227 | # Used for passage retrieval
228 | def load_passages(path):
229 | if not os.path.exists(path):
230 | logger.info(f"{path} does not exist")
231 | return
232 | logger.info(f"Loading passages from: {path}")
233 | passages = []
234 | with open(path, encoding='utf-8') as fin:
235 | if path.endswith(".jsonl"):
236 | for k, line in enumerate(fin):
237 | ex = json.loads(line)
238 | passages.append(ex)
239 | else:
240 | data = fin.read()
241 | data = data.replace("\0", "") # get rid of null-bytes
242 | data_repr = io.StringIO(data)
243 | reader = csv.reader(data_repr, delimiter="\t")
244 | for k, row in enumerate(reader):
245 | if not row[0] == "id":
246 | ex = {"id": row[0], "title": row[2], "text": row[1]} if len(row) == 3 else {"id": row[0], "text": row[1]}
247 | passages.append(ex)
248 | if path.endswith(".jsonl") and "_id" in passages[0]:
249 | for item in passages:
250 | item["id"] = item["_id"]
251 | if type(item["text"]) is float:
252 | item["text"] = item["title"]
253 | if type(item["title"]) is float:
254 | item["title"] = ""
255 | return passages
256 |
--------------------------------------------------------------------------------
/TART/src/dist_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2 |
3 | import torch
4 | import torch.distributed as dist
5 |
6 |
7 | class Gather(torch.autograd.Function):
8 | @staticmethod
9 | def forward(ctx, x: torch.tensor):
10 | output = [torch.zeros_like(x) for _ in range(dist.get_world_size())]
11 | dist.all_gather(output, x)
12 | return tuple(output)
13 |
14 | @staticmethod
15 | def backward(ctx, *grads):
16 | all_gradients = torch.stack(grads)
17 | dist.all_reduce(all_gradients)
18 | return all_gradients[dist.get_rank()]
19 |
20 |
21 | def gather(x: torch.tensor):
22 | if not dist.is_initialized():
23 | return x
24 | x_gather = Gather.apply(x)
25 | x_gather = torch.cat(x_gather, dim=0)
26 | return x_gather
27 |
28 |
29 | @torch.no_grad()
30 | def gather_nograd(x: torch.tensor):
31 | if not dist.is_initialized():
32 | return x
33 | x_gather = [torch.ones_like(x) for _ in range(dist.get_world_size())]
34 | dist.all_gather(x_gather, x, async_op=False)
35 |
36 | x_gather = torch.cat(x_gather, dim=0)
37 | return x_gather
38 |
39 |
40 | @torch.no_grad()
41 | def varsize_gather_nograd(x: torch.Tensor):
42 | """gather tensors of different sizes along the first dimension"""
43 | if not dist.is_initialized():
44 | return x
45 |
46 | # determine max size
47 | size = torch.tensor([x.shape[0]], device=x.device, dtype=torch.int)
48 | allsizes = [torch.zeros_like(size) for _ in range(dist.get_world_size())]
49 | dist.all_gather(allsizes, size)
50 | max_size = max([size.cpu().max() for size in allsizes])
51 |
52 | padded = torch.empty(max_size, *x.shape[1:], dtype=x.dtype, device=x.device)
53 | padded[: x.shape[0]] = x
54 | output = [torch.zeros_like(padded) for _ in range(dist.get_world_size())]
55 | dist.all_gather(output, padded)
56 |
57 | output = [tensor[: allsizes[k]] for k, tensor in enumerate(output)]
58 | output = torch.cat(output, dim=0)
59 |
60 | return output
61 |
62 |
63 | @torch.no_grad()
64 | def get_varsize(x: torch.Tensor):
65 | """gather tensors of different sizes along the first dimension"""
66 | if not dist.is_initialized():
67 | return [x.shape[0]]
68 |
69 | # determine max size
70 | size = torch.tensor([x.shape[0]], device=x.device, dtype=torch.int)
71 | allsizes = [torch.zeros_like(size) for _ in range(dist.get_world_size())]
72 | dist.all_gather(allsizes, size)
73 | allsizes = torch.cat(allsizes)
74 | return allsizes
75 |
76 |
77 | def get_rank():
78 | if not dist.is_available():
79 | return 0
80 | if not dist.is_initialized():
81 | return 0
82 | return dist.get_rank()
83 |
84 |
85 | def is_main():
86 | return get_rank() == 0
87 |
88 |
89 | def get_world_size():
90 | if not dist.is_initialized():
91 | return 1
92 | else:
93 | return dist.get_world_size()
94 |
95 |
96 | def barrier():
97 | if dist.is_initialized():
98 | dist.barrier()
99 |
100 |
101 | def average_main(x):
102 | if not dist.is_initialized():
103 | return x
104 | if dist.is_initialized() and dist.get_world_size() > 1:
105 | dist.reduce(x, 0, op=dist.ReduceOp.SUM)
106 | if is_main():
107 | x = x / dist.get_world_size()
108 | return x
109 |
110 |
111 | def sum_main(x):
112 | if not dist.is_initialized():
113 | return x
114 | if dist.is_initialized() and dist.get_world_size() > 1:
115 | dist.reduce(x, 0, op=dist.ReduceOp.SUM)
116 | return x
117 |
118 |
119 | def weighted_average(x, count):
120 | if not dist.is_initialized():
121 | if isinstance(x, torch.Tensor):
122 | x = x.item()
123 | return x, count
124 | t_loss = torch.tensor([x * count]).cuda()
125 | t_total = torch.tensor([count]).cuda()
126 | t_loss = sum_main(t_loss)
127 | t_total = sum_main(t_total)
128 | return (t_loss / t_total).item(), t_total.item()
129 |
--------------------------------------------------------------------------------
/TART/src/evaluation.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2 |
3 | import collections
4 | import logging
5 | import regex
6 | import string
7 | import unicodedata
8 | from functools import partial
9 | from multiprocessing import Pool as ProcessPool
10 | from typing import Tuple, List, Dict
11 | import numpy as np
12 | from collections import Counter
13 |
14 | """
15 | Evaluation code from DPR: https://github.com/facebookresearch/DPR
16 | """
17 |
18 | class SimpleTokenizer(object):
19 | ALPHA_NUM = r'[\p{L}\p{N}\p{M}]+'
20 | NON_WS = r'[^\p{Z}\p{C}]'
21 |
22 | def __init__(self):
23 | """
24 | Args:
25 | annotators: None or empty set (only tokenizes).
26 | """
27 | self._regexp = regex.compile(
28 | '(%s)|(%s)' % (self.ALPHA_NUM, self.NON_WS),
29 | flags=regex.IGNORECASE + regex.UNICODE + regex.MULTILINE
30 | )
31 |
32 | def tokenize(self, text, uncased=False):
33 | matches = [m for m in self._regexp.finditer(text)]
34 | if uncased:
35 | tokens = [m.group().lower() for m in matches]
36 | else:
37 | tokens = [m.group() for m in matches]
38 | return tokens
39 |
40 | logger = logging.getLogger(__name__)
41 |
42 | QAMatchStats = collections.namedtuple('QAMatchStats', ['top_k_hits', 'questions_doc_hits'])
43 |
44 | def calculate_matches(data: List, workers_num: int):
45 | """
46 | Evaluates answers presence in the set of documents. This function is supposed to be used with a large collection of
47 | documents and results. It internally forks multiple sub-processes for evaluation and then merges results
48 | :param all_docs: dictionary of the entire documents database. doc_id -> (doc_text, title)
49 | :param answers: list of answers's list. One list per question
50 | :param closest_docs: document ids of the top results along with their scores
51 | :param workers_num: amount of parallel threads to process data
52 | :param match_type: type of answer matching. Refer to has_answer code for available options
53 | :return: matching information tuple.
54 | top_k_hits - a list where the index is the amount of top documents retrieved and the value is the total amount of
55 | valid matches across an entire dataset.
56 | questions_doc_hits - more detailed info with answer matches for every question and every retrieved document
57 | """
58 |
59 | logger.info('Matching answers in top docs...')
60 |
61 | tokenizer = SimpleTokenizer()
62 | get_score_partial = partial(check_answer, tokenizer=tokenizer)
63 |
64 | processes = ProcessPool(processes=workers_num)
65 | scores = processes.map(get_score_partial, data)
66 |
67 | logger.info('Per question validation results len=%d', len(scores))
68 |
69 | n_docs = len(data[0]['ctxs'])
70 | top_k_hits = [0] * n_docs
71 | for question_hits in scores:
72 | best_hit = next((i for i, x in enumerate(question_hits) if x), None)
73 | if best_hit is not None:
74 | top_k_hits[best_hit:] = [v + 1 for v in top_k_hits[best_hit:]]
75 |
76 | return QAMatchStats(top_k_hits, scores)
77 |
78 | def check_answer(example, tokenizer) -> List[bool]:
79 | """Search through all the top docs to see if they have any of the answers."""
80 | answers = example['answers']
81 | ctxs = example['ctxs']
82 |
83 | hits = []
84 |
85 | for i, doc in enumerate(ctxs):
86 | text = doc['text']
87 |
88 | if text is None: # cannot find the document for some reason
89 | logger.warning("no doc in db")
90 | hits.append(False)
91 | continue
92 |
93 | hits.append(has_answer(answers, text, tokenizer))
94 |
95 | return hits
96 |
97 | def has_answer(answers, text, tokenizer) -> bool:
98 | """Check if a document contains an answer string."""
99 | text = _normalize(text)
100 | text = tokenizer.tokenize(text, uncased=True)
101 |
102 | for answer in answers:
103 | answer = _normalize(answer)
104 | answer = tokenizer.tokenize(answer, uncased=True)
105 | for i in range(0, len(text) - len(answer) + 1):
106 | if answer == text[i: i + len(answer)]:
107 | return True
108 | return False
109 |
110 | #################################################
111 | ######## READER EVALUATION ########
112 | #################################################
113 |
114 | def _normalize(text):
115 | return unicodedata.normalize('NFD', text)
116 |
117 | #Normalization and score functions from SQuAD evaluation script https://worksheets.codalab.org/rest/bundles/0x6b567e1cf2e041ec80d7098f031c5c9e/contents/blob/
118 | def normalize_answer(s):
119 | def remove_articles(text):
120 | return regex.sub(r'\b(a|an|the)\b', ' ', text)
121 |
122 | def white_space_fix(text):
123 | return ' '.join(text.split())
124 |
125 | def remove_punc(text):
126 | exclude = set(string.punctuation)
127 | return ''.join(ch for ch in text if ch not in exclude)
128 |
129 | def lower(text):
130 | return text.lower()
131 |
132 | return white_space_fix(remove_articles(remove_punc(lower(s))))
133 |
134 | def em(prediction, ground_truth):
135 | return normalize_answer(prediction) == normalize_answer(ground_truth)
136 |
137 | def f1(prediction, ground_truth):
138 | prediction_tokens = normalize_answer(prediction).split()
139 | ground_truth_tokens = normalize_answer(ground_truth).split()
140 | common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
141 | num_same = sum(common.values())
142 | if num_same == 0:
143 | return 0
144 | precision = 1.0 * num_same / len(prediction_tokens)
145 | recall = 1.0 * num_same / len(ground_truth_tokens)
146 | f1 = (2 * precision * recall) / (precision + recall)
147 | return f1
148 |
149 | def f1_score(prediction, ground_truths):
150 | return max([f1(prediction, gt) for gt in ground_truths])
151 |
152 | def exact_match_score(prediction, ground_truths):
153 | return max([em(prediction, gt) for gt in ground_truths])
154 |
155 | ####################################################
156 | ######## RETRIEVER EVALUATION ########
157 | ####################################################
158 |
159 | def eval_batch(scores, inversions, avg_topk, idx_topk):
160 | for k, s in enumerate(scores):
161 | s = s.cpu().numpy()
162 | sorted_idx = np.argsort(-s)
163 | score(sorted_idx, inversions, avg_topk, idx_topk)
164 |
165 | def count_inversions(arr):
166 | inv_count = 0
167 | lenarr = len(arr)
168 | for i in range(lenarr):
169 | for j in range(i + 1, lenarr):
170 | if (arr[i] > arr[j]):
171 | inv_count += 1
172 | return inv_count
173 |
174 | def score(x, inversions, avg_topk, idx_topk):
175 | x = np.array(x)
176 | inversions.append(count_inversions(x))
177 | for k in avg_topk:
178 | # ratio of passages in the predicted top-k that are
179 | # also in the topk given by gold score
180 | avg_pred_topk = (x[:k] 0:
45 | random_negatives = random.sample(example["negative_ctxs"], n_random_negatives)
46 | negatives += random_negatives
47 | if n_hard_negatives > 0:
48 | hard_negatives = random.sample(
49 | example["hard_negative_ctxs"][self.hard_negative_min_idx :], n_hard_negatives
50 | )
51 | negatives += hard_negatives
52 | else:
53 | gold = example["positive_ctxs"][0]
54 | nidx = 0
55 | if "negative_ctxs" in example:
56 | negatives = [example["negative_ctxs"][nidx]]
57 | else:
58 | negatives = []
59 |
60 | if "title" in gold and type(gold["title"]) is str and len(gold) > 0:
61 | gold = gold["title"] + " " + gold["text"]
62 | else:
63 | gold = gold["text"]
64 |
65 | negatives_new = []
66 | for n in negatives:
67 | if "title" in n and type(n["title"]) is str and len(n["title"]) > 0:
68 | negatives_new.append(n["title"] + " " + n["text"])
69 | else:
70 | negatives_new.append(n["text"])
71 |
72 | negatives = negatives_new
73 | # negatives = [
74 | # n["title"] + " " + n["text"] if ("title" in n and type(n["title"]) is str and len(n["title"]) > 0) else n["text"] for n in negatives
75 | # ]
76 |
77 | example = {
78 | "query": self.normalize_fn(question),
79 | "gold": self.normalize_fn(gold),
80 | "negatives": [self.normalize_fn(n) for n in negatives],
81 | }
82 | return example
83 |
84 | def _load_data(self, datapaths, global_rank, world_size, maxload):
85 | counter = 0
86 | self.data = []
87 | for path in datapaths:
88 | path = str(path)
89 | if path.endswith(".jsonl"):
90 | file_data, counter = self._load_data_jsonl(path, global_rank, world_size, counter, maxload)
91 | elif path.endswith(".json"):
92 | file_data, counter = self._load_data_json(path, global_rank, world_size, counter, maxload)
93 | self.data.extend(file_data)
94 | if maxload is not None and maxload > 0 and counter >= maxload:
95 | break
96 |
97 | def _load_data_json(self, path, global_rank, world_size, counter, maxload=None):
98 | examples = []
99 | with open(path, "r") as fin:
100 | data = json.load(fin)
101 | for example in data:
102 | counter += 1
103 | if global_rank > -1 and not counter % world_size == global_rank:
104 | continue
105 | examples.append(example)
106 | if maxload is not None and maxload > 0 and counter == maxload:
107 | break
108 |
109 | return examples, counter
110 |
111 | def _load_data_jsonl(self, path, global_rank, world_size, counter, maxload=None):
112 | examples = []
113 | with open(path, "r") as fin:
114 | for line in fin:
115 | counter += 1
116 | if global_rank > -1 and not counter % world_size == global_rank:
117 | continue
118 | example = json.loads(line)
119 | examples.append(example)
120 | if maxload is not None and maxload > 0 and counter == maxload:
121 | break
122 |
123 | return examples, counter
124 |
125 | def sample_n_hard_negatives(self, ex):
126 |
127 | if "hard_negative_ctxs" in ex:
128 | n_hard_negatives = sum([random.random() < self.negative_hard_ratio for _ in range(self.negative_ctxs)])
129 | n_hard_negatives = min(n_hard_negatives, len(ex["hard_negative_ctxs"][self.negative_hard_min_idx :]))
130 | else:
131 | n_hard_negatives = 0
132 | n_random_negatives = self.negative_ctxs - n_hard_negatives
133 | if "negative_ctxs" in ex:
134 | n_random_negatives = min(n_random_negatives, len(ex["negative_ctxs"]))
135 | else:
136 | n_random_negatives = 0
137 | # print(n_hard_negatives, n_random_negatives)
138 | return n_hard_negatives, n_random_negatives
139 |
140 |
141 | class Collator(object):
142 | def __init__(self, tokenizer, passage_maxlength=200):
143 | self.tokenizer = tokenizer
144 | self.passage_maxlength = passage_maxlength
145 |
146 | def __call__(self, batch):
147 | queries = [ex["query"] for ex in batch]
148 | golds = [ex["gold"] for ex in batch]
149 | negs = [item for ex in batch for item in ex["negatives"]]
150 | allpassages = golds + negs
151 |
152 | qout = self.tokenizer.batch_encode_plus(
153 | queries,
154 | max_length=self.passage_maxlength,
155 | truncation=True,
156 | padding=True,
157 | add_special_tokens=True,
158 | return_tensors="pt",
159 | )
160 | kout = self.tokenizer.batch_encode_plus(
161 | allpassages,
162 | max_length=self.passage_maxlength,
163 | truncation=True,
164 | padding=True,
165 | add_special_tokens=True,
166 | return_tensors="pt",
167 | )
168 | q_tokens, q_mask = qout["input_ids"], qout["attention_mask"].bool()
169 | k_tokens, k_mask = kout["input_ids"], kout["attention_mask"].bool()
170 |
171 | g_tokens, g_mask = k_tokens[: len(golds)], k_mask[: len(golds)]
172 | n_tokens, n_mask = k_tokens[len(golds) :], k_mask[len(golds) :]
173 |
174 | batch = {
175 | "q_tokens": q_tokens,
176 | "q_mask": q_mask,
177 | "k_tokens": k_tokens,
178 | "k_mask": k_mask,
179 | "g_tokens": g_tokens,
180 | "g_mask": g_mask,
181 | "n_tokens": n_tokens,
182 | "n_mask": n_mask,
183 | }
184 |
185 | return batch
186 |
187 | class DatasetKD(torch.utils.data.Dataset):
188 | def __init__(self,
189 | datapaths,
190 | n_context=50,
191 | question_prefix='',
192 | title_prefix='',
193 | passage_prefix='',
194 | global_rank=-1,
195 | world_size=-1,
196 | maxload=None,
197 | random_sort=False):
198 |
199 | self._load_data(datapaths, global_rank, world_size, maxload)
200 | self.n_context = n_context
201 | self.question_prefix = question_prefix
202 | self.title_prefix = title_prefix
203 | self.passage_prefix = passage_prefix
204 | self.random_sort = random_sort
205 | self.normalize_fn = normalize_text.normalize if normalize_text else lambda x: x
206 | self.sort_data()
207 |
208 | def __len__(self):
209 | return len(self.data)
210 |
211 | def __getitem__(self, index):
212 | example = self.data[index]
213 | question = self.question_prefix + " " + example['question']
214 |
215 | if 'ctxs' in example and self.n_context is not None:
216 | f = self.title_prefix + " {} " + self.passage_prefix + " {}"
217 | if len(example['ctxs']) > self.n_context and self.random_sort is True:
218 | contexts = random.sample(example['ctxs'], k=self.n_context)
219 | else:
220 | contexts = example['ctxs'][:self.n_context]
221 | passages = [f.format(c['title'], c['text']) for c in contexts]
222 | scores = [float(c['gold_score']) for c in contexts]
223 | scores = torch.tensor(scores)
224 | if len(contexts) == 0:
225 | contexts = [question]
226 | else:
227 | passages, scores = None, None
228 |
229 | return {
230 | 'index' : index,
231 | 'question' : question,
232 | 'passages' : passages,
233 | 'scores' : scores
234 | }
235 |
236 | def sort_data(self):
237 | if self.n_context is None or not 'gold_score' in self.data[0]['ctxs'][0]:
238 | return
239 | for ex in self.data:
240 | ex['ctxs'].sort(key=lambda x: float(x['gold_score']), reverse=True)
241 |
242 | def get_example(self, index):
243 | return self.data[index]
244 |
245 | def _load_data(self, datapaths, global_rank, world_size, maxload):
246 | counter = 0
247 | self.data = []
248 | for path in datapaths:
249 | path = str(path)
250 | if path.endswith(".jsonl"):
251 | file_data, counter = self._load_data_jsonl(path, global_rank, world_size, counter, maxload)
252 | elif path.endswith(".json"):
253 | file_data, counter = self._load_data_json(path, global_rank, world_size, counter, maxload)
254 | self.data.extend(file_data)
255 | if maxload is not None and maxload > 0 and counter >= maxload:
256 | break
257 |
258 | def _load_data_json(self, path, global_rank, world_size, counter, maxload=None):
259 | examples = []
260 | with open(path, "r") as fin:
261 | data = json.load(fin)
262 | for example in data:
263 | counter += 1
264 | if global_rank > -1 and not counter % world_size == global_rank:
265 | continue
266 | examples.append(example)
267 | if maxload is not None and maxload > 0 and counter == maxload:
268 | break
269 |
270 | return examples, counter
271 |
272 | def _load_data_jsonl(self, path, global_rank, world_size, counter, maxload=None):
273 | examples = []
274 | with open(path, "r") as fin:
275 | for line in fin:
276 | counter += 1
277 | if global_rank > -1 and not counter % world_size == global_rank:
278 | continue
279 | example = json.loads(line)
280 | examples.append(example)
281 | if maxload is not None and maxload > 0 and counter == maxload:
282 | break
283 |
284 | return examples, counter
285 |
286 | def encode_passages(batch_text_passages, tokenizer, max_length):
287 | passage_ids, passage_masks = [], []
288 | for k, text_passages in enumerate(batch_text_passages):
289 | p = tokenizer.batch_encode_plus(
290 | text_passages,
291 | max_length=max_length,
292 | pad_to_max_length=True,
293 | return_tensors='pt',
294 | truncation=True
295 | )
296 | passage_ids.append(p['input_ids'][None])
297 | passage_masks.append(p['attention_mask'][None])
298 |
299 | passage_ids = torch.cat(passage_ids, dim=0)
300 | passage_masks = torch.cat(passage_masks, dim=0)
301 | return passage_ids, passage_masks.bool()
302 |
303 | def load_data(data_path=None, global_rank=-1, world_size=-1):
304 | assert data_path
305 | if data_path.endswith('.jsonl'):
306 | data = open(data_path, 'r')
307 | elif data_path.endswith('.json'):
308 | with open(data_path, 'r') as fin:
309 | data = json.load(fin)
310 | examples = []
311 | for k, example in enumerate(data):
312 | if global_rank > -1 and not k%world_size==global_rank:
313 | continue
314 | if data_path is not None and data_path.endswith('.jsonl'):
315 | example = json.loads(example)
316 | if not 'id' in example:
317 | example['id'] = k
318 | for c in example['ctxs']:
319 | if not 'score' in c:
320 | c['score'] = 1.0 / (k + 1)
321 | examples.append(example)
322 | if data_path is not None and data_path.endswith('.jsonl'):
323 | data.close()
324 |
325 | return examples
326 |
327 | class CollatorKD(object):
328 | def __init__(self, tokenizer, passage_maxlength=200, question_maxlength=40):
329 | self.tokenizer = tokenizer
330 | self.passage_maxlength = passage_maxlength
331 | self.question_maxlength = question_maxlength
332 |
333 | def __call__(self, batch):
334 | index = torch.tensor([ex['index'] for ex in batch])
335 |
336 | question = [ex['question'] for ex in batch]
337 | question = self.tokenizer.batch_encode_plus(
338 | question,
339 | pad_to_max_length=True,
340 | return_tensors="pt",
341 | max_length=self.question_maxlength,
342 | truncation=True
343 | )
344 | question_ids = question['input_ids']
345 | question_mask = question['attention_mask'].bool()
346 |
347 | if batch[0]['scores'] is None or batch[0]['passages'] is None:
348 | return index, question_ids, question_mask, None, None, None
349 |
350 | scores = [ex['scores'] for ex in batch]
351 | scores = torch.stack(scores, dim=0)
352 |
353 | passages = [ex['passages'] for ex in batch]
354 | passage_ids, passage_masks = encode_passages(
355 | passages,
356 | self.tokenizer,
357 | self.passage_maxlength
358 | )
359 |
360 | batch = {
361 | "question_ids": question_ids,
362 | "question_mask": question_mask,
363 | "passage_ids": passage_ids,
364 | "passage_mask": passage_masks,
365 | "gold_score": scores
366 | }
367 |
368 | return batch
369 |
--------------------------------------------------------------------------------
/TART/src/inbatch.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2 |
3 | import torch
4 | import torch.nn as nn
5 | import numpy as np
6 | import math
7 | import random
8 | import transformers
9 | import logging
10 | import torch.distributed as dist
11 | import copy
12 | from src import contriever, dist_utils, utils
13 | import torch.nn.functional as F
14 |
15 | logger = logging.getLogger(__name__)
16 |
17 |
18 | class InBatch(nn.Module):
19 | def __init__(self, opt, retriever=None, tokenizer=None):
20 | super(InBatch, self).__init__()
21 |
22 | self.opt = opt
23 | self.norm_doc = opt.norm_doc
24 | self.norm_query = opt.norm_query
25 | self.label_smoothing = opt.label_smoothing
26 | if retriever is None or tokenizer is None:
27 | retriever, tokenizer = self._load_retriever(
28 | opt.retriever_model_id, pooling=opt.pooling, random_init=opt.random_init
29 | )
30 | self.tokenizer = tokenizer
31 | self.encoder = retriever
32 |
33 | def _load_retriever(self, model_id, pooling, random_init):
34 | print("load retrieval")
35 | print(model_id)
36 | if "xlm" in model_id:
37 | model_class = contriever.XLMRetriever
38 | elif "t5" in model_id or "T0" in model_id or "gtr" in model_id:
39 | print("loading t0")
40 | model_class = contriever.T5Contriever
41 | print(model_class)
42 | else:
43 | model_class = contriever.Contriever
44 |
45 | cfg = utils.load_hf(transformers.AutoConfig, model_id)
46 | tokenizer = utils.load_hf(transformers.AutoTokenizer, model_id)
47 | if random_init:
48 | retriever = model_class(cfg)
49 | else:
50 | retriever = utils.load_hf(model_class, model_id)
51 |
52 | if "bert-" in model_id:
53 | if tokenizer.bos_token_id is None:
54 | tokenizer.bos_token = "[CLS]"
55 | if tokenizer.eos_token_id is None:
56 | tokenizer.eos_token = "[SEP]"
57 |
58 | retriever.config.pooling = pooling
59 |
60 | return retriever, tokenizer
61 |
62 | def get_encoder(self):
63 | return self.encoder
64 |
65 | def forward(self, q_tokens, q_mask, k_tokens, k_mask, gold_scores=None, stats_prefix="", iter_stats={}, **kwargs):
66 |
67 | bsz = len(q_tokens)
68 | labels = torch.arange(0, bsz, dtype=torch.long, device=q_tokens.device)
69 |
70 | qemb = self.encoder(input_ids=q_tokens, attention_mask=q_mask, normalize=self.norm_query)
71 | kemb = self.encoder(input_ids=k_tokens, attention_mask=k_mask, normalize=self.norm_doc)
72 |
73 | gather_fn = dist_utils.gather
74 |
75 | gather_kemb = gather_fn(kemb)
76 |
77 | labels = labels + dist_utils.get_rank() * len(kemb)
78 |
79 | scores = torch.einsum("id, jd->ij", qemb / self.opt.temperature, gather_kemb)
80 |
81 | loss = torch.nn.functional.cross_entropy(scores, labels, label_smoothing=self.label_smoothing)
82 |
83 | # log stats
84 | if len(stats_prefix) > 0:
85 | stats_prefix = stats_prefix + "/"
86 | iter_stats[f"{stats_prefix}loss"] = (loss.item(), bsz)
87 |
88 | predicted_idx = torch.argmax(scores, dim=-1)
89 | accuracy = 100 * (predicted_idx == labels).float().mean()
90 | stdq = torch.std(qemb, dim=0).mean().item()
91 | stdk = torch.std(kemb, dim=0).mean().item()
92 | iter_stats[f"{stats_prefix}accuracy"] = (accuracy, bsz)
93 | iter_stats[f"{stats_prefix}stdq"] = (stdq, bsz)
94 | iter_stats[f"{stats_prefix}stdk"] = (stdk, bsz)
95 |
96 | return loss, iter_stats
97 |
98 |
99 |
100 | class ByInBatch(nn.Module):
101 | def __init__(self, opt, retriever=None, tokenizer=None):
102 | super(ByInBatch, self).__init__()
103 |
104 | self.opt = opt
105 | self.norm_doc = opt.norm_doc
106 | self.norm_query = opt.norm_query
107 | self.label_smoothing = opt.label_smoothing
108 | if retriever is None or tokenizer is None:
109 | retriever, tokenizer = self._load_retriever(
110 | opt.retriever_model_id, pooling=opt.pooling, random_init=opt.random_init
111 | )
112 | self.tokenizer = tokenizer
113 | self.q_encoder = copy.deepcopy(retriever)
114 | self.p_encoder = copy.deepcopy(retriever)
115 |
116 | def _load_retriever(self, model_id, pooling, random_init):
117 | cfg = utils.load_hf(transformers.AutoConfig, model_id)
118 | tokenizer = utils.load_hf(transformers.AutoTokenizer, model_id)
119 |
120 | if "xlm" in model_id:
121 | model_class = contriever.XLMRetriever
122 | else:
123 | model_class = contriever.Contriever
124 |
125 | if random_init:
126 | retriever = model_class(cfg)
127 | else:
128 | retriever = utils.load_hf(model_class, model_id)
129 |
130 | if "bert-" in model_id:
131 | if tokenizer.bos_token_id is None:
132 | tokenizer.bos_token = "[CLS]"
133 | if tokenizer.eos_token_id is None:
134 | tokenizer.eos_token = "[SEP]"
135 |
136 | retriever.config.pooling = pooling
137 |
138 | return retriever, tokenizer
139 |
140 | def get_q_encoder(self):
141 | return self.q_encoder
142 |
143 | def get_p_encoder(self):
144 | return self.p_encoder
145 |
146 | def forward(self, q_tokens, q_mask, k_tokens, k_mask, stats_prefix="", iter_stats={}, **kwargs):
147 |
148 | bsz = len(q_tokens)
149 | labels = torch.arange(0, bsz, dtype=torch.long, device=q_tokens.device)
150 |
151 | qemb = self.q_encoder(input_ids=q_tokens, attention_mask=q_mask, normalize=self.norm_query)
152 | kemb = self.p_encoder(input_ids=k_tokens, attention_mask=k_mask, normalize=self.norm_doc)
153 |
154 | gather_fn = dist_utils.gather
155 |
156 | gather_kemb = gather_fn(kemb)
157 |
158 | labels = labels + dist_utils.get_rank() * len(kemb)
159 |
160 | scores = torch.einsum("id, jd->ij", qemb / self.opt.temperature, gather_kemb)
161 |
162 | loss = torch.nn.functional.cross_entropy(scores, labels, label_smoothing=self.label_smoothing)
163 |
164 | # log stats
165 | if len(stats_prefix) > 0:
166 | stats_prefix = stats_prefix + "/"
167 | iter_stats[f"{stats_prefix}loss"] = (loss.item(), bsz)
168 |
169 | predicted_idx = torch.argmax(scores, dim=-1)
170 | accuracy = 100 * (predicted_idx == labels).float().mean()
171 | stdq = torch.std(qemb, dim=0).mean().item()
172 | stdk = torch.std(kemb, dim=0).mean().item()
173 | iter_stats[f"{stats_prefix}accuracy"] = (accuracy, bsz)
174 | iter_stats[f"{stats_prefix}stdq"] = (stdq, bsz)
175 | iter_stats[f"{stats_prefix}stdk"] = (stdk, bsz)
176 |
177 | return loss, iter_stats
178 |
179 |
180 | class InBatch_KD(nn.Module):
181 | def __init__(self, opt, retriever=None, tokenizer=None, loss_type="kl", temperature=1):
182 | super(InBatch_KD, self).__init__()
183 |
184 | self.opt = opt
185 | self.norm_doc = opt.norm_doc
186 | self.norm_query = opt.norm_query
187 | self.label_smoothing = opt.label_smoothing
188 | if retriever is None or tokenizer is None:
189 | retriever, tokenizer = self._load_retriever(
190 | opt.retriever_model_id, pooling=opt.pooling, random_init=opt.random_init
191 | )
192 | self.tokenizer = tokenizer
193 | self.encoder = retriever
194 | self.loss_type = loss_type
195 | self.temperature = temperature
196 | if loss_type == "kl":
197 | self.loss_fct = torch.nn.KLDivLoss()
198 | elif loss_type == "mse":
199 | self.loss_fct = torch.nn.MSELoss()
200 | else:
201 | raise NotImplementedError
202 |
203 | def _load_retriever(self, model_id, pooling, random_init):
204 | cfg = utils.load_hf(transformers.AutoConfig, model_id)
205 | tokenizer = utils.load_hf(transformers.AutoTokenizer, model_id)
206 |
207 | if "xlm" in model_id:
208 | model_class = contriever.XLMRetriever
209 | elif "t5" in model_id or "T0" in model_id or "gtr" in model_id:
210 | model_class = contriever.T5Contriever
211 | else:
212 | model_class = contriever.Contriever
213 |
214 | if random_init:
215 | retriever = model_class(cfg)
216 | else:
217 | retriever = utils.load_hf(model_class, model_id)
218 |
219 | if "bert-" in model_id:
220 | if tokenizer.bos_token_id is None:
221 | tokenizer.bos_token = "[CLS]"
222 | if tokenizer.eos_token_id is None:
223 | tokenizer.eos_token = "[SEP]"
224 |
225 | retriever.config.pooling = pooling
226 |
227 | return retriever, tokenizer
228 |
229 | def get_encoder(self):
230 | return self.encoder
231 |
232 | def forward(self, question_ids, question_mask, passage_ids, passage_mask, gold_score, stats_prefix="", iter_stats={}, **kwargs):
233 | question_output = self.encoder(input_ids=question_ids, attention_mask=question_mask, normalize=self.norm_query)
234 | bsz, n_passages, plen = passage_ids.size()
235 | passage_ids = passage_ids.view(bsz * n_passages, plen)
236 | passage_mask = passage_mask.view(bsz * n_passages, plen)
237 | passage_output = self.encoder(input_ids=passage_ids, attention_mask=passage_mask, normalize=self.norm_doc)
238 |
239 | score = torch.einsum(
240 | 'bd,bid->bi',
241 | question_output,
242 | passage_output.view(bsz, n_passages, -1)
243 | )
244 |
245 | score = score / np.sqrt(question_output.size(-1))
246 | if gold_score is not None:
247 | if self.loss_type == "kl":
248 | loss = self.kldivloss(score, gold_score)
249 | else:
250 | loss = self.mseloss(score, gold_score)
251 | else:
252 | loss = None
253 | # log stats
254 | if len(stats_prefix) > 0:
255 | stats_prefix = stats_prefix + "/"
256 | iter_stats[f"{stats_prefix}loss"] = (loss.item(), bsz)
257 |
258 | # predicted_idx = torch.argmax(scores, dim=-1)
259 | # accuracy = 100 * (predicted_idx == labels).float().mean()
260 | # stdq = torch.std(qemb, dim=0).mean().item()
261 | # stdk = torch.std(kemb, dim=0).mean().item()
262 | # iter_stats[f"{stats_prefix}accuracy"] = (accuracy, bsz)
263 | # iter_stats[f"{stats_prefix}stdq"] = (stdq, bsz)
264 | # iter_stats[f"{stats_prefix}stdk"] = (stdk, bsz)
265 | # print(loss)
266 | return loss, iter_stats
267 |
268 | def kldivloss(self, score, gold_score):
269 | # print("scores")
270 | # print(gold_score[0,:10])
271 | gold_score = torch.softmax(gold_score / self.temperature, dim=-1)
272 | # print(gold_score[0,:10])
273 | # print(score[0,:10])
274 | score = torch.nn.functional.log_softmax(score / self.temperature, dim=-1)
275 | # print(score[0,:10])
276 | loss = self.loss_fct(score, gold_score) * (self.temperature**2)
277 | # loss = F.kl_div(score, gold_score, size_average=False) * (self.temperature**2)
278 | # print(loss)
279 | # print(loss.size())
280 |
281 | return loss
282 |
283 | def mseloss(self, score, gold_score):
284 | # print("scores")
285 | # print(gold_score[0,:10])
286 | gold_score = torch.softmax(gold_score, dim=-1)
287 | # print(gold_score[0,:10])
288 | # print(score[0,:10])
289 | score = torch.softmax(score, dim=-1)
290 | # print(score[0,:10])
291 | loss = self.loss_fct(score, gold_score)
292 | # print(loss)
293 | # print(loss.size())
294 | return self.loss_fct(score, gold_score)
--------------------------------------------------------------------------------
/TART/src/index.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2 |
3 | import os
4 | import pickle
5 | from typing import List, Tuple
6 |
7 | import faiss
8 | import numpy as np
9 | from tqdm import tqdm
10 |
11 | class Indexer(object):
12 |
13 | def __init__(self, vector_sz, n_subquantizers=0, n_bits=8):
14 | if n_subquantizers > 0:
15 | self.index = faiss.IndexPQ(vector_sz, n_subquantizers, n_bits, faiss.METRIC_INNER_PRODUCT)
16 | else:
17 | self.index = faiss.IndexFlatIP(vector_sz)
18 | #self.index_id_to_db_id = np.empty((0), dtype=np.int64)
19 | self.index_id_to_db_id = []
20 |
21 | def index_data(self, ids, embeddings):
22 | self._update_id_mapping(ids)
23 | embeddings = embeddings.astype('float32')
24 | if not self.index.is_trained:
25 | self.index.train(embeddings)
26 | print(len(embeddings[0]))
27 | self.index.add(embeddings)
28 |
29 | print(f'Total data indexed {len(self.index_id_to_db_id)}')
30 |
31 | def search_knn(self, query_vectors: np.array, top_docs: int, index_batch_size: int = 2048) -> List[Tuple[List[object], List[float]]]:
32 | query_vectors = query_vectors.astype('float32')
33 | result = []
34 | nbatch = (len(query_vectors)-1) // index_batch_size + 1
35 | for k in tqdm(range(nbatch)):
36 | start_idx = k*index_batch_size
37 | end_idx = min((k+1)*index_batch_size, len(query_vectors))
38 | q = query_vectors[start_idx: end_idx]
39 | scores, indexes = self.index.search(q, top_docs)
40 | # convert to external ids
41 | db_ids = [[str(self.index_id_to_db_id[i]) for i in query_top_idxs] for query_top_idxs in indexes]
42 | result.extend([(db_ids[i], scores[i]) for i in range(len(db_ids))])
43 | return result
44 |
45 | def serialize(self, dir_path):
46 | index_file = os.path.join(dir_path, 'index.faiss')
47 | meta_file = os.path.join(dir_path, 'index_meta.faiss')
48 | print(f'Serializing index to {index_file}, meta data to {meta_file}')
49 |
50 | faiss.write_index(self.index, index_file)
51 | with open(meta_file, mode='wb') as f:
52 | pickle.dump(self.index_id_to_db_id, f)
53 |
54 | def deserialize_from(self, dir_path):
55 | index_file = os.path.join(dir_path, 'index.faiss')
56 | meta_file = os.path.join(dir_path, 'index_meta.faiss')
57 | print(f'Loading index from {index_file}, meta data from {meta_file}')
58 |
59 | self.index = faiss.read_index(index_file)
60 | print('Loaded index of type %s and size %d', type(self.index), self.index.ntotal)
61 |
62 | with open(meta_file, "rb") as reader:
63 | self.index_id_to_db_id = pickle.load(reader)
64 | assert len(
65 | self.index_id_to_db_id) == self.index.ntotal, 'Deserialized index_id_to_db_id should match faiss index size'
66 |
67 | def _update_id_mapping(self, db_ids: List):
68 | #new_ids = np.array(db_ids, dtype=np.int64)
69 | #self.index_id_to_db_id = np.concatenate((self.index_id_to_db_id, new_ids), axis=0)
70 | self.index_id_to_db_id.extend(db_ids)
--------------------------------------------------------------------------------
/TART/src/moco.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2 |
3 | import torch
4 | import torch.nn as nn
5 | import logging
6 | import copy
7 | import transformers
8 |
9 | from src import contriever, dist_utils, utils
10 |
11 | logger = logging.getLogger(__name__)
12 |
13 |
14 | class MoCo(nn.Module):
15 | def __init__(self, opt):
16 | super(MoCo, self).__init__()
17 |
18 | self.queue_size = opt.queue_size
19 | self.momentum = opt.momentum
20 | self.temperature = opt.temperature
21 | self.label_smoothing = opt.label_smoothing
22 | self.norm_doc = opt.norm_doc
23 | self.norm_query = opt.norm_query
24 | self.moco_train_mode_encoder_k = opt.moco_train_mode_encoder_k # apply the encoder on keys in train mode
25 |
26 | retriever, tokenizer = self._load_retriever(
27 | opt.retriever_model_id, pooling=opt.pooling, random_init=opt.random_init
28 | )
29 |
30 | self.tokenizer = tokenizer
31 | self.encoder_q = retriever
32 | self.encoder_k = copy.deepcopy(retriever)
33 |
34 | for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
35 | param_k.data.copy_(param_q.data)
36 | param_k.requires_grad = False
37 |
38 | # create the queue
39 | self.register_buffer("queue", torch.randn(opt.projection_size, self.queue_size))
40 | self.queue = nn.functional.normalize(self.queue, dim=0)
41 |
42 | self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))
43 |
44 | def _load_retriever(self, model_id, pooling, random_init):
45 | cfg = utils.load_hf(transformers.AutoConfig, model_id)
46 | tokenizer = utils.load_hf(transformers.AutoTokenizer, model_id)
47 |
48 | if "xlm" in model_id:
49 | model_class = contriever.XLMRetriever
50 | if "albert" in model_id:
51 | model_class = contriever.ALBERTRetriever
52 | elif "t5" in model_id or "T0" in model_id:
53 | model_class = contriever.T5Contriever
54 | else:
55 | model_class = contriever.Contriever
56 |
57 | if random_init:
58 | retriever = model_class(cfg)
59 | else:
60 | retriever = utils.load_hf(model_class, model_id)
61 |
62 | if "bert-" in model_id:
63 | if tokenizer.bos_token_id is None:
64 | tokenizer.bos_token = "[CLS]"
65 | if tokenizer.eos_token_id is None:
66 | tokenizer.eos_token = "[SEP]"
67 |
68 | retriever.config.pooling = pooling
69 |
70 | return retriever, tokenizer
71 |
72 | def get_encoder(self, return_encoder_k=False):
73 | if return_encoder_k:
74 | return self.encoder_k
75 | else:
76 | return self.encoder_q
77 |
78 | def _momentum_update_key_encoder(self):
79 | """
80 | Update of the key encoder
81 | """
82 | for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()):
83 | param_k.data = param_k.data * self.momentum + param_q.data * (1.0 - self.momentum)
84 |
85 | @torch.no_grad()
86 | def _dequeue_and_enqueue(self, keys):
87 | # gather keys before updating queue
88 | keys = dist_utils.gather_nograd(keys.contiguous())
89 |
90 | batch_size = keys.shape[0]
91 |
92 | ptr = int(self.queue_ptr)
93 | assert self.queue_size % batch_size == 0, f"{batch_size}, {self.queue_size}" # for simplicity
94 |
95 | # replace the keys at ptr (dequeue and enqueue)
96 | self.queue[:, ptr : ptr + batch_size] = keys.T
97 | ptr = (ptr + batch_size) % self.queue_size # move pointer
98 |
99 | self.queue_ptr[0] = ptr
100 |
101 | def _compute_logits(self, q, k):
102 | l_pos = torch.einsum("nc,nc->n", [q, k]).unsqueeze(-1)
103 | l_neg = torch.einsum("nc,ck->nk", [q, self.queue.clone().detach()])
104 |
105 | logits = torch.cat([l_pos, l_neg], dim=1)
106 | return logits
107 |
108 | def forward(self, q_tokens, q_mask, k_tokens, k_mask, stats_prefix="", iter_stats={}, **kwargs):
109 | bsz = q_tokens.size(0)
110 |
111 | q = self.encoder_q(input_ids=q_tokens, attention_mask=q_mask, normalize=self.norm_query)
112 | # compute key features
113 | with torch.no_grad(): # no gradient to keys
114 | self._momentum_update_key_encoder() # update the key encoder
115 |
116 | if not self.encoder_k.training and not self.moco_train_mode_encoder_k:
117 | self.encoder_k.eval()
118 |
119 | k = self.encoder_k(input_ids=k_tokens, attention_mask=k_mask, normalize=self.norm_doc)
120 | logits = self._compute_logits(q, k) / self.temperature
121 |
122 | # labels: positive key indicators
123 | labels = torch.zeros(bsz, dtype=torch.long).cuda()
124 |
125 | loss = torch.nn.functional.cross_entropy(logits, labels, label_smoothing=self.label_smoothing)
126 |
127 | self._dequeue_and_enqueue(k)
128 |
129 | # log stats
130 | if len(stats_prefix) > 0:
131 | stats_prefix = stats_prefix + "/"
132 | iter_stats[f"{stats_prefix}loss"] = (loss.item(), bsz)
133 |
134 | predicted_idx = torch.argmax(logits, dim=-1)
135 | accuracy = 100 * (predicted_idx == labels).float().mean()
136 | stdq = torch.std(q, dim=0).mean().item()
137 | stdk = torch.std(k, dim=0).mean().item()
138 | iter_stats[f"{stats_prefix}accuracy"] = (accuracy, bsz)
139 | iter_stats[f"{stats_prefix}stdq"] = (stdq, bsz)
140 | iter_stats[f"{stats_prefix}stdk"] = (stdk, bsz)
141 |
142 | return loss, iter_stats
143 |
--------------------------------------------------------------------------------
/TART/src/modeling_enc_t5.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2 |
3 | import copy
4 |
5 | import torch
6 | from torch import nn
7 | from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
8 | from transformers.modeling_outputs import SequenceClassifierOutput
9 | from transformers.models.t5.modeling_t5 import T5Config, T5PreTrainedModel, T5Stack
10 | from transformers.utils.model_parallel_utils import assert_device_map, get_device_map
11 |
12 |
13 | class EncT5ForSequenceClassification(T5PreTrainedModel):
14 | _keys_to_ignore_on_load_missing = [
15 | r"encoder\.embed_tokens\.weight",
16 | ]
17 |
18 | def __init__(self, config: T5Config, dropout=0.1):
19 | super().__init__(config)
20 | self.num_labels = config.num_labels
21 | self.config = config
22 |
23 | self.shared = nn.Embedding(config.vocab_size, config.d_model)
24 |
25 | encoder_config = copy.deepcopy(config)
26 | encoder_config.use_cache = False
27 | encoder_config.is_encoder_decoder = False
28 | self.encoder = T5Stack(encoder_config, self.shared)
29 |
30 | self.dropout = nn.Dropout(dropout)
31 | self.classifier = nn.Linear(config.hidden_size, config.num_labels)
32 |
33 | # Initialize weights and apply final processing
34 | self.post_init()
35 |
36 | # Model parallel
37 | self.model_parallel = False
38 | self.device_map = None
39 |
40 | def parallelize(self, device_map=None):
41 | self.device_map = (
42 | get_device_map(len(self.encoder.block), range(torch.cuda.device_count()))
43 | if device_map is None
44 | else device_map
45 | )
46 | assert_device_map(self.device_map, len(self.encoder.block))
47 | self.encoder.parallelize(self.device_map)
48 | self.classifier = self.classifier.to(self.encoder.first_device)
49 | self.model_parallel = True
50 |
51 | def deparallelize(self):
52 | self.encoder.deparallelize()
53 | self.encoder = self.encoder.to("cpu")
54 | self.model_parallel = False
55 | self.device_map = None
56 | torch.cuda.empty_cache()
57 |
58 | def get_input_embeddings(self):
59 | return self.shared
60 |
61 | def set_input_embeddings(self, new_embeddings):
62 | self.shared = new_embeddings
63 | self.encoder.set_input_embeddings(new_embeddings)
64 |
65 | def get_encoder(self):
66 | return self.encoder
67 |
68 | def _prune_heads(self, heads_to_prune):
69 | """
70 | Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
71 | class PreTrainedModel
72 | """
73 | for layer, heads in heads_to_prune.items():
74 | self.encoder.layer[layer].attention.prune_heads(heads)
75 |
76 | def forward(
77 | self,
78 | input_ids=None,
79 | attention_mask=None,
80 | head_mask=None,
81 | inputs_embeds=None,
82 | labels=None,
83 | output_attentions=None,
84 | output_hidden_states=None,
85 | return_dict=None,
86 | ):
87 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
88 |
89 | outputs = self.encoder(
90 | input_ids=input_ids,
91 | attention_mask=attention_mask,
92 | inputs_embeds=inputs_embeds,
93 | head_mask=head_mask,
94 | output_attentions=output_attentions,
95 | output_hidden_states=output_hidden_states,
96 | return_dict=return_dict,
97 | )
98 |
99 | hidden_states = outputs[0]
100 | pooled_output = hidden_states[:, 0, :] # Take bos token (equiv. to )
101 |
102 | pooled_output = self.dropout(pooled_output)
103 | logits = self.classifier(pooled_output)
104 |
105 | loss = None
106 | if labels is not None:
107 | if self.config.problem_type is None:
108 | if self.num_labels == 1:
109 | self.config.problem_type = "regression"
110 | elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
111 | self.config.problem_type = "single_label_classification"
112 | else:
113 | self.config.problem_type = "multi_label_classification"
114 |
115 | if self.config.problem_type == "regression":
116 | loss_fct = MSELoss()
117 | if self.num_labels == 1:
118 | loss = loss_fct(logits.squeeze(), labels.squeeze())
119 | else:
120 | loss = loss_fct(logits, labels)
121 | elif self.config.problem_type == "single_label_classification":
122 | loss_fct = CrossEntropyLoss()
123 | loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
124 | elif self.config.problem_type == "multi_label_classification":
125 | loss_fct = BCEWithLogitsLoss()
126 | loss = loss_fct(logits, labels)
127 | if not return_dict:
128 | output = (logits,) + outputs[1:]
129 | return ((loss,) + output) if loss is not None else output
130 |
131 | return SequenceClassifierOutput(
132 | loss=loss,
133 | logits=logits,
134 | hidden_states=outputs.hidden_states,
135 | attentions=outputs.attentions,
136 | )
137 |
--------------------------------------------------------------------------------
/TART/src/normalize_text.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2 |
3 | #: Control characters.
4 | CONTROLS = {
5 | '\u0001', '\u0002', '\u0003', '\u0004', '\u0005', '\u0006', '\u0007', '\u0008', '\u000e', '\u000f', '\u0011',
6 | '\u0012', '\u0013', '\u0014', '\u0015', '\u0016', '\u0017', '\u0018', '\u0019', '\u001a', '\u001b',
7 | }
8 | # There are further control characters, but they are instead replaced with a space by unicode normalization
9 | # '\u0009', '\u000a', '\u000b', '\u000c', '\u000d', '\u001c', '\u001d', '\u001e', '\u001f'
10 |
11 |
12 | #: Hyphen and dash characters.
13 | HYPHENS = {
14 | '-', # \u002d Hyphen-minus
15 | '‐', # \u2010 Hyphen
16 | '‑', # \u2011 Non-breaking hyphen
17 | '⁃', # \u2043 Hyphen bullet
18 | '‒', # \u2012 figure dash
19 | '–', # \u2013 en dash
20 | '—', # \u2014 em dash
21 | '―', # \u2015 horizontal bar
22 | }
23 |
24 | #: Minus characters.
25 | MINUSES = {
26 | '-', # \u002d Hyphen-minus
27 | '−', # \u2212 Minus
28 | '-', # \uff0d Full-width Hyphen-minus
29 | '⁻', # \u207b Superscript minus
30 | }
31 |
32 | #: Plus characters.
33 | PLUSES = {
34 | '+', # \u002b Plus
35 | '+', # \uff0b Full-width Plus
36 | '⁺', # \u207a Superscript plus
37 | }
38 |
39 | #: Slash characters.
40 | SLASHES = {
41 | '/', # \u002f Solidus
42 | '⁄', # \u2044 Fraction slash
43 | '∕', # \u2215 Division slash
44 | }
45 |
46 | #: Tilde characters.
47 | TILDES = {
48 | '~', # \u007e Tilde
49 | '˜', # \u02dc Small tilde
50 | '⁓', # \u2053 Swung dash
51 | '∼', # \u223c Tilde operator #in mbert vocab
52 | '∽', # \u223d Reversed tilde
53 | '∿', # \u223f Sine wave
54 | '〜', # \u301c Wave dash #in mbert vocab
55 | '~', # \uff5e Full-width tilde #in mbert vocab
56 | }
57 |
58 | #: Apostrophe characters.
59 | APOSTROPHES = {
60 | "'", # \u0027
61 | '’', # \u2019
62 | '՚', # \u055a
63 | 'Ꞌ', # \ua78b
64 | 'ꞌ', # \ua78c
65 | ''', # \uff07
66 | }
67 |
68 | #: Single quote characters.
69 | SINGLE_QUOTES = {
70 | "'", # \u0027
71 | '‘', # \u2018
72 | '’', # \u2019
73 | '‚', # \u201a
74 | '‛', # \u201b
75 |
76 | }
77 |
78 | #: Double quote characters.
79 | DOUBLE_QUOTES = {
80 | '"', # \u0022
81 | '“', # \u201c
82 | '”', # \u201d
83 | '„', # \u201e
84 | '‟', # \u201f
85 | }
86 |
87 | #: Accent characters.
88 | ACCENTS = {
89 | '`', # \u0060
90 | '´', # \u00b4
91 | }
92 |
93 | #: Prime characters.
94 | PRIMES = {
95 | '′', # \u2032
96 | '″', # \u2033
97 | '‴', # \u2034
98 | '‵', # \u2035
99 | '‶', # \u2036
100 | '‷', # \u2037
101 | '⁗', # \u2057
102 | }
103 |
104 | #: Quote characters, including apostrophes, single quotes, double quotes, accents and primes.
105 | QUOTES = APOSTROPHES | SINGLE_QUOTES | DOUBLE_QUOTES | ACCENTS | PRIMES
106 |
107 | def normalize(text):
108 | for control in CONTROLS:
109 | text = text.replace(control, '')
110 | text = text.replace('\u000b', ' ').replace('\u000c', ' ').replace(u'\u0085', ' ')
111 |
112 | for hyphen in HYPHENS | MINUSES:
113 | text = text.replace(hyphen, '-')
114 | text = text.replace('\u00ad', '')
115 |
116 | for double_quote in DOUBLE_QUOTES:
117 | text = text.replace(double_quote, '"') # \u0022
118 | for single_quote in (SINGLE_QUOTES | APOSTROPHES | ACCENTS):
119 | text = text.replace(single_quote, "'") # \u0027
120 | text = text.replace('′', "'") # \u2032 prime
121 | text = text.replace('‵', "'") # \u2035 reversed prime
122 | text = text.replace('″', "''") # \u2033 double prime
123 | text = text.replace('‶', "''") # \u2036 reversed double prime
124 | text = text.replace('‴', "'''") # \u2034 triple prime
125 | text = text.replace('‷', "'''") # \u2037 reversed triple prime
126 | text = text.replace('⁗', "''''") # \u2057 quadruple prime
127 |
128 | text = text.replace('…', '...').replace(' . . . ', ' ... ') # \u2026
129 |
130 | for slash in SLASHES:
131 | text = text.replace(slash, '/')
132 |
133 | #for tilde in TILDES:
134 | # text = text.replace(tilde, '~')
135 |
136 | return text
--------------------------------------------------------------------------------
/TART/src/options.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2 |
3 | import argparse
4 | import os
5 |
6 |
7 | class Options:
8 | def __init__(self):
9 | self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
10 | self.initialize()
11 |
12 | def initialize(self):
13 | # basic parameters
14 | self.parser.add_argument(
15 | "--output_dir", type=str, default="./checkpoint/my_experiments", help="models are saved here"
16 | )
17 | self.parser.add_argument(
18 | "--train_data",
19 | nargs="+",
20 | default=[],
21 | help="Data used for training, passed as a list of directories splitted into tensor files.",
22 | )
23 | self.parser.add_argument(
24 | "--eval_data",
25 | nargs="+",
26 | default=[],
27 | help="Data used for evaluation during finetuning, this option is not used during contrastive pre-training.",
28 | )
29 | self.parser.add_argument(
30 | "--eval_datasets", nargs="+", default=[], help="List of datasets used for evaluation, in BEIR format"
31 | )
32 | self.parser.add_argument(
33 | "--eval_datasets_dir", type=str, default="./", help="Directory where eval datasets are stored"
34 | )
35 | self.parser.add_argument("--model_path", type=str, default="none", help="path for retraining")
36 | self.parser.add_argument("--continue_training", action="store_true")
37 | self.parser.add_argument("--num_workers", type=int, default=5)
38 |
39 | self.parser.add_argument("--chunk_length", type=int, default=256)
40 | self.parser.add_argument("--loading_mode", type=str, default="split")
41 | self.parser.add_argument("--lower_case", action="store_true", help="perform evaluation after lowercasing")
42 | self.parser.add_argument(
43 | "--sampling_coefficient",
44 | type=float,
45 | default=0.0,
46 | help="coefficient used for sampling between different datasets during training, \
47 | by default sampling is uniform over datasets",
48 | )
49 | self.parser.add_argument("--augmentation", type=str, default="none")
50 | self.parser.add_argument("--prob_augmentation", type=float, default=0.0)
51 |
52 | self.parser.add_argument("--dropout", type=float, default=0.1)
53 | self.parser.add_argument("--rho", type=float, default=0.05)
54 |
55 | self.parser.add_argument("--contrastive_mode", type=str, default="moco")
56 | self.parser.add_argument("--queue_size", type=int, default=65536)
57 | self.parser.add_argument("--temperature", type=float, default=1.0)
58 | self.parser.add_argument("--momentum", type=float, default=0.999)
59 | self.parser.add_argument("--moco_train_mode_encoder_k", action="store_true")
60 | self.parser.add_argument("--eval_normalize_text", action="store_true")
61 | self.parser.add_argument("--norm_query", action="store_true")
62 | self.parser.add_argument("--norm_doc", action="store_true")
63 | self.parser.add_argument("--projection_size", type=int, default=768)
64 |
65 | # bi encoder
66 | self.parser.add_argument("--bi_encoder", action="store_true", help="instead of sharing a single encoder, we use separate encoders.")
67 | self.parser.add_argument("--freeze_ctx_encoder", action="store_true", help="if we use bi encoder but only want to update the query encoder.")
68 |
69 | self.parser.add_argument("--ratio_min", type=float, default=0.1)
70 | self.parser.add_argument("--ratio_max", type=float, default=0.5)
71 | self.parser.add_argument("--score_function", type=str, default="dot")
72 | self.parser.add_argument("--retriever_model_id", type=str, default="bert-base-uncased")
73 | self.parser.add_argument("--pooling", type=str, default="average")
74 | self.parser.add_argument("--random_init", action="store_true", help="init model with random weights")
75 |
76 | # dataset parameters
77 | self.parser.add_argument("--per_gpu_batch_size", default=64, type=int, help="Batch size per GPU for training.")
78 | self.parser.add_argument(
79 | "--per_gpu_eval_batch_size", default=256, type=int, help="Batch size per GPU for evaluation."
80 | )
81 | self.parser.add_argument("--total_steps", type=int, default=1000)
82 | self.parser.add_argument("--warmup_steps", type=int, default=-1)
83 |
84 | self.parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
85 | self.parser.add_argument("--main_port", type=int, default=10001, help="Master port (for multi-node SLURM jobs)")
86 | self.parser.add_argument("--seed", type=int, default=0, help="random seed for initialization")
87 | self.parser.add_argument("--hard_order", action="store_true", help="use the most related hard negatives.")
88 |
89 | # training parameters
90 | self.parser.add_argument("--optim", type=str, default="adamw")
91 | self.parser.add_argument("--scheduler", type=str, default="linear")
92 | self.parser.add_argument("--lr", type=float, default=1e-4, help="learning rate")
93 | self.parser.add_argument(
94 | "--lr_min_ratio",
95 | type=float,
96 | default=0.0,
97 | help="minimum learning rate at the end of the optimization schedule as a ratio of the learning rate",
98 | )
99 | self.parser.add_argument("--weight_decay", type=float, default=0.01, help="learning rate")
100 | self.parser.add_argument("--beta1", type=float, default=0.9, help="beta1")
101 | self.parser.add_argument("--beta2", type=float, default=0.98, help="beta2")
102 | self.parser.add_argument("--eps", type=float, default=1e-6, help="eps")
103 | self.parser.add_argument(
104 | "--log_freq", type=int, default=100, help="log train stats every steps during training"
105 | )
106 | self.parser.add_argument(
107 | "--eval_freq", type=int, default=500, help="evaluate model every steps during training"
108 | )
109 | self.parser.add_argument("--save_freq", type=int, default=50000)
110 | self.parser.add_argument("--maxload", type=int, default=None)
111 | self.parser.add_argument("--label_smoothing", type=float, default=0.0)
112 |
113 | # finetuning options
114 | self.parser.add_argument("--negative_ctxs", type=int, default=1)
115 | self.parser.add_argument("--negative_hard_min_idx", type=int, default=0)
116 | self.parser.add_argument("--negative_hard_ratio", type=float, default=0.0)
117 | self.parser.add_argument("--kd", action="store_true")
118 | self.parser.add_argument("--loss_type", type=str, default="kl")
119 | self.parser.add_argument("--T", type=float, default=0.1, help="eps")
120 | self.parser.add_argument("--n_context", type=int, default=50)
121 | self.parser.add_argument("--random_sort", action="store_true", help="randomly sampling top N for distillation")
122 |
123 | def print_options(self, opt):
124 | message = ""
125 | for k, v in sorted(vars(opt).items()):
126 | comment = ""
127 | default = self.parser.get_default(k)
128 | if v != default:
129 | comment = f"\t[default: %s]" % str(default)
130 | message += f"{str(k):>40}: {str(v):<40}{comment}\n"
131 | print(message, flush=True)
132 | model_dir = os.path.join(opt.output_dir, "models")
133 | if not os.path.exists(model_dir):
134 | os.makedirs(os.path.join(opt.output_dir, "models"))
135 | file_name = os.path.join(opt.output_dir, "opt.txt")
136 | with open(file_name, "wt") as opt_file:
137 | opt_file.write(message)
138 | opt_file.write("\n")
139 |
140 | def parse(self):
141 | opt, _ = self.parser.parse_known_args()
142 | # opt = self.parser.parse_args()
143 | return opt
144 |
--------------------------------------------------------------------------------
/TART/src/rerank.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2 |
3 | import logging
4 | from typing import Dict, List
5 | from transformers import AutoTokenizer, AutoModelForSequenceClassification
6 | import torch
7 | import torch.nn.functional as F
8 | from tqdm import tqdm
9 | from src.modeling_enc_t5 import EncT5ForSequenceClassification
10 | from src.tokenization_enc_t5 import EncT5Tokenizer
11 |
12 | logger = logging.getLogger(__name__)
13 |
14 | #Parent class for any reranking model
15 | class Rerank:
16 | def __init__(self, model_name_or_path, batch_size: int = 128, **kwargs):
17 | if "t0" in model_name_or_path or "t5" in model_name_or_path:
18 | self.model = EncT5ForSequenceClassification.from_pretrained(model_name_or_path)
19 | self.tokenizer = EncT5Tokenizer.from_pretrained(model_name_or_path)
20 | else:
21 | self.model = AutoModelForSequenceClassification.from_pretrained(model_name_or_path)
22 | self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
23 | self.batch_size = batch_size
24 | self.rerank_results = {}
25 | self.model.to('cuda')
26 |
27 | self.model.eval()
28 |
29 | def rerank(self,
30 | corpus: Dict[str, Dict[str, str]],
31 | queries: Dict[str, str],
32 | results: Dict[str, Dict[str, float]],
33 | top_k: int,
34 | prompt: str = None) -> Dict[str, Dict[str, float]]:
35 |
36 | sentence_pairs, pair_ids = [], []
37 |
38 | self.rerank_results = {query_id: {} for query_id in results}
39 | for query_id in tqdm(results):
40 | docs = []
41 | query= []
42 | doc_ids = []
43 |
44 | if len(results[query_id]) > top_k:
45 | for (doc_id, _) in sorted(results[query_id].items(), key=lambda item: item[1], reverse=True)[:top_k]:
46 | pair_ids.append([query_id, doc_id])
47 | corpus_text = (corpus[doc_id].get("title", "") + " " + corpus[doc_id].get("text", "")).strip()
48 | # print(corpus_text)
49 | # print(corpus_text)
50 | docs.append(corpus_text)
51 | doc_ids.append(doc_id)
52 | if prompt is None:
53 | # sentence_pairs.append([queries[query_id], corpus_text])
54 | query.append(queries[query_id])
55 | else:
56 | # sentence_pairs.append(["{0} [SEP] {1}".format(prompt, queries[query_id]), corpus_text])
57 | query.append("{0} [SEP] {1}".format(prompt, queries[query_id]))
58 | # print(query[-1])
59 |
60 | else:
61 | for doc_id in results[query_id]:
62 | pair_ids.append([query_id, doc_id])
63 | corpus_text = (corpus[doc_id].get("title", "") + " " + corpus[doc_id].get("text", "")).strip()
64 | # print(corpus_text)
65 | docs.append(corpus_text)
66 | doc_ids.append(doc_id)
67 | if prompt is None:
68 | query.append(queries[query_id])
69 | # sentence_pairs.append([queries[query_id], corpus_text])
70 | else:
71 | query.append("{0} [SEP] {1}".format(prompt, queries[query_id]))
72 | # sentence_pairs.append(["{0} [SEP] {1}".format(prompt, queries[query_id]), corpus_text])
73 |
74 | # run inference
75 | features = self.tokenizer(query, docs, padding=True, truncation=True, max_length=512, return_tensors="pt").to('cuda')
76 | with torch.no_grad():
77 | scores = self.model(**features).logits
78 | normalized_scores = F.softmax(scores, dim=1)
79 | final_scores = [float(score[1]) for score in normalized_scores]
80 | # print(final_scores)
81 | for doc_id, score in zip(doc_ids, final_scores):
82 | self.rerank_results[query_id][doc_id] = score
83 |
84 | # #### Starting to Rerank using cross-attention
85 | # logging.info("Starting To Rerank Top-{}....".format(top_k))
86 |
87 | # rerank_scores = [float(score[1]) for score in self.cross_encoder.predict(sentence_pairs, batch_size=self.batch_size)]
88 | # #### Reranking results
89 | # self.rerank_results = {query_id: {} for query_id in results}
90 | # for pair, score in zip(pair_ids, rerank_scores):
91 | # query_id, doc_id = pair[0], pair[1]
92 | # self.rerank_results[query_id][doc_id] = score
93 | # print(self.rerank_results)
94 | return self.rerank_results
--------------------------------------------------------------------------------
/TART/src/slurm.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates.
2 | # All rights reserved.
3 | #
4 | # This source code is licensed under the license found in the
5 | # LICENSE file in the root directory of this source tree.
6 |
7 | from logging import getLogger
8 | import os
9 | import sys
10 | import torch
11 | import socket
12 | import signal
13 | import subprocess
14 |
15 |
16 | logger = getLogger()
17 |
18 | def sig_handler(signum, frame):
19 | logger.warning("Signal handler called with signal " + str(signum))
20 | prod_id = int(os.environ['SLURM_PROCID'])
21 | logger.warning("Host: %s - Global rank: %i" % (socket.gethostname(), prod_id))
22 | if prod_id == 0:
23 | logger.warning("Requeuing job " + os.environ['SLURM_JOB_ID'])
24 | os.system('scontrol requeue ' + os.environ['SLURM_JOB_ID'])
25 | else:
26 | logger.warning("Not the main process, no need to requeue.")
27 | sys.exit(-1)
28 |
29 |
30 | def term_handler(signum, frame):
31 | logger.warning("Signal handler called with signal " + str(signum))
32 | logger.warning("Bypassing SIGTERM.")
33 |
34 |
35 | def init_signal_handler():
36 | """
37 | Handle signals sent by SLURM for time limit / pre-emption.
38 | """
39 | signal.signal(signal.SIGUSR1, sig_handler)
40 | signal.signal(signal.SIGTERM, term_handler)
41 |
42 |
43 | def init_distributed_mode(params):
44 | """
45 | Handle single and multi-GPU / multi-node / SLURM jobs.
46 | Initialize the following variables:
47 | - local_rank
48 | - global_rank
49 | - world_size
50 | """
51 | is_slurm_job = 'SLURM_JOB_ID' in os.environ and not 'WORLD_SIZE' in os.environ
52 | print("is slurm job {}".format(is_slurm_job))
53 | if 'WORLD_SIZE' in os.environ:
54 | print("world size: {}".format(os.environ['WORLD_SIZE']))
55 | has_local_rank = hasattr(params, 'local_rank')
56 | print("has local rank? {}".format(has_local_rank))
57 |
58 | # SLURM job without torch.distributed.launch
59 | if is_slurm_job and has_local_rank:
60 |
61 | assert params.local_rank == -1 # on the cluster, this is handled by SLURM
62 |
63 | # local rank on the current node / global rank
64 | params.local_rank = int(os.environ['SLURM_LOCALID'])
65 | params.global_rank = int(os.environ['SLURM_PROCID'])
66 | params.world_size = int(os.environ['SLURM_NTASKS'])
67 |
68 | # define master address and master port
69 | hostnames = subprocess.check_output(['scontrol', 'show', 'hostnames', os.environ['SLURM_JOB_NODELIST']])
70 | params.main_addr = hostnames.split()[0].decode('utf-8')
71 | assert 10001 <= params.main_port <= 20000 or params.world_size == 1
72 |
73 | # set environment variables for 'env://'
74 | os.environ['MASTER_ADDR'] = params.main_addr
75 | os.environ['MASTER_PORT'] = str(params.main_port)
76 | os.environ['WORLD_SIZE'] = str(params.world_size)
77 | os.environ['RANK'] = str(params.global_rank)
78 | is_distributed = True
79 |
80 |
81 | # multi-GPU job (local or multi-node) - jobs started with torch.distributed.launch
82 | elif has_local_rank and params.local_rank != -1:
83 |
84 | assert params.main_port == -1
85 |
86 | # read environment variables
87 | params.global_rank = int(os.environ['RANK'])
88 | params.world_size = int(os.environ['WORLD_SIZE'])
89 |
90 | is_distributed = True
91 |
92 | # local job (single GPU)
93 | else:
94 | params.local_rank = 0
95 | params.global_rank = 0
96 | params.world_size = 1
97 | is_distributed = False
98 |
99 | # set GPU device
100 | torch.cuda.set_device(params.local_rank)
101 |
102 | # initialize multi-GPU
103 | if is_distributed:
104 |
105 | # http://pytorch.apachecn.org/en/0.3.0/distributed.html#environment-variable-initialization
106 | # 'env://' will read these environment variables:
107 | # MASTER_PORT - required; has to be a free port on machine with rank 0
108 | # MASTER_ADDR - required (except for rank 0); address of rank 0 node
109 | # WORLD_SIZE - required; can be set either here, or in a call to init function
110 | # RANK - required; can be set either here, or in a call to init function
111 |
112 | #print("Initializing PyTorch distributed ...")
113 | torch.distributed.init_process_group(
114 | init_method='env://',
115 | backend='nccl',
116 | #world_size=params.world_size,
117 | #rank=params.global_rank,
118 | )
119 |
--------------------------------------------------------------------------------
/TART/src/tokenization_enc_t5.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2 |
3 | from typing import Any, Dict, List, Optional
4 |
5 | from transformers import T5Tokenizer
6 |
7 |
8 | class EncT5Tokenizer(T5Tokenizer):
9 | def __init__(
10 | self,
11 | vocab_file,
12 | bos_token="",
13 | eos_token="",
14 | unk_token="",
15 | pad_token="",
16 | extra_ids=100,
17 | additional_special_tokens=None,
18 | sp_model_kwargs: Optional[Dict[str, Any]] = None,
19 | **kwargs,
20 | ) -> None:
21 | sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
22 |
23 | super().__init__(
24 | vocab_file=vocab_file,
25 | bos_token=bos_token,
26 | eos_token=eos_token,
27 | unk_token=unk_token,
28 | pad_token=pad_token,
29 | extra_ids=extra_ids,
30 | additional_special_tokens=additional_special_tokens,
31 | sp_model_kwargs=sp_model_kwargs,
32 | **kwargs,
33 | )
34 |
35 | def get_special_tokens_mask(
36 | self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
37 | ) -> List[int]:
38 | """
39 | Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
40 | special tokens using the tokenizer `prepare_for_model` method.
41 | Args:
42 | token_ids_0 (`List[int]`):
43 | List of IDs.
44 | token_ids_1 (`List[int]`, *optional*):
45 | Optional second list of IDs for sequence pairs.
46 | already_has_special_tokens (`bool`, *optional*, defaults to `False`):
47 | Whether or not the token list is already formatted with special tokens for the model.
48 | Returns:
49 | `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
50 | """
51 | if already_has_special_tokens:
52 | return super().get_special_tokens_mask(
53 | token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True
54 | )
55 |
56 | # normal case: some special tokens
57 | if token_ids_1 is None:
58 | return [1] + ([0] * len(token_ids_0)) + [1]
59 | return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
60 |
61 | def create_token_type_ids_from_sequences(
62 | self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
63 | ) -> List[int]:
64 | """
65 | Create a mask from the two sequences passed to be used in a sequence-pair classification task. T5 does not make
66 | use of token type ids, therefore a list of zeros is returned.
67 | Args:
68 | token_ids_0 (`List[int]`):
69 | List of IDs.
70 | token_ids_1 (`List[int]`, *optional*):
71 | Optional second list of IDs for sequence pairs.
72 | Returns:
73 | `List[int]`: List of zeros.
74 | """
75 | bos = [self.bos_token_id]
76 | eos = [self.eos_token_id]
77 |
78 | if token_ids_1 is None:
79 | return len(bos + token_ids_0 + eos) * [0]
80 | return len(bos + token_ids_0 + eos + token_ids_1 + eos) * [0]
81 |
82 | def build_inputs_with_special_tokens(
83 | self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
84 | ) -> List[int]:
85 | """
86 | Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
87 | adding special tokens. A sequence has the following format:
88 | - single sequence: ` X `
89 | - pair of sequences: ` A B `
90 | Args:
91 | token_ids_0 (`List[int]`):
92 | List of IDs to which the special tokens will be added.
93 | token_ids_1 (`List[int]`, *optional*):
94 | Optional second list of IDs for sequence pairs.
95 | Returns:
96 | `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
97 | """
98 | if token_ids_1 is None:
99 | return [self.bos_token_id] + token_ids_0 + [self.eos_token_id]
100 | else:
101 | return [self.bos_token_id] + token_ids_0 + [self.eos_token_id] + token_ids_1 + [self.eos_token_id]
102 |
--------------------------------------------------------------------------------
/TART/src/utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2 |
3 | import os
4 | import sys
5 | import logging
6 | import torch
7 | import errno
8 | from typing import Union, Tuple, List, Dict
9 | from collections import defaultdict
10 | from transformers import T5EncoderModel
11 |
12 | from src import dist_utils
13 |
14 | Number = Union[float, int]
15 |
16 | logger = logging.getLogger(__name__)
17 |
18 |
19 | def init_logger(args, stdout_only=False):
20 | if torch.distributed.is_initialized():
21 | torch.distributed.barrier()
22 | stdout_handler = logging.StreamHandler(sys.stdout)
23 | handlers = [stdout_handler]
24 | if not stdout_only:
25 | file_handler = logging.FileHandler(filename=os.path.join(args.output_dir, "run.log"))
26 | handlers.append(file_handler)
27 | logging.basicConfig(
28 | datefmt="%m/%d/%Y %H:%M:%S",
29 | level=logging.INFO if dist_utils.is_main() else logging.WARN,
30 | format="[%(asctime)s] {%(filename)s:%(lineno)d} %(levelname)s - %(message)s",
31 | handlers=handlers,
32 | )
33 | return logger
34 |
35 |
36 | def symlink_force(target, link_name):
37 | try:
38 | os.symlink(target, link_name)
39 | except OSError as e:
40 | if e.errno == errno.EEXIST:
41 | os.remove(link_name)
42 | os.symlink(target, link_name)
43 | else:
44 | raise e
45 |
46 |
47 | def save(model, optimizer, scheduler, step, opt, dir_path, name):
48 | model_to_save = model.module if hasattr(model, "module") else model
49 | path = os.path.join(dir_path, "checkpoint")
50 | epoch_path = os.path.join(path, name) # "step-%s" % step)
51 | os.makedirs(epoch_path, exist_ok=True)
52 | cp = os.path.join(path, "latest")
53 | fp = os.path.join(epoch_path, "checkpoint.pth")
54 | checkpoint = {
55 | "step": step,
56 | "model": model_to_save.state_dict(),
57 | "optimizer": optimizer.state_dict(),
58 | "scheduler": scheduler.state_dict(),
59 | "opt": opt,
60 | }
61 | torch.save(checkpoint, fp)
62 | symlink_force(epoch_path, cp)
63 | if not name == "lastlog":
64 | logger.info(f"Saving model to {epoch_path}")
65 |
66 |
67 | def load(model_class, dir_path, opt, reset_params=False):
68 | epoch_path = os.path.realpath(dir_path)
69 | checkpoint_path = os.path.join(epoch_path, "checkpoint.pth")
70 | logger.info(f"loading checkpoint {checkpoint_path}")
71 | checkpoint = torch.load(checkpoint_path, map_location="cpu")
72 | opt_checkpoint = checkpoint["opt"]
73 | state_dict = checkpoint["model"]
74 |
75 | model = model_class(opt_checkpoint)
76 | model.load_state_dict(state_dict, strict=True)
77 | model = model.cuda()
78 | step = checkpoint["step"]
79 | if not reset_params:
80 | optimizer, scheduler = set_optim(opt_checkpoint, model)
81 | scheduler.load_state_dict(checkpoint["scheduler"])
82 | optimizer.load_state_dict(checkpoint["optimizer"])
83 | else:
84 | optimizer, scheduler = set_optim(opt, model)
85 |
86 | return model, optimizer, scheduler, opt_checkpoint, step
87 |
88 |
89 | ############ OPTIM
90 |
91 |
92 | class WarmupLinearScheduler(torch.optim.lr_scheduler.LambdaLR):
93 | def __init__(self, optimizer, warmup, total, ratio, last_epoch=-1):
94 | self.warmup = warmup
95 | self.total = total
96 | self.ratio = ratio
97 | super(WarmupLinearScheduler, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch)
98 |
99 | def lr_lambda(self, step):
100 | if step < self.warmup:
101 | return (1 - self.ratio) * step / float(max(1, self.warmup))
102 |
103 | return max(
104 | 0.0,
105 | 1.0 + (self.ratio - 1) * (step - self.warmup) / float(max(1.0, self.total - self.warmup)),
106 | )
107 |
108 |
109 | class CosineScheduler(torch.optim.lr_scheduler.LambdaLR):
110 | def __init__(self, optimizer, warmup, total, ratio=0.1, last_epoch=-1):
111 | self.warmup = warmup
112 | self.total = total
113 | self.ratio = ratio
114 | super(CosineScheduler, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch)
115 |
116 | def lr_lambda(self, step):
117 | if step < self.warmup:
118 | return float(step) / self.warmup
119 | s = float(step - self.warmup) / (self.total - self.warmup)
120 | return self.ratio + (1.0 - self.ratio) * math.cos(0.5 * math.pi * s)
121 |
122 |
123 | def set_optim(opt, model):
124 | if opt.optim == "adamw":
125 | optimizer = torch.optim.AdamW(
126 | model.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2), eps=opt.eps, weight_decay=opt.weight_decay
127 | )
128 | else:
129 | raise NotImplementedError("optimizer class not implemented")
130 |
131 | scheduler_args = {
132 | "warmup": opt.warmup_steps,
133 | "total": opt.total_steps,
134 | "ratio": opt.lr_min_ratio,
135 | }
136 | if opt.scheduler == "linear":
137 | scheduler_class = WarmupLinearScheduler
138 | elif opt.scheduler == "cosine":
139 | scheduler_class = CosineScheduler
140 | else:
141 | raise ValueError
142 | scheduler = scheduler_class(optimizer, **scheduler_args)
143 | return optimizer, scheduler
144 |
145 |
146 | def get_parameters(net, verbose=False):
147 | num_params = 0
148 | for param in net.parameters():
149 | num_params += param.numel()
150 | message = "[Network] Total number of parameters : %.6f M" % (num_params / 1e6)
151 | return message
152 |
153 |
154 | class WeightedAvgStats:
155 | """provides an average over a bunch of stats"""
156 |
157 | def __init__(self):
158 | self.raw_stats: Dict[str, float] = defaultdict(float)
159 | self.total_weights: Dict[str, float] = defaultdict(float)
160 |
161 | def update(self, vals: Dict[str, Tuple[Number, Number]]) -> None:
162 | for key, (value, weight) in vals.items():
163 | self.raw_stats[key] += value * weight
164 | self.total_weights[key] += weight
165 |
166 | @property
167 | def stats(self) -> Dict[str, float]:
168 | return {x: self.raw_stats[x] / self.total_weights[x] for x in self.raw_stats.keys()}
169 |
170 | @property
171 | def tuple_stats(self) -> Dict[str, Tuple[float, float]]:
172 | return {x: (self.raw_stats[x] / self.total_weights[x], self.total_weights[x]) for x in self.raw_stats.keys()}
173 |
174 | def reset(self) -> None:
175 | self.raw_stats = defaultdict(float)
176 | self.total_weights = defaultdict(float)
177 |
178 | @property
179 | def average_stats(self) -> Dict[str, float]:
180 | keys = sorted(self.raw_stats.keys())
181 | if torch.distributed.is_initialized():
182 | torch.distributed.broadcast_object_list(keys, src=0)
183 | global_dict = {}
184 | for k in keys:
185 | if not k in self.total_weights:
186 | v = 0.0
187 | else:
188 | v = self.raw_stats[k] / self.total_weights[k]
189 | v, _ = dist_utils.weighted_average(v, self.total_weights[k])
190 | global_dict[k] = v
191 | return global_dict
192 |
193 |
194 | def load_hf(object_class, model_name):
195 | if "gtr-t5" in model_name:
196 | obj = T5EncoderModel.from_pretrained(model_name)
197 | try:
198 | obj = object_class.from_pretrained(model_name, local_files_only=True)
199 | except:
200 | obj = object_class.from_pretrained(model_name, local_files_only=False)
201 | return obj
202 |
203 |
204 | def init_tb_logger(output_dir):
205 | try:
206 | from torch.utils import tensorboard
207 |
208 | if dist_utils.is_main():
209 | tb_logger = tensorboard.SummaryWriter(output_dir)
210 | else:
211 | tb_logger = None
212 | except:
213 | logger.warning("Tensorboard is not available.")
214 | tb_logger = None
215 |
216 | return tb_logger
217 |
--------------------------------------------------------------------------------
/TART/train.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2 |
3 | import os
4 | import time
5 | import sys
6 | import torch
7 | import logging
8 | import json
9 | import numpy as np
10 | import random
11 | import pickle
12 |
13 | import torch.distributed as dist
14 | from torch.utils.data import DataLoader, RandomSampler
15 |
16 | from src.options import Options
17 | from src import data, beir_utils, slurm, dist_utils, utils
18 | from src import moco, inbatch
19 |
20 |
21 | logger = logging.getLogger(__name__)
22 |
23 |
24 | def train(opt, model, optimizer, scheduler, step):
25 |
26 | run_stats = utils.WeightedAvgStats()
27 |
28 | tb_logger = utils.init_tb_logger(opt.output_dir)
29 |
30 | logger.info("Data loading")
31 | if isinstance(model, torch.nn.parallel.DistributedDataParallel):
32 | tokenizer = model.module.tokenizer
33 | else:
34 | tokenizer = model.tokenizer
35 | collator = data.Collator(opt=opt)
36 | train_dataset = data.load_data(opt, tokenizer)
37 | logger.warning(f"Data loading finished for rank {dist_utils.get_rank()}")
38 |
39 | train_sampler = RandomSampler(train_dataset)
40 | train_dataloader = DataLoader(
41 | train_dataset,
42 | sampler=train_sampler,
43 | batch_size=opt.per_gpu_batch_size,
44 | drop_last=True,
45 | num_workers=opt.num_workers,
46 | collate_fn=collator,
47 | )
48 |
49 | epoch = 1
50 |
51 | model.train()
52 | while step < opt.total_steps:
53 | train_dataset.generate_offset()
54 |
55 | logger.info(f"Start epoch {epoch}")
56 | for i, batch in enumerate(train_dataloader):
57 | step += 1
58 |
59 | batch = {key: value.cuda() if isinstance(value, torch.Tensor) else value for key, value in batch.items()}
60 | train_loss, iter_stats = model(**batch, stats_prefix="train")
61 |
62 | train_loss.backward()
63 | optimizer.step()
64 |
65 | scheduler.step()
66 | model.zero_grad()
67 |
68 | run_stats.update(iter_stats)
69 |
70 | if step % opt.log_freq == 0:
71 | log = f"{step} / {opt.total_steps}"
72 | for k, v in sorted(run_stats.average_stats.items()):
73 | log += f" | {k}: {v:.3f}"
74 | if tb_logger:
75 | tb_logger.add_scalar(k, v, step)
76 | log += f" | lr: {scheduler.get_last_lr()[0]:0.3g}"
77 | log += f" | Memory: {torch.cuda.max_memory_allocated()//1e9} GiB"
78 |
79 | logger.info(log)
80 | run_stats.reset()
81 |
82 | if step % opt.eval_freq == 0:
83 | if isinstance(model, torch.nn.parallel.DistributedDataParallel):
84 | encoder = model.module.get_encoder()
85 | else:
86 | encoder = model.get_encoder()
87 | eval_model(
88 | opt, query_encoder=encoder, doc_encoder=encoder, tokenizer=tokenizer, tb_logger=tb_logger, step=step
89 | )
90 |
91 | if dist_utils.is_main() and step % opt.save_freq == 0:
92 | utils.save(model, optimizer, scheduler, step, opt, opt.output_dir, f"step-{step}")
93 |
94 | model.train()
95 | if step > opt.total_steps:
96 | break
97 | epoch += 1
98 |
99 |
100 | def eval_model(opt, query_encoder, doc_encoder, tokenizer, tb_logger, step):
101 | for datasetname in opt.eval_datasets:
102 | metrics = beir_utils.evaluate_model(
103 | query_encoder,
104 | doc_encoder,
105 | tokenizer,
106 | dataset=datasetname,
107 | batch_size=opt.per_gpu_eval_batch_size,
108 | norm_doc=opt.norm_doc,
109 | norm_query=opt.norm_query,
110 | beir_dir=opt.eval_datasets_dir,
111 | score_function=opt.score_function,
112 | lower_case=opt.lower_case,
113 | normalize_text=opt.eval_normalize_text,
114 | )
115 |
116 | message = []
117 | if dist_utils.is_main():
118 | for metric in ["NDCG@10", "Recall@10", "Recall@100"]:
119 | message.append(f"{datasetname}/{metric}: {metrics[metric]:.2f}")
120 | if tb_logger is not None:
121 | tb_logger.add_scalar(f"{datasetname}/{metric}", metrics[metric], step)
122 | logger.info(" | ".join(message))
123 |
124 |
125 | if __name__ == "__main__":
126 | logger.info("Start")
127 |
128 | options = Options()
129 | opt = options.parse()
130 |
131 | torch.manual_seed(opt.seed)
132 | slurm.init_distributed_mode(opt)
133 | slurm.init_signal_handler()
134 |
135 | directory_exists = os.path.isdir(opt.output_dir)
136 | if dist.is_initialized():
137 | dist.barrier()
138 | os.makedirs(opt.output_dir, exist_ok=True)
139 | if not directory_exists and dist_utils.is_main():
140 | options.print_options(opt)
141 | if dist.is_initialized():
142 | dist.barrier()
143 | utils.init_logger(opt)
144 |
145 | os.environ["TOKENIZERS_PARALLELISM"] = "false"
146 |
147 | if opt.contrastive_mode == "moco":
148 | model_class = moco.MoCo
149 | elif opt.contrastive_mode == "inbatch":
150 | model_class = inbatch.InBatch
151 | else:
152 | raise ValueError(f"contrastive mode: {opt.contrastive_mode} not recognised")
153 |
154 | if not directory_exists and opt.model_path == "none":
155 | model = model_class(opt)
156 | model = model.cuda()
157 | optimizer, scheduler = utils.set_optim(opt, model)
158 | step = 0
159 | elif directory_exists:
160 | model_path = os.path.join(opt.output_dir, "checkpoint", "latest")
161 | model, optimizer, scheduler, opt_checkpoint, step = utils.load(
162 | model_class,
163 | model_path,
164 | opt,
165 | reset_params=False,
166 | )
167 | logger.info(f"Model loaded from {opt.output_dir}")
168 | else:
169 | model, optimizer, scheduler, opt_checkpoint, step = utils.load(
170 | model_class,
171 | opt.model_path,
172 | opt,
173 | reset_params=False if opt.continue_training else True,
174 | )
175 | if not opt.continue_training:
176 | step = 0
177 | logger.info(f"Model loaded from {opt.model_path}")
178 |
179 | logger.info(utils.get_parameters(model))
180 |
181 | if dist.is_initialized():
182 | model = torch.nn.parallel.DistributedDataParallel(
183 | model,
184 | device_ids=[opt.local_rank],
185 | output_device=opt.local_rank,
186 | find_unused_parameters=False,
187 | )
188 | dist.barrier()
189 |
190 | logger.info("Start training")
191 | train(opt, model, optimizer, scheduler, step)
192 |
--------------------------------------------------------------------------------
/cross_task_cross_eval/create_cross_task_data.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2 |
3 | import pandas as pd
4 | import collections
5 | import numpy
6 | import json
7 | import jsonlines
8 | import csv
9 | import os
10 | import random
11 | import tqdm
12 | import datasets
13 | wikiqa = datasets.load_dataset("wiki_qa")
14 |
15 | def load_jsonlines(file_name):
16 | with jsonlines.open(file_name, 'r') as jsonl_f:
17 | data = [obj for obj in jsonl_f]
18 | return data
19 |
20 |
21 | os.mkdir("ambig")
22 | os.mkdir("wikiqa")
23 | os.mkdir("gooqa_tech")
24 | os.mkdir("linkso_py")
25 | os.mkdir("codesearch_py")
26 |
27 |
28 | # WIKIQA
29 | corpus = {}
30 | for split in ["train", "test", "validation"]:
31 | for item in wikiqa[split]:
32 | corpus.setdefault(item["document_title"], [])
33 | if item["answer"] not in corpus[item["document_title"]]:
34 | corpus[item["document_title"]].append(item["answer"])
35 |
36 | final_corpus = []
37 | for title in corpus:
38 | for idx, doc in enumerate(corpus[title]):
39 | final_corpus.append({"title": title, "text": doc, "_id":"{0}_{1}".format(title, idx), "metadata": {}})
40 |
41 | final_qrel_data = []
42 | final_queries = []
43 | for item in wikiqa["validation"]:
44 | question_id = item["question_id"]
45 | question = item["question"]
46 | if item["label"] == 1:
47 | corpus_id = corpus[item["document_title"]].index(item["answer"])
48 | final_queries.append({"_id": "wikiqa_{}".format(question_id), "text": question, "metadata": {}})
49 | final_qrel_data.append({"query-id": "wikiqa_{}".format(question_id), "corpus-id": "{0}_{1}".format(item["document_title"], corpus_id), "score": 1})
50 | for item in wikiqa["test"]:
51 | question_id = item["question_id"]
52 | question = item["question"]
53 | if item["label"] == 1:
54 | corpus_id = corpus[item["document_title"]].index(item["answer"])
55 | final_queries.append({"_id": "wikiqa_{}".format(question_id), "text": question, "metadata": {}})
56 | final_qrel_data.append({"query-id": "wikiqa_{}".format(question_id), "corpus-id": "{0}_{1}".format(item["document_title"], corpus_id), "score": 1})
57 |
58 | q2dic = {}
59 | for item in final_queries:
60 | q2dic[item["_id"]] = item
61 | final_queries = list(q2dic.values())
62 |
63 | with jsonlines.open('wikiqa/queries.jsonl', 'w') as writer:
64 | writer.write_all(final_queries)
65 | with jsonlines.open('wikiqa/corpus.jsonl', 'w') as writer:
66 | writer.write_all(final_corpus)
67 | with open('wikiqa/qrels/test.tsv', 'wt') as out_file:
68 | tsv_writer = csv.writer(out_file, delimiter='\t')
69 | tsv_writer.writerow(['query-id', 'corpus-id', "score"])
70 | for item in final_qrel_data:
71 | tsv_writer.writerow([item["query-id"], item["corpus-id"], item["score"]])
72 |
73 |
74 | # Ambig QA
75 | ambigqa_path = "data/ambignq_light/"
76 | ambigqa_dev_data = load_jsonlines(ambigqa_path + "dev_light.json")[0]
77 | ambigqa_train_data = load_jsonlines(ambigqa_path + "train_light.json")[0]
78 |
79 | ambig_evals_qq = {}
80 | final_queries = []
81 | final_corpus = []
82 | count = 0
83 | for item in ambigqa_train_data:
84 | for an in item["annotations"]:
85 | if an["type"] == "multipleQAs":
86 | qa_pairs = an["qaPairs"]
87 | for q in qa_pairs:
88 | final_corpus.append({"_id": "ambig_train_{0}".format(count), "text": q["question"], "title": "", "metadata": {}})
89 | count += 1
90 |
91 | final_qrels = []
92 | count = 0
93 | for item in ambigqa_dev_data:
94 | for an in item["annotations"]:
95 | if an["type"] == "multipleQAs":
96 | qa_pairs = an["qaPairs"]
97 | for qa in qa_pairs:
98 | target_id = "ambig_test_{0}".format(count)
99 | final_corpus.append({"_id": "ambig_test_{0}".format(count), "text": qa["question"], "title": "" , "metadata": {}})
100 | final_queries.append({"_id": "ambig_nq_source_{}".format(item["id"]), "text": item["question"], "metadata": {}})
101 | final_qrels.append({"corpus-id": "ambig_test_{0}".format(count), "query-id": "ambig_nq_source_{}".format(item["id"]), "score": 1})
102 | count += 1
103 |
104 | q2dic = {}
105 | for item in final_queries:
106 | q2dic[item["_id"]] = item
107 | final_queries = list(q2dic.values())
108 |
109 | with jsonlines.open('ambig/queries.jsonl', 'w') as writer:
110 | writer.write_all(final_queries)
111 | with jsonlines.open('ambig/corpus.jsonl', 'w') as writer:
112 | writer.write_all(final_corpus)
113 | with open('ambig/qrels/test.tsv', 'wt') as out_file:
114 | tsv_writer = csv.writer(out_file, delimiter='\t')
115 | tsv_writer.writerow(['query-id', 'corpus-id', "score"])
116 | for item in final_qrels:
117 | tsv_writer.writerow([item["query-id"], item["corpus-id"], item["score"]])
118 |
119 |
120 | # GooAQ technical
121 | gooqa_technical = []
122 | gooaq_data = load_jsonlines("data/gooaq.jsonl")
123 |
124 | for item in gooaq_data:
125 | if item["answer_type"] == "feat_snip" and item["answer_url"] is not None and "https://" in item["answer_url"] :
126 | url = item["answer_url"].split("https://")[1].split("/")[0]
127 | if url == "stackoverflow.com":
128 | item["url_processed"] = url
129 | gooqa_technical.append(item)
130 |
131 | random_sampled_qooaq_tech = random.sample(gooqa_technical, k=1000)
132 |
133 | full_corpus = []
134 | full_queries = []
135 | full_qrels = []
136 | answer2id = {}
137 |
138 | for idx, item in enumerate(gooqa_technical):
139 | full_corpus.append({"_id": "{0}_{1}".format(item["url_processed"], idx), "text": item["answer"], "title": "", "metadata": {}})
140 | answer2id[item["answer"]] = "{0}_{1}".format(item["url_processed"], idx)
141 |
142 |
143 | for item in random_sampled_qooaq_tech:
144 | full_queries.append({"_id": "gooaq_technical_{}".format(item["id"]), "text": item["question"], "metadata": {}})
145 | corpus_id = answer2id[item["answer"]]
146 | full_qrels.append({"query-id": "gooaq_technical_{}".format(item["id"]), "corpus-id": corpus_id, "score": 1})
147 |
148 |
149 | os.mkdir("gooaq_technical/qrels")
150 | with jsonlines.open('gooaq_technical/queries.jsonl', 'w') as writer:
151 | writer.write_all(full_queries)
152 | with jsonlines.open('gooaq_technical/corpus.jsonl', 'w') as writer:
153 | writer.write_all(full_corpus)
154 | with open('gooaq_technical/qrels/test.tsv', 'wt') as out_file:
155 | tsv_writer = csv.writer(out_file, delimiter='\t')
156 | tsv_writer.writerow(['query-id', 'corpus-id', "score"])
157 | for item in full_qrels:
158 | tsv_writer.writerow([item["query-id"], item["corpus-id"], item["score"]])
159 |
160 |
161 | # LinkSO
162 | def find_duplicated_qestions(dir, lang):
163 | duplicated_q_pairs = []
164 | qid2all = pd.read_csv(os.path.join(dir, "{}_qid2all.txt".format(lang)), sep="\t", header=None)
165 | qid2all_dic = {}
166 | for idx, row in qid2all.iterrows():
167 | qid2all_dic[int(row[0])] = {"title": row[1], "body": row[2]}
168 |
169 | cosin = pd.read_csv(os.path.join(dir, "{}_cosidf.txt".format(lang)), sep="\t")
170 | dup_pair_ids = {}
171 | for idx, row in cosin.iterrows():
172 | if row["label"] == 1:
173 | dup_pair_ids[int(row["qid1"])] = int(row["qid2"])
174 |
175 | test_qs = open(os.path.join(dir, "{}_test_qid.txt".format(lang))).read().split("\n")[:-1]
176 | for q_id in test_qs:
177 | if int(q_id) in dup_pair_ids:
178 | dup_id = dup_pair_ids[int(q_id)]
179 | duplicated_q_pairs.append((qid2all_dic[int(q_id)], qid2all_dic[dup_id]))
180 | return duplicated_q_pairs
181 |
182 | linkso_data_python = "/private/home/akariasai/inst_dpr/preprocessing/linkso_data/topublish/python"
183 | full_corpus, full_queries, full_qrels = find_duplicated_qestions(linkso_data_python, "python")
184 | linkso_dups_python = find_duplicated_qestions(linkso_data_python, "python")
185 | qid2queries = {item["_id"]: item["text"] for item in full_queries}
186 | qid2corpus = {item["_id"]: item for item in full_corpus}
187 |
188 | with jsonlines.open('linkso_py/queries.jsonl', 'w') as writer:
189 | writer.write_all(full_queries)
190 | with jsonlines.open('linkso_py/corpus.jsonl', 'w') as writer:
191 | writer.write_all(full_corpus)
192 | with open('linkso_py/qrels/test.tsv', 'wt') as out_file:
193 | tsv_writer = csv.writer(out_file, delimiter='\t')
194 | tsv_writer.writerow(['query-id', 'corpus-id', "score"])
195 | for item in full_qrels:
196 | tsv_writer.writerow([item["query-id"], item["corpus-id"], item["score"]])
197 |
198 | # CodeSearch Net Py
199 | python_code_serach_net = datasets.load_dataset("code_search_net", "python")
200 | python_short_descs = [item for item in python_code_serach_net["test"] if len(item["func_documentation_string"]) < 300 and len(item["func_documentation_string"]) > 50]
201 |
202 | full_corpus = []
203 | full_queries = []
204 | full_qrels = []
205 | answer2id = {}
206 |
207 | for idx, item in tqdm(enumerate(python_code_serach_net["train"])):
208 | doc_id = "codeserachnet_python_train_{0}_{1}".format(idx, item["func_name"])
209 | if '"""' in item["func_code_string"]:
210 | code = (item["func_code_string"].split('"""')[0] + item["func_code_string"].split('"""')[2]).replace("\n\n", "")
211 | elif "'''" in item["func_code_string"]:
212 | code = (item["func_code_string"].split("'''")[0] + item["func_code_string"].split("'''")[2]).replace("\n\n", "")
213 | else:
214 | code = item["func_code_string"]
215 | full_corpus.append({"_id": doc_id, "text": code, "metadata": {}, "title": "" })
216 | answer2id[code] = doc_id
217 |
218 | for idx, item in tqdm(enumerate(python_code_serach_net["validation"])):
219 | doc_id = "codeserachnet_python_validation_{0}_{1}".format(idx, item["func_name"])
220 | if '"""' in item["func_code_string"]:
221 | code = (item["func_code_string"].split('"""')[0] + item["func_code_string"].split('"""')[2]).replace("\n\n", "")
222 | elif "'''" in item["func_code_string"]:
223 | code = (item["func_code_string"].split("'''")[0] + item["func_code_string"].split("'''")[2]).replace("\n\n", "")
224 | else:
225 | code = item["func_code_string"]
226 | full_corpus.append({"_id": doc_id, "text": code, "metadata": {}, "title": "" })
227 | answer2id[code] = doc_id
228 |
229 | for idx, item in tqdm(enumerate(python_code_serach_net["test"])):
230 | doc_id = "codeserachnet_python_test_{0}_{1}".format(idx, item["func_name"])
231 | if '"""' in item["func_code_string"]:
232 | code = (item["func_code_string"].split('"""')[0] + item["func_code_string"].split('"""')[2]).replace("\n\n", "")
233 | elif "'''" in item["func_code_string"]:
234 | code = (item["func_code_string"].split("'''")[0] + item["func_code_string"].split("'''")[2]).replace("\n\n", "")
235 | else:
236 | code = item["func_code_string"]
237 | full_corpus.append({"_id": doc_id, "text": code, "metadata": {}, "title": "" })
238 | answer2id[code] = doc_id
239 |
240 | random_sampled_python_short_descs= random.sample(python_short_descs, k = 1000)
241 |
242 | for idx, item in enumerate(random_sampled_python_short_descs):
243 | qid = "python_codesearch_{}".format(idx)
244 |
245 | query = item["func_documentation_string"]
246 | if '"""' in item["func_code_string"]:
247 | code = (item["func_code_string"].split('"""')[0] + item["func_code_string"].split('"""')[2]).replace("\n\n", "")
248 | elif "'''" in item["func_code_string"]:
249 | code = (item["func_code_string"].split("'''")[0] + item["func_code_string"].split("'''")[2]).replace("\n\n", "")
250 | else:
251 | code = item["func_code_string"]
252 | corpus_id = answer2id[code]
253 | full_queries.append({"_id": qid, "text": query, "metadata": {}})
254 | full_qrels.append({"corpus-id": corpus_id, "query-id": qid, "score": 1})
255 |
256 | os.mkdir("codesearch_py/qrels")
257 | with jsonlines.open('codesearch_py/queries.jsonl', 'w') as writer:
258 | writer.write_all(full_queries)
259 | with jsonlines.open('codesearch_py/corpus.jsonl', 'w') as writer:
260 | writer.write_all(full_corpus)
261 | with open('codesearch_py/qrels/test.tsv', 'wt') as out_file:
262 | tsv_writer = csv.writer(out_file, delimiter='\t')
263 | tsv_writer.writerow(['query-id', 'corpus-id', "score"])
264 | for item in full_qrels:
265 | tsv_writer.writerow([item["query-id"], item["corpus-id"], item["score"]])
--------------------------------------------------------------------------------
/cross_task_cross_eval/download_create_data.sh:
--------------------------------------------------------------------------------
1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2 |
3 | mkdir data
4 | cd data
5 | wget https://github.com/allenai/gooaq/raw/main/data/gooaq.jsonl
6 | gdown 1X5GoVi_OcRxahXH1pRW7TSesZUeMH3ss
7 | wget https://nlp.cs.washington.edu/ambigqa/data/ambignq_light.zip
8 |
9 | tar xvzf linkso.tar.gz
10 | unzip ambignq_light
11 |
12 | python create_cross_task_data.py
--------------------------------------------------------------------------------
/figures/intro.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/facebookresearch/tart/4ed5fcb0ed0254b1062305adbc390617b296fd29/figures/intro.png
--------------------------------------------------------------------------------