├── LICENSE ├── README.md ├── conda_environment.txt ├── scripts ├── convert_to_anserini.py ├── expand_queries.py └── quantize.py └── src ├── __init__.py ├── evaluation ├── __init__.py ├── loaders.py ├── metrics.py └── ranking.py ├── index.py ├── model.py ├── parameters.py ├── rerank.py ├── retrieve.py ├── test.py ├── train.py ├── training ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-37.pyc │ ├── __init__.cpython-38.pyc │ ├── data_reader.cpython-37.pyc │ └── data_reader.cpython-38.pyc └── data_reader.py ├── utils.py └── utils2.py /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SIGIR2021 2 | 3 | 4 | ## Train 5 | 6 | ``` 7 | python -m src.train --triples ../../triples.train.small.expanded.tsv 8 | ``` 9 | 10 | ## Index 11 | ``` 12 | python -m src.index 13 | ``` 14 | 15 | Training triples: 16 | https://drive.google.com/file/d/1SlVfdqdtAjbf7T0tnaZ7As3esLf9s1TZ/view?usp=sharing 17 | 18 | Expanded collection: 19 | https://drive.google.com/file/d/10PKQeTsfxQclVlQs6dYDPyg6vaMR4cUX/view?usp=sharing 20 | 21 | Model checkpoint: 22 | https://drive.google.com/file/d/1WQJcgWI5NRNQz8aNrWFx72SqSaM2XskH/view?usp=sharing 23 | 24 | -------------------------------------------------------------------------------- /conda_environment.txt: -------------------------------------------------------------------------------- 1 | # This file may be used to create an environment using: 2 | # $ conda create --name --file 3 | # platform: linux-64 4 | _libgcc_mutex=0.1=main 5 | boto3=1.13.21=pypi_0 6 | botocore=1.16.21=pypi_0 7 | ca-certificates=2020.1.1=0 8 | certifi=2020.4.5.1=py36_0 9 | chardet=3.0.4=pypi_0 10 | click=7.1.2=pypi_0 11 | docutils=0.15.2=pypi_0 12 | faiss-cpu=1.6.3=pypi_0 13 | idna=2.9=pypi_0 14 | jmespath=0.10.0=pypi_0 15 | joblib=0.15.1=pypi_0 16 | ld_impl_linux-64=2.33.1=h53a641e_7 17 | libedit=3.1.20181209=hc058e9b_0 18 | libffi=3.3=he6710b0_1 19 | libgcc-ng=9.1.0=hdf63c60_0 20 | libstdcxx-ng=9.1.0=hdf63c60_0 21 | ncurses=6.2=he6710b0_1 22 | numpy=1.18.4=pypi_0 23 | openssl=1.1.1g=h7b6447c_0 24 | pip=20.0.2=py36_3 25 | python=3.6.10=h7579374_2 26 | python-dateutil=2.8.1=pypi_0 27 | readline=8.0=h7b6447c_0 28 | regex=2020.5.14=pypi_0 29 | requests=2.23.0=pypi_0 30 | s3transfer=0.3.3=pypi_0 31 | sacremoses=0.0.43=pypi_0 32 | sentencepiece=0.1.91=pypi_0 33 | setuptools=46.4.0=py36_0 34 | six=1.15.0=pypi_0 35 | sqlite=3.31.1=h62c20be_1 36 | tk=8.6.8=hbc83047_0 37 | torch=1.4.0=pypi_0 38 | tqdm=4.46.0=pypi_0 39 | transformers=2.1.1=pypi_0 40 | ujson=1.35=pypi_0 41 | urllib3=1.25.9=pypi_0 42 | wheel=0.34.2=py36_0 43 | xz=5.2.5=h7b6447c_0 44 | zlib=1.2.11=h7b6447c_3 45 | -------------------------------------------------------------------------------- /scripts/convert_to_anserini.py: -------------------------------------------------------------------------------- 1 | import json 2 | from tqdm import tqdm 3 | from argparse import ArgumentParser 4 | 5 | 6 | def process(input_filename, output_filename): 7 | with open(input_filename) as input_file, open(output_filename, "w+") as output_file: 8 | for docid, line in tqdm(enumerate(input_file)): 9 | data = {} 10 | data["id"] = docid 11 | data["contents"] = "" 12 | data["vector"] = {} 13 | for t in line.strip().split(","): 14 | split_list = t.strip().split(":") 15 | if len(split_list) == 2: 16 | term, score = split_list 17 | data["vector"][term] = float(score) 18 | json.dump(data, output_file) 19 | output_file.write('\n') 20 | 21 | 22 | def main(): 23 | parser = ArgumentParser(description='Convert a DeepImpact collection into an Anserini JsonVectorCollection.') 24 | parser.add_argument('--input', dest='input', required=True) 25 | parser.add_argument('--output', dest='output', required=True) 26 | args = parser.parse_args() 27 | process(args.input, args.output) 28 | 29 | if __name__ == "__main__": 30 | main() 31 | 32 | -------------------------------------------------------------------------------- /scripts/expand_queries.py: -------------------------------------------------------------------------------- 1 | import json 2 | from tqdm import tqdm 3 | from argparse import ArgumentParser 4 | 5 | 6 | def process(input_filename, output_filename): 7 | with open(input_filename) as input_file, open(output_filename, "w+") as output_file: 8 | for line in tqdm(input_file): 9 | query_id, line = line.split("\t") 10 | terms = "" 11 | for t in line.strip().split(","): 12 | split_list = t.strip().split(":") 13 | if len(split_list) == 2: 14 | term, score = split_list 15 | terms += " ".join([term] * int(score)) 16 | terms += " " 17 | output_file.write("{}\t{}\n".format(query_id, terms)) 18 | 19 | 20 | def main(): 21 | parser = ArgumentParser(description='Expand queries according to the query weights.') 22 | parser.add_argument('--input', dest='input', required=True) 23 | parser.add_argument('--output', dest='output', required=True) 24 | args = parser.parse_args() 25 | process(args.input, args.output) 26 | 27 | if __name__ == "__main__": 28 | main() 29 | -------------------------------------------------------------------------------- /scripts/quantize.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | from argparse import ArgumentParser 3 | 4 | import math 5 | 6 | QUANTIZATION_BITS = 8 7 | 8 | def quantize(value, scale): 9 | return int(math.ceil(value * scale)) 10 | 11 | def find_max_value(input_filename): 12 | max_val = 0 13 | with open(input_filename) as input_file: 14 | for docid, line in tqdm(enumerate(input_file)): 15 | for t in line.strip().split(","): 16 | split_list = t.strip().split(": ") 17 | if len(split_list) == 2: 18 | term, score = split_list 19 | max_val = max(max_val, float(score)) 20 | return max_val 21 | 22 | def process(input_filename, output_filename, max_val_provided = None): 23 | if max_val_provided: 24 | max_val= float(max_val_provided) 25 | print("We will use {} as max val".format(max_val_provided)) 26 | else: 27 | max_val = find_max_value(input_filename) 28 | print("Max Val is: {}".format(max_val)) 29 | scale = (1< Loading qrels from", qrels_path, "...") 11 | 12 | qrels = {} 13 | with open(qrels_path, mode='r', encoding="utf-8") as f: 14 | for line in f: 15 | qid, x, pid, y = map(int, line.strip().split('\t')) 16 | assert x == 0 and y == 1 17 | qrels[qid] = qrels.get(qid, []) 18 | qrels[qid].append(pid) 19 | 20 | assert all(len(qrels[qid]) == len(set(qrels[qid])) for qid in qrels) 21 | 22 | avg_positive = round(sum(len(qrels[qid]) for qid in qrels) / len(qrels), 2) 23 | 24 | print_message("#> Loaded qrels for", len(qrels), "unique queries with", 25 | avg_positive, "positives per query on average.\n") 26 | 27 | return qrels 28 | 29 | 30 | def load_topK(topK_path): 31 | queries = {} 32 | topK_docs = {} 33 | topK_pids = {} 34 | 35 | print_message("#> Loading the top-k per query from", topK_path, "...") 36 | 37 | with open(topK_path) as f: 38 | for line in f: 39 | qid, pid, query, passage = line.split('\t') 40 | qid, pid = int(qid), int(pid) 41 | 42 | assert (qid not in queries) or (queries[qid] == query) 43 | queries[qid] = query 44 | topK_docs[qid] = topK_docs.get(qid, []) 45 | topK_docs[qid].append(passage) 46 | topK_pids[qid] = topK_pids.get(qid, []) 47 | topK_pids[qid].append(pid) 48 | 49 | assert all(len(topK_pids[qid]) == len(set(topK_pids[qid])) for qid in topK_pids) 50 | 51 | Ks = [len(topK_pids[qid]) for qid in topK_pids] 52 | 53 | print_message("#> max(Ks) =", max(Ks), ", avg(Ks) =", round(sum(Ks) / len(Ks), 2)) 54 | print_message("#> Loaded the top-k per query for", len(queries), "unique queries.\n") 55 | 56 | return queries, topK_docs, topK_pids 57 | 58 | 59 | def load_colbert(args): 60 | print_message("#> Loading model checkpoint.") 61 | colbert = MultiBERT.from_pretrained('bert-base-uncased') 62 | colbert = colbert.to(DEVICE) 63 | checkpoint = load_checkpoint(args.checkpoint, colbert) 64 | colbert.eval() 65 | 66 | print('\n') 67 | 68 | return colbert, checkpoint 69 | -------------------------------------------------------------------------------- /src/evaluation/metrics.py: -------------------------------------------------------------------------------- 1 | class Metrics: 2 | def __init__(self, mrr_depths: dict, recall_depths: dict, total_queries=None): 3 | self.results = {} 4 | self.mrr_sums = {depth: 0.0 for depth in mrr_depths} 5 | self.recall_sums = {depth: 0.0 for depth in recall_depths} 6 | self.total_queries = total_queries 7 | 8 | def add(self, query_idx, query_key, ranking, gold_positives): 9 | assert query_key not in self.results 10 | assert len(self.results) <= query_idx 11 | assert len(set(gold_positives)) == len(gold_positives) 12 | assert len(set([pid for _, pid, _ in ranking])) == len(ranking) 13 | 14 | self.results[query_key] = ranking 15 | 16 | positives = [i for i, (_, pid, _) in enumerate(ranking) if pid in gold_positives] 17 | 18 | if len(positives) == 0: 19 | return 20 | 21 | for depth in self.mrr_sums: 22 | first_positive = positives[0] 23 | self.mrr_sums[depth] += (1.0 / (first_positive+1.0)) if first_positive < depth else 0.0 24 | 25 | for depth in self.recall_sums: 26 | num_positives_up_to_depth = len([pos for pos in positives if pos < depth]) 27 | self.recall_sums[depth] += num_positives_up_to_depth / len(gold_positives) 28 | 29 | def print_metrics(self, query_idx): 30 | for depth in sorted(self.mrr_sums): 31 | print("MRR@" + str(depth), "=", self.mrr_sums[depth] / (query_idx+1.0)) 32 | 33 | for depth in sorted(self.recall_sums): 34 | print("Recall@" + str(depth), "=", self.recall_sums[depth] / (query_idx+1.0)) 35 | 36 | 37 | def evaluate_recall(qrels, queries, topK_pids): 38 | if qrels is None: 39 | return 40 | 41 | assert set(qrels.keys()) == set(queries.keys()) 42 | recall_at_k = [len(set.intersection(set(qrels[qid]), set(topK_pids[qid]))) / len(qrels[qid]) for qid in qrels] 43 | recall_at_k = sum(recall_at_k) / len(qrels) 44 | recall_at_k = round(recall_at_k, 3) 45 | print("Recall @ maximum depth =", recall_at_k) 46 | -------------------------------------------------------------------------------- /src/evaluation/ranking.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import time 4 | import torch 5 | 6 | from src.utils import print_message, load_checkpoint, batch 7 | from src.evaluation.metrics import Metrics 8 | 9 | 10 | def rerank(args, query, pids, passages, index=None): 11 | colbert = args.colbert 12 | #tokenized_passages = list(args.pool.map(colbert.tokenizer.tokenize, passages)) 13 | scores = [colbert.forward([query] * len(D), D)[0].cpu() for D in batch(passages, args.bsize)] 14 | scores = torch.cat(scores).squeeze(1).sort(descending=True) 15 | ranked = scores.indices.tolist() 16 | ranked_scores = scores.values.tolist() 17 | ranked_pids = [pids[position] for position in ranked] 18 | ranked_passages = [passages[position] for position in ranked] 19 | 20 | assert len(ranked_pids) == len(set(ranked_pids)) 21 | 22 | return list(zip(ranked_scores, ranked_pids, ranked_passages)) 23 | 24 | 25 | def evaluate(args, index=None): 26 | qrels, queries, topK_docs, topK_pids = args.qrels, args.queries, args.topK_docs, args.topK_pids 27 | 28 | metrics = Metrics(mrr_depths={10}, recall_depths={50, 200, 1000}, total_queries=None) 29 | 30 | if index: 31 | args.buffer = torch.zeros(1000, args.doc_maxlen, args.dim, dtype=index[0].dtype) 32 | 33 | output_path = '.'.join([str(x) for x in [args.run_name, 'tsv', int(time.time())]]) 34 | output_path = os.path.join(args.output_dir, output_path) 35 | 36 | # TODO: Save an associated metadata file with the args.input_args 37 | 38 | with open(output_path, 'w') as outputfile: 39 | with torch.no_grad(): 40 | keys = sorted(list(queries.keys())) 41 | random.shuffle(keys) 42 | 43 | for query_idx, qid in enumerate(keys): 44 | query = queries[qid] 45 | print_message(query_idx, qid, query, '\n') 46 | 47 | if qrels and args.shortcircuit and len(set.intersection(set(qrels[qid]), set(topK_pids[qid]))) == 0: 48 | continue 49 | 50 | ranking = rerank(args, query, topK_pids[qid], topK_docs[qid], index) 51 | 52 | for i, (score, pid, passage) in enumerate(ranking): 53 | outputfile.write('\t'.join([str(x) for x in [qid, pid, i+1]]) + "\n") 54 | 55 | if i+1 in [1, 2, 5, 10, 20, 100]: 56 | print("#> " + str(i+1) + ") ", pid, ":", score, ' ', passage) 57 | 58 | if qrels: 59 | metrics.add(query_idx, qid, ranking, qrels[qid]) 60 | 61 | for i, (score, pid, passage) in enumerate(ranking): 62 | if pid in qrels[qid]: 63 | print("\n#> Found", pid, "at position", i+1, "with score", score) 64 | print(passage) 65 | 66 | metrics.print_metrics(query_idx) 67 | 68 | print_message("#> checkpoint['batch'] =", args.checkpoint['batch'], '\n') 69 | print("output_path =", output_path) 70 | print("\n\n") 71 | -------------------------------------------------------------------------------- /src/index.py: -------------------------------------------------------------------------------- 1 | import random 2 | import datetime 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.optim as optim 7 | 8 | from time import time 9 | from math import ceil 10 | from src.model import * 11 | from multiprocessing import Pool 12 | from src.evaluation.loaders import load_checkpoint 13 | 14 | MB_SIZE = 1024 15 | 16 | def print_message(*s): 17 | s = ' '.join(map(str, s)) 18 | print("[{}] {}".format(datetime.datetime.utcnow().strftime("%b %d, %H:%M:%S"), s), flush=True) 19 | 20 | 21 | print_message("#> Loading model checkpoint.") 22 | net = MultiBERT.from_pretrained('bert-base-uncased') 23 | net = net.to(DEVICE) 24 | load_checkpoint("/scratch/am8949/MultiBERT/colbert-12layers-100000.dnn", net) 25 | net.eval() 26 | 27 | 28 | 29 | 30 | 31 | 32 | def tok(d): 33 | d = cleanD(d, join=False) 34 | content = ' '.join(d) 35 | tokenized_content = net.tokenizer.tokenize(content) 36 | 37 | terms = list(set([(t, d.index(t)) for t in d])) # Quadratic! 38 | word_indexes = list(accumulate([-1] + tokenized_content, lambda a, b: a + int(not b.startswith('##')))) 39 | terms = [(t, word_indexes.index(idx)) for t, idx in terms] 40 | terms = [(t, idx) for (t, idx) in terms if idx < MAX_LENGTH] 41 | 42 | return tokenized_content, terms 43 | 44 | 45 | 46 | def process_batch(g, super_batch): 47 | print_message("Start process_batch()", "") 48 | 49 | with torch.no_grad(): 50 | super_batch = list(p.map(tok, super_batch)) 51 | 52 | sorted_super_batch = sorted([(v, idx) for idx, v in enumerate(super_batch)], key=lambda x: len(x[0][0])) 53 | super_batch = [v for v, _ in sorted_super_batch] 54 | super_batch_indices = [idx for _, idx in sorted_super_batch] 55 | 56 | print_message("Done sorting", "") 57 | 58 | every_term_score = [] 59 | 60 | for batch_idx in range(ceil(len(super_batch) / MB_SIZE)): 61 | D = super_batch[batch_idx * MB_SIZE: (batch_idx + 1) * MB_SIZE] 62 | IDXs = super_batch_indices[batch_idx * MB_SIZE: (batch_idx + 1) * MB_SIZE] 63 | all_term_scores = net.index(D, len(D[-1][0])+2) 64 | every_term_score += zip(IDXs, all_term_scores) 65 | 66 | every_term_score = sorted(every_term_score) 67 | 68 | lines = [] 69 | for _, term_scores in every_term_score: 70 | term_scores = ', '.join([term + ": " + str(round(score, 3)) for term, score in term_scores]) 71 | lines.append(term_scores) 72 | 73 | g.write('\n'.join(lines) + "\n") 74 | g.flush() 75 | 76 | 77 | p = Pool(16) 78 | start_time = time() 79 | 80 | COLLECTION = "/scratch/am8949" 81 | with open(COLLECTION + '/queries.dev.test.txt', 'w') as g: 82 | with open(COLLECTION + '/queries.dev.small.tsv') as f: 83 | for idx, passage in enumerate(f): 84 | if idx % (50*1024) == 0: 85 | if idx > 0: 86 | process_batch(g, super_batch) 87 | throughput = round(idx / (time() - start_time), 1) 88 | print_message("Processed", str(idx), "passages so far [rate:", str(throughput), "passages per second]") 89 | super_batch = [] 90 | 91 | passage = passage.strip() 92 | pid, passage = passage.split('\t') 93 | super_batch.append(passage) 94 | 95 | #assert int(pid) == idx 96 | 97 | process_batch(g, super_batch) 98 | 99 | -------------------------------------------------------------------------------- /src/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from random import sample, shuffle, randint 5 | 6 | from itertools import accumulate 7 | from transformers import * 8 | import re 9 | from src.parameters import DEVICE 10 | from src.utils2 import cleanQ, cleanD 11 | 12 | MAX_LENGTH = 300 13 | 14 | 15 | class MultiBERT(BertPreTrainedModel): 16 | def __init__(self, config): 17 | super(MultiBERT, self).__init__(config) 18 | self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') 19 | self.bert = BertModel(config) 20 | self.impact_score_encoder = nn.Sequential( 21 | nn.Linear(config.hidden_size, config.hidden_size), 22 | nn.ReLU(), 23 | nn.Dropout(config.hidden_dropout_prob), 24 | nn.Linear(config.hidden_size, 1), 25 | nn.ReLU() 26 | ) 27 | self.init_weights() 28 | 29 | def convert_example(self, d, max_seq_length): 30 | max_length = min(MAX_LENGTH, max_seq_length) 31 | inputs = self.tokenizer.encode_plus(d, add_special_tokens=True, max_length=max_length, truncation=True) 32 | 33 | padding_length = max_length - len(inputs["input_ids"]) 34 | attention_mask = ([1] * len(inputs["input_ids"])) + ([0] * padding_length) 35 | input_ids = inputs["input_ids"] + ([0] * padding_length) 36 | token_type_ids = inputs["token_type_ids"] + ([0] * padding_length) 37 | 38 | return {'input_ids': input_ids, 'attention_mask': attention_mask, 'token_type_ids': token_type_ids} 39 | 40 | def tokenize(self, q, d): 41 | query_tokens = list(set(cleanQ(q).strip().split())) # [:10] 42 | 43 | content = cleanD(d).strip() 44 | doc_tokens = content.split() 45 | 46 | # NOTE: The following line accounts for CLS! 47 | tokenized = self.tokenizer.tokenize(content) 48 | word_indexes = list(accumulate([-1] + tokenized, lambda a, b: a + int(not b.startswith('##')))) 49 | match_indexes = list(set([doc_tokens.index(t) for t in query_tokens if t in doc_tokens])) 50 | term_indexes = [word_indexes.index(idx) for idx in match_indexes] 51 | 52 | a = [idx for i, idx in enumerate(match_indexes) if term_indexes[i] < MAX_LENGTH] 53 | b = [idx for idx in term_indexes if idx < MAX_LENGTH] 54 | 55 | return content, tokenized, a, b, len(word_indexes) + 2 56 | 57 | def forward(self, Q, D): 58 | bsize = len(Q) 59 | pairs = [] 60 | X, pfx_sum, pfx_sumX = [], [], [] 61 | total_size, total_sizeX, max_seq_length = 0, 0, 0 62 | 63 | doc_partials = [] 64 | pre_pairs = [] 65 | 66 | for q, d in zip(Q, D): 67 | tokens, tokenized, term_idxs, token_idxs, seq_length = self.tokenize(q, d) 68 | max_seq_length = max(max_seq_length, seq_length) 69 | 70 | pfx_sumX.append(total_sizeX) 71 | total_sizeX += len(term_idxs) 72 | 73 | tokens_split = tokens.split() 74 | 75 | doc_partials.append([(total_size + idx, tokens_split[i]) for idx, i in enumerate(term_idxs)]) 76 | total_size += len(doc_partials[-1]) 77 | pfx_sum.append(total_size) 78 | 79 | pre_pairs.append((tokenized, token_idxs)) 80 | 81 | if max_seq_length % 10 == 0: 82 | print("#>>> max_seq_length = ", max_seq_length) 83 | 84 | for tokenized, token_idxs in pre_pairs: 85 | pairs.append(self.convert_example(tokenized, max_seq_length)) 86 | X.append(token_idxs) 87 | 88 | input_ids = torch.tensor([f['input_ids'] for f in pairs], dtype=torch.long).to(DEVICE) 89 | attention_mask = torch.tensor([f['attention_mask'] for f in pairs], dtype=torch.long).to(DEVICE) 90 | token_type_ids = torch.tensor([f['token_type_ids'] for f in pairs], dtype=torch.long).to(DEVICE) 91 | 92 | outputs = self.bert.forward(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids) 93 | 94 | hidden_state = outputs[0] 95 | 96 | def one(i): 97 | if len(X[i]) > 0: 98 | l = [hidden_state[i, j] for j in X[i]] # + [mismatch_scores[i, j] for j in all_mismatches[i]] 99 | return torch.stack(l) 100 | return torch.tensor([]).to(DEVICE) 101 | 102 | pooled_output = torch.cat([one(i) for i in range(bsize)]) 103 | 104 | bsize = len(pooled_output) 105 | 106 | if bsize == 0: 107 | term_scores = [] 108 | for doc in doc_partials: 109 | term_scores.append([]) 110 | for (idx, term) in doc: 111 | term_scores[-1].append((term, 0.0)) 112 | 113 | return torch.tensor([[0.0]] * len(Q)).to(DEVICE), term_scores 114 | 115 | y_score = self.impact_score_encoder(pooled_output) 116 | 117 | x = torch.arange(bsize).expand(len(pfx_sum), bsize) < torch.tensor(pfx_sum).unsqueeze(1) 118 | y = torch.arange(bsize).expand(len(pfx_sum), bsize) >= torch.tensor([0] + pfx_sum[:-1]).unsqueeze(1) 119 | mask = (x & y).to(DEVICE) 120 | 121 | y_scorex = list(y_score.cpu()) 122 | term_scores = [] 123 | for doc in doc_partials: 124 | term_scores.append([]) 125 | for (idx, term) in doc: 126 | term_scores[-1].append((term, y_scorex[idx])) 127 | 128 | return (mask.type(torch.float32) @ y_score), term_scores #, ordered_terms #, num_exceeding_fifth 129 | 130 | def index(self, D, max_seq_length): 131 | if max_seq_length % 10 == 0: 132 | print("#>>> max_seq_length = ", max_seq_length) 133 | 134 | bsize = len(D) 135 | offset = 0 136 | pairs, X = [], [] 137 | 138 | for tokenized_content, terms in D: 139 | terms = [(t, idx, offset + pos) for pos, (t, idx) in enumerate(terms)] 140 | offset += len(terms) 141 | pairs.append(self.convert_example(tokenized_content, max_seq_length)) 142 | X.append(terms) 143 | 144 | input_ids = torch.tensor([f['input_ids'] for f in pairs], dtype=torch.long).to(DEVICE) 145 | attention_mask = torch.tensor([f['attention_mask'] for f in pairs], dtype=torch.long).to(DEVICE) 146 | token_type_ids = torch.tensor([f['token_type_ids'] for f in pairs], dtype=torch.long).to(DEVICE) 147 | 148 | outputs = self.bert.forward(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids) 149 | 150 | hidden_state = outputs[0] 151 | pooled_output = torch.cat([hidden_state[i, list(map(lambda x: x[1], X[i]))] for i in range(bsize)]) 152 | 153 | y_score = self.impact_score_encoder(pooled_output) 154 | y_score = y_score.squeeze().cpu().numpy().tolist() 155 | term_scores = [[(term, y_score[pos]) for term, _, pos in terms] for terms in X] 156 | 157 | return term_scores 158 | 159 | -------------------------------------------------------------------------------- /src/parameters.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | DEVICE = torch.device("cuda:0") 4 | 5 | DEFAULT_DATA_DIR = './data_download/' 6 | 7 | SAVED_CHECKPOINTS = [32*1000, 100*1000, 150*1000, 200*1000, 300*1000, 400*1000] 8 | -------------------------------------------------------------------------------- /src/rerank.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | 4 | from argparse import ArgumentParser 5 | from src.parameters import DEFAULT_DATA_DIR, DEVICE 6 | from src.utils import print_message, create_directory 7 | 8 | from src.evaluation.loaders import load_colbert, load_topK, load_qrels 9 | from src.indexing.loaders import load_document_encodings 10 | 11 | from src.evaluation.ranking import evaluate 12 | from src.evaluation.metrics import evaluate_recall 13 | 14 | 15 | def main(): 16 | random.seed(123456) 17 | 18 | parser = ArgumentParser(description='Exhaustive (non-index-based) evaluation of re-ranking with ColBERT.') 19 | 20 | parser.add_argument('--checkpoint', dest='checkpoint', required=True) 21 | parser.add_argument('--topk', dest='topK', required=True) 22 | parser.add_argument('--qrels', dest='qrels', default=None) 23 | 24 | parser.add_argument('--index', dest='index', required=True) 25 | parser.add_argument('--index_dir', dest='index_dir', default='outputs.index/') 26 | 27 | parser.add_argument('--data_dir', dest='data_dir', default=DEFAULT_DATA_DIR) 28 | parser.add_argument('--output_dir', dest='output_dir', default='outputs.rerank/') 29 | 30 | parser.add_argument('--bsize', dest='bsize', default=128, type=int) 31 | parser.add_argument('--subsample', dest='subsample', default=None) # TODO: Add this 32 | 33 | # TODO: For the following four arguments, default should be None. If None, they should be loaded from checkpoint. 34 | parser.add_argument('--similarity', dest='similarity', default='cosine', choices=['cosine', 'l2']) 35 | parser.add_argument('--dim', dest='dim', default=128, type=int) 36 | parser.add_argument('--query_maxlen', dest='query_maxlen', default=32, type=int) 37 | parser.add_argument('--doc_maxlen', dest='doc_maxlen', default=180, type=int) 38 | 39 | args = parser.parse_args() 40 | args.input_arguments = args 41 | args.run_name = args.topK 42 | args.shortcircuit = False 43 | 44 | create_directory(args.output_dir) 45 | 46 | args.topK = os.path.join(args.data_dir, args.topK) 47 | 48 | if args.qrels: 49 | args.qrels = os.path.join(args.data_dir, args.qrels) 50 | 51 | args.colbert, args.checkpoint = load_colbert(args) 52 | args.qrels = load_qrels(args.qrels) 53 | args.queries, args.topK_docs, args.topK_pids = load_topK(args.topK) 54 | 55 | evaluate_recall(args.qrels, args.queries, args.topK_pids) 56 | 57 | args.index = os.path.join(args.index_dir, args.index) 58 | args.index = load_document_encodings(args.index) 59 | evaluate(args, args.index) 60 | 61 | 62 | if __name__ == "__main__": 63 | main() 64 | -------------------------------------------------------------------------------- /src/retrieve.py: -------------------------------------------------------------------------------- 1 | # To be released soon. -------------------------------------------------------------------------------- /src/test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | 4 | from argparse import ArgumentParser 5 | from multiprocessing import Pool 6 | 7 | from src.parameters import DEFAULT_DATA_DIR, DEVICE 8 | from src.utils import print_message, create_directory 9 | 10 | from src.evaluation.loaders import load_colbert, load_topK, load_qrels 11 | from src.evaluation.ranking import evaluate 12 | from src.evaluation.metrics import evaluate_recall 13 | 14 | 15 | def main(): 16 | random.seed(123456) 17 | 18 | parser = ArgumentParser(description='Exhaustive (non-index-based) evaluation of re-ranking with ColBERT.') 19 | 20 | parser.add_argument('--checkpoint', dest='checkpoint', required=True) 21 | parser.add_argument('--topk', dest='topK', required=True) 22 | parser.add_argument('--qrels', dest='qrels', default=None) 23 | parser.add_argument('--shortcircuit', dest='shortcircuit', default=False, action='store_true') 24 | 25 | parser.add_argument('--data_dir', dest='data_dir', default=DEFAULT_DATA_DIR) 26 | parser.add_argument('--output_dir', dest='output_dir', default='outputs.test/') 27 | 28 | parser.add_argument('--bsize', dest='bsize', default=128, type=int) 29 | parser.add_argument('--subsample', dest='subsample', default=None) # TODO: Add this 30 | 31 | # TODO: For the following four arguments, default should be None. If None, they should be loaded from checkpoint. 32 | parser.add_argument('--similarity', dest='similarity', default='cosine', choices=['cosine', 'l2']) 33 | parser.add_argument('--dim', dest='dim', default=128, type=int) 34 | parser.add_argument('--query_maxlen', dest='query_maxlen', default=32, type=int) 35 | parser.add_argument('--doc_maxlen', dest='doc_maxlen', default=180, type=int) 36 | 37 | args = parser.parse_args() 38 | args.input_arguments = args 39 | 40 | assert (not args.shortcircuit) or args.qrels, \ 41 | "Short-circuiting (i.e., applying minimal computation to queries with no positives [in the re-ranked set]) " \ 42 | "can only be applied if qrels is provided." 43 | 44 | args.pool = Pool(10) 45 | args.run_name = args.topK 46 | 47 | create_directory(args.output_dir) 48 | 49 | args.topK = os.path.join(args.data_dir, args.topK) 50 | 51 | if args.qrels: 52 | args.qrels = os.path.join(args.data_dir, args.qrels) 53 | 54 | args.colbert, args.checkpoint = load_colbert(args) 55 | args.qrels = load_qrels(args.qrels) 56 | args.queries, args.topK_docs, args.topK_pids = load_topK(args.topK) 57 | 58 | evaluate_recall(args.qrels, args.queries, args.topK_pids) 59 | evaluate(args) 60 | 61 | 62 | if __name__ == "__main__": 63 | main() 64 | -------------------------------------------------------------------------------- /src/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import torch 4 | 5 | from argparse import ArgumentParser 6 | 7 | from src.training.data_reader import train 8 | from src.utils import print_message, create_directory 9 | 10 | 11 | def main(): 12 | random.seed(12345) 13 | torch.manual_seed(1) 14 | 15 | parser = ArgumentParser(description='Training ColBERT with triples.') 16 | 17 | parser.add_argument('--lr', dest='lr', default=3e-06, type=float) 18 | parser.add_argument('--maxsteps', dest='maxsteps', default=400000, type=int) 19 | parser.add_argument('--bsize', dest='bsize', default=32, type=int) 20 | parser.add_argument('--accum', dest='accumsteps', default=2, type=int) 21 | 22 | parser.add_argument('--triples', dest='triples', default='triples.train.small.tsv') 23 | parser.add_argument('--output_dir', dest='output_dir', default='outputs.train/') 24 | 25 | parser.add_argument('--similarity', dest='similarity', default='cosine', choices=['cosine', 'l2']) 26 | parser.add_argument('--dim', dest='dim', default=128, type=int) 27 | parser.add_argument('--query_maxlen', dest='query_maxlen', default=32, type=int) 28 | parser.add_argument('--doc_maxlen', dest='doc_maxlen', default=180, type=int) 29 | 30 | # TODO: Add resume functionality 31 | # TODO: Save the configuration to the checkpoint. 32 | # TODO: Extract common parser arguments/behavior into a class. 33 | 34 | args = parser.parse_args() 35 | args.input_arguments = args 36 | 37 | create_directory(args.output_dir) 38 | 39 | assert args.bsize % args.accumsteps == 0, ((args.bsize, args.accumsteps), 40 | "The batch size must be divisible by the number of gradient accumulation steps.") 41 | assert args.query_maxlen <= 512 42 | assert args.doc_maxlen <= 512 43 | 44 | train(args) 45 | 46 | 47 | if __name__ == "__main__": 48 | main() 49 | -------------------------------------------------------------------------------- /src/training/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DI4IR/SIGIR2021/a6b1ee4efaba7d0de75501f2f05a4b9353cdb673/src/training/__init__.py -------------------------------------------------------------------------------- /src/training/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DI4IR/SIGIR2021/a6b1ee4efaba7d0de75501f2f05a4b9353cdb673/src/training/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /src/training/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DI4IR/SIGIR2021/a6b1ee4efaba7d0de75501f2f05a4b9353cdb673/src/training/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /src/training/__pycache__/data_reader.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DI4IR/SIGIR2021/a6b1ee4efaba7d0de75501f2f05a4b9353cdb673/src/training/__pycache__/data_reader.cpython-37.pyc -------------------------------------------------------------------------------- /src/training/__pycache__/data_reader.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DI4IR/SIGIR2021/a6b1ee4efaba7d0de75501f2f05a4b9353cdb673/src/training/__pycache__/data_reader.cpython-38.pyc -------------------------------------------------------------------------------- /src/training/data_reader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import torch 4 | import torch.nn as nn 5 | 6 | from argparse import ArgumentParser 7 | from transformers import AdamW 8 | 9 | from src.parameters import DEVICE, SAVED_CHECKPOINTS 10 | 11 | from src.model import MultiBERT 12 | from src.utils import print_message, save_checkpoint 13 | import re 14 | import datetime 15 | class TrainReader: 16 | def __init__(self, data_file): 17 | print_message("#> Training with the triples in", data_file, "...\n\n") 18 | self.reader = open(data_file, mode='r', encoding="utf-8") 19 | 20 | def get_minibatch(self, bsize): 21 | return [self.reader.readline().split('\t') for _ in range(bsize)] 22 | 23 | 24 | def manage_checkpoints(colbert, optimizer, batch_idx): 25 | if batch_idx % 2000 == 0: 26 | save_checkpoint("colbert-test.dnn", 0, batch_idx, colbert, optimizer) 27 | 28 | if batch_idx in SAVED_CHECKPOINTS: 29 | save_checkpoint("colbert-test-" + str(batch_idx) + ".dnn", 0, batch_idx, colbert, optimizer) 30 | 31 | 32 | def train(args): 33 | colbert = MultiBERT.from_pretrained('bert-base-uncased') 34 | colbert = colbert.to(DEVICE) 35 | colbert.train() 36 | 37 | criterion = nn.CrossEntropyLoss() 38 | optimizer = AdamW(colbert.parameters(), lr=args.lr, eps=1e-8) 39 | 40 | optimizer.zero_grad() 41 | labels = torch.zeros(args.bsize, dtype=torch.long, device=DEVICE) 42 | 43 | reader = TrainReader(args.triples) 44 | train_loss = 0.0 45 | 46 | for batch_idx in range(args.maxsteps): 47 | Batch = reader.get_minibatch(args.bsize) 48 | Batch = sorted(Batch, key=lambda x: max(len(x[1]), len(x[2]))) 49 | 50 | for B_idx in range(args.accumsteps): 51 | size = args.bsize // args.accumsteps 52 | B = Batch[B_idx * size: (B_idx+1) * size] 53 | Q, D1, D2 = zip(*B) 54 | 55 | colbert_out, _ = colbert(Q + Q, D1 + D2) 56 | colbert_out= colbert_out.squeeze(1) 57 | 58 | colbert_out1, colbert_out2 = colbert_out[:len(Q)], colbert_out[len(Q):] 59 | 60 | out = torch.stack((colbert_out1, colbert_out2), dim=-1) 61 | 62 | positive_score, negative_score = round(colbert_out1.mean().item(), 2), round(colbert_out2.mean().item(), 2) 63 | print("#>>> ", positive_score, negative_score, '\t\t|\t\t', positive_score - negative_score) 64 | loss = criterion(out, labels[:out.size(0)]) 65 | loss = loss / args.accumsteps 66 | loss.backward() 67 | 68 | train_loss += loss.item() 69 | 70 | torch.nn.utils.clip_grad_norm_(colbert.parameters(), 2.0) 71 | 72 | optimizer.step() 73 | optimizer.zero_grad() 74 | 75 | print_message(batch_idx, train_loss / (batch_idx+1)) 76 | 77 | manage_checkpoints(colbert, optimizer, batch_idx+1) 78 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import datetime 4 | 5 | 6 | def print_message(*s): 7 | s = ' '.join([str(x) for x in s]) 8 | print("[{}] {}".format(datetime.datetime.utcnow().strftime("%b %d, %H:%M:%S"), s), flush=True) 9 | 10 | 11 | def save_checkpoint(path, epoch_idx, mb_idx, model, optimizer): 12 | print("#> Saving a checkpoint..") 13 | 14 | checkpoint = {} 15 | checkpoint['epoch'] = epoch_idx 16 | checkpoint['batch'] = mb_idx 17 | checkpoint['model_state_dict'] = model.state_dict() 18 | checkpoint['optimizer_state_dict'] = optimizer.state_dict() 19 | 20 | torch.save(checkpoint, path) 21 | 22 | 23 | def load_checkpoint(path, model, optimizer=None): 24 | print_message("#> Loading checkpoint", path) 25 | 26 | checkpoint = torch.load(path, map_location='cpu') 27 | model.load_state_dict(checkpoint['model_state_dict']) 28 | 29 | if optimizer: 30 | optimizer.load_state_dict(checkpoint['optimizer_state_dict']) 31 | 32 | print_message("#> checkpoint['epoch'] =", checkpoint['epoch']) 33 | print_message("#> checkpoint['batch'] =", checkpoint['batch']) 34 | 35 | return checkpoint 36 | 37 | 38 | def create_directory(path): 39 | if not os.path.exists(path): 40 | print_message("#> Creating", path) 41 | os.makedirs(path) 42 | 43 | 44 | def batch(group, bsize): 45 | offset = 0 46 | while offset < len(group): 47 | L = group[offset: offset + bsize] 48 | yield L 49 | offset += len(L) 50 | return 51 | -------------------------------------------------------------------------------- /src/utils2.py: -------------------------------------------------------------------------------- 1 | import string 2 | 3 | STOPLIST = ["a", "about", "also", "am", "an", "and", "another", "any", "anyone", "are", "aren't", "as", "at", "be", 4 | "been", "being", "but", "by", "despite", "did", "didn't", "do", "does", "doesn't", "doing", "done", "don't", 5 | "each", "etc", "every", "everyone", "for", "from", "further", "had", "hadn't", "has", "hasn't", "have", 6 | "haven't", "having", "he", "he'd", "he'll", "her", "here", "here's", "hers", "herself", "he's", 7 | "him", "himself", "his", "however", "i", "i'd", "if", "i'll", "i'm", "in", "into", "is", "isn't", "it", 8 | "its", "it's", "itself", "i've", "just", "let's", "like", "lot", "may", "me", "might", "mightn't", 9 | "my", "myself", "no", "nor", "not", "of", "on", "onto", "or", "other", "ought", "oughtn't", "our", "ours", 10 | "ourselves", "out", "over", "shall", "shan't", "she", "she'd", "she'll", "she's", "since", "so", "some", 11 | "something", "such", "than", "that", "that's", "the", "their", "theirs", "them", "themselves", "then", 12 | "there", "there's", "these", "they", "they'd", "they'll", "they're", "they've", "this", "those", "through", 13 | "tht", "to", "too", "usually", "very", "via", "was", "wasn't", "we", "we'd", "well", "we'll", "were", 14 | "we're", "weren't", "we've", "will", "with", "without", "won't", "would", "wouldn't", "yes", "yet", "you", 15 | "you'd", "you'll", "your", "you're", "yours", "yourself", "yourselves", "you've"] 16 | 17 | printable = set('0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ ') 18 | printableX = set('0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ-. ') 19 | printable3X = set('0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ,.- ') 20 | 21 | printableD = set('0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ.- ') 22 | printable3D = set('0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ.- ') 23 | 24 | STOPLIST_ = list(map(lambda s: ''.join(filter(lambda x: x in printable, s)), STOPLIST)) 25 | 26 | STOPLIST = {} 27 | for w in STOPLIST_: 28 | STOPLIST[w] = True 29 | 30 | def cleanD(s, join=True): 31 | s = [(x.lower() if x in printable3X else ' ') for x in s] 32 | s = [(x if x in printableX else ' ' + x + ' ') for x in s] 33 | s = ''.join(s).split() 34 | s = [(w if '.' not in w else (' . ' if len(max(w.split('.'), key=len)) > 1 else '').join(w.split('.'))) for w in s] 35 | s = ' '.join(s).split() 36 | s = [(w if '-' not in w else w.replace('-', '') + ' ( ' + ' '.join(w.split('-')) + ' ) ') for w in s] 37 | s = ' '.join(s).split() 38 | # s = [w for w in s if w not in STOPLIST] 39 | 40 | return ' '.join(s) if join else s 41 | 42 | 43 | def cleanQ(s, join=True): 44 | s = [(x.lower() if x in printable3D else ' ') for x in s] 45 | s = [(x if x in printableD else ' ' + x + ' ') for x in s] 46 | s = ''.join(s).split() 47 | s = [(w if '.' not in w else (' ' if len(max(w.split('.'), key=len)) > 1 else '').join(w.split('.'))) for w in s] 48 | s = ' '.join(s).split() 49 | s = [(w if '-' not in w else (' ' if len(min(w.split('-'), key=len)) > 1 else '').join(w.split('-'))) for w in s] 50 | s = ' '.join(s).split() 51 | s = [w for w in s if w not in STOPLIST] 52 | 53 | return ' '.join(s) if join else s 54 | 55 | --------------------------------------------------------------------------------