├── .gitignore ├── LICENSE.txt ├── README.md ├── bin ├── gdeval.pl └── trec_eval ├── cedr ├── __init__.py ├── data.py ├── extract_docs_from_index.py ├── modeling.py ├── modeling_util.py ├── rerank.py └── train.py ├── data ├── robust │ ├── f1.test.run │ ├── f1.train.pairs │ ├── f1.valid.run │ ├── f2.test.run │ ├── f2.train.pairs │ ├── f2.valid.run │ ├── f3.test.run │ ├── f3.train.pairs │ ├── f3.valid.run │ ├── f4.test.run │ ├── f4.train.pairs │ ├── f4.valid.run │ ├── f5.test.run │ ├── f5.train.pairs │ ├── f5.valid.run │ ├── qrels │ └── queries.tsv └── wt │ ├── qrels │ ├── queries.tsv │ ├── test.wt12.run │ ├── test.wt13.run │ ├── test.wt14.run │ ├── train.wt12.pairs │ ├── train.wt13.pairs │ ├── train.wt14.pairs │ └── valid.wt11.run ├── requirements.txt └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | build 2 | *.egg-info 3 | __pycache__ 4 | dist 5 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) 2 | 3 | Copyright (c) 2019 Georgetown Information Retrieval Lab 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CEDR: Contextualized Embeddings for Document Ranking 2 | Sean MacAvaney, Andrew Yates, Arman Cohan, Nazli Goharian. *CEDR: Contextualized Embeddings for Document Ranking*. SIGIR 2019 (short). 3 | 4 | Paper at: https://arxiv.org/abs/1904.07094 5 | 6 | > tl;dr We demonstrate the effectiveness of using BERT classification for document ranking 7 | > (*"Vanilla BERT"*) and show that BERT embeddings can be used by prior neural ranking architectures 8 | > to further improve ranking performance (*"CEDR-\* models"*). 9 | 10 | If you use this work, please cite as: [bibtex](https://smac.pub/bib/sigir2019-contextuallms.bib) 11 | 12 | ``` 13 | @InProceedings{macavaney2019ContextualWR, 14 | author = {MacAvaney, Sean and Yates, Andrew and Cohan, Arman and Goharian, Nazli}, 15 | title = {CEDR: Contextualized Embeddings for Document Ranking}, 16 | booktitle = {SIGIR}, 17 | year = {2019} 18 | } 19 | ``` 20 | 21 | ## Getting started 22 | 23 | This code is tested on Python 3.6. Install dependencies using the following command: 24 | 25 | ``` 26 | pip install -r requirements.txt 27 | ``` 28 | 29 | You will need to prepare files for training and evaluation. Many of these files are available in 30 | `data/wt` (TREC WebTrack) and `data/robust` (TREC Robust 2004). 31 | 32 | **qrels**: a standard TREC-style query relevant file. Used for identifying relevant items for 33 | training pair generation and for validation (`data/wt/qrels`, `data/robust/qrels`). 34 | 35 | **train_pairs**: a tab-deliminted file containing pairs used for training. The training process 36 | will only use query-document pairs found in this file. Samples are in `data/{wt,robust}/*.pairs`. 37 | File format: 38 | 39 | ``` 40 | [query_id] [doc_id] 41 | ``` 42 | 43 | **valid_run**: a standard TREC-style run file for re-ranking items for validation. The `.run` files used for re-ranking are available in `data/{wt,robust}/*.run`. Note that these runs are using the default parameters, so they do not match the tuned results shown in Table 1. 44 | 45 | **datafiles**: Files containing the text of queries and documents needed for training, validation, 46 | or testing. Should be in tab-delimited format as follows, where `[type]` is either `query` or `doc`, 47 | `[id]` is the identifer of the query or document (e.g., `132`, `clueweb12-0206wb-59-32292`), and 48 | `[text]` is the textual content of the query or document (no tabs or newline characters, 49 | tokenization done by `BertTokenizer`). 50 | 51 | ``` 52 | [type] [id] [text] 53 | ``` 54 | 55 | Queries for WebTrack and Robust are available in `data/wt/queries.tsv` and `data/robust/queries.tsv`. 56 | Document text can be extracted from an index using `extract_docs_from_index.py` (be sure to use an 57 | index that has appropriate pre-processing). The script supports both Indri and Lucene (via Anserini) 58 | indices. See instructions below for help installing pyndri or Anserini. 59 | 60 | Examples: 61 | 62 | ``` 63 | # Indri index 64 | awk '{print $3}' data/robust/*.run | python extract_docs_from_index.py indri PATH_TO_INDRI_INDEX > data/robust/documents.tsv 65 | # Lucene index (should be built with Anserini and the -storeTransformedDocs) 66 | awk '{print $3}' data/robust/*.run | python extract_docs_from_index.py lucene PATH_TO_LUCENE_INDEX > data/robust/documents.tsv 67 | ``` 68 | 69 | ## Running Vanilla BERT 70 | 71 | To train a Vanilla BERT model, use the following command: 72 | 73 | ``` 74 | python train.py \ 75 | --model vanilla_bert \ 76 | --datafiles data/queries.tsv data/documents.tsv \ 77 | --qrels data/qrels \ 78 | --train_pairs data/train_pairs \ 79 | --valid_run data/valid_run \ 80 | --model_out_dir models/vbert 81 | ``` 82 | 83 | You can see the performance of Vanilla BERT by re-ranking a test run: 84 | 85 | ``` 86 | python rerank.py \ 87 | --model vanilla_bert \ 88 | --datafiles data/queries.tsv data/documents.tsv \ 89 | --run data/test_run \ 90 | --model_weights models/vbert/weights.p \ 91 | --out_path models/vbert/test.run 92 | ``` 93 | 94 | ## Running CEDR 95 | 96 | To train a CEDR model, first train a Vanilla BERT model, and then use the following command: 97 | 98 | ``` 99 | python train.py \ 100 | --model cedr_pacrr \ # or cedr_knrm / cedr_drmm 101 | --datafiles data/queries.tsv data/documents.tsv \ 102 | --qrels data/qrels \ 103 | --train_pairs data/train_pairs \ 104 | --valid_run data/valid_run \ 105 | --initial_bert_weights models/vbert/weights.p \ 106 | --model_out_dir models/cedrpacrr 107 | ``` 108 | 109 | You can see the performance of CEDR by re-ranking a test run: 110 | 111 | ``` 112 | python rerank.py \ 113 | --model cedr_pacrr \ # or cedr_knrm / cedr_drmm 114 | --datafiles data/queries.tsv data/documents.tsv \ 115 | --run data/test_run \ 116 | --model_weights models/cedrpacrr/weights.p \ 117 | --out_path models/cedrpacrr/test.run 118 | ``` 119 | 120 | Note that this will calculate results using `bin/trec_eval` with P@20, whereas the nDCG@20 and ERR@20 results in Table 1 are calculated using `bin/gdeval.pl`. 121 | 122 | ## Misc 123 | 124 | These instructions are only needed if using the `extract_docs_from_index.py` script, and depend 125 | on the index from which you are extracting documents. 126 | 127 | ### Installing pyndri 128 | 129 | Here's what worked for me. Please refer to [cvangysel/pyndri](https://github.com/cvangysel/pyndri) 130 | for futher assistance installing pyndri. 131 | 132 | ``` 133 | wget https://sourceforge.net/projects/lemur/files/lemur/indri-5.14/indri-5.14.tar.gz 134 | tar xvfz indri-5.14.tar.gz 135 | cd indri-5.14 136 | ./configure CXX="g++ -D_GLIBCXX_USE_CXX11_ABI=0" 137 | make 138 | sudo make install 139 | pip install pyndri==0.4 140 | ``` 141 | 142 | ### Installing Anserini 143 | 144 | Install pyjnius (refer to [kivy/pyjnius](https://github.com/kivy/pyjnius) for futher assistance 145 | with pyjnius.) 146 | 147 | ``` 148 | pip install pyjnius==1.1.4 149 | ``` 150 | 151 | Build Anserini (refer to [castorini/anserini](https://github.com/castorini/anserini) for further 152 | assistance with Anserini.) 153 | 154 | ``` 155 | wget https://github.com/castorini/anserini/archive/anserini-0.4.0.tar.gz 156 | tar -xzvf anserini-0.4.0.tar.gz 157 | cd anserini-anserini-0.4.0/ 158 | mvn clean package appassembler:assemble 159 | mv target/anserini-0.4.0-fatjar.jar ~/cedr/bin/anserini.jar 160 | ``` 161 | -------------------------------------------------------------------------------- /bin/gdeval.pl: -------------------------------------------------------------------------------- 1 | #!/usr/bin/perl -w 2 | 3 | # Graded relevance assessment script for the TREC 2010 Web track 4 | # Evalution measures are written to standard output in CSV format. 5 | # 6 | # Currently reports only NDCG and ERR 7 | # (see http://learningtorankchallenge.yahoo.com/instructions.php) 8 | 9 | use constant LOGBASEDIV => log(2.0); 10 | 11 | # gloals 12 | my $QRELS; 13 | my $VERSION = "version 1.3 (Mon Apr 29 20:50:24 EDT 2013)"; 14 | my $MAX_JUDGMENT = 4; # Maximum gain value allowed in qrels file. 15 | my $K = 20; # Reporting depth for results. 16 | my $USAGE = "usage: $0 [options] qrels run\n 17 | options:\n 18 | -c 19 | Average over the complete set of topics in the relevance judgments 20 | instead of the topics in the intersection of relevance judgments 21 | and results.\n 22 | -k value 23 | Non-negative integer depth of ranking to evaluate in range [1,inf]. 24 | Default value is k=@{[($K)]}.\n 25 | -baseline BASELINE_RUN_FILE 26 | Baseline run to use for risk-sensitive evaluation\n 27 | -riskAlpha value 28 | Non-negative Risk sensitivity value to use when doing risk-sensitive 29 | evaluation. A baseline must still be specified. By default 0. 30 | The final weight to downside changes in performance is (1+value).\n"; 31 | 32 | 33 | 34 | use strict 'vars'; 35 | 36 | { # main block to scope variables 37 | if ($#ARGV >= 0 && ($ARGV[0] eq "-v" || $ARGV[0] eq "-version")) { 38 | print "$0: $VERSION\n"; 39 | exit 0; 40 | } 41 | 42 | my $baselineRun = undef; 43 | my $riskAlpha = 0; 44 | my $cflag = 0; 45 | while ($#ARGV != 1) # should probably replace this with perl's argument parsing 46 | { 47 | if ($#ARGV >= 0 && $ARGV[0] eq "-help") { 48 | print "$USAGE\n"; 49 | exit 0; 50 | } 51 | elsif ($#ARGV >= 2 and ("-c" eq $ARGV[0])) 52 | { 53 | $cflag = 1; 54 | shift @ARGV; 55 | } 56 | elsif ($#ARGV >= 3 and ("-k" eq $ARGV[0])) 57 | { 58 | $K = int($ARGV[1]); 59 | die $USAGE if ($K < 1); 60 | # print STDERR "k=$K\n"; 61 | shift @ARGV; shift @ARGV; 62 | } 63 | elsif ($#ARGV >= 3 and ("-baseline" eq $ARGV[0])) 64 | { 65 | $baselineRun = $ARGV[1]; 66 | shift @ARGV; shift @ARGV; 67 | } 68 | elsif ($#ARGV >= 3 and ("-riskAlpha" eq $ARGV[0])) 69 | { 70 | $riskAlpha = $ARGV[1]; 71 | die $USAGE if ($riskAlpha < 0.0); 72 | shift @ARGV; shift @ARGV; 73 | } 74 | else 75 | { 76 | die $USAGE; 77 | } 78 | } 79 | 80 | die $USAGE unless $#ARGV == 1; 81 | $QRELS = $ARGV[0]; 82 | my $run = $ARGV[1]; 83 | 84 | # Read qrels file, check format, and sort 85 | my @qrels = (); 86 | my %seen = (); 87 | open (QRELS,"<$QRELS") || die "$0: cannot open \"$QRELS\": $!\n"; 88 | while () { 89 | s/[\r\n]//g; 90 | my ($topic, $zero, $docno, $judgment) = split (' '); 91 | $topic =~ s/^.*\-//; 92 | die "$0: format error on line $. of \"$QRELS\"\n" 93 | unless 94 | $topic =~ /^[0-9]+$/ && $zero == 0 95 | && $judgment =~ /^-?[0-9]+$/ && $judgment <= $MAX_JUDGMENT; 96 | if ($judgment > 0) { 97 | $qrels[$#qrels + 1]= "$topic $docno $judgment"; 98 | $seen{$topic} = 1; 99 | } 100 | } 101 | close (QRELS); 102 | @qrels = sort qrelsOrder (@qrels); 103 | 104 | # Process qrels: store judgments and compute ideal gains 105 | my $topicCurrent = -1; 106 | my %ideal = (); 107 | my @gain = (); 108 | my %judgment = (); 109 | for (my $i = 0; $i <= $#qrels; $i++) { 110 | my ($topic, $docno, $judgment) = split (' ', $qrels[$i]); 111 | if ($topic != $topicCurrent) { 112 | if ($topicCurrent >= 0) { 113 | $ideal{$topicCurrent} = &dcg($K, @gain); 114 | $#gain = -1; 115 | } 116 | $topicCurrent = $topic; 117 | } 118 | next if $judgment < 0; 119 | $judgment{"$topic:$docno"} = $gain[$#gain + 1] = $judgment; 120 | } 121 | if ($topicCurrent >= 0) { 122 | $ideal{$topicCurrent} = &dcg($K, @gain); 123 | $#gain = -1; 124 | } 125 | 126 | # process baseline if doing risk sensitive 127 | my ($baseNDCGByTopic,$baseERRByTopic,$baserunname); 128 | if (defined $baselineRun) 129 | { 130 | ($baseNDCGByTopic,$baseERRByTopic,$baserunname) = processRun($baselineRun,0,\%seen,\%ideal,\%judgment,$cflag,0); 131 | } 132 | 133 | # process main run 134 | processRun($run,1,\%seen,\%ideal,\%judgment,$cflag,defined($baselineRun),$riskAlpha,$baserunname,$baseNDCGByTopic,$baseERRByTopic); 135 | 136 | exit 0; 137 | 138 | } # end main block 139 | 140 | # comparison function for qrels: by topic then judgment 141 | sub qrelsOrder { 142 | my ($topicA, $docnoA, $judgmentA) = split (' ', $a); 143 | my ($topicB, $docnoB, $judgmentB) = split (' ', $b); 144 | 145 | if ($topicA < $topicB) { 146 | return -1; 147 | } elsif ($topicA > $topicB) { 148 | return 1; 149 | } else { 150 | return $judgmentB <=> $judgmentA; 151 | } 152 | } 153 | 154 | # comparison function for runs: by topic then score then docno 155 | sub runOrder { 156 | my ($topicA, $docnoA, $scoreA) = split (' ', $a); 157 | my ($topicB, $docnoB, $scoreB) = split (' ', $b); 158 | 159 | if ($topicA < $topicB) { 160 | return -1; 161 | } elsif ($topicA > $topicB) { 162 | return 1; 163 | } elsif ($scoreA < $scoreB) { 164 | return 1; 165 | } elsif ($scoreA > $scoreB) { 166 | return -1; 167 | } elsif ($docnoA lt $docnoB) { 168 | return 1; 169 | } elsif ($docnoA gt $docnoB) { 170 | return -1; 171 | } else { 172 | return 0; 173 | } 174 | } 175 | 176 | # compute DCG over a sorted array of gain values, reporting at depth $k 177 | sub dcg { 178 | my ($k, @gain) = @_; 179 | my ($i, $score) = (0, 0); 180 | 181 | for ($i = 0; $i <= ($k <= $#gain ? $k - 1 : $#gain); $i++) { 182 | $score += (2**$gain[$i] - 1)/(log ($i + 2)/ +LOGBASEDIV); 183 | } 184 | return $score; 185 | } 186 | 187 | # compute ERR over a sorted array of gain values, reporting at depth $k 188 | sub err { 189 | my ($k, @gain) = @_; 190 | my ($i, $score, $decay, $r); 191 | 192 | $score = 0.0; 193 | $decay = 1.0; 194 | for ($i = 0; $i <= ($k <= $#gain ? $k - 1 : $#gain); $i++) { 195 | $r = (2**$gain[$i] - 1)/(2**$MAX_JUDGMENT); 196 | $score += $r*$decay/($i + 1); 197 | $decay *= (1 - $r); 198 | } 199 | return $score; 200 | } 201 | 202 | sub riskWeighted 203 | { 204 | my ($run,$base,$alpha) = @_; 205 | if ($run < $base) 206 | { 207 | $run = (1+$alpha) * ($run - $base); 208 | } 209 | else 210 | { 211 | $run = $run - $base; 212 | } 213 | return $run; 214 | } 215 | # compute and report information for current topic 216 | sub topicDone { 217 | my ($printTopic, $runid, $topic, $pndcgTotal, $perrTotal, $ptopics, $pseen, $pideal, 218 | $isRiskSensitive, $riskAlpha, $baseNDCG, $baseERR, @gain) = @_; 219 | my($ndcg, $err) = (0, 0); 220 | if (exists($$pseen{$topic}) and defined($$pseen{$topic}) and $$pseen{$topic}) { 221 | $ndcg = &dcg($K, @gain)/$$pideal{$topic}; 222 | $err = &err ($K, @gain); 223 | $ndcg = riskWeighted($ndcg,$baseNDCG,$riskAlpha) if ($isRiskSensitive); 224 | $err = riskWeighted($err,$baseERR,$riskAlpha) if ($isRiskSensitive); 225 | $$pndcgTotal += $ndcg; 226 | $$perrTotal += $err; 227 | $$ptopics++; 228 | printf("$runid,$topic,%.5f,%.5f\n",$ndcg,$err) if ($printTopic); 229 | return ($ndcg,$err); 230 | } 231 | } 232 | 233 | sub processRun 234 | { 235 | my ($run,$printTopics,$pseen,$pideal,$pjudgment,$avgOverAllTopics,$isRiskSensitive,$riskAlpha,$baserunname,$baseNDCGByTopic,$baseERRByTopic) = @_; 236 | my $ndcgByTopic = {()}; 237 | my $errByTopic = {()}; 238 | my $runid = "?????"; 239 | my @run = (); 240 | # Read run rile, check format, and sort 241 | open (RUN,"<$run") || die "$0: cannot open \"$run\": $!\n"; 242 | while () { 243 | s/[\r\n]//g; 244 | my ($topic, $q0, $docno, $rank, $score); 245 | ($topic, $q0, $docno, $rank, $score, $runid) = split (' '); 246 | $topic =~ s/^.*\-//; 247 | die "$0: format error on line $. of \"$run\"\n" 248 | unless 249 | $topic =~ /^[0-9]+$/ && $q0 eq "Q0" && $rank =~ /^[0-9]+$/ && $runid; 250 | $run[$#run + 1] = "$topic $docno $score"; 251 | } 252 | 253 | @run = sort runOrder (@run); 254 | 255 | my %processed = (); 256 | foreach my $topic (%$pseen) 257 | { 258 | $processed{$topic} = 0; 259 | } 260 | 261 | if ($isRiskSensitive) 262 | { 263 | $runid = sprintf("%s (rel to. %s, rs=1+a, a=%s)",$runid,$baserunname,$riskAlpha); 264 | } 265 | 266 | # Process runs: compute measures for each topic and average 267 | my $ndcgTotal = 0; 268 | my $errTotal = 0; 269 | my $topics = 0; 270 | print "runid,topic,ndcg\@$K,err\@$K\n" if ($printTopics); 271 | my $topicCurrent = -1; 272 | my @gain = (); 273 | for (my $i = 0; $i <= $#run; $i++) { 274 | my ($topic, $docno, $score) = split (' ', $run[$i]); 275 | if ($topic != $topicCurrent) { 276 | if ($topicCurrent >= 0) { 277 | my ($baseNDCG,$baseERR) = 0; 278 | if ($isRiskSensitive) 279 | { 280 | $baseNDCG = $$baseNDCGByTopic{$topicCurrent} if (exists($$baseNDCGByTopic{$topicCurrent}) and defined($$baseNDCGByTopic{$topicCurrent})); 281 | $baseERR = $$baseERRByTopic{$topicCurrent} if (exists($$baseERRByTopic{$topicCurrent}) and defined($$baseERRByTopic{$topicCurrent})); 282 | } 283 | my ($ndcg,$err) = &topicDone ($printTopics, $runid, $topicCurrent, \$ndcgTotal, \$errTotal, \$topics, 284 | $pseen, $pideal, $isRiskSensitive, $riskAlpha, $baseNDCG, $baseERR, @gain); 285 | $$ndcgByTopic{$topicCurrent} = $ndcg; 286 | $$errByTopic{$topicCurrent} = $err; 287 | $processed{$topicCurrent} = 1; 288 | $#gain = -1; 289 | } 290 | $topicCurrent = $topic; 291 | } 292 | my $j = $$pjudgment{"$topic:$docno"}; 293 | $j = 0 unless $j; 294 | $gain[$#gain + 1] = $j; 295 | } 296 | if ($topicCurrent >= 0) { 297 | my ($baseNDCG,$baseERR) = 0; 298 | if ($isRiskSensitive) 299 | { 300 | $baseNDCG = $$baseNDCGByTopic{$topicCurrent} if (exists($$baseNDCGByTopic{$topicCurrent}) and defined($$baseNDCGByTopic{$topicCurrent})); 301 | $baseERR = $$baseERRByTopic{$topicCurrent} if (exists($$baseERRByTopic{$topicCurrent}) and defined($$baseERRByTopic{$topicCurrent})); 302 | } 303 | my ($ndcg,$err) = &topicDone ($printTopics, $runid, $topicCurrent, \$ndcgTotal, \$errTotal, \$topics, 304 | $pseen, $pideal, $isRiskSensitive, $riskAlpha, $baseNDCG, $baseERR, @gain); 305 | $$ndcgByTopic{$topicCurrent} = $ndcg; 306 | $$errByTopic{$topicCurrent} = $err; 307 | $processed{$topicCurrent} = 1; 308 | $#gain = -1; 309 | } 310 | my $numTopics = $topics; # $topics has the number in the run (at this point) 311 | if ($avgOverAllTopics) 312 | { 313 | $numTopics = scalar(keys %$pseen); # we want denominator to change whenever flag is on but only need to compute differences for risk 314 | if ($isRiskSensitive) 315 | { # need to process any topics that were missing from run 316 | my ($baseNDCG,$baseERR) = 0; 317 | my @gain = (); 318 | foreach my $topicCurrent (sort {$a <=> $b} keys %processed) 319 | { 320 | next if ($processed{$topicCurrent}); 321 | $baseNDCG = $$baseNDCGByTopic{$topicCurrent} if (exists($$baseNDCGByTopic{$topicCurrent}) and defined($$baseNDCGByTopic{$topicCurrent})); 322 | $baseERR = $$baseERRByTopic{$topicCurrent} if (exists($$baseERRByTopic{$topicCurrent}) and defined($$baseERRByTopic{$topicCurrent})); 323 | my ($ndcg,$err) = &topicDone ($printTopics, $runid, $topicCurrent, \$ndcgTotal, \$errTotal, \$topics, 324 | $pseen, $pideal, $isRiskSensitive, $riskAlpha, $baseNDCG, $baseERR, @gain); 325 | } 326 | } 327 | } 328 | 329 | my $ndcgAvg = $ndcgTotal; 330 | my $errAvg = $errTotal; 331 | if ($numTopics > 0) 332 | { 333 | $ndcgAvg /= $numTopics; 334 | $errAvg /= $numTopics; 335 | } 336 | printf "$runid,amean,%.5f,%.5f\n",$ndcgAvg,$errAvg if ($printTopics); 337 | 338 | return ($ndcgByTopic,$errByTopic,$runid); 339 | close(RUN); 340 | } 341 | -------------------------------------------------------------------------------- /bin/trec_eval: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Georgetown-IR-Lab/cedr/a263f1e00b4b4dcfdc1aa1519793224138be44fc/bin/trec_eval -------------------------------------------------------------------------------- /cedr/__init__.py: -------------------------------------------------------------------------------- 1 | from . import data 2 | from . import extract_docs_from_index 3 | from . import modeling 4 | from . import modeling_util 5 | from . import rerank 6 | from . import train 7 | -------------------------------------------------------------------------------- /cedr/data.py: -------------------------------------------------------------------------------- 1 | import random 2 | from tqdm import tqdm 3 | import torch 4 | 5 | 6 | def read_datafiles(files): 7 | queries = {} 8 | docs = {} 9 | for file in files: 10 | for line in tqdm(file, desc='loading datafile (by line)', leave=False): 11 | cols = line.rstrip().split('\t') 12 | if len(cols) != 3: 13 | tqdm.write(f'skipping line: `{line.rstrip()}`') 14 | continue 15 | c_type, c_id, c_text = cols 16 | assert c_type in ('query', 'doc') 17 | if c_type == 'query': 18 | queries[c_id] = c_text 19 | if c_type == 'doc': 20 | docs[c_id] = c_text 21 | return queries, docs 22 | 23 | 24 | def read_qrels_dict(file): 25 | result = {} 26 | for line in tqdm(file, desc='loading qrels (by line)', leave=False): 27 | qid, _, docid, score = line.split() 28 | result.setdefault(qid, {})[docid] = int(score) 29 | return result 30 | 31 | 32 | def read_run_dict(file): 33 | result = {} 34 | for line in tqdm(file, desc='loading run (by line)', leave=False): 35 | qid, _, docid, rank, score, _ = line.split() 36 | result.setdefault(qid, {})[docid] = float(score) 37 | return result 38 | 39 | 40 | def read_pairs_dict(file): 41 | result = {} 42 | for line in tqdm(file, desc='loading pairs (by line)', leave=False): 43 | qid, docid = line.split() 44 | result.setdefault(qid, {})[docid] = 1 45 | return result 46 | 47 | 48 | def iter_train_pairs(model, dataset, train_pairs, qrels, batch_size): 49 | batch = {'query_id': [], 'doc_id': [], 'query_tok': [], 'doc_tok': []} 50 | for qid, did, query_tok, doc_tok in _iter_train_pairs(model, dataset, train_pairs, qrels): 51 | batch['query_id'].append(qid) 52 | batch['doc_id'].append(did) 53 | batch['query_tok'].append(query_tok) 54 | batch['doc_tok'].append(doc_tok) 55 | if len(batch['query_id']) // 2 == batch_size: 56 | yield _pack_n_ship(batch) 57 | batch = {'query_id': [], 'doc_id': [], 'query_tok': [], 'doc_tok': []} 58 | 59 | 60 | 61 | def _iter_train_pairs(model, dataset, train_pairs, qrels): 62 | ds_queries, ds_docs = dataset 63 | while True: 64 | qids = list(train_pairs.keys()) 65 | random.shuffle(qids) 66 | for qid in qids: 67 | pos_ids = [did for did in train_pairs[qid] if qrels.get(qid, {}).get(did, 0) > 0] 68 | if len(pos_ids) == 0: 69 | tqdm.write("no positive labels for query %s " % qid) 70 | continue 71 | pos_id = random.choice(pos_ids) 72 | pos_ids_lookup = set(pos_ids) 73 | pos_ids = set(pos_ids) 74 | neg_ids = [did for did in train_pairs[qid] if did not in pos_ids_lookup] 75 | if len(neg_ids) == 0: 76 | tqdm.write("no negative labels for query %s " % qid) 77 | continue 78 | neg_id = random.choice(neg_ids) 79 | query_tok = model.tokenize(ds_queries[qid]) 80 | pos_doc = ds_docs.get(pos_id) 81 | neg_doc = ds_docs.get(neg_id) 82 | if pos_doc is None: 83 | tqdm.write(f'missing doc {pos_id}! Skipping') 84 | continue 85 | if neg_doc is None: 86 | tqdm.write(f'missing doc {neg_id}! Skipping') 87 | continue 88 | yield qid, pos_id, query_tok, model.tokenize(pos_doc) 89 | yield qid, neg_id, query_tok, model.tokenize(neg_doc) 90 | 91 | 92 | def iter_valid_records(model, dataset, run, batch_size): 93 | batch = {'query_id': [], 'doc_id': [], 'query_tok': [], 'doc_tok': []} 94 | for qid, did, query_tok, doc_tok in _iter_valid_records(model, dataset, run): 95 | batch['query_id'].append(qid) 96 | batch['doc_id'].append(did) 97 | batch['query_tok'].append(query_tok) 98 | batch['doc_tok'].append(doc_tok) 99 | if len(batch['query_id']) == batch_size: 100 | yield _pack_n_ship(batch) 101 | batch = {'query_id': [], 'doc_id': [], 'query_tok': [], 'doc_tok': []} 102 | # final batch 103 | if len(batch['query_id']) > 0: 104 | yield _pack_n_ship(batch) 105 | 106 | 107 | def _iter_valid_records(model, dataset, run): 108 | ds_queries, ds_docs = dataset 109 | for qid in run: 110 | query_tok = model.tokenize(ds_queries[qid]) 111 | for did in run[qid]: 112 | doc = ds_docs.get(did) 113 | if doc is None: 114 | tqdm.write(f'missing doc {did}! Skipping') 115 | continue 116 | doc_tok = model.tokenize(doc) 117 | yield qid, did, query_tok, doc_tok 118 | 119 | 120 | def _pack_n_ship(batch): 121 | QLEN = 20 122 | MAX_DLEN = 800 123 | DLEN = min(MAX_DLEN, max(len(b) for b in batch['doc_tok'])) 124 | return { 125 | 'query_id': batch['query_id'], 126 | 'doc_id': batch['doc_id'], 127 | 'query_tok': _pad_crop(batch['query_tok'], QLEN), 128 | 'doc_tok': _pad_crop(batch['doc_tok'], DLEN), 129 | 'query_mask': _mask(batch['query_tok'], QLEN), 130 | 'doc_mask': _mask(batch['doc_tok'], DLEN), 131 | } 132 | 133 | 134 | def _pad_crop(items, l): 135 | result = [] 136 | for item in items: 137 | if len(item) < l: 138 | item = item + [-1] * (l - len(item)) 139 | if len(item) > l: 140 | item = item[:l] 141 | result.append(item) 142 | return torch.tensor(result).long().cuda() 143 | 144 | 145 | def _mask(items, l): 146 | result = [] 147 | for item in items: 148 | # needs padding (masked) 149 | if len(item) < l: 150 | mask = [1. for _ in item] + ([0.] * (l - len(item))) 151 | # no padding (possible crop) 152 | if len(item) >= l: 153 | mask = [1. for _ in item[:l]] 154 | result.append(mask) 155 | return torch.tensor(result).float().cuda() 156 | -------------------------------------------------------------------------------- /cedr/extract_docs_from_index.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import argparse 4 | from tqdm import tqdm 5 | 6 | def indri_doc_extractor(path): 7 | import pyndri 8 | index = pyndri.Index(path) 9 | id2token = index.get_dictionary()[1] 10 | def wrapped(docid): 11 | doc_id_tuples = index.document_ids([docid]) 12 | if not doc_id_tuples: 13 | return None # not found 14 | int_docid = doc_id_tuples[0][1] 15 | _, doc_toks = index.document(int_docid) 16 | return ' '.join(id2token[tok] for tok in doc_toks if tok != 0) 17 | return wrapped 18 | 19 | 20 | def lucene_doc_extractor(path): 21 | import jnius_config 22 | if not os.path.exists('bin/anserini.jar'): 23 | sys.stderr.write('missing bin/anserini.jar') 24 | sys.exit(1) 25 | jnius_config.set_classpath("bin/anserini.jar") 26 | from jnius import autoclass 27 | index_utils = autoclass('io.anserini.index.IndexUtils')(path) 28 | def wrapped(docid): 29 | lucene_doc_id = index_utils.convertDocidToLuceneDocid(docid) 30 | if lucene_doc_id == -1: 31 | return None # not found 32 | return index_utils.getTransformedDocument(docid) 33 | return wrapped 34 | 35 | 36 | INDEX_MAP = { 37 | 'indri': indri_doc_extractor, 38 | 'lucene': lucene_doc_extractor 39 | } 40 | 41 | 42 | def main_cli(): 43 | parser = argparse.ArgumentParser('Extract documents from index (stdin: document IDs, ' 44 | 'stdout: datafile, stderr: progress and missing documents)') 45 | parser.add_argument('index_type', choices=INDEX_MAP.keys()) 46 | parser.add_argument('index_path') 47 | args = parser.parse_args() 48 | doc_extractor = INDEX_MAP[args.index_type](args.index_path) 49 | for docid in tqdm(sys.stdin): 50 | docid = docid.rstrip() 51 | doc = doc_extractor(docid) 52 | if doc is None: 53 | tqdm.write(f'[WARN] missing doc id: {docid}') 54 | else: 55 | doc = doc.replace('\t', ' ').replace('\r', ' ').replace('\n', ' ') 56 | sys.stdout.write(f'doc\t{docid}\t{doc}\n') 57 | 58 | 59 | if __name__ == '__main__': 60 | main_cli() 61 | -------------------------------------------------------------------------------- /cedr/modeling.py: -------------------------------------------------------------------------------- 1 | from pytools import memoize_method 2 | import torch 3 | import torch.nn.functional as F 4 | import pytorch_pretrained_bert 5 | from . import modeling_util 6 | 7 | 8 | class BertRanker(torch.nn.Module): 9 | def __init__(self): 10 | super().__init__() 11 | self.BERT_MODEL = 'bert-base-uncased' 12 | self.CHANNELS = 12 + 1 # from bert-base-uncased 13 | self.BERT_SIZE = 768 # from bert-base-uncased 14 | self.bert = CustomBertModel.from_pretrained(self.BERT_MODEL) 15 | self.tokenizer = pytorch_pretrained_bert.BertTokenizer.from_pretrained(self.BERT_MODEL) 16 | 17 | def forward(self, **inputs): 18 | raise NotImplementedError 19 | 20 | def save(self, path): 21 | state = self.state_dict(keep_vars=True) 22 | for key in list(state): 23 | if state[key].requires_grad: 24 | state[key] = state[key].data 25 | else: 26 | del state[key] 27 | torch.save(state, path) 28 | 29 | def load(self, path): 30 | self.load_state_dict(torch.load(path), strict=False) 31 | 32 | @memoize_method 33 | def tokenize(self, text): 34 | toks = self.tokenizer.tokenize(text) 35 | toks = [self.tokenizer.vocab[t] for t in toks] 36 | return toks 37 | 38 | def encode_bert(self, query_tok, query_mask, doc_tok, doc_mask): 39 | BATCH, QLEN = query_tok.shape 40 | DIFF = 3 # = [CLS] and 2x[SEP] 41 | maxlen = self.bert.config.max_position_embeddings 42 | MAX_DOC_TOK_LEN = maxlen - QLEN - DIFF 43 | 44 | doc_toks, sbcount = modeling_util.subbatch(doc_tok, MAX_DOC_TOK_LEN) 45 | doc_mask, _ = modeling_util.subbatch(doc_mask, MAX_DOC_TOK_LEN) 46 | 47 | query_toks = torch.cat([query_tok] * sbcount, dim=0) 48 | query_mask = torch.cat([query_mask] * sbcount, dim=0) 49 | 50 | CLSS = torch.full_like(query_toks[:, :1], self.tokenizer.vocab['[CLS]']) 51 | SEPS = torch.full_like(query_toks[:, :1], self.tokenizer.vocab['[SEP]']) 52 | ONES = torch.ones_like(query_mask[:, :1]) 53 | NILS = torch.zeros_like(query_mask[:, :1]) 54 | 55 | # build BERT input sequences 56 | toks = torch.cat([CLSS, query_toks, SEPS, doc_toks, SEPS], dim=1) 57 | mask = torch.cat([ONES, query_mask, ONES, doc_mask, ONES], dim=1) 58 | segment_ids = torch.cat([NILS] * (2 + QLEN) + [ONES] * (1 + doc_toks.shape[1]), dim=1) 59 | toks[toks == -1] = 0 # remove padding (will be masked anyway) 60 | 61 | # execute BERT model 62 | result = self.bert(toks, segment_ids.long(), mask) 63 | 64 | # extract relevant subsequences for query and doc 65 | query_results = [r[:BATCH, 1:QLEN+1] for r in result] 66 | doc_results = [r[:, QLEN+2:-1] for r in result] 67 | doc_results = [modeling_util.un_subbatch(r, doc_tok, MAX_DOC_TOK_LEN) for r in doc_results] 68 | 69 | # build CLS representation 70 | cls_results = [] 71 | for layer in result: 72 | cls_output = layer[:, 0] 73 | cls_result = [] 74 | for i in range(cls_output.shape[0] // BATCH): 75 | cls_result.append(cls_output[i*BATCH:(i+1)*BATCH]) 76 | cls_result = torch.stack(cls_result, dim=2).mean(dim=2) 77 | cls_results.append(cls_result) 78 | 79 | return cls_results, query_results, doc_results 80 | 81 | 82 | class VanillaBertRanker(BertRanker): 83 | def __init__(self): 84 | super().__init__() 85 | self.dropout = torch.nn.Dropout(0.1) 86 | self.cls = torch.nn.Linear(self.BERT_SIZE, 1) 87 | 88 | def forward(self, query_tok, query_mask, doc_tok, doc_mask): 89 | cls_reps, _, _ = self.encode_bert(query_tok, query_mask, doc_tok, doc_mask) 90 | return self.cls(self.dropout(cls_reps[-1])) 91 | 92 | 93 | class CedrPacrrRanker(BertRanker): 94 | def __init__(self): 95 | super().__init__() 96 | QLEN = 20 97 | KMAX = 2 98 | NFILTERS = 32 99 | MINGRAM = 1 100 | MAXGRAM = 3 101 | self.simmat = modeling_util.SimmatModule() 102 | self.ngrams = torch.nn.ModuleList() 103 | self.rbf_bank = None 104 | for ng in range(MINGRAM, MAXGRAM+1): 105 | ng = modeling_util.PACRRConvMax2dModule(ng, NFILTERS, k=KMAX, channels=self.CHANNELS) 106 | self.ngrams.append(ng) 107 | qvalue_size = len(self.ngrams) * KMAX 108 | self.linear1 = torch.nn.Linear(self.BERT_SIZE + QLEN * qvalue_size, 32) 109 | self.linear2 = torch.nn.Linear(32, 32) 110 | self.linear3 = torch.nn.Linear(32, 1) 111 | 112 | def forward(self, query_tok, query_mask, doc_tok, doc_mask): 113 | cls_reps, query_reps, doc_reps = self.encode_bert(query_tok, query_mask, doc_tok, doc_mask) 114 | simmat = self.simmat(query_reps, doc_reps, query_tok, doc_tok) 115 | scores = [ng(simmat) for ng in self.ngrams] 116 | scores = torch.cat(scores, dim=2) 117 | scores = scores.reshape(scores.shape[0], scores.shape[1] * scores.shape[2]) 118 | scores = torch.cat([scores, cls_reps[-1]], dim=1) 119 | rel = F.relu(self.linear1(scores)) 120 | rel = F.relu(self.linear2(rel)) 121 | rel = self.linear3(rel) 122 | return rel 123 | 124 | 125 | class CedrKnrmRanker(BertRanker): 126 | def __init__(self): 127 | super().__init__() 128 | MUS = [-0.9, -0.7, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 0.7, 0.9, 1.0] 129 | SIGMAS = [0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.001] 130 | self.bert_ranker = VanillaBertRanker() 131 | self.simmat = modeling_util.SimmatModule() 132 | self.kernels = modeling_util.KNRMRbfKernelBank(MUS, SIGMAS) 133 | self.combine = torch.nn.Linear(self.kernels.count() * self.CHANNELS + self.BERT_SIZE, 1) 134 | 135 | def forward(self, query_tok, query_mask, doc_tok, doc_mask): 136 | cls_reps, query_reps, doc_reps = self.encode_bert(query_tok, query_mask, doc_tok, doc_mask) 137 | simmat = self.simmat(query_reps, doc_reps, query_tok, doc_tok) 138 | kernels = self.kernels(simmat) 139 | BATCH, KERNELS, VIEWS, QLEN, DLEN = kernels.shape 140 | kernels = kernels.reshape(BATCH, KERNELS * VIEWS, QLEN, DLEN) 141 | simmat = simmat.reshape(BATCH, 1, VIEWS, QLEN, DLEN) \ 142 | .expand(BATCH, KERNELS, VIEWS, QLEN, DLEN) \ 143 | .reshape(BATCH, KERNELS * VIEWS, QLEN, DLEN) 144 | result = kernels.sum(dim=3) # sum over document 145 | mask = (simmat.sum(dim=3) != 0.) # which query terms are not padding? 146 | result = torch.where(mask, (result + 1e-6).log(), mask.float()) 147 | result = result.sum(dim=2) # sum over query terms 148 | result = torch.cat([result, cls_reps[-1]], dim=1) 149 | scores = self.combine(result) # linear combination over kernels 150 | return scores 151 | 152 | 153 | class CedrDrmmRanker(BertRanker): 154 | def __init__(self): 155 | super().__init__() 156 | NBINS = 11 157 | HIDDEN = 5 158 | self.bert_ranker = VanillaBertRanker() 159 | self.simmat = modeling_util.SimmatModule() 160 | self.histogram = modeling_util.DRMMLogCountHistogram(NBINS) 161 | self.hidden_1 = torch.nn.Linear(NBINS * self.CHANNELS + self.BERT_SIZE, HIDDEN) 162 | self.hidden_2 = torch.nn.Linear(HIDDEN, 1) 163 | 164 | def forward(self, query_tok, query_mask, doc_tok, doc_mask): 165 | cls_reps, query_reps, doc_reps = self.encode_bert(query_tok, query_mask, doc_tok, doc_mask) 166 | simmat = self.simmat(query_reps, doc_reps, query_tok, doc_tok) 167 | histogram = self.histogram(simmat, doc_tok, query_tok) 168 | BATCH, CHANNELS, QLEN, BINS = histogram.shape 169 | histogram = histogram.permute(0, 2, 3, 1) 170 | output = histogram.reshape(BATCH * QLEN, BINS * CHANNELS) 171 | # repeat cls representation for each query token 172 | cls_rep = cls_reps[-1].reshape(BATCH, 1, -1).expand(BATCH, QLEN, -1).reshape(BATCH * QLEN, -1) 173 | output = torch.cat([output, cls_rep], dim=1) 174 | term_scores = self.hidden_2(torch.relu(self.hidden_1(output))).reshape(BATCH, QLEN) 175 | return term_scores.sum(dim=1) 176 | 177 | 178 | class CustomBertModel(pytorch_pretrained_bert.BertModel): 179 | """ 180 | Based on pytorch_pretrained_bert.BertModel, but also outputs un-contextualized embeddings. 181 | """ 182 | def forward(self, input_ids, token_type_ids, attention_mask): 183 | """ 184 | Based on pytorch_pretrained_bert.BertModel 185 | """ 186 | embedding_output = self.embeddings(input_ids, token_type_ids) 187 | 188 | extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) 189 | extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility 190 | extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 191 | 192 | encoded_layers = self.encoder(embedding_output, extended_attention_mask, output_all_encoded_layers=True) 193 | 194 | return [embedding_output] + encoded_layers 195 | -------------------------------------------------------------------------------- /cedr/modeling_util.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | 4 | 5 | def subbatch(toks, maxlen): 6 | _, DLEN = toks.shape[:2] 7 | SUBBATCH = math.ceil(DLEN / maxlen) 8 | S = math.ceil(DLEN / SUBBATCH) if SUBBATCH > 0 else 0 # minimize the size given the number of subbatch 9 | stack = [] 10 | if SUBBATCH == 1: 11 | return toks, SUBBATCH 12 | else: 13 | for s in range(SUBBATCH): 14 | stack.append(toks[:, s*S:(s+1)*S]) 15 | if stack[-1].shape[1] != S: 16 | nulls = torch.zeros_like(toks[:, :S - stack[-1].shape[1]]) 17 | stack[-1] = torch.cat([stack[-1], nulls], dim=1) 18 | return torch.cat(stack, dim=0), SUBBATCH 19 | 20 | 21 | def un_subbatch(embed, toks, maxlen): 22 | BATCH, DLEN = toks.shape[:2] 23 | SUBBATCH = math.ceil(DLEN / maxlen) 24 | if SUBBATCH == 1: 25 | return embed 26 | else: 27 | embed_stack = [] 28 | for b in range(SUBBATCH): 29 | embed_stack.append(embed[b*BATCH:(b+1)*BATCH]) 30 | embed = torch.cat(embed_stack, dim=1) 31 | embed = embed[:, :DLEN] 32 | return embed 33 | 34 | 35 | class PACRRConvMax2dModule(torch.nn.Module): 36 | 37 | def __init__(self, shape, n_filters, k, channels): 38 | super().__init__() 39 | self.shape = shape 40 | if shape != 1: 41 | self.pad = torch.nn.ConstantPad2d((0, shape-1, 0, shape-1), 0) 42 | else: 43 | self.pad = None 44 | self.conv = torch.nn.Conv2d(channels, n_filters, shape) 45 | self.activation = torch.nn.ReLU() 46 | self.k = k 47 | self.shape = shape 48 | self.channels = channels 49 | 50 | def forward(self, simmat): 51 | BATCH, CHANNELS, QLEN, DLEN = simmat.shape 52 | if self.pad: 53 | simmat = self.pad(simmat) 54 | conv = self.activation(self.conv(simmat)) 55 | top_filters, _ = conv.max(dim=1) 56 | top_toks, _ = top_filters.topk(self.k, dim=2) 57 | result = top_toks.reshape(BATCH, QLEN, self.k) 58 | return result 59 | 60 | 61 | class SimmatModule(torch.nn.Module): 62 | 63 | def __init__(self, padding=-1): 64 | super().__init__() 65 | self.padding = padding 66 | self._hamming_index_loaded = None 67 | self._hamming_index = None 68 | 69 | def forward(self, query_embed, doc_embed, query_tok, doc_tok): 70 | simmat = [] 71 | 72 | for a_emb, b_emb in zip(query_embed, doc_embed): 73 | BAT, A, B = a_emb.shape[0], a_emb.shape[1], b_emb.shape[1] 74 | # embeddings -- cosine similarity matrix 75 | a_denom = a_emb.norm(p=2, dim=2).reshape(BAT, A, 1).expand(BAT, A, B) + 1e-9 # avoid 0div 76 | b_denom = b_emb.norm(p=2, dim=2).reshape(BAT, 1, B).expand(BAT, A, B) + 1e-9 # avoid 0div 77 | perm = b_emb.permute(0, 2, 1) 78 | sim = a_emb.bmm(perm) 79 | sim = sim / (a_denom * b_denom) 80 | 81 | # nullify padding (indicated by -1 by default) 82 | nul = torch.zeros_like(sim) 83 | sim = torch.where(query_tok.reshape(BAT, A, 1).expand(BAT, A, B) == self.padding, nul, sim) 84 | sim = torch.where(doc_tok.reshape(BAT, 1, B).expand(BAT, A, B) == self.padding, nul, sim) 85 | 86 | simmat.append(sim) 87 | return torch.stack(simmat, dim=1) 88 | 89 | 90 | class DRMMLogCountHistogram(torch.nn.Module): 91 | def __init__(self, bins): 92 | super().__init__() 93 | self.bins = bins 94 | 95 | def forward(self, simmat, dtoks, qtoks): 96 | # THIS IS SLOW ... Any way to make this faster? Maybe it's not worth doing on GPU? 97 | BATCH, CHANNELS, QLEN, DLEN = simmat.shape 98 | # +1e-5 to nudge scores of 1 to above threshold 99 | bins = ((simmat + 1.000001) / 2. * (self.bins - 1)).int() 100 | # set weights of 0 for padding (in both query and doc dims) 101 | weights = ((dtoks != -1).reshape(BATCH, 1, DLEN).expand(BATCH, QLEN, DLEN) * \ 102 | (qtoks != -1).reshape(BATCH, QLEN, 1).expand(BATCH, QLEN, DLEN)).float() 103 | 104 | # no way to batch this... loses gradients here. https://discuss.pytorch.org/t/histogram-function-in-pytorch/5350 105 | bins, weights = bins.cpu(), weights.cpu() 106 | histogram = [] 107 | for superbins, w in zip(bins, weights): 108 | result = [] 109 | for b in superbins: 110 | result.append(torch.stack([torch.bincount(q, x, self.bins) for q, x in zip(b, w)], dim=0)) 111 | result = torch.stack(result, dim=0) 112 | histogram.append(result) 113 | histogram = torch.stack(histogram, dim=0) 114 | 115 | # back to GPU 116 | histogram = histogram.to(simmat.device) 117 | return (histogram.float() + 1e-5).log() 118 | 119 | 120 | class KNRMRbfKernelBank(torch.nn.Module): 121 | def __init__(self, mus=None, sigmas=None, dim=1, requires_grad=True): 122 | super().__init__() 123 | self.dim = dim 124 | kernels = [KNRMRbfKernel(m, s, requires_grad=requires_grad) for m, s in zip(mus, sigmas)] 125 | self.kernels = torch.nn.ModuleList(kernels) 126 | 127 | def count(self): 128 | return len(self.kernels) 129 | 130 | def forward(self, data): 131 | return torch.stack([k(data) for k in self.kernels], dim=self.dim) 132 | 133 | 134 | class KNRMRbfKernel(torch.nn.Module): 135 | def __init__(self, initial_mu, initial_sigma, requires_grad=True): 136 | super().__init__() 137 | self.mu = torch.nn.Parameter(torch.tensor(initial_mu), requires_grad=requires_grad) 138 | self.sigma = torch.nn.Parameter(torch.tensor(initial_sigma), requires_grad=requires_grad) 139 | 140 | def forward(self, data): 141 | adj = data - self.mu 142 | return torch.exp(-0.5 * adj * adj / self.sigma / self.sigma) 143 | -------------------------------------------------------------------------------- /cedr/rerank.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from . import train 3 | from . import data 4 | 5 | 6 | def main_cli(): 7 | parser = argparse.ArgumentParser('CEDR model re-ranking') 8 | parser.add_argument('--model', choices=train.MODEL_MAP.keys(), default='vanilla_bert') 9 | parser.add_argument('--datafiles', type=argparse.FileType('rt'), nargs='+') 10 | parser.add_argument('--run', type=argparse.FileType('rt')) 11 | parser.add_argument('--model_weights', type=argparse.FileType('rb')) 12 | parser.add_argument('--out_path', type=argparse.FileType('wt')) 13 | args = parser.parse_args() 14 | model = train.MODEL_MAP[args.model]().cuda() 15 | dataset = data.read_datafiles(args.datafiles) 16 | run = data.read_run_dict(args.run) 17 | if args.model_weights is not None: 18 | model.load(args.model_weights.name) 19 | train.run_model(model, dataset, run, args.out_path.name, desc='rerank') 20 | 21 | 22 | if __name__ == '__main__': 23 | main_cli() 24 | -------------------------------------------------------------------------------- /cedr/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import subprocess 4 | import random 5 | import tempfile 6 | from tqdm import tqdm 7 | import torch 8 | from . import modeling 9 | from . import data 10 | import pytrec_eval 11 | from statistics import mean 12 | from collections import defaultdict 13 | 14 | 15 | 16 | SEED = 42 17 | LR = 0.001 18 | BERT_LR = 2e-5 19 | MAX_EPOCH = 100 20 | BATCH_SIZE = 16 21 | BATCHES_PER_EPOCH = 32 22 | GRAD_ACC_SIZE = 2 23 | #other possibilities: ndcg 24 | VALIDATION_METRIC = 'P_20' 25 | PATIENCE = 20 # how many epochs to wait for validation improvement 26 | 27 | torch.manual_seed(SEED) 28 | torch.cuda.manual_seed_all(SEED) 29 | random.seed(SEED) 30 | 31 | 32 | MODEL_MAP = { 33 | 'vanilla_bert': modeling.VanillaBertRanker, 34 | 'cedr_pacrr': modeling.CedrPacrrRanker, 35 | 'cedr_knrm': modeling.CedrKnrmRanker, 36 | 'cedr_drmm': modeling.CedrDrmmRanker 37 | } 38 | 39 | 40 | def main(model, dataset, train_pairs, qrels_train, valid_run, qrels_valid, model_out_dir=None): 41 | ''' 42 | Runs the training loop, controlled by the constants above 43 | Args: 44 | model(torch.nn.model or str): One of the models in modelling.py, 45 | or one of the keys of MODEL_MAP. 46 | dataset: A tuple containing two dictionaries, which contains the 47 | text of documents and queries in both training and validation sets: 48 | ({"q1" : "query text 1"}, {"d1" : "doct text 1"} ) 49 | train_pairs: A dictionary containing query document mappings for the training set 50 | (i.e, document to to generate pairs from). E.g.: 51 | {"q1: : ["d1", "d2", "d3"]} 52 | qrels_train(dict): A dicationary containing training qrels. Scores > 0 are considered 53 | relevant. Missing scores are considered non-relevant. e.g.: 54 | {"q1" : {"d1" : 2, "d2" : 0}} 55 | If you want to generate pairs from qrels, you can pass in same object for qrels_train and train_pairs 56 | valid_run: Query document mappings for validation set, in same format as train_pairs. 57 | qrels_valid: A dictionary containing qrels 58 | model_out_dir: Location where to write the models. If None, a temporary directoy is used. 59 | ''' 60 | 61 | if isinstance(model,str): 62 | model = MODEL_MAP[model]().cuda() 63 | if model_out_dir is None: 64 | model_out_dir = tempfile.mkdtemp() 65 | 66 | params = [(k, v) for k, v in model.named_parameters() if v.requires_grad] 67 | non_bert_params = {'params': [v for k, v in params if not k.startswith('bert.')]} 68 | bert_params = {'params': [v for k, v in params if k.startswith('bert.')], 'lr': BERT_LR} 69 | optimizer = torch.optim.Adam([non_bert_params, bert_params], lr=LR) 70 | 71 | epoch = 0 72 | top_valid_score = None 73 | print(f'Starting training, upto {MAX_EPOCH} epochs, patience {PATIENCE} LR={LR} BERT_LR={BERT_LR}', flush=True) 74 | for epoch in range(MAX_EPOCH): 75 | 76 | loss = train_iteration(model, optimizer, dataset, train_pairs, qrels_train) 77 | print(f'train epoch={epoch} loss={loss}') 78 | 79 | valid_score = validate(model, dataset, valid_run, qrels_valid, epoch) 80 | print(f'validation epoch={epoch} score={valid_score}') 81 | 82 | if top_valid_score is None or valid_score > top_valid_score: 83 | top_valid_score = valid_score 84 | print('new top validation score, saving weights', flush=True) 85 | model.save(os.path.join(model_out_dir, 'weights.p')) 86 | top_valid_score_epoch = epoch 87 | if top_valid_score is not None and epoch - top_valid_score_epoch > PATIENCE: 88 | print(f'no validation improvement since {top_valid_score_epoch}, early stopping', flush=True) 89 | break 90 | 91 | #load the final selected model for returning 92 | if top_valid_score_epoch != epoch: 93 | model.load(os.path.join(model_out_dir, 'weights.p')) 94 | return (model, top_valid_score_epoch) 95 | 96 | 97 | def train_iteration(model, optimizer, dataset, train_pairs, qrels): 98 | 99 | total = 0 100 | model.train() 101 | total_loss = 0. 102 | with tqdm('training', total=BATCH_SIZE * BATCHES_PER_EPOCH, ncols=80, desc='train', leave=False) as pbar: 103 | for record in data.iter_train_pairs(model, dataset, train_pairs, qrels, GRAD_ACC_SIZE): 104 | scores = model(record['query_tok'], 105 | record['query_mask'], 106 | record['doc_tok'], 107 | record['doc_mask']) 108 | count = len(record['query_id']) // 2 109 | scores = scores.reshape(count, 2) 110 | loss = torch.mean(1. - scores.softmax(dim=1)[:, 0]) # pariwse softmax 111 | loss.backward() 112 | total_loss += loss.item() 113 | total += count 114 | if total % BATCH_SIZE == 0: 115 | optimizer.step() 116 | optimizer.zero_grad() 117 | pbar.update(count) 118 | if total >= BATCH_SIZE * BATCHES_PER_EPOCH: 119 | return total_loss 120 | 121 | 122 | def validate(model, dataset, run, valid_qrels, epoch): 123 | run_scores = run_model(model, dataset, run) 124 | metric = VALIDATION_METRIC 125 | if metric.startswith("P_"): 126 | metric = "P" 127 | trec_eval = pytrec_eval.RelevanceEvaluator(valid_qrels, {metric}) 128 | eval_scores = trec_eval.evaluate(run_scores) 129 | print(eval_scores) 130 | return mean([d[VALIDATION_METRIC] for d in eval_scores.values()]) 131 | 132 | 133 | def run_model(model, dataset, run, desc='valid'): 134 | rerank_run = defaultdict(dict) 135 | with torch.no_grad(), tqdm(total=sum(len(r) for r in run.values()), ncols=80, desc=desc, leave=False) as pbar: 136 | model.eval() 137 | for records in data.iter_valid_records(model, dataset, run, BATCH_SIZE): 138 | scores = model(records['query_tok'], 139 | records['query_mask'], 140 | records['doc_tok'], 141 | records['doc_mask']) 142 | for qid, did, score in zip(records['query_id'], records['doc_id'], scores): 143 | rerank_run[qid][did] = score.item() 144 | pbar.update(len(records['query_id'])) 145 | return rerank_run 146 | 147 | 148 | def write_run(rerank_run, runf): 149 | ''' 150 | Utility method to write a file to disk. Now unused 151 | ''' 152 | with open(runf, 'wt') as runfile: 153 | for qid in rerank_run: 154 | scores = list(sorted(rerank_run[qid].items(), key=lambda x: (x[1], x[0]), reverse=True)) 155 | for i, (did, score) in enumerate(scores): 156 | runfile.write(f'{qid} 0 {did} {i+1} {score} run\n') 157 | 158 | def main_cli(): 159 | parser = argparse.ArgumentParser('CEDR model training and validation') 160 | parser.add_argument('--model', choices=MODEL_MAP.keys(), default='vanilla_bert') 161 | parser.add_argument('--datafiles', type=argparse.FileType('rt'), nargs='+') 162 | parser.add_argument('--qrels', type=argparse.FileType('rt')) 163 | parser.add_argument('--train_pairs', type=argparse.FileType('rt')) 164 | parser.add_argument('--valid_run', type=argparse.FileType('rt')) 165 | parser.add_argument('--initial_bert_weights', type=argparse.FileType('rb')) 166 | parser.add_argument('--model_out_dir') 167 | args = parser.parse_args() 168 | model = MODEL_MAP[args.model]().cuda() 169 | dataset = data.read_datafiles(args.datafiles) 170 | qrels = data.read_qrels_dict(args.qrels) 171 | train_pairs = data.read_pairs_dict(args.train_pairs) 172 | valid_run = data.read_run_dict(args.valid_run) 173 | 174 | if args.initial_bert_weights is not None: 175 | model.load(args.initial_bert_weights.name) 176 | os.makedirs(args.model_out_dir, exist_ok=True) 177 | # we use the same qrels object for both training and validation sets 178 | main(model, dataset, train_pairs, qrels, valid_run, qrels, args.model_out_dir) 179 | 180 | 181 | if __name__ == '__main__': 182 | main_cli() 183 | -------------------------------------------------------------------------------- /data/robust/queries.tsv: -------------------------------------------------------------------------------- 1 | query 301 International Organized Crime 2 | query 302 Poliomyelitis and Post-Polio 3 | query 303 Hubble Telescope Achievements 4 | query 304 Endangered Species (Mammals) 5 | query 305 Most Dangerous Vehicles 6 | query 306 African Civilian Deaths 7 | query 307 New Hydroelectric Projects 8 | query 308 Implant Dentistry 9 | query 309 Rap and Crime 10 | query 310 Radio Waves and Brain Cancer 11 | query 311 Industrial Espionage 12 | query 312 Hydroponics 13 | query 313 Magnetic Levitation-Maglev 14 | query 314 Marine Vegetation 15 | query 315 Unexplained Highway Accidents 16 | query 316 Polygamy Polyandry Polygyny 17 | query 317 Unsolicited Faxes 18 | query 318 Best Retirement Country 19 | query 319 New Fuel Sources 20 | query 320 Undersea Fiber Optic Cable 21 | query 321 Women in Parliaments 22 | query 322 International Art Crime 23 | query 323 Literary/Journalistic Plagiarism 24 | query 324 Argentine/British Relations 25 | query 325 Cult Lifestyles 26 | query 326 Ferry Sinkings 27 | query 327 Modern Slavery 28 | query 328 Pope Beatifications 29 | query 329 Mexican Air Pollution 30 | query 330 Iran-Iraq Cooperation 31 | query 331 World Bank Criticism 32 | query 332 Income Tax Evasion 33 | query 333 Antibiotics Bacteria Disease 34 | query 334 Export Controls Cryptography 35 | query 335 Adoptive Biological Parents 36 | query 336 Black Bear Attacks 37 | query 337 Viral Hepatitis 38 | query 338 Risk of Aspirin 39 | query 339 Alzheimer's Drug Treatment 40 | query 340 Land Mine Ban 41 | query 341 Airport Security 42 | query 342 Diplomatic Expulsion 43 | query 343 Police Deaths 44 | query 344 Abuses of E-Mail 45 | query 345 Overseas Tobacco Sales 46 | query 346 Educational Standards 47 | query 347 Wildlife Extinction 48 | query 348 Agoraphobia 49 | query 349 Metabolism 50 | query 350 Health and Computer Terminals 51 | query 351 Falkland petroleum exploration 52 | query 352 British Chunnel impact 53 | query 353 Antarctica exploration 54 | query 354 journalist risks 55 | query 355 ocean remote sensing 56 | query 356 postmenopausal estrogen Britain 57 | query 357 territorial waters dispute 58 | query 358 blood-alcohol fatalities 59 | query 359 mutual fund predictors 60 | query 360 drug legalization benefits 61 | query 361 clothing sweatshops 62 | query 362 human smuggling 63 | query 363 transportation tunnel disasters 64 | query 364 rabies 65 | query 365 El Nino 66 | query 366 commercial cyanide uses 67 | query 367 piracy 68 | query 368 in vitro fertilization 69 | query 369 anorexia nervosa bulimia 70 | query 370 food/drug laws 71 | query 371 health insurance holistic 72 | query 372 Native American casino 73 | query 373 encryption equipment export 74 | query 374 Nobel prize winners 75 | query 375 hydrogen energy 76 | query 376 World Court 77 | query 377 cigar smoking 78 | query 378 euro opposition 79 | query 379 mainstreaming 80 | query 380 obesity medical treatment 81 | query 381 alternative medicine 82 | query 382 hydrogen fuel automobiles 83 | query 383 mental illness drugs 84 | query 384 space station moon 85 | query 385 hybrid fuel cars 86 | query 386 teaching disabled children 87 | query 387 radioactive waste 88 | query 388 organic soil enhancement 89 | query 389 illegal technology transfer 90 | query 390 orphan drugs 91 | query 391 R&D drug prices 92 | query 392 robotics 93 | query 393 mercy killing 94 | query 394 home schooling 95 | query 395 tourism 96 | query 396 sick building syndrome 97 | query 397 automobile recalls 98 | query 398 dismantling Europe's arsenal 99 | query 399 oceanographic vessels 100 | query 400 Amazon rain forest 101 | query 401 foreign minorities, Germany 102 | query 402 behavioral genetics 103 | query 403 osteoporosis 104 | query 404 Ireland, peace talks 105 | query 405 cosmic events 106 | query 406 Parkinson's disease 107 | query 407 poaching, wildlife preserves 108 | query 408 tropical storms 109 | query 409 legal, Pan Am, 103 110 | query 410 Schengen agreement 111 | query 411 salvaging, shipwreck, treasure 112 | query 412 airport security 113 | query 413 steel production 114 | query 414 Cuba, sugar, exports 115 | query 415 drugs, Golden Triangle 116 | query 416 Three Gorges Project 117 | query 417 creativity 118 | query 418 quilts, income 119 | query 419 recycle, automobile tires 120 | query 420 carbon monoxide poisoning 121 | query 421 industrial waste disposal 122 | query 422 art, stolen, forged 123 | query 423 Milosevic, Mirjana Markovic 124 | query 424 suicides 125 | query 425 counterfeiting money 126 | query 426 law enforcement, dogs 127 | query 427 UV damage, eyes 128 | query 428 declining birth rates 129 | query 429 Legionnaires' disease 130 | query 430 killer bee attacks 131 | query 431 robotic technology 132 | query 432 profiling, motorists, police 133 | query 433 Greek, philosophy, stoicism 134 | query 434 Estonia, economy 135 | query 435 curbing population growth 136 | query 436 railway accidents 137 | query 437 deregulation, gas, electric 138 | query 438 tourism, increase 139 | query 439 inventions, scientific discoveries 140 | query 440 child labor 141 | query 441 Lyme disease 142 | query 442 heroic acts 143 | query 443 U.S., investment, Africa 144 | query 444 supercritical fluids 145 | query 445 women clergy 146 | query 446 tourists, violence 147 | query 447 Stirling engine 148 | query 448 ship losses 149 | query 449 antibiotics ineffectiveness 150 | query 450 King Hussein, peace 151 | query 601 Turkey Iraq water 152 | query 602 Czech, Slovak sovereignty 153 | query 603 Tobacco cigarette lawsuit 154 | query 604 Lyme disease arthritis 155 | query 605 Great Britain health care 156 | query 606 leg traps ban 157 | query 607 human genetic code 158 | query 608 taxing social security 159 | query 609 per capita alcohol consumption 160 | query 610 minimum wage adverse impact 161 | query 611 Kurds Germany violence 162 | query 612 Tibet protesters 163 | query 613 Berlin wall disposal 164 | query 614 Flavr Savr tomato 165 | query 615 timber exports Asia 166 | query 616 Volkswagen Mexico 167 | query 617 Russia Cuba economy 168 | query 618 Ayatollah Khomeini death 169 | query 619 Winnie Mandela scandal 170 | query 620 France nuclear testing 171 | query 621 women ordained Church of England 172 | query 622 price fixing 173 | query 623 toxic chemical weapon 174 | query 624 SDI Star Wars 175 | query 625 arrests bombing WTC 176 | query 626 human stampede 177 | query 627 Russian food crisis 178 | query 628 U.S. invasion of Panama 179 | query 629 abortion clinic attack 180 | query 630 Gulf War Syndrome 181 | query 631 Mandela South Africa President 182 | query 632 southeast Asia tin mining 183 | query 633 Welsh devolution 184 | query 634 L-tryptophan deaths 185 | query 635 doctor assisted suicides 186 | query 636 jury duty exemptions 187 | query 637 human growth hormone (HGH) 188 | query 638 wrongful convictions 189 | query 639 consumer on-line shopping 190 | query 640 maternity leave policies 191 | query 641 Valdez wildlife marine life 192 | query 642 Tiananmen Square protesters 193 | query 643 salmon dams Pacific northwest 194 | query 644 exotic animals import 195 | query 645 software piracy 196 | query 646 food stamps increase 197 | query 647 windmill electricity 198 | query 648 family leave law 199 | query 649 computer viruses 200 | query 650 tax evasion indicted 201 | query 651 U.S. ethnic population 202 | query 652 OIC Balkans 1990s 203 | query 653 ETA Basque terrorism 204 | query 654 same-sex schools 205 | query 655 ADD diagnosis treatment 206 | query 656 lead poisoning children 207 | query 657 school prayer banned 208 | query 658 teenage pregnancy 209 | query 659 cruise health safety 210 | query 660 whale watching California 211 | query 661 melanoma treatment causes 212 | query 662 telemarketer protection 213 | query 663 Agent Orange exposure 214 | query 664 American Indian Museum 215 | query 665 poverty Africa sub-Sahara 216 | query 666 Thatcher resignation impact 217 | query 667 unmarried-partner households 218 | query 668 poverty, disease 219 | query 669 Islamic Revolution 220 | query 670 U.S. elections apathy 221 | query 671 Salvation Army benefits 222 | query 672 NRA membership profile 223 | query 673 Soviet withdrawal Afghanistan 224 | query 674 Greenpeace prosecuted 225 | query 675 Olympics training swimming 226 | query 676 poppy cultivation 227 | query 677 Leaning Tower of Pisa 228 | query 678 joint custody impact 229 | query 679 opening adoption records 230 | query 680 immigrants Spanish school 231 | query 681 wind power location 232 | query 682 adult immigrants English 233 | query 683 Czechoslovakia breakup 234 | query 684 part-time benefits 235 | query 685 Oscar winner selection 236 | query 686 Argentina pegging dollar 237 | query 687 Northern Ireland industry 238 | query 688 non-U.S. media bias 239 | query 689 family-planning aid 240 | query 690 college education advantage 241 | query 691 clear-cutting forests 242 | query 692 prostate cancer detection treatment 243 | query 693 newspapers electronic media 244 | query 694 compost pile 245 | query 695 white collar crime sentence 246 | query 696 safety plastic surgery 247 | query 697 air traffic controller 248 | query 698 literacy rates Africa 249 | query 699 term limits 250 | query 700 gasoline tax U.S. 251 | -------------------------------------------------------------------------------- /data/wt/queries.tsv: -------------------------------------------------------------------------------- 1 | query 1 obama family tree 2 | query 2 french lick resort casino 3 | query 3 organized 4 | query 4 toilet 5 | query 5 mitchell college 6 | query 6 kcs 7 | query 7 air travel 8 | query 8 appraisals 9 | query 9 car parts 10 | query 10 cheap internet 11 | query 11 gmat prep classes 12 | query 12 djs 13 | query 13 map 14 | query 14 dinosaurs 15 | query 15 espn sports 16 | query 16 arizona game fish 17 | query 17 poker tournaments 18 | query 18 wedding budget calculator 19 | query 19 current 20 | query 20 defender 21 | query 21 volvo 22 | query 22 rick warren 23 | query 23 yahoo 24 | query 24 diversity 25 | query 25 euclid 26 | query 26 lower heart rate 27 | query 27 starbucks 28 | query 28 inuyasha 29 | query 29 ps 2 games 30 | query 30 diabetes education 31 | query 31 atari 32 | query 32 website design hosting 33 | query 33 elliptical trainer 34 | query 34 cell phones 35 | query 35 hoboken 36 | query 36 gps 37 | query 37 pampered chef 38 | query 38 dogs adoption 39 | query 39 disneyland hotel 40 | query 40 michworks 41 | query 41 orange county convention center 42 | query 42 music man 43 | query 43 secret garden 44 | query 44 map united states 45 | query 45 solar panels 46 | query 46 alexian brothers hospital 47 | query 47 indexed annuity 48 | query 48 wilson antenna 49 | query 49 flame designs 50 | query 50 dog heat 51 | query 51 horse hooves 52 | query 52 avp 53 | query 53 discovery channel store 54 | query 54 president united states 55 | query 55 iron 56 | query 56 uss yorktown charleston sc 57 | query 57 ct jobs 58 | query 58 penguins 59 | query 59 build fence 60 | query 60 bellevue 61 | query 61 worm 62 | query 62 texas border patrol 63 | query 63 flushing 64 | query 64 moths 65 | query 65 korean language 66 | query 66 income tax return online 67 | query 67 vldl levels 68 | query 68 pvc 69 | query 69 sewing instructions 70 | query 70 question 71 | query 71 living india 72 | query 72 sun 73 | query 73 neil young 74 | query 74 kiwi 75 | query 75 tornadoes 76 | query 76 raised gardens 77 | query 77 bobcat 78 | query 78 dieting 79 | query 79 voyager 80 | query 80 keyboard reviews 81 | query 81 afghanistan 82 | query 82 joints 83 | query 83 memory 84 | query 84 continental plates 85 | query 85 milwaukee journal sentinel 86 | query 86 bart sf 87 | query 87 who invented music 88 | query 88 forearm pain 89 | query 89 ocd 90 | query 90 mgb 91 | query 91 er tv show 92 | query 92 wall 93 | query 93 raffles 94 | query 94 titan 95 | query 95 earn money home 96 | query 96 rice 97 | query 97 south africa 98 | query 98 sat 99 | query 99 satellite 100 | query 100 rincon puerto rico 101 | query 101 ritz carlton lake las vegas 102 | query 102 fickle creek farm 103 | query 103 madam cj walker 104 | query 104 indiana child support 105 | query 105 sonoma county medical services 106 | query 106 universal animal cuts reviews 107 | query 107 cass county missouri 108 | query 108 ralph owen brewster 109 | query 109 mayo clinic jacksonville fl 110 | query 110 map brazil 111 | query 111 lymphoma dogs 112 | query 112 kenmore gas water heater 113 | query 113 hp mini 2140 114 | query 114 adobe indian houses 115 | query 115 pacific northwest laboratory 116 | query 116 california franchise tax board 117 | query 117 dangers asbestos 118 | query 118 poem pocket day 119 | query 119 interview thank 120 | query 120 tv computer 121 | query 121 sit reach test 122 | query 122 culpeper national cemetery 123 | query 123 von willebrand disease 124 | query 124 bowflex power 125 | query 125 butter margarine 126 | query 126 capitol map 127 | query 127 dutchess county tourism 128 | query 128 atypical squamous cells 129 | query 129 iowa food stamp program 130 | query 130 uranus 131 | query 131 equal opportunity employer 132 | query 132 mothers day songs 133 | query 133 created equal 134 | query 134 electronic skeet shoot 135 | query 135 source nile 136 | query 136 american military university 137 | query 137 rock gem shows 138 | query 138 jax chemical company 139 | query 139 rocky mountain news 140 | query 140 east ridge high school 141 | query 141 va dmv registration 142 | query 142 illinois state tax 143 | query 143 arkadelphia health club 144 | query 144 trombone sale 145 | query 145 vines shade 146 | query 146 sherwood regional library 147 | query 147 tangible personal property tax 148 | query 148 martha stewart imclone 149 | query 149 uplift yellowstone national park 150 | query 150 tn highway patrol 151 | query 151 403b 152 | query 152 angular cheilitis 153 | query 153 pocono 154 | query 154 figs 155 | query 155 last supper painting 156 | query 156 university phoenix 157 | query 157 beatles rock band 158 | query 158 septic system design 159 | query 159 porterville 160 | query 160 grilling 161 | query 161 furniture small spaces 162 | query 162 dnr 163 | query 163 arkansas 164 | query 164 hobby stores 165 | query 165 blue throated hummingbird 166 | query 166 computer programming 167 | query 167 barbados 168 | query 168 lipoma 169 | query 169 battles civil war 170 | query 170 scooters 171 | query 171 ron howard 172 | query 172 paralegal 173 | query 173 hip fractures 174 | query 174 rock art 175 | query 175 signs heartattack 176 | query 176 weather strip 177 | query 177 term care insurance 178 | query 178 pork tenderloin 179 | query 179 black history 180 | query 180 newyork hotels 181 | query 181 old coins 182 | query 182 quit smoking 183 | query 183 kansas city mo 184 | query 184 civil right movement 185 | query 185 credit report 186 | query 186 unc 187 | query 187 vanuatu 188 | query 188 internet phone service 189 | query 189 gs pay rate 190 | query 190 brooks brothers clearance 191 | query 191 churchill 192 | query 192 condos florida 193 | query 193 dog clean bags 194 | query 194 designer dog breeds 195 | query 195 pressure washers 196 | query 196 sore throat 197 | query 197 idaho state flower 198 | query 198 indiana state fairgrounds 199 | query 199 fybromyalgia 200 | query 200 ontario california airport 201 | query 201 raspberry pi 202 | query 202 uss carl vinson 203 | query 203 reviews les miserables 204 | query 204 rules golf 205 | query 205 average charitable donation 206 | query 206 wind power 207 | query 207 bph treatment 208 | query 208 doctor zhivago 209 | query 209 land surveyor 210 | query 210 golf gps 211 | query 211 madagascar known 212 | query 212 home theater systems 213 | query 213 carpal tunnel syndrome 214 | query 214 capital gains tax rate 215 | query 215 maryland department natural resources 216 | query 216 nicolas cage movies 217 | query 217 kids earth day activities 218 | query 218 solar water fountains 219 | query 219 name elvis presley home 220 | query 220 nba records 221 | query 221 electoral college 2008 results 222 | query 222 male menopause 223 | query 223 usda food pyramid 224 | query 224 chicken soup scratch 225 | query 225 black gold 226 | query 226 traverse city 227 | query 227 will survive lyrics 228 | query 228 hawaiian volcano observatories 229 | query 229 beef stroganoff recipe 230 | query 230 world biggest dog 231 | query 231 deadly sins 232 | query 232 hurricane irene flooding manville nj 233 | query 233 hair dye 234 | query 234 dark chocolate health benefits 235 | query 235 ham radio 236 | query 236 symptoms mad cow disease humans 237 | query 237 lump throat 238 | query 238 george bush sr bio 239 | query 239 frank lloyd wright biography 240 | query 240 presidential middle names 241 | query 241 wiki 242 | query 242 cannellini beans 243 | query 243 afghanistan flag 244 | query 244 old town scottsdale 245 | query 245 roosevelt island 246 | query 246 civil war battles south carolina 247 | query 247 rain man 248 | query 248 eggs shelf life 249 | query 249 occupational therapist 250 | query 250 ford edge problems 251 | query 251 identifying spider bites 252 | query 252 history orcas island 253 | query 253 tooth abscess 254 | query 254 barrett esophagus 255 | query 255 teddy bears 256 | query 256 patron saint mental illness 257 | query 257 holes louis sachar 258 | query 258 hip roof 259 | query 259 carpenter bee 260 | query 260 american revolutionary 261 | query 261 folk remedies sore throat 262 | query 262 balding cure 263 | query 263 evidence evolution 264 | query 264 tribe living alabama 265 | query 265 f5 tornado 266 | query 266 symptoms heart attack 267 | query 267 feliz navidad lyrics 268 | query 268 benefits running 269 | query 269 marshall county schools 270 | query 270 sun tzu 271 | query 271 halloween activities middle school 272 | query 272 dreams interpretation 273 | query 273 wilson disease 274 | query 274 golf instruction 275 | query 275 uss cole 276 | query 276 african american music influence history 277 | query 277 bewitched cast 278 | query 278 mister rogers 279 | query 279 game theory 280 | query 280 view internet history 281 | query 281 ketogenic diet 282 | query 282 nasa interplanetary missions 283 | query 283 hayrides pa 284 | query 284 find morel mushrooms 285 | query 285 magnesium rich foods 286 | query 286 common schizophrenia drugs 287 | query 287 carotid cavernous fistula treatment 288 | query 288 fidel castro 289 | query 289 benefits yoga 290 | query 290 norway spruce 291 | query 291 sangre de cristo mountains 292 | query 292 history electronic medical record 293 | query 293 educational advantages social networking sites 294 | query 294 flowering plants 295 | query 295 tie windsor knot 296 | query 296 recycling lead acid batteries 297 | query 297 altitude sickness 298 | query 298 medical care jehovah witnesses 299 | query 299 pink slime ground beef 300 | query 300 find mean 301 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.1.0 2 | pytorch-pretrained-bert==0.6.2 3 | tqdm 4 | pytools==2018.5.2 5 | git+https://github.com/cvangysel/pytrec_eval.git 6 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | requirements = [] 4 | with open('requirements.txt', 'rt') as f: 5 | for req in f.read().splitlines(): 6 | if req.startswith('git+'): 7 | pkg_name = req.split('/')[-1].replace('.git', '') 8 | requirements.append(f'{pkg_name} @ {req}') 9 | else: 10 | requirements.append(req) 11 | 12 | with open("README.md", "r") as fh: 13 | long_description = fh.read() 14 | 15 | setuptools.setup( 16 | name="cedr", 17 | version="0.0.1", 18 | author="Sean MacAvaney", 19 | author_email="sean@ir.cs.georgetown.edu", 20 | description="Code for CEDR: Contextualized Embeddings for Document Ranking, at SIGIR 2019.", 21 | long_description=long_description, 22 | long_description_content_type="text/markdown", 23 | url="https://github.com/Georgetown-IR-Lab/cedr", 24 | packages=setuptools.find_packages(), 25 | install_requires=requirements, 26 | python_requires='>=3.6', 27 | ) 28 | --------------------------------------------------------------------------------