├── BERRI ├── README.md ├── berri_instructions.tsv ├── create_tart_dual_train_data.py ├── create_tart_full_train_data.py ├── denoising.py └── enc_t5 │ ├── __init__.py │ ├── modeling_enc_t5.py │ └── tokenization_enc_t5.py ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── TART ├── custom_metrics.py ├── eval_beir.py ├── eval_cross_task.py ├── finetuning.py ├── finetuning_tart_full.py ├── generate_passage_embeddings.py ├── interactive.py ├── passage_retrieval.py ├── preprocess.py ├── requirements.txt ├── src │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-39.pyc │ │ ├── contriever.cpython-39.pyc │ │ ├── data.cpython-39.pyc │ │ ├── dist_utils.cpython-39.pyc │ │ ├── evaluation.cpython-39.pyc │ │ ├── index.cpython-39.pyc │ │ ├── modeling_enc_t5.cpython-39.pyc │ │ ├── normalize_text.cpython-39.pyc │ │ ├── slurm.cpython-39.pyc │ │ ├── tokenization_enc_t5.cpython-39.pyc │ │ └── utils.cpython-39.pyc │ ├── beir_utils.py │ ├── contriever.py │ ├── data.py │ ├── dist_utils.py │ ├── evaluation.py │ ├── finetuning_data.py │ ├── inbatch.py │ ├── index.py │ ├── moco.py │ ├── modeling_enc_t5.py │ ├── normalize_text.py │ ├── options.py │ ├── rerank.py │ ├── slurm.py │ ├── tokenization_enc_t5.py │ └── utils.py └── train.py ├── cross_task_cross_eval ├── create_cross_task_data.py └── download_create_data.sh └── figures └── intro.png /BERRI/README.md: -------------------------------------------------------------------------------- 1 | # BERRI 2 | 3 | **BERRI** is a collection of retrieval task with instructions. 4 | 5 | Due to some legal reasons, Meta cannot directly release the preprocessed scripts, so this repository contains the script to re-process and create data. 6 | You can also download the data processed by third party below: 7 | 8 | You can download the processed source data (from the process (i)) as well as the final training data for TART-dual and full, processed by a third party here: 9 | - [source data (22 GB)](https://drive.google.com/file/d/1hzlN4cEFOZRkdVeCMq62NUxvMNTopB1o/view?usp=share_link) 10 | - [TART-full training data (1 GB)](https://drive.google.com/file/d/1oijzAb2gWKT54OgeE7_KB9VcHvA7UxpQ/view?usp=share_link) 11 | - [TART-dual training data (14 GB)](https://drive.google.com/file/d/1lMmD5lTxYWYf0z0ua0-GaGKz2qs2mG1r/view?usp=share_link) 12 | 13 | 14 | ## Preprocessing 15 | First please download the corpus (`corpus.tsv`) and source data file (`qa.jsonl`) for all retrieval tasks [here]((https://drive.google.com/file/d/1hzlN4cEFOZRkdVeCMq62NUxvMNTopB1o/view?usp=share_link). 16 | 17 | 18 | ### Step 1: run Contriever to find the top documents 19 | 20 | - generate embeddings 21 | First, generate passage embeddings using `facebook/contriever-msmarco`. 22 | 23 | ```sh 24 | cd ../TART 25 | for i in {0..7}; do 26 | export CUDA_VISIBLE_DEVICES=${i} 27 | nohup python generate_passage_embeddings.py --model_name_or_path facebook/contriever-msmarco --output_dir OUTPUT_DIR_NAME \ 28 | --passages ../BERRI/berri_corpus_data/TASK_NAME/corpus.tsv --shard_id ${i} --num_shards 8 > ./log/nohup.log.${i} 2>&1 & 29 | done 30 | ``` 31 | 32 | Then, retrieve top passages as follows: 33 | 34 | ``` 35 | python passage_retrieval.py \ 36 | --model_name_or_path facebook/contriever-msmarco \ 37 | --passages ../BERRI/berri_corpus_data/TASK_NAME/corpus.tsv \ 38 | --passages_embeddings "YOUR_EMBEDDING_PATH/*" \ 39 | --data ../BERRI/berri_corpus_data/TASK_NAME/qa_data.json \ 40 | --output_dir PATH_TO_OUTPUT_DIR --n_docs 100 41 | ``` 42 | 43 | ### Step 2: Denoise the passages 44 | 45 | ``` 46 | python denoising.py \ 47 | --task_name TASK_NAME \ 48 | --train_file PATH_TO_TRAIN_FILE \ 49 | --test_input_file output_dir/qa_data.json \ 50 | --model_name_or_path PATH_TO_DENOISING_MODEL_NAME \ 51 | --output_dir PATH_TO_OUTPUT_DIR \ 52 | --do_predict \ 53 | --evaluation_strategy steps \ 54 | --max_seq_length 512 --overwrite_cache --top_k 30 \ 55 | --instruction_file berri_instructions.tsv # only for creating tart-dual training data. 56 | ``` 57 | 58 | ### Step 3: Combine denoised results and create training data 59 | 60 | - Creating TART-dual training data 61 | ``` 62 | python create_tart_dual_train_data.py \ 63 | --inst_file berri_instructions.tsv \ 64 | --output_file PATH_TO_OUTPUT_DIR \ 65 | --input_dir PATH_TO_DENOISED_RESULTS \ 66 | ``` 67 | - Creating TART-full training data 68 | ``` 69 | python create_tart_full_train_data.py \ 70 | --input_file berri_instructions.tsv \ 71 | --output_dir --output_file PATH_TO_OUTPUT_DIR \ 72 | --input_dir PATH_TO_DENOISED_RESULTS \ 73 | ``` 74 | 75 | 76 | 77 | -------------------------------------------------------------------------------- /BERRI/berri_instructions.tsv: -------------------------------------------------------------------------------- 1 | dataset prompt_1 prompt_2 prompt_3 prompt_4 prompt_5 prompt_6 prompt_7 prompt_8 2 | altlex Find an sentence from Simple Wikipedia that corresponds to the following Wikipedia sentence Retrieve a sentence that talks about the same as the following in a simple or shorter English A simplified sentence from Simple Wikipedia of this Wikipedia sentence You need to find a simplified sentence from Simple Wikipedia corresponding to the following sentence 3 | cnn_dailymail The following sentences are the summaries of an news article. Find the source news article. Retrieve a news article that is summarized as following Find the original news article based on which the following highlights are written 4 | coco_captions Retrieve a image caption sentence that is written for the same image as the following caption Find a image caption describing the same image as Can you find an image caption talking about the same image as 5 | codesearch_go Match the following natural language instruction to Go codes Retrieve Go implementations achieving the following features Find a Go code implementation on GitHub for the following natural language instruction I want to find a Go code implementing the following function from Github. 6 | codesearch_java Match the following natural language instruction to Java codes Retrieve Java implementations achieving the following features Find a Java code implementation on GitHub for the following natural language instruction I want to find Java code implementing the following function from Github. 7 | codesearch_javascript Match the following natural language instruction to JavaScript codes Retrieve JavaScript implementations achieving the following features Find a JavaScript code implementation on GitHub for the following natural language instruction I want to find JavaScript code implementing the following function from Github. 8 | codesearch_ruby Match the following natural language instruction to Ruby codes Retrieve Ruby implementations achieving the following features Find a Ruby code implementation on GitHub for the following natural language instruction I want to find Ruby code implementing the following function from Github. 9 | eli5_question_answer Find an answer to this question for me. Retrieve an paragraph-length answer to this question. You have to find answer to this question. 10 | fever Retrieve a Wikipedia paragraph to verify this claim Find an evidence paragraph from Wikipedia to confirm the statement is correct I want to know if this sentence is a fact or not. Can you find relented Wikipedia passages for me? You need to find Wikipedia paragraphs that support or refute this sentence 11 | gigaword Retrieve a extremely short summary of the following Gigaword article Find a corresponding headline of the this Gigaword article summary. I want to retrieve a headline of this gigaword news 12 | hotpotqa Retrieve a Wikipedia paragraph that provides useful evidence or answer for this question I want to know the answer of this question. Please find related Wikipedia passages for me. Find a paragraph that provides useful information to answer this question You need to find multiple Wikipedia passages to answer this multi-hop question. Can you find related passages? 13 | yahoo_answers_title_answer Find an answer from a community QA forum for the following question Retrieve the most voted answer for this question from Yahoo Answers. This is an question's title posted in Yahoo Answers, a community forum. Please retrieve the answer post. 14 | mdmcqa Find an evidence from medical text book to answer this medical exam question I want to know the answer of this following medical exam question. Retrieve an evidence passage answering question Find the explanation for the correct answer of this medical question 15 | medical_sim Please retrieve a medical paper summary that is written in a simple language so that my patient can understand You need to find a simple summary of medical paper that corresponds to the following paper abstract You need to retrieve a medical paper summary written in a simple English of this paper Match this following medical paper abstract to a simplified English summary for non experts 16 | msmarco-triplets I want to know the answer to the question. Can you find good evidence on the web? Retrieve a web paragraph that answers the following Find an evidence paragraph on the web with evidence and answer for this 17 | multilexsum Find a short summary of this following legal case Map this legal case summary to a sentence-long summary An extremely short summary sentence of this legal case Retrieve a one-sentence summary of the following legal case 18 | nq retrieve passages from Wikipedia that provides answers to the following question You have to find a Wikipedia paragraph that provides the answer to the question I want to find an answer for question. Can you find some paragraphs that provide evidence from Wikipedia? I'm looking for a Wikipedia paragraph that answers this question. Give me a Wikipedia paragraph answering this open-domain question. A Wikipedia paragraph providing sufficient evidence to answer this question Your job is to find a Wikipedia paragraph that answers my question You need to retrieve an evidence paragraph from Wikipedia to answer this question 19 | oqa Find a question that is paraphrased of this Retrieve a question that is duplicated with the following You need to find duplicate questions in Wiki forum. Could you find a question that is similar to this question Find an open-domain question that is similar to the following An open-domain question that is duplicated with the following 20 | agnews Find a news summary sentence corresponding to the following header Retrieve an sentence-long news summary for this header Please find a good summary of this news from the news summary collection 21 | pubmedqa Find a related medical paper to answer the following question I want to check if the following statement is true or not. Retrieve a scientific paper from PubMed for me Help me to find a highly related pubmed paper to answer this question Retrieve a medical paper abstract answering the following question 22 | qrecc Find a meaningful dialogue response to answer the user's question Retrieve a good dialogue response that answers this question You need to find a good response from a collection of previous responses and help users to know this topic more 23 | record Find a News article to verify the following sentence I want to know if this sentence is true. Please retrieve a highly-relevant news article News articles that provide a piece of sufficient evidence to verify the following statement 24 | scitldr Find a sentence-length summary of this paper. Retrieve a really short summary of this paper abstract. What is the TLDR of this paper? Retrieve a sentence that summarizes the following abstract 25 | searchQA_top5_snippets Pick up the top web search results' snippets for the following question. Find the top 5 Web snippets that answer this You have to match this question to top five web search snippets. 26 | sentence-compression Find a short sentence that compresses the following long sentence You have to match this long sentence to a shorter compressed one Retrieve a compressed version of the following sentence written by human annotators 27 | npr Given a news article headline published at npr.org, find a corresponding summary of the news Retrieve a news summary of the following news headline You have to match the following news article to a short summary 28 | squad_pairs.jsonl Find a Wikipedia paragraph that answer the question You have to retrieve a paragraph from Wikipedia that provides evidence and answer for this question This question asks about the details written in a Wikipedia paragraph. Select the paragraph the question is about 29 | stackexchange_duplicate_questions_title-body_title-body Find a question paragraph that is duplicated with the following question paragraph at StackExchange. I want to find a question similar to this question already asked in StackExchange. Retrieve a question (main text) that is similar to this following paragraph StackExchange is a community QA forum for diverse topics including technical or science. Help me to find a question paragraph that duplicates with my following question paragraph 30 | stackexchange_duplicate_questions_title_title Find an duplicated question on StackExchange, a community forum. Find a question title that is similar to the following question title asked in StackExchange, a community QA forum I want to find a related question asked in StackExchange. Can you find one for me? 31 | paq Find a web paragraph long answer for this question Can you answer my question by finding an article on the web? Retrieve a paragraph answer for the following question 32 | triviaqa retrieve a related web article that provides evidences and answers for the following Trivia question You have to find a answer paragraph on the web for this question I want to find an answer for the following Trivia questions. Can you find some paragraphs that provide evidence on the web? 33 | wikihow Find the how-to article to achieve the following goal from Wikihow, a website collecting how-to articles. WikiHow is an wiki-style community and database of how-to guides. Suggest the best article for the following Find a detailed paragraph from WikiHow that explains how-to to achieve 34 | wow Find an Wikipedia paragraph that is related to the current conversation topic to generate a meaningful response. Retrieve a paragraph from Wikipedia to help a conversational AI to generate a knowledge-grounded dialogue You must find a Wikipedia paragraph to help a chatbot to generate informative response. Can you find one? Find an Wikipedia paragraph related to the following conversation topic. 35 | xsum Find a short summary of this following legal case Map this legal case summary to a sentence-long summary An extremely short summary sentence of this legal case 36 | quora Find a question that is duplicated with the following question asked in Quora, a community QA forum I want to find a question similar to this question already asked in Quora. Retrieve a question body that is similar to Check if a Quora question is duplicated with this question 37 | ccnews Retrieve a news article that corresponds to this title Find the original news article that this title is written for I want to know the details of this news. Can you find a detailed news article on this for me? 38 | wow response Find a knowledgeable response for an AI chat bot given the following user's inputs Retrieve a informative knowledge-grounded dialogue response given this AI-user chat history A good chat bot response to answer this user's query -------------------------------------------------------------------------------- /BERRI/create_tart_dual_train_data.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import os 7 | import argparse 8 | import json 9 | import glob 10 | import random 11 | import jsonlines 12 | import numpy as np 13 | import pandas as pd 14 | 15 | import numpy as np 16 | from tqdm import tqdm 17 | import copy 18 | 19 | 20 | def load_data(data_path): 21 | data = [] 22 | with open(data_path, "r") as fin: 23 | for k, example in enumerate(fin): 24 | example = json.loads(example) 25 | data.append(example) 26 | return data 27 | 28 | 29 | def process_instruction_unfollowing_sample(input_data, prompts): 30 | new_data = {} 31 | for item in input_data: 32 | query = item["question"] 33 | sampled_context = random.sample(item["ctxs"], k=1) 34 | new_data[query] = sampled_context 35 | return new_data 36 | 37 | 38 | def load_jsonlines(file_name): 39 | with jsonlines.open(file_name, 'r') as jsonl_f: 40 | data = [obj for obj in jsonl_f] 41 | return data 42 | 43 | 44 | def process_prompts(row): 45 | # load instructions 46 | prompts = [] 47 | for i in range(1, 9): 48 | column_name = "prompt_{}".format(i) 49 | if type(column_name) == str: 50 | prompts.append(row[column_name]) 51 | return prompts 52 | 53 | 54 | def main(args): 55 | print(f"Loading model from: {args.model_name_or_path}") 56 | final_data = [] 57 | min_neg_nums = 100 58 | file_names = pd.read_csv(args.inst_file, sep="\t") 59 | 60 | if args.instruction_unfollowing_file is not None: 61 | inst_unfollow_dict = {} 62 | inst_unfollow_file_names = open( 63 | args.instruction_unfollowing_file).read().split("\n")[:-1] 64 | # Load instruction unfollowing files. 65 | # src indicates the original task name (query task) while tgt indicate the corpus task name. 66 | for file in inst_unfollow_file_names: 67 | src_task_name = (file.split("src_")[1]).split("_tgt")[0] 68 | inst_unfollow_dict.setdefault(src_task_name, []) 69 | inst_unfollow_dict[src_task_name].append(file) 70 | 71 | for _, file_data in tqdm(file_names.iterrows()): 72 | input_file = os.path.join( 73 | args.input_dir, file_data["dataset"] + ".jsonl") 74 | prompts = process_prompts(file_data) 75 | input_data = load_jsonlines(input_file) 76 | # mostly for ablation. Skip tasks that are specified in `ignore_tasks`, or that are not specified in `task_names`. 77 | if args.ignore_tasks is not None and file_data["dataset"] in args.ignore_tasks: 78 | print("skip {0}".format(file_data["dataset"])) 79 | continue 80 | if args.task_names is not None and file_data["dataset"] not in args.task_names: 81 | print("skip {0}".format(file_data["dataset"])) 82 | continue 83 | print("# of data: {}".format(len(input_data))) 84 | 85 | # add instruction-unfollowing data. 86 | if args.instruction_unfollowing_file is not None and file_data["dataset"] in inst_unfollow_dict: 87 | for unfollowing_file_name in inst_unfollow_dict[file_data["dataset"]]: 88 | unfollowing_file_name = glob.glob( 89 | unfollowing_file_name + "/*.json*")[0] 90 | print(unfollowing_file_name) 91 | unfollowing_input_data = load_data(unfollowing_file_name) 92 | processed_training_data_unfollowing = process_instruction_unfollowing_sample( 93 | unfollowing_input_data, prompts) 94 | 95 | for datapoint in input_data: 96 | if args.instruction_unfollowing_file is not None and file_data["dataset"] in inst_unfollow_dict and datapoint["question"] in processed_training_data_unfollowing: 97 | instructions_unfollowing_negatives = processed_training_data_unfollowing[ 98 | datapoint["question"]] 99 | else: 100 | instructions_unfollowing_negatives = [] 101 | if len(datapoint["question"]) < 20: 102 | continue 103 | 104 | # To make a model robust to different instructions, sample two prompts per instance. 105 | if len(prompts) > 2: 106 | sampled_prompts = random.sample(prompts, k=2) 107 | else: 108 | sampled_prompts = prompts 109 | true_negatives = [] 110 | true_hard_negatives = [] 111 | # final check if the paragraph is not included in the labeled gold paragraphs. 112 | for neg in datapoint["negative_ctxs"]: 113 | skip = False 114 | for pos in datapoint["positive_ctxs"]: 115 | if neg["text"] == pos["text"]: 116 | print("false negatives") 117 | skip = True 118 | if skip is False: 119 | true_negatives.append(neg) 120 | 121 | # a cross encoder can falsely predict a positive paragraph "negative". 122 | # check if the paragraph is not included in the labeled gold paragraphs. 123 | for neg in datapoint["hard_negative_ctxs"]: 124 | skip = False 125 | for pos in datapoint["positive_ctxs"]: 126 | if neg["text"] == pos["text"]: 127 | print("false negatives") 128 | skip = True 129 | if skip is False: 130 | true_hard_negatives.append(neg) 131 | datapoint["negative_ctxs"] = true_negatives 132 | datapoint["hard_negative_ctxs"] = true_hard_negatives 133 | 134 | # create training samples. 135 | 136 | for p in sampled_prompts: 137 | new_data = copy.deepcopy(datapoint) 138 | new_data["question"] = "{0} [SEP] {1}".format( 139 | p, new_data["question"]) 140 | if args.only_negative is True: 141 | # do not add additional CE scores 142 | new_data["positive_ctxs"] = [ 143 | pos for pos in new_data["positive_ctxs"] if "ce_score" not in pos] 144 | 145 | if len(new_data["positive_ctxs"]) > args.num_positive_paragraphs: 146 | new_data["positive_ctxs"] = random.sample( 147 | new_data["positive_ctxs"], k=args.num_positive_paragraphs) 148 | for pos in new_data["positive_ctxs"]: 149 | 150 | if "title" not in pos: 151 | neg["title"] = "" 152 | 153 | if pos["title"] is None: 154 | print("none title") 155 | pos["title"] = "" 156 | 157 | if type(pos["text"]) is list: 158 | pos["text"] = pos["text"][0] 159 | 160 | if len(new_data["negative_ctxs"]) > args.num_negative_paragraphs: 161 | new_data["negative_ctxs"] = random.sample( 162 | new_data["negative_ctxs"], k=args.num_negative_paragraphs) 163 | for neg in new_data["negative_ctxs"]: 164 | if type(neg["text"]) is list: 165 | neg["text"] = neg["text"][0] 166 | 167 | if "title" not in neg: 168 | neg["title"] = "" 169 | 170 | if neg["title"] is None: 171 | neg["title"] = "" 172 | 173 | if len(new_data["hard_negative_ctxs"]) > args.num_hard_negative_paragraphs: 174 | new_data["hard_negative_ctxs"] = random.sample( 175 | new_data["hard_negative_ctxs"], k=args.num_hard_negative_paragraphs) 176 | for neg in new_data["hard_negative_ctxs"]: 177 | if type(neg["text"]) is list: 178 | neg["text"] = neg["text"][0] 179 | if "title" not in neg: 180 | neg["title"] = "" 181 | if neg["title"] is None: 182 | neg["title"] = "" 183 | 184 | if len(instructions_unfollowing_negatives) > args.num_instructions_unfollowing_negatives: 185 | new_data["hard_negative_ctxs"] = random.sample( 186 | new_data["hard_negative_ctxs"], k=args.num_instructions_unfollowing_negatives) 187 | for neg in instructions_unfollowing_negatives: 188 | if type(neg["text"]) is list: 189 | neg["text"] = neg["text"][0] 190 | if "title" not in neg: 191 | neg["title"] = "" 192 | if neg["title"] is None: 193 | neg["title"] = "" 194 | new_data["hard_negative_ctxs"].append(neg) 195 | 196 | assert len(new_data["positive_ctxs"]) > 0 197 | final_data.append(new_data) 198 | if len(new_data["negative_ctxs"]) + len(new_data["hard_negative_ctxs"]) < min_neg_nums: 199 | min_neg_nums = len( 200 | new_data["negative_ctxs"]) + len(new_data["hard_negative_ctxs"]) 201 | print("# of data: {}".format(len(final_data))) 202 | 203 | random.shuffle(final_data) 204 | # split data into train and dev set for development purpose. 205 | train_data, dev_data = final_data[5000:], final_data[:5000] 206 | 207 | with jsonlines.open(args.output_file + "_train.jsonl", "w") as writer: 208 | writer.write_all(train_data) 209 | with jsonlines.open(args.output_file + "_dev.jsonl", "w") as writer: 210 | writer.write_all(dev_data) 211 | 212 | 213 | if __name__ == "__main__": 214 | parser = argparse.ArgumentParser() 215 | 216 | parser.add_argument( 217 | "--inst_file", 218 | required=True, 219 | type=str, 220 | default=None, 221 | help=".json file containing question and answers, similar format to reader data", 222 | ) 223 | 224 | parser.add_argument( 225 | "--only_negative", 226 | action="store_true" 227 | ) 228 | 229 | parser.add_argument( 230 | "--kd", 231 | action="store_true" 232 | ) 233 | parser.add_argument( 234 | "--ignore_tasks", 235 | type=str, 236 | nargs="+" 237 | ) 238 | 239 | parser.add_argument( 240 | "--input_dir", 241 | required=True, 242 | type=str, 243 | default=None, 244 | help=".json file containing question and answers, similar format to reader data", 245 | ) 246 | parser.add_argument("--output_file", type=str, default=None, 247 | help="Path to passages (.tsv file)") 248 | parser.add_argument("--sample_dataset_num", type=int, default=None, 249 | help="Path to passages (.tsv file)") 250 | parser.add_argument("--output_dir", type=str, default=None, 251 | help="Path to passages (.tsv file)") 252 | parser.add_argument("--task_names", type=str, default=None, 253 | help="task names filter", nargs="+") 254 | parser.add_argument("--prompts", type=str, default=None, 255 | help="prompt", nargs="+") 256 | parser.add_argument("--per_gpu_batch_size", type=int, 257 | default=64, help="Batch size for question encoding") 258 | parser.add_argument("--n_docs", type=int, default=100, 259 | help="Number of documents to retrieve per questions") 260 | parser.add_argument("--instruction_unfollowing_file", type=str, default=None, 261 | help="instruction unfollowing file") 262 | parser.add_argument( 263 | "--model_name_or_path", type=str, help="path to directory containing model weights and config file" 264 | ) 265 | parser.add_argument("--no_fp16", action="store_true", 266 | help="inference in fp32") 267 | parser.add_argument("--question_maxlength", type=int, 268 | default=512, help="Maximum number of tokens in a question") 269 | parser.add_argument("--start_idx", type=int, default=None,) 270 | parser.add_argument("--end_idx", type=int, default=None,) 271 | parser.add_argument("--num_positive_paragraphs", type=int, default=2,) 272 | parser.add_argument("--num_negative_paragraphs", type=int, default=7,) 273 | parser.add_argument("--num_hard_negative_paragraphs", type=int, default=3,) 274 | parser.add_argument( 275 | "--num_instructions_unfollowing_negatives", type=int, default=1,) 276 | args = parser.parse_args() 277 | main(args) 278 | -------------------------------------------------------------------------------- /BERRI/create_tart_full_train_data.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import os 7 | import argparse 8 | import json 9 | import glob 10 | import random 11 | import jsonlines 12 | import pandas as pd 13 | 14 | from tqdm import tqdm 15 | 16 | 17 | def load_jsonlines(file_name): 18 | with jsonlines.open(file_name, 'r') as jsonl_f: 19 | data = [obj for obj in jsonl_f] 20 | return data 21 | 22 | 23 | def process_data(input_data, prompts): 24 | new_data = [] 25 | false_negatives = 0 26 | for item in input_data: 27 | query = item["question"] 28 | positive_ctxs = item["positive_ctxs"] 29 | if len(positive_ctxs) > args.num_positive_ctxs: 30 | positive_ctxs = random.sample( 31 | positive_ctxs, k=args.num_positive_ctxs) 32 | negative_ctxs = item["negative_ctxs"] + \ 33 | item["hard_negative_ctxs"] if "hard_negative_ctxs" in item else item["negative_ctxs"] 34 | final_negatives = [] 35 | for neg in negative_ctxs: 36 | if neg["text"] in [pos["text"] for pos in item["positive_ctxs"]]: 37 | false_negatives += 1 38 | continue 39 | else: 40 | final_negatives.append(neg) 41 | negative_ctxs = final_negatives 42 | 43 | if len(negative_ctxs) > args.num_negative_ctxs: 44 | negative_ctxs = random.sample( 45 | negative_ctxs, k=args.num_negative_ctxs) 46 | 47 | for pos in positive_ctxs: 48 | prompt = random.sample(prompts, k=1)[0] 49 | prompted_query = "{0} {1}".format(prompt, query) 50 | 51 | title = pos["title"] 52 | text = pos["text"] 53 | if title is not None and len(title) > 0: 54 | text = "{0} {1}".format(title, text) 55 | new_data.append( 56 | {"query": prompted_query, "document": text, "label": 1}) 57 | for neg in negative_ctxs: 58 | prompt = random.sample(prompts, k=1)[0] 59 | prompted_query = "{0} {1}".format(prompt, query) 60 | 61 | title = neg["title"] 62 | text = neg["text"] 63 | if title is not None and len(title) > 0: 64 | text = "{0} {1}".format(title, text) 65 | new_data.append( 66 | {"query": prompted_query, "document": text, "label": 0}) 67 | print("{} data created".format(len(new_data))) 68 | print("false negatives {}".format(false_negatives)) 69 | return new_data 70 | 71 | 72 | def process_instruction_unfollowing_sample(input_data, prompts, full_data_num): 73 | inst_num = min(int(full_data_num * 0.8 * 0.2), 10000) 74 | if len(input_data) > 7500: 75 | input_data = random.sample(input_data, k=inst_num) 76 | new_data = [] 77 | for item in input_data: 78 | query = item["question"] 79 | sampled_context = random.sample(item["ctxs"], k=1) 80 | for ctx in sampled_context: 81 | prompt = random.sample(prompts, k=1)[0] 82 | prompted_query = "{0} {1}".format(prompt, query) 83 | 84 | title = ctx["title"] 85 | text = ctx["text"] 86 | if len(title) > 0: 87 | text = "{0} {1}".format(title, text) 88 | new_data.append( 89 | {"query": prompted_query, "document": text, "label": 0}) 90 | print("instructions unfollowing samples") 91 | print("{} data created".format(len(new_data))) 92 | return new_data 93 | 94 | 95 | def process_prompts(row): 96 | # load instructions 97 | prompts = [] 98 | for i in range(1, 9): 99 | column_name = "prompt_{}".format(i) 100 | if type(column_name) == str: 101 | prompts.append(row[column_name]) 102 | return prompts 103 | 104 | 105 | def load_data(data_path): 106 | data = [] 107 | with open(data_path, "r") as fin: 108 | for k, example in enumerate(fin): 109 | example = json.loads(example) 110 | data.append(example) 111 | return data 112 | 113 | 114 | def main(args): 115 | if not os.path.exists(args.output_dir): 116 | os.makedirs(args.output_dir) 117 | 118 | all_data = [] 119 | file_names = pd.read_csv(args.input_file, sep="\t") 120 | if args.instruction_unfollowing_file is not None: 121 | inst_unfollow_dict = {} 122 | inst_unfollow_file_names = open( 123 | args.instruction_unfollowing_file).read().split("\n")[:-1] 124 | for file in inst_unfollow_file_names: 125 | src_task_name = (file.split("src_")[1]).split("_tgt")[0] 126 | inst_unfollow_dict.setdefault(src_task_name, []) 127 | inst_unfollow_dict[src_task_name].append(file) 128 | 129 | if args.sample_dataset_num is not None: 130 | dataset_names = file_names["dataset"] 131 | sampled_datasets = random.sample( 132 | list(dataset_names), k=args.sample_dataset_num) 133 | 134 | for idx, file_data in tqdm(file_names.iterrows()): 135 | if args.task_names is not None and file_data["dataset"] not in args.task_names: 136 | continue 137 | if args.start_idx is not None and idx < args.start_idx: 138 | continue 139 | if args.end_idx is not None and idx > args.end_idx: 140 | continue 141 | 142 | if os.path.exists(os.path.join(args.input_dir, file_data["dataset"] + ".jsonl")) is False: 143 | print("file name is mssing ") 144 | print(os.path.join(args.input_dir, 145 | file_data["dataset"] + ".jsonl")) 146 | print(file_data["dataset"]) 147 | continue 148 | 149 | if args.sample_dataset_num is not None and file_data["dataset"] not in sampled_datasets: 150 | continue 151 | 152 | task_input_file = os.path.join( 153 | args.input_dir, file_data["dataset"] + ".jsonl") 154 | print(task_input_file) 155 | prompts = process_prompts(file_data) 156 | input_data = load_data(task_input_file) 157 | processed_training_data = process_data(input_data, prompts) 158 | all_data += processed_training_data 159 | 160 | if args.instruction_unfollowing_file is not None and file_data["dataset"] in inst_unfollow_dict: 161 | for unfollowing_file_name in inst_unfollow_dict[file_data["dataset"]]: 162 | unfollowing_file_name = glob.glob( 163 | unfollowing_file_name + "/*.json*")[0] 164 | print(unfollowing_file_name) 165 | unfollowing_input_data = load_data(unfollowing_file_name) 166 | processed_training_data_unfollowing = process_instruction_unfollowing_sample( 167 | unfollowing_input_data, prompts, len(processed_training_data)) 168 | all_data += processed_training_data_unfollowing 169 | 170 | if len(all_data) > args.max_train_data: 171 | all_data = random.sample(all_data, k=args.max_train_data) 172 | 173 | random.shuffle(all_data) 174 | train_data, dev_data = all_data[10000:], all_data[:10000] 175 | 176 | with jsonlines.open(os.path.join(args.output_dir, "tart_full_train.json"), 'w') as writer: 177 | writer.write_all(train_data) 178 | with jsonlines.open(os.path.join(args.output_dir, "tart_full_dev.json"), 'w') as writer: 179 | writer.write_all(dev_data) 180 | 181 | 182 | if __name__ == "__main__": 183 | parser = argparse.ArgumentParser() 184 | 185 | parser.add_argument( 186 | "--input_file", 187 | required=True, 188 | type=str, 189 | default=None, 190 | help=".json file containing question and answers, similar format to reader data", 191 | ) 192 | parser.add_argument("--output_file", type=str, default=None, 193 | help="Path to passages (.tsv file)") 194 | parser.add_argument("--output_dir", type=str, default=None, 195 | help="Path to passages (.tsv file)") 196 | parser.add_argument("--input_dir", type=str, default=None, 197 | help="Path to passages (.tsv file)") 198 | parser.add_argument("--sample_dataset_num", type=int, default=None, 199 | help="Path to passages (.tsv file)") 200 | parser.add_argument("--instruction_unfollowing_file", type=str, default=None, 201 | help="instruction unfollowing file") 202 | parser.add_argument("--prompts", type=str, default=None, 203 | help="prompt", nargs="+") 204 | parser.add_argument("--task_names", type=str, default=None, 205 | help="task names filter", nargs="+") 206 | parser.add_argument("--instance_idx_start", type=int, 207 | default=None, help="instance start index") 208 | parser.add_argument("--instance_idx_end", type=int, 209 | default=None, help="instance end index") 210 | parser.add_argument("--per_gpu_batch_size", type=int, 211 | default=64, help="Batch size for question encoding") 212 | parser.add_argument("--n_docs", type=int, default=100, 213 | help="Number of documents to retrieve per questions") 214 | parser.add_argument( 215 | "--model_name_or_path", type=str, help="path to directory containing model weights and config file" 216 | ) 217 | parser.add_argument("--no_fp16", action="store_true", 218 | help="inference in fp32") 219 | parser.add_argument("--question_maxlength", type=int, 220 | default=512, help="Maximum number of tokens in a question") 221 | parser.add_argument("--start_idx", type=int, default=None,) 222 | parser.add_argument("--end_idx", type=int, default=None,) 223 | parser.add_argument("--max_train_data", type=int, default=3000000) 224 | parser.add_argument("--num_positive_paragraphs", type=int, default=1,) 225 | parser.add_argument("--num_negative_paragraphs", type=int, default=4,) 226 | args = parser.parse_args() 227 | src.slurm.init_distributed_mode(args) 228 | main(args) 229 | -------------------------------------------------------------------------------- /BERRI/enc_t5/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | from .modeling_enc_t5 import EncT5ForSequenceClassification 7 | from .tokenization_enc_t5 import EncT5Tokenizer 8 | -------------------------------------------------------------------------------- /BERRI/enc_t5/modeling_enc_t5.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | import copy 7 | import torch 8 | from torch import nn 9 | from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss 10 | from transformers.modeling_outputs import SequenceClassifierOutput 11 | from transformers.models.t5.modeling_t5 import T5Config, T5PreTrainedModel, T5Stack 12 | from transformers.utils.model_parallel_utils import assert_device_map, get_device_map 13 | 14 | 15 | class EncT5ForSequenceClassification(T5PreTrainedModel): 16 | _keys_to_ignore_on_load_missing = [ 17 | r"encoder\.embed_tokens\.weight", 18 | ] 19 | 20 | def __init__(self, config: T5Config, dropout=0.1): 21 | super().__init__(config) 22 | self.num_labels = config.num_labels 23 | self.config = config 24 | 25 | self.shared = nn.Embedding(config.vocab_size, config.d_model) 26 | 27 | encoder_config = copy.deepcopy(config) 28 | encoder_config.use_cache = False 29 | encoder_config.is_encoder_decoder = False 30 | self.encoder = T5Stack(encoder_config, self.shared) 31 | 32 | self.dropout = nn.Dropout(dropout) 33 | self.classifier = nn.Linear(config.hidden_size, config.num_labels) 34 | 35 | # Initialize weights and apply final processing 36 | self.post_init() 37 | 38 | # Model parallel 39 | self.model_parallel = False 40 | self.device_map = None 41 | 42 | def parallelize(self, device_map=None): 43 | self.device_map = ( 44 | get_device_map(len(self.encoder.block), 45 | range(torch.cuda.device_count())) 46 | if device_map is None 47 | else device_map 48 | ) 49 | assert_device_map(self.device_map, len(self.encoder.block)) 50 | self.encoder.parallelize(self.device_map) 51 | self.classifier = self.classifier.to(self.encoder.first_device) 52 | self.model_parallel = True 53 | 54 | def deparallelize(self): 55 | self.encoder.deparallelize() 56 | self.encoder = self.encoder.to("cpu") 57 | self.model_parallel = False 58 | self.device_map = None 59 | torch.cuda.empty_cache() 60 | 61 | def get_input_embeddings(self): 62 | return self.shared 63 | 64 | def set_input_embeddings(self, new_embeddings): 65 | self.shared = new_embeddings 66 | self.encoder.set_input_embeddings(new_embeddings) 67 | 68 | def get_encoder(self): 69 | return self.encoder 70 | 71 | def _prune_heads(self, heads_to_prune): 72 | """ 73 | Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base 74 | class PreTrainedModel 75 | """ 76 | for layer, heads in heads_to_prune.items(): 77 | self.encoder.layer[layer].attention.prune_heads(heads) 78 | 79 | def forward( 80 | self, 81 | input_ids=None, 82 | attention_mask=None, 83 | head_mask=None, 84 | inputs_embeds=None, 85 | labels=None, 86 | output_attentions=None, 87 | output_hidden_states=None, 88 | return_dict=None, 89 | ): 90 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 91 | 92 | outputs = self.encoder( 93 | input_ids=input_ids, 94 | attention_mask=attention_mask, 95 | inputs_embeds=inputs_embeds, 96 | head_mask=head_mask, 97 | output_attentions=output_attentions, 98 | output_hidden_states=output_hidden_states, 99 | return_dict=return_dict, 100 | ) 101 | 102 | hidden_states = outputs[0] 103 | # Take bos token (equiv. to ) 104 | pooled_output = hidden_states[:, 0, :] 105 | 106 | pooled_output = self.dropout(pooled_output) 107 | logits = self.classifier(pooled_output) 108 | 109 | loss = None 110 | if labels is not None: 111 | if self.config.problem_type is None: 112 | if self.num_labels == 1: 113 | self.config.problem_type = "regression" 114 | elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): 115 | self.config.problem_type = "single_label_classification" 116 | else: 117 | self.config.problem_type = "multi_label_classification" 118 | 119 | if self.config.problem_type == "regression": 120 | loss_fct = MSELoss() 121 | if self.num_labels == 1: 122 | loss = loss_fct(logits.squeeze(), labels.squeeze()) 123 | else: 124 | loss = loss_fct(logits, labels) 125 | elif self.config.problem_type == "single_label_classification": 126 | loss_fct = CrossEntropyLoss() 127 | loss = loss_fct( 128 | logits.view(-1, self.num_labels), labels.view(-1)) 129 | elif self.config.problem_type == "multi_label_classification": 130 | loss_fct = BCEWithLogitsLoss() 131 | loss = loss_fct(logits, labels) 132 | if not return_dict: 133 | output = (logits,) + outputs[1:] 134 | return ((loss,) + output) if loss is not None else output 135 | 136 | return SequenceClassifierOutput( 137 | loss=loss, 138 | logits=logits, 139 | hidden_states=outputs.hidden_states, 140 | attentions=outputs.attentions, 141 | ) 142 | -------------------------------------------------------------------------------- /BERRI/enc_t5/tokenization_enc_t5.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | from typing import Any, Dict, List, Optional 7 | from transformers import T5Tokenizer 8 | 9 | 10 | class EncT5Tokenizer(T5Tokenizer): 11 | def __init__( 12 | self, 13 | vocab_file, 14 | bos_token="", 15 | eos_token="", 16 | unk_token="", 17 | pad_token="", 18 | extra_ids=100, 19 | additional_special_tokens=None, 20 | sp_model_kwargs: Optional[Dict[str, Any]] = None, 21 | **kwargs, 22 | ) -> None: 23 | sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs 24 | 25 | super().__init__( 26 | vocab_file=vocab_file, 27 | bos_token=bos_token, 28 | eos_token=eos_token, 29 | unk_token=unk_token, 30 | pad_token=pad_token, 31 | extra_ids=extra_ids, 32 | additional_special_tokens=additional_special_tokens, 33 | sp_model_kwargs=sp_model_kwargs, 34 | **kwargs, 35 | ) 36 | 37 | def get_special_tokens_mask( 38 | self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False 39 | ) -> List[int]: 40 | """ 41 | Retrieve sequence ids from a token list that has no special tokens added. 42 | This will be called when adding special tokens using the tokenizer `prepare_for_model` method. 43 | """ 44 | if already_has_special_tokens: 45 | return super().get_special_tokens_mask( 46 | token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True 47 | ) 48 | 49 | # normal case: some special tokens 50 | if token_ids_1 is None: 51 | return [1] + ([0] * len(token_ids_0)) + [1] 52 | return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1] 53 | 54 | def create_token_type_ids_from_sequences( 55 | self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None 56 | ) -> List[int]: 57 | bos = [self.bos_token_id] 58 | eos = [self.eos_token_id] 59 | 60 | if token_ids_1 is None: 61 | return len(bos + token_ids_0 + eos) * [0] 62 | return len(bos + token_ids_0 + eos + token_ids_1 + eos) * [0] 63 | 64 | def build_inputs_with_special_tokens( 65 | self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None 66 | ) -> List[int]: 67 | """ 68 | Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and 69 | adding special tokens. A sequence has the following format: 70 | - single sequence: ` X ` 71 | - pair of sequences: ` A B ` 72 | """ 73 | if token_ids_1 is None: 74 | return [self.bos_token_id] + token_ids_0 + [self.eos_token_id] 75 | else: 76 | return [self.bos_token_id] + token_ids_0 + [self.eos_token_id] + token_ids_1 + [self.eos_token_id] 77 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | Facebook has adopted a Code of Conduct that we expect project participants to adhere to. 4 | Please read the [full text](https://code.fb.com/codeofconduct/) 5 | so that you can understand what actions will and will not be tolerated. 6 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to this repo 2 | 3 | ## Pull Requests 4 | 5 | In order to accept your pull request, we need you to submit a CLA. You only need 6 | to do this once to work on any of Facebook's open source projects. 7 | 8 | Complete your CLA here: 9 | 10 | ## Issues 11 | We use GitHub issues to track public bugs. Please ensure your description is 12 | clear and has sufficient instructions to be able to reproduce the issue. 13 | 14 | ## License 15 | By contributing to this repo, you agree that your contributions will be licensed 16 | under the LICENSE file in the root directory of this source tree. 17 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Task-aware Retrieval with Instructions 2 | 3 | This is the official repository for our preprint, [Task-aware Retrieval with Instructions](https://arxiv.org/abs/2211.09260). 4 | 5 | We introduce a new retrieval task formulation, **retrieval with instructions**, constructs **BERRI**, the first large-scale collection of retrieval datasets with instructions, and present **TART**, multi-task instruction-following retrieval models trained on BERRI. 6 | 7 |

8 | 9 |

10 | 11 | ## Content 12 | 13 | 1. [Getting started](#getting-started) 14 | 2. [Pretrained Checkpoints](#pre-trained-checkpoints) 15 | - [Pre-trained checkpoints ](#pre-trained-checkpoints) 16 | - [Embeddings](#embeddings) 17 | 3. [Evaluation](#evaluation) 18 | - [BERRI](#beir) 19 | - [LOTTE](#lotte) 20 | - [Cross-task Cross-domain evaluation](#cross-task-cross-domain-dataset) 21 | 4. [Training](#training) 22 | 5. [Dataset: BERRI](#dataset-berri) 23 | 5. [Citations and Contact](#citation-and-contact) 24 | 25 | 26 | *Updates* 27 | **2023.12**: Released initial codes. 28 | **2023.2**: Released scripts to create BERRI with links to the data created by the third party. 29 | **2023.2**: Released training instructions. 30 | 31 | ## Getting started 32 | Pre-trained models can be loaded through the HuggingFace transformers library: 33 | 34 | ```py 35 | from src.modeling_enc_t5 import EncT5ForSequenceClassification 36 | from src.tokenization_enc_t5 import EncT5Tokenizer 37 | import torch 38 | import torch.nn.functional as F 39 | import numpy as np 40 | # load TART full and tokenizer 41 | model = EncT5ForSequenceClassification.from_pretrained("facebook/tart-full-flan-t5-xl") 42 | tokenizer = EncT5Tokenizer.from_pretrained("facebook/tart-full-flan-t5-xl") 43 | ``` 44 | 45 | Then, you can test TART-full as follows: 46 | ```py 47 | model.eval() 48 | q = "What is the population of Tokyo?" 49 | in_answer = "retrieve a passage that answers this question from Wikipedia" 50 | 51 | p_1 = "The population of Japan's capital, Tokyo, dropped by about 48,600 people to just under 14 million at the start of 2022, the first decline since 1996, the metropolitan government reported Monday." 52 | p_2 = "Tokyo, officially the Tokyo Metropolis (東京都, Tōkyō-to), is the capital and largest city of Japan." 53 | 54 | # 1. TART-full can identify more relevant paragraph. 55 | features = tokenizer(['{0} [SEP] {1}'.format(in_answer, q), '{0} [SEP] {1}'.format(in_answer, q)], [p_1, p_2], padding=True, truncation=True, return_tensors="pt") 56 | with torch.no_grad(): 57 | scores = model(**features).logits 58 | normalized_scores = [float(score[1]) for score in F.softmax(scores, dim=1)] 59 | 60 | print([p_1, p_2][np.argmax(normalized_scores)]) # "The population of Japan's capital, Tokyo, dropped by about 48,600 people to just under 14 million ... " 61 | 62 | # 2. TART-full can identify the document that is more relevant AND follows instructions. 63 | in_sim = "You need to find duplicated questions in Wiki forum. Could you find a question that is similar to this question" 64 | q_1 = "How many people live in Tokyo?" 65 | features = tokenizer(['{0} [SEP] {1}'.format(in_sim, q), '{0} [SEP] {1}'.format(in_sim, q)], [p, q_1], padding=True, truncation=True, return_tensors="pt") 66 | with torch.no_grad(): 67 | scores = model(**features).logits 68 | normalized_scores = [float(score[1]) for score in F.softmax(scores, dim=1)] 69 | 70 | print([p_1, q_1][np.argmax(normalized_scores)]) # "How many people live in Tokyo?" 71 | 72 | ``` 73 | As you can see, TART not only gives lower scores to the wrong query (q_1 v.s. q_2), but also gives a lower score to the document that is relevant but does not follow the instruction. 74 | 75 | ### Interactive mode 76 | Interactive mode enables you to type whatever questions and retrieve pre-encoded documents. To run this interactive mode, you need to encode documents, as well as the models. 77 | 78 | ```sh 79 | python interactive.py \ 80 | --passages scifact/corpus.jsonl \ 81 | --passages_embeddings "PATH_TO_YOUR_EMBEDDINGS/*" \ 82 | --model_name_or_path PATH_TO_YOUR_BE_MODEL \ 83 | --ce_model PATH_TO_YOUR_CE_MODEL \ 84 | ``` 85 | 86 | 87 | ## Pre-trained checkpoints 88 | ### TART-full 89 | We release TART-full models trained on BERRI using different initial encoder weights. Our TART-full model on the paper is based on T0-3B. We will release TART-full using smaller model soon! 90 | 91 | Models are all on huggingface hub. 92 | 93 | | name | size | initialization | 94 | | ----------- | ----------- | ----------- | 95 | | [facebook/tart-full-flan-t0-3b](https://huggingface.co/facebook/tart-full-flan-t0-3b) | 1.5 billions |[T0-3B](https://huggingface.co/bigscience/T0_3B)| 96 | | [facebook/tart-full-flan-t5-xl](https://huggingface.co/facebook/tart-full-flan-t5-xl) | 1.5 billions | [FLANT5-XL](https://huggingface.co/google/flan-t5-xl)| 97 | 98 | ### TART-dual 99 | TART-dual is an efficient bi-encoder model sharing an encoder for document and query encodings. 100 | | name | size | initialization | 101 | | ----------- | ----------- | ----------- | 102 | | [facebook/tart-dual-contriever-msmarco](https://homes.cs.washington.edu/~akari/models/tart-dual-contriever-msmarco.zip) | 110 millions |[facebook/contriever-msmarco](https://huggingface.co/facebook/contriever-msmarco)| 103 | 104 | The main model on the paper uses [Contriever-MS MARCO](https://huggingface.co/facebook/contriever-msmarco) pre-trained on Wikipedia 2020 dump. 105 | 106 | ## Embeddings 107 | We release the pre-encoded embeddings for the BEIR datasets [here](https://drive.google.com/file/d/1qLqFLlvLpQ2OGHowO0Jwv_di8mJgh2be/view?usp=sharing): 108 | 109 | 110 | ## Evaluation 111 | ### BEIR 112 | You can evaluate the models on BEIR, by running `eval_beir.py` or `eval_cross_task.py`. 113 | 114 | `eval_beir.py` is adopted from the official BEIR repository, encodes and runs inference using a single GPU every time, while `eval_cross_task.py` assumes that you have encoded document embeddings and parallelize inference using multiple GPUs. If you have multiple GPUs or try to evaluate TART on datasets with millions of documents (e.g., Climate-FEVER), we recommend using `eval_cross_task.py` script. 115 | 116 | #### Run evaluation with `eval_beir.py` 117 | 118 | ```sh 119 | python eval_beir.py \ 120 | --model_name_or_path BI_ENCODER_MODEL_NAME_OR_PATH \ 121 | --dataset BEIR_DATASET_NAME \ 122 | --output_dir YOUR_OUTPUT_DIR 123 | --model_name_or_path BI_ENCODER_MODEL_NAME_OR_PATH \ 124 | --ce_model CROSS_ENCODER_MODEL_NAME_OR_PATH \ 125 | --prompt "YOUR INSTRUCTIONS" 126 | ``` 127 | 128 | 129 | #### Run evaluation with `eval_cross_task.py` 130 | As mentioned above, there are two steps to run `eval_cross_task.py` script: **STEP1: encode all documents**, and **STEP2: run evaluations using encoded embeddings**. 131 | 132 | ##### STEP1: Encode all of the document 133 | To encode document using a single GPU, please run the command below: 134 | 135 | ```sh 136 | python generate_passage_embeddings.py --model_name_or_path YOUR_MODEL_NAME --output_dir OUTPUT_DIR_NAME \ 137 | --passages PATH_TO_YOUR_INPUT_DATA_DIR/corpus.jsonl --shard_id ${i} --num_shards 1 138 | ``` 139 | 140 | If you want to use multiple GPUs to speed up the process, you can run the following command: 141 | 142 | ```sh 143 | for i in {0..7}; do 144 | export CUDA_VISIBLE_DEVICES=${i} 145 | nohup python generate_passage_embeddings.py --model_name_or_path BI_ENCODER_MODEL_NAME_OR_PATH --output_dir OUTPUT_DIR_NAME \ 146 | --passages PATH_TO_YOUR_INPUT_DATA_DIR/corpus.jsonl --shard_id ${i} --num_shards 8 > ./log/nohup.log.${i} 2>&1 & 147 | done 148 | ``` 149 | 150 | The corpus file is a `jsonlines` file, where each item contains `text` and `title`, and optional `_id` and `meta_data`. 151 | 152 | e.g., 153 | 154 | ``` 155 | {"_id": "doc9", "title": "Chicago Fire (season 4)", "text": "Hermann is rushed to Chicago Med after being stabbed at Molly's. After losing a lot a blood, it is determined he needs emergency surgery. Feeling guilty about Hermann's present state, Cruz searches for Freddy to turn him in. Severide is reinstated as Lieutenant while Borelli grows more concerned about Chili's erratic behavior. Mouch considers finally proposing to Platt.", "metadata": {}} 156 | ``` 157 | 158 | ##### STEP2: Run predictions 159 | 160 | Once you encode passages, you can run the evaluations as follows: 161 | ```sh 162 | python eval_cross_task.py \ 163 | --passages PATH_TO_YOUR_INPUT_DATA_DIR/corpus.jsonl \ 164 | --passages_embeddings "PATH_TO_YOUR_EMBEDDING_OUTPUT_DIR/passages_*" \ 165 | --qrels PATH_TO_YOUR_INPUT_DATA_DIR/qrels/test.csv \ 166 | --output_dir OUT_PUT_DIR_NAME \ 167 | --model_name_or_path BI_ENCODER_MODEL_NAME_OR_PATH \ 168 | --ce_model CROSS_ENCODER_MODEL_NAME_OR_PATH \ 169 | --data PATH_TO_YOUR_INPUT_DATA_DIR/queries.jsonl \ 170 | --prompt "YOUR INSTRUCTIONS" 171 | ``` 172 | 173 | ### LOTTE 174 | We evaluate our model on LOTTE-search (pooled). To run the evaluations on LOTTE, you can download our processed data (the data itself is the same but we convert the input data file formats and add instructions) as follows: 175 | 176 | ```sh 177 | wget https://homes.cs.washington.edu/~akari/tart/processed_lotte_search_pooled.zip 178 | unzip processed_lotte_search_pooled.zip 179 | ``` 180 | 181 | Encode passages as in the previous section. 182 | 183 | ```sh 184 | for i in {0..7}; do 185 | export CUDA_VISIBLE_DEVICES=${i} 186 | nohup python generate_passage_embeddings.py --model_name_or_path BI_ENCODER_MODEL_NAME_OR_PATH --output_dir OUTPUT_DIR_NAME \ 187 | --passages processed_lotte_search_pooled/corpus.jsonl --shard_id ${i} --num_shards 8 > ./log/nohup.log.${i} 2>&1 & 188 | done 189 | ``` 190 | 191 | Once you encode the passages, you can run evaluations 192 | ```sh 193 | python eval_cross_task.py \ 194 | --passages processed_lotte_search_pooled/corpus.jsonl \ 195 | --passages_embeddings "contriever_lotte_corpus/passages_*" \ 196 | --qrels processed_lotte_search_pooled/qrels/test.tsv \ 197 | --output_dir OUT_PUT_DIR_NAME \ 198 | --model_name_or_path BI_ENCODER_MODEL_NAME_OR_PATH \ 199 | --ce_model CROSS_ENCODER_MODEL_NAME_OR_PATH \ 200 | --data processed_lotte_search_pooled/queries_w_instructions_sep.jsonl \ 201 | --lotte 202 | ``` 203 | This code output the lotte's official evaluation script format data under `CROSS_ENCODER_MODEL_NAME_OR_PATH/` 204 | Then you can run the official evaluation script as follows: 205 | 206 | ```sh 207 | cp lotte 208 | python evaluate_lotte_rankings.py --k 5 --split test --data_path ../lotte --rankings_path PATH_TO_PREDICTION_FILE 209 | ``` 210 | 211 | ### Cross-task Cross-domain dataset 212 | In this paper, we newly introduce cross-task cross-domain evaluation, where given an instruction and a single large-scale domain, a system needs to retrieve documents that follow instructions. 213 | 214 | Due to legal reasons, Meta cannot host this data. The script to create cross-task cross-domain dataset is available at [cross_task_cross_domain](https://github.com/facebookresearch/tart/tree/main/cross_task_cross_eval), and you can also download the processed cross task dataset as follows. 215 | 216 | ```sh 217 | wget https://homes.cs.washington.edu/~akari/tart/cross_task_cross_domain_final.zip 218 | unzip https://homes.cs.washington.edu/~akari/tart/cross_task_cross_domain_final.zip 219 | ``` 220 | 221 | Due to the larger corpus, we highly recommend encoding every documents beforehand. 222 | Encoded documents are available at the [encoded documents](#embeddings) Section. 223 | 224 | Then you can run evaluations on the cross-task cross-domain data as follows: 225 | ```sh 226 | python eval_cross_task.py \ 227 | --passages ../cross_2_eval/nq/corpus.jsonl ../cross_tacross_2_evalsk_eval/scifact/corpus.jsonl ../cross_2_eval/gooaq_med/corpus.jsonl ../cross_2_eval/linkso_py/corpus.jsonl ../cross_2_eval/ambig/corpus.jsonl ../cross_2_eval/wikiqa/corpus.jsonl ../cross_2_eval/gooaq_technical/corpus.jsonl ../cross_2_eval/codesearch_py/corpus_new.jsonl \ 228 | --passages_embeddings "linkso_py_contriever_embeddings/passages_*" "ambig_contriever_embeddings/passages_*" "scifact_contriever_embeddings/*" "nq_contriever_embeddings/passages_*" "gooaq_technical_contriever_embeddings/passages_*" "codesearch_py_contriever_embeddings/passages_*" "wikiqa_contriever_embeddings/passages_*" \ 229 | --qrels ../cross_task_eval/linkso/qrels/test_new.tsv \ 230 | --output_dir YOUR_OUTPUT_DIR \ 231 | --model_name_or_path BI_ENCODER_MODEL_NAME_OR_PATH \ 232 | --ce_model CROSS_ENCODER_MODEL_NAME_OR_PATH \ 233 | ``` 234 | 235 | ## Training 236 | ### TART-full 237 | To train TART-full model, run the script below. We use 8 GBU for training. We found that training 3B cross-encoder too long can make the model overfit to the data, so we only train the models for one epoch and pick up the best model based on the development sore. 238 | 239 | ``` 240 | python finetuning_tart_full.py --task_name ce \ 241 | --train_data PATH_TO_TRAIN_DATA \ 242 | --eval_data PATH_TO_DEV_DATA \ 243 | --model_name_or_path [bigscience/T0_3B, google/flan-t5-xl] \ 244 | --output_dir PATH_TO_YOUR_OUTPUT_DIR \ 245 | --do_train --do_eval \ 246 | --overwrite_output_dir --evaluation_strategy steps \ 247 | --eval_steps 2000 --save_steps 2000 \ 248 | --metric_for_best_model accuracy \ 249 | --per_gpu_train_batch_size 1 --num_train_epochs 1 \ 250 | --gradient_accumulation_steps 8 --max_seq_length 512 \ 251 | --load_best_model_at_end 252 | ``` 253 | 254 | ### TART-dual 255 | To train TART-dual model, run the script below. We use 64 GBU for training. 256 | 257 | ``` 258 | cd TART 259 | python finetuning.py \ 260 | --model_path facebook/contriever-msmarco \ 261 | --train_data PATH_TO_TRAIN_DATA \ 262 | --eval_data PATH_TO_DEV_DATA \ 263 | --chunk_length 256 --negative_ctxs 5 --warmup_steps 1000 --total_steps 50000 \ 264 | --lr 0.00001 --scheduler linear --optim adamw --per_gpu_batch_size 16 --temperature 0.05 \ 265 | --per_gpu_eval_batch_size 16 --output_dir PATH_TO_OUTPUT_DIR \ 266 | --save_freq 5000 --eval_freq 5000 --negative_hard_ratio 0.1 267 | ``` 268 | 269 | 270 | ## Dataset: BERRI 271 | 272 | [`berri`](berri) directory includes the scripts and instruction data to construct the BERRI dataset. 273 | 274 | ### Instructions 275 | All of the annotated instructions are in [berri_instructions.tsv](berri/berri_instructions.tsv) 276 | 277 | ### Preprocessing script and processed data 278 | BERRI data constructions consists of (i) convert existing datasets into retrieval tasks, adding initial positive and randomly sampled negative passages, (ii) retrieve top paragraphs using an efficient bi-encoder (i.e., Contriever), and then (iii) denoise the top documents. 279 | 280 | [`berri/README.md`](berri/README.md) describes the detailed instructions. 281 | 282 | You can download the processed source data (from the process (i)) as well as the final training data for TART-dual and full, processed by a third party here: 283 | - [source data (22 GB)](https://drive.google.com/file/d/1hzlN4cEFOZRkdVeCMq62NUxvMNTopB1o/view?usp=share_link) 284 | - [TART-full training data (1 GB)](https://drive.google.com/file/d/1oijzAb2gWKT54OgeE7_KB9VcHvA7UxpQ/view?usp=share_link) 285 | - [TART-dual training data (14 GB)](https://drive.google.com/file/d/1lMmD5lTxYWYf0z0ua0-GaGKz2qs2mG1r/view?usp=share_link) 286 | 287 | 288 | ## Citation and Contact 289 | If you find this repository helpful, please cite our paper. 290 | 291 | ``` 292 | @article{asai2022tart, 293 | title={Task-aware Retrieval with Instructions}, 294 | author={Asai, Akari and Schick, Timo and Lewis, Patrick and Chen, Xilun and Izacard, Gautier and Riedel, Sebastian and Hajishirzi, Hannaneh and Yih, Wen-tau}, 295 | journal={arXiv preprint arXiv:2211.09260}, 296 | year={2022} 297 | } 298 | ``` 299 | 300 | If you have any questions about the paper, feel free to contact Akari Asai (akari[at]cs.washington.edu) or open an issue, and mention @AkariAsai 301 | 302 | ### License 303 | See the [LICENSE](LICENSE) file for more details. 304 | -------------------------------------------------------------------------------- /TART/custom_metrics.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import logging 8 | from typing import List, Dict, Union, Tuple 9 | 10 | def mrr(qrels: Dict[str, Dict[str, int]], 11 | results: Dict[str, Dict[str, float]], 12 | k_values: List[int]) -> Tuple[Dict[str, float]]: 13 | 14 | MRR = {} 15 | 16 | for k in k_values: 17 | MRR[f"MRR@{k}"] = 0.0 18 | 19 | k_max, top_hits = max(k_values), {} 20 | logging.info("\n") 21 | 22 | for query_id, doc_scores in results.items(): 23 | top_hits[query_id] = sorted(doc_scores.items(), key=lambda item: item[1], reverse=True)[0:k_max] 24 | 25 | for query_id in top_hits: 26 | query_relevant_docs = set([doc_id for doc_id in qrels[query_id] if qrels[query_id][doc_id] > 0]) 27 | for k in k_values: 28 | for rank, hit in enumerate(top_hits[query_id][0:k]): 29 | if hit[0] in query_relevant_docs: 30 | MRR[f"MRR@{k}"] += 1.0 / (rank + 1) 31 | break 32 | 33 | for k in k_values: 34 | MRR[f"MRR@{k}"] = round(MRR[f"MRR@{k}"]/len(qrels), 5) 35 | logging.info("MRR@{}: {:.4f}".format(k, MRR[f"MRR@{k}"])) 36 | 37 | return MRR 38 | 39 | def recall_cap(qrels: Dict[str, Dict[str, int]], 40 | results: Dict[str, Dict[str, float]], 41 | k_values: List[int]) -> Tuple[Dict[str, float]]: 42 | 43 | capped_recall = {} 44 | 45 | for k in k_values: 46 | capped_recall[f"R_cap@{k}"] = 0.0 47 | 48 | k_max = max(k_values) 49 | logging.info("\n") 50 | 51 | for query_id, doc_scores in results.items(): 52 | top_hits = sorted(doc_scores.items(), key=lambda item: item[1], reverse=True)[0:k_max] 53 | query_relevant_docs = [doc_id for doc_id in qrels[query_id] if qrels[query_id][doc_id] > 0] 54 | for k in k_values: 55 | retrieved_docs = [row[0] for row in top_hits[0:k] if qrels[query_id].get(row[0], 0) > 0] 56 | denominator = min(len(query_relevant_docs), k) 57 | capped_recall[f"R_cap@{k}"] += (len(retrieved_docs) / denominator) 58 | 59 | for k in k_values: 60 | capped_recall[f"R_cap@{k}"] = round(capped_recall[f"R_cap@{k}"]/len(qrels), 5) 61 | logging.info("R_cap@{}: {:.4f}".format(k, capped_recall[f"R_cap@{k}"])) 62 | 63 | return capped_recall 64 | 65 | 66 | def hole(qrels: Dict[str, Dict[str, int]], 67 | results: Dict[str, Dict[str, float]], 68 | k_values: List[int]) -> Tuple[Dict[str, float]]: 69 | 70 | Hole = {} 71 | 72 | for k in k_values: 73 | Hole[f"Hole@{k}"] = 0.0 74 | 75 | annotated_corpus = set() 76 | for _, docs in qrels.items(): 77 | for doc_id, score in docs.items(): 78 | annotated_corpus.add(doc_id) 79 | 80 | k_max = max(k_values) 81 | logging.info("\n") 82 | 83 | for _, scores in results.items(): 84 | top_hits = sorted(scores.items(), key=lambda item: item[1], reverse=True)[0:k_max] 85 | for k in k_values: 86 | hole_docs = [row[0] for row in top_hits[0:k] if row[0] not in annotated_corpus] 87 | Hole[f"Hole@{k}"] += len(hole_docs) / k 88 | 89 | for k in k_values: 90 | Hole[f"Hole@{k}"] = round(Hole[f"Hole@{k}"]/len(qrels), 5) 91 | logging.info("Hole@{}: {:.4f}".format(k, Hole[f"Hole@{k}"])) 92 | 93 | return Hole 94 | 95 | def top_k_accuracy( 96 | qrels: Dict[str, Dict[str, int]], 97 | results: Dict[str, Dict[str, float]], 98 | k_values: List[int]) -> Tuple[Dict[str, float]]: 99 | 100 | top_k_acc = {} 101 | 102 | for k in k_values: 103 | top_k_acc[f"Accuracy@{k}"] = 0.0 104 | 105 | k_max, top_hits = max(k_values), {} 106 | logging.info("\n") 107 | 108 | for query_id, doc_scores in results.items(): 109 | top_hits[query_id] = [item[0] for item in sorted(doc_scores.items(), key=lambda item: item[1], reverse=True)[0:k_max]] 110 | 111 | for query_id in top_hits: 112 | query_relevant_docs = set([doc_id for doc_id in qrels[query_id] if qrels[query_id][doc_id] > 0]) 113 | for k in k_values: 114 | for relevant_doc_id in query_relevant_docs: 115 | if relevant_doc_id in top_hits[query_id][0:k]: 116 | top_k_acc[f"Accuracy@{k}"] += 1.0 117 | break 118 | 119 | for k in k_values: 120 | top_k_acc[f"Accuracy@{k}"] = round(top_k_acc[f"Accuracy@{k}"]/len(qrels), 5) 121 | logging.info("Accuracy@{}: {:.4f}".format(k, top_k_acc[f"Accuracy@{k}"])) 122 | 123 | return top_k_acc -------------------------------------------------------------------------------- /TART/eval_beir.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import sys 8 | import argparse 9 | from turtle import update 10 | import torch 11 | import logging 12 | import json 13 | import numpy as np 14 | import os 15 | import copy 16 | 17 | import src.slurm 18 | import src.contriever 19 | import src.beir_utils 20 | import src.utils 21 | import src.dist_utils 22 | import src.contriever 23 | 24 | logger = logging.getLogger(__name__) 25 | 26 | 27 | def main(args): 28 | 29 | src.slurm.init_distributed_mode(args) 30 | src.slurm.init_signal_handler() 31 | 32 | os.makedirs(args.output_dir, exist_ok=True) 33 | 34 | logger = src.utils.init_logger(args) 35 | 36 | model, tokenizer, _ = src.contriever.load_retriever( 37 | args.model_name_or_path) 38 | if args.bi_encoder is True: 39 | if args.ckpt_path is not None: 40 | state_dict = torch.load(args.ckpt_path)["model"] 41 | query_encoder = copy.deepcopy(model) 42 | doc_encoder = copy.deepcopy(model) 43 | query_encoder_dic = {} 44 | doc_encoder_dic = {} 45 | for name, param in state_dict.items(): 46 | print(name) 47 | if "q_encoder" in name: 48 | if "encoder" in name: 49 | orig_name = name.replace( 50 | "q_encoder.encoder", "encoder") 51 | if "embeddings" in name: 52 | orig_name = name.replace( 53 | "q_encoder.embeddings", "embeddings") 54 | query_encoder_dic[orig_name] = param 55 | print(orig_name) 56 | if "p_encoder" in name: 57 | if "encoder" in name: 58 | orig_name = name.replace( 59 | "p_encoder.encoder", "encoder") 60 | if "embeddings" in name: 61 | orig_name = name.replace( 62 | "p_encoder.embeddings", "embeddings") 63 | print(orig_name) 64 | doc_encoder_dic[orig_name] = param 65 | 66 | # print(query_encoder_dic.keys()) 67 | # print(doc_encoder_dic.keys()) 68 | query_encoder.load_state_dict(query_encoder_dic) 69 | # doc_encoder.load_state_dict(doc_encoder_dic) 70 | 71 | query_encoder = query_encoder.cuda() 72 | query_encoder.eval() 73 | 74 | doc_encoder = doc_encoder.cuda() 75 | doc_encoder.eval() 76 | else: 77 | print("for biencoder model, you have to preload the fine-tuned checkpoints.") 78 | raise NotImplementedError() 79 | 80 | else: 81 | if args.ckpt_path is not None: 82 | print("loading model") 83 | state_dict = torch.load(args.ckpt_path)["model"] 84 | 85 | new_dict = {} 86 | # state_dict = {k.replace("encoder_q.", ""): v for k, v in state_dict.items() if "encoder_q." in k} 87 | print(state_dict.keys()) 88 | # print(dict(model.named_parameters()).keys()) 89 | 90 | for name, param in model.named_parameters(): 91 | if name in state_dict: 92 | new_dict[name] = state_dict[name] 93 | print("updated") 94 | print(name) 95 | else: 96 | new_dict[name] = param 97 | 98 | # print(new_dict.keys()) 99 | # print(model.keys()) 100 | # assert model.keys() == new_dict.keys() 101 | 102 | model.load_state_dict(new_dict, strict=False) 103 | 104 | model = model.cuda() 105 | model.eval() 106 | query_encoder = model 107 | doc_encoder = model 108 | 109 | logger.info("Start indexing") 110 | 111 | if args.multiple_prompts is not None: 112 | metrics = src.beir_utils.evaluate_model_multiple( 113 | query_encoder=query_encoder, 114 | doc_encoder=doc_encoder, 115 | tokenizer=tokenizer, 116 | dataset=args.dataset, 117 | batch_size=args.per_gpu_batch_size, 118 | norm_query=args.norm_query, 119 | norm_doc=args.norm_doc, 120 | is_main=src.dist_utils.is_main(), 121 | split="dev" if args.dataset == "msmarco" else "test", 122 | score_function=args.score_function, 123 | beir_dir=args.beir_dir, 124 | save_results_path=args.save_results_path, 125 | lower_case=args.lower_case, 126 | normalize_text=args.normalize_text, 127 | prompt=args.prompt, 128 | multiple_prompts=args.multiple_prompts 129 | ) 130 | else: 131 | metrics = src.beir_utils.evaluate_model( 132 | query_encoder=query_encoder, 133 | doc_encoder=doc_encoder, 134 | tokenizer=tokenizer, 135 | dataset=args.dataset, 136 | batch_size=args.per_gpu_batch_size, 137 | norm_query=args.norm_query, 138 | norm_doc=args.norm_doc, 139 | is_main=src.dist_utils.is_main(), 140 | split="dev" if args.dataset == "msmarco" else "test", 141 | score_function=args.score_function, 142 | beir_dir=args.beir_dir, 143 | save_results_path=args.save_results_path, 144 | lower_case=args.lower_case, 145 | normalize_text=args.normalize_text, 146 | prompt=args.prompt, 147 | emb_load_path=args.emb_load_path, 148 | emb_save_path=args.emb_save_path 149 | ) 150 | 151 | if src.dist_utils.is_main(): 152 | for key, value in metrics.items(): 153 | logger.info(f"{args.dataset} : {key}: {value:.1f}") 154 | 155 | print("saving results") 156 | if os.path.exists(os.path.join(args.output_dir, "{0}_{1}_results.json".format(args.dataset, args.model_id))) is True: 157 | results_log = json.load(open(os.path.join( 158 | args.output_dir, "{0}_{1}_results.json".format(args.dataset, args.model_id)))) 159 | else: 160 | results_log = {} 161 | results_log.setdefault(args.prompt, {}) 162 | results_log[args.prompt] = metrics 163 | 164 | with open(os.path.join(args.output_dir, "{0}_{1}_results.json".format(args.dataset, args.model_id)), "w") as outfile: 165 | json.dump(results_log, outfile) 166 | 167 | 168 | if __name__ == "__main__": 169 | parser = argparse.ArgumentParser( 170 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 171 | 172 | parser.add_argument("--dataset", type=str, 173 | help="Evaluation dataset from the BEIR benchmark") 174 | parser.add_argument("--beir_dir", type=str, default="./", 175 | help="Directory to save and load beir datasets") 176 | parser.add_argument("--text_maxlength", type=int, 177 | default=512, help="Maximum text length") 178 | parser.add_argument("--emb_load_path", type=str, default=None, 179 | help="path to load already computed embeddings.", nargs="+") 180 | parser.add_argument("--emb_save_path", type=str, default=None, 181 | help="path to save already computed embeddings.") 182 | 183 | parser.add_argument("--per_gpu_batch_size", default=128, 184 | type=int, help="Batch size per GPU/CPU for indexing.") 185 | parser.add_argument("--output_dir", type=str, 186 | default="./my_experiment", help="Output directory") 187 | parser.add_argument("--model_name_or_path", type=str, 188 | help="Model name or path") 189 | parser.add_argument( 190 | "--score_function", type=str, default="dot", help="Metric used to compute similarity between two embeddings" 191 | ) 192 | parser.add_argument("--norm_query", action="store_true", 193 | help="Normalize query representation") 194 | parser.add_argument("--norm_doc", action="store_true", 195 | help="Normalize document representation") 196 | parser.add_argument("--lower_case", action="store_true", 197 | help="lowercase query and document text") 198 | parser.add_argument( 199 | "--normalize_text", action="store_true", help="Apply function to normalize some common characters" 200 | ) 201 | parser.add_argument("--save_results_path", type=str, 202 | default=None, help="Path to save result object") 203 | 204 | parser.add_argument("--local_rank", type=int, default=-1, 205 | help="For distributed training: local_rank") 206 | parser.add_argument("--main_port", type=int, default=-1, 207 | help="Main port (for multi-node SLURM jobs)") 208 | parser.add_argument("--ckpt_path", type=str, help="Model name or path") 209 | parser.add_argument("--bi_encoder", action="store_true", ) 210 | parser.add_argument( 211 | "--prompt", type=str, default=None, help="instructional prompt." 212 | ) 213 | parser.add_argument( 214 | "--multiple_prompts", type=str, nargs='+' 215 | ) 216 | parser.add_argument( 217 | "--model_id", type=str, default=None, help="for logging" 218 | ) 219 | args, _ = parser.parse_known_args() 220 | main(args) 221 | -------------------------------------------------------------------------------- /TART/generate_passage_embeddings.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import os 8 | 9 | import argparse 10 | import csv 11 | import logging 12 | import pickle 13 | 14 | import numpy as np 15 | import torch 16 | 17 | import transformers 18 | 19 | import src.slurm 20 | import src.contriever 21 | import src.utils 22 | import src.data 23 | import src.normalize_text 24 | 25 | 26 | def embed_passages(args, passages, model, tokenizer): 27 | total = 0 28 | allids, allembeddings = [], [] 29 | batch_ids, batch_text = [], [] 30 | with torch.no_grad(): 31 | for k, p in enumerate(passages): 32 | batch_ids.append(p["id"] if "id" in p else p["_id"]) 33 | if args.no_title or not "title" in p: 34 | text = p["text"] 35 | else: 36 | if type(p["title"]) is not str: 37 | print("title incorrect") 38 | print(p) 39 | if type(p["text"]) is not str: 40 | print("text incorrect") 41 | print(p) 42 | text = p["title"] + " " + p["text"] 43 | if args.lowercase: 44 | text = text.lower() 45 | if args.normalize_text: 46 | text = src.normalize_text.normalize(text) 47 | batch_text.append(text) 48 | 49 | if len(batch_text) == args.per_gpu_batch_size or k == len(passages) - 1: 50 | 51 | encoded_batch = tokenizer.batch_encode_plus( 52 | batch_text, 53 | return_tensors="pt", 54 | max_length=args.passage_maxlength, 55 | padding=True, 56 | truncation=True, 57 | ) 58 | 59 | encoded_batch = {k: v.cuda() for k, v in encoded_batch.items()} 60 | embeddings = model(**encoded_batch) 61 | 62 | embeddings = embeddings.cpu() 63 | total += len(batch_ids) 64 | allids.extend(batch_ids) 65 | allembeddings.append(embeddings) 66 | 67 | batch_text = [] 68 | batch_ids = [] 69 | if k % 100000 == 0 and k > 0: 70 | print(f"Encoded passages {total}") 71 | 72 | allembeddings = torch.cat(allembeddings, dim=0).numpy() 73 | return allids, allembeddings 74 | 75 | 76 | def main(args): 77 | model, tokenizer, _ = src.contriever.load_retriever( 78 | args.model_name_or_path) 79 | print(f"Model loaded from {args.model_name_or_path}.", flush=True) 80 | model.eval() 81 | model = model.cuda() 82 | if not args.no_fp16: 83 | model = model.half() 84 | 85 | passages = src.data.load_passages(args.passages) 86 | 87 | shard_size = len(passages) // args.num_shards 88 | start_idx = args.shard_id * shard_size 89 | end_idx = start_idx + shard_size 90 | if args.shard_id == args.num_shards - 1: 91 | end_idx = len(passages) 92 | 93 | passages = passages[start_idx:end_idx] 94 | print( 95 | f"Embedding generation for {len(passages)} passages from idx {start_idx} to {end_idx}.") 96 | 97 | allids, allembeddings = embed_passages(args, passages, model, tokenizer) 98 | 99 | save_file = os.path.join( 100 | args.output_dir, args.prefix + f"_{args.shard_id:02d}") 101 | os.makedirs(args.output_dir, exist_ok=True) 102 | print(f"Saving {len(allids)} passage embeddings to {save_file}.") 103 | with open(save_file, mode="wb") as f: 104 | pickle.dump((allids, allembeddings), f) 105 | 106 | print(f"Total passages processed {len(allids)}. Written to {save_file}.") 107 | 108 | 109 | if __name__ == "__main__": 110 | parser = argparse.ArgumentParser() 111 | 112 | parser.add_argument("--passages", type=str, default=None, 113 | help="Path to passages (.tsv file)") 114 | parser.add_argument("--output_dir", type=str, 115 | default="wikipedia_embeddings", help="dir path to save embeddings") 116 | parser.add_argument("--prefix", type=str, default="passages", 117 | help="prefix path to save embeddings") 118 | parser.add_argument("--shard_id", type=int, default=0, 119 | help="Id of the current shard") 120 | parser.add_argument("--num_shards", type=int, default=1, 121 | help="Total number of shards") 122 | parser.add_argument( 123 | "--per_gpu_batch_size", type=int, default=512, help="Batch size for the passage encoder forward pass" 124 | ) 125 | parser.add_argument("--passage_maxlength", type=int, 126 | default=512, help="Maximum number of tokens in a passage") 127 | parser.add_argument( 128 | "--model_name_or_path", type=str, help="path to directory containing model weights and config file" 129 | ) 130 | parser.add_argument("--no_fp16", action="store_true", 131 | help="inference in fp32") 132 | parser.add_argument("--no_title", action="store_true", 133 | help="title not added to the passage body") 134 | parser.add_argument("--lowercase", action="store_true", 135 | help="lowercase text before encoding") 136 | parser.add_argument("--normalize_text", action="store_true", 137 | help="lowercase text before encoding") 138 | 139 | args = parser.parse_args() 140 | 141 | src.slurm.init_distributed_mode(args) 142 | 143 | main(args) 144 | -------------------------------------------------------------------------------- /TART/passage_retrieval.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import os 8 | import argparse 9 | import csv 10 | import json 11 | import logging 12 | import pickle 13 | import time 14 | import glob 15 | from pathlib import Path 16 | 17 | import numpy as np 18 | import torch 19 | import transformers 20 | 21 | import src.index 22 | import src.contriever 23 | import src.utils 24 | import src.slurm 25 | import src.data 26 | from src.evaluation import calculate_matches 27 | import src.normalize_text 28 | 29 | os.environ["TOKENIZERS_PARALLELISM"] = "true" 30 | 31 | 32 | def embed_queries(args, queries, model, tokenizer): 33 | model.eval() 34 | embeddings, batch_question = [], [] 35 | with torch.no_grad(): 36 | 37 | for k, q in enumerate(queries): 38 | if args.lowercase: 39 | q = q.lower() 40 | if args.normalize_text: 41 | q = src.normalize_text.normalize(q) 42 | batch_question.append(q) 43 | 44 | if len(batch_question) == args.per_gpu_batch_size or k == len(queries) - 1: 45 | 46 | encoded_batch = tokenizer.batch_encode_plus( 47 | batch_question, 48 | return_tensors="pt", 49 | max_length=args.question_maxlength, 50 | padding=True, 51 | truncation=True, 52 | ) 53 | encoded_batch = {k: v.cuda() for k, v in encoded_batch.items()} 54 | output = model(**encoded_batch) 55 | embeddings.append(output.cpu()) 56 | 57 | batch_question = [] 58 | 59 | embeddings = torch.cat(embeddings, dim=0) 60 | print(f"Questions embeddings shape: {embeddings.size()}") 61 | 62 | return embeddings.numpy() 63 | 64 | 65 | def index_encoded_data(index, embedding_files, indexing_batch_size): 66 | allids = [] 67 | allembeddings = np.array([]) 68 | for i, file_path in enumerate(embedding_files): 69 | print(f"Loading file {file_path}") 70 | with open(file_path, "rb") as fin: 71 | ids, embeddings = pickle.load(fin) 72 | 73 | allembeddings = np.vstack( 74 | (allembeddings, embeddings)) if allembeddings.size else embeddings 75 | allids.extend(ids) 76 | while allembeddings.shape[0] > indexing_batch_size: 77 | allembeddings, allids = add_embeddings( 78 | index, allembeddings, allids, indexing_batch_size) 79 | 80 | while allembeddings.shape[0] > 0: 81 | allembeddings, allids = add_embeddings( 82 | index, allembeddings, allids, indexing_batch_size) 83 | 84 | print("Data indexing completed.") 85 | 86 | 87 | def add_embeddings(index, embeddings, ids, indexing_batch_size): 88 | end_idx = min(indexing_batch_size, embeddings.shape[0]) 89 | ids_toadd = ids[:end_idx] 90 | embeddings_toadd = embeddings[:end_idx] 91 | ids = ids[end_idx:] 92 | embeddings = embeddings[end_idx:] 93 | index.index_data(ids_toadd, embeddings_toadd) 94 | return embeddings, ids 95 | 96 | 97 | def validate(data, workers_num): 98 | match_stats = calculate_matches(data, workers_num) 99 | top_k_hits = match_stats.top_k_hits 100 | 101 | print("Validation results: top k documents hits %s", top_k_hits) 102 | top_k_hits = [v / len(data) for v in top_k_hits] 103 | message = "" 104 | for k in [5, 10, 20, 100]: 105 | if k <= len(top_k_hits): 106 | message += f"R@{k}: {top_k_hits[k-1]} " 107 | print(message) 108 | return match_stats.questions_doc_hits 109 | 110 | 111 | def add_passages(data, passages, top_passages_and_scores): 112 | # add passages to original data 113 | merged_data = [] 114 | assert len(data) == len(top_passages_and_scores) 115 | for i, d in enumerate(data): 116 | results_and_scores = top_passages_and_scores[i] 117 | docs = [passages[doc_id] for doc_id in results_and_scores[0]] 118 | scores = [str(score) for score in results_and_scores[1]] 119 | ctxs_num = len(docs) 120 | d["ctxs"] = [ 121 | { 122 | "id": results_and_scores[0][c], 123 | "title": docs[c]["title"] if "title" in docs[c] else "", 124 | "text": docs[c]["text"], 125 | "score": scores[c], 126 | } 127 | for c in range(ctxs_num) 128 | ] 129 | 130 | 131 | def add_hasanswer(data, hasanswer): 132 | # add hasanswer to data 133 | for i, ex in enumerate(data): 134 | for k, d in enumerate(ex["ctxs"]): 135 | d["hasanswer"] = hasanswer[i][k] 136 | 137 | # fix me 138 | 139 | 140 | def load_data(data_path): 141 | if data_path.endswith(".json"): 142 | with open(data_path, "r") as fin: 143 | data = json.load(fin) 144 | elif data_path.endswith(".jsonl"): 145 | data = [] 146 | with open(data_path, "r") as fin: 147 | for k, example in enumerate(fin): 148 | example = json.loads(example) 149 | data.append(example) 150 | return data 151 | 152 | 153 | def main(args): 154 | 155 | print(f"Loading model from: {args.model_name_or_path}") 156 | model, tokenizer, _ = src.contriever.load_retriever( 157 | args.model_name_or_path) 158 | model.eval() 159 | model = model.cuda() 160 | if not args.no_fp16: 161 | model = model.half() 162 | 163 | index = src.index.Indexer(args.projection_size, 164 | args.n_subquantizers, args.n_bits) 165 | 166 | # index all passages 167 | input_paths = glob.glob(args.passages_embeddings) 168 | input_paths = sorted(input_paths) 169 | embeddings_dir = os.path.dirname(input_paths[0]) 170 | index_path = os.path.join(embeddings_dir, "index.faiss") 171 | if args.save_or_load_index and os.path.exists(index_path): 172 | index.deserialize_from(embeddings_dir) 173 | else: 174 | print(f"Indexing passages from files {input_paths}") 175 | start_time_indexing = time.time() 176 | index_encoded_data(index, input_paths, args.indexing_batch_size) 177 | print(f"Indexing time: {time.time()-start_time_indexing:.1f} s.") 178 | if args.save_or_load_index: 179 | index.serialize(embeddings_dir) 180 | 181 | # load passages 182 | passages = src.data.load_passages(args.passages) 183 | passage_id_map = {x["id"]: x for x in passages} 184 | 185 | data_paths = glob.glob(args.data) 186 | alldata = [] 187 | for path in data_paths: 188 | data = load_data(path) 189 | output_path = os.path.join(args.output_dir, os.path.basename(path)) 190 | 191 | queries = [ex["question"] for ex in data] 192 | questions_embedding = embed_queries(args, queries, model, tokenizer) 193 | 194 | # get top k results 195 | start_time_retrieval = time.time() 196 | top_ids_and_scores = index.search_knn(questions_embedding, args.n_docs) 197 | print(f"Search time: {time.time()-start_time_retrieval:.1f} s.") 198 | 199 | add_passages(data, passage_id_map, top_ids_and_scores) 200 | hasanswer = validate(data, args.validation_workers) 201 | add_hasanswer(data, hasanswer) 202 | os.makedirs(os.path.dirname(output_path), exist_ok=True) 203 | with open(output_path, "w") as fout: 204 | for ex in data: 205 | json.dump(ex, fout, ensure_ascii=False) 206 | fout.write("\n") 207 | print(f"Saved results to {output_path}") 208 | 209 | 210 | if __name__ == "__main__": 211 | parser = argparse.ArgumentParser() 212 | 213 | parser.add_argument( 214 | "--data", 215 | required=True, 216 | type=str, 217 | default=None, 218 | help=".json file containing question and answers, similar format to reader data", 219 | ) 220 | parser.add_argument("--passages", type=str, default=None, 221 | help="Path to passages (.tsv file)") 222 | parser.add_argument("--passages_embeddings", type=str, 223 | default=None, help="Glob path to encoded passages") 224 | parser.add_argument( 225 | "--output_dir", type=str, default=None, help="Results are written to outputdir with data suffix" 226 | ) 227 | parser.add_argument("--n_docs", type=int, default=100, 228 | help="Number of documents to retrieve per questions") 229 | parser.add_argument( 230 | "--validation_workers", type=int, default=32, help="Number of parallel processes to validate results" 231 | ) 232 | parser.add_argument("--per_gpu_batch_size", type=int, 233 | default=64, help="Batch size for question encoding") 234 | parser.add_argument( 235 | "--save_or_load_index", action="store_true", help="If enabled, save index and load index if it exists" 236 | ) 237 | parser.add_argument( 238 | "--model_name_or_path", type=str, help="path to directory containing model weights and config file" 239 | ) 240 | parser.add_argument("--no_fp16", action="store_true", 241 | help="inference in fp32") 242 | parser.add_argument("--question_maxlength", type=int, 243 | default=512, help="Maximum number of tokens in a question") 244 | parser.add_argument( 245 | "--indexing_batch_size", type=int, default=1000000, help="Batch size of the number of passages indexed" 246 | ) 247 | parser.add_argument("--projection_size", type=int, default=768) 248 | parser.add_argument( 249 | "--n_subquantizers", 250 | type=int, 251 | default=0, 252 | help="Number of subquantizer used for vector quantization, if 0 flat index is used", 253 | ) 254 | parser.add_argument("--n_bits", type=int, default=8, 255 | help="Number of bits per subquantizer") 256 | parser.add_argument("--lang", nargs="+") 257 | parser.add_argument("--dataset", type=str, default="none") 258 | parser.add_argument("--lowercase", action="store_true", 259 | help="lowercase text before encoding") 260 | parser.add_argument("--normalize_text", 261 | action="store_true", help="normalize text") 262 | 263 | args = parser.parse_args() 264 | src.slurm.init_distributed_mode(args) 265 | main(args) 266 | -------------------------------------------------------------------------------- /TART/preprocess.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import os 8 | import argparse 9 | import torch 10 | 11 | import transformers 12 | from src.normalize_text import normalize 13 | 14 | 15 | def save(tensor, split_path): 16 | if not os.path.exists(os.path.dirname(split_path)): 17 | os.makedirs(os.path.dirname(split_path)) 18 | with open(split_path, 'wb') as fout: 19 | torch.save(tensor, fout) 20 | 21 | 22 | def apply_tokenizer(path, tokenizer, normalize_text=False): 23 | alltokens = [] 24 | lines = [] 25 | with open(path, "r", encoding="utf-8") as fin: 26 | for k, line in enumerate(fin): 27 | if normalize_text: 28 | line = normalize(line) 29 | 30 | lines.append(line) 31 | if len(lines) > 1000000: 32 | tokens = tokenizer.batch_encode_plus( 33 | lines, add_special_tokens=False)['input_ids'] 34 | tokens = [torch.tensor(x, dtype=torch.int) for x in tokens] 35 | alltokens.extend(tokens) 36 | lines = [] 37 | 38 | tokens = tokenizer.batch_encode_plus( 39 | lines, add_special_tokens=False)['input_ids'] 40 | tokens = [torch.tensor(x, dtype=torch.int) for x in tokens] 41 | alltokens.extend(tokens) 42 | 43 | alltokens = torch.cat(alltokens) 44 | return alltokens 45 | 46 | 47 | def tokenize_file(args): 48 | filename = os.path.basename(args.datapath) 49 | savepath = os.path.join(args.outdir, f"{filename}.pkl") 50 | if os.path.exists(savepath): 51 | if args.overwrite: 52 | print(f"File {savepath} already exists, overwriting") 53 | else: 54 | print(f"File {savepath} already exists, exiting") 55 | return 56 | try: 57 | tokenizer = transformers.AutoTokenizer.from_pretrained( 58 | args.tokenizer, local_files_only=True) 59 | except: 60 | tokenizer = transformers.AutoTokenizer.from_pretrained( 61 | args.tokenizer, local_files_only=False) 62 | print(f"Encoding {args.datapath}...") 63 | tokens = apply_tokenizer(args.datapath, tokenizer, 64 | normalize_text=args.normalize_text) 65 | 66 | print(f"Saving at {savepath}...") 67 | save(tokens, savepath) 68 | 69 | 70 | if __name__ == '__main__': 71 | parser = argparse.ArgumentParser( 72 | formatter_class=argparse.ArgumentDefaultsHelpFormatter) 73 | parser.add_argument("--datapath", type=str) 74 | parser.add_argument("--outdir", type=str) 75 | parser.add_argument("--tokenizer", type=str) 76 | parser.add_argument("--overwrite", action="store_true") 77 | parser.add_argument("--normalize_text", action="store_true") 78 | 79 | args, _ = parser.parse_known_args() 80 | tokenize_file(args) 81 | -------------------------------------------------------------------------------- /TART/requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.11.0 2 | transformers==4.18.0 3 | beir==1.0.0 4 | pytrec_eval 5 | jsonlines 6 | gdown -------------------------------------------------------------------------------- /TART/src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/tart/4ed5fcb0ed0254b1062305adbc390617b296fd29/TART/src/__init__.py -------------------------------------------------------------------------------- /TART/src/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/tart/4ed5fcb0ed0254b1062305adbc390617b296fd29/TART/src/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /TART/src/__pycache__/contriever.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/tart/4ed5fcb0ed0254b1062305adbc390617b296fd29/TART/src/__pycache__/contriever.cpython-39.pyc -------------------------------------------------------------------------------- /TART/src/__pycache__/data.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/tart/4ed5fcb0ed0254b1062305adbc390617b296fd29/TART/src/__pycache__/data.cpython-39.pyc -------------------------------------------------------------------------------- /TART/src/__pycache__/dist_utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/tart/4ed5fcb0ed0254b1062305adbc390617b296fd29/TART/src/__pycache__/dist_utils.cpython-39.pyc -------------------------------------------------------------------------------- /TART/src/__pycache__/evaluation.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/tart/4ed5fcb0ed0254b1062305adbc390617b296fd29/TART/src/__pycache__/evaluation.cpython-39.pyc -------------------------------------------------------------------------------- /TART/src/__pycache__/index.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/tart/4ed5fcb0ed0254b1062305adbc390617b296fd29/TART/src/__pycache__/index.cpython-39.pyc -------------------------------------------------------------------------------- /TART/src/__pycache__/modeling_enc_t5.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/tart/4ed5fcb0ed0254b1062305adbc390617b296fd29/TART/src/__pycache__/modeling_enc_t5.cpython-39.pyc -------------------------------------------------------------------------------- /TART/src/__pycache__/normalize_text.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/tart/4ed5fcb0ed0254b1062305adbc390617b296fd29/TART/src/__pycache__/normalize_text.cpython-39.pyc -------------------------------------------------------------------------------- /TART/src/__pycache__/slurm.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/tart/4ed5fcb0ed0254b1062305adbc390617b296fd29/TART/src/__pycache__/slurm.cpython-39.pyc -------------------------------------------------------------------------------- /TART/src/__pycache__/tokenization_enc_t5.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/tart/4ed5fcb0ed0254b1062305adbc390617b296fd29/TART/src/__pycache__/tokenization_enc_t5.cpython-39.pyc -------------------------------------------------------------------------------- /TART/src/__pycache__/utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/tart/4ed5fcb0ed0254b1062305adbc390617b296fd29/TART/src/__pycache__/utils.cpython-39.pyc -------------------------------------------------------------------------------- /TART/src/contriever.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | import os 4 | from string import printable 5 | import torch 6 | import transformers 7 | from transformers import BertModel, XLMRobertaModel, AlbertModel, T5EncoderModel 8 | 9 | from src import utils 10 | 11 | 12 | class Contriever(BertModel): 13 | def __init__(self, config, pooling="average", **kwargs): 14 | super().__init__(config, add_pooling_layer=False) 15 | if not hasattr(config, "pooling"): 16 | self.config.pooling = pooling 17 | 18 | def forward( 19 | self, 20 | input_ids=None, 21 | attention_mask=None, 22 | token_type_ids=None, 23 | position_ids=None, 24 | head_mask=None, 25 | inputs_embeds=None, 26 | encoder_hidden_states=None, 27 | encoder_attention_mask=None, 28 | output_attentions=None, 29 | output_hidden_states=None, 30 | normalize=False, 31 | ): 32 | 33 | model_output = super().forward( 34 | input_ids=input_ids, 35 | attention_mask=attention_mask, 36 | token_type_ids=token_type_ids, 37 | position_ids=position_ids, 38 | head_mask=head_mask, 39 | inputs_embeds=inputs_embeds, 40 | encoder_hidden_states=encoder_hidden_states, 41 | encoder_attention_mask=encoder_attention_mask, 42 | output_attentions=output_attentions, 43 | output_hidden_states=output_hidden_states, 44 | ) 45 | 46 | last_hidden = model_output["last_hidden_state"] 47 | last_hidden = last_hidden.masked_fill(~attention_mask[..., None].bool(), 0.0) 48 | 49 | if self.config.pooling == "average": 50 | emb = last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None] 51 | elif self.config.pooling == "cls": 52 | emb = last_hidden[:, 0] 53 | 54 | if normalize: 55 | emb = torch.nn.functional.normalize(emb, dim=-1) 56 | return emb 57 | 58 | 59 | class XLMRetriever(XLMRobertaModel): 60 | def __init__(self, config, pooling="average", **kwargs): 61 | super().__init__(config, add_pooling_layer=False) 62 | if not hasattr(config, "pooling"): 63 | self.config.pooling = pooling 64 | 65 | def forward( 66 | self, 67 | input_ids=None, 68 | attention_mask=None, 69 | token_type_ids=None, 70 | position_ids=None, 71 | head_mask=None, 72 | inputs_embeds=None, 73 | encoder_hidden_states=None, 74 | encoder_attention_mask=None, 75 | output_attentions=None, 76 | output_hidden_states=None, 77 | normalize=False, 78 | ): 79 | 80 | model_output = super().forward( 81 | input_ids=input_ids, 82 | attention_mask=attention_mask, 83 | token_type_ids=token_type_ids, 84 | position_ids=position_ids, 85 | head_mask=head_mask, 86 | inputs_embeds=inputs_embeds, 87 | encoder_hidden_states=encoder_hidden_states, 88 | encoder_attention_mask=encoder_attention_mask, 89 | output_attentions=output_attentions, 90 | output_hidden_states=output_hidden_states, 91 | ) 92 | 93 | last_hidden = model_output["last_hidden_state"] 94 | last_hidden = last_hidden.masked_fill(~attention_mask[..., None].bool(), 0.0) 95 | if self.config.pooling == "average": 96 | emb = last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None] 97 | elif self.config.pooling == "cls": 98 | emb = last_hidden[:, 0] 99 | if normalize: 100 | emb = torch.nn.functional.normalize(emb, dim=-1) 101 | return emb 102 | 103 | 104 | class ALBERTRetriever(AlbertModel): 105 | def __init__(self, config, pooling="average", **kwargs): 106 | super().__init__(config, add_pooling_layer=False) 107 | if not hasattr(config, "pooling"): 108 | self.config.pooling = pooling 109 | 110 | def forward( 111 | self, 112 | input_ids=None, 113 | attention_mask=None, 114 | token_type_ids=None, 115 | position_ids=None, 116 | head_mask=None, 117 | inputs_embeds=None, 118 | encoder_hidden_states=None, 119 | encoder_attention_mask=None, 120 | output_attentions=None, 121 | output_hidden_states=None, 122 | normalize=False, 123 | ): 124 | 125 | model_output = super().forward( 126 | input_ids=input_ids, 127 | attention_mask=attention_mask, 128 | token_type_ids=token_type_ids, 129 | position_ids=position_ids, 130 | head_mask=head_mask, 131 | inputs_embeds=inputs_embeds, 132 | output_attentions=output_attentions, 133 | output_hidden_states=output_hidden_states, 134 | ) 135 | 136 | last_hidden = model_output["last_hidden_state"] 137 | last_hidden = last_hidden.masked_fill(~attention_mask[..., None].bool(), 0.0) 138 | if self.config.pooling == "average": 139 | emb = last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None] 140 | elif self.config.pooling == "cls": 141 | emb = last_hidden[:, 0] 142 | if normalize: 143 | emb = torch.nn.functional.normalize(emb, dim=-1) 144 | return emb 145 | 146 | class T5Contriever(T5EncoderModel): 147 | def __init__(self, config, pooling="average", **kwargs): 148 | super().__init__(config) 149 | if not hasattr(config, "pooling"): 150 | self.config.pooling = pooling 151 | 152 | def forward( 153 | self, 154 | input_ids=None, 155 | attention_mask=None, 156 | token_type_ids=None, 157 | position_ids=None, 158 | head_mask=None, 159 | inputs_embeds=None, 160 | encoder_hidden_states=None, 161 | encoder_attention_mask=None, 162 | output_attentions=None, 163 | output_hidden_states=None, 164 | normalize=False, 165 | ): 166 | 167 | model_output = super().forward( 168 | input_ids=input_ids, 169 | attention_mask=attention_mask, 170 | head_mask=head_mask, 171 | inputs_embeds=inputs_embeds, 172 | output_attentions=output_attentions, 173 | output_hidden_states=output_hidden_states, 174 | ) 175 | 176 | last_hidden = model_output["last_hidden_state"] 177 | last_hidden = last_hidden.masked_fill(~attention_mask[..., None].bool(), 0.0) 178 | 179 | if self.config.pooling == "average": 180 | emb = last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None] 181 | elif self.config.pooling == "cls": 182 | emb = last_hidden[:, 0] 183 | 184 | if normalize: 185 | emb = torch.nn.functional.normalize(emb, dim=-1) 186 | return emb 187 | 188 | def load_retriever(model_path, pooling="average", random_init=False): 189 | # try: check if model exists locally 190 | path = os.path.join(model_path, "checkpoint.pth") 191 | if os.path.exists(path): 192 | pretrained_dict = torch.load(path, map_location="cpu") 193 | opt = pretrained_dict["opt"] 194 | if hasattr(opt, "retriever_model_id"): 195 | retriever_model_id = opt.retriever_model_id 196 | else: 197 | # retriever_model_id = "bert-base-uncased" 198 | retriever_model_id = "bert-base-multilingual-cased" 199 | tokenizer = utils.load_hf(transformers.AutoTokenizer, retriever_model_id) 200 | cfg = utils.load_hf(transformers.AutoConfig, retriever_model_id) 201 | if "xlm" in retriever_model_id: 202 | model_class = XLMRetriever 203 | elif "albert" in retriever_model_id: 204 | print("Albert Contriever") 205 | model_class = ALBERTRetriever 206 | elif "t5" in retriever_model_id or "T0" in retriever_model_id or "gtr" in retriever_model_id: 207 | model_class = T5Contriever 208 | else: 209 | model_class = Contriever 210 | retriever = model_class(cfg) 211 | pretrained_dict = pretrained_dict["model"] 212 | 213 | if any("encoder_q." in key for key in pretrained_dict.keys()): # test if model is defined with moco class 214 | pretrained_dict = {k.replace("encoder_q.", ""): v for k, v in pretrained_dict.items() if "encoder_q." in k} 215 | # elif any("encoder." in key for key in pretrained_dict.keys()): # test if model is defined with inbatch class 216 | # pretrained_dict = {k.replace("encoder.", ""): v for k, v in pretrained_dict.items() if "encoder." in k} 217 | retriever.load_state_dict(pretrained_dict) 218 | else: 219 | retriever_model_id = model_path 220 | if "xlm" in retriever_model_id: 221 | model_class = XLMRetriever 222 | elif "albert" in retriever_model_id: 223 | model_class = ALBERTRetriever 224 | elif "t5" in retriever_model_id or "T0" in retriever_model_id or "gtr" in retriever_model_id: 225 | model_class = T5Contriever 226 | else: 227 | model_class = Contriever 228 | cfg = utils.load_hf(transformers.AutoConfig, model_path) 229 | tokenizer = utils.load_hf(transformers.AutoTokenizer, model_path) 230 | retriever = utils.load_hf(model_class, model_path) 231 | 232 | return retriever, tokenizer, retriever_model_id 233 | -------------------------------------------------------------------------------- /TART/src/data.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | import os 4 | import glob 5 | import torch 6 | import random 7 | import json 8 | import csv 9 | import numpy as np 10 | import numpy.random 11 | import logging 12 | from collections import defaultdict 13 | import torch.distributed as dist 14 | import io 15 | 16 | from src import dist_utils 17 | 18 | logger = logging.getLogger(__name__) 19 | 20 | 21 | def load_data(opt, tokenizer): 22 | datasets = {} 23 | for path in opt.train_data: 24 | data = load_dataset(path, opt.loading_mode) 25 | if data is not None: 26 | datasets[path] = Dataset(data, opt.chunk_length, tokenizer, opt) 27 | dataset = MultiDataset(datasets) 28 | dataset.set_prob(coeff=opt.sampling_coefficient) 29 | return dataset 30 | 31 | 32 | def load_dataset(data_path, loading_mode): 33 | files = glob.glob(os.path.join(data_path, "*.p*")) 34 | files.sort() 35 | tensors = [] 36 | if loading_mode == "split": 37 | files_split = list(np.array_split(files, dist_utils.get_world_size()))[dist_utils.get_rank()] 38 | for filepath in files_split: 39 | try: 40 | tensors.append(torch.load(filepath, map_location="cpu")) 41 | except: 42 | logger.warning(f"Unable to load file {filepath}") 43 | elif loading_mode == "full": 44 | for fin in files: 45 | tensors.append(torch.load(fin, map_location="cpu")) 46 | elif loading_mode == "single": 47 | tensors.append(torch.load(files[0], map_location="cpu")) 48 | if len(tensors) == 0: 49 | return None 50 | tensor = torch.cat(tensors) 51 | return tensor 52 | 53 | 54 | class MultiDataset(torch.utils.data.Dataset): 55 | def __init__(self, datasets): 56 | 57 | self.datasets = datasets 58 | self.prob = [1 / len(self.datasets) for _ in self.datasets] 59 | self.dataset_ids = list(self.datasets.keys()) 60 | 61 | def __len__(self): 62 | return sum([len(dataset) for dataset in self.datasets.values()]) 63 | 64 | def __getitem__(self, index): 65 | dataset_idx = numpy.random.choice(range(len(self.prob)), 1, p=self.prob)[0] 66 | did = self.dataset_ids[dataset_idx] 67 | index = random.randint(0, len(self.datasets[did]) - 1) 68 | sample = self.datasets[did][index] 69 | sample["dataset_id"] = did 70 | return sample 71 | 72 | def generate_offset(self): 73 | for dataset in self.datasets.values(): 74 | dataset.generate_offset() 75 | 76 | def set_prob(self, coeff=0.0): 77 | 78 | prob = np.array([float(len(dataset)) for _, dataset in self.datasets.items()]) 79 | prob /= prob.sum() 80 | prob = np.array([p**coeff for p in prob]) 81 | prob /= prob.sum() 82 | self.prob = prob 83 | 84 | 85 | class Dataset(torch.utils.data.Dataset): 86 | """Monolingual dataset based on a list of paths""" 87 | 88 | def __init__(self, data, chunk_length, tokenizer, opt): 89 | 90 | self.data = data 91 | self.chunk_length = chunk_length 92 | self.tokenizer = tokenizer 93 | self.opt = opt 94 | self.generate_offset() 95 | 96 | def __len__(self): 97 | return (self.data.size(0) - self.offset) // self.chunk_length 98 | 99 | def __getitem__(self, index): 100 | start_idx = self.offset + index * self.chunk_length 101 | end_idx = start_idx + self.chunk_length 102 | tokens = self.data[start_idx:end_idx] 103 | q_tokens = randomcrop(tokens, self.opt.ratio_min, self.opt.ratio_max) 104 | k_tokens = randomcrop(tokens, self.opt.ratio_min, self.opt.ratio_max) 105 | q_tokens = apply_augmentation(q_tokens, self.opt) 106 | q_tokens = add_bos_eos(q_tokens, self.tokenizer.bos_token_id, self.tokenizer.eos_token_id) 107 | k_tokens = apply_augmentation(k_tokens, self.opt) 108 | k_tokens = add_bos_eos(k_tokens, self.tokenizer.bos_token_id, self.tokenizer.eos_token_id) 109 | 110 | return {"q_tokens": q_tokens, "k_tokens": k_tokens} 111 | 112 | def generate_offset(self): 113 | self.offset = random.randint(0, self.chunk_length - 1) 114 | 115 | 116 | class Collator(object): 117 | def __init__(self, opt): 118 | self.opt = opt 119 | 120 | def __call__(self, batch_examples): 121 | 122 | batch = defaultdict(list) 123 | for example in batch_examples: 124 | for k, v in example.items(): 125 | batch[k].append(v) 126 | 127 | q_tokens, q_mask = build_mask(batch["q_tokens"]) 128 | k_tokens, k_mask = build_mask(batch["k_tokens"]) 129 | 130 | batch["q_tokens"] = q_tokens 131 | batch["q_mask"] = q_mask 132 | batch["k_tokens"] = k_tokens 133 | batch["k_mask"] = k_mask 134 | 135 | return batch 136 | 137 | 138 | def randomcrop(x, ratio_min, ratio_max): 139 | 140 | ratio = random.uniform(ratio_min, ratio_max) 141 | length = int(len(x) * ratio) 142 | start = random.randint(0, len(x) - length) 143 | end = start + length 144 | crop = x[start:end].clone() 145 | return crop 146 | 147 | 148 | def build_mask(tensors): 149 | shapes = [x.shape for x in tensors] 150 | maxlength = max([len(x) for x in tensors]) 151 | returnmasks = [] 152 | ids = [] 153 | for k, x in enumerate(tensors): 154 | returnmasks.append(torch.tensor([1] * len(x) + [0] * (maxlength - len(x)))) 155 | ids.append(torch.cat((x, torch.tensor([0] * (maxlength - len(x)))))) 156 | ids = torch.stack(ids, dim=0).long() 157 | returnmasks = torch.stack(returnmasks, dim=0).bool() 158 | return ids, returnmasks 159 | 160 | 161 | def add_token(x, token): 162 | x = torch.cat((torch.tensor([token]), x)) 163 | return x 164 | 165 | 166 | def deleteword(x, p=0.1): 167 | mask = np.random.rand(len(x)) 168 | x = [e for e, m in zip(x, mask) if m > p] 169 | return x 170 | 171 | 172 | def replaceword(x, min_random, max_random, p=0.1): 173 | mask = np.random.rand(len(x)) 174 | x = [e if m > p else random.randint(min_random, max_random) for e, m in zip(x, mask)] 175 | return x 176 | 177 | 178 | def maskword(x, mask_id, p=0.1): 179 | mask = np.random.rand(len(x)) 180 | x = [e if m > p else mask_id for e, m in zip(x, mask)] 181 | return x 182 | 183 | 184 | def shuffleword(x, p=0.1): 185 | count = (np.random.rand(len(x)) < p).sum() 186 | """Shuffles any n number of values in a list""" 187 | indices_to_shuffle = random.sample(range(len(x)), k=count) 188 | to_shuffle = [x[i] for i in indices_to_shuffle] 189 | random.shuffle(to_shuffle) 190 | for index, value in enumerate(to_shuffle): 191 | old_index = indices_to_shuffle[index] 192 | x[old_index] = value 193 | return x 194 | 195 | 196 | def apply_augmentation(x, opt): 197 | if opt.augmentation == "mask": 198 | return torch.tensor(maskword(x, mask_id=opt.mask_id, p=opt.prob_augmentation)) 199 | elif opt.augmentation == "replace": 200 | return torch.tensor( 201 | replaceword(x, min_random=opt.start_id, max_random=opt.vocab_size - 1, p=opt.prob_augmentation) 202 | ) 203 | elif opt.augmentation == "delete": 204 | return torch.tensor(deleteword(x, p=opt.prob_augmentation)) 205 | elif opt.augmentation == "shuffle": 206 | return torch.tensor(shuffleword(x, p=opt.prob_augmentation)) 207 | else: 208 | if not isinstance(x, torch.Tensor): 209 | x = torch.Tensor(x) 210 | return x 211 | 212 | 213 | def add_bos_eos(x, bos_token_id, eos_token_id): 214 | if not isinstance(x, torch.Tensor): 215 | x = torch.Tensor(x) 216 | if bos_token_id is None and eos_token_id is not None: 217 | x = torch.cat([x.clone().detach(), torch.tensor([eos_token_id])]) 218 | elif bos_token_id is not None and eos_token_id is None: 219 | x = torch.cat([torch.tensor([bos_token_id]), x.clone().detach()]) 220 | elif bos_token_id is None and eos_token_id is None: 221 | pass 222 | else: 223 | x = torch.cat([torch.tensor([bos_token_id]), x.clone().detach(), torch.tensor([eos_token_id])]) 224 | return x 225 | 226 | 227 | # Used for passage retrieval 228 | def load_passages(path): 229 | if not os.path.exists(path): 230 | logger.info(f"{path} does not exist") 231 | return 232 | logger.info(f"Loading passages from: {path}") 233 | passages = [] 234 | with open(path, encoding='utf-8') as fin: 235 | if path.endswith(".jsonl"): 236 | for k, line in enumerate(fin): 237 | ex = json.loads(line) 238 | passages.append(ex) 239 | else: 240 | data = fin.read() 241 | data = data.replace("\0", "") # get rid of null-bytes 242 | data_repr = io.StringIO(data) 243 | reader = csv.reader(data_repr, delimiter="\t") 244 | for k, row in enumerate(reader): 245 | if not row[0] == "id": 246 | ex = {"id": row[0], "title": row[2], "text": row[1]} if len(row) == 3 else {"id": row[0], "text": row[1]} 247 | passages.append(ex) 248 | if path.endswith(".jsonl") and "_id" in passages[0]: 249 | for item in passages: 250 | item["id"] = item["_id"] 251 | if type(item["text"]) is float: 252 | item["text"] = item["title"] 253 | if type(item["title"]) is float: 254 | item["title"] = "" 255 | return passages 256 | -------------------------------------------------------------------------------- /TART/src/dist_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | import torch 4 | import torch.distributed as dist 5 | 6 | 7 | class Gather(torch.autograd.Function): 8 | @staticmethod 9 | def forward(ctx, x: torch.tensor): 10 | output = [torch.zeros_like(x) for _ in range(dist.get_world_size())] 11 | dist.all_gather(output, x) 12 | return tuple(output) 13 | 14 | @staticmethod 15 | def backward(ctx, *grads): 16 | all_gradients = torch.stack(grads) 17 | dist.all_reduce(all_gradients) 18 | return all_gradients[dist.get_rank()] 19 | 20 | 21 | def gather(x: torch.tensor): 22 | if not dist.is_initialized(): 23 | return x 24 | x_gather = Gather.apply(x) 25 | x_gather = torch.cat(x_gather, dim=0) 26 | return x_gather 27 | 28 | 29 | @torch.no_grad() 30 | def gather_nograd(x: torch.tensor): 31 | if not dist.is_initialized(): 32 | return x 33 | x_gather = [torch.ones_like(x) for _ in range(dist.get_world_size())] 34 | dist.all_gather(x_gather, x, async_op=False) 35 | 36 | x_gather = torch.cat(x_gather, dim=0) 37 | return x_gather 38 | 39 | 40 | @torch.no_grad() 41 | def varsize_gather_nograd(x: torch.Tensor): 42 | """gather tensors of different sizes along the first dimension""" 43 | if not dist.is_initialized(): 44 | return x 45 | 46 | # determine max size 47 | size = torch.tensor([x.shape[0]], device=x.device, dtype=torch.int) 48 | allsizes = [torch.zeros_like(size) for _ in range(dist.get_world_size())] 49 | dist.all_gather(allsizes, size) 50 | max_size = max([size.cpu().max() for size in allsizes]) 51 | 52 | padded = torch.empty(max_size, *x.shape[1:], dtype=x.dtype, device=x.device) 53 | padded[: x.shape[0]] = x 54 | output = [torch.zeros_like(padded) for _ in range(dist.get_world_size())] 55 | dist.all_gather(output, padded) 56 | 57 | output = [tensor[: allsizes[k]] for k, tensor in enumerate(output)] 58 | output = torch.cat(output, dim=0) 59 | 60 | return output 61 | 62 | 63 | @torch.no_grad() 64 | def get_varsize(x: torch.Tensor): 65 | """gather tensors of different sizes along the first dimension""" 66 | if not dist.is_initialized(): 67 | return [x.shape[0]] 68 | 69 | # determine max size 70 | size = torch.tensor([x.shape[0]], device=x.device, dtype=torch.int) 71 | allsizes = [torch.zeros_like(size) for _ in range(dist.get_world_size())] 72 | dist.all_gather(allsizes, size) 73 | allsizes = torch.cat(allsizes) 74 | return allsizes 75 | 76 | 77 | def get_rank(): 78 | if not dist.is_available(): 79 | return 0 80 | if not dist.is_initialized(): 81 | return 0 82 | return dist.get_rank() 83 | 84 | 85 | def is_main(): 86 | return get_rank() == 0 87 | 88 | 89 | def get_world_size(): 90 | if not dist.is_initialized(): 91 | return 1 92 | else: 93 | return dist.get_world_size() 94 | 95 | 96 | def barrier(): 97 | if dist.is_initialized(): 98 | dist.barrier() 99 | 100 | 101 | def average_main(x): 102 | if not dist.is_initialized(): 103 | return x 104 | if dist.is_initialized() and dist.get_world_size() > 1: 105 | dist.reduce(x, 0, op=dist.ReduceOp.SUM) 106 | if is_main(): 107 | x = x / dist.get_world_size() 108 | return x 109 | 110 | 111 | def sum_main(x): 112 | if not dist.is_initialized(): 113 | return x 114 | if dist.is_initialized() and dist.get_world_size() > 1: 115 | dist.reduce(x, 0, op=dist.ReduceOp.SUM) 116 | return x 117 | 118 | 119 | def weighted_average(x, count): 120 | if not dist.is_initialized(): 121 | if isinstance(x, torch.Tensor): 122 | x = x.item() 123 | return x, count 124 | t_loss = torch.tensor([x * count]).cuda() 125 | t_total = torch.tensor([count]).cuda() 126 | t_loss = sum_main(t_loss) 127 | t_total = sum_main(t_total) 128 | return (t_loss / t_total).item(), t_total.item() 129 | -------------------------------------------------------------------------------- /TART/src/evaluation.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | import collections 4 | import logging 5 | import regex 6 | import string 7 | import unicodedata 8 | from functools import partial 9 | from multiprocessing import Pool as ProcessPool 10 | from typing import Tuple, List, Dict 11 | import numpy as np 12 | from collections import Counter 13 | 14 | """ 15 | Evaluation code from DPR: https://github.com/facebookresearch/DPR 16 | """ 17 | 18 | class SimpleTokenizer(object): 19 | ALPHA_NUM = r'[\p{L}\p{N}\p{M}]+' 20 | NON_WS = r'[^\p{Z}\p{C}]' 21 | 22 | def __init__(self): 23 | """ 24 | Args: 25 | annotators: None or empty set (only tokenizes). 26 | """ 27 | self._regexp = regex.compile( 28 | '(%s)|(%s)' % (self.ALPHA_NUM, self.NON_WS), 29 | flags=regex.IGNORECASE + regex.UNICODE + regex.MULTILINE 30 | ) 31 | 32 | def tokenize(self, text, uncased=False): 33 | matches = [m for m in self._regexp.finditer(text)] 34 | if uncased: 35 | tokens = [m.group().lower() for m in matches] 36 | else: 37 | tokens = [m.group() for m in matches] 38 | return tokens 39 | 40 | logger = logging.getLogger(__name__) 41 | 42 | QAMatchStats = collections.namedtuple('QAMatchStats', ['top_k_hits', 'questions_doc_hits']) 43 | 44 | def calculate_matches(data: List, workers_num: int): 45 | """ 46 | Evaluates answers presence in the set of documents. This function is supposed to be used with a large collection of 47 | documents and results. It internally forks multiple sub-processes for evaluation and then merges results 48 | :param all_docs: dictionary of the entire documents database. doc_id -> (doc_text, title) 49 | :param answers: list of answers's list. One list per question 50 | :param closest_docs: document ids of the top results along with their scores 51 | :param workers_num: amount of parallel threads to process data 52 | :param match_type: type of answer matching. Refer to has_answer code for available options 53 | :return: matching information tuple. 54 | top_k_hits - a list where the index is the amount of top documents retrieved and the value is the total amount of 55 | valid matches across an entire dataset. 56 | questions_doc_hits - more detailed info with answer matches for every question and every retrieved document 57 | """ 58 | 59 | logger.info('Matching answers in top docs...') 60 | 61 | tokenizer = SimpleTokenizer() 62 | get_score_partial = partial(check_answer, tokenizer=tokenizer) 63 | 64 | processes = ProcessPool(processes=workers_num) 65 | scores = processes.map(get_score_partial, data) 66 | 67 | logger.info('Per question validation results len=%d', len(scores)) 68 | 69 | n_docs = len(data[0]['ctxs']) 70 | top_k_hits = [0] * n_docs 71 | for question_hits in scores: 72 | best_hit = next((i for i, x in enumerate(question_hits) if x), None) 73 | if best_hit is not None: 74 | top_k_hits[best_hit:] = [v + 1 for v in top_k_hits[best_hit:]] 75 | 76 | return QAMatchStats(top_k_hits, scores) 77 | 78 | def check_answer(example, tokenizer) -> List[bool]: 79 | """Search through all the top docs to see if they have any of the answers.""" 80 | answers = example['answers'] 81 | ctxs = example['ctxs'] 82 | 83 | hits = [] 84 | 85 | for i, doc in enumerate(ctxs): 86 | text = doc['text'] 87 | 88 | if text is None: # cannot find the document for some reason 89 | logger.warning("no doc in db") 90 | hits.append(False) 91 | continue 92 | 93 | hits.append(has_answer(answers, text, tokenizer)) 94 | 95 | return hits 96 | 97 | def has_answer(answers, text, tokenizer) -> bool: 98 | """Check if a document contains an answer string.""" 99 | text = _normalize(text) 100 | text = tokenizer.tokenize(text, uncased=True) 101 | 102 | for answer in answers: 103 | answer = _normalize(answer) 104 | answer = tokenizer.tokenize(answer, uncased=True) 105 | for i in range(0, len(text) - len(answer) + 1): 106 | if answer == text[i: i + len(answer)]: 107 | return True 108 | return False 109 | 110 | ################################################# 111 | ######## READER EVALUATION ######## 112 | ################################################# 113 | 114 | def _normalize(text): 115 | return unicodedata.normalize('NFD', text) 116 | 117 | #Normalization and score functions from SQuAD evaluation script https://worksheets.codalab.org/rest/bundles/0x6b567e1cf2e041ec80d7098f031c5c9e/contents/blob/ 118 | def normalize_answer(s): 119 | def remove_articles(text): 120 | return regex.sub(r'\b(a|an|the)\b', ' ', text) 121 | 122 | def white_space_fix(text): 123 | return ' '.join(text.split()) 124 | 125 | def remove_punc(text): 126 | exclude = set(string.punctuation) 127 | return ''.join(ch for ch in text if ch not in exclude) 128 | 129 | def lower(text): 130 | return text.lower() 131 | 132 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 133 | 134 | def em(prediction, ground_truth): 135 | return normalize_answer(prediction) == normalize_answer(ground_truth) 136 | 137 | def f1(prediction, ground_truth): 138 | prediction_tokens = normalize_answer(prediction).split() 139 | ground_truth_tokens = normalize_answer(ground_truth).split() 140 | common = Counter(prediction_tokens) & Counter(ground_truth_tokens) 141 | num_same = sum(common.values()) 142 | if num_same == 0: 143 | return 0 144 | precision = 1.0 * num_same / len(prediction_tokens) 145 | recall = 1.0 * num_same / len(ground_truth_tokens) 146 | f1 = (2 * precision * recall) / (precision + recall) 147 | return f1 148 | 149 | def f1_score(prediction, ground_truths): 150 | return max([f1(prediction, gt) for gt in ground_truths]) 151 | 152 | def exact_match_score(prediction, ground_truths): 153 | return max([em(prediction, gt) for gt in ground_truths]) 154 | 155 | #################################################### 156 | ######## RETRIEVER EVALUATION ######## 157 | #################################################### 158 | 159 | def eval_batch(scores, inversions, avg_topk, idx_topk): 160 | for k, s in enumerate(scores): 161 | s = s.cpu().numpy() 162 | sorted_idx = np.argsort(-s) 163 | score(sorted_idx, inversions, avg_topk, idx_topk) 164 | 165 | def count_inversions(arr): 166 | inv_count = 0 167 | lenarr = len(arr) 168 | for i in range(lenarr): 169 | for j in range(i + 1, lenarr): 170 | if (arr[i] > arr[j]): 171 | inv_count += 1 172 | return inv_count 173 | 174 | def score(x, inversions, avg_topk, idx_topk): 175 | x = np.array(x) 176 | inversions.append(count_inversions(x)) 177 | for k in avg_topk: 178 | # ratio of passages in the predicted top-k that are 179 | # also in the topk given by gold score 180 | avg_pred_topk = (x[:k] 0: 45 | random_negatives = random.sample(example["negative_ctxs"], n_random_negatives) 46 | negatives += random_negatives 47 | if n_hard_negatives > 0: 48 | hard_negatives = random.sample( 49 | example["hard_negative_ctxs"][self.hard_negative_min_idx :], n_hard_negatives 50 | ) 51 | negatives += hard_negatives 52 | else: 53 | gold = example["positive_ctxs"][0] 54 | nidx = 0 55 | if "negative_ctxs" in example: 56 | negatives = [example["negative_ctxs"][nidx]] 57 | else: 58 | negatives = [] 59 | 60 | if "title" in gold and type(gold["title"]) is str and len(gold) > 0: 61 | gold = gold["title"] + " " + gold["text"] 62 | else: 63 | gold = gold["text"] 64 | 65 | negatives_new = [] 66 | for n in negatives: 67 | if "title" in n and type(n["title"]) is str and len(n["title"]) > 0: 68 | negatives_new.append(n["title"] + " " + n["text"]) 69 | else: 70 | negatives_new.append(n["text"]) 71 | 72 | negatives = negatives_new 73 | # negatives = [ 74 | # n["title"] + " " + n["text"] if ("title" in n and type(n["title"]) is str and len(n["title"]) > 0) else n["text"] for n in negatives 75 | # ] 76 | 77 | example = { 78 | "query": self.normalize_fn(question), 79 | "gold": self.normalize_fn(gold), 80 | "negatives": [self.normalize_fn(n) for n in negatives], 81 | } 82 | return example 83 | 84 | def _load_data(self, datapaths, global_rank, world_size, maxload): 85 | counter = 0 86 | self.data = [] 87 | for path in datapaths: 88 | path = str(path) 89 | if path.endswith(".jsonl"): 90 | file_data, counter = self._load_data_jsonl(path, global_rank, world_size, counter, maxload) 91 | elif path.endswith(".json"): 92 | file_data, counter = self._load_data_json(path, global_rank, world_size, counter, maxload) 93 | self.data.extend(file_data) 94 | if maxload is not None and maxload > 0 and counter >= maxload: 95 | break 96 | 97 | def _load_data_json(self, path, global_rank, world_size, counter, maxload=None): 98 | examples = [] 99 | with open(path, "r") as fin: 100 | data = json.load(fin) 101 | for example in data: 102 | counter += 1 103 | if global_rank > -1 and not counter % world_size == global_rank: 104 | continue 105 | examples.append(example) 106 | if maxload is not None and maxload > 0 and counter == maxload: 107 | break 108 | 109 | return examples, counter 110 | 111 | def _load_data_jsonl(self, path, global_rank, world_size, counter, maxload=None): 112 | examples = [] 113 | with open(path, "r") as fin: 114 | for line in fin: 115 | counter += 1 116 | if global_rank > -1 and not counter % world_size == global_rank: 117 | continue 118 | example = json.loads(line) 119 | examples.append(example) 120 | if maxload is not None and maxload > 0 and counter == maxload: 121 | break 122 | 123 | return examples, counter 124 | 125 | def sample_n_hard_negatives(self, ex): 126 | 127 | if "hard_negative_ctxs" in ex: 128 | n_hard_negatives = sum([random.random() < self.negative_hard_ratio for _ in range(self.negative_ctxs)]) 129 | n_hard_negatives = min(n_hard_negatives, len(ex["hard_negative_ctxs"][self.negative_hard_min_idx :])) 130 | else: 131 | n_hard_negatives = 0 132 | n_random_negatives = self.negative_ctxs - n_hard_negatives 133 | if "negative_ctxs" in ex: 134 | n_random_negatives = min(n_random_negatives, len(ex["negative_ctxs"])) 135 | else: 136 | n_random_negatives = 0 137 | # print(n_hard_negatives, n_random_negatives) 138 | return n_hard_negatives, n_random_negatives 139 | 140 | 141 | class Collator(object): 142 | def __init__(self, tokenizer, passage_maxlength=200): 143 | self.tokenizer = tokenizer 144 | self.passage_maxlength = passage_maxlength 145 | 146 | def __call__(self, batch): 147 | queries = [ex["query"] for ex in batch] 148 | golds = [ex["gold"] for ex in batch] 149 | negs = [item for ex in batch for item in ex["negatives"]] 150 | allpassages = golds + negs 151 | 152 | qout = self.tokenizer.batch_encode_plus( 153 | queries, 154 | max_length=self.passage_maxlength, 155 | truncation=True, 156 | padding=True, 157 | add_special_tokens=True, 158 | return_tensors="pt", 159 | ) 160 | kout = self.tokenizer.batch_encode_plus( 161 | allpassages, 162 | max_length=self.passage_maxlength, 163 | truncation=True, 164 | padding=True, 165 | add_special_tokens=True, 166 | return_tensors="pt", 167 | ) 168 | q_tokens, q_mask = qout["input_ids"], qout["attention_mask"].bool() 169 | k_tokens, k_mask = kout["input_ids"], kout["attention_mask"].bool() 170 | 171 | g_tokens, g_mask = k_tokens[: len(golds)], k_mask[: len(golds)] 172 | n_tokens, n_mask = k_tokens[len(golds) :], k_mask[len(golds) :] 173 | 174 | batch = { 175 | "q_tokens": q_tokens, 176 | "q_mask": q_mask, 177 | "k_tokens": k_tokens, 178 | "k_mask": k_mask, 179 | "g_tokens": g_tokens, 180 | "g_mask": g_mask, 181 | "n_tokens": n_tokens, 182 | "n_mask": n_mask, 183 | } 184 | 185 | return batch 186 | 187 | class DatasetKD(torch.utils.data.Dataset): 188 | def __init__(self, 189 | datapaths, 190 | n_context=50, 191 | question_prefix='', 192 | title_prefix='', 193 | passage_prefix='', 194 | global_rank=-1, 195 | world_size=-1, 196 | maxload=None, 197 | random_sort=False): 198 | 199 | self._load_data(datapaths, global_rank, world_size, maxload) 200 | self.n_context = n_context 201 | self.question_prefix = question_prefix 202 | self.title_prefix = title_prefix 203 | self.passage_prefix = passage_prefix 204 | self.random_sort = random_sort 205 | self.normalize_fn = normalize_text.normalize if normalize_text else lambda x: x 206 | self.sort_data() 207 | 208 | def __len__(self): 209 | return len(self.data) 210 | 211 | def __getitem__(self, index): 212 | example = self.data[index] 213 | question = self.question_prefix + " " + example['question'] 214 | 215 | if 'ctxs' in example and self.n_context is not None: 216 | f = self.title_prefix + " {} " + self.passage_prefix + " {}" 217 | if len(example['ctxs']) > self.n_context and self.random_sort is True: 218 | contexts = random.sample(example['ctxs'], k=self.n_context) 219 | else: 220 | contexts = example['ctxs'][:self.n_context] 221 | passages = [f.format(c['title'], c['text']) for c in contexts] 222 | scores = [float(c['gold_score']) for c in contexts] 223 | scores = torch.tensor(scores) 224 | if len(contexts) == 0: 225 | contexts = [question] 226 | else: 227 | passages, scores = None, None 228 | 229 | return { 230 | 'index' : index, 231 | 'question' : question, 232 | 'passages' : passages, 233 | 'scores' : scores 234 | } 235 | 236 | def sort_data(self): 237 | if self.n_context is None or not 'gold_score' in self.data[0]['ctxs'][0]: 238 | return 239 | for ex in self.data: 240 | ex['ctxs'].sort(key=lambda x: float(x['gold_score']), reverse=True) 241 | 242 | def get_example(self, index): 243 | return self.data[index] 244 | 245 | def _load_data(self, datapaths, global_rank, world_size, maxload): 246 | counter = 0 247 | self.data = [] 248 | for path in datapaths: 249 | path = str(path) 250 | if path.endswith(".jsonl"): 251 | file_data, counter = self._load_data_jsonl(path, global_rank, world_size, counter, maxload) 252 | elif path.endswith(".json"): 253 | file_data, counter = self._load_data_json(path, global_rank, world_size, counter, maxload) 254 | self.data.extend(file_data) 255 | if maxload is not None and maxload > 0 and counter >= maxload: 256 | break 257 | 258 | def _load_data_json(self, path, global_rank, world_size, counter, maxload=None): 259 | examples = [] 260 | with open(path, "r") as fin: 261 | data = json.load(fin) 262 | for example in data: 263 | counter += 1 264 | if global_rank > -1 and not counter % world_size == global_rank: 265 | continue 266 | examples.append(example) 267 | if maxload is not None and maxload > 0 and counter == maxload: 268 | break 269 | 270 | return examples, counter 271 | 272 | def _load_data_jsonl(self, path, global_rank, world_size, counter, maxload=None): 273 | examples = [] 274 | with open(path, "r") as fin: 275 | for line in fin: 276 | counter += 1 277 | if global_rank > -1 and not counter % world_size == global_rank: 278 | continue 279 | example = json.loads(line) 280 | examples.append(example) 281 | if maxload is not None and maxload > 0 and counter == maxload: 282 | break 283 | 284 | return examples, counter 285 | 286 | def encode_passages(batch_text_passages, tokenizer, max_length): 287 | passage_ids, passage_masks = [], [] 288 | for k, text_passages in enumerate(batch_text_passages): 289 | p = tokenizer.batch_encode_plus( 290 | text_passages, 291 | max_length=max_length, 292 | pad_to_max_length=True, 293 | return_tensors='pt', 294 | truncation=True 295 | ) 296 | passage_ids.append(p['input_ids'][None]) 297 | passage_masks.append(p['attention_mask'][None]) 298 | 299 | passage_ids = torch.cat(passage_ids, dim=0) 300 | passage_masks = torch.cat(passage_masks, dim=0) 301 | return passage_ids, passage_masks.bool() 302 | 303 | def load_data(data_path=None, global_rank=-1, world_size=-1): 304 | assert data_path 305 | if data_path.endswith('.jsonl'): 306 | data = open(data_path, 'r') 307 | elif data_path.endswith('.json'): 308 | with open(data_path, 'r') as fin: 309 | data = json.load(fin) 310 | examples = [] 311 | for k, example in enumerate(data): 312 | if global_rank > -1 and not k%world_size==global_rank: 313 | continue 314 | if data_path is not None and data_path.endswith('.jsonl'): 315 | example = json.loads(example) 316 | if not 'id' in example: 317 | example['id'] = k 318 | for c in example['ctxs']: 319 | if not 'score' in c: 320 | c['score'] = 1.0 / (k + 1) 321 | examples.append(example) 322 | if data_path is not None and data_path.endswith('.jsonl'): 323 | data.close() 324 | 325 | return examples 326 | 327 | class CollatorKD(object): 328 | def __init__(self, tokenizer, passage_maxlength=200, question_maxlength=40): 329 | self.tokenizer = tokenizer 330 | self.passage_maxlength = passage_maxlength 331 | self.question_maxlength = question_maxlength 332 | 333 | def __call__(self, batch): 334 | index = torch.tensor([ex['index'] for ex in batch]) 335 | 336 | question = [ex['question'] for ex in batch] 337 | question = self.tokenizer.batch_encode_plus( 338 | question, 339 | pad_to_max_length=True, 340 | return_tensors="pt", 341 | max_length=self.question_maxlength, 342 | truncation=True 343 | ) 344 | question_ids = question['input_ids'] 345 | question_mask = question['attention_mask'].bool() 346 | 347 | if batch[0]['scores'] is None or batch[0]['passages'] is None: 348 | return index, question_ids, question_mask, None, None, None 349 | 350 | scores = [ex['scores'] for ex in batch] 351 | scores = torch.stack(scores, dim=0) 352 | 353 | passages = [ex['passages'] for ex in batch] 354 | passage_ids, passage_masks = encode_passages( 355 | passages, 356 | self.tokenizer, 357 | self.passage_maxlength 358 | ) 359 | 360 | batch = { 361 | "question_ids": question_ids, 362 | "question_mask": question_mask, 363 | "passage_ids": passage_ids, 364 | "passage_mask": passage_masks, 365 | "gold_score": scores 366 | } 367 | 368 | return batch 369 | -------------------------------------------------------------------------------- /TART/src/inbatch.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | import torch 4 | import torch.nn as nn 5 | import numpy as np 6 | import math 7 | import random 8 | import transformers 9 | import logging 10 | import torch.distributed as dist 11 | import copy 12 | from src import contriever, dist_utils, utils 13 | import torch.nn.functional as F 14 | 15 | logger = logging.getLogger(__name__) 16 | 17 | 18 | class InBatch(nn.Module): 19 | def __init__(self, opt, retriever=None, tokenizer=None): 20 | super(InBatch, self).__init__() 21 | 22 | self.opt = opt 23 | self.norm_doc = opt.norm_doc 24 | self.norm_query = opt.norm_query 25 | self.label_smoothing = opt.label_smoothing 26 | if retriever is None or tokenizer is None: 27 | retriever, tokenizer = self._load_retriever( 28 | opt.retriever_model_id, pooling=opt.pooling, random_init=opt.random_init 29 | ) 30 | self.tokenizer = tokenizer 31 | self.encoder = retriever 32 | 33 | def _load_retriever(self, model_id, pooling, random_init): 34 | print("load retrieval") 35 | print(model_id) 36 | if "xlm" in model_id: 37 | model_class = contriever.XLMRetriever 38 | elif "t5" in model_id or "T0" in model_id or "gtr" in model_id: 39 | print("loading t0") 40 | model_class = contriever.T5Contriever 41 | print(model_class) 42 | else: 43 | model_class = contriever.Contriever 44 | 45 | cfg = utils.load_hf(transformers.AutoConfig, model_id) 46 | tokenizer = utils.load_hf(transformers.AutoTokenizer, model_id) 47 | if random_init: 48 | retriever = model_class(cfg) 49 | else: 50 | retriever = utils.load_hf(model_class, model_id) 51 | 52 | if "bert-" in model_id: 53 | if tokenizer.bos_token_id is None: 54 | tokenizer.bos_token = "[CLS]" 55 | if tokenizer.eos_token_id is None: 56 | tokenizer.eos_token = "[SEP]" 57 | 58 | retriever.config.pooling = pooling 59 | 60 | return retriever, tokenizer 61 | 62 | def get_encoder(self): 63 | return self.encoder 64 | 65 | def forward(self, q_tokens, q_mask, k_tokens, k_mask, gold_scores=None, stats_prefix="", iter_stats={}, **kwargs): 66 | 67 | bsz = len(q_tokens) 68 | labels = torch.arange(0, bsz, dtype=torch.long, device=q_tokens.device) 69 | 70 | qemb = self.encoder(input_ids=q_tokens, attention_mask=q_mask, normalize=self.norm_query) 71 | kemb = self.encoder(input_ids=k_tokens, attention_mask=k_mask, normalize=self.norm_doc) 72 | 73 | gather_fn = dist_utils.gather 74 | 75 | gather_kemb = gather_fn(kemb) 76 | 77 | labels = labels + dist_utils.get_rank() * len(kemb) 78 | 79 | scores = torch.einsum("id, jd->ij", qemb / self.opt.temperature, gather_kemb) 80 | 81 | loss = torch.nn.functional.cross_entropy(scores, labels, label_smoothing=self.label_smoothing) 82 | 83 | # log stats 84 | if len(stats_prefix) > 0: 85 | stats_prefix = stats_prefix + "/" 86 | iter_stats[f"{stats_prefix}loss"] = (loss.item(), bsz) 87 | 88 | predicted_idx = torch.argmax(scores, dim=-1) 89 | accuracy = 100 * (predicted_idx == labels).float().mean() 90 | stdq = torch.std(qemb, dim=0).mean().item() 91 | stdk = torch.std(kemb, dim=0).mean().item() 92 | iter_stats[f"{stats_prefix}accuracy"] = (accuracy, bsz) 93 | iter_stats[f"{stats_prefix}stdq"] = (stdq, bsz) 94 | iter_stats[f"{stats_prefix}stdk"] = (stdk, bsz) 95 | 96 | return loss, iter_stats 97 | 98 | 99 | 100 | class ByInBatch(nn.Module): 101 | def __init__(self, opt, retriever=None, tokenizer=None): 102 | super(ByInBatch, self).__init__() 103 | 104 | self.opt = opt 105 | self.norm_doc = opt.norm_doc 106 | self.norm_query = opt.norm_query 107 | self.label_smoothing = opt.label_smoothing 108 | if retriever is None or tokenizer is None: 109 | retriever, tokenizer = self._load_retriever( 110 | opt.retriever_model_id, pooling=opt.pooling, random_init=opt.random_init 111 | ) 112 | self.tokenizer = tokenizer 113 | self.q_encoder = copy.deepcopy(retriever) 114 | self.p_encoder = copy.deepcopy(retriever) 115 | 116 | def _load_retriever(self, model_id, pooling, random_init): 117 | cfg = utils.load_hf(transformers.AutoConfig, model_id) 118 | tokenizer = utils.load_hf(transformers.AutoTokenizer, model_id) 119 | 120 | if "xlm" in model_id: 121 | model_class = contriever.XLMRetriever 122 | else: 123 | model_class = contriever.Contriever 124 | 125 | if random_init: 126 | retriever = model_class(cfg) 127 | else: 128 | retriever = utils.load_hf(model_class, model_id) 129 | 130 | if "bert-" in model_id: 131 | if tokenizer.bos_token_id is None: 132 | tokenizer.bos_token = "[CLS]" 133 | if tokenizer.eos_token_id is None: 134 | tokenizer.eos_token = "[SEP]" 135 | 136 | retriever.config.pooling = pooling 137 | 138 | return retriever, tokenizer 139 | 140 | def get_q_encoder(self): 141 | return self.q_encoder 142 | 143 | def get_p_encoder(self): 144 | return self.p_encoder 145 | 146 | def forward(self, q_tokens, q_mask, k_tokens, k_mask, stats_prefix="", iter_stats={}, **kwargs): 147 | 148 | bsz = len(q_tokens) 149 | labels = torch.arange(0, bsz, dtype=torch.long, device=q_tokens.device) 150 | 151 | qemb = self.q_encoder(input_ids=q_tokens, attention_mask=q_mask, normalize=self.norm_query) 152 | kemb = self.p_encoder(input_ids=k_tokens, attention_mask=k_mask, normalize=self.norm_doc) 153 | 154 | gather_fn = dist_utils.gather 155 | 156 | gather_kemb = gather_fn(kemb) 157 | 158 | labels = labels + dist_utils.get_rank() * len(kemb) 159 | 160 | scores = torch.einsum("id, jd->ij", qemb / self.opt.temperature, gather_kemb) 161 | 162 | loss = torch.nn.functional.cross_entropy(scores, labels, label_smoothing=self.label_smoothing) 163 | 164 | # log stats 165 | if len(stats_prefix) > 0: 166 | stats_prefix = stats_prefix + "/" 167 | iter_stats[f"{stats_prefix}loss"] = (loss.item(), bsz) 168 | 169 | predicted_idx = torch.argmax(scores, dim=-1) 170 | accuracy = 100 * (predicted_idx == labels).float().mean() 171 | stdq = torch.std(qemb, dim=0).mean().item() 172 | stdk = torch.std(kemb, dim=0).mean().item() 173 | iter_stats[f"{stats_prefix}accuracy"] = (accuracy, bsz) 174 | iter_stats[f"{stats_prefix}stdq"] = (stdq, bsz) 175 | iter_stats[f"{stats_prefix}stdk"] = (stdk, bsz) 176 | 177 | return loss, iter_stats 178 | 179 | 180 | class InBatch_KD(nn.Module): 181 | def __init__(self, opt, retriever=None, tokenizer=None, loss_type="kl", temperature=1): 182 | super(InBatch_KD, self).__init__() 183 | 184 | self.opt = opt 185 | self.norm_doc = opt.norm_doc 186 | self.norm_query = opt.norm_query 187 | self.label_smoothing = opt.label_smoothing 188 | if retriever is None or tokenizer is None: 189 | retriever, tokenizer = self._load_retriever( 190 | opt.retriever_model_id, pooling=opt.pooling, random_init=opt.random_init 191 | ) 192 | self.tokenizer = tokenizer 193 | self.encoder = retriever 194 | self.loss_type = loss_type 195 | self.temperature = temperature 196 | if loss_type == "kl": 197 | self.loss_fct = torch.nn.KLDivLoss() 198 | elif loss_type == "mse": 199 | self.loss_fct = torch.nn.MSELoss() 200 | else: 201 | raise NotImplementedError 202 | 203 | def _load_retriever(self, model_id, pooling, random_init): 204 | cfg = utils.load_hf(transformers.AutoConfig, model_id) 205 | tokenizer = utils.load_hf(transformers.AutoTokenizer, model_id) 206 | 207 | if "xlm" in model_id: 208 | model_class = contriever.XLMRetriever 209 | elif "t5" in model_id or "T0" in model_id or "gtr" in model_id: 210 | model_class = contriever.T5Contriever 211 | else: 212 | model_class = contriever.Contriever 213 | 214 | if random_init: 215 | retriever = model_class(cfg) 216 | else: 217 | retriever = utils.load_hf(model_class, model_id) 218 | 219 | if "bert-" in model_id: 220 | if tokenizer.bos_token_id is None: 221 | tokenizer.bos_token = "[CLS]" 222 | if tokenizer.eos_token_id is None: 223 | tokenizer.eos_token = "[SEP]" 224 | 225 | retriever.config.pooling = pooling 226 | 227 | return retriever, tokenizer 228 | 229 | def get_encoder(self): 230 | return self.encoder 231 | 232 | def forward(self, question_ids, question_mask, passage_ids, passage_mask, gold_score, stats_prefix="", iter_stats={}, **kwargs): 233 | question_output = self.encoder(input_ids=question_ids, attention_mask=question_mask, normalize=self.norm_query) 234 | bsz, n_passages, plen = passage_ids.size() 235 | passage_ids = passage_ids.view(bsz * n_passages, plen) 236 | passage_mask = passage_mask.view(bsz * n_passages, plen) 237 | passage_output = self.encoder(input_ids=passage_ids, attention_mask=passage_mask, normalize=self.norm_doc) 238 | 239 | score = torch.einsum( 240 | 'bd,bid->bi', 241 | question_output, 242 | passage_output.view(bsz, n_passages, -1) 243 | ) 244 | 245 | score = score / np.sqrt(question_output.size(-1)) 246 | if gold_score is not None: 247 | if self.loss_type == "kl": 248 | loss = self.kldivloss(score, gold_score) 249 | else: 250 | loss = self.mseloss(score, gold_score) 251 | else: 252 | loss = None 253 | # log stats 254 | if len(stats_prefix) > 0: 255 | stats_prefix = stats_prefix + "/" 256 | iter_stats[f"{stats_prefix}loss"] = (loss.item(), bsz) 257 | 258 | # predicted_idx = torch.argmax(scores, dim=-1) 259 | # accuracy = 100 * (predicted_idx == labels).float().mean() 260 | # stdq = torch.std(qemb, dim=0).mean().item() 261 | # stdk = torch.std(kemb, dim=0).mean().item() 262 | # iter_stats[f"{stats_prefix}accuracy"] = (accuracy, bsz) 263 | # iter_stats[f"{stats_prefix}stdq"] = (stdq, bsz) 264 | # iter_stats[f"{stats_prefix}stdk"] = (stdk, bsz) 265 | # print(loss) 266 | return loss, iter_stats 267 | 268 | def kldivloss(self, score, gold_score): 269 | # print("scores") 270 | # print(gold_score[0,:10]) 271 | gold_score = torch.softmax(gold_score / self.temperature, dim=-1) 272 | # print(gold_score[0,:10]) 273 | # print(score[0,:10]) 274 | score = torch.nn.functional.log_softmax(score / self.temperature, dim=-1) 275 | # print(score[0,:10]) 276 | loss = self.loss_fct(score, gold_score) * (self.temperature**2) 277 | # loss = F.kl_div(score, gold_score, size_average=False) * (self.temperature**2) 278 | # print(loss) 279 | # print(loss.size()) 280 | 281 | return loss 282 | 283 | def mseloss(self, score, gold_score): 284 | # print("scores") 285 | # print(gold_score[0,:10]) 286 | gold_score = torch.softmax(gold_score, dim=-1) 287 | # print(gold_score[0,:10]) 288 | # print(score[0,:10]) 289 | score = torch.softmax(score, dim=-1) 290 | # print(score[0,:10]) 291 | loss = self.loss_fct(score, gold_score) 292 | # print(loss) 293 | # print(loss.size()) 294 | return self.loss_fct(score, gold_score) -------------------------------------------------------------------------------- /TART/src/index.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | import os 4 | import pickle 5 | from typing import List, Tuple 6 | 7 | import faiss 8 | import numpy as np 9 | from tqdm import tqdm 10 | 11 | class Indexer(object): 12 | 13 | def __init__(self, vector_sz, n_subquantizers=0, n_bits=8): 14 | if n_subquantizers > 0: 15 | self.index = faiss.IndexPQ(vector_sz, n_subquantizers, n_bits, faiss.METRIC_INNER_PRODUCT) 16 | else: 17 | self.index = faiss.IndexFlatIP(vector_sz) 18 | #self.index_id_to_db_id = np.empty((0), dtype=np.int64) 19 | self.index_id_to_db_id = [] 20 | 21 | def index_data(self, ids, embeddings): 22 | self._update_id_mapping(ids) 23 | embeddings = embeddings.astype('float32') 24 | if not self.index.is_trained: 25 | self.index.train(embeddings) 26 | print(len(embeddings[0])) 27 | self.index.add(embeddings) 28 | 29 | print(f'Total data indexed {len(self.index_id_to_db_id)}') 30 | 31 | def search_knn(self, query_vectors: np.array, top_docs: int, index_batch_size: int = 2048) -> List[Tuple[List[object], List[float]]]: 32 | query_vectors = query_vectors.astype('float32') 33 | result = [] 34 | nbatch = (len(query_vectors)-1) // index_batch_size + 1 35 | for k in tqdm(range(nbatch)): 36 | start_idx = k*index_batch_size 37 | end_idx = min((k+1)*index_batch_size, len(query_vectors)) 38 | q = query_vectors[start_idx: end_idx] 39 | scores, indexes = self.index.search(q, top_docs) 40 | # convert to external ids 41 | db_ids = [[str(self.index_id_to_db_id[i]) for i in query_top_idxs] for query_top_idxs in indexes] 42 | result.extend([(db_ids[i], scores[i]) for i in range(len(db_ids))]) 43 | return result 44 | 45 | def serialize(self, dir_path): 46 | index_file = os.path.join(dir_path, 'index.faiss') 47 | meta_file = os.path.join(dir_path, 'index_meta.faiss') 48 | print(f'Serializing index to {index_file}, meta data to {meta_file}') 49 | 50 | faiss.write_index(self.index, index_file) 51 | with open(meta_file, mode='wb') as f: 52 | pickle.dump(self.index_id_to_db_id, f) 53 | 54 | def deserialize_from(self, dir_path): 55 | index_file = os.path.join(dir_path, 'index.faiss') 56 | meta_file = os.path.join(dir_path, 'index_meta.faiss') 57 | print(f'Loading index from {index_file}, meta data from {meta_file}') 58 | 59 | self.index = faiss.read_index(index_file) 60 | print('Loaded index of type %s and size %d', type(self.index), self.index.ntotal) 61 | 62 | with open(meta_file, "rb") as reader: 63 | self.index_id_to_db_id = pickle.load(reader) 64 | assert len( 65 | self.index_id_to_db_id) == self.index.ntotal, 'Deserialized index_id_to_db_id should match faiss index size' 66 | 67 | def _update_id_mapping(self, db_ids: List): 68 | #new_ids = np.array(db_ids, dtype=np.int64) 69 | #self.index_id_to_db_id = np.concatenate((self.index_id_to_db_id, new_ids), axis=0) 70 | self.index_id_to_db_id.extend(db_ids) -------------------------------------------------------------------------------- /TART/src/moco.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | import torch 4 | import torch.nn as nn 5 | import logging 6 | import copy 7 | import transformers 8 | 9 | from src import contriever, dist_utils, utils 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | class MoCo(nn.Module): 15 | def __init__(self, opt): 16 | super(MoCo, self).__init__() 17 | 18 | self.queue_size = opt.queue_size 19 | self.momentum = opt.momentum 20 | self.temperature = opt.temperature 21 | self.label_smoothing = opt.label_smoothing 22 | self.norm_doc = opt.norm_doc 23 | self.norm_query = opt.norm_query 24 | self.moco_train_mode_encoder_k = opt.moco_train_mode_encoder_k # apply the encoder on keys in train mode 25 | 26 | retriever, tokenizer = self._load_retriever( 27 | opt.retriever_model_id, pooling=opt.pooling, random_init=opt.random_init 28 | ) 29 | 30 | self.tokenizer = tokenizer 31 | self.encoder_q = retriever 32 | self.encoder_k = copy.deepcopy(retriever) 33 | 34 | for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()): 35 | param_k.data.copy_(param_q.data) 36 | param_k.requires_grad = False 37 | 38 | # create the queue 39 | self.register_buffer("queue", torch.randn(opt.projection_size, self.queue_size)) 40 | self.queue = nn.functional.normalize(self.queue, dim=0) 41 | 42 | self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long)) 43 | 44 | def _load_retriever(self, model_id, pooling, random_init): 45 | cfg = utils.load_hf(transformers.AutoConfig, model_id) 46 | tokenizer = utils.load_hf(transformers.AutoTokenizer, model_id) 47 | 48 | if "xlm" in model_id: 49 | model_class = contriever.XLMRetriever 50 | if "albert" in model_id: 51 | model_class = contriever.ALBERTRetriever 52 | elif "t5" in model_id or "T0" in model_id: 53 | model_class = contriever.T5Contriever 54 | else: 55 | model_class = contriever.Contriever 56 | 57 | if random_init: 58 | retriever = model_class(cfg) 59 | else: 60 | retriever = utils.load_hf(model_class, model_id) 61 | 62 | if "bert-" in model_id: 63 | if tokenizer.bos_token_id is None: 64 | tokenizer.bos_token = "[CLS]" 65 | if tokenizer.eos_token_id is None: 66 | tokenizer.eos_token = "[SEP]" 67 | 68 | retriever.config.pooling = pooling 69 | 70 | return retriever, tokenizer 71 | 72 | def get_encoder(self, return_encoder_k=False): 73 | if return_encoder_k: 74 | return self.encoder_k 75 | else: 76 | return self.encoder_q 77 | 78 | def _momentum_update_key_encoder(self): 79 | """ 80 | Update of the key encoder 81 | """ 82 | for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()): 83 | param_k.data = param_k.data * self.momentum + param_q.data * (1.0 - self.momentum) 84 | 85 | @torch.no_grad() 86 | def _dequeue_and_enqueue(self, keys): 87 | # gather keys before updating queue 88 | keys = dist_utils.gather_nograd(keys.contiguous()) 89 | 90 | batch_size = keys.shape[0] 91 | 92 | ptr = int(self.queue_ptr) 93 | assert self.queue_size % batch_size == 0, f"{batch_size}, {self.queue_size}" # for simplicity 94 | 95 | # replace the keys at ptr (dequeue and enqueue) 96 | self.queue[:, ptr : ptr + batch_size] = keys.T 97 | ptr = (ptr + batch_size) % self.queue_size # move pointer 98 | 99 | self.queue_ptr[0] = ptr 100 | 101 | def _compute_logits(self, q, k): 102 | l_pos = torch.einsum("nc,nc->n", [q, k]).unsqueeze(-1) 103 | l_neg = torch.einsum("nc,ck->nk", [q, self.queue.clone().detach()]) 104 | 105 | logits = torch.cat([l_pos, l_neg], dim=1) 106 | return logits 107 | 108 | def forward(self, q_tokens, q_mask, k_tokens, k_mask, stats_prefix="", iter_stats={}, **kwargs): 109 | bsz = q_tokens.size(0) 110 | 111 | q = self.encoder_q(input_ids=q_tokens, attention_mask=q_mask, normalize=self.norm_query) 112 | # compute key features 113 | with torch.no_grad(): # no gradient to keys 114 | self._momentum_update_key_encoder() # update the key encoder 115 | 116 | if not self.encoder_k.training and not self.moco_train_mode_encoder_k: 117 | self.encoder_k.eval() 118 | 119 | k = self.encoder_k(input_ids=k_tokens, attention_mask=k_mask, normalize=self.norm_doc) 120 | logits = self._compute_logits(q, k) / self.temperature 121 | 122 | # labels: positive key indicators 123 | labels = torch.zeros(bsz, dtype=torch.long).cuda() 124 | 125 | loss = torch.nn.functional.cross_entropy(logits, labels, label_smoothing=self.label_smoothing) 126 | 127 | self._dequeue_and_enqueue(k) 128 | 129 | # log stats 130 | if len(stats_prefix) > 0: 131 | stats_prefix = stats_prefix + "/" 132 | iter_stats[f"{stats_prefix}loss"] = (loss.item(), bsz) 133 | 134 | predicted_idx = torch.argmax(logits, dim=-1) 135 | accuracy = 100 * (predicted_idx == labels).float().mean() 136 | stdq = torch.std(q, dim=0).mean().item() 137 | stdk = torch.std(k, dim=0).mean().item() 138 | iter_stats[f"{stats_prefix}accuracy"] = (accuracy, bsz) 139 | iter_stats[f"{stats_prefix}stdq"] = (stdq, bsz) 140 | iter_stats[f"{stats_prefix}stdk"] = (stdk, bsz) 141 | 142 | return loss, iter_stats 143 | -------------------------------------------------------------------------------- /TART/src/modeling_enc_t5.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | import copy 4 | 5 | import torch 6 | from torch import nn 7 | from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss 8 | from transformers.modeling_outputs import SequenceClassifierOutput 9 | from transformers.models.t5.modeling_t5 import T5Config, T5PreTrainedModel, T5Stack 10 | from transformers.utils.model_parallel_utils import assert_device_map, get_device_map 11 | 12 | 13 | class EncT5ForSequenceClassification(T5PreTrainedModel): 14 | _keys_to_ignore_on_load_missing = [ 15 | r"encoder\.embed_tokens\.weight", 16 | ] 17 | 18 | def __init__(self, config: T5Config, dropout=0.1): 19 | super().__init__(config) 20 | self.num_labels = config.num_labels 21 | self.config = config 22 | 23 | self.shared = nn.Embedding(config.vocab_size, config.d_model) 24 | 25 | encoder_config = copy.deepcopy(config) 26 | encoder_config.use_cache = False 27 | encoder_config.is_encoder_decoder = False 28 | self.encoder = T5Stack(encoder_config, self.shared) 29 | 30 | self.dropout = nn.Dropout(dropout) 31 | self.classifier = nn.Linear(config.hidden_size, config.num_labels) 32 | 33 | # Initialize weights and apply final processing 34 | self.post_init() 35 | 36 | # Model parallel 37 | self.model_parallel = False 38 | self.device_map = None 39 | 40 | def parallelize(self, device_map=None): 41 | self.device_map = ( 42 | get_device_map(len(self.encoder.block), range(torch.cuda.device_count())) 43 | if device_map is None 44 | else device_map 45 | ) 46 | assert_device_map(self.device_map, len(self.encoder.block)) 47 | self.encoder.parallelize(self.device_map) 48 | self.classifier = self.classifier.to(self.encoder.first_device) 49 | self.model_parallel = True 50 | 51 | def deparallelize(self): 52 | self.encoder.deparallelize() 53 | self.encoder = self.encoder.to("cpu") 54 | self.model_parallel = False 55 | self.device_map = None 56 | torch.cuda.empty_cache() 57 | 58 | def get_input_embeddings(self): 59 | return self.shared 60 | 61 | def set_input_embeddings(self, new_embeddings): 62 | self.shared = new_embeddings 63 | self.encoder.set_input_embeddings(new_embeddings) 64 | 65 | def get_encoder(self): 66 | return self.encoder 67 | 68 | def _prune_heads(self, heads_to_prune): 69 | """ 70 | Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base 71 | class PreTrainedModel 72 | """ 73 | for layer, heads in heads_to_prune.items(): 74 | self.encoder.layer[layer].attention.prune_heads(heads) 75 | 76 | def forward( 77 | self, 78 | input_ids=None, 79 | attention_mask=None, 80 | head_mask=None, 81 | inputs_embeds=None, 82 | labels=None, 83 | output_attentions=None, 84 | output_hidden_states=None, 85 | return_dict=None, 86 | ): 87 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 88 | 89 | outputs = self.encoder( 90 | input_ids=input_ids, 91 | attention_mask=attention_mask, 92 | inputs_embeds=inputs_embeds, 93 | head_mask=head_mask, 94 | output_attentions=output_attentions, 95 | output_hidden_states=output_hidden_states, 96 | return_dict=return_dict, 97 | ) 98 | 99 | hidden_states = outputs[0] 100 | pooled_output = hidden_states[:, 0, :] # Take bos token (equiv. to ) 101 | 102 | pooled_output = self.dropout(pooled_output) 103 | logits = self.classifier(pooled_output) 104 | 105 | loss = None 106 | if labels is not None: 107 | if self.config.problem_type is None: 108 | if self.num_labels == 1: 109 | self.config.problem_type = "regression" 110 | elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): 111 | self.config.problem_type = "single_label_classification" 112 | else: 113 | self.config.problem_type = "multi_label_classification" 114 | 115 | if self.config.problem_type == "regression": 116 | loss_fct = MSELoss() 117 | if self.num_labels == 1: 118 | loss = loss_fct(logits.squeeze(), labels.squeeze()) 119 | else: 120 | loss = loss_fct(logits, labels) 121 | elif self.config.problem_type == "single_label_classification": 122 | loss_fct = CrossEntropyLoss() 123 | loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) 124 | elif self.config.problem_type == "multi_label_classification": 125 | loss_fct = BCEWithLogitsLoss() 126 | loss = loss_fct(logits, labels) 127 | if not return_dict: 128 | output = (logits,) + outputs[1:] 129 | return ((loss,) + output) if loss is not None else output 130 | 131 | return SequenceClassifierOutput( 132 | loss=loss, 133 | logits=logits, 134 | hidden_states=outputs.hidden_states, 135 | attentions=outputs.attentions, 136 | ) 137 | -------------------------------------------------------------------------------- /TART/src/normalize_text.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | #: Control characters. 4 | CONTROLS = { 5 | '\u0001', '\u0002', '\u0003', '\u0004', '\u0005', '\u0006', '\u0007', '\u0008', '\u000e', '\u000f', '\u0011', 6 | '\u0012', '\u0013', '\u0014', '\u0015', '\u0016', '\u0017', '\u0018', '\u0019', '\u001a', '\u001b', 7 | } 8 | # There are further control characters, but they are instead replaced with a space by unicode normalization 9 | # '\u0009', '\u000a', '\u000b', '\u000c', '\u000d', '\u001c', '\u001d', '\u001e', '\u001f' 10 | 11 | 12 | #: Hyphen and dash characters. 13 | HYPHENS = { 14 | '-', # \u002d Hyphen-minus 15 | '‐', # \u2010 Hyphen 16 | '‑', # \u2011 Non-breaking hyphen 17 | '⁃', # \u2043 Hyphen bullet 18 | '‒', # \u2012 figure dash 19 | '–', # \u2013 en dash 20 | '—', # \u2014 em dash 21 | '―', # \u2015 horizontal bar 22 | } 23 | 24 | #: Minus characters. 25 | MINUSES = { 26 | '-', # \u002d Hyphen-minus 27 | '−', # \u2212 Minus 28 | '-', # \uff0d Full-width Hyphen-minus 29 | '⁻', # \u207b Superscript minus 30 | } 31 | 32 | #: Plus characters. 33 | PLUSES = { 34 | '+', # \u002b Plus 35 | '+', # \uff0b Full-width Plus 36 | '⁺', # \u207a Superscript plus 37 | } 38 | 39 | #: Slash characters. 40 | SLASHES = { 41 | '/', # \u002f Solidus 42 | '⁄', # \u2044 Fraction slash 43 | '∕', # \u2215 Division slash 44 | } 45 | 46 | #: Tilde characters. 47 | TILDES = { 48 | '~', # \u007e Tilde 49 | '˜', # \u02dc Small tilde 50 | '⁓', # \u2053 Swung dash 51 | '∼', # \u223c Tilde operator #in mbert vocab 52 | '∽', # \u223d Reversed tilde 53 | '∿', # \u223f Sine wave 54 | '〜', # \u301c Wave dash #in mbert vocab 55 | '~', # \uff5e Full-width tilde #in mbert vocab 56 | } 57 | 58 | #: Apostrophe characters. 59 | APOSTROPHES = { 60 | "'", # \u0027 61 | '’', # \u2019 62 | '՚', # \u055a 63 | 'Ꞌ', # \ua78b 64 | 'ꞌ', # \ua78c 65 | ''', # \uff07 66 | } 67 | 68 | #: Single quote characters. 69 | SINGLE_QUOTES = { 70 | "'", # \u0027 71 | '‘', # \u2018 72 | '’', # \u2019 73 | '‚', # \u201a 74 | '‛', # \u201b 75 | 76 | } 77 | 78 | #: Double quote characters. 79 | DOUBLE_QUOTES = { 80 | '"', # \u0022 81 | '“', # \u201c 82 | '”', # \u201d 83 | '„', # \u201e 84 | '‟', # \u201f 85 | } 86 | 87 | #: Accent characters. 88 | ACCENTS = { 89 | '`', # \u0060 90 | '´', # \u00b4 91 | } 92 | 93 | #: Prime characters. 94 | PRIMES = { 95 | '′', # \u2032 96 | '″', # \u2033 97 | '‴', # \u2034 98 | '‵', # \u2035 99 | '‶', # \u2036 100 | '‷', # \u2037 101 | '⁗', # \u2057 102 | } 103 | 104 | #: Quote characters, including apostrophes, single quotes, double quotes, accents and primes. 105 | QUOTES = APOSTROPHES | SINGLE_QUOTES | DOUBLE_QUOTES | ACCENTS | PRIMES 106 | 107 | def normalize(text): 108 | for control in CONTROLS: 109 | text = text.replace(control, '') 110 | text = text.replace('\u000b', ' ').replace('\u000c', ' ').replace(u'\u0085', ' ') 111 | 112 | for hyphen in HYPHENS | MINUSES: 113 | text = text.replace(hyphen, '-') 114 | text = text.replace('\u00ad', '') 115 | 116 | for double_quote in DOUBLE_QUOTES: 117 | text = text.replace(double_quote, '"') # \u0022 118 | for single_quote in (SINGLE_QUOTES | APOSTROPHES | ACCENTS): 119 | text = text.replace(single_quote, "'") # \u0027 120 | text = text.replace('′', "'") # \u2032 prime 121 | text = text.replace('‵', "'") # \u2035 reversed prime 122 | text = text.replace('″', "''") # \u2033 double prime 123 | text = text.replace('‶', "''") # \u2036 reversed double prime 124 | text = text.replace('‴', "'''") # \u2034 triple prime 125 | text = text.replace('‷', "'''") # \u2037 reversed triple prime 126 | text = text.replace('⁗', "''''") # \u2057 quadruple prime 127 | 128 | text = text.replace('…', '...').replace(' . . . ', ' ... ') # \u2026 129 | 130 | for slash in SLASHES: 131 | text = text.replace(slash, '/') 132 | 133 | #for tilde in TILDES: 134 | # text = text.replace(tilde, '~') 135 | 136 | return text -------------------------------------------------------------------------------- /TART/src/options.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | import argparse 4 | import os 5 | 6 | 7 | class Options: 8 | def __init__(self): 9 | self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 10 | self.initialize() 11 | 12 | def initialize(self): 13 | # basic parameters 14 | self.parser.add_argument( 15 | "--output_dir", type=str, default="./checkpoint/my_experiments", help="models are saved here" 16 | ) 17 | self.parser.add_argument( 18 | "--train_data", 19 | nargs="+", 20 | default=[], 21 | help="Data used for training, passed as a list of directories splitted into tensor files.", 22 | ) 23 | self.parser.add_argument( 24 | "--eval_data", 25 | nargs="+", 26 | default=[], 27 | help="Data used for evaluation during finetuning, this option is not used during contrastive pre-training.", 28 | ) 29 | self.parser.add_argument( 30 | "--eval_datasets", nargs="+", default=[], help="List of datasets used for evaluation, in BEIR format" 31 | ) 32 | self.parser.add_argument( 33 | "--eval_datasets_dir", type=str, default="./", help="Directory where eval datasets are stored" 34 | ) 35 | self.parser.add_argument("--model_path", type=str, default="none", help="path for retraining") 36 | self.parser.add_argument("--continue_training", action="store_true") 37 | self.parser.add_argument("--num_workers", type=int, default=5) 38 | 39 | self.parser.add_argument("--chunk_length", type=int, default=256) 40 | self.parser.add_argument("--loading_mode", type=str, default="split") 41 | self.parser.add_argument("--lower_case", action="store_true", help="perform evaluation after lowercasing") 42 | self.parser.add_argument( 43 | "--sampling_coefficient", 44 | type=float, 45 | default=0.0, 46 | help="coefficient used for sampling between different datasets during training, \ 47 | by default sampling is uniform over datasets", 48 | ) 49 | self.parser.add_argument("--augmentation", type=str, default="none") 50 | self.parser.add_argument("--prob_augmentation", type=float, default=0.0) 51 | 52 | self.parser.add_argument("--dropout", type=float, default=0.1) 53 | self.parser.add_argument("--rho", type=float, default=0.05) 54 | 55 | self.parser.add_argument("--contrastive_mode", type=str, default="moco") 56 | self.parser.add_argument("--queue_size", type=int, default=65536) 57 | self.parser.add_argument("--temperature", type=float, default=1.0) 58 | self.parser.add_argument("--momentum", type=float, default=0.999) 59 | self.parser.add_argument("--moco_train_mode_encoder_k", action="store_true") 60 | self.parser.add_argument("--eval_normalize_text", action="store_true") 61 | self.parser.add_argument("--norm_query", action="store_true") 62 | self.parser.add_argument("--norm_doc", action="store_true") 63 | self.parser.add_argument("--projection_size", type=int, default=768) 64 | 65 | # bi encoder 66 | self.parser.add_argument("--bi_encoder", action="store_true", help="instead of sharing a single encoder, we use separate encoders.") 67 | self.parser.add_argument("--freeze_ctx_encoder", action="store_true", help="if we use bi encoder but only want to update the query encoder.") 68 | 69 | self.parser.add_argument("--ratio_min", type=float, default=0.1) 70 | self.parser.add_argument("--ratio_max", type=float, default=0.5) 71 | self.parser.add_argument("--score_function", type=str, default="dot") 72 | self.parser.add_argument("--retriever_model_id", type=str, default="bert-base-uncased") 73 | self.parser.add_argument("--pooling", type=str, default="average") 74 | self.parser.add_argument("--random_init", action="store_true", help="init model with random weights") 75 | 76 | # dataset parameters 77 | self.parser.add_argument("--per_gpu_batch_size", default=64, type=int, help="Batch size per GPU for training.") 78 | self.parser.add_argument( 79 | "--per_gpu_eval_batch_size", default=256, type=int, help="Batch size per GPU for evaluation." 80 | ) 81 | self.parser.add_argument("--total_steps", type=int, default=1000) 82 | self.parser.add_argument("--warmup_steps", type=int, default=-1) 83 | 84 | self.parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") 85 | self.parser.add_argument("--main_port", type=int, default=10001, help="Master port (for multi-node SLURM jobs)") 86 | self.parser.add_argument("--seed", type=int, default=0, help="random seed for initialization") 87 | self.parser.add_argument("--hard_order", action="store_true", help="use the most related hard negatives.") 88 | 89 | # training parameters 90 | self.parser.add_argument("--optim", type=str, default="adamw") 91 | self.parser.add_argument("--scheduler", type=str, default="linear") 92 | self.parser.add_argument("--lr", type=float, default=1e-4, help="learning rate") 93 | self.parser.add_argument( 94 | "--lr_min_ratio", 95 | type=float, 96 | default=0.0, 97 | help="minimum learning rate at the end of the optimization schedule as a ratio of the learning rate", 98 | ) 99 | self.parser.add_argument("--weight_decay", type=float, default=0.01, help="learning rate") 100 | self.parser.add_argument("--beta1", type=float, default=0.9, help="beta1") 101 | self.parser.add_argument("--beta2", type=float, default=0.98, help="beta2") 102 | self.parser.add_argument("--eps", type=float, default=1e-6, help="eps") 103 | self.parser.add_argument( 104 | "--log_freq", type=int, default=100, help="log train stats every steps during training" 105 | ) 106 | self.parser.add_argument( 107 | "--eval_freq", type=int, default=500, help="evaluate model every steps during training" 108 | ) 109 | self.parser.add_argument("--save_freq", type=int, default=50000) 110 | self.parser.add_argument("--maxload", type=int, default=None) 111 | self.parser.add_argument("--label_smoothing", type=float, default=0.0) 112 | 113 | # finetuning options 114 | self.parser.add_argument("--negative_ctxs", type=int, default=1) 115 | self.parser.add_argument("--negative_hard_min_idx", type=int, default=0) 116 | self.parser.add_argument("--negative_hard_ratio", type=float, default=0.0) 117 | self.parser.add_argument("--kd", action="store_true") 118 | self.parser.add_argument("--loss_type", type=str, default="kl") 119 | self.parser.add_argument("--T", type=float, default=0.1, help="eps") 120 | self.parser.add_argument("--n_context", type=int, default=50) 121 | self.parser.add_argument("--random_sort", action="store_true", help="randomly sampling top N for distillation") 122 | 123 | def print_options(self, opt): 124 | message = "" 125 | for k, v in sorted(vars(opt).items()): 126 | comment = "" 127 | default = self.parser.get_default(k) 128 | if v != default: 129 | comment = f"\t[default: %s]" % str(default) 130 | message += f"{str(k):>40}: {str(v):<40}{comment}\n" 131 | print(message, flush=True) 132 | model_dir = os.path.join(opt.output_dir, "models") 133 | if not os.path.exists(model_dir): 134 | os.makedirs(os.path.join(opt.output_dir, "models")) 135 | file_name = os.path.join(opt.output_dir, "opt.txt") 136 | with open(file_name, "wt") as opt_file: 137 | opt_file.write(message) 138 | opt_file.write("\n") 139 | 140 | def parse(self): 141 | opt, _ = self.parser.parse_known_args() 142 | # opt = self.parser.parse_args() 143 | return opt 144 | -------------------------------------------------------------------------------- /TART/src/rerank.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | import logging 4 | from typing import Dict, List 5 | from transformers import AutoTokenizer, AutoModelForSequenceClassification 6 | import torch 7 | import torch.nn.functional as F 8 | from tqdm import tqdm 9 | from src.modeling_enc_t5 import EncT5ForSequenceClassification 10 | from src.tokenization_enc_t5 import EncT5Tokenizer 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | #Parent class for any reranking model 15 | class Rerank: 16 | def __init__(self, model_name_or_path, batch_size: int = 128, **kwargs): 17 | if "t0" in model_name_or_path or "t5" in model_name_or_path: 18 | self.model = EncT5ForSequenceClassification.from_pretrained(model_name_or_path) 19 | self.tokenizer = EncT5Tokenizer.from_pretrained(model_name_or_path) 20 | else: 21 | self.model = AutoModelForSequenceClassification.from_pretrained(model_name_or_path) 22 | self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) 23 | self.batch_size = batch_size 24 | self.rerank_results = {} 25 | self.model.to('cuda') 26 | 27 | self.model.eval() 28 | 29 | def rerank(self, 30 | corpus: Dict[str, Dict[str, str]], 31 | queries: Dict[str, str], 32 | results: Dict[str, Dict[str, float]], 33 | top_k: int, 34 | prompt: str = None) -> Dict[str, Dict[str, float]]: 35 | 36 | sentence_pairs, pair_ids = [], [] 37 | 38 | self.rerank_results = {query_id: {} for query_id in results} 39 | for query_id in tqdm(results): 40 | docs = [] 41 | query= [] 42 | doc_ids = [] 43 | 44 | if len(results[query_id]) > top_k: 45 | for (doc_id, _) in sorted(results[query_id].items(), key=lambda item: item[1], reverse=True)[:top_k]: 46 | pair_ids.append([query_id, doc_id]) 47 | corpus_text = (corpus[doc_id].get("title", "") + " " + corpus[doc_id].get("text", "")).strip() 48 | # print(corpus_text) 49 | # print(corpus_text) 50 | docs.append(corpus_text) 51 | doc_ids.append(doc_id) 52 | if prompt is None: 53 | # sentence_pairs.append([queries[query_id], corpus_text]) 54 | query.append(queries[query_id]) 55 | else: 56 | # sentence_pairs.append(["{0} [SEP] {1}".format(prompt, queries[query_id]), corpus_text]) 57 | query.append("{0} [SEP] {1}".format(prompt, queries[query_id])) 58 | # print(query[-1]) 59 | 60 | else: 61 | for doc_id in results[query_id]: 62 | pair_ids.append([query_id, doc_id]) 63 | corpus_text = (corpus[doc_id].get("title", "") + " " + corpus[doc_id].get("text", "")).strip() 64 | # print(corpus_text) 65 | docs.append(corpus_text) 66 | doc_ids.append(doc_id) 67 | if prompt is None: 68 | query.append(queries[query_id]) 69 | # sentence_pairs.append([queries[query_id], corpus_text]) 70 | else: 71 | query.append("{0} [SEP] {1}".format(prompt, queries[query_id])) 72 | # sentence_pairs.append(["{0} [SEP] {1}".format(prompt, queries[query_id]), corpus_text]) 73 | 74 | # run inference 75 | features = self.tokenizer(query, docs, padding=True, truncation=True, max_length=512, return_tensors="pt").to('cuda') 76 | with torch.no_grad(): 77 | scores = self.model(**features).logits 78 | normalized_scores = F.softmax(scores, dim=1) 79 | final_scores = [float(score[1]) for score in normalized_scores] 80 | # print(final_scores) 81 | for doc_id, score in zip(doc_ids, final_scores): 82 | self.rerank_results[query_id][doc_id] = score 83 | 84 | # #### Starting to Rerank using cross-attention 85 | # logging.info("Starting To Rerank Top-{}....".format(top_k)) 86 | 87 | # rerank_scores = [float(score[1]) for score in self.cross_encoder.predict(sentence_pairs, batch_size=self.batch_size)] 88 | # #### Reranking results 89 | # self.rerank_results = {query_id: {} for query_id in results} 90 | # for pair, score in zip(pair_ids, rerank_scores): 91 | # query_id, doc_id = pair[0], pair[1] 92 | # self.rerank_results[query_id][doc_id] = score 93 | # print(self.rerank_results) 94 | return self.rerank_results -------------------------------------------------------------------------------- /TART/src/slurm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from logging import getLogger 8 | import os 9 | import sys 10 | import torch 11 | import socket 12 | import signal 13 | import subprocess 14 | 15 | 16 | logger = getLogger() 17 | 18 | def sig_handler(signum, frame): 19 | logger.warning("Signal handler called with signal " + str(signum)) 20 | prod_id = int(os.environ['SLURM_PROCID']) 21 | logger.warning("Host: %s - Global rank: %i" % (socket.gethostname(), prod_id)) 22 | if prod_id == 0: 23 | logger.warning("Requeuing job " + os.environ['SLURM_JOB_ID']) 24 | os.system('scontrol requeue ' + os.environ['SLURM_JOB_ID']) 25 | else: 26 | logger.warning("Not the main process, no need to requeue.") 27 | sys.exit(-1) 28 | 29 | 30 | def term_handler(signum, frame): 31 | logger.warning("Signal handler called with signal " + str(signum)) 32 | logger.warning("Bypassing SIGTERM.") 33 | 34 | 35 | def init_signal_handler(): 36 | """ 37 | Handle signals sent by SLURM for time limit / pre-emption. 38 | """ 39 | signal.signal(signal.SIGUSR1, sig_handler) 40 | signal.signal(signal.SIGTERM, term_handler) 41 | 42 | 43 | def init_distributed_mode(params): 44 | """ 45 | Handle single and multi-GPU / multi-node / SLURM jobs. 46 | Initialize the following variables: 47 | - local_rank 48 | - global_rank 49 | - world_size 50 | """ 51 | is_slurm_job = 'SLURM_JOB_ID' in os.environ and not 'WORLD_SIZE' in os.environ 52 | print("is slurm job {}".format(is_slurm_job)) 53 | if 'WORLD_SIZE' in os.environ: 54 | print("world size: {}".format(os.environ['WORLD_SIZE'])) 55 | has_local_rank = hasattr(params, 'local_rank') 56 | print("has local rank? {}".format(has_local_rank)) 57 | 58 | # SLURM job without torch.distributed.launch 59 | if is_slurm_job and has_local_rank: 60 | 61 | assert params.local_rank == -1 # on the cluster, this is handled by SLURM 62 | 63 | # local rank on the current node / global rank 64 | params.local_rank = int(os.environ['SLURM_LOCALID']) 65 | params.global_rank = int(os.environ['SLURM_PROCID']) 66 | params.world_size = int(os.environ['SLURM_NTASKS']) 67 | 68 | # define master address and master port 69 | hostnames = subprocess.check_output(['scontrol', 'show', 'hostnames', os.environ['SLURM_JOB_NODELIST']]) 70 | params.main_addr = hostnames.split()[0].decode('utf-8') 71 | assert 10001 <= params.main_port <= 20000 or params.world_size == 1 72 | 73 | # set environment variables for 'env://' 74 | os.environ['MASTER_ADDR'] = params.main_addr 75 | os.environ['MASTER_PORT'] = str(params.main_port) 76 | os.environ['WORLD_SIZE'] = str(params.world_size) 77 | os.environ['RANK'] = str(params.global_rank) 78 | is_distributed = True 79 | 80 | 81 | # multi-GPU job (local or multi-node) - jobs started with torch.distributed.launch 82 | elif has_local_rank and params.local_rank != -1: 83 | 84 | assert params.main_port == -1 85 | 86 | # read environment variables 87 | params.global_rank = int(os.environ['RANK']) 88 | params.world_size = int(os.environ['WORLD_SIZE']) 89 | 90 | is_distributed = True 91 | 92 | # local job (single GPU) 93 | else: 94 | params.local_rank = 0 95 | params.global_rank = 0 96 | params.world_size = 1 97 | is_distributed = False 98 | 99 | # set GPU device 100 | torch.cuda.set_device(params.local_rank) 101 | 102 | # initialize multi-GPU 103 | if is_distributed: 104 | 105 | # http://pytorch.apachecn.org/en/0.3.0/distributed.html#environment-variable-initialization 106 | # 'env://' will read these environment variables: 107 | # MASTER_PORT - required; has to be a free port on machine with rank 0 108 | # MASTER_ADDR - required (except for rank 0); address of rank 0 node 109 | # WORLD_SIZE - required; can be set either here, or in a call to init function 110 | # RANK - required; can be set either here, or in a call to init function 111 | 112 | #print("Initializing PyTorch distributed ...") 113 | torch.distributed.init_process_group( 114 | init_method='env://', 115 | backend='nccl', 116 | #world_size=params.world_size, 117 | #rank=params.global_rank, 118 | ) 119 | -------------------------------------------------------------------------------- /TART/src/tokenization_enc_t5.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | from typing import Any, Dict, List, Optional 4 | 5 | from transformers import T5Tokenizer 6 | 7 | 8 | class EncT5Tokenizer(T5Tokenizer): 9 | def __init__( 10 | self, 11 | vocab_file, 12 | bos_token="", 13 | eos_token="", 14 | unk_token="", 15 | pad_token="", 16 | extra_ids=100, 17 | additional_special_tokens=None, 18 | sp_model_kwargs: Optional[Dict[str, Any]] = None, 19 | **kwargs, 20 | ) -> None: 21 | sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs 22 | 23 | super().__init__( 24 | vocab_file=vocab_file, 25 | bos_token=bos_token, 26 | eos_token=eos_token, 27 | unk_token=unk_token, 28 | pad_token=pad_token, 29 | extra_ids=extra_ids, 30 | additional_special_tokens=additional_special_tokens, 31 | sp_model_kwargs=sp_model_kwargs, 32 | **kwargs, 33 | ) 34 | 35 | def get_special_tokens_mask( 36 | self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False 37 | ) -> List[int]: 38 | """ 39 | Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding 40 | special tokens using the tokenizer `prepare_for_model` method. 41 | Args: 42 | token_ids_0 (`List[int]`): 43 | List of IDs. 44 | token_ids_1 (`List[int]`, *optional*): 45 | Optional second list of IDs for sequence pairs. 46 | already_has_special_tokens (`bool`, *optional*, defaults to `False`): 47 | Whether or not the token list is already formatted with special tokens for the model. 48 | Returns: 49 | `List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. 50 | """ 51 | if already_has_special_tokens: 52 | return super().get_special_tokens_mask( 53 | token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True 54 | ) 55 | 56 | # normal case: some special tokens 57 | if token_ids_1 is None: 58 | return [1] + ([0] * len(token_ids_0)) + [1] 59 | return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1] 60 | 61 | def create_token_type_ids_from_sequences( 62 | self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None 63 | ) -> List[int]: 64 | """ 65 | Create a mask from the two sequences passed to be used in a sequence-pair classification task. T5 does not make 66 | use of token type ids, therefore a list of zeros is returned. 67 | Args: 68 | token_ids_0 (`List[int]`): 69 | List of IDs. 70 | token_ids_1 (`List[int]`, *optional*): 71 | Optional second list of IDs for sequence pairs. 72 | Returns: 73 | `List[int]`: List of zeros. 74 | """ 75 | bos = [self.bos_token_id] 76 | eos = [self.eos_token_id] 77 | 78 | if token_ids_1 is None: 79 | return len(bos + token_ids_0 + eos) * [0] 80 | return len(bos + token_ids_0 + eos + token_ids_1 + eos) * [0] 81 | 82 | def build_inputs_with_special_tokens( 83 | self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None 84 | ) -> List[int]: 85 | """ 86 | Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and 87 | adding special tokens. A sequence has the following format: 88 | - single sequence: ` X ` 89 | - pair of sequences: ` A B ` 90 | Args: 91 | token_ids_0 (`List[int]`): 92 | List of IDs to which the special tokens will be added. 93 | token_ids_1 (`List[int]`, *optional*): 94 | Optional second list of IDs for sequence pairs. 95 | Returns: 96 | `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens. 97 | """ 98 | if token_ids_1 is None: 99 | return [self.bos_token_id] + token_ids_0 + [self.eos_token_id] 100 | else: 101 | return [self.bos_token_id] + token_ids_0 + [self.eos_token_id] + token_ids_1 + [self.eos_token_id] 102 | -------------------------------------------------------------------------------- /TART/src/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | import os 4 | import sys 5 | import logging 6 | import torch 7 | import errno 8 | from typing import Union, Tuple, List, Dict 9 | from collections import defaultdict 10 | from transformers import T5EncoderModel 11 | 12 | from src import dist_utils 13 | 14 | Number = Union[float, int] 15 | 16 | logger = logging.getLogger(__name__) 17 | 18 | 19 | def init_logger(args, stdout_only=False): 20 | if torch.distributed.is_initialized(): 21 | torch.distributed.barrier() 22 | stdout_handler = logging.StreamHandler(sys.stdout) 23 | handlers = [stdout_handler] 24 | if not stdout_only: 25 | file_handler = logging.FileHandler(filename=os.path.join(args.output_dir, "run.log")) 26 | handlers.append(file_handler) 27 | logging.basicConfig( 28 | datefmt="%m/%d/%Y %H:%M:%S", 29 | level=logging.INFO if dist_utils.is_main() else logging.WARN, 30 | format="[%(asctime)s] {%(filename)s:%(lineno)d} %(levelname)s - %(message)s", 31 | handlers=handlers, 32 | ) 33 | return logger 34 | 35 | 36 | def symlink_force(target, link_name): 37 | try: 38 | os.symlink(target, link_name) 39 | except OSError as e: 40 | if e.errno == errno.EEXIST: 41 | os.remove(link_name) 42 | os.symlink(target, link_name) 43 | else: 44 | raise e 45 | 46 | 47 | def save(model, optimizer, scheduler, step, opt, dir_path, name): 48 | model_to_save = model.module if hasattr(model, "module") else model 49 | path = os.path.join(dir_path, "checkpoint") 50 | epoch_path = os.path.join(path, name) # "step-%s" % step) 51 | os.makedirs(epoch_path, exist_ok=True) 52 | cp = os.path.join(path, "latest") 53 | fp = os.path.join(epoch_path, "checkpoint.pth") 54 | checkpoint = { 55 | "step": step, 56 | "model": model_to_save.state_dict(), 57 | "optimizer": optimizer.state_dict(), 58 | "scheduler": scheduler.state_dict(), 59 | "opt": opt, 60 | } 61 | torch.save(checkpoint, fp) 62 | symlink_force(epoch_path, cp) 63 | if not name == "lastlog": 64 | logger.info(f"Saving model to {epoch_path}") 65 | 66 | 67 | def load(model_class, dir_path, opt, reset_params=False): 68 | epoch_path = os.path.realpath(dir_path) 69 | checkpoint_path = os.path.join(epoch_path, "checkpoint.pth") 70 | logger.info(f"loading checkpoint {checkpoint_path}") 71 | checkpoint = torch.load(checkpoint_path, map_location="cpu") 72 | opt_checkpoint = checkpoint["opt"] 73 | state_dict = checkpoint["model"] 74 | 75 | model = model_class(opt_checkpoint) 76 | model.load_state_dict(state_dict, strict=True) 77 | model = model.cuda() 78 | step = checkpoint["step"] 79 | if not reset_params: 80 | optimizer, scheduler = set_optim(opt_checkpoint, model) 81 | scheduler.load_state_dict(checkpoint["scheduler"]) 82 | optimizer.load_state_dict(checkpoint["optimizer"]) 83 | else: 84 | optimizer, scheduler = set_optim(opt, model) 85 | 86 | return model, optimizer, scheduler, opt_checkpoint, step 87 | 88 | 89 | ############ OPTIM 90 | 91 | 92 | class WarmupLinearScheduler(torch.optim.lr_scheduler.LambdaLR): 93 | def __init__(self, optimizer, warmup, total, ratio, last_epoch=-1): 94 | self.warmup = warmup 95 | self.total = total 96 | self.ratio = ratio 97 | super(WarmupLinearScheduler, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch) 98 | 99 | def lr_lambda(self, step): 100 | if step < self.warmup: 101 | return (1 - self.ratio) * step / float(max(1, self.warmup)) 102 | 103 | return max( 104 | 0.0, 105 | 1.0 + (self.ratio - 1) * (step - self.warmup) / float(max(1.0, self.total - self.warmup)), 106 | ) 107 | 108 | 109 | class CosineScheduler(torch.optim.lr_scheduler.LambdaLR): 110 | def __init__(self, optimizer, warmup, total, ratio=0.1, last_epoch=-1): 111 | self.warmup = warmup 112 | self.total = total 113 | self.ratio = ratio 114 | super(CosineScheduler, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch) 115 | 116 | def lr_lambda(self, step): 117 | if step < self.warmup: 118 | return float(step) / self.warmup 119 | s = float(step - self.warmup) / (self.total - self.warmup) 120 | return self.ratio + (1.0 - self.ratio) * math.cos(0.5 * math.pi * s) 121 | 122 | 123 | def set_optim(opt, model): 124 | if opt.optim == "adamw": 125 | optimizer = torch.optim.AdamW( 126 | model.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2), eps=opt.eps, weight_decay=opt.weight_decay 127 | ) 128 | else: 129 | raise NotImplementedError("optimizer class not implemented") 130 | 131 | scheduler_args = { 132 | "warmup": opt.warmup_steps, 133 | "total": opt.total_steps, 134 | "ratio": opt.lr_min_ratio, 135 | } 136 | if opt.scheduler == "linear": 137 | scheduler_class = WarmupLinearScheduler 138 | elif opt.scheduler == "cosine": 139 | scheduler_class = CosineScheduler 140 | else: 141 | raise ValueError 142 | scheduler = scheduler_class(optimizer, **scheduler_args) 143 | return optimizer, scheduler 144 | 145 | 146 | def get_parameters(net, verbose=False): 147 | num_params = 0 148 | for param in net.parameters(): 149 | num_params += param.numel() 150 | message = "[Network] Total number of parameters : %.6f M" % (num_params / 1e6) 151 | return message 152 | 153 | 154 | class WeightedAvgStats: 155 | """provides an average over a bunch of stats""" 156 | 157 | def __init__(self): 158 | self.raw_stats: Dict[str, float] = defaultdict(float) 159 | self.total_weights: Dict[str, float] = defaultdict(float) 160 | 161 | def update(self, vals: Dict[str, Tuple[Number, Number]]) -> None: 162 | for key, (value, weight) in vals.items(): 163 | self.raw_stats[key] += value * weight 164 | self.total_weights[key] += weight 165 | 166 | @property 167 | def stats(self) -> Dict[str, float]: 168 | return {x: self.raw_stats[x] / self.total_weights[x] for x in self.raw_stats.keys()} 169 | 170 | @property 171 | def tuple_stats(self) -> Dict[str, Tuple[float, float]]: 172 | return {x: (self.raw_stats[x] / self.total_weights[x], self.total_weights[x]) for x in self.raw_stats.keys()} 173 | 174 | def reset(self) -> None: 175 | self.raw_stats = defaultdict(float) 176 | self.total_weights = defaultdict(float) 177 | 178 | @property 179 | def average_stats(self) -> Dict[str, float]: 180 | keys = sorted(self.raw_stats.keys()) 181 | if torch.distributed.is_initialized(): 182 | torch.distributed.broadcast_object_list(keys, src=0) 183 | global_dict = {} 184 | for k in keys: 185 | if not k in self.total_weights: 186 | v = 0.0 187 | else: 188 | v = self.raw_stats[k] / self.total_weights[k] 189 | v, _ = dist_utils.weighted_average(v, self.total_weights[k]) 190 | global_dict[k] = v 191 | return global_dict 192 | 193 | 194 | def load_hf(object_class, model_name): 195 | if "gtr-t5" in model_name: 196 | obj = T5EncoderModel.from_pretrained(model_name) 197 | try: 198 | obj = object_class.from_pretrained(model_name, local_files_only=True) 199 | except: 200 | obj = object_class.from_pretrained(model_name, local_files_only=False) 201 | return obj 202 | 203 | 204 | def init_tb_logger(output_dir): 205 | try: 206 | from torch.utils import tensorboard 207 | 208 | if dist_utils.is_main(): 209 | tb_logger = tensorboard.SummaryWriter(output_dir) 210 | else: 211 | tb_logger = None 212 | except: 213 | logger.warning("Tensorboard is not available.") 214 | tb_logger = None 215 | 216 | return tb_logger 217 | -------------------------------------------------------------------------------- /TART/train.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | import os 4 | import time 5 | import sys 6 | import torch 7 | import logging 8 | import json 9 | import numpy as np 10 | import random 11 | import pickle 12 | 13 | import torch.distributed as dist 14 | from torch.utils.data import DataLoader, RandomSampler 15 | 16 | from src.options import Options 17 | from src import data, beir_utils, slurm, dist_utils, utils 18 | from src import moco, inbatch 19 | 20 | 21 | logger = logging.getLogger(__name__) 22 | 23 | 24 | def train(opt, model, optimizer, scheduler, step): 25 | 26 | run_stats = utils.WeightedAvgStats() 27 | 28 | tb_logger = utils.init_tb_logger(opt.output_dir) 29 | 30 | logger.info("Data loading") 31 | if isinstance(model, torch.nn.parallel.DistributedDataParallel): 32 | tokenizer = model.module.tokenizer 33 | else: 34 | tokenizer = model.tokenizer 35 | collator = data.Collator(opt=opt) 36 | train_dataset = data.load_data(opt, tokenizer) 37 | logger.warning(f"Data loading finished for rank {dist_utils.get_rank()}") 38 | 39 | train_sampler = RandomSampler(train_dataset) 40 | train_dataloader = DataLoader( 41 | train_dataset, 42 | sampler=train_sampler, 43 | batch_size=opt.per_gpu_batch_size, 44 | drop_last=True, 45 | num_workers=opt.num_workers, 46 | collate_fn=collator, 47 | ) 48 | 49 | epoch = 1 50 | 51 | model.train() 52 | while step < opt.total_steps: 53 | train_dataset.generate_offset() 54 | 55 | logger.info(f"Start epoch {epoch}") 56 | for i, batch in enumerate(train_dataloader): 57 | step += 1 58 | 59 | batch = {key: value.cuda() if isinstance(value, torch.Tensor) else value for key, value in batch.items()} 60 | train_loss, iter_stats = model(**batch, stats_prefix="train") 61 | 62 | train_loss.backward() 63 | optimizer.step() 64 | 65 | scheduler.step() 66 | model.zero_grad() 67 | 68 | run_stats.update(iter_stats) 69 | 70 | if step % opt.log_freq == 0: 71 | log = f"{step} / {opt.total_steps}" 72 | for k, v in sorted(run_stats.average_stats.items()): 73 | log += f" | {k}: {v:.3f}" 74 | if tb_logger: 75 | tb_logger.add_scalar(k, v, step) 76 | log += f" | lr: {scheduler.get_last_lr()[0]:0.3g}" 77 | log += f" | Memory: {torch.cuda.max_memory_allocated()//1e9} GiB" 78 | 79 | logger.info(log) 80 | run_stats.reset() 81 | 82 | if step % opt.eval_freq == 0: 83 | if isinstance(model, torch.nn.parallel.DistributedDataParallel): 84 | encoder = model.module.get_encoder() 85 | else: 86 | encoder = model.get_encoder() 87 | eval_model( 88 | opt, query_encoder=encoder, doc_encoder=encoder, tokenizer=tokenizer, tb_logger=tb_logger, step=step 89 | ) 90 | 91 | if dist_utils.is_main() and step % opt.save_freq == 0: 92 | utils.save(model, optimizer, scheduler, step, opt, opt.output_dir, f"step-{step}") 93 | 94 | model.train() 95 | if step > opt.total_steps: 96 | break 97 | epoch += 1 98 | 99 | 100 | def eval_model(opt, query_encoder, doc_encoder, tokenizer, tb_logger, step): 101 | for datasetname in opt.eval_datasets: 102 | metrics = beir_utils.evaluate_model( 103 | query_encoder, 104 | doc_encoder, 105 | tokenizer, 106 | dataset=datasetname, 107 | batch_size=opt.per_gpu_eval_batch_size, 108 | norm_doc=opt.norm_doc, 109 | norm_query=opt.norm_query, 110 | beir_dir=opt.eval_datasets_dir, 111 | score_function=opt.score_function, 112 | lower_case=opt.lower_case, 113 | normalize_text=opt.eval_normalize_text, 114 | ) 115 | 116 | message = [] 117 | if dist_utils.is_main(): 118 | for metric in ["NDCG@10", "Recall@10", "Recall@100"]: 119 | message.append(f"{datasetname}/{metric}: {metrics[metric]:.2f}") 120 | if tb_logger is not None: 121 | tb_logger.add_scalar(f"{datasetname}/{metric}", metrics[metric], step) 122 | logger.info(" | ".join(message)) 123 | 124 | 125 | if __name__ == "__main__": 126 | logger.info("Start") 127 | 128 | options = Options() 129 | opt = options.parse() 130 | 131 | torch.manual_seed(opt.seed) 132 | slurm.init_distributed_mode(opt) 133 | slurm.init_signal_handler() 134 | 135 | directory_exists = os.path.isdir(opt.output_dir) 136 | if dist.is_initialized(): 137 | dist.barrier() 138 | os.makedirs(opt.output_dir, exist_ok=True) 139 | if not directory_exists and dist_utils.is_main(): 140 | options.print_options(opt) 141 | if dist.is_initialized(): 142 | dist.barrier() 143 | utils.init_logger(opt) 144 | 145 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 146 | 147 | if opt.contrastive_mode == "moco": 148 | model_class = moco.MoCo 149 | elif opt.contrastive_mode == "inbatch": 150 | model_class = inbatch.InBatch 151 | else: 152 | raise ValueError(f"contrastive mode: {opt.contrastive_mode} not recognised") 153 | 154 | if not directory_exists and opt.model_path == "none": 155 | model = model_class(opt) 156 | model = model.cuda() 157 | optimizer, scheduler = utils.set_optim(opt, model) 158 | step = 0 159 | elif directory_exists: 160 | model_path = os.path.join(opt.output_dir, "checkpoint", "latest") 161 | model, optimizer, scheduler, opt_checkpoint, step = utils.load( 162 | model_class, 163 | model_path, 164 | opt, 165 | reset_params=False, 166 | ) 167 | logger.info(f"Model loaded from {opt.output_dir}") 168 | else: 169 | model, optimizer, scheduler, opt_checkpoint, step = utils.load( 170 | model_class, 171 | opt.model_path, 172 | opt, 173 | reset_params=False if opt.continue_training else True, 174 | ) 175 | if not opt.continue_training: 176 | step = 0 177 | logger.info(f"Model loaded from {opt.model_path}") 178 | 179 | logger.info(utils.get_parameters(model)) 180 | 181 | if dist.is_initialized(): 182 | model = torch.nn.parallel.DistributedDataParallel( 183 | model, 184 | device_ids=[opt.local_rank], 185 | output_device=opt.local_rank, 186 | find_unused_parameters=False, 187 | ) 188 | dist.barrier() 189 | 190 | logger.info("Start training") 191 | train(opt, model, optimizer, scheduler, step) 192 | -------------------------------------------------------------------------------- /cross_task_cross_eval/create_cross_task_data.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | import pandas as pd 4 | import collections 5 | import numpy 6 | import json 7 | import jsonlines 8 | import csv 9 | import os 10 | import random 11 | import tqdm 12 | import datasets 13 | wikiqa = datasets.load_dataset("wiki_qa") 14 | 15 | def load_jsonlines(file_name): 16 | with jsonlines.open(file_name, 'r') as jsonl_f: 17 | data = [obj for obj in jsonl_f] 18 | return data 19 | 20 | 21 | os.mkdir("ambig") 22 | os.mkdir("wikiqa") 23 | os.mkdir("gooqa_tech") 24 | os.mkdir("linkso_py") 25 | os.mkdir("codesearch_py") 26 | 27 | 28 | # WIKIQA 29 | corpus = {} 30 | for split in ["train", "test", "validation"]: 31 | for item in wikiqa[split]: 32 | corpus.setdefault(item["document_title"], []) 33 | if item["answer"] not in corpus[item["document_title"]]: 34 | corpus[item["document_title"]].append(item["answer"]) 35 | 36 | final_corpus = [] 37 | for title in corpus: 38 | for idx, doc in enumerate(corpus[title]): 39 | final_corpus.append({"title": title, "text": doc, "_id":"{0}_{1}".format(title, idx), "metadata": {}}) 40 | 41 | final_qrel_data = [] 42 | final_queries = [] 43 | for item in wikiqa["validation"]: 44 | question_id = item["question_id"] 45 | question = item["question"] 46 | if item["label"] == 1: 47 | corpus_id = corpus[item["document_title"]].index(item["answer"]) 48 | final_queries.append({"_id": "wikiqa_{}".format(question_id), "text": question, "metadata": {}}) 49 | final_qrel_data.append({"query-id": "wikiqa_{}".format(question_id), "corpus-id": "{0}_{1}".format(item["document_title"], corpus_id), "score": 1}) 50 | for item in wikiqa["test"]: 51 | question_id = item["question_id"] 52 | question = item["question"] 53 | if item["label"] == 1: 54 | corpus_id = corpus[item["document_title"]].index(item["answer"]) 55 | final_queries.append({"_id": "wikiqa_{}".format(question_id), "text": question, "metadata": {}}) 56 | final_qrel_data.append({"query-id": "wikiqa_{}".format(question_id), "corpus-id": "{0}_{1}".format(item["document_title"], corpus_id), "score": 1}) 57 | 58 | q2dic = {} 59 | for item in final_queries: 60 | q2dic[item["_id"]] = item 61 | final_queries = list(q2dic.values()) 62 | 63 | with jsonlines.open('wikiqa/queries.jsonl', 'w') as writer: 64 | writer.write_all(final_queries) 65 | with jsonlines.open('wikiqa/corpus.jsonl', 'w') as writer: 66 | writer.write_all(final_corpus) 67 | with open('wikiqa/qrels/test.tsv', 'wt') as out_file: 68 | tsv_writer = csv.writer(out_file, delimiter='\t') 69 | tsv_writer.writerow(['query-id', 'corpus-id', "score"]) 70 | for item in final_qrel_data: 71 | tsv_writer.writerow([item["query-id"], item["corpus-id"], item["score"]]) 72 | 73 | 74 | # Ambig QA 75 | ambigqa_path = "data/ambignq_light/" 76 | ambigqa_dev_data = load_jsonlines(ambigqa_path + "dev_light.json")[0] 77 | ambigqa_train_data = load_jsonlines(ambigqa_path + "train_light.json")[0] 78 | 79 | ambig_evals_qq = {} 80 | final_queries = [] 81 | final_corpus = [] 82 | count = 0 83 | for item in ambigqa_train_data: 84 | for an in item["annotations"]: 85 | if an["type"] == "multipleQAs": 86 | qa_pairs = an["qaPairs"] 87 | for q in qa_pairs: 88 | final_corpus.append({"_id": "ambig_train_{0}".format(count), "text": q["question"], "title": "", "metadata": {}}) 89 | count += 1 90 | 91 | final_qrels = [] 92 | count = 0 93 | for item in ambigqa_dev_data: 94 | for an in item["annotations"]: 95 | if an["type"] == "multipleQAs": 96 | qa_pairs = an["qaPairs"] 97 | for qa in qa_pairs: 98 | target_id = "ambig_test_{0}".format(count) 99 | final_corpus.append({"_id": "ambig_test_{0}".format(count), "text": qa["question"], "title": "" , "metadata": {}}) 100 | final_queries.append({"_id": "ambig_nq_source_{}".format(item["id"]), "text": item["question"], "metadata": {}}) 101 | final_qrels.append({"corpus-id": "ambig_test_{0}".format(count), "query-id": "ambig_nq_source_{}".format(item["id"]), "score": 1}) 102 | count += 1 103 | 104 | q2dic = {} 105 | for item in final_queries: 106 | q2dic[item["_id"]] = item 107 | final_queries = list(q2dic.values()) 108 | 109 | with jsonlines.open('ambig/queries.jsonl', 'w') as writer: 110 | writer.write_all(final_queries) 111 | with jsonlines.open('ambig/corpus.jsonl', 'w') as writer: 112 | writer.write_all(final_corpus) 113 | with open('ambig/qrels/test.tsv', 'wt') as out_file: 114 | tsv_writer = csv.writer(out_file, delimiter='\t') 115 | tsv_writer.writerow(['query-id', 'corpus-id', "score"]) 116 | for item in final_qrels: 117 | tsv_writer.writerow([item["query-id"], item["corpus-id"], item["score"]]) 118 | 119 | 120 | # GooAQ technical 121 | gooqa_technical = [] 122 | gooaq_data = load_jsonlines("data/gooaq.jsonl") 123 | 124 | for item in gooaq_data: 125 | if item["answer_type"] == "feat_snip" and item["answer_url"] is not None and "https://" in item["answer_url"] : 126 | url = item["answer_url"].split("https://")[1].split("/")[0] 127 | if url == "stackoverflow.com": 128 | item["url_processed"] = url 129 | gooqa_technical.append(item) 130 | 131 | random_sampled_qooaq_tech = random.sample(gooqa_technical, k=1000) 132 | 133 | full_corpus = [] 134 | full_queries = [] 135 | full_qrels = [] 136 | answer2id = {} 137 | 138 | for idx, item in enumerate(gooqa_technical): 139 | full_corpus.append({"_id": "{0}_{1}".format(item["url_processed"], idx), "text": item["answer"], "title": "", "metadata": {}}) 140 | answer2id[item["answer"]] = "{0}_{1}".format(item["url_processed"], idx) 141 | 142 | 143 | for item in random_sampled_qooaq_tech: 144 | full_queries.append({"_id": "gooaq_technical_{}".format(item["id"]), "text": item["question"], "metadata": {}}) 145 | corpus_id = answer2id[item["answer"]] 146 | full_qrels.append({"query-id": "gooaq_technical_{}".format(item["id"]), "corpus-id": corpus_id, "score": 1}) 147 | 148 | 149 | os.mkdir("gooaq_technical/qrels") 150 | with jsonlines.open('gooaq_technical/queries.jsonl', 'w') as writer: 151 | writer.write_all(full_queries) 152 | with jsonlines.open('gooaq_technical/corpus.jsonl', 'w') as writer: 153 | writer.write_all(full_corpus) 154 | with open('gooaq_technical/qrels/test.tsv', 'wt') as out_file: 155 | tsv_writer = csv.writer(out_file, delimiter='\t') 156 | tsv_writer.writerow(['query-id', 'corpus-id', "score"]) 157 | for item in full_qrels: 158 | tsv_writer.writerow([item["query-id"], item["corpus-id"], item["score"]]) 159 | 160 | 161 | # LinkSO 162 | def find_duplicated_qestions(dir, lang): 163 | duplicated_q_pairs = [] 164 | qid2all = pd.read_csv(os.path.join(dir, "{}_qid2all.txt".format(lang)), sep="\t", header=None) 165 | qid2all_dic = {} 166 | for idx, row in qid2all.iterrows(): 167 | qid2all_dic[int(row[0])] = {"title": row[1], "body": row[2]} 168 | 169 | cosin = pd.read_csv(os.path.join(dir, "{}_cosidf.txt".format(lang)), sep="\t") 170 | dup_pair_ids = {} 171 | for idx, row in cosin.iterrows(): 172 | if row["label"] == 1: 173 | dup_pair_ids[int(row["qid1"])] = int(row["qid2"]) 174 | 175 | test_qs = open(os.path.join(dir, "{}_test_qid.txt".format(lang))).read().split("\n")[:-1] 176 | for q_id in test_qs: 177 | if int(q_id) in dup_pair_ids: 178 | dup_id = dup_pair_ids[int(q_id)] 179 | duplicated_q_pairs.append((qid2all_dic[int(q_id)], qid2all_dic[dup_id])) 180 | return duplicated_q_pairs 181 | 182 | linkso_data_python = "/private/home/akariasai/inst_dpr/preprocessing/linkso_data/topublish/python" 183 | full_corpus, full_queries, full_qrels = find_duplicated_qestions(linkso_data_python, "python") 184 | linkso_dups_python = find_duplicated_qestions(linkso_data_python, "python") 185 | qid2queries = {item["_id"]: item["text"] for item in full_queries} 186 | qid2corpus = {item["_id"]: item for item in full_corpus} 187 | 188 | with jsonlines.open('linkso_py/queries.jsonl', 'w') as writer: 189 | writer.write_all(full_queries) 190 | with jsonlines.open('linkso_py/corpus.jsonl', 'w') as writer: 191 | writer.write_all(full_corpus) 192 | with open('linkso_py/qrels/test.tsv', 'wt') as out_file: 193 | tsv_writer = csv.writer(out_file, delimiter='\t') 194 | tsv_writer.writerow(['query-id', 'corpus-id', "score"]) 195 | for item in full_qrels: 196 | tsv_writer.writerow([item["query-id"], item["corpus-id"], item["score"]]) 197 | 198 | # CodeSearch Net Py 199 | python_code_serach_net = datasets.load_dataset("code_search_net", "python") 200 | python_short_descs = [item for item in python_code_serach_net["test"] if len(item["func_documentation_string"]) < 300 and len(item["func_documentation_string"]) > 50] 201 | 202 | full_corpus = [] 203 | full_queries = [] 204 | full_qrels = [] 205 | answer2id = {} 206 | 207 | for idx, item in tqdm(enumerate(python_code_serach_net["train"])): 208 | doc_id = "codeserachnet_python_train_{0}_{1}".format(idx, item["func_name"]) 209 | if '"""' in item["func_code_string"]: 210 | code = (item["func_code_string"].split('"""')[0] + item["func_code_string"].split('"""')[2]).replace("\n\n", "") 211 | elif "'''" in item["func_code_string"]: 212 | code = (item["func_code_string"].split("'''")[0] + item["func_code_string"].split("'''")[2]).replace("\n\n", "") 213 | else: 214 | code = item["func_code_string"] 215 | full_corpus.append({"_id": doc_id, "text": code, "metadata": {}, "title": "" }) 216 | answer2id[code] = doc_id 217 | 218 | for idx, item in tqdm(enumerate(python_code_serach_net["validation"])): 219 | doc_id = "codeserachnet_python_validation_{0}_{1}".format(idx, item["func_name"]) 220 | if '"""' in item["func_code_string"]: 221 | code = (item["func_code_string"].split('"""')[0] + item["func_code_string"].split('"""')[2]).replace("\n\n", "") 222 | elif "'''" in item["func_code_string"]: 223 | code = (item["func_code_string"].split("'''")[0] + item["func_code_string"].split("'''")[2]).replace("\n\n", "") 224 | else: 225 | code = item["func_code_string"] 226 | full_corpus.append({"_id": doc_id, "text": code, "metadata": {}, "title": "" }) 227 | answer2id[code] = doc_id 228 | 229 | for idx, item in tqdm(enumerate(python_code_serach_net["test"])): 230 | doc_id = "codeserachnet_python_test_{0}_{1}".format(idx, item["func_name"]) 231 | if '"""' in item["func_code_string"]: 232 | code = (item["func_code_string"].split('"""')[0] + item["func_code_string"].split('"""')[2]).replace("\n\n", "") 233 | elif "'''" in item["func_code_string"]: 234 | code = (item["func_code_string"].split("'''")[0] + item["func_code_string"].split("'''")[2]).replace("\n\n", "") 235 | else: 236 | code = item["func_code_string"] 237 | full_corpus.append({"_id": doc_id, "text": code, "metadata": {}, "title": "" }) 238 | answer2id[code] = doc_id 239 | 240 | random_sampled_python_short_descs= random.sample(python_short_descs, k = 1000) 241 | 242 | for idx, item in enumerate(random_sampled_python_short_descs): 243 | qid = "python_codesearch_{}".format(idx) 244 | 245 | query = item["func_documentation_string"] 246 | if '"""' in item["func_code_string"]: 247 | code = (item["func_code_string"].split('"""')[0] + item["func_code_string"].split('"""')[2]).replace("\n\n", "") 248 | elif "'''" in item["func_code_string"]: 249 | code = (item["func_code_string"].split("'''")[0] + item["func_code_string"].split("'''")[2]).replace("\n\n", "") 250 | else: 251 | code = item["func_code_string"] 252 | corpus_id = answer2id[code] 253 | full_queries.append({"_id": qid, "text": query, "metadata": {}}) 254 | full_qrels.append({"corpus-id": corpus_id, "query-id": qid, "score": 1}) 255 | 256 | os.mkdir("codesearch_py/qrels") 257 | with jsonlines.open('codesearch_py/queries.jsonl', 'w') as writer: 258 | writer.write_all(full_queries) 259 | with jsonlines.open('codesearch_py/corpus.jsonl', 'w') as writer: 260 | writer.write_all(full_corpus) 261 | with open('codesearch_py/qrels/test.tsv', 'wt') as out_file: 262 | tsv_writer = csv.writer(out_file, delimiter='\t') 263 | tsv_writer.writerow(['query-id', 'corpus-id', "score"]) 264 | for item in full_qrels: 265 | tsv_writer.writerow([item["query-id"], item["corpus-id"], item["score"]]) -------------------------------------------------------------------------------- /cross_task_cross_eval/download_create_data.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | mkdir data 4 | cd data 5 | wget https://github.com/allenai/gooaq/raw/main/data/gooaq.jsonl 6 | gdown 1X5GoVi_OcRxahXH1pRW7TSesZUeMH3ss 7 | wget https://nlp.cs.washington.edu/ambigqa/data/ambignq_light.zip 8 | 9 | tar xvzf linkso.tar.gz 10 | unzip ambignq_light 11 | 12 | python create_cross_task_data.py -------------------------------------------------------------------------------- /figures/intro.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/tart/4ed5fcb0ed0254b1062305adbc390617b296fd29/figures/intro.png --------------------------------------------------------------------------------