├── requirements.txt ├── data ├── id_remap.py ├── preprocess_orquac.py ├── process_fn.py ├── preprocess_cast19.py ├── preprocess_cast20.py ├── preprocess_cast21.py ├── tokenizing.py └── gen_ranking_data.py ├── LICENSE ├── .gitignore ├── drivers ├── gen_passage_embeddings.py ├── run_convdr_inference.py └── run_convdr_train.py ├── model └── models.py ├── utils ├── dpr_utils.py └── util.py └── README.md /requirements.txt: -------------------------------------------------------------------------------- 1 | transformers==2.3.0 2 | torch 3 | pytrec-eval 4 | faiss-gpu 5 | sklearn 6 | tensorboardX 7 | trec-car-tools 8 | -------------------------------------------------------------------------------- /data/id_remap.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pickle 3 | 4 | if __name__ == "__main__": 5 | parser = argparse.ArgumentParser() 6 | parser.add_argument("--convdr_trec", type=str) 7 | parser.add_argument("--doc_idx_to_id", type=str) 8 | parser.add_argument("--out_trec", type=str) 9 | args = parser.parse_args() 10 | 11 | print("Loading doc_idx_to_id...") 12 | with open(args.doc_idx_to_id, "rb") as f: 13 | doc_idx_to_id = pickle.load(f) 14 | 15 | with open(args.convdr_trec, "r") as f, open(args.out_trec, "w") as g: 16 | for line in f: 17 | qid, _, pid, rank, score, label = line.strip().split() 18 | g.write("{} Q0 {} {} {} {}\n".format(qid, doc_idx_to_id[int(pid)], rank, score, label)) -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 THUNLP 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | commands/runs/ 4 | runs/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | pip-wheel-metadata/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | *.py,cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | db.sqlite3-journal 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | target/ 78 | 79 | # Jupyter Notebook 80 | .ipynb_checkpoints 81 | 82 | # IPython 83 | profile_default/ 84 | ipython_config.py 85 | 86 | # pyenv 87 | .python-version 88 | 89 | # pipenv 90 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 91 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 92 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 93 | # install all needed dependencies. 94 | #Pipfile.lock 95 | 96 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 97 | __pypackages__/ 98 | 99 | # Celery stuff 100 | celerybeat-schedule 101 | celerybeat.pid 102 | 103 | # SageMath parsed files 104 | *.sage.py 105 | 106 | # Environments 107 | .env 108 | .venv 109 | env/ 110 | venv/ 111 | ENV/ 112 | env.bak/ 113 | venv.bak/ 114 | 115 | # Spyder project settings 116 | .spyderproject 117 | .spyproject 118 | 119 | # Rope project settings 120 | .ropeproject 121 | 122 | # mkdocs documentation 123 | /site 124 | 125 | # mypy 126 | .mypy_cache/ 127 | .dmypy.json 128 | dmypy.json 129 | 130 | # Pyre type checker 131 | .pyre/ 132 | 133 | logs/ 134 | checkpoints/ 135 | datasets/ 136 | results/ 137 | 138 | .vscode/ -------------------------------------------------------------------------------- /data/preprocess_orquac.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import json 4 | from tqdm import tqdm 5 | import pickle 6 | 7 | if __name__ == "__main__": 8 | parser = argparse.ArgumentParser() 9 | parser.add_argument("--orquac_dir", type=str) 10 | parser.add_argument("--output_dir", type=str) 11 | args = parser.parse_args() 12 | 13 | if not os.path.exists(args.output_dir): 14 | os.mkdir(args.output_dir) 15 | 16 | # all_blocks.txt -> collection.tsv 17 | print("Processing all_blocks.txt...") 18 | all_blocks = os.path.join(args.orquac_dir, "all_blocks.txt") 19 | collection = os.path.join(args.output_dir, "collection.jsonl") 20 | passage_id_to_idx = {} 21 | with open(all_blocks, "r") as f, open(collection, "w") as g: 22 | idx = 0 23 | for line in tqdm(f): 24 | obj = json.loads(line) 25 | passage = obj['text'].replace('\n', ' ').replace('\t', ' ') 26 | pid = obj['id'] 27 | g.write( 28 | json.dumps({ 29 | "id": idx, 30 | "title": obj["title"], 31 | "text": passage 32 | }) + "\n") 33 | passage_id_to_idx[pid] = idx 34 | idx += 1 35 | 36 | # train/dev/test.txt -> queries.train/dev/test.manual/raw.tsv 37 | targets = ['train', 'dev', 'test'] 38 | qids_set = {"train": set(), "dev": set(), "test": set()} 39 | idx = 0 40 | for target in targets: 41 | print(f"Processing {target}.txt") 42 | train = os.path.join(args.orquac_dir, "preprocessed", f"{target}.txt") 43 | queries_manual = os.path.join(args.output_dir, 44 | f"queries.{target}.manual.tsv") 45 | queries_raw = os.path.join(args.output_dir, 46 | f"queries.{target}.raw.tsv") 47 | cqr = os.path.join(args.output_dir, "{}.jsonl".format(target)) 48 | with open(train, "r") as f, open(queries_manual, "w") as g, open( 49 | cqr, "w") as h, open(queries_raw, "w") as i: 50 | responses = [] 51 | last_dialog_id = None 52 | for line in f: 53 | obj = json.loads(line) 54 | qid, query = obj['qid'], obj['rewrite'] 55 | raw_query = obj["question"] 56 | dialog_id = qid[:qid.rfind('#')] 57 | if dialog_id != last_dialog_id: 58 | last_dialog_id = dialog_id 59 | responses.clear() 60 | cur_response = obj["answer"]["text"] 61 | responses.append(cur_response) 62 | input_sents = [] 63 | for his in obj["history"]: 64 | input_sents.append(his["question"]) 65 | input_sents.append(obj["question"]) 66 | h.write( 67 | json.dumps({ 68 | "qid": qid, 69 | "input": input_sents, 70 | "target": query, 71 | "manual_response": responses 72 | }) + "\n") 73 | g.write(f"{qid}\t{query}\n") 74 | i.write(f"{qid}\t{raw_query}\n") 75 | qids_set[target].add(qid) 76 | idx += 1 77 | 78 | # qrels.txt -> qrels.train.tsv 79 | print("Processing qrels.txt...") 80 | qrels = os.path.join(args.orquac_dir, "qrels.txt") 81 | with open(qrels, "r") as f: 82 | qrels_dict = json.load(f) 83 | target_qrels_file = open(os.path.join(args.output_dir, "qrels.tsv"), "w") 84 | for qid, v in qrels_dict.items(): 85 | for pid in v.keys(): 86 | passage_idx = passage_id_to_idx[pid] 87 | target_qrels_file.write(f"{qid}\t0\t{passage_idx}\t1\n") 88 | target_qrels_file.close() 89 | 90 | print("End") -------------------------------------------------------------------------------- /data/process_fn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def pad_ids(input_ids, attention_mask, token_type_ids, max_length, pad_token, mask_padding_with_zero, pad_token_segment_id, pad_on_left=False): 5 | padding_length = max_length - len(input_ids) 6 | if pad_on_left: 7 | input_ids = ([pad_token] * padding_length) + input_ids 8 | attention_mask = ([0 if mask_padding_with_zero else 1] 9 | * padding_length) + attention_mask 10 | token_type_ids = ([pad_token_segment_id] * 11 | padding_length) + token_type_ids 12 | else: 13 | input_ids += [pad_token] * padding_length 14 | attention_mask += [0 if mask_padding_with_zero else 1] * padding_length 15 | token_type_ids += [pad_token_segment_id] * padding_length 16 | 17 | return input_ids, attention_mask, token_type_ids 18 | 19 | 20 | def dual_process_fn(line, i, tokenizer, args): 21 | features = [] 22 | cells = line.split("\t") 23 | if len(cells) == 2: 24 | # this is for training and validation 25 | # id, passage = line 26 | mask_padding_with_zero = True 27 | pad_token_segment_id = 0 28 | pad_on_left = False 29 | 30 | text = cells[1].strip() 31 | input_id_a = tokenizer.encode( 32 | text, add_special_tokens=True, max_length=args.max_seq_length,) 33 | token_type_ids_a = [0] * len(input_id_a) 34 | attention_mask_a = [ 35 | 1 if mask_padding_with_zero else 0] * len(input_id_a) 36 | input_id_a, attention_mask_a, token_type_ids_a = pad_ids( 37 | input_id_a, attention_mask_a, token_type_ids_a, args.max_seq_length, tokenizer.pad_token_id, mask_padding_with_zero, pad_token_segment_id, pad_on_left) 38 | features += [torch.tensor(input_id_a, dtype=torch.int), torch.tensor( 39 | attention_mask_a, dtype=torch.bool), torch.tensor(token_type_ids_a, dtype=torch.uint8)] 40 | qid = int(cells[0]) 41 | features.append(qid) 42 | else: 43 | raise Exception( 44 | "Line doesn't have correct length: {0}. Expected 2.".format(str(len(cells)))) 45 | return [features] 46 | 47 | 48 | def triple_process_fn(line, i, tokenizer, args): 49 | features = [] 50 | cells = line.split("\t") 51 | if len(cells) == 3: 52 | # this is for training and validation 53 | # query, positive_passage, negative_passage = line 54 | mask_padding_with_zero = True 55 | pad_token_segment_id = 0 56 | pad_on_left = False 57 | 58 | for text in cells: 59 | input_id_a = tokenizer.encode( 60 | text.strip(), add_special_tokens=True, max_length=args.max_seq_length,) 61 | token_type_ids_a = [0] * len(input_id_a) 62 | attention_mask_a = [ 63 | 1 if mask_padding_with_zero else 0] * len(input_id_a) 64 | input_id_a, attention_mask_a, token_type_ids_a = pad_ids( 65 | input_id_a, attention_mask_a, token_type_ids_a, args.max_seq_length, tokenizer.pad_token_id, mask_padding_with_zero, pad_token_segment_id, pad_on_left) 66 | features += [torch.tensor(input_id_a, dtype=torch.int), 67 | torch.tensor(attention_mask_a, dtype=torch.bool)] 68 | else: 69 | raise Exception( 70 | "Line doesn't have correct length: {0}. Expected 3.".format(str(len(cells)))) 71 | return [features] 72 | 73 | 74 | def triple2dual_process_fn(line, i, tokenizer, args): 75 | ret = [] 76 | cells = line.split("\t") 77 | if len(cells) == 3: 78 | # this is for training and validation 79 | # query, positive_passage, negative_passage = line 80 | # return 2 entries per line, 1 pos + 1 neg 81 | mask_padding_with_zero = True 82 | pad_token_segment_id = 0 83 | pad_on_left = False 84 | pos_feats = [] 85 | neg_feats = [] 86 | 87 | for i, text in enumerate(cells): 88 | input_id_a = tokenizer.encode( 89 | text.strip(), add_special_tokens=True, max_length=args.max_seq_length,) 90 | token_type_ids_a = [0] * len(input_id_a) 91 | attention_mask_a = [ 92 | 1 if mask_padding_with_zero else 0] * len(input_id_a) 93 | input_id_a, attention_mask_a, token_type_ids_a = pad_ids( 94 | input_id_a, attention_mask_a, token_type_ids_a, args.max_seq_length, tokenizer.pad_token_id, mask_padding_with_zero, pad_token_segment_id, pad_on_left) 95 | if i == 0: 96 | pos_feats += [torch.tensor(input_id_a, dtype=torch.int), 97 | torch.tensor(attention_mask_a, dtype=torch.bool)] 98 | neg_feats += [torch.tensor(input_id_a, dtype=torch.int), 99 | torch.tensor(attention_mask_a, dtype=torch.bool)] 100 | elif i == 1: 101 | pos_feats += [torch.tensor(input_id_a, dtype=torch.int), 102 | torch.tensor(attention_mask_a, dtype=torch.bool), 1] 103 | else: 104 | neg_feats += [torch.tensor(input_id_a, dtype=torch.int), 105 | torch.tensor(attention_mask_a, dtype=torch.bool), 0] 106 | ret = [pos_feats, neg_feats] 107 | else: 108 | raise Exception( 109 | "Line doesn't have correct length: {0}. Expected 3.".format(str(len(cells)))) 110 | return ret 111 | 112 | -------------------------------------------------------------------------------- /data/preprocess_cast19.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from trec_car import read_data 3 | from tqdm import tqdm 4 | import pickle 5 | import os 6 | import json 7 | import copy 8 | from utils.util import NUM_FOLD 9 | 10 | 11 | def parse_sim_file(filename): 12 | """ 13 | Reads the deduplicated documents file and stores the 14 | duplicate passage ids into a dictionary 15 | """ 16 | 17 | sim_dict = {} 18 | lines = open(filename).readlines() 19 | for line in lines: 20 | data = line.strip().split(':') 21 | if len(data[1]) > 0: 22 | sim_docs = data[-1].split(',') 23 | for docs in sim_docs: 24 | sim_dict[docs] = 1 25 | 26 | return sim_dict 27 | 28 | 29 | if __name__ == "__main__": 30 | parser = argparse.ArgumentParser() 31 | parser.add_argument("--car_cbor", type=str) 32 | parser.add_argument("--msmarco_collection", type=str) 33 | parser.add_argument("--duplicate_file", type=str) 34 | parser.add_argument("--cast_dir", type=str) 35 | 36 | parser.add_argument("--out_data_dir", type=str) 37 | parser.add_argument("--out_collection_dir", type=str) 38 | args = parser.parse_args() 39 | 40 | # INPUT 41 | sim_file = args.duplicate_file 42 | cast_topics_raw_file = os.path.join(args.cast_dir, 43 | "evaluation_topics_v1.0.json") 44 | cast_topics_manual_file = os.path.join( 45 | args.cast_dir, "evaluation_topics_annotated_resolved_v1.0.tsv") 46 | cast_qrels_file = os.path.join(args.cast_dir, "2019qrels.txt") 47 | 48 | # OUTPUT 49 | out_topics_file = os.path.join(args.out_data_dir, "eval_topics.jsonl") 50 | out_raw_queries_file = os.path.join(args.out_data_dir, "queries.raw.tsv") 51 | out_manual_queries_file = os.path.join(args.out_data_dir, 52 | "queries.manual.tsv") 53 | out_qrels_file = os.path.join(args.out_data_dir, "qrels.tsv") 54 | car_id_to_idx_file = os.path.join(args.out_collection_dir, 55 | "car_id_to_idx.pickle") 56 | car_idx_to_id_file = os.path.join(args.out_collection_dir, 57 | "car_idx_to_id.pickle") 58 | out_collection_file = os.path.join(args.out_collection_dir, 59 | "collection.tsv") 60 | 61 | # 1. Combine TREC-CAR & MS MARCO, remove duplicate passages, assign new ids 62 | car_id_to_idx = {} 63 | car_idx_to_id = [] 64 | if os.path.exists(out_collection_file) and os.path.exists( 65 | car_id_to_idx_file) and os.path.exists(car_idx_to_id_file): 66 | print("Preprocessed collection found. Loading car_id_to_idx...") 67 | with open(car_id_to_idx_file, "rb") as f: 68 | car_id_to_idx = pickle.load(f) 69 | else: 70 | sim_dict = parse_sim_file(sim_file) 71 | car_base_id = 10000000 72 | i = 0 73 | with open(out_collection_file, "w") as f: 74 | print("Processing TREC-CAR...") 75 | for para in tqdm( 76 | read_data.iter_paragraphs(open(args.car_cbor, 'rb'))): 77 | car_id = "CAR_" + para.para_id 78 | text = para.get_text() 79 | text = text.replace("\t", " ").replace("\n", 80 | " ").replace("\r", " ") 81 | idx = car_base_id + i 82 | car_id_to_idx[ 83 | car_id] = idx # e.g. CAR_76a4a716d4b1b01995c6663ee16e94b4ca35fdd3 -> 10000044 84 | car_idx_to_id.append(car_id) 85 | f.write("{}\t{}\n".format(idx, text)) 86 | i += 1 87 | print("Processing MS MARCO...") 88 | removed = 0 89 | with open(args.msmarco_collection, "r") as m: 90 | for line in tqdm(m): 91 | marco_id, text = line.strip().split("\t") 92 | if ("MARCO_" + marco_id) in sim_dict: 93 | removed += 1 94 | continue 95 | f.write("{}\t{}\n".format(marco_id, text)) 96 | print("Removed " + str(removed) + " passages") 97 | print("Dumping id mappings...") 98 | with open(car_id_to_idx_file, "wb") as f: 99 | pickle.dump(car_id_to_idx, f) 100 | with open(car_idx_to_id_file, "wb") as f: 101 | pickle.dump(car_idx_to_id, f) 102 | 103 | # 2. Process queries 104 | print("Processing CAsT utterances...") 105 | with open(cast_topics_raw_file, "r") as fin: 106 | raw_data = json.load(fin) 107 | 108 | with open(cast_topics_manual_file, "r") as fin: 109 | annonated_lines = fin.readlines() 110 | 111 | out_raw_queries = open(out_raw_queries_file, "w") 112 | out_manual_queries = open(out_manual_queries_file, "w") 113 | 114 | all_annonated = {} 115 | for line in annonated_lines: 116 | splitted = line.split('\t') 117 | out_manual_queries.write(line) 118 | topic_query = splitted[0] 119 | query = splitted[1].strip() 120 | topic_id = topic_query.split('_')[0] 121 | query_id = topic_query.split('_')[1] 122 | if topic_id not in all_annonated: 123 | all_annonated[topic_id] = {} 124 | all_annonated[topic_id][query_id] = query 125 | out_manual_queries.close() 126 | 127 | topic_number_dict = {} 128 | data = [] 129 | for group in raw_data: 130 | topic_number, description, turn, title = str( 131 | group['number']), group.get('description', 132 | ''), group['turn'], group.get( 133 | 'title', '') 134 | queries = [] 135 | for query in turn: 136 | query_number, raw_utterance = str( 137 | query['number']), query['raw_utterance'] 138 | queries.append(raw_utterance) 139 | record = {} 140 | record['topic_number'] = topic_number 141 | record['query_number'] = query_number 142 | record['description'] = description 143 | record['title'] = title 144 | record['input'] = copy.deepcopy(queries) 145 | record['target'] = all_annonated[topic_number][query_number] 146 | out_raw_queries.write("{}_{}\t{}\n".format(topic_number, 147 | query_number, 148 | raw_utterance)) 149 | if not topic_number in topic_number_dict: 150 | topic_number_dict[topic_number] = len(topic_number_dict) 151 | data.append(record) 152 | out_raw_queries.close() 153 | 154 | with open(out_topics_file, 'w') as fout: 155 | for item in data: 156 | json_str = json.dumps(item) 157 | fout.write(json_str + '\n') 158 | 159 | # Split eval data into K-fold 160 | topic_per_fold = len(topic_number_dict) // NUM_FOLD 161 | for i in range(NUM_FOLD): 162 | with open(out_topics_file + "." + str(i), 'w') as fout: 163 | for item in data: 164 | idx = topic_number_dict[item['topic_number']] 165 | if idx // topic_per_fold == i: 166 | json_str = json.dumps(item) 167 | fout.write(json_str + '\n') 168 | 169 | # 3. Process and convert qrels 170 | print("Processing qrels...") 171 | with open(cast_qrels_file, "r") as oq, open(out_qrels_file, "w") as nq: 172 | for line in oq: 173 | qid, _, pid, rel = line.strip().split() 174 | if pid.startswith("CAR_"): 175 | assert car_id_to_idx[pid] != -1 176 | pid = car_id_to_idx[pid] 177 | elif pid.startswith("MARCO_"): 178 | pid = int(pid[6:]) 179 | else: 180 | continue 181 | nq.write(qid + "\t0\t" + str(pid) + "\t" + rel + "\n") 182 | 183 | print("End") -------------------------------------------------------------------------------- /data/preprocess_cast20.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from trec_car import read_data 3 | from tqdm import tqdm 4 | import pickle 5 | import os 6 | import json 7 | import copy 8 | from utils.util import NUM_FOLD 9 | 10 | topic_range = range(81, 106) 11 | fold_dict = {x: (x - 81) // NUM_FOLD for x in topic_range} 12 | 13 | 14 | def parse_sim_file(filename): 15 | """ 16 | Reads the deduplicated documents file and stores the 17 | duplicate passage ids into a dictionary 18 | """ 19 | 20 | sim_dict = {} 21 | lines = open(filename).readlines() 22 | for line in lines: 23 | data = line.strip().split(':') 24 | if len(data[1]) > 0: 25 | sim_docs = data[-1].split(',') 26 | for docs in sim_docs: 27 | sim_dict[docs] = 1 28 | 29 | return sim_dict 30 | 31 | 32 | if __name__ == "__main__": 33 | parser = argparse.ArgumentParser() 34 | parser.add_argument("--car_cbor", type=str) 35 | parser.add_argument("--msmarco_collection", type=str) 36 | parser.add_argument("--duplicate_file", type=str) 37 | parser.add_argument("--cast_dir", type=str) 38 | 39 | parser.add_argument("--out_data_dir", type=str) 40 | parser.add_argument("--out_collection_dir", type=str) 41 | args = parser.parse_args() 42 | 43 | # INPUT 44 | sim_file = args.duplicate_file 45 | cast_topics_automatic_file = os.path.join( 46 | args.cast_dir, "2020_automatic_evaluation_topics_v1.0.json") 47 | cast_topics_manual_file = os.path.join( 48 | args.cast_dir, "2020_manual_evaluation_topics_v1.0.json") 49 | cast_qrels_file = os.path.join(args.cast_dir, "2020qrels.txt") 50 | 51 | # OUTPUT 52 | out_topics_file = os.path.join(args.out_data_dir, "eval_topics.jsonl") 53 | out_raw_queries_file = os.path.join(args.out_data_dir, "queries.raw.tsv") 54 | out_manual_queries_file = os.path.join(args.out_data_dir, 55 | "queries.manual.tsv") 56 | out_qrels_file = os.path.join(args.out_data_dir, "qrels.tsv") 57 | car_id_to_idx_file = os.path.join(args.out_collection_dir, 58 | "car_id_to_idx.pickle") 59 | car_idx_to_id_file = os.path.join(args.out_collection_dir, 60 | "car_idx_to_id.pickle") 61 | out_collection_file = os.path.join(args.out_collection_dir, 62 | "collection.tsv") 63 | 64 | # 1. Combine TREC-CAR & MS MARCO, remove duplicate passages, assign new ids 65 | car_id_to_idx = {} 66 | car_idx_to_id = [] 67 | collection = ["xx"] * 4000_0000 68 | if os.path.exists(out_collection_file) and os.path.exists( 69 | car_id_to_idx_file) and os.path.exists(car_idx_to_id_file): 70 | print("Preprocessed collection found. Loading car_id_to_idx...") 71 | with open(car_id_to_idx_file, "rb") as f: 72 | car_id_to_idx = pickle.load(f) 73 | print("Loading processed document collection...") 74 | with open(out_collection_file, "r") as f: 75 | for line in f: 76 | try: 77 | line = line.strip() 78 | obj = line.split("\t") 79 | pid = obj[0] 80 | text = obj[1] 81 | pid = int(pid) 82 | collection[pid] = text 83 | except IndexError: 84 | print(line) 85 | else: 86 | sim_dict = parse_sim_file(sim_file) 87 | car_base_id = 10000000 88 | i = 0 89 | with open(car_idx_to_id_file, "w") as f: 90 | print("Processing TREC-CAR...") 91 | for para in tqdm( 92 | read_data.iter_paragraphs(open(args.car_cbor, 'rb'))): 93 | car_id = "CAR_" + para.para_id 94 | text = para.get_text() 95 | text = text.replace("\t", " ").replace("\n", 96 | " ").replace("\r", " ") 97 | idx = car_base_id + i 98 | car_id_to_idx[ 99 | car_id] = idx # e.g. CAR_76a4a716d4b1b01995c6663ee16e94b4ca35fdd3 -> 10000044 100 | collection[idx] = text 101 | car_idx_to_id.append(car_id) 102 | f.write("{}\t{}\n".format(idx, text)) 103 | i += 1 104 | print("Processing MS MARCO...") 105 | removed = 0 106 | with open(args.msmarco_collection, "r") as m: 107 | for line in tqdm(m): 108 | marco_id, text = line.strip().split("\t") 109 | if ("MARCO_" + marco_id) in sim_dict: 110 | removed += 1 111 | continue 112 | collection[int(marco_id)] = text 113 | f.write("{}\t{}\n".format(marco_id, text)) 114 | print("Removed " + str(removed) + " passages") 115 | print("Dumping id mappings...") 116 | with open(car_id_to_idx_file, "wb") as f: 117 | pickle.dump(car_id_to_idx, f) 118 | with open(car_idx_to_id_file, "wb") as f: 119 | pickle.dump(car_idx_to_id, f) 120 | 121 | # 2. Process queries 122 | print("Processing CAsT utterances...") 123 | 124 | def get_text_by_raw_id(raw_id): 125 | new_id = None 126 | if raw_id.startswith("MARCO_"): 127 | new_id = int(raw_id[6:]) 128 | elif raw_id.startswith("CAR_"): 129 | new_id = car_id_to_idx[raw_id] 130 | else: 131 | raise ValueError("Invalid document id") 132 | text = collection[new_id] 133 | if text == "xx": 134 | raise ValueError("Unknown document") 135 | return text 136 | 137 | with open(cast_topics_automatic_file, "r") as f: 138 | auto_raw = json.load(f) 139 | with open(cast_topics_manual_file, "r") as f: 140 | manual_raw = json.load(f) 141 | out_topics = open(out_topics_file, "w") 142 | out_topics_fold = open(out_topics_file + ".0", "w") 143 | out_raw_queries = open(out_raw_queries_file, "w") 144 | out_manual_queries = open(out_manual_queries_file, "w") 145 | cur_fold = 0 146 | for auto_topic, manual_topic in zip(auto_raw, manual_raw): 147 | topic_number = auto_topic["number"] 148 | assert topic_number == manual_topic["number"] 149 | auto_turns = auto_topic["turn"] 150 | manual_turns = manual_topic["turn"] 151 | assert len(auto_turns) == len(manual_turns) 152 | inputs = [] 153 | manual_responses = [] 154 | auto_responses = [] 155 | manual_res_ids = [] 156 | auto_res_ids = [] 157 | for auto_turn, manual_turn in zip(auto_turns, manual_turns): 158 | query_number = auto_turn["number"] 159 | 160 | raw = auto_turn["raw_utterance"] 161 | inputs.append(raw) 162 | target = manual_turn["manual_rewritten_utterance"] 163 | 164 | manual_res_ids.append(manual_turn["manual_canonical_result_id"]) 165 | response = get_text_by_raw_id( 166 | manual_turn["manual_canonical_result_id"]) 167 | manual_responses.append(response) 168 | 169 | auto_res_ids.append(auto_turn["automatic_canonical_result_id"]) 170 | response = get_text_by_raw_id( 171 | auto_turn["automatic_canonical_result_id"]) 172 | auto_responses.append(response) 173 | 174 | output_dict = { 175 | "topic_number": topic_number, 176 | "query_number": query_number, 177 | "input": copy.deepcopy(inputs), 178 | "automatic_response_id": copy.deepcopy(auto_res_ids), 179 | "automatic_response": copy.deepcopy(auto_responses), 180 | "manual_response_id": copy.deepcopy(manual_res_ids), 181 | "manual_response": copy.deepcopy(manual_responses), 182 | "target": target 183 | } 184 | 185 | dumped_str = json.dumps(output_dict) + "\n" 186 | out_topics.write(dumped_str) 187 | if fold_dict[topic_number] != cur_fold: 188 | out_topics_fold.close() 189 | out_topics_fold = open( 190 | out_topics_file + "." + str(fold_dict[topic_number]), "w") 191 | cur_fold = fold_dict[topic_number] 192 | out_topics_fold.write(dumped_str) 193 | 194 | out_raw_queries.write( 195 | str(topic_number) + "_" + str(query_number) + "\t" + raw + 196 | "\n") 197 | out_manual_queries.write( 198 | str(topic_number) + "_" + str(query_number) + "\t" + target + 199 | "\n") 200 | 201 | # 3. Process and convert qrels 202 | print("Processing qrels...") 203 | with open(cast_qrels_file, "r") as oq, open(out_qrels_file, "w") as nq: 204 | for line in oq: 205 | qid, _, pid, rel = line.strip().split() 206 | if pid.startswith("CAR_"): 207 | assert car_id_to_idx[pid] != -1 208 | pid = car_id_to_idx[pid] 209 | elif pid.startswith("MARCO_"): 210 | pid = int(pid[6:]) 211 | else: 212 | continue 213 | nq.write(qid + "\t0\t" + str(pid) + "\t" + rel + "\n") 214 | 215 | print("End") -------------------------------------------------------------------------------- /data/preprocess_cast21.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from trec_car import read_data 3 | from tqdm import tqdm 4 | import pickle 5 | import os 6 | import json 7 | import copy 8 | from utils.util import NUM_FOLD 9 | 10 | topic_range = range(106, 132) 11 | fold_dict = {x: (x - 106) // NUM_FOLD for x in topic_range} 12 | 13 | 14 | def parse_sim_file(filename): 15 | """ 16 | Reads the deduplicated documents file and stores the 17 | duplicate passage ids into a dictionary 18 | """ 19 | 20 | sim_dict = {} 21 | lines = open(filename).readlines() 22 | for line in lines: 23 | data = line.strip().split(':') 24 | if len(data[1]) > 0: 25 | sim_docs = data[-1].split(',') 26 | for docs in sim_docs: 27 | sim_dict[docs] = 1 28 | 29 | return sim_dict 30 | 31 | 32 | def main(): 33 | parser = argparse.ArgumentParser() 34 | parser.add_argument("--kilt", type=str) 35 | parser.add_argument("--msmarco", type=str) 36 | parser.add_argument("--wapo", type=str) 37 | # parser.add_argument("--duplicate_file", type=str) 38 | parser.add_argument("--cast_dir", type=str) 39 | 40 | parser.add_argument("--out_data_dir", type=str) 41 | parser.add_argument("--out_collection_dir", type=str) 42 | args = parser.parse_args() 43 | 44 | # INPUT 45 | cast_topics_manual_file = os.path.join( 46 | args.cast_dir, "2021_manual_evaluation_topics_v1.0.json") 47 | 48 | # OUTPUT 49 | out_topics_file = os.path.join(args.out_data_dir, "eval_topics.jsonl") 50 | out_raw_queries_file = os.path.join(args.out_data_dir, "queries.raw.tsv") 51 | out_manual_queries_file = os.path.join(args.out_data_dir, 52 | "queries.manual.tsv") 53 | doc_id_to_idx_file = os.path.join(args.out_collection_dir, 54 | "doc_id_to_idx.pickle") 55 | doc_idx_to_id_file = os.path.join(args.out_collection_dir, 56 | "doc_idx_to_id.pickle") 57 | out_collection_file = os.path.join(args.out_collection_dir, 58 | "collection.tsv") 59 | out_psuedo_qrels_file = os.path.join(args.out_data_dir, "qrels.tsv") 60 | 61 | # 1. Combine KILT, MS MARCO * WaPo, remove duplicate passages, assign new ids 62 | doc_id_to_idx = {} 63 | doc_idx_to_id = [] 64 | collection = ["xx"] * 8000_0000 65 | if os.path.exists(out_collection_file) and os.path.exists( 66 | doc_id_to_idx_file) and os.path.exists(doc_idx_to_id_file): 67 | print("Preprocessed collection found. Loading car_id_to_idx...") 68 | with open(doc_id_to_idx_file, "rb") as f: 69 | doc_id_to_idx = pickle.load(f) 70 | print("Loading processed document collection...") 71 | with open(out_collection_file, "r") as f: 72 | for line in f: 73 | try: 74 | line = line.strip() 75 | obj = line.split("\t") 76 | pid = obj[0] 77 | text = obj[1] 78 | pid = int(pid) 79 | collection[pid] = text 80 | except IndexError: 81 | print(line) 82 | except ValueError: 83 | print(line) 84 | else: 85 | with open(out_collection_file, "w") as f: 86 | print("Processing KILT...") 87 | with open(args.kilt, "r") as k: 88 | all_content = k.read() 89 | pidx = parse_documents(all_content, doc_id_to_idx, doc_idx_to_id, collection, f, pidx=0) 90 | print("Processing MS MARCO...") 91 | with open(args.msmarco, "r") as m: 92 | all_content = m.read() 93 | pidx = parse_documents(all_content, doc_id_to_idx, doc_idx_to_id, collection, f, pidx) 94 | print("Processing WaPo...") 95 | with open(args.wapo, "r") as w: 96 | all_content = w.read() 97 | pidx = parse_documents(all_content, doc_id_to_idx, doc_idx_to_id, collection, f, pidx) 98 | print("Total document num: {}".format(pidx)) 99 | print("Dumping id mappings...") 100 | with open(doc_id_to_idx_file, "wb") as f: 101 | pickle.dump(doc_id_to_idx, f) 102 | with open(doc_idx_to_id_file, "wb") as f: 103 | pickle.dump(doc_idx_to_id, f) 104 | 105 | # 2. Process queries 106 | print("Processing CAsT utterances...") 107 | 108 | def get_text_by_raw_id(raw_id): 109 | new_id = doc_id_to_idx[raw_id] 110 | text = collection[new_id] 111 | if text == "xx": 112 | raise ValueError("Unknown document") 113 | return text, new_id 114 | 115 | with open(cast_topics_manual_file, "r") as f: 116 | manual_raw = json.load(f) 117 | out_topics = open(out_topics_file, "w") 118 | out_topics_fold = open(out_topics_file + ".0", "w") 119 | out_raw_queries = open(out_raw_queries_file, "w") 120 | out_manual_queries = open(out_manual_queries_file, "w") 121 | out_psuedo_qrels = open(out_psuedo_qrels_file, "w") 122 | cur_fold = 0 123 | for manual_topic in manual_raw: 124 | topic_number = manual_topic["number"] 125 | manual_turns = manual_topic["turn"] 126 | inputs = [] 127 | manual_responses = [] 128 | auto_responses = [] 129 | manual_res_ids = [] 130 | auto_res_ids = [] 131 | for manual_turn in manual_turns: 132 | query_number = manual_turn["number"] 133 | 134 | raw = manual_turn["raw_utterance"] 135 | inputs.append(raw) 136 | target = manual_turn["manual_rewritten_utterance"] 137 | 138 | res_id = manual_turn["canonical_result_id"] + "-" + str(manual_turn["passage_id"]) 139 | manual_res_ids.append(res_id) 140 | response, new_id = get_text_by_raw_id( 141 | res_id) 142 | manual_responses.append(response) 143 | 144 | output_dict = { 145 | "topic_number": topic_number, 146 | "query_number": query_number, 147 | "input": copy.deepcopy(inputs), 148 | "automatic_response_id": copy.deepcopy(auto_res_ids), 149 | "automatic_response": copy.deepcopy(auto_responses), 150 | "manual_response_id": copy.deepcopy(manual_res_ids), 151 | "manual_response": copy.deepcopy(manual_responses), 152 | "target": target 153 | } 154 | 155 | dumped_str = json.dumps(output_dict) + "\n" 156 | out_topics.write(dumped_str) 157 | if fold_dict[topic_number] != cur_fold: 158 | out_topics_fold.close() 159 | out_topics_fold = open( 160 | out_topics_file + "." + str(fold_dict[topic_number]), "w") 161 | cur_fold = fold_dict[topic_number] 162 | out_topics_fold.write(dumped_str) 163 | 164 | out_psuedo_qrels.write(str(topic_number) + "_" + str(query_number)+ "\t0\t" + str(new_id) + "\t1\n") 165 | 166 | out_raw_queries.write( 167 | str(topic_number) + "_" + str(query_number) + "\t" + raw + 168 | "\n") 169 | out_manual_queries.write( 170 | str(topic_number) + "_" + str(query_number) + "\t" + target + 171 | "\n") 172 | 173 | print("End") 174 | 175 | def parse_documents(all_content, doc_id_to_idx, doc_idx_to_id, collection, f, pidx=0): 176 | docid = None 177 | title = None 178 | passage = None 179 | pid = None 180 | char_id = 0 181 | last_char_id = 0 182 | with tqdm(total=len(all_content)) as pbar: 183 | while char_id < len(all_content): 184 | last_char_id = char_id 185 | if all_content[char_id] == "<": 186 | char_id += 1 187 | if all_content[char_id] not in ["D", "T", "p"]: 188 | continue 189 | if all_content[char_id:char_id+len("DOCNO>")] == "DOCNO>": 190 | char_id += len("DOCNO>") 191 | end_pos = all_content.find("", char_id) 192 | assert end_pos != -1 193 | docid = all_content[char_id:end_pos] 194 | char_id = end_pos + len("") 195 | elif all_content[char_id:char_id+len("TITLE>")] == "TITLE>": 196 | char_id += len("TITLE>") 197 | end_pos = all_content.find("", char_id) 198 | assert end_pos != -1 199 | title = all_content[char_id:end_pos] 200 | char_id = end_pos + len("") 201 | elif all_content[char_id:char_id+len("passage id=")] == "passage id=": 202 | char_id += len("passage id=") 203 | end_pos = all_content.find(">", char_id) 204 | assert end_pos != -1 205 | pid = str(int(all_content[char_id:end_pos])) 206 | char_id = end_pos + 1 207 | end_pos = all_content.find("", char_id) 208 | assert end_pos != -1 209 | passage = all_content[char_id:end_pos].strip().replace("\n", " ").replace("\t", " ").strip() 210 | text = title + " " + passage 211 | char_id = end_pos + len("") 212 | doc_id_to_idx[docid + "-" + pid] = pidx 213 | doc_idx_to_id.append(docid + "-" + pid) 214 | collection[pidx] = text 215 | f.write("{}\t{}\n".format(pidx, text)) 216 | pidx += 1 217 | else: 218 | char_id += 1 219 | pbar.update(char_id - last_char_id) 220 | 221 | return pidx 222 | 223 | 224 | if __name__ == "__main__": 225 | main() -------------------------------------------------------------------------------- /data/tokenizing.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | import torch 4 | sys.path += ['../'] 5 | import pickle 6 | from utils.util import pad_input_ids, multi_file_process, numbered_byte_file_generator, EmbeddingCache 7 | from model.models import MSMarcoConfigDict, ALL_MODELS 8 | from torch.utils.data import TensorDataset 9 | import numpy as np 10 | import argparse 11 | import json 12 | 13 | 14 | def preprocess(args): 15 | 16 | pid2offset = {} 17 | offset2pid = [] 18 | in_passage_path = args.collection 19 | 20 | out_passage_path = os.path.join( 21 | args.out_data_dir, 22 | "passages", 23 | ) 24 | 25 | if os.path.exists(out_passage_path): 26 | print("preprocessed data already exist, exit preprocessing") 27 | return 28 | 29 | out_line_count = 0 30 | 31 | print('start passage file split processing') 32 | multi_file_process( 33 | args, 34 | 32, 35 | in_passage_path, 36 | out_passage_path, 37 | PassagePreprocessingFn) 38 | 39 | print('start merging splits') 40 | with open(out_passage_path, 'wb') as f: 41 | for idx, record in enumerate(numbered_byte_file_generator( 42 | out_passage_path, 32, 8 + 4 + args.max_seq_length * 4)): 43 | p_id = int.from_bytes(record[:8], 'big') 44 | f.write(record[8:]) 45 | pid2offset[p_id] = idx 46 | offset2pid.append(p_id) 47 | if idx < 3: 48 | print(str(idx) + " " + str(p_id)) 49 | out_line_count += 1 50 | 51 | print("Total lines written: " + str(out_line_count)) 52 | meta = { 53 | 'type': 'int32', 54 | 'total_number': out_line_count, 55 | 'embedding_size': args.max_seq_length} 56 | with open(out_passage_path + "_meta", 'w') as f: 57 | json.dump(meta, f) 58 | embedding_cache = EmbeddingCache(out_passage_path) 59 | print("First line") 60 | with embedding_cache as emb: 61 | print(emb[0]) 62 | 63 | pid2offset_path = os.path.join( 64 | args.out_data_dir, 65 | "pid2offset.pickle", 66 | ) 67 | offset2pid_path = os.path.join( 68 | args.out_data_dir, 69 | "offset2pid.pickle", 70 | ) 71 | with open(pid2offset_path, 'wb') as handle: 72 | pickle.dump(pid2offset, handle, protocol=4) 73 | with open(offset2pid_path, "wb") as handle: 74 | pickle.dump(offset2pid, handle, protocol=4) 75 | print("done saving pid2offset") 76 | 77 | 78 | def PassagePreprocessingFn(args, line, tokenizer): 79 | line = line.strip() 80 | ext = args.collection[args.collection.rfind("."):] 81 | passage = None 82 | if ext == ".jsonl": 83 | obj = json.loads(line) 84 | p_id = int(obj["id"]) 85 | p_text = obj["text"] 86 | p_title = obj["title"] 87 | 88 | full_text = p_text[:args.max_doc_character] 89 | 90 | passage = tokenizer.encode( 91 | p_title, 92 | text_pair=full_text, 93 | add_special_tokens=True, 94 | max_length=args.max_seq_length, 95 | ) 96 | elif ext == ".tsv": 97 | try: 98 | line_arr = line.split('\t') 99 | p_id = int(line_arr[0]) 100 | p_text = line_arr[1].rstrip() 101 | except IndexError: # split error 102 | raise ValueError # empty passage 103 | else: 104 | full_text = p_text[:args.max_doc_character] 105 | passage = tokenizer.encode( 106 | full_text, 107 | add_special_tokens=True, 108 | max_length=args.max_seq_length, 109 | ) 110 | else: 111 | raise TypeError("Unrecognized file type") 112 | 113 | passage_len = min(len(passage), args.max_seq_length) 114 | input_id_b = pad_input_ids(passage, args.max_seq_length) 115 | 116 | return p_id.to_bytes(8,'big') + passage_len.to_bytes(4,'big') + np.array(input_id_b,np.int32).tobytes() 117 | 118 | 119 | def QueryPreprocessingFn(args, line, tokenizer): 120 | line_arr = line.split('\t') 121 | q_id = int(line_arr[0]) 122 | 123 | passage = tokenizer.encode( 124 | line_arr[1].rstrip(), 125 | add_special_tokens=True, 126 | max_length=args.max_query_length) 127 | passage_len = min(len(passage), args.max_query_length) 128 | input_id_b = pad_input_ids(passage, args.max_query_length) 129 | 130 | return q_id.to_bytes(8,'big') + passage_len.to_bytes(4,'big') + np.array(input_id_b,np.int32).tobytes() 131 | 132 | 133 | def GetProcessingFn(args, query=False): 134 | def fn(vals, i): 135 | passage_len, passage = vals 136 | max_len = args.max_query_length if query else args.max_seq_length 137 | 138 | pad_len = max(0, max_len - passage_len) 139 | token_type_ids = ([0] if query else [1]) * passage_len + [0] * pad_len 140 | attention_mask = [1] * passage_len + [0] * pad_len 141 | 142 | passage_collection = [(i, passage, attention_mask, token_type_ids)] 143 | 144 | query2id_tensor = torch.tensor( 145 | [f[0] for f in passage_collection], dtype=torch.long) 146 | all_input_ids_a = torch.tensor( 147 | [f[1] for f in passage_collection], dtype=torch.int) 148 | all_attention_mask_a = torch.tensor( 149 | [f[2] for f in passage_collection], dtype=torch.bool) 150 | all_token_type_ids_a = torch.tensor( 151 | [f[3] for f in passage_collection], dtype=torch.uint8) 152 | 153 | dataset = TensorDataset( 154 | all_input_ids_a, 155 | all_attention_mask_a, 156 | all_token_type_ids_a, 157 | query2id_tensor) 158 | 159 | return [ts for ts in dataset] 160 | 161 | return fn 162 | 163 | 164 | def GetTrainingDataProcessingFn(args, query_cache, passage_cache): 165 | def fn(line, i): 166 | line_arr = line.split('\t') 167 | qid = int(line_arr[0]) 168 | pos_pid = int(line_arr[1]) 169 | neg_pids = line_arr[2].split(',') 170 | neg_pids = [int(neg_pid) for neg_pid in neg_pids] 171 | 172 | all_input_ids_a = [] 173 | all_attention_mask_a = [] 174 | 175 | query_data = GetProcessingFn( 176 | args, query=True)( 177 | query_cache[qid], qid)[0] 178 | pos_data = GetProcessingFn( 179 | args, query=False)( 180 | passage_cache[pos_pid], pos_pid)[0] 181 | 182 | pos_label = torch.tensor(1, dtype=torch.long) 183 | neg_label = torch.tensor(0, dtype=torch.long) 184 | 185 | for neg_pid in neg_pids: 186 | neg_data = GetProcessingFn( 187 | args, query=False)( 188 | passage_cache[neg_pid], neg_pid)[0] 189 | yield (query_data[0], query_data[1], query_data[2], pos_data[0], pos_data[1], pos_data[2], pos_label) 190 | yield (query_data[0], query_data[1], query_data[2], neg_data[0], neg_data[1], neg_data[2], neg_label) 191 | 192 | return fn 193 | 194 | 195 | def GetTripletTrainingDataProcessingFn(args, query_cache, passage_cache): 196 | def fn(line, i): 197 | line_arr = line.split('\t') 198 | qid = int(line_arr[0]) 199 | pos_pid = int(line_arr[1]) 200 | neg_pids = line_arr[2].split(',') 201 | neg_pids = [int(neg_pid) for neg_pid in neg_pids] 202 | 203 | all_input_ids_a = [] 204 | all_attention_mask_a = [] 205 | 206 | query_data = GetProcessingFn( 207 | args, query=True)( 208 | query_cache[qid], qid)[0] 209 | pos_data = GetProcessingFn( 210 | args, query=False)( 211 | passage_cache[pos_pid], pos_pid)[0] 212 | 213 | for neg_pid in neg_pids: 214 | neg_data = GetProcessingFn( 215 | args, query=False)( 216 | passage_cache[neg_pid], neg_pid)[0] 217 | yield (query_data[0], query_data[1], query_data[2], pos_data[0], pos_data[1], pos_data[2], 218 | neg_data[0], neg_data[1], neg_data[2]) 219 | 220 | return fn 221 | 222 | 223 | def get_arguments(): 224 | parser = argparse.ArgumentParser() 225 | 226 | parser.add_argument( 227 | "--collection", 228 | default=None, 229 | type=str, 230 | required=True, 231 | help="The input data dir", 232 | ) 233 | parser.add_argument( 234 | "--out_data_dir", 235 | default=None, 236 | type=str, 237 | required=True, 238 | help="The output data dir", 239 | ) 240 | parser.add_argument( 241 | "--model_name_or_path", 242 | default=None, 243 | type=str, 244 | required=True, 245 | help="Path to pre-trained model or shortcut name selected in the list: " + 246 | ", ".join(ALL_MODELS), 247 | ) 248 | parser.add_argument( 249 | "--model_type", 250 | default=None, 251 | type=str, 252 | required=True, 253 | help="Model type selected in the list: " + 254 | ", ".join(MSMarcoConfigDict.keys()), 255 | ) 256 | parser.add_argument( 257 | "--max_seq_length", 258 | default=512, 259 | type=int, 260 | help="The maximum total input sequence length after tokenization. Sequences longer " 261 | "than this will be truncated, sequences shorter will be padded.", 262 | ) 263 | parser.add_argument( 264 | "--max_doc_character", 265 | default=10000, 266 | type=int, 267 | help="used before tokenizer to save tokenizer latency", 268 | ) 269 | 270 | args = parser.parse_args() 271 | 272 | return args 273 | 274 | 275 | def main(): 276 | args = get_arguments() 277 | 278 | if not os.path.exists(args.out_data_dir): 279 | os.makedirs(args.out_data_dir) 280 | preprocess(args) 281 | 282 | 283 | if __name__ == '__main__': 284 | main() 285 | -------------------------------------------------------------------------------- /drivers/gen_passage_embeddings.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | sys.path += ['../'] 4 | import torch 5 | import os 6 | from utils.util import ( 7 | barrier_array_merge, 8 | StreamingDataset, 9 | EmbeddingCache, 10 | ) 11 | from data.tokenizing import GetProcessingFn 12 | from model.models import MSMarcoConfigDict 13 | from torch import nn 14 | import torch.distributed as dist 15 | from tqdm import tqdm 16 | from torch.utils.data import DataLoader 17 | import numpy as np 18 | import argparse 19 | import logging 20 | from utils.dpr_utils import load_states_from_checkpoint, get_model_obj 21 | import re 22 | 23 | torch.multiprocessing.set_sharing_strategy('file_system') 24 | 25 | logger = logging.getLogger(__name__) 26 | 27 | 28 | def load_model(args, checkpoint_path): 29 | label_list = ["0", "1"] 30 | num_labels = len(label_list) 31 | args.model_type = args.model_type.lower() 32 | configObj = MSMarcoConfigDict[args.model_type] 33 | args.model_name_or_path = checkpoint_path 34 | 35 | config, tokenizer, model = None, None, None 36 | if args.model_type != "dpr": 37 | config = configObj.config_class.from_pretrained( 38 | args.model_name_or_path, 39 | num_labels=num_labels, 40 | finetuning_task="MSMarco", 41 | cache_dir=args.cache_dir if args.cache_dir else None, 42 | ) 43 | tokenizer = configObj.tokenizer_class.from_pretrained( 44 | args.model_name_or_path, 45 | do_lower_case=True, 46 | cache_dir=args.cache_dir if args.cache_dir else None, 47 | ) 48 | model = configObj.model_class.from_pretrained( 49 | args.model_name_or_path, 50 | from_tf=bool(".ckpt" in args.model_name_or_path), 51 | config=config, 52 | cache_dir=args.cache_dir if args.cache_dir else None, 53 | ) 54 | else: # dpr 55 | model = configObj.model_class(args) 56 | saved_state = load_states_from_checkpoint(checkpoint_path) 57 | model_to_load = get_model_obj(model) 58 | logger.info('Loading saved model state ...') 59 | model_to_load.load_state_dict(saved_state.model_dict) 60 | 61 | model.to(args.device) 62 | logger.info("Inference parameters %s", args) 63 | if args.local_rank != -1: 64 | model = torch.nn.parallel.DistributedDataParallel( 65 | model, 66 | device_ids=[args.local_rank], 67 | output_device=args.local_rank, 68 | find_unused_parameters=True, 69 | ) 70 | return config, tokenizer, model 71 | 72 | 73 | def InferenceEmbeddingFromStreamDataLoader( 74 | args, 75 | model, 76 | train_dataloader, 77 | is_query_inference=True, 78 | ): 79 | # expect dataset from ReconstructTrainingSet 80 | results = {} 81 | eval_batch_size = args.per_gpu_eval_batch_size 82 | 83 | # Inference! 84 | logger.info("***** Running ANN Embedding Inference *****") 85 | logger.info(" Batch size = %d", eval_batch_size) 86 | 87 | embedding = [] 88 | embedding2id = [] 89 | 90 | if args.local_rank != -1: 91 | dist.barrier() 92 | model.eval() 93 | 94 | for batch in tqdm(train_dataloader, 95 | desc="Inferencing", 96 | disable=args.local_rank not in [-1, 0], 97 | position=0, 98 | leave=True): 99 | 100 | idxs = batch[3].detach().numpy() # [#B] 101 | 102 | batch = tuple(t.to(args.device) for t in batch) 103 | 104 | with torch.no_grad(): 105 | inputs = { 106 | "input_ids": batch[0].long(), 107 | "attention_mask": batch[1].long() 108 | } 109 | if is_query_inference: 110 | embs = model.module.query_emb(**inputs) 111 | else: 112 | embs = model.module.body_emb(**inputs) 113 | 114 | embs = embs.detach().cpu().numpy() 115 | 116 | # check for multi chunk output for long sequence 117 | if len(embs.shape) == 3: 118 | for chunk_no in range(embs.shape[1]): 119 | embedding2id.append(idxs) 120 | embedding.append(embs[:, chunk_no, :]) 121 | else: 122 | embedding2id.append(idxs) 123 | embedding.append(embs) 124 | 125 | embedding = np.concatenate(embedding, axis=0) 126 | embedding2id = np.concatenate(embedding2id, axis=0) 127 | return embedding, embedding2id 128 | 129 | 130 | # streaming inference 131 | def StreamInferenceDoc(args, 132 | model, 133 | fn, 134 | prefix, 135 | f, 136 | is_query_inference=True, 137 | merge=True): 138 | inference_batch_size = args.per_gpu_eval_batch_size # * max(1, args.n_gpu) 139 | inference_dataset = StreamingDataset(f, fn) 140 | inference_dataloader = DataLoader(inference_dataset, 141 | batch_size=inference_batch_size) 142 | 143 | if args.local_rank != -1: 144 | dist.barrier() # directory created 145 | 146 | _embedding, _embedding2id = InferenceEmbeddingFromStreamDataLoader( 147 | args, 148 | model, 149 | inference_dataloader, 150 | is_query_inference=is_query_inference, 151 | ) 152 | 153 | logger.info("merging embeddings") 154 | 155 | # preserve to memory 156 | full_embedding = barrier_array_merge(args, 157 | _embedding, 158 | prefix=prefix + "_emb_p_", 159 | load_cache=False, 160 | only_load_in_master=True, 161 | merge=merge) 162 | full_embedding2id = barrier_array_merge(args, 163 | _embedding2id, 164 | prefix=prefix + "_embid_p_", 165 | load_cache=False, 166 | only_load_in_master=True, 167 | merge=merge) 168 | 169 | return full_embedding, full_embedding2id 170 | 171 | 172 | def generate_new_ann( 173 | args, 174 | checkpoint_path, 175 | ): 176 | 177 | _, __, model = load_model(args, checkpoint_path) 178 | merge = False 179 | 180 | logger.info("***** inference of passages *****") 181 | passage_collection_path = os.path.join(args.data_dir, 182 | "passages") 183 | passage_cache = EmbeddingCache(passage_collection_path) 184 | with passage_cache as emb: 185 | passage_embedding, passage_embedding2id = StreamInferenceDoc( 186 | args, 187 | model, 188 | GetProcessingFn(args, query=False), 189 | "passage_", 190 | emb, 191 | is_query_inference=False, 192 | merge=merge) 193 | logger.info("***** Done passage inference *****") 194 | 195 | 196 | def get_arguments(): 197 | parser = argparse.ArgumentParser() 198 | 199 | # Required parameters 200 | parser.add_argument( 201 | "--data_dir", 202 | default=None, 203 | type=str, 204 | required=True, 205 | help="The input data dir. Should contain the tokenized passage ids.", 206 | ) 207 | 208 | parser.add_argument( 209 | "--checkpoint", 210 | default=None, 211 | type=str, 212 | required=True, 213 | help="Checkpoint of the ad hoc retriever", 214 | ) 215 | 216 | parser.add_argument( 217 | "--model_type", 218 | default=None, 219 | type=str, 220 | required=True, 221 | help="Model type selected in the list: " + 222 | ", ".join(MSMarcoConfigDict.keys()), 223 | ) 224 | 225 | parser.add_argument( 226 | "--output_dir", 227 | default=None, 228 | type=str, 229 | required=True, 230 | help="The output directory where the training data will be written", 231 | ) 232 | 233 | parser.add_argument( 234 | "--cache_dir", 235 | default=None, 236 | type=str, 237 | required=True, 238 | help="The directory where cached data will be written", 239 | ) 240 | 241 | parser.add_argument( 242 | "--max_seq_length", 243 | default=512, 244 | type=int, 245 | help= 246 | "The maximum total input sequence length after tokenization. Sequences longer " 247 | "than this will be truncated, sequences shorter will be padded.", 248 | ) 249 | 250 | parser.add_argument( 251 | "--max_query_length", 252 | default=64, 253 | type=int, 254 | help= 255 | "The maximum total input sequence length after tokenization. Sequences longer " 256 | "than this will be truncated, sequences shorter will be padded.", 257 | ) 258 | 259 | parser.add_argument( 260 | "--max_doc_character", 261 | default=10000, 262 | type=int, 263 | help="used before tokenizer to save tokenizer latency", 264 | ) 265 | 266 | parser.add_argument( 267 | "--per_gpu_eval_batch_size", 268 | default=64, 269 | type=int, 270 | help="The starting output file number", 271 | ) 272 | 273 | parser.add_argument( 274 | "--no_cuda", 275 | action="store_true", 276 | help="Avoid using CUDA when available", 277 | ) 278 | 279 | parser.add_argument( 280 | "--local_rank", 281 | type=int, 282 | default=-1, 283 | help="For distributed training: local_rank", 284 | ) 285 | 286 | parser.add_argument( 287 | "--server_ip", 288 | type=str, 289 | default="", 290 | help="For distant debugging.", 291 | ) 292 | 293 | parser.add_argument( 294 | "--server_port", 295 | type=str, 296 | default="", 297 | help="For distant debugging.", 298 | ) 299 | 300 | args = parser.parse_args() 301 | 302 | return args 303 | 304 | 305 | def set_env(args): 306 | # Setup CUDA, GPU & distributed training 307 | if args.local_rank == -1 or args.no_cuda: 308 | device = torch.device("cuda" if torch.cuda.is_available() 309 | and not args.no_cuda else "cpu") 310 | args.n_gpu = torch.cuda.device_count() 311 | else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs 312 | torch.cuda.set_device(args.local_rank) 313 | device = torch.device("cuda", args.local_rank) 314 | torch.distributed.init_process_group(backend="nccl") 315 | args.n_gpu = 1 316 | args.device = device 317 | 318 | # store args 319 | if args.local_rank != -1: 320 | args.world_size = torch.distributed.get_world_size() 321 | args.rank = dist.get_rank() 322 | 323 | # Setup logging 324 | logging.basicConfig( 325 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 326 | datefmt="%m/%d/%Y %H:%M:%S", 327 | level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN, 328 | ) 329 | logger.warning( 330 | "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s", 331 | args.local_rank, 332 | device, 333 | args.n_gpu, 334 | bool(args.local_rank != -1), 335 | ) 336 | 337 | 338 | def ann_data_gen(args): 339 | 340 | logger.info("start generate ann data") 341 | generate_new_ann( 342 | args, 343 | args.checkpoint, 344 | ) 345 | 346 | if args.local_rank != -1: 347 | dist.barrier() 348 | 349 | 350 | def main(): 351 | args = get_arguments() 352 | set_env(args) 353 | ann_data_gen(args) 354 | 355 | 356 | if __name__ == "__main__": 357 | main() 358 | -------------------------------------------------------------------------------- /model/models.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | sys.path += ['../'] 4 | import torch 5 | from torch import nn 6 | from transformers import (RobertaConfig, RobertaModel, 7 | RobertaForSequenceClassification, RobertaTokenizer, 8 | BertModel, BertTokenizer, BertConfig) 9 | import torch.nn.functional as F 10 | from data.process_fn import triple_process_fn, triple2dual_process_fn 11 | 12 | 13 | class EmbeddingMixin: 14 | """ 15 | Mixin for common functions in most embedding models. Each model should define its own bert-like backbone and forward. 16 | We inherit from RobertaModel to use from_pretrained 17 | """ 18 | def __init__(self, model_argobj): 19 | if model_argobj is None: 20 | self.use_mean = False 21 | else: 22 | self.use_mean = model_argobj.use_mean 23 | print("Using mean:", self.use_mean) 24 | 25 | def _init_weights(self, module): 26 | """ Initialize the weights """ 27 | if isinstance(module, (nn.Linear, nn.Embedding, nn.Conv1d)): 28 | # Slightly different from the TF version which uses truncated_normal for initialization 29 | # cf https://github.com/pytorch/pytorch/pull/5617 30 | module.weight.data.normal_(mean=0.0, std=0.02) 31 | 32 | def masked_mean(self, t, mask): 33 | s = torch.sum(t * mask.unsqueeze(-1).float(), axis=1) 34 | d = mask.sum(axis=1, keepdim=True).float() 35 | return s / d 36 | 37 | def masked_mean_or_first(self, emb_all, mask): 38 | # emb_all is a tuple from bert - sequence output, pooler 39 | assert isinstance(emb_all, tuple) 40 | if self.use_mean: 41 | return self.masked_mean(emb_all[0], mask) 42 | else: 43 | return emb_all[0][:, 0] 44 | 45 | def query_emb(self, input_ids, attention_mask): 46 | raise NotImplementedError("Please Implement this method") 47 | 48 | def body_emb(self, input_ids, attention_mask): 49 | raise NotImplementedError("Please Implement this method") 50 | 51 | 52 | class NLL(EmbeddingMixin): 53 | def forward(self, 54 | query_ids, 55 | attention_mask_q, 56 | input_ids_a=None, 57 | attention_mask_a=None, 58 | input_ids_b=None, 59 | attention_mask_b=None, 60 | is_query=True): 61 | if input_ids_b is None and is_query: 62 | return self.query_emb(query_ids, attention_mask_q) 63 | elif input_ids_b is None: 64 | return self.body_emb(query_ids, attention_mask_q) 65 | 66 | q_embs = self.query_emb(query_ids, attention_mask_q) 67 | a_embs = self.body_emb(input_ids_a, attention_mask_a) 68 | b_embs = self.body_emb(input_ids_b, attention_mask_b) 69 | 70 | logit_matrix = torch.cat([(q_embs * a_embs).sum(-1).unsqueeze(1), 71 | (q_embs * b_embs).sum(-1).unsqueeze(1)], 72 | dim=1) # [B, 2] 73 | lsm = F.log_softmax(logit_matrix, dim=1) 74 | loss = -1.0 * lsm[:, 0] 75 | return (loss.mean(), ) 76 | 77 | 78 | class NLL_MultiChunk(EmbeddingMixin): 79 | def forward(self, 80 | query_ids, 81 | attention_mask_q, 82 | input_ids_a=None, 83 | attention_mask_a=None, 84 | input_ids_b=None, 85 | attention_mask_b=None, 86 | is_query=True): 87 | if input_ids_b is None and is_query: 88 | return self.query_emb(query_ids, attention_mask_q) 89 | elif input_ids_b is None: 90 | return self.body_emb(query_ids, attention_mask_q) 91 | 92 | q_embs = self.query_emb(query_ids, attention_mask_q) 93 | a_embs = self.body_emb(input_ids_a, attention_mask_a) 94 | b_embs = self.body_emb(input_ids_b, attention_mask_b) 95 | 96 | [batchS, full_length] = input_ids_a.size() 97 | chunk_factor = full_length // self.base_len 98 | 99 | # special handle of attention mask ----- 100 | attention_mask_body = attention_mask_a.reshape( 101 | batchS, chunk_factor, -1)[:, :, 0] # [batchS, chunk_factor] 102 | inverted_bias = ((1 - attention_mask_body) * (-9999)).float() 103 | 104 | a12 = torch.matmul(q_embs.unsqueeze(1), 105 | a_embs.transpose(1, 2)) # [batch, 1, chunk_factor] 106 | logits_a = (a12[:, 0, :] + inverted_bias).max( 107 | dim=-1, keepdim=False).values # [batch] 108 | # ------------------------------------- 109 | 110 | # special handle of attention mask ----- 111 | attention_mask_body = attention_mask_b.reshape( 112 | batchS, chunk_factor, -1)[:, :, 0] # [batchS, chunk_factor] 113 | inverted_bias = ((1 - attention_mask_body) * (-9999)).float() 114 | 115 | a12 = torch.matmul(q_embs.unsqueeze(1), 116 | b_embs.transpose(1, 2)) # [batch, 1, chunk_factor] 117 | logits_b = (a12[:, 0, :] + inverted_bias).max( 118 | dim=-1, keepdim=False).values # [batch] 119 | # ------------------------------------- 120 | 121 | logit_matrix = torch.cat( 122 | [logits_a.unsqueeze(1), 123 | logits_b.unsqueeze(1)], dim=1) # [B, 2] 124 | lsm = F.log_softmax(logit_matrix, dim=1) 125 | loss = -1.0 * lsm[:, 0] 126 | return (loss.mean(), ) 127 | 128 | 129 | class RobertaDot_NLL_LN(NLL, RobertaForSequenceClassification): 130 | """None 131 | Compress embedding to 200d, then computes NLL loss. 132 | """ 133 | def __init__(self, config, model_argobj=None): 134 | NLL.__init__(self, model_argobj) 135 | RobertaForSequenceClassification.__init__(self, config) 136 | self.embeddingHead = nn.Linear(config.hidden_size, 768) 137 | self.norm = nn.LayerNorm(768) 138 | self.apply(self._init_weights) 139 | 140 | def query_emb(self, input_ids, attention_mask): 141 | outputs1 = self.roberta(input_ids=input_ids, 142 | attention_mask=attention_mask) 143 | full_emb = self.masked_mean_or_first(outputs1, attention_mask) 144 | query1 = self.norm(self.embeddingHead(full_emb)) 145 | return query1 146 | 147 | def body_emb(self, input_ids, attention_mask): 148 | return self.query_emb(input_ids, attention_mask) 149 | 150 | 151 | class RobertaDot_NLL_LN_Inference(RobertaDot_NLL_LN): 152 | def __init__(self, config, model_argobj=None): 153 | RobertaDot_NLL_LN.__init__(self, config, model_argobj=model_argobj) 154 | 155 | def forward(self, input_ids, attention_mask): 156 | return self.query_emb(input_ids, attention_mask) 157 | 158 | 159 | class RobertaDot_CLF_ANN_NLL_MultiChunk(NLL_MultiChunk, RobertaDot_NLL_LN): 160 | def __init__(self, config): 161 | RobertaDot_NLL_LN.__init__(self, config) 162 | self.base_len = 512 163 | 164 | def body_emb(self, input_ids, attention_mask): 165 | [batchS, full_length] = input_ids.size() 166 | chunk_factor = full_length // self.base_len 167 | 168 | input_seq = input_ids.reshape(batchS, chunk_factor, 169 | full_length // chunk_factor).reshape( 170 | batchS * chunk_factor, 171 | full_length // chunk_factor) 172 | attention_mask_seq = attention_mask.reshape( 173 | batchS, chunk_factor, 174 | full_length // chunk_factor).reshape(batchS * chunk_factor, 175 | full_length // chunk_factor) 176 | 177 | outputs_k = self.roberta(input_ids=input_seq, 178 | attention_mask=attention_mask_seq) 179 | 180 | compressed_output_k = self.embeddingHead( 181 | outputs_k[0]) # [batch, len, dim] 182 | compressed_output_k = self.norm(compressed_output_k[:, 0, :]) 183 | 184 | [batch_expand, embeddingS] = compressed_output_k.size() 185 | complex_emb_k = compressed_output_k.reshape(batchS, chunk_factor, 186 | embeddingS) 187 | 188 | return complex_emb_k # size [batchS, chunk_factor, embeddingS] 189 | 190 | 191 | class HFBertEncoder(BertModel): 192 | def __init__(self, config): 193 | BertModel.__init__(self, config) 194 | assert config.hidden_size > 0, 'Encoder hidden_size can\'t be zero' 195 | self.init_weights() 196 | 197 | @classmethod 198 | def init_encoder(cls, args, dropout: float = 0.1): 199 | cfg = BertConfig.from_pretrained("bert-base-uncased") 200 | if dropout != 0: 201 | cfg.attention_probs_dropout_prob = dropout 202 | cfg.hidden_dropout_prob = dropout 203 | return cls.from_pretrained("bert-base-uncased") 204 | 205 | def forward(self, input_ids, attention_mask): 206 | hidden_states = None 207 | 208 | sequence_output, pooled_output = super().forward( 209 | input_ids=input_ids, attention_mask=attention_mask) 210 | pooled_output = sequence_output[:, 0, :] 211 | return sequence_output, pooled_output, hidden_states 212 | 213 | def get_out_size(self): 214 | if self.encode_proj: 215 | return self.encode_proj.out_features 216 | return self.config.hidden_size 217 | 218 | 219 | class BiEncoder(nn.Module): 220 | """ Bi-Encoder model component. Encapsulates query/question and context/passage encoders. 221 | """ 222 | def __init__(self, args): 223 | super(BiEncoder, self).__init__() 224 | self.question_model = HFBertEncoder.init_encoder(args) 225 | self.ctx_model = HFBertEncoder.init_encoder(args) 226 | 227 | def query_emb(self, input_ids, attention_mask): 228 | sequence_output, pooled_output, hidden_states = self.question_model( 229 | input_ids, attention_mask) 230 | return pooled_output 231 | 232 | def body_emb(self, input_ids, attention_mask): 233 | sequence_output, pooled_output, hidden_states = self.ctx_model( 234 | input_ids, attention_mask) 235 | return pooled_output 236 | 237 | def forward(self, 238 | query_ids, 239 | attention_mask_q, 240 | input_ids_a=None, 241 | attention_mask_a=None, 242 | input_ids_b=None, 243 | attention_mask_b=None, 244 | is_query=True): 245 | if input_ids_b is None: 246 | if input_ids_a is None: 247 | return self.query_emb( 248 | query_ids, 249 | attention_mask_q) if is_query else self.body_emb( 250 | query_ids, attention_mask_q) 251 | q_embs = self.query_emb(query_ids, attention_mask_q) 252 | a_embs = self.body_emb(input_ids_a, attention_mask_a) 253 | return (q_embs, a_embs) 254 | q_embs = self.query_emb(query_ids, attention_mask_q) 255 | a_embs = self.body_emb(input_ids_a, attention_mask_a) 256 | b_embs = self.body_emb(input_ids_b, attention_mask_b) 257 | logit_matrix = torch.cat([(q_embs * a_embs).sum(-1).unsqueeze(1), 258 | (q_embs * b_embs).sum(-1).unsqueeze(1)], 259 | dim=1) #[B, 2] 260 | lsm = F.log_softmax(logit_matrix, dim=1) 261 | loss = -1.0 * lsm[:, 0] 262 | return (loss.mean(), ) 263 | 264 | 265 | # -------------------------------------------------- 266 | ALL_MODELS = sum( 267 | (tuple(conf.pretrained_config_archive_map.keys()) 268 | for conf in (RobertaConfig, )), 269 | (), 270 | ) 271 | 272 | default_process_fn = triple_process_fn 273 | 274 | 275 | class MSMarcoConfig: 276 | def __init__(self, 277 | name, 278 | model, 279 | process_fn=default_process_fn, 280 | use_mean=True, 281 | tokenizer_class=RobertaTokenizer, 282 | config_class=RobertaConfig): 283 | self.name = name 284 | self.process_fn = process_fn 285 | self.model_class = model 286 | self.use_mean = use_mean 287 | self.tokenizer_class = tokenizer_class 288 | self.config_class = config_class 289 | 290 | 291 | configs = [ 292 | MSMarcoConfig( 293 | name="rdot_nll", 294 | model=RobertaDot_NLL_LN, 295 | use_mean=False, 296 | ), 297 | MSMarcoConfig( 298 | name="rdot_nll_multi_chunk", 299 | model=RobertaDot_CLF_ANN_NLL_MultiChunk, 300 | use_mean=False, 301 | ), 302 | MSMarcoConfig( 303 | name="dpr", 304 | model=BiEncoder, 305 | tokenizer_class=BertTokenizer, 306 | config_class=BertConfig, 307 | use_mean=False, 308 | ), 309 | ] 310 | 311 | MSMarcoConfigDict = {cfg.name: cfg for cfg in configs} 312 | -------------------------------------------------------------------------------- /utils/dpr_utils.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import sys 3 | sys.path += ['../'] 4 | import glob 5 | import logging 6 | import os 7 | from typing import List, Tuple, Dict 8 | import faiss 9 | import pickle 10 | import numpy as np 11 | import unicodedata 12 | import torch 13 | import torch.distributed as dist 14 | from torch import nn 15 | from torch.serialization import default_restore_location 16 | import regex 17 | from transformers import AdamW 18 | # from utils.lamb import Lamb 19 | 20 | 21 | logger = logging.getLogger() 22 | 23 | CheckpointState = collections.namedtuple("CheckpointState", 24 | ['model_dict', 'optimizer_dict', 'scheduler_dict', 'offset', 'epoch', 25 | 'encoder_params']) 26 | 27 | def get_encoder_checkpoint_params_names(): 28 | return ['do_lower_case', 'pretrained_model_cfg', 'encoder_model_type', 29 | 'pretrained_file', 30 | 'projection_dim', 'sequence_length'] 31 | 32 | def get_encoder_params_state(args): 33 | """ 34 | Selects the param values to be saved in a checkpoint, so that a trained model faile can be used for downstream 35 | tasks without the need to specify these parameter again 36 | :return: Dict of params to memorize in a checkpoint 37 | """ 38 | params_to_save = get_encoder_checkpoint_params_names() 39 | 40 | r = {} 41 | for param in params_to_save: 42 | r[param] = getattr(args, param) 43 | return r 44 | 45 | def set_encoder_params_from_state(state, args): 46 | if not state: 47 | return 48 | params_to_save = get_encoder_checkpoint_params_names() 49 | 50 | override_params = [(param, state[param]) for param in params_to_save if param in state and state[param]] 51 | for param, value in override_params: 52 | if hasattr(args, param): 53 | logger.warning('Overriding args parameter value from checkpoint state. Param = %s, value = %s', param, 54 | value) 55 | setattr(args, param, value) 56 | return args 57 | 58 | def get_model_obj(model: nn.Module): 59 | return model.module if hasattr(model, 'module') else model 60 | 61 | 62 | def get_model_file(args, file_prefix) -> str: 63 | out_cp_files = glob.glob(os.path.join(args.output_dir, file_prefix + '*')) if args.output_dir else [] 64 | logger.info('Checkpoint files %s', out_cp_files) 65 | model_file = None 66 | 67 | if args.model_file and os.path.exists(args.model_file): 68 | model_file = args.model_file 69 | elif len(out_cp_files) > 0: 70 | model_file = max(out_cp_files, key=os.path.getctime) 71 | return model_file 72 | 73 | 74 | def load_states_from_checkpoint(model_file: str) -> CheckpointState: 75 | logger.info('Reading saved model from %s', model_file) 76 | state_dict = torch.load(model_file, map_location=lambda s, l: default_restore_location(s, 'cpu')) 77 | logger.info('model_state_dict keys %s', state_dict.keys()) 78 | return CheckpointState(**state_dict) 79 | 80 | def get_optimizer(args, model: nn.Module, weight_decay: float = 0.0, ) -> torch.optim.Optimizer: 81 | no_decay = ['bias', 'LayerNorm.weight'] 82 | optimizer_grouped_parameters = [ 83 | {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 84 | 'weight_decay': weight_decay}, 85 | {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} 86 | ] 87 | return AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) 88 | 89 | 90 | def all_gather_list(data, group=None, max_size=16384): 91 | """Gathers arbitrary data from all nodes into a list. 92 | Similar to :func:`~torch.distributed.all_gather` but for arbitrary Python 93 | data. Note that *data* must be picklable. 94 | Args: 95 | data (Any): data from the local worker to be gathered on other workers 96 | group (optional): group of the collective 97 | """ 98 | SIZE_STORAGE_BYTES = 4 # int32 to encode the payload size 99 | 100 | enc = pickle.dumps(data) 101 | enc_size = len(enc) 102 | 103 | if enc_size + SIZE_STORAGE_BYTES > max_size: 104 | raise ValueError( 105 | 'encoded data exceeds max_size, this can be fixed by increasing buffer size: {}'.format(enc_size)) 106 | 107 | rank = dist.get_rank() 108 | world_size = dist.get_world_size() 109 | buffer_size = max_size * world_size 110 | 111 | if not hasattr(all_gather_list, '_buffer') or \ 112 | all_gather_list._buffer.numel() < buffer_size: 113 | all_gather_list._buffer = torch.cuda.ByteTensor(buffer_size) 114 | all_gather_list._cpu_buffer = torch.ByteTensor(max_size).pin_memory() 115 | 116 | buffer = all_gather_list._buffer 117 | buffer.zero_() 118 | cpu_buffer = all_gather_list._cpu_buffer 119 | 120 | assert enc_size < 256 ** SIZE_STORAGE_BYTES, 'Encoded object size should be less than {} bytes'.format( 121 | 256 ** SIZE_STORAGE_BYTES) 122 | 123 | size_bytes = enc_size.to_bytes(SIZE_STORAGE_BYTES, byteorder='big') 124 | 125 | cpu_buffer[0:SIZE_STORAGE_BYTES] = torch.ByteTensor(list(size_bytes)) 126 | cpu_buffer[SIZE_STORAGE_BYTES: enc_size + SIZE_STORAGE_BYTES] = torch.ByteTensor(list(enc)) 127 | 128 | start = rank * max_size 129 | size = enc_size + SIZE_STORAGE_BYTES 130 | buffer[start: start + size].copy_(cpu_buffer[:size]) 131 | 132 | if group is None: 133 | group = dist.group.WORLD 134 | dist.all_reduce(buffer, group=group) 135 | 136 | try: 137 | result = [] 138 | for i in range(world_size): 139 | out_buffer = buffer[i * max_size: (i + 1) * max_size] 140 | size = int.from_bytes(out_buffer[0:SIZE_STORAGE_BYTES], byteorder='big') 141 | if size > 0: 142 | result.append(pickle.loads(bytes(out_buffer[SIZE_STORAGE_BYTES: size + SIZE_STORAGE_BYTES].tolist()))) 143 | return result 144 | except pickle.UnpicklingError: 145 | raise Exception( 146 | 'Unable to unpickle data from other workers. all_gather_list requires all ' 147 | 'workers to enter the function together, so this error usually indicates ' 148 | 'that the workers have fallen out of sync somehow. Workers can fall out of ' 149 | 'sync if one of them runs out of memory, or if there are other conditions ' 150 | 'in your training script that can cause one worker to finish an epoch ' 151 | 'while other workers are still iterating over their portions of the data.' 152 | ) 153 | 154 | 155 | 156 | class DenseHNSWFlatIndexer(object): 157 | """ 158 | Efficient index for retrieval. Note: default settings are for hugh accuracy but also high RAM usage 159 | """ 160 | 161 | def __init__(self, vector_sz: int, buffer_size: int = 50000, store_n: int = 512 162 | , ef_search: int = 128, ef_construction: int = 200): 163 | self.buffer_size = buffer_size 164 | self.index_id_to_db_id = [] 165 | self.index = None 166 | 167 | # IndexHNSWFlat supports L2 similarity only 168 | # so we have to apply DOT -> L2 similairy space conversion with the help of an extra dimension 169 | index = faiss.IndexHNSWFlat(vector_sz + 1, store_n) 170 | index.hnsw.efSearch = ef_search 171 | index.hnsw.efConstruction = ef_construction 172 | self.index = index 173 | self.phi = 0 174 | 175 | def index_data(self, data: List[Tuple[object, np.array]]): 176 | n = len(data) 177 | 178 | # max norm is required before putting all vectors in the index to convert inner product similarity to L2 179 | if self.phi > 0: 180 | raise RuntimeError('DPR HNSWF index needs to index all data at once,' 181 | 'results will be unpredictable otherwise.') 182 | phi = 0 183 | for i, item in enumerate(data): 184 | id, doc_vector = item 185 | norms = (doc_vector ** 2).sum() 186 | phi = max(phi, norms) 187 | logger.info('HNSWF DotProduct -> L2 space phi={}'.format(phi)) 188 | self.phi = 0 189 | 190 | # indexing in batches is beneficial for many faiss index types 191 | for i in range(0, n, self.buffer_size): 192 | db_ids = [t[0] for t in data[i:i + self.buffer_size]] 193 | vectors = [np.reshape(t[1], (1, -1)) for t in data[i:i + self.buffer_size]] 194 | 195 | norms = [(doc_vector ** 2).sum() for doc_vector in vectors] 196 | aux_dims = [np.sqrt(phi - norm) for norm in norms] 197 | hnsw_vectors = [np.hstack((doc_vector, aux_dims[i].reshape(-1, 1))) for i, doc_vector in 198 | enumerate(vectors)] 199 | hnsw_vectors = np.concatenate(hnsw_vectors, axis=0) 200 | 201 | self._update_id_mapping(db_ids) 202 | self.index.add(hnsw_vectors) 203 | logger.info('data indexed %d', len(self.index_id_to_db_id)) 204 | 205 | indexed_cnt = len(self.index_id_to_db_id) 206 | logger.info('Total data indexed %d', indexed_cnt) 207 | 208 | def search_knn(self, query_vectors: np.array, top_docs: int) -> List[Tuple[List[object], List[float]]]: 209 | 210 | aux_dim = np.zeros(len(query_vectors), dtype='float32') 211 | query_nhsw_vectors = np.hstack((query_vectors, aux_dim.reshape(-1, 1))) 212 | logger.info('query_hnsw_vectors %s', query_nhsw_vectors.shape) 213 | scores, indexes = self.index.search(query_nhsw_vectors, top_docs) 214 | # convert to external ids 215 | db_ids = [[self.index_id_to_db_id[i] for i in query_top_idxs] for query_top_idxs in indexes] 216 | result = [(db_ids[i], scores[i]) for i in range(len(db_ids))] 217 | return result 218 | 219 | def _update_id_mapping(self, db_ids: List): 220 | self.index_id_to_db_id.extend(db_ids) 221 | 222 | 223 | 224 | def check_answer(passages, answers, doc_ids, tokenizer): 225 | """Search through all the top docs to see if they have any of the answers.""" 226 | hits = [] 227 | for i, doc_id in enumerate(doc_ids): 228 | text = passages[doc_id][0] 229 | hits.append(has_answer(answers, text, tokenizer)) 230 | return hits 231 | 232 | 233 | def has_answer(answers, text, tokenizer): 234 | """Check if a document contains an answer string. 235 | If `match_type` is string, token matching is done between the text and answer. 236 | If `match_type` is regex, we search the whole text with the regex. 237 | """ 238 | 239 | if text is None: 240 | logger.warning("no doc in db") 241 | return False 242 | 243 | text = _normalize(text) 244 | 245 | # Answer is a list of possible strings 246 | text = tokenizer.tokenize(text).words(uncased=True) 247 | 248 | for single_answer in answers: 249 | single_answer = _normalize(single_answer) 250 | single_answer = tokenizer.tokenize(single_answer) 251 | single_answer = single_answer.words(uncased=True) 252 | 253 | for i in range(0, len(text) - len(single_answer) + 1): 254 | if single_answer == text[i: i + len(single_answer)]: 255 | return True 256 | return False 257 | 258 | 259 | class SimpleTokenizer: 260 | ALPHA_NUM = r'[\p{L}\p{N}\p{M}]+' 261 | NON_WS = r'[^\p{Z}\p{C}]' 262 | 263 | def __init__(self, **kwargs): 264 | """ 265 | Args: 266 | annotators: None or empty set (only tokenizes). 267 | """ 268 | self._regexp = regex.compile( 269 | '(%s)|(%s)' % (self.ALPHA_NUM, self.NON_WS), 270 | flags=regex.IGNORECASE + regex.UNICODE + regex.MULTILINE 271 | ) 272 | if len(kwargs.get('annotators', {})) > 0: 273 | logger.warning('%s only tokenizes! Skipping annotators: %s' % 274 | (type(self).__name__, kwargs.get('annotators'))) 275 | self.annotators = set() 276 | 277 | def tokenize(self, text): 278 | data = [] 279 | matches = [m for m in self._regexp.finditer(text)] 280 | for i in range(len(matches)): 281 | # Get text 282 | token = matches[i].group() 283 | 284 | # Get whitespace 285 | span = matches[i].span() 286 | start_ws = span[0] 287 | if i + 1 < len(matches): 288 | end_ws = matches[i + 1].span()[0] 289 | else: 290 | end_ws = span[1] 291 | 292 | # Format data 293 | data.append(( 294 | token, 295 | text[start_ws: end_ws], 296 | span, 297 | )) 298 | return Tokens(data, self.annotators) 299 | 300 | 301 | def _normalize(text): 302 | return unicodedata.normalize('NFD', text) 303 | 304 | 305 | class Tokens(object): 306 | """A class to represent a list of tokenized text.""" 307 | TEXT = 0 308 | TEXT_WS = 1 309 | SPAN = 2 310 | POS = 3 311 | LEMMA = 4 312 | NER = 5 313 | 314 | def __init__(self, data, annotators, opts=None): 315 | self.data = data 316 | self.annotators = annotators 317 | self.opts = opts or {} 318 | 319 | def __len__(self): 320 | """The number of tokens.""" 321 | return len(self.data) 322 | 323 | def words(self, uncased=False): 324 | """Returns a list of the text of each token 325 | 326 | Args: 327 | uncased: lower cases text 328 | """ 329 | if uncased: 330 | return [t[self.TEXT].lower() for t in self.data] 331 | else: 332 | return [t[self.TEXT] for t in self.data] 333 | -------------------------------------------------------------------------------- /data/gen_ranking_data.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import copy 3 | import json 4 | import csv 5 | import os 6 | import pickle 7 | import random 8 | from utils.util import load_collection 9 | 10 | # For CAsT cross-validation, I manually split the data into five folds to ensure the balance of judged queries in each fold. 11 | qid_to_fold_test = { 12 | "31_1": 0, 13 | "31_2": 0, 14 | "31_3": 0, 15 | "31_4": 0, 16 | "31_5": 0, 17 | "31_6": 0, 18 | "31_7": 0, 19 | "31_8": 0, 20 | "31_9": 0, 21 | "32_1": 0, 22 | "32_2": 0, 23 | "32_3": 0, 24 | "32_4": 0, 25 | "32_5": 0, 26 | "32_6": 0, 27 | "32_7": 0, 28 | "32_8": 0, 29 | "32_9": 0, 30 | "32_10": 0, 31 | "32_11": 0, 32 | "33_1": 0, 33 | "33_2": 0, 34 | "33_3": 0, 35 | "33_4": 0, 36 | "33_5": 0, 37 | "33_6": 0, 38 | "33_7": 0, 39 | "33_8": 0, 40 | "33_9": 0, 41 | "33_10": 0, 42 | "34_1": 0, 43 | "34_2": 0, 44 | "34_3": 0, 45 | "34_4": 0, 46 | "34_5": 0, 47 | "34_6": 0, 48 | "34_7": 0, 49 | "34_8": 0, 50 | "34_9": 0, 51 | "35_1": 0, 52 | "35_2": 0, 53 | "35_3": 0, 54 | "35_4": 0, 55 | "35_5": 0, 56 | "35_6": 0, 57 | "35_7": 0, 58 | "35_8": 0, 59 | "35_9": 0, 60 | "36_1": 0, 61 | "36_2": 0, 62 | "36_3": 0, 63 | "36_4": 0, 64 | "36_5": 0, 65 | "36_6": 0, 66 | "36_7": 0, 67 | "36_8": 0, 68 | "36_9": 0, 69 | "36_10": 0, 70 | "36_11": 0, 71 | "37_1": 1, 72 | "37_2": 1, 73 | "37_3": 1, 74 | "37_4": 1, 75 | "37_5": 1, 76 | "37_6": 1, 77 | "37_7": 1, 78 | "37_8": 1, 79 | "37_9": 1, 80 | "37_10": 1, 81 | "37_11": 1, 82 | "37_12": 1, 83 | "38_1": 0, 84 | "38_2": 0, 85 | "38_3": 0, 86 | "38_4": 0, 87 | "38_5": 0, 88 | "38_6": 0, 89 | "38_7": 0, 90 | "38_8": 0, 91 | "39_1": 0, 92 | "39_2": 0, 93 | "39_3": 0, 94 | "39_4": 0, 95 | "39_5": 0, 96 | "39_6": 0, 97 | "39_7": 0, 98 | "39_8": 0, 99 | "39_9": 0, 100 | "40_1": 1, 101 | "40_2": 1, 102 | "40_3": 1, 103 | "40_4": 1, 104 | "40_5": 1, 105 | "40_6": 1, 106 | "40_7": 1, 107 | "40_8": 1, 108 | "40_9": 1, 109 | "40_10": 1, 110 | "41_1": 1, 111 | "41_2": 1, 112 | "41_3": 1, 113 | "41_4": 1, 114 | "41_5": 1, 115 | "41_6": 1, 116 | "41_7": 1, 117 | "41_8": 1, 118 | "41_9": 1, 119 | "42_1": 1, 120 | "42_2": 1, 121 | "42_3": 1, 122 | "42_4": 1, 123 | "42_5": 1, 124 | "42_6": 1, 125 | "42_7": 1, 126 | "42_8": 1, 127 | "43_1": 1, 128 | "43_2": 1, 129 | "43_3": 1, 130 | "43_4": 1, 131 | "43_5": 1, 132 | "43_6": 1, 133 | "43_7": 1, 134 | "43_8": 1, 135 | "44_1": 1, 136 | "44_2": 1, 137 | "44_3": 1, 138 | "44_4": 1, 139 | "44_5": 1, 140 | "44_6": 1, 141 | "44_7": 1, 142 | "44_8": 1, 143 | "45_1": 1, 144 | "45_2": 1, 145 | "45_3": 1, 146 | "45_4": 1, 147 | "45_5": 1, 148 | "45_6": 1, 149 | "45_7": 1, 150 | "45_8": 1, 151 | "46_1": 1, 152 | "46_2": 1, 153 | "46_3": 1, 154 | "46_4": 1, 155 | "46_5": 1, 156 | "46_6": 1, 157 | "46_7": 1, 158 | "46_8": 1, 159 | "46_9": 1, 160 | "46_10": 1, 161 | "47_1": 1, 162 | "47_2": 1, 163 | "47_3": 1, 164 | "47_4": 1, 165 | "47_5": 1, 166 | "47_6": 1, 167 | "47_7": 1, 168 | "48_1": 1, 169 | "48_2": 1, 170 | "48_3": 1, 171 | "48_4": 1, 172 | "48_5": 1, 173 | "48_6": 1, 174 | "48_7": 1, 175 | "48_8": 1, 176 | "48_9": 1, 177 | "49_1": 1, 178 | "49_2": 1, 179 | "49_3": 1, 180 | "49_4": 1, 181 | "49_5": 1, 182 | "49_6": 1, 183 | "49_7": 1, 184 | "49_8": 1, 185 | "49_9": 1, 186 | "49_10": 1, 187 | "50_1": 1, 188 | "50_2": 1, 189 | "50_3": 1, 190 | "50_4": 1, 191 | "50_5": 1, 192 | "50_6": 1, 193 | "50_7": 1, 194 | "50_8": 1, 195 | "50_9": 1, 196 | "50_10": 1, 197 | "51_1": 2, 198 | "51_2": 2, 199 | "51_3": 2, 200 | "51_4": 2, 201 | "51_5": 2, 202 | "51_6": 2, 203 | "51_7": 2, 204 | "51_8": 2, 205 | "51_9": 2, 206 | "51_10": 2, 207 | "52_1": 2, 208 | "52_2": 2, 209 | "52_3": 2, 210 | "52_4": 2, 211 | "52_5": 2, 212 | "52_6": 2, 213 | "52_7": 2, 214 | "52_8": 2, 215 | "52_9": 2, 216 | "52_10": 2, 217 | "53_1": 2, 218 | "53_2": 2, 219 | "53_3": 2, 220 | "53_4": 2, 221 | "53_5": 2, 222 | "53_6": 2, 223 | "53_7": 2, 224 | "53_8": 2, 225 | "53_9": 2, 226 | "54_1": 2, 227 | "54_2": 2, 228 | "54_3": 2, 229 | "54_4": 2, 230 | "54_5": 2, 231 | "54_6": 2, 232 | "54_7": 2, 233 | "54_8": 2, 234 | "54_9": 2, 235 | "55_1": 2, 236 | "55_2": 2, 237 | "55_3": 2, 238 | "55_4": 2, 239 | "55_5": 2, 240 | "55_6": 2, 241 | "55_7": 2, 242 | "55_8": 2, 243 | "55_9": 2, 244 | "55_10": 2, 245 | "56_1": 2, 246 | "56_2": 2, 247 | "56_3": 2, 248 | "56_4": 2, 249 | "56_5": 2, 250 | "56_6": 2, 251 | "56_7": 2, 252 | "56_8": 2, 253 | "57_1": 2, 254 | "57_2": 2, 255 | "57_3": 2, 256 | "57_4": 2, 257 | "57_5": 2, 258 | "57_6": 2, 259 | "57_7": 2, 260 | "57_8": 2, 261 | "57_9": 2, 262 | "57_10": 2, 263 | "58_1": 2, 264 | "58_2": 2, 265 | "58_3": 2, 266 | "58_4": 2, 267 | "58_5": 2, 268 | "58_6": 2, 269 | "58_7": 2, 270 | "58_8": 2, 271 | "59_1": 2, 272 | "59_2": 2, 273 | "59_3": 2, 274 | "59_4": 2, 275 | "59_5": 2, 276 | "59_6": 2, 277 | "59_7": 2, 278 | "59_8": 2, 279 | "60_1": 2, 280 | "60_2": 2, 281 | "60_3": 2, 282 | "60_4": 2, 283 | "60_5": 2, 284 | "60_6": 2, 285 | "60_7": 2, 286 | "61_1": 4, 287 | "61_2": 4, 288 | "61_3": 4, 289 | "61_4": 4, 290 | "61_5": 4, 291 | "61_6": 4, 292 | "61_7": 4, 293 | "61_8": 4, 294 | "61_9": 4, 295 | "62_1": 3, 296 | "62_2": 3, 297 | "62_3": 3, 298 | "62_4": 3, 299 | "62_5": 3, 300 | "62_6": 3, 301 | "62_7": 3, 302 | "62_8": 3, 303 | "62_9": 3, 304 | "62_10": 3, 305 | "62_11": 3, 306 | "63_1": 3, 307 | "63_2": 3, 308 | "63_3": 3, 309 | "63_4": 3, 310 | "63_5": 3, 311 | "63_6": 3, 312 | "63_7": 3, 313 | "63_8": 3, 314 | "63_9": 3, 315 | "63_10": 3, 316 | "64_1": 3, 317 | "64_2": 3, 318 | "64_3": 3, 319 | "64_4": 3, 320 | "64_5": 3, 321 | "64_6": 3, 322 | "64_7": 3, 323 | "64_8": 3, 324 | "64_9": 3, 325 | "64_10": 3, 326 | "64_11": 3, 327 | "65_1": 3, 328 | "65_2": 3, 329 | "65_3": 3, 330 | "65_4": 3, 331 | "65_5": 3, 332 | "65_6": 3, 333 | "65_7": 3, 334 | "65_8": 3, 335 | "65_9": 3, 336 | "65_10": 3, 337 | "66_1": 3, 338 | "66_2": 3, 339 | "66_3": 3, 340 | "66_4": 3, 341 | "66_5": 3, 342 | "66_6": 3, 343 | "66_7": 3, 344 | "66_8": 3, 345 | "66_9": 3, 346 | "67_1": 3, 347 | "67_2": 3, 348 | "67_3": 3, 349 | "67_4": 3, 350 | "67_5": 3, 351 | "67_6": 3, 352 | "67_7": 3, 353 | "67_8": 3, 354 | "67_9": 3, 355 | "67_10": 3, 356 | "67_11": 3, 357 | "68_1": 3, 358 | "68_2": 3, 359 | "68_3": 3, 360 | "68_4": 3, 361 | "68_5": 3, 362 | "68_6": 3, 363 | "68_7": 3, 364 | "68_8": 3, 365 | "68_9": 3, 366 | "68_10": 3, 367 | "68_11": 3, 368 | "69_1": 3, 369 | "69_2": 3, 370 | "69_3": 3, 371 | "69_4": 3, 372 | "69_5": 3, 373 | "69_6": 3, 374 | "69_7": 3, 375 | "69_8": 3, 376 | "69_9": 3, 377 | "69_10": 3, 378 | "70_1": 3, 379 | "70_2": 3, 380 | "70_3": 3, 381 | "70_4": 3, 382 | "70_5": 3, 383 | "70_6": 3, 384 | "70_7": 3, 385 | "70_8": 3, 386 | "70_9": 3, 387 | "70_10": 3, 388 | "71_1": 4, 389 | "71_2": 4, 390 | "71_3": 4, 391 | "71_4": 4, 392 | "71_5": 4, 393 | "71_6": 4, 394 | "71_7": 4, 395 | "71_8": 4, 396 | "71_9": 4, 397 | "71_10": 4, 398 | "71_11": 4, 399 | "71_12": 4, 400 | "72_1": 4, 401 | "72_2": 4, 402 | "72_3": 4, 403 | "72_4": 4, 404 | "72_5": 4, 405 | "72_6": 4, 406 | "72_7": 4, 407 | "72_8": 4, 408 | "72_9": 4, 409 | "72_10": 4, 410 | "73_1": 4, 411 | "73_2": 4, 412 | "73_3": 4, 413 | "73_4": 4, 414 | "73_5": 4, 415 | "73_6": 4, 416 | "73_7": 4, 417 | "73_8": 4, 418 | "73_9": 4, 419 | "73_10": 4, 420 | "74_1": 4, 421 | "74_2": 4, 422 | "74_3": 4, 423 | "74_4": 4, 424 | "74_5": 4, 425 | "74_6": 4, 426 | "74_7": 4, 427 | "74_8": 4, 428 | "74_9": 4, 429 | "74_10": 4, 430 | "74_11": 4, 431 | "74_12": 4, 432 | "75_1": 4, 433 | "75_2": 4, 434 | "75_3": 4, 435 | "75_4": 4, 436 | "75_5": 4, 437 | "75_6": 4, 438 | "75_7": 4, 439 | "75_8": 4, 440 | "75_9": 4, 441 | "75_10": 4, 442 | "76_1": 4, 443 | "76_2": 4, 444 | "76_3": 4, 445 | "76_4": 4, 446 | "76_5": 4, 447 | "76_6": 4, 448 | "76_7": 4, 449 | "76_8": 4, 450 | "76_9": 4, 451 | "76_10": 4, 452 | "77_1": 4, 453 | "77_2": 4, 454 | "77_3": 4, 455 | "77_4": 4, 456 | "77_5": 4, 457 | "77_6": 4, 458 | "77_7": 4, 459 | "77_8": 4, 460 | "77_9": 4, 461 | "77_10": 4, 462 | "78_1": 4, 463 | "78_2": 4, 464 | "78_3": 4, 465 | "78_4": 4, 466 | "78_5": 4, 467 | "78_6": 4, 468 | "78_7": 4, 469 | "78_8": 4, 470 | "78_9": 4, 471 | "78_10": 4, 472 | "79_1": 4, 473 | "79_2": 4, 474 | "79_3": 4, 475 | "79_4": 4, 476 | "79_5": 4, 477 | "79_6": 4, 478 | "79_7": 4, 479 | "79_8": 4, 480 | "79_9": 4, 481 | "80_1": 4, 482 | "80_2": 4, 483 | "80_3": 4, 484 | "80_4": 4, 485 | "80_5": 4, 486 | "80_6": 4, 487 | "80_7": 4, 488 | "80_8": 4, 489 | "80_9": 4, 490 | "80_10": 4 491 | } 492 | 493 | if __name__ == "__main__": 494 | parser = argparse.ArgumentParser() 495 | parser.add_argument("--train", type=str) 496 | parser.add_argument("--run", type=str) 497 | parser.add_argument("--qrels", type=str) 498 | parser.add_argument("--output", type=str) 499 | parser.add_argument("--collection", type=str) 500 | parser.add_argument( 501 | "--cast", 502 | action="store_true", 503 | help= 504 | "Set this flag if you are working on the TREC CAsT dataset to enable cross-validation" 505 | ) 506 | parser.add_argument("--num_negs", type=int, default=9) 507 | args = parser.parse_args() 508 | 509 | print("Selecting negative documents...") 510 | query_positive_id = {} 511 | query_negative_id = {} 512 | with open(args.qrels, 'r', encoding='utf8') as f: 513 | tsvreader = csv.reader(f, delimiter="\t") 514 | for [topicid, _, docid, rel] in tsvreader: 515 | docid = int(docid) 516 | rel = int(rel) 517 | if rel > 0: 518 | if topicid not in query_positive_id: 519 | query_positive_id[topicid] = {} 520 | query_positive_id[topicid][docid] = rel 521 | else: 522 | query_positive_id[topicid][docid] = rel 523 | else: 524 | if topicid not in query_negative_id: 525 | query_negative_id[topicid] = [] 526 | query_negative_id[topicid].append(docid) 527 | else: 528 | query_negative_id[topicid].append(docid) 529 | 530 | with open(args.train, "r") as f: 531 | cqr = {} 532 | for line in f: 533 | obj = json.loads(line) 534 | qid = ( 535 | obj["topic_number"] + "_" + 536 | obj["query_number"]) if "topic_number" in obj else obj["qid"] 537 | cqr[qid] = obj 538 | 539 | # find negatives documents first in the retrieved & annotated negatives 540 | negatives = {} 541 | with open(args.run, "r") as f: 542 | for line in f: 543 | qid, _, pid, _, _, _ = line.strip().split() 544 | pid = int(pid) 545 | positive_ids = query_positive_id[ 546 | qid] if qid in query_positive_id else {} 547 | if positive_ids != {} and pid not in positive_ids: 548 | if qid in query_negative_id and pid in query_negative_id[qid]: 549 | if qid not in negatives: 550 | negatives[qid] = [pid] 551 | else: 552 | negatives[qid].append(pid) 553 | 554 | # not annotated negatives may be false negatives 555 | with open(args.run, "r") as f: 556 | for line in f: 557 | qid, _, pid, _, _, _ = line.strip().split() 558 | pid = int(pid) 559 | if qid in negatives and len(negatives[qid]) >= 20: 560 | continue 561 | positive_ids = query_positive_id[ 562 | qid] if qid in query_positive_id else {} 563 | if positive_ids != {} and pid not in positive_ids: 564 | if qid not in negatives: 565 | negatives[qid] = [pid] 566 | else: 567 | negatives[qid].append(pid) 568 | print(len(negatives)) 569 | 570 | print("Loading document collection...") 571 | all_passages = load_collection(args.collection) 572 | 573 | print("Writing to file...") 574 | items = copy.deepcopy(list(negatives.items())) 575 | random.shuffle(items) 576 | file_id = 0 577 | fs = None 578 | if args.cast: 579 | fs = [open(args.output + "." + str(x), "w") for x in range(5)] 580 | f = open(args.output, "w") 581 | for qid, negs in items: 582 | if qid not in query_positive_id: 583 | continue 584 | positives = query_positive_id[qid] 585 | max_positive = -1 586 | max_rel = -1 587 | for pos_id, rel in positives.items(): 588 | if rel > max_rel: 589 | max_rel = rel 590 | max_positive = pos_id 591 | sampled_negs = random.sample( 592 | negs, args.num_negs) if len(negs) > args.num_negs else negs 593 | cqr_record = cqr[qid] 594 | target_obj = copy.deepcopy(cqr_record) 595 | target_obj.update({ 596 | "doc_pos": all_passages[max_positive], 597 | "doc_pos_id": max_positive, 598 | "doc_negs": [all_passages[x] for x in sampled_negs], 599 | "doc_negs_id": [x for x in sampled_negs] 600 | }) 601 | to_write = json.dumps(target_obj) + "\n" 602 | if args.cast: 603 | fs[qid_to_fold_test[qid]].write(to_write) 604 | f.write(to_write) 605 | 606 | if args.cast: 607 | for x in range(5): 608 | fs[x].close() 609 | f.close() -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ConvDR 2 | 3 | This repo contains code and data for SIGIR 2021 paper ["Few-Shot Conversational Dense Retrieval"](https://arxiv.org/pdf/2105.04166.pdf). 4 | 5 | ## Prerequisites 6 | 7 | Install dependencies: 8 | 9 | ```bash 10 | git clone https://github.com/thunlp/ConvDR.git 11 | cd ConvDR 12 | pip install -r requirements.txt 13 | ``` 14 | 15 | We recommend set `PYTHONPATH` before running the code: 16 | 17 | ```bash 18 | export PYTHONPATH=${PYTHONPATH}:`pwd` 19 | ``` 20 | 21 | To train ConvDR, we need trained ad hoc dense retrievers. We use [ANCE](https://github.com/microsoft/ANCE) for both tasks. Please downloads those checkpoints here: [TREC CAsT](https://webdatamltrainingdiag842.blob.core.windows.net/semistructstore/OpenSource/Passage_ANCE_FirstP_Checkpoint.zip) and [OR-QuAC](https://data.thunlp.org/convdr/ad-hoc-ance-orquac.cp). For TREC CAsT, we directly use the official model trained on MS MARCO Passage Retrieval task. For OR-QuAC, we initialize the retriever from the official model trained on NQ and TriviaQA, and continue training on OR-QuAC with manually reformulated questions using the ANCE codebase. 22 | 23 | The following code downloads those checkpoints and store them in `./checkpoints`. 24 | 25 | ```bash 26 | mkdir checkpoints 27 | wget https://webdatamltrainingdiag842.blob.core.windows.net/semistructstore/OpenSource/Passage_ANCE_FirstP_Checkpoint.zip 28 | wget https://data.thunlp.org/convdr/ad-hoc-ance-orquac.cp 29 | unzip Passage_ANCE_FirstP_Checkpoint.zip 30 | mv "Passage ANCE(FirstP) Checkpoint" ad-hoc-ance-msmarco 31 | ``` 32 | 33 | ## Data Preparation 34 | 35 | By default, we expect raw data to be stored in `./datasets/raw` and processed data to be stored in `./datasets`: 36 | 37 | ```bash 38 | mkdir datasets 39 | mkdir datasets/raw 40 | ``` 41 | 42 | ### TREC CAsT 43 | 44 | #### CAsT shared files download 45 | 46 | Use the following commands to download the document collection for CAsT-19 & CAsT-20 as well as the MARCO duplicate file: 47 | 48 | ```bash 49 | cd datasets/raw 50 | wget https://msmarco.blob.core.windows.net/msmarcoranking/collection.tar.gz -O msmarco.tsv 51 | wget http://trec-car.cs.unh.edu/datareleases/v2.0/paragraphCorpus.v2.0.tar.xz 52 | wget http://boston.lti.cs.cmu.edu/Services/treccast19/duplicate_list_v1.0.txt 53 | ``` 54 | 55 | #### CAsT-19 files download 56 | 57 | Download necessary files for CAsT-19 and store them into `./datasets/raw/cast-19`: 58 | 59 | ```bash 60 | mkdir datasets/raw/cast-19 61 | cd datasets/raw/cast-19 62 | wget https://raw.githubusercontent.com/daltonj/treccastweb/master/2019/data/evaluation/evaluation_topics_v1.0.json 63 | wget https://raw.githubusercontent.com/daltonj/treccastweb/master/2019/data/evaluation/evaluation_topics_annotated_resolved_v1.0.tsv 64 | wget https://trec.nist.gov/data/cast/2019qrels.txt 65 | ``` 66 | 67 | #### CAsT-20 files download 68 | 69 | Download necessary files for CAsT-20 and store them into `./datasets/raw/cast-20`: 70 | 71 | ```bash 72 | mkdir datasets/raw/cast-20 73 | cd datasets/raw/cast-20 74 | wget https://raw.githubusercontent.com/daltonj/treccastweb/master/2020/2020_automatic_evaluation_topics_v1.0.json 75 | wget https://raw.githubusercontent.com/daltonj/treccastweb/master/2020/2020_manual_evaluation_topics_v1.0.json 76 | wget https://trec.nist.gov/data/cast/2020qrels.txt 77 | ``` 78 | 79 | #### CAsT preprocessing 80 | 81 | Use the scripts `./data/preprocess_cast19` and `./data/preprocess_cast20` to preprocess raw CAsT files: 82 | 83 | ```bash 84 | mkdir datasets/cast-19 85 | mkdir datasets/cast-shared 86 | python data/preprocess_cast19.py --car_cbor=datasets/raw/dedup.articles-paragraphs.cbor --msmarco_collection=datasets/raw/msmarco.tsv --duplicate_file=datasets/raw/duplicate_list_v1.0.txt --cast_dir=datasets/raw/cast-19/ --out_data_dir=datasets/cast-19 --out_collection_dir=datasets/cast-shared 87 | ``` 88 | 89 | ```bash 90 | mkdir datasets/cast-20 91 | mkdir datasets/cast-shared 92 | python data/preprocess_cast20.py --car_cbor=datasets/raw/dedup.articles-paragraphs.cbor --msmarco_collection=datasets/raw/msmarco.tsv --duplicate_file=datasets/raw/duplicate_list_v1.0.txt --cast_dir=datasets/raw/cast-20/ --out_data_dir=datasets/cast-20 --out_collection_dir=datasets/cast-shared 93 | ``` 94 | 95 | ### OR-QuAC 96 | 97 | #### OR-QuAC files download 98 | 99 | Download necessary OR-QuAC files and store them into `./datasets/raw/or-quac`: 100 | 101 | ```bash 102 | mkdir datasets/raw/or-quac 103 | cd datasets/raw/or-quac 104 | wget https://ciir.cs.umass.edu/downloads/ORConvQA/all_blocks.txt.gz 105 | wget https://ciir.cs.umass.edu/downloads/ORConvQA/qrels.txt.gz 106 | gzip -d *.txt.gz 107 | mkdir preprocessed 108 | cd preprocessed 109 | wget https://ciir.cs.umass.edu/downloads/ORConvQA/preprocessed/train.txt 110 | wget https://ciir.cs.umass.edu/downloads/ORConvQA/preprocessed/test.txt 111 | wget https://ciir.cs.umass.edu/downloads/ORConvQA/preprocessed/dev.txt 112 | ``` 113 | 114 | #### OR-QuAC preprocessing 115 | 116 | Use the scripts `./data/preprocess_orquac` to preprocess OR-QuAC files: 117 | 118 | ```bash 119 | mkdir datasets/or-quac 120 | python data/preprocess_orquac.py --orquac_dir=datasets/raw/or-quac --output_dir=datasets/or-quac 121 | ``` 122 | 123 | ### Generate Document Embeddings 124 | 125 | Our code is based on ANCE and we have a similar embedding inference pipeline, where the documents are first tokenized and converted to token ids and then the token ids are used for embedding inference. We create sub-directories `tokenized` and `embeddings` inside `./datasets/cast-shared` and `./datasets/or-quac` to store the tokenized documents and document embeddings, respectively: 126 | 127 | ```bash 128 | mkdir datasets/cast-shared/tokenized 129 | mkdir datasets/cast-shared/embeddings 130 | mkdir datasets/or-quac/tokenized 131 | mkdir datasets/or-quac/embeddings 132 | ``` 133 | 134 | Run `./data/tokenizing.py` to tokenize documents in parallel: 135 | 136 | ```bash 137 | # CAsT 138 | python data/tokenizing.py --collection=datasets/cast-shared/collection.tsv --out_data_dir=datasets/cast-shared/tokenized --model_name_or_path=checkpoints/ad-hoc-ance-msmarco --model_type=rdot_nll 139 | # OR-QuAC 140 | python data/tokenizing.py --collection=datasets/or-quac/collection.tsv --out_data_dir=datasets/or-quac/tokenized --model_name_or_path=bert-base-uncased --model_type=dpr 141 | ``` 142 | 143 | After tokenization, run `./drivers/gen_passage_embeddings.py` to generate document embeddings: 144 | 145 | ```bash 146 | # CAsT 147 | python -m torch.distributed.launch --nproc_per_node=$gpu_no python drivers/gen_passage_embeddings.py --data_dir=datasets/cast-shared/tokenized --checkpoint=checkpoints/ad-hoc-ance-msmarco --output_dir=datasets/cast-shared/embeddings --model_type=rdot_nll 148 | # OR-QuAC 149 | python -m torch.distributed.launch --nproc_per_node=$gpu_no python drivers/gen_passage_embeddings.py --data_dir=datasets/or-quac/tokenized --checkpoint=checkpoints/ad-hoc-ance-orquac.cp --output_dir=datasets/or-quac/embeddings --model_type=dpr 150 | ``` 151 | 152 | Note that we follow the ANCE implementation and this step takes up a lot of memory. To generate all 38M CAsT document embeddings safely, the machine should have at least 200GB memory. It's possible to save memory by generating a part at a time, and we may update the implementation in the future. 153 | 154 | ## ConvDR Training 155 | 156 | Now we are all prepared: we have downloaded & preprocessed data, and we have obtained document embeddings. Simply run `./drivers/run_convdr_train.py` to train a ConvDR using KD (MSE) loss: 157 | 158 | ```bash 159 | # CAsT-19, KD loss only, five-fold cross-validation 160 | python drivers/run_convdr_train.py --output_dir=checkpoints/convdr-kd-cast19 --model_name_or_path=checkpoints/ad-hoc-ance-msmarco --train_file=datasets/cast-19/eval_topics.jsonl --query=no_res --per_gpu_train_batch_size=4 --learning_rate=1e-5 --log_dir=logs/convdr_kd_cast19 --num_train_epochs=8 --model_type=rdot_nll --cross_validate 161 | # CAsT-20, KD loss only, five-fold cross-validation, use automatic canonical responses, set a longer length 162 | python drivers/run_convdr_train.py --output_dir=checkpoints/convdr-kd-cast20 --model_name_or_path=checkpoints/ad-hoc-ance-msmarco --train_file=datasets/cast-20/eval_topics.jsonl --query=auto_can --per_gpu_train_batch_size=4 --learning_rate=1e-5 --log_dir=logs/convdr_kd_cast20 --num_train_epochs=8 --model_type=rdot_nll --cross_validate --max_concat_length=512 163 | # OR-QuAC, KD loss only 164 | python drivers/run_convdr_train.py --output_dir=checkpoints/convdr-kd-orquac.cp --model_name_or_path=checkpoints/ad-hoc-ance-orquac.cp --train_file=datasets/or-quac/train.jsonl --query=no_res --per_gpu_train_batch_size=4 --learning_rate=1e-5 --log_dir=logs/convdr_kd_orquac --num_train_epochs=1 --model_type=dpr --log_steps=100 165 | ``` 166 | 167 | Note that for CAsT-20, it's better to first pretrain the model on CANARD and then do cross-validation: 168 | 169 | ```bash 170 | # Pretrain on CANARD (use preprocessed OR-QuAC) 171 | python drivers/run_convdr_train.py --output_dir=checkpoints/convdr-kd-cast20-warmup --model_name_or_path=checkpoints/ad-hoc-ance-msmarco --train_file=datasets/or-quac/train.jsonl --query=man_can --per_gpu_train_batch_size=4 --learning_rate=1e-5 --log_dir=logs/convdr_kd_cast20_warmup --num_train_epochs=1 --model_type=rdot_nll --log_steps=100 --max_concat_length=512 172 | # Do cross-validation on CAsT-20; Set model_name_or_path to the pretrained model and specify teacher_model to the ad hoc model 173 | python drivers/run_convdr_train.py --output_dir=checkpoints/convdr-kd-cast20 --model_name_or_path=checkpoints/convdr-kd-cast20-warmup --teacher_model=checkpoints/ad-hoc-ance-msmarco --train_file=datasets/cast-20/eval_topics.jsonl --query=auto_can --per_gpu_train_batch_size=4 --learning_rate=1e-5 --log_dir=logs/convdr_kd_cast20 --num_train_epochs=8 --model_type=rdot_nll --cross_validate --max_concat_length=512 174 | ``` 175 | 176 | To use ranking loss, we need to find negative documents for each query. We use top retrieved negatives documents from the ranking results of **manual** queries. So we need to first perform retrieval using the manual queries: 177 | 178 | ```bash 179 | # CAsT-19 180 | python drivers/run_convdr_inference.py --model_path=checkpoints/ad-hoc-ance-msmarco --eval_file=datasets/cast-19/eval_topics.jsonl --query=target --per_gpu_eval_batch_size=8 --ann_data_dir=datasets/cast-19/embeddings --qrels=datasets/cast-19/qrels.tsv --processed_data_dir=datasets/cast-19/tokenized --raw_data_dir=datasets/cast-19 --output_file=results/cast-19/manual_ance.jsonl --output_trec_file=results/cast-19/manual_ance.trec --model_type=rdot_nll --output_query_type=manual --use_gpu 181 | # OR-QuAC, inference on train, set query to "target" to use manual queries directly 182 | python drivers/run_convdr_inference.py --model_path=checkpoints/ad-hoc-ance-orquac.cp --eval_file=datasets/or-quac/train.jsonl --query=target --per_gpu_eval_batch_size=8 --ann_data_dir=datasets/or-quac/embeddings --qrels=datasets/or-quac/qrels.tsv --processed_data_dir=datasets/or-quac/tokenized --raw_data_dir=datasets/or-quac --output_file=results/or-quac/manual_ance_train.jsonl --output_trec_file=results/or-quac/manual_ance_train.trec --model_type=dpr --output_query_type=train.manual --use_gpu 183 | ``` 184 | 185 | After the retrieval finishes, we can select negative documents from manual runs and supplement the original training files with them: 186 | 187 | ```bash 188 | # CAsT-19 189 | python data/gen_ranking_data.py --train=datasets/cast-19/eval_topics.jsonl --run=results/cast-19/manual_ance.trec --output=datasets/cast-19/eval_topics.rank.jsonl --qrels=datasets/cast-19/qrels.tsv --collection=datasets/cast-shared/collection.tsv --cast 190 | # OR-QuAC 191 | python data/gen_ranking_data.py --train=datasets/or-quac/train.jsonl --run=results/or-quac/manual_ance_train.trec --output=datasets/or-quac/train.rank.jsonl --qrels=datasets/or-quac/qrels.tsv --collection=datasets/or-quac/collection.jsonl 192 | ``` 193 | 194 | Now we are able to use the ranking loss, with the `--ranking_task` flag on: 195 | 196 | ```bash 197 | # CAsT-19, Multi-task 198 | python drivers/run_convdr_train.py --output_dir=checkpoints/convdr-multi-cast19 --model_name_or_path=checkpoints/ad-hoc-ance-msmarco --train_file=datasets/cast-19/eval_topics.rank.jsonl --query=no_res --per_gpu_train_batch_size=4 --learning_rate=1e-5 --log_dir=logs/convdr_multi_cast19 --num_train_epochs=8 --model_type=rdot_nll --cross_validate --ranking_task 199 | # OR-QuAC, Multi-task 200 | python drivers/run_convdr_train.py --output_dir=checkpoints/convdr-multi-orquac.cp --model_name_or_path=checkpoints/ad-hoc-ance-orquac.cp --train_file=datasets/or-quac/train.rank.jsonl --query=no_res --per_gpu_train_batch_size=4 --learning_rate=1e-5 --log_dir=logs/convdr_multi_orquac --num_train_epochs=1 --model_type=dpr --log_steps=100 --ranking_task 201 | ``` 202 | 203 | To disable the KD loss, simply set the `--no_mse` flag. 204 | 205 | ## ConvDR Inference 206 | 207 | Run `./drivers/run_convdr_inference.py` to get inference results. `output_file` is the [OpenMatch](https://github.com/thunlp/OpenMatch)-format file for reranking, and `output_trec_file` is the TREC-style run file which can be evaluated by the [trec_eval](https://github.com/usnistgov/trec_eval) tool. 208 | 209 | ```bash 210 | # OR-QuAC 211 | python drivers/run_convdr_inference.py --model_path=checkpoints/convdr-multi-orquac.cp --eval_file=datasets/or-quac/test.jsonl --query=no_res --per_gpu_eval_batch_size=8 --cache_dir=../ann_cache_dir --ann_data_dir=datasets/or-quac/embeddings --qrels=datasets/or-quac/qrels.tsv --processed_data_dir=datasets/or-quac/tokenized --raw_data_dir=datasets/or-quac --output_file=results/or-quac/multi_task.jsonl --output_trec_file=results/or-quac/multi_task.trec --model_type=dpr --output_query_type=test.raw --use_gpu 212 | # CAsT-19 213 | python drivers/run_convdr_inference.py --model_path=checkpoints/convdr-kd-cast19 --eval_file=datasets/cast-19/eval_topics.jsonl --query=no_res --per_gpu_eval_batch_size=8 --cache_dir=../ann_cache_dir --ann_data_dir=datasets/cast-19/embeddings --qrels=datasets/cast-19/qrels.tsv --processed_data_dir=datasets/cast-19/tokenized --raw_data_dir=datasets/cast-19 --output_file=results/cast-19/kd.jsonl --output_trec_file=results/cast-19/kd.trec --model_type=rdot_nll --output_query_type=raw --use_gpu --cross_validation 214 | ``` 215 | 216 | The query embedding inference always takes the first GPU. If you set the `--use_gpu` flag (recommended), the retrieval will be performed on the remaining GPUs. The retrieval process consumes a lot of GPU resources. To reduce the resource usage, we split all document embeddings into several blocks, perform searching one-by-one and finally combine the results. If you have enough GPU resources, you can modify the code to perform searching all at once. 217 | 218 | ## Download Trained Models 219 | 220 | Three trained models can be downloaded with the following link: [CAsT19-KD-CV-Fold1](https://data.thunlp.org/convdr/convdr-kd-cast19-1.zip), [CAsT20-KD-Warmup-CV-Fold2](https://data.thunlp.org/convdr/convdr-kd-cast20-2.zip) and [ORQUAC-Multi](https://data.thunlp.org/convdr/convdr-multi-orquac.cp). 221 | 222 | ## Results 223 | 224 | [Download ConvDR and baseline runs on CAsT](https://drive.google.com/file/d/1F0RwA9sZscUAyE0IyQ7PMrgzNVqDnho5/view?usp=sharing) 225 | 226 | ## Contact 227 | 228 | Please send email to ~~yus17@mails.tsinghua.edu.cn~~ yushi17@foxmail.com. 229 | -------------------------------------------------------------------------------- /drivers/run_convdr_inference.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import csv 3 | import logging 4 | import json 5 | from model.models import MSMarcoConfigDict 6 | import os 7 | import pickle 8 | import time 9 | import copy 10 | import faiss 11 | import torch 12 | import numpy as np 13 | from torch.utils.data.sampler import SequentialSampler 14 | from torch.utils.data import DataLoader 15 | 16 | from utils.util import ConvSearchDataset, NUM_FOLD, set_seed, load_model, load_collection 17 | 18 | logger = logging.getLogger(__name__) 19 | 20 | 21 | def EvalDevQuery(query_embedding2id, 22 | merged_D, 23 | dev_query_positive_id, 24 | I_nearest_neighbor, 25 | topN, 26 | output_file, 27 | output_trec_file, 28 | offset2pid, 29 | raw_data_dir, 30 | output_query_type, 31 | raw_sequences=None): 32 | prediction = {} 33 | 34 | qids_to_ranked_candidate_passages = {} 35 | qids_to_ranked_candidate_passages_ori = {} 36 | qids_to_raw_sequences = {} 37 | for query_idx in range(len(I_nearest_neighbor)): 38 | seen_pid = set() 39 | inputs = raw_sequences[query_idx] 40 | query_id = query_embedding2id[query_idx] 41 | prediction[query_id] = {} 42 | 43 | top_ann_pid = I_nearest_neighbor[query_idx].copy() 44 | top_ann_score = merged_D[query_idx].copy() 45 | selected_ann_idx = top_ann_pid[:topN] 46 | selected_ann_score = top_ann_score[:topN].tolist() 47 | rank = 0 48 | 49 | if query_id in qids_to_ranked_candidate_passages: 50 | pass 51 | else: 52 | tmp = [(0, 0)] * topN 53 | tmp_ori = [0] * topN 54 | qids_to_ranked_candidate_passages[query_id] = tmp 55 | qids_to_ranked_candidate_passages_ori[query_id] = tmp_ori 56 | qids_to_raw_sequences[query_id] = inputs 57 | 58 | for idx, score in zip(selected_ann_idx, selected_ann_score): 59 | pred_pid = offset2pid[idx] 60 | 61 | if not pred_pid in seen_pid: 62 | qids_to_ranked_candidate_passages[query_id][rank] = (pred_pid, 63 | score) 64 | qids_to_ranked_candidate_passages_ori[query_id][ 65 | rank] = pred_pid 66 | 67 | rank += 1 68 | prediction[query_id][pred_pid] = -rank 69 | seen_pid.add(pred_pid) 70 | 71 | logger.info("Reading queries and passages...") 72 | queries = {} 73 | with open( 74 | os.path.join(raw_data_dir, 75 | "queries." + output_query_type + ".tsv"), "r") as f: 76 | for line in f: 77 | qid, query = line.strip().split("\t") 78 | queries[qid] = query 79 | collection = os.path.join(raw_data_dir, "collection.jsonl") 80 | if not os.path.exists(collection): 81 | collection = os.path.join(raw_data_dir, "collection.tsv") 82 | if not os.path.exists(collection): 83 | raise FileNotFoundError( 84 | "Neither collection.tsv nor collection.jsonl found in {}". 85 | format(raw_data_dir)) 86 | all_passages = load_collection(collection) 87 | 88 | # Write to file 89 | with open(output_file, "w") as f, open(output_trec_file, "w") as g: 90 | for qid, passages in qids_to_ranked_candidate_passages.items(): 91 | ori_qid = qid 92 | query_text = queries[ori_qid] 93 | sequences = qids_to_raw_sequences[ori_qid] 94 | for i in range(topN): 95 | pid, score = passages[i] 96 | ori_pid = pid 97 | passage_text = all_passages[ori_pid] 98 | label = 0 if qid not in dev_query_positive_id else ( 99 | dev_query_positive_id[qid][ori_pid] 100 | if ori_pid in dev_query_positive_id[qid] else 0) 101 | f.write( 102 | json.dumps({ 103 | "query": query_text, 104 | "doc": passage_text, 105 | "label": label, 106 | "query_id": str(ori_qid), 107 | "doc_id": str(ori_pid), 108 | "retrieval_score": score, 109 | "input": sequences 110 | }) + "\n") 111 | g.write( 112 | str(ori_qid) + " Q0 " + str(ori_pid) + " " + str(i + 1) + 113 | " " + str(-i - 1 + 200) + " ance\n") 114 | 115 | 116 | def evaluate(args, eval_dataset, model, logger): 117 | args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu) 118 | eval_sampler = SequentialSampler(eval_dataset) 119 | eval_dataloader = DataLoader(eval_dataset, 120 | sampler=eval_sampler, 121 | batch_size=args.eval_batch_size, 122 | collate_fn=eval_dataset.get_collate_fn( 123 | args, "inference")) 124 | 125 | logger.info("***** Running evaluation *****") 126 | logger.info(" Num examples = %d", len(eval_dataset)) 127 | logger.info(" Instantaneous batch size per GPU = %d", 128 | args.per_gpu_eval_batch_size) 129 | 130 | model.zero_grad() 131 | set_seed( 132 | args) # Added here for reproducibility (even between python 2 and 3) 133 | embedding = [] 134 | embedding2id = [] 135 | raw_sequences = [] 136 | epoch_iterator = eval_dataloader 137 | for batch in epoch_iterator: 138 | qids = batch["qid"] 139 | ids, id_mask = ( 140 | ele.to(args.device) 141 | for ele in [batch["concat_ids"], batch["concat_id_mask"]]) 142 | model.eval() 143 | with torch.no_grad(): 144 | embs = model(ids, id_mask) 145 | embs = embs.detach().cpu().numpy() 146 | embedding.append(embs) 147 | for qid in qids: 148 | embedding2id.append(qid) 149 | 150 | sequences = batch["history_utterances"] 151 | raw_sequences.extend(sequences) 152 | 153 | embedding = np.concatenate(embedding, axis=0) 154 | return embedding, embedding2id, raw_sequences 155 | 156 | 157 | def search_one_by_one(ann_data_dir, gpu_index, query_embedding, topN): 158 | merged_candidate_matrix = None 159 | for block_id in range(8): 160 | logger.info("Loading passage reps " + str(block_id)) 161 | passage_embedding = None 162 | passage_embedding2id = None 163 | try: 164 | with open( 165 | os.path.join( 166 | ann_data_dir, 167 | "passage__emb_p__data_obj_" + str(block_id) + ".pb"), 168 | 'rb') as handle: 169 | passage_embedding = pickle.load(handle) 170 | with open( 171 | os.path.join( 172 | ann_data_dir, 173 | "passage__embid_p__data_obj_" + str(block_id) + ".pb"), 174 | 'rb') as handle: 175 | passage_embedding2id = pickle.load(handle) 176 | except: 177 | break 178 | print('passage embedding shape: ' + str(passage_embedding.shape)) 179 | print("query embedding shape: " + str(query_embedding.shape)) 180 | gpu_index.add(passage_embedding) 181 | ts = time.time() 182 | D, I = gpu_index.search(query_embedding, topN) 183 | te = time.time() 184 | elapsed_time = te - ts 185 | print({ 186 | "total": elapsed_time, 187 | "data": query_embedding.shape[0], 188 | "per_query": elapsed_time / query_embedding.shape[0] 189 | }) 190 | candidate_id_matrix = passage_embedding2id[ 191 | I] # passage_idx -> passage_id 192 | D = D.tolist() 193 | candidate_id_matrix = candidate_id_matrix.tolist() 194 | candidate_matrix = [] 195 | for score_list, passage_list in zip(D, candidate_id_matrix): 196 | candidate_matrix.append([]) 197 | for score, passage in zip(score_list, passage_list): 198 | candidate_matrix[-1].append((score, passage)) 199 | assert len(candidate_matrix[-1]) == len(passage_list) 200 | assert len(candidate_matrix) == I.shape[0] 201 | 202 | gpu_index.reset() 203 | del passage_embedding 204 | del passage_embedding2id 205 | 206 | if merged_candidate_matrix == None: 207 | merged_candidate_matrix = candidate_matrix 208 | continue 209 | 210 | # merge 211 | merged_candidate_matrix_tmp = copy.deepcopy(merged_candidate_matrix) 212 | merged_candidate_matrix = [] 213 | for merged_list, cur_list in zip(merged_candidate_matrix_tmp, 214 | candidate_matrix): 215 | p1, p2 = 0, 0 216 | merged_candidate_matrix.append([]) 217 | while p1 < topN and p2 < topN: 218 | if merged_list[p1][0] >= cur_list[p2][0]: 219 | merged_candidate_matrix[-1].append(merged_list[p1]) 220 | p1 += 1 221 | else: 222 | merged_candidate_matrix[-1].append(cur_list[p2]) 223 | p2 += 1 224 | while p1 < topN: 225 | merged_candidate_matrix[-1].append(merged_list[p1]) 226 | p1 += 1 227 | while p2 < topN: 228 | merged_candidate_matrix[-1].append(cur_list[p2]) 229 | p2 += 1 230 | 231 | merged_D, merged_I = [], [] 232 | for merged_list in merged_candidate_matrix: 233 | merged_D.append([]) 234 | merged_I.append([]) 235 | for candidate in merged_list: 236 | merged_D[-1].append(candidate[0]) 237 | merged_I[-1].append(candidate[1]) 238 | merged_D, merged_I = np.array(merged_D), np.array(merged_I) 239 | 240 | print(merged_I) 241 | 242 | return merged_D, merged_I 243 | 244 | 245 | def main(): 246 | parser = argparse.ArgumentParser() 247 | parser.add_argument("--model_path", type=str, help="The model checkpoint.") 248 | parser.add_argument("--eval_file", 249 | type=str, 250 | help="The evaluation dataset.") 251 | parser.add_argument( 252 | "--max_concat_length", 253 | default=256, 254 | type=int, 255 | help="Max input concatenated query length after tokenization.") 256 | parser.add_argument("--max_query_length", 257 | default=64, 258 | type=int, 259 | help="Max input query length after tokenization." 260 | "This option is for single query input.") 261 | parser.add_argument("--cross_validate", 262 | action='store_true', 263 | help="Set when doing cross validation.") 264 | parser.add_argument("--per_gpu_eval_batch_size", 265 | default=4, 266 | type=int, 267 | help="Batch size per GPU/CPU.") 268 | parser.add_argument("--no_cuda", 269 | action='store_true', 270 | help="Avoid using CUDA when available (for pytorch).") 271 | parser.add_argument('--seed', 272 | type=int, 273 | default=42, 274 | help="Random seed for initialization.") 275 | parser.add_argument("--cache_dir", type=str) 276 | parser.add_argument("--ann_data_dir", 277 | type=str, 278 | help="Path to ANCE embeddings.") 279 | parser.add_argument("--use_gpu", 280 | action='store_true', 281 | help="Whether to use GPU for Faiss.") 282 | parser.add_argument("--qrels", type=str, help="The qrels file.") 283 | parser.add_argument("--processed_data_dir", 284 | type=str, 285 | help="Path to tokenized documents.") 286 | parser.add_argument("--raw_data_dir", type=str, help="Path to dataset.") 287 | parser.add_argument("--output_file", 288 | type=str, 289 | help="Output file for OpenMatch reranking.") 290 | parser.add_argument( 291 | "--output_trec_file", 292 | type=str, 293 | help="TREC-style run file, to be evaluated by the trec_eval tool.") 294 | parser.add_argument( 295 | "--query", 296 | type=str, 297 | default="no_res", 298 | choices=["no_res", "man_can", "auto_can", "target", "output", "raw"], 299 | help="Input query format.") 300 | parser.add_argument("--output_query_type", 301 | type=str, 302 | help="Query to be written in the OpenMatch file.") 303 | parser.add_argument( 304 | "--fold", 305 | type=int, 306 | default=-1, 307 | help="Fold to evaluate on; set to -1 to evaluate all folds.") 308 | parser.add_argument( 309 | "--model_type", 310 | default=None, 311 | type=str, 312 | required=True, 313 | help="Model type selected in the list: " + 314 | ", ".join(MSMarcoConfigDict.keys()), 315 | ) 316 | parser.add_argument("--top_n", 317 | default=100, 318 | type=int, 319 | help="Number of retrieved documents for each query.") 320 | args = parser.parse_args() 321 | 322 | device = torch.device( 323 | "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") 324 | args.n_gpu = 1 325 | args.device = device 326 | 327 | ngpu = faiss.get_num_gpus() 328 | gpu_resources = [] 329 | tempmem = -1 330 | 331 | for i in range(ngpu): 332 | res = faiss.StandardGpuResources() 333 | if tempmem >= 0: 334 | res.setTempMemory(tempmem) 335 | gpu_resources.append(res) 336 | 337 | # Setup logging 338 | logging.basicConfig( 339 | format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 340 | datefmt='%m/%d/%Y %H:%M:%S', 341 | level=logging.INFO) 342 | logger.warning("device: %s, n_gpu: %s", device, args.n_gpu) 343 | 344 | # Set seed 345 | set_seed(args) 346 | 347 | with open(os.path.join(args.processed_data_dir, "offset2pid.pickle"), 348 | "rb") as f: 349 | offset2pid = pickle.load(f) 350 | 351 | logger.info("Building index") 352 | # faiss.omp_set_num_threads(16) 353 | cpu_index = faiss.IndexFlatIP(768) 354 | index = None 355 | if args.use_gpu: 356 | co = faiss.GpuMultipleClonerOptions() 357 | co.shard = True 358 | co.usePrecomputed = False 359 | # gpu_vector_resources, gpu_devices_vector 360 | vres = faiss.GpuResourcesVector() 361 | vdev = faiss.Int32Vector() 362 | for i in range(0, ngpu): 363 | vdev.push_back(i) 364 | vres.push_back(gpu_resources[i]) 365 | gpu_index = faiss.index_cpu_to_gpu_multiple(vres, 366 | vdev, 367 | cpu_index, co) 368 | index = gpu_index 369 | else: 370 | index = cpu_index 371 | 372 | dev_query_positive_id = {} 373 | if args.qrels is not None: 374 | with open(args.qrels, 'r', encoding='utf8') as f: 375 | tsvreader = csv.reader(f, delimiter="\t") 376 | for [topicid, _, docid, rel] in tsvreader: 377 | topicid = str(topicid) 378 | docid = int(docid) 379 | rel = int(rel) 380 | if topicid not in dev_query_positive_id: 381 | if rel > 0: 382 | dev_query_positive_id[topicid] = {} 383 | dev_query_positive_id[topicid][docid] = rel 384 | else: 385 | dev_query_positive_id[topicid][docid] = rel 386 | 387 | total_embedding = [] 388 | total_embedding2id = [] 389 | total_raw_sequences = [] 390 | 391 | if not args.cross_validate: 392 | 393 | config, tokenizer, model = load_model(args, args.model_path) 394 | 395 | if args.max_concat_length <= 0: 396 | args.max_concat_length = tokenizer.max_len_single_sentence 397 | args.max_concat_length = min(args.max_concat_length, 398 | tokenizer.max_len_single_sentence) 399 | 400 | # eval 401 | logger.info("Training/evaluation parameters %s", args) 402 | eval_dataset = ConvSearchDataset([args.eval_file], 403 | tokenizer, 404 | args, 405 | mode="inference") 406 | total_embedding, total_embedding2id, raw_sequences = evaluate( 407 | args, eval_dataset, model, logger) 408 | total_raw_sequences.extend(raw_sequences) 409 | del model 410 | torch.cuda.empty_cache() 411 | 412 | else: 413 | # K-Fold Cross Validation 414 | 415 | for i in range(NUM_FOLD): 416 | if args.fold != -1 and i != args.fold: 417 | continue 418 | 419 | logger.info("Testing Fold #{}".format(i)) 420 | suffix = ('-' + str(i)) 421 | config, tokenizer, model = load_model(args, 422 | args.model_path + suffix) 423 | 424 | if args.max_concat_length <= 0: 425 | args.max_concat_length = tokenizer.max_len_single_sentence 426 | args.max_concat_length = min(args.max_concat_length, 427 | tokenizer.max_len_single_sentence) 428 | 429 | logger.info("Training/evaluation parameters %s", args) 430 | eval_file = "%s.%d" % (args.eval_file, i) 431 | logger.info("eval_file: {}".format(eval_file)) 432 | eval_dataset = ConvSearchDataset([eval_file], 433 | tokenizer, 434 | args, 435 | mode="inference") 436 | embedding, embedding2id, raw_sequences = evaluate( 437 | args, eval_dataset, model, logger) 438 | total_embedding.append(embedding) 439 | total_embedding2id.extend(embedding2id) 440 | total_raw_sequences.extend(raw_sequences) 441 | 442 | del model 443 | torch.cuda.empty_cache() 444 | 445 | total_embedding = np.concatenate(total_embedding, axis=0) 446 | 447 | merged_D, merged_I = search_one_by_one(args.ann_data_dir, index, 448 | total_embedding, args.top_n) 449 | logger.info("start EvalDevQuery...") 450 | EvalDevQuery(total_embedding2id, 451 | merged_D, 452 | dev_query_positive_id=dev_query_positive_id, 453 | I_nearest_neighbor=merged_I, 454 | topN=args.top_n, 455 | output_file=args.output_file, 456 | output_trec_file=args.output_trec_file, 457 | offset2pid=offset2pid, 458 | raw_data_dir=args.raw_data_dir, 459 | output_query_type=args.output_query_type, 460 | raw_sequences=total_raw_sequences) 461 | 462 | 463 | if __name__ == "__main__": 464 | main() 465 | -------------------------------------------------------------------------------- /drivers/run_convdr_train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import logging 3 | import os 4 | import torch 5 | import random 6 | from tensorboardX import SummaryWriter 7 | 8 | from utils.util import pad_input_ids_with_mask, getattr_recursive 9 | 10 | from torch.utils.data import DataLoader, RandomSampler 11 | from tqdm import tqdm, trange 12 | from transformers import get_linear_schedule_with_warmup 13 | from torch import nn 14 | 15 | from model.models import MSMarcoConfigDict 16 | from utils.util import ConvSearchDataset, NUM_FOLD, set_seed, load_model 17 | from utils.dpr_utils import CheckpointState, get_model_obj, get_optimizer 18 | 19 | logger = logging.getLogger(__name__) 20 | 21 | 22 | def _save_checkpoint(args, 23 | model, 24 | output_dir, 25 | optimizer=None, 26 | scheduler=None, 27 | step=0) -> str: 28 | offset = step 29 | epoch = 0 30 | model_to_save = get_model_obj(model) 31 | cp = os.path.join(output_dir, 'checkpoint-' + str(offset)) 32 | 33 | meta_params = {} 34 | state = CheckpointState(model_to_save.state_dict(), optimizer.state_dict(), 35 | scheduler.state_dict(), offset, epoch, meta_params) 36 | torch.save(state._asdict(), cp) 37 | logger.info('Saved checkpoint at %s', cp) 38 | return cp 39 | 40 | 41 | def train(args, 42 | train_dataset, 43 | model, 44 | teacher_model, 45 | loss_fn, 46 | logger, 47 | writer: SummaryWriter, 48 | cross_validate_id=-1, 49 | loss_fn_2=None, 50 | tokenizer=None): 51 | args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu) 52 | train_sampler = RandomSampler(train_dataset) 53 | train_dataloader = DataLoader(train_dataset, 54 | sampler=train_sampler, 55 | batch_size=args.train_batch_size, 56 | collate_fn=train_dataset.get_collate_fn( 57 | args, "train")) 58 | 59 | if args.max_steps > 0: 60 | t_total = args.max_steps 61 | args.num_train_epochs = args.max_steps // ( 62 | len(train_dataloader) // args.gradient_accumulation_steps) + 1 63 | else: 64 | t_total = len( 65 | train_dataloader 66 | ) // args.gradient_accumulation_steps * args.num_train_epochs 67 | 68 | # Prepare optimizer and schedule (linear warmup and decay) 69 | optimizer = get_optimizer(args, model, weight_decay=args.weight_decay) 70 | 71 | scheduler = get_linear_schedule_with_warmup( 72 | optimizer, 73 | num_warmup_steps=args.warmup_steps, 74 | num_training_steps=t_total) 75 | 76 | # multi-gpu training (should be after apex fp16 initialization) 77 | if args.n_gpu > 1: 78 | model = torch.nn.DataParallel(model) 79 | 80 | # Train! 81 | logger.info("***** Running training *****") 82 | logger.info(" Num examples = %d", len(train_dataset)) 83 | logger.info(" Num Epochs = %d", args.num_train_epochs) 84 | logger.info(" Instantaneous batch size per GPU = %d", 85 | args.per_gpu_train_batch_size) 86 | logger.info(" Total train batch size (w. parallel & accumulation) = %d", 87 | args.train_batch_size * args.gradient_accumulation_steps) 88 | logger.info(" Gradient Accumulation steps = %d", 89 | args.gradient_accumulation_steps) 90 | logger.info(" Total optimization steps = %d", t_total) 91 | 92 | global_step = 0 93 | tr_loss, logging_loss = 0.0, 0.0 94 | tr_loss1, tr_loss2 = 0.0, 0.0 95 | model.zero_grad() 96 | train_iterator = trange(int(args.num_train_epochs), desc="Epoch") 97 | set_seed( 98 | args) # Added here for reproducibility (even between python 2 and 3) 99 | for _ in train_iterator: 100 | epoch_iterator = tqdm(train_dataloader, desc="Iteration") 101 | for step, batch in enumerate(epoch_iterator): 102 | concat_ids, concat_id_mask, target_ids, target_id_mask = (ele.to( 103 | args.device) for ele in [ 104 | batch["concat_ids"], batch["concat_id_mask"], 105 | batch["target_ids"], batch["target_id_mask"] 106 | ]) 107 | model.train() 108 | teacher_model.eval() 109 | embs = model(concat_ids, concat_id_mask) 110 | with torch.no_grad(): 111 | teacher_embs = teacher_model(target_ids, 112 | target_id_mask).detach() 113 | loss1 = None 114 | if not args.no_mse: 115 | loss1 = loss_fn(embs, teacher_embs) 116 | loss = loss1 117 | loss2 = None 118 | if args.ranking_task: 119 | bs = concat_ids.shape[0] # real batch size 120 | pos_and_negs = batch["documents"] #.numpy() # (B, K) 121 | total_token_list = [] 122 | for group in pos_and_negs: 123 | sampled = random.sample(group[1:], args.num_negatives) 124 | this_pos_neg = [group[0]] + sampled 125 | for doc in this_pos_neg: 126 | try: 127 | title, text = doc.split("[SEP]") 128 | doc_ids = tokenizer.encode(title, 129 | text_pair=text, 130 | add_special_tokens=True, 131 | max_length=512) 132 | except ValueError: 133 | doc_ids = tokenizer.encode(doc, 134 | add_special_tokens=True, 135 | max_length=512) 136 | doc_ids, doc_mask = pad_input_ids_with_mask( 137 | doc_ids, 512) 138 | total_token_list.append((doc_ids, doc_mask)) 139 | doc_batch_size = 8 140 | pos_and_negs_embeddings = [] 141 | for i in range(0, len(total_token_list), doc_batch_size): 142 | # batchify 143 | batch_ids = [] 144 | batch_mask = [] 145 | for j in range( 146 | i, min(i + doc_batch_size, len(total_token_list))): 147 | batch_ids.append(total_token_list[j][0]) 148 | batch_mask.append(total_token_list[j][1]) 149 | batch_ids = torch.tensor(batch_ids, 150 | dtype=torch.long).to(args.device) 151 | batch_mask = torch.tensor(batch_mask, 152 | dtype=torch.long).to(args.device) 153 | with torch.no_grad(): 154 | pos_and_negs_embeddings_tmp = teacher_model( 155 | batch_ids, batch_mask, is_query=False).detach() 156 | pos_and_negs_embeddings.append( 157 | pos_and_negs_embeddings_tmp) 158 | pos_and_negs_embeddings = torch.cat(pos_and_negs_embeddings, 159 | dim=0) # (B * K, E) 160 | pos_and_negs_embeddings = pos_and_negs_embeddings.view( 161 | bs, args.num_negatives + 1, -1) # (B, K, E) 162 | embs_for_ranking = embs.unsqueeze(-1) # (B, E, 1) 163 | embs_for_ranking = embs_for_ranking.expand( 164 | bs, 768, args.num_negatives + 1) # (B, E, K) 165 | embs_for_ranking = embs_for_ranking.transpose(1, 166 | 2) # (B, K, E) 167 | logits = embs_for_ranking * pos_and_negs_embeddings 168 | logits = torch.sum(logits, dim=-1) # (B, K) 169 | labels = torch.zeros(bs, dtype=torch.long).to(args.device) 170 | loss2 = loss_fn_2(logits, labels) 171 | loss = loss1 + loss2 if loss1 != None else loss2 172 | 173 | if args.n_gpu > 1: 174 | loss = loss.mean() 175 | if args.gradient_accumulation_steps > 1: 176 | loss = loss / args.gradient_accumulation_steps 177 | 178 | loss.backward() 179 | tr_loss += loss.item() 180 | if not args.no_mse: 181 | tr_loss1 += loss1.item() 182 | if args.ranking_task: 183 | tr_loss2 += loss2.item() 184 | del loss 185 | torch.cuda.empty_cache() 186 | 187 | if (step + 1) % args.gradient_accumulation_steps == 0: 188 | torch.nn.utils.clip_grad_norm_(model.parameters(), 189 | args.max_grad_norm) 190 | optimizer.step() 191 | scheduler.step() # Update learning rate schedule 192 | model.zero_grad() 193 | global_step += 1 194 | 195 | if global_step % args.log_steps == 0: 196 | writer.add_scalar( 197 | str(cross_validate_id) + "/loss", 198 | tr_loss / args.log_steps, global_step) 199 | if not args.no_mse: 200 | writer.add_scalar( 201 | str(cross_validate_id) + "/mse_loss", 202 | tr_loss1 / args.log_steps, global_step) 203 | if args.ranking_task: 204 | writer.add_scalar( 205 | str(cross_validate_id) + "/ranking_loss", 206 | tr_loss2 / args.log_steps, global_step) 207 | tr_loss = 0.0 208 | tr_loss1 = 0.0 209 | tr_loss2 = 0.0 210 | 211 | if args.save_steps > 0 and global_step % args.save_steps == 0: 212 | checkpoint_prefix = 'checkpoint' 213 | output_dir = args.output_dir + ( 214 | ('-' + str(cross_validate_id)) 215 | if cross_validate_id != -1 else "") 216 | if args.model_type == "rdot_nll": 217 | output_dir = os.path.join( 218 | output_dir, '{}-{}'.format(checkpoint_prefix, 219 | global_step)) 220 | if not os.path.exists(output_dir): 221 | os.makedirs(output_dir) 222 | model_to_save = model.module if hasattr( 223 | model, 'module') else model 224 | model_to_save.save_pretrained(output_dir) 225 | torch.save( 226 | args, os.path.join(output_dir, 227 | 'training_args.bin')) 228 | else: 229 | if not os.path.exists(output_dir): 230 | os.makedirs(output_dir) 231 | _save_checkpoint(args, model, output_dir, optimizer, 232 | scheduler, global_step) 233 | logger.info("Saving model checkpoint to %s", output_dir) 234 | 235 | if args.max_steps > 0 and global_step > args.max_steps: 236 | epoch_iterator.close() 237 | break 238 | 239 | if args.max_steps > 0 and global_step > args.max_steps: 240 | train_iterator.close() 241 | break 242 | 243 | if args.model_type == "dpr": 244 | output_dir = args.output_dir + ( 245 | ('-' + str(cross_validate_id)) if cross_validate_id != -1 else "") 246 | # output_dir = os.path.join(output_dir, '{}-{}'.format(checkpoint_prefix, global_step)) 247 | if not os.path.exists(output_dir): 248 | os.makedirs(output_dir) 249 | _save_checkpoint(args, model, output_dir, optimizer, scheduler, 250 | global_step) 251 | 252 | return global_step, tr_loss / global_step 253 | 254 | 255 | def main(): 256 | parser = argparse.ArgumentParser() 257 | 258 | parser.add_argument( 259 | "--output_dir", 260 | default=None, 261 | type=str, 262 | required=True, 263 | help= 264 | "The output directory where the model predictions and checkpoints will be written." 265 | ) 266 | parser.add_argument( 267 | "--model_name_or_path", 268 | type=str, 269 | help="The model checkpoint for weights initialization.") 270 | parser.add_argument( 271 | "--max_concat_length", 272 | default=256, 273 | type=int, 274 | help="Optional input sequence length after tokenization." 275 | "The training dataset will be truncated in block of this size for training." 276 | "Default to the model max input length for single sentence inputs (take into account special tokens)." 277 | ) 278 | parser.add_argument( 279 | "--max_query_length", 280 | default=64, 281 | type=int, 282 | help="Max input query length after tokenization." 283 | "This option is for single query input." 284 | ) 285 | parser.add_argument( 286 | "--train_file", 287 | default=None, 288 | type=str, 289 | required=True, 290 | help= 291 | "Path of training file. Do not add fold suffix when cross validate, i.e. use 'data/eval_topics.jsonl' instead of 'data/eval_topics.jsonl.0'" 292 | ) 293 | parser.add_argument( 294 | "--cross_validate", 295 | action='store_true', 296 | help="Set when doing cross validation" 297 | ) 298 | parser.add_argument( 299 | "--init_from_multiple_models", 300 | action='store_true', 301 | help= 302 | "Set when initialize from different models during cross validation (Model-based+CV)" 303 | ) 304 | parser.add_argument( 305 | "--model_type", 306 | default=None, 307 | type=str, 308 | required=True, 309 | help="Model type selected in the list: " + 310 | ", ".join(MSMarcoConfigDict.keys()), 311 | ) 312 | 313 | parser.add_argument( 314 | "--ranking_task", 315 | action='store_true', 316 | help="Whether to use ranking loss." 317 | ) 318 | parser.add_argument( 319 | "--no_mse", 320 | action="store_true", 321 | help="Whether to disable KD loss." 322 | ) 323 | parser.add_argument( 324 | "--num_negatives", 325 | type=int, 326 | default=9, 327 | help="Number of negative documents per query." 328 | ) 329 | 330 | parser.add_argument( 331 | "--per_gpu_train_batch_size", 332 | default=4, 333 | type=int, 334 | help="Batch size per GPU/CPU for training." 335 | ) 336 | parser.add_argument( 337 | '--gradient_accumulation_steps', 338 | type=int, 339 | default=1, 340 | help= 341 | "Number of updates steps to accumulate before performing a backward/update pass." 342 | ) 343 | parser.add_argument( 344 | "--learning_rate", 345 | default=1e-5, 346 | type=float, 347 | help="The initial learning rate for Adam." 348 | ) 349 | parser.add_argument( 350 | "--weight_decay", 351 | default=0.0, 352 | type=float, 353 | help="Weight deay if we apply some." 354 | ) 355 | parser.add_argument( 356 | "--adam_epsilon", 357 | default=1e-8, 358 | type=float, 359 | help="Epsilon for Adam optimizer." 360 | ) 361 | parser.add_argument( 362 | "--max_grad_norm", 363 | default=1.0, 364 | type=float, 365 | help="Max gradient norm." 366 | ) 367 | parser.add_argument( 368 | "--num_train_epochs", 369 | default=1.0, 370 | type=float, 371 | help="Total number of training epochs to perform." 372 | ) 373 | parser.add_argument( 374 | "--max_steps", 375 | default=-1, 376 | type=int, 377 | help= 378 | "If > 0: set total number of training steps to perform. Override num_train_epochs." 379 | ) 380 | parser.add_argument( 381 | "--warmup_steps", 382 | default=0, 383 | type=int, 384 | help="Linear warmup over warmup_steps." 385 | ) 386 | parser.add_argument( 387 | '--save_steps', 388 | type=int, 389 | default=-1, 390 | help="Save checkpoint every X updates steps." 391 | ) 392 | parser.add_argument( 393 | "--no_cuda", 394 | action='store_true', 395 | help="Avoid using CUDA when available" 396 | ) 397 | parser.add_argument( 398 | '--overwrite_output_dir', 399 | action='store_true', 400 | help="Overwrite the content of the output directory" 401 | ) 402 | parser.add_argument( 403 | '--seed', 404 | type=int, 405 | default=42, 406 | help="random seed for initialization" 407 | ) 408 | parser.add_argument( 409 | "--log_dir", 410 | type=str, 411 | help="Directory for tensorboard logging." 412 | ) 413 | parser.add_argument( 414 | "--log_steps", 415 | type=int, 416 | default=1, 417 | help="Log loss every x steps." 418 | ) 419 | parser.add_argument( 420 | "--cache_dir", 421 | type=str 422 | ) 423 | parser.add_argument( 424 | "--teacher_model", 425 | type=str, 426 | help="The teacher model. If None, use `model_name_or_path` as teacher." 427 | ) 428 | parser.add_argument( 429 | "--query", 430 | type=str, 431 | default="no_res", 432 | choices=["no_res", "man_can", "auto_can", "target", "output", "raw"], 433 | help="Input query format." 434 | ) 435 | args = parser.parse_args() 436 | 437 | tb_writer = SummaryWriter(log_dir=args.log_dir) 438 | 439 | if os.path.exists(args.output_dir) and os.listdir( 440 | args.output_dir) and not args.overwrite_output_dir: 441 | raise ValueError( 442 | "Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome." 443 | .format(args.output_dir)) 444 | 445 | device = torch.device( 446 | "cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") 447 | args.n_gpu = torch.cuda.device_count() 448 | args.device = device 449 | 450 | # Setup logging 451 | logging.basicConfig( 452 | format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', 453 | datefmt='%m/%d/%Y %H:%M:%S', 454 | level=logging.INFO) 455 | logger.warning("device: %s, n_gpu: %s", device, args.n_gpu) 456 | 457 | # Set seed 458 | set_seed(args) 459 | 460 | loss_fn = nn.MSELoss() 461 | loss_fn.to(args.device) 462 | loss_fn_2 = nn.CrossEntropyLoss() 463 | loss_fn_2.to(args.device) 464 | 465 | if args.teacher_model == None: 466 | args.teacher_model = args.model_name_or_path 467 | _, __, teacher_model = load_model(args, args.teacher_model) 468 | 469 | if not args.cross_validate: 470 | 471 | config, tokenizer, model = load_model(args, args.model_name_or_path) 472 | if args.query in ["man_can", "auto_can"]: 473 | tokenizer.add_tokens([""]) 474 | model.resize_token_embeddings(len(tokenizer)) 475 | if args.max_concat_length <= 0: 476 | args.max_concat_length = tokenizer.max_len_single_sentence 477 | args.max_concat_length = min(args.max_concat_length, 478 | tokenizer.max_len_single_sentence) 479 | 480 | # Training 481 | logger.info("Training/evaluation parameters %s", args) 482 | train_dataset = ConvSearchDataset([args.train_file], 483 | tokenizer, 484 | args, 485 | mode="train") 486 | global_step, tr_loss = train(args, 487 | train_dataset, 488 | model, 489 | teacher_model, 490 | loss_fn, 491 | logger, 492 | tb_writer, 493 | cross_validate_id=0, 494 | loss_fn_2=loss_fn_2, 495 | tokenizer=tokenizer) 496 | logger.info(" global_step = %s, average loss = %s", global_step, 497 | tr_loss) 498 | 499 | # Saving 500 | # Create output directory if needed 501 | 502 | if args.model_type == "rdot_nll": 503 | if not os.path.exists(args.output_dir): 504 | os.makedirs(args.output_dir) 505 | logger.info("Saving model checkpoint to %s", args.output_dir) 506 | model_to_save = model.module if hasattr(model, 'module') else model 507 | model_to_save.save_pretrained(args.output_dir) 508 | tokenizer.save_pretrained(args.output_dir) 509 | torch.save(args, os.path.join(args.output_dir, 510 | 'training_args.bin')) 511 | 512 | else: 513 | # K-Fold Cross Validation 514 | for i in range(NUM_FOLD): 515 | logger.info("Training Fold #{}".format(i)) 516 | suffix = ('-' + str(i)) if args.init_from_multiple_models else '' 517 | config, tokenizer, model = load_model( 518 | args, args.model_name_or_path + suffix) 519 | if args.query in ["man_can", "auto_can"]: 520 | tokenizer.add_tokens([""]) 521 | model.resize_token_embeddings(len(tokenizer)) 522 | 523 | if args.max_concat_length <= 0: 524 | args.max_concat_length = tokenizer.max_len_single_sentence 525 | args.max_concat_length = min(args.max_concat_length, 526 | tokenizer.max_len_single_sentence) 527 | 528 | logger.info("Training/evaluation parameters %s", args) 529 | train_files = [ 530 | "%s.%d" % (args.train_file, j) for j in range(NUM_FOLD) 531 | if j != i 532 | ] 533 | logger.info("train_files: {}".format(train_files)) 534 | train_dataset = ConvSearchDataset(train_files, 535 | tokenizer, 536 | args, 537 | mode="train") 538 | global_step, tr_loss = train(args, 539 | train_dataset, 540 | model, 541 | teacher_model, 542 | loss_fn, 543 | logger, 544 | tb_writer, 545 | cross_validate_id=i, 546 | loss_fn_2=loss_fn_2, 547 | tokenizer=tokenizer) 548 | logger.info(" global_step = %s, average loss = %s", global_step, 549 | tr_loss) 550 | 551 | if args.model_type == "rdot_nll": 552 | output_dir = args.output_dir + '-' + str(i) 553 | if not os.path.exists(output_dir): 554 | os.makedirs(output_dir) 555 | 556 | logger.info("Saving model checkpoint to %s", output_dir) 557 | model_to_save = model.module if hasattr(model, 558 | 'module') else model 559 | model_to_save.save_pretrained(output_dir) 560 | tokenizer.save_pretrained(output_dir) 561 | torch.save(args, os.path.join(output_dir, 'training_args.bin')) 562 | 563 | del model 564 | torch.cuda.empty_cache() 565 | 566 | tb_writer.close() 567 | 568 | 569 | if __name__ == "__main__": 570 | main() 571 | -------------------------------------------------------------------------------- /utils/util.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | sys.path += ['../'] 4 | # import pandas as pd 5 | # from sklearn.metrics import roc_curve, auc 6 | import gzip 7 | import copy 8 | import torch 9 | from torch import nn 10 | import torch.distributed as dist 11 | from tqdm import tqdm, trange 12 | import os 13 | from os import listdir 14 | from os.path import isfile, join 15 | import json 16 | import logging 17 | import random 18 | import pytrec_eval 19 | import pickle 20 | import numpy as np 21 | import torch 22 | 23 | torch.multiprocessing.set_sharing_strategy('file_system') 24 | from multiprocessing import Process 25 | from torch.utils.data import DataLoader, Dataset, TensorDataset, IterableDataset 26 | from utils.dpr_utils import get_model_obj, load_states_from_checkpoint 27 | import re 28 | from model.models import MSMarcoConfigDict, ALL_MODELS 29 | from typing import List, Set, Dict, Tuple, Callable, Iterable, Any 30 | 31 | logger = logging.getLogger(__name__) 32 | NUM_FOLD = 5 33 | 34 | 35 | class InputFeaturesPair(object): 36 | """ 37 | A single set of features of data. 38 | 39 | Args: 40 | input_ids: Indices of input sequence tokens in the vocabulary. 41 | attention_mask: Mask to avoid performing attention on padding token indices. 42 | Mask values selected in ``[0, 1]``: 43 | Usually ``1`` for tokens that are NOT MASKED, ``0`` for MASKED (padded) tokens. 44 | token_type_ids: Segment token indices to indicate first and second portions of the inputs. 45 | label: Label corresponding to the input 46 | """ 47 | def __init__(self, 48 | input_ids_a, 49 | attention_mask_a=None, 50 | token_type_ids_a=None, 51 | input_ids_b=None, 52 | attention_mask_b=None, 53 | token_type_ids_b=None, 54 | label=None): 55 | 56 | self.input_ids_a = input_ids_a 57 | self.attention_mask_a = attention_mask_a 58 | self.token_type_ids_a = token_type_ids_a 59 | 60 | self.input_ids_b = input_ids_b 61 | self.attention_mask_b = attention_mask_b 62 | self.token_type_ids_b = token_type_ids_b 63 | 64 | self.label = label 65 | 66 | def __repr__(self): 67 | return str(self.to_json_string()) 68 | 69 | def to_dict(self): 70 | """Serializes this instance to a Python dictionary.""" 71 | output = copy.deepcopy(self.__dict__) 72 | return output 73 | 74 | def to_json_string(self): 75 | """Serializes this instance to a JSON string.""" 76 | return json.dumps(self.to_dict(), indent=2, sort_keys=True) + "\n" 77 | 78 | 79 | def getattr_recursive(obj, name): 80 | for layer in name.split("."): 81 | if hasattr(obj, layer): 82 | obj = getattr(obj, layer) 83 | else: 84 | return None 85 | return obj 86 | 87 | 88 | def barrier_array_merge(args, 89 | data_array, 90 | merge_axis=0, 91 | prefix="", 92 | load_cache=False, 93 | only_load_in_master=False, 94 | merge=True): 95 | # data array: [B, any dimension] 96 | # merge alone one axis 97 | 98 | if args.local_rank == -1: 99 | return data_array 100 | 101 | if not load_cache: 102 | rank = args.rank 103 | if is_first_worker(): 104 | if not os.path.exists(args.output_dir): 105 | os.makedirs(args.output_dir) 106 | 107 | dist.barrier() # directory created 108 | pickle_path = os.path.join( 109 | args.output_dir, "{1}_data_obj_{0}.pb".format(str(rank), prefix)) 110 | with open(pickle_path, 'wb') as handle: 111 | pickle.dump(data_array, handle, protocol=4) 112 | 113 | # make sure all processes wrote their data before first process 114 | # collects it 115 | dist.barrier() 116 | 117 | data_array = None 118 | 119 | data_list = [] 120 | 121 | if not merge: 122 | return None 123 | 124 | # return empty data 125 | if only_load_in_master: 126 | if not is_first_worker(): 127 | dist.barrier() 128 | return None 129 | 130 | for i in range(args.world_size 131 | ): # TODO: dynamically find the max instead of HardCode 132 | pickle_path = os.path.join( 133 | args.output_dir, "{1}_data_obj_{0}.pb".format(str(i), prefix)) 134 | try: 135 | with open(pickle_path, 'rb') as handle: 136 | b = pickle.load(handle) 137 | data_list.append(b) 138 | except BaseException: 139 | continue 140 | 141 | data_array_agg = np.concatenate(data_list, axis=merge_axis) 142 | dist.barrier() 143 | return data_array_agg 144 | 145 | 146 | def pad_input_ids(input_ids, max_length, pad_on_left=False, pad_token=0): 147 | padding_length = max_length - len(input_ids) 148 | padding_id = [pad_token] * padding_length 149 | 150 | # attention_mask = [1] * len(input_ids) + [0] * padding_length 151 | 152 | if padding_length <= 0: 153 | input_ids = input_ids[:max_length] 154 | else: 155 | if pad_on_left: 156 | input_ids = padding_id + input_ids 157 | else: 158 | input_ids = input_ids + padding_id 159 | 160 | return input_ids 161 | 162 | 163 | def pad_input_ids_with_mask(input_ids, 164 | max_length, 165 | pad_on_left=False, 166 | pad_token=0): 167 | padding_length = max_length - len(input_ids) 168 | padding_id = [pad_token] * padding_length 169 | 170 | attention_mask = [] 171 | 172 | if padding_length <= 0: 173 | input_ids = input_ids[:max_length] 174 | attention_mask = [1] * max_length 175 | else: 176 | if pad_on_left: 177 | input_ids = padding_id + input_ids 178 | else: 179 | attention_mask = [1] * len(input_ids) + [0] * padding_length 180 | input_ids = input_ids + padding_id 181 | 182 | assert len(input_ids) == max_length 183 | assert len(attention_mask) == max_length 184 | 185 | return input_ids, attention_mask 186 | 187 | 188 | def pad_ids(input_ids, 189 | attention_mask, 190 | token_type_ids, 191 | max_length, 192 | pad_on_left=False, 193 | pad_token=0, 194 | pad_token_segment_id=0, 195 | mask_padding_with_zero=True): 196 | padding_length = max_length - len(input_ids) 197 | padding_id = [pad_token] * padding_length 198 | padding_type = [pad_token_segment_id] * padding_length 199 | padding_attention = [0 if mask_padding_with_zero else 1] * padding_length 200 | 201 | if padding_length <= 0: 202 | input_ids = input_ids[:max_length] 203 | attention_mask = attention_mask[:max_length] 204 | token_type_ids = token_type_ids[:max_length] 205 | else: 206 | if pad_on_left: 207 | input_ids = padding_id + input_ids 208 | attention_mask = padding_attention + attention_mask 209 | token_type_ids = padding_type + token_type_ids 210 | else: 211 | input_ids = input_ids + padding_id 212 | attention_mask = attention_mask + padding_attention 213 | token_type_ids = token_type_ids + padding_type 214 | 215 | return input_ids, attention_mask, token_type_ids 216 | 217 | 218 | # to reuse pytrec_eval, id must be string 219 | def convert_to_string_id(result_dict): 220 | string_id_dict = {} 221 | 222 | # format [string, dict[string, val]] 223 | for k, v in result_dict.items(): 224 | _temp_v = {} 225 | for inner_k, inner_v in v.items(): 226 | _temp_v[str(inner_k)] = inner_v 227 | 228 | string_id_dict[str(k)] = _temp_v 229 | 230 | return string_id_dict 231 | 232 | 233 | def set_seed(args): 234 | random.seed(args.seed) 235 | np.random.seed(args.seed) 236 | torch.manual_seed(args.seed) 237 | if args.n_gpu > 0: 238 | torch.cuda.manual_seed_all(args.seed) 239 | 240 | 241 | def load_model(args, checkpoint_path): 242 | label_list = ["0", "1"] 243 | num_labels = len(label_list) 244 | args.model_type = args.model_type.lower() 245 | configObj = MSMarcoConfigDict[args.model_type] 246 | args.model_path = checkpoint_path 247 | 248 | config, tokenizer, model = None, None, None 249 | if args.model_type != "dpr": 250 | config = configObj.config_class.from_pretrained( 251 | args.model_path, 252 | num_labels=num_labels, 253 | finetuning_task="MSMarco", 254 | cache_dir=args.cache_dir if args.cache_dir else None, 255 | ) 256 | tokenizer = configObj.tokenizer_class.from_pretrained( 257 | args.model_path, 258 | do_lower_case=True, 259 | cache_dir=args.cache_dir if args.cache_dir else None, 260 | ) 261 | model = configObj.model_class.from_pretrained( 262 | args.model_path, 263 | from_tf=bool(".ckpt" in args.model_path), 264 | config=config, 265 | cache_dir=args.cache_dir if args.cache_dir else None, 266 | ) 267 | else: # dpr 268 | model = configObj.model_class(args) 269 | saved_state = load_states_from_checkpoint(checkpoint_path) 270 | model_to_load = get_model_obj(model) 271 | logger.info('Loading saved model state ...') 272 | model_to_load.load_state_dict(saved_state.model_dict) 273 | tokenizer = configObj.tokenizer_class.from_pretrained( 274 | "bert-base-uncased", 275 | do_lower_case=True, 276 | cache_dir=None, 277 | ) 278 | 279 | model.to(args.device) 280 | return config, tokenizer, model 281 | 282 | 283 | def is_first_worker(): 284 | return not dist.is_available() or not dist.is_initialized( 285 | ) or dist.get_rank() == 0 286 | 287 | 288 | def concat_key(all_list, key, axis=0): 289 | return np.concatenate([ele[key] for ele in all_list], axis=axis) 290 | 291 | 292 | def get_checkpoint_no(checkpoint_path): 293 | return int(re.findall(r'\d+', checkpoint_path)[-1]) 294 | 295 | 296 | def get_latest_ann_data(ann_data_path): 297 | ANN_PREFIX = "ann_ndcg_" 298 | if not os.path.exists(ann_data_path): 299 | return -1, None, None 300 | files = list(next(os.walk(ann_data_path))[2]) 301 | num_start_pos = len(ANN_PREFIX) 302 | data_no_list = [ 303 | int(s[num_start_pos:]) for s in files 304 | if s[:num_start_pos] == ANN_PREFIX 305 | ] 306 | if len(data_no_list) > 0: 307 | data_no = max(data_no_list) 308 | with open(os.path.join(ann_data_path, ANN_PREFIX + str(data_no)), 309 | 'r') as f: 310 | ndcg_json = json.load(f) 311 | return data_no, os.path.join(ann_data_path, "ann_training_data_" + 312 | str(data_no)), ndcg_json 313 | return -1, None, None 314 | 315 | 316 | def numbered_byte_file_generator(base_path, file_no, record_size): 317 | for i in range(file_no): 318 | with open('{}_split{}'.format(base_path, i), 'rb') as f: 319 | while True: 320 | b = f.read(record_size) 321 | if not b: 322 | # eof 323 | break 324 | yield b 325 | 326 | 327 | def load_collection(collection_file): 328 | all_passages = ["[INVALID DOC ID]"] * 5000_0000 329 | ext = collection_file[collection_file.rfind(".") + 1:] 330 | if ext not in ["jsonl", "tsv"]: 331 | raise TypeError("Unrecognized file type") 332 | with open(collection_file, "r") as f: 333 | if ext == "jsonl": 334 | for line in f: 335 | line = line.strip() 336 | obj = json.loads(line) 337 | pid = int(obj["id"]) 338 | passage = obj["title"] + "[SEP]" + obj["text"] 339 | all_passages[pid] = passage 340 | else: 341 | for line in f: 342 | line = line.strip() 343 | try: 344 | line_arr = line.split("\t") 345 | pid = int(line_arr[0]) 346 | passage = line_arr[1].rstrip() 347 | all_passages[pid] = passage 348 | except IndexError: 349 | print("bad passage") 350 | except ValueError: 351 | print("bad pid") 352 | return all_passages 353 | 354 | 355 | class EmbeddingCache: 356 | def __init__(self, base_path, seed=-1): 357 | self.base_path = base_path 358 | with open(base_path + '_meta', 'r') as f: 359 | meta = json.load(f) 360 | self.dtype = np.dtype(meta['type']) 361 | self.total_number = meta['total_number'] 362 | self.record_size = int( 363 | meta['embedding_size']) * self.dtype.itemsize + 4 364 | if seed >= 0: 365 | self.ix_array = np.random.RandomState(seed).permutation( 366 | self.total_number) 367 | else: 368 | self.ix_array = np.arange(self.total_number) 369 | self.f = None 370 | 371 | def open(self): 372 | self.f = open(self.base_path, 'rb') 373 | 374 | def close(self): 375 | self.f.close() 376 | 377 | def read_single_record(self): 378 | record_bytes = self.f.read(self.record_size) 379 | passage_len = int.from_bytes(record_bytes[:4], 'big') 380 | passage = np.frombuffer(record_bytes[4:], dtype=self.dtype) 381 | return passage_len, passage 382 | 383 | def __enter__(self): 384 | self.open() 385 | return self 386 | 387 | def __exit__(self, type, value, traceback): 388 | self.close() 389 | 390 | def __getitem__(self, key): 391 | if key < 0 or key > self.total_number: 392 | raise IndexError( 393 | "Index {} is out of bound for cached embeddings of size {}". 394 | format(key, self.total_number)) 395 | self.f.seek(key * self.record_size) 396 | return self.read_single_record() 397 | 398 | def __iter__(self): 399 | self.f.seek(0) 400 | for i in range(self.total_number): 401 | new_ix = self.ix_array[i] 402 | yield self.__getitem__(new_ix) 403 | 404 | def __len__(self): 405 | return self.total_number 406 | 407 | 408 | class StreamingDataset(IterableDataset): 409 | def __init__(self, elements, fn): 410 | super().__init__() 411 | self.elements = elements 412 | self.fn = fn 413 | self.num_replicas = -1 414 | 415 | def __iter__(self): 416 | if dist.is_initialized(): 417 | self.num_replicas = dist.get_world_size() 418 | self.rank = dist.get_rank() 419 | print("Rank:", self.rank, "world:", self.num_replicas) 420 | else: 421 | print("Not running in distributed mode") 422 | for i, element in enumerate(self.elements): 423 | if self.num_replicas != -1 and i % self.num_replicas != self.rank: 424 | continue 425 | records = self.fn(element, i) 426 | for rec in records: 427 | # print("yielding record") 428 | # print(rec) 429 | yield rec 430 | 431 | 432 | class ConvSearchExample: 433 | def __init__(self, 434 | qid, 435 | concat_ids, 436 | concat_id_mask, 437 | target_ids, 438 | target_id_mask, 439 | doc_pos=None, 440 | doc_negs=None, 441 | raw_sequences=None): 442 | self.qid = qid 443 | self.concat_ids = concat_ids 444 | self.concat_id_mask = concat_id_mask 445 | self.target_ids = target_ids 446 | self.target_id_mask = target_id_mask 447 | self.doc_pos = doc_pos 448 | self.doc_negs = doc_negs 449 | self.raw_sequences = raw_sequences 450 | 451 | 452 | class ConvSearchDataset(Dataset): 453 | def __init__(self, filenames, tokenizer, args, mode="train"): 454 | self.examples = [] 455 | for filename in filenames: 456 | with open(filename, encoding="utf-8") as f: 457 | for line in f: 458 | record = json.loads(line) 459 | input_sents = record['input'] 460 | target_sent = record['target'] 461 | auto_sent = record.get('output', "no") 462 | raw_sent = record["input"][-1] 463 | responses = record[ 464 | "manual_response"] if args.query == "man_can" else ( 465 | record["automatic_response"] 466 | if args.query == "auto_can" else []) 467 | topic_number = record.get('topic_number', None) 468 | query_number = record.get('query_number', None) 469 | qid = str(topic_number) + "_" + str( 470 | query_number) if topic_number != None else str( 471 | record["qid"]) 472 | sequences = record['input'] 473 | concat_ids = [] 474 | concat_id_mask = [] 475 | target_ids = None 476 | target_id_mask = None 477 | doc_pos = None 478 | doc_negs = None 479 | if mode == "train" and args.ranking_task: 480 | doc_pos = record["doc_pos"] 481 | doc_negs = record["doc_negs"] 482 | 483 | if mode == "train" or args.query in [ 484 | "no_res", "man_can", "auto_can" 485 | ]: 486 | if args.model_type == "dpr": 487 | concat_ids.append( 488 | tokenizer.cls_token_id 489 | ) # dpr (on OR-QuAC) uses BERT-style sequence [CLS] q1 [SEP] q2 [SEP] ... 490 | for sent in input_sents[:-1]: # exlude last one 491 | if args.model_type != "dpr": 492 | concat_ids.append( 493 | tokenizer.cls_token_id 494 | ) # RoBERTa-style sequence q1 q2 ... 495 | concat_ids.extend( 496 | tokenizer.convert_tokens_to_ids( 497 | tokenizer.tokenize(sent))) 498 | concat_ids.append(tokenizer.sep_token_id) 499 | 500 | if args.query in [ 501 | "man_can", "auto_can" 502 | ] and len(responses) >= 2: # add response 503 | if args.model_type != "dpr": 504 | concat_ids.append(tokenizer.cls_token_id) 505 | concat_ids.extend( 506 | tokenizer.convert_tokens_to_ids(["" 507 | ])) 508 | concat_ids.extend( 509 | tokenizer.convert_tokens_to_ids( 510 | tokenizer.tokenize(responses[-2]))) 511 | concat_ids.append(tokenizer.sep_token_id) 512 | sequences.insert(-1, responses[-2]) 513 | 514 | if args.model_type != "dpr": 515 | concat_ids.append(tokenizer.cls_token_id) 516 | concat_ids.extend( 517 | tokenizer.convert_tokens_to_ids( 518 | tokenizer.tokenize(input_sents[-1]))) 519 | concat_ids.append(tokenizer.sep_token_id) 520 | 521 | # We do not use token type id for BERT (and for RoBERTa, of course) 522 | concat_ids, concat_id_mask = pad_input_ids_with_mask( 523 | concat_ids, args.max_concat_length) 524 | assert len(concat_ids) == args.max_concat_length 525 | 526 | elif args.query == "target": # manual 527 | 528 | concat_ids = tokenizer.encode( 529 | target_sent, 530 | add_special_tokens=True, 531 | max_length=args.max_query_length) 532 | concat_ids, concat_id_mask = pad_input_ids_with_mask( 533 | concat_ids, args.max_query_length) 534 | assert len(concat_ids) == args.max_query_length 535 | 536 | elif args.query == "output": # reserved for query rewriter output 537 | 538 | concat_ids = tokenizer.encode( 539 | auto_sent, 540 | add_special_tokens=True, 541 | max_length=args.max_query_length) 542 | concat_ids, concat_id_mask = pad_input_ids_with_mask( 543 | concat_ids, args.max_query_length) 544 | assert len(concat_ids) == args.max_query_length 545 | 546 | elif args.query == "raw": 547 | 548 | concat_ids = tokenizer.encode( 549 | raw_sent, 550 | add_special_tokens=True, 551 | max_length=args.max_query_length) 552 | concat_ids, concat_id_mask = pad_input_ids_with_mask( 553 | concat_ids, args.max_query_length) 554 | assert len(concat_ids) == args.max_query_length 555 | 556 | else: 557 | raise KeyError("Unsupported query type") 558 | 559 | if mode == "train": 560 | target_ids = tokenizer.encode( 561 | target_sent, 562 | add_special_tokens=True, 563 | max_length=args.max_query_length) 564 | target_ids, target_id_mask = pad_input_ids_with_mask( 565 | target_ids, args.max_query_length) 566 | assert len(target_ids) == args.max_query_length 567 | 568 | self.examples.append( 569 | ConvSearchExample(qid, concat_ids, concat_id_mask, 570 | target_ids, target_id_mask, doc_pos, 571 | doc_negs, sequences)) 572 | 573 | def __len__(self): 574 | return len(self.examples) 575 | 576 | def __getitem__(self, item): 577 | return self.examples[item] 578 | 579 | @staticmethod 580 | def get_collate_fn(args, mode): 581 | def collate_fn(batch_dataset: list): 582 | collated_dict = { 583 | "qid": [], 584 | "concat_ids": [], 585 | "concat_id_mask": [], 586 | } 587 | if mode == "train": 588 | collated_dict.update({"target_ids": [], "target_id_mask": []}) 589 | if args.ranking_task: 590 | collated_dict.update({"documents": []}) 591 | else: 592 | collated_dict.update({"history_utterances": []}) 593 | for example in batch_dataset: 594 | collated_dict["qid"].append(example.qid) 595 | collated_dict["concat_ids"].append(example.concat_ids) 596 | collated_dict["concat_id_mask"].append(example.concat_id_mask) 597 | if mode == "train": 598 | collated_dict["target_ids"].append(example.target_ids) 599 | collated_dict["target_id_mask"].append( 600 | example.target_id_mask) 601 | if args.ranking_task: 602 | collated_dict["documents"].append([example.doc_pos] + 603 | example.doc_negs) 604 | else: 605 | collated_dict["history_utterances"].append( 606 | example.raw_sequences) 607 | should_be_tensor = [ 608 | "concat_ids", "concat_id_mask", "target_ids", "target_id_mask" 609 | ] 610 | for key in should_be_tensor: 611 | if key in collated_dict: 612 | collated_dict[key] = torch.tensor(collated_dict[key], 613 | dtype=torch.long) 614 | 615 | return collated_dict 616 | 617 | return collate_fn 618 | 619 | 620 | def tokenize_to_file(args, i, num_process, in_path, out_path, line_fn): 621 | 622 | configObj = MSMarcoConfigDict[args.model_type] 623 | tokenizer = configObj.tokenizer_class.from_pretrained( 624 | args.model_name_or_path, 625 | do_lower_case=True, 626 | cache_dir=None, 627 | ) 628 | 629 | with open(in_path, 'r', encoding='utf-8') if in_path[-2:] != "gz" else gzip.open(in_path, 'rt', encoding='utf8') as in_f,\ 630 | open('{}_split{}'.format(out_path, i), 'wb') as out_f: 631 | for idx, line in enumerate(in_f): 632 | if idx % num_process != i: 633 | continue 634 | try: 635 | res = line_fn(args, line, tokenizer) 636 | except ValueError: 637 | print("Bad passage.") 638 | else: 639 | out_f.write(res) 640 | 641 | 642 | # args, 32, , collection.tsv, passages, 643 | def multi_file_process(args, num_process, in_path, out_path, line_fn): 644 | processes = [] 645 | for i in range(num_process): 646 | p = Process(target=tokenize_to_file, 647 | args=( 648 | args, 649 | i, 650 | num_process, 651 | in_path, 652 | out_path, 653 | line_fn, 654 | )) 655 | processes.append(p) 656 | p.start() 657 | for p in processes: 658 | p.join() 659 | 660 | 661 | def all_gather(data): 662 | """ 663 | Run all_gather on arbitrary picklable data (not necessarily tensors) 664 | Args: 665 | data: any picklable object 666 | Returns: 667 | list[data]: list of data gathered from each rank 668 | """ 669 | if not dist.is_initialized() or dist.get_world_size() == 1: 670 | return [data] 671 | 672 | world_size = dist.get_world_size() 673 | # serialized to a Tensor 674 | buffer = pickle.dumps(data) 675 | storage = torch.ByteStorage.from_buffer(buffer) 676 | tensor = torch.ByteTensor(storage).to("cuda") 677 | 678 | # obtain Tensor size of each rank 679 | local_size = torch.LongTensor([tensor.numel()]).to("cuda") 680 | size_list = [torch.LongTensor([0]).to("cuda") for _ in range(world_size)] 681 | dist.all_gather(size_list, local_size) 682 | size_list = [int(size.item()) for size in size_list] 683 | max_size = max(size_list) 684 | 685 | # receiving Tensor from all ranks 686 | # we pad the tensor because torch all_gather does not support 687 | # gathering tensors of different shapes 688 | tensor_list = [] 689 | for _ in size_list: 690 | tensor_list.append(torch.ByteTensor(size=(max_size, )).to("cuda")) 691 | if local_size != max_size: 692 | padding = torch.ByteTensor(size=(max_size - local_size, )).to("cuda") 693 | tensor = torch.cat((tensor, padding), dim=0) 694 | dist.all_gather(tensor_list, tensor) 695 | 696 | data_list = [] 697 | for size, tensor in zip(size_list, tensor_list): 698 | buffer = tensor.cpu().numpy().tobytes()[:size] 699 | data_list.append(pickle.loads(buffer)) 700 | 701 | return data_list 702 | --------------------------------------------------------------------------------