├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── data_scripts ├── convertmrtydi2beir.py ├── preprocess_xmkqa.py └── tokenization_script.sh ├── eval_beir.py ├── evaluate_retrieved_passages.py ├── example_scripts ├── contriever.sh └── mcontriever.sh ├── finetuning.py ├── generate_passage_embeddings.py ├── passage_retrieval.py ├── preprocess.py ├── requirements.txt ├── src ├── __init__.py ├── beir_utils.py ├── contriever.py ├── data.py ├── dist_utils.py ├── evaluation.py ├── finetuning_data.py ├── inbatch.py ├── index.py ├── moco.py ├── normalize_text.py ├── options.py ├── slurm.py └── utils.py └── train.py /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | Facebook has adopted a Code of Conduct that we expect project participants to adhere to. 4 | Please read the [full text](https://code.fb.com/codeofconduct/) 5 | so that you can understand what actions will and will not be tolerated. 6 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to this repo 2 | 3 | ## Pull Requests 4 | 5 | In order to accept your pull request, we need you to submit a CLA. You only need 6 | to do this once to work on any of Facebook's open source projects. 7 | 8 | Complete your CLA here: 9 | 10 | ## Issues 11 | We use GitHub issues to track public bugs. Please ensure your description is 12 | clear and has sufficient instructions to be able to reproduce the issue. 13 | 14 | ## License 15 | By contributing to this repo, you agree that your contributions will be licensed 16 | under the LICENSE file in the root directory of this source tree. 17 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Attribution-NonCommercial 4.0 International 2 | 3 | ======================================================================= 4 | 5 | Creative Commons Corporation ("Creative Commons") is not a law firm and 6 | does not provide legal services or legal advice. Distribution of 7 | Creative Commons public licenses does not create a lawyer-client or 8 | other relationship. Creative Commons makes its licenses and related 9 | information available on an "as-is" basis. Creative Commons gives no 10 | warranties regarding its licenses, any material licensed under their 11 | terms and conditions, or any related information. Creative Commons 12 | disclaims all liability for damages resulting from their use to the 13 | fullest extent possible. 14 | 15 | Using Creative Commons Public Licenses 16 | 17 | Creative Commons public licenses provide a standard set of terms and 18 | conditions that creators and other rights holders may use to share 19 | original works of authorship and other material subject to copyright 20 | and certain other rights specified in the public license below. The 21 | following considerations are for informational purposes only, are not 22 | exhaustive, and do not form part of our licenses. 23 | 24 | Considerations for licensors: Our public licenses are 25 | intended for use by those authorized to give the public 26 | permission to use material in ways otherwise restricted by 27 | copyright and certain other rights. Our licenses are 28 | irrevocable. Licensors should read and understand the terms 29 | and conditions of the license they choose before applying it. 30 | Licensors should also secure all rights necessary before 31 | applying our licenses so that the public can reuse the 32 | material as expected. Licensors should clearly mark any 33 | material not subject to the license. This includes other CC- 34 | licensed material, or material used under an exception or 35 | limitation to copyright. More considerations for licensors: 36 | wiki.creativecommons.org/Considerations_for_licensors 37 | 38 | Considerations for the public: By using one of our public 39 | licenses, a licensor grants the public permission to use the 40 | licensed material under specified terms and conditions. If 41 | the licensor's permission is not necessary for any reason--for 42 | example, because of any applicable exception or limitation to 43 | copyright--then that use is not regulated by the license. Our 44 | licenses grant only permissions under copyright and certain 45 | other rights that a licensor has authority to grant. Use of 46 | the licensed material may still be restricted for other 47 | reasons, including because others have copyright or other 48 | rights in the material. A licensor may make special requests, 49 | such as asking that all changes be marked or described. 50 | Although not required by our licenses, you are encouraged to 51 | respect those requests where reasonable. More_considerations 52 | for the public: 53 | wiki.creativecommons.org/Considerations_for_licensees 54 | 55 | ======================================================================= 56 | 57 | Creative Commons Attribution-NonCommercial 4.0 International Public 58 | License 59 | 60 | By exercising the Licensed Rights (defined below), You accept and agree 61 | to be bound by the terms and conditions of this Creative Commons 62 | Attribution-NonCommercial 4.0 International Public License ("Public 63 | License"). To the extent this Public License may be interpreted as a 64 | contract, You are granted the Licensed Rights in consideration of Your 65 | acceptance of these terms and conditions, and the Licensor grants You 66 | such rights in consideration of benefits the Licensor receives from 67 | making the Licensed Material available under these terms and 68 | conditions. 69 | 70 | Section 1 -- Definitions. 71 | 72 | a. Adapted Material means material subject to Copyright and Similar 73 | Rights that is derived from or based upon the Licensed Material 74 | and in which the Licensed Material is translated, altered, 75 | arranged, transformed, or otherwise modified in a manner requiring 76 | permission under the Copyright and Similar Rights held by the 77 | Licensor. For purposes of this Public License, where the Licensed 78 | Material is a musical work, performance, or sound recording, 79 | Adapted Material is always produced where the Licensed Material is 80 | synched in timed relation with a moving image. 81 | 82 | b. Adapter's License means the license You apply to Your Copyright 83 | and Similar Rights in Your contributions to Adapted Material in 84 | accordance with the terms and conditions of this Public License. 85 | 86 | c. Copyright and Similar Rights means copyright and/or similar rights 87 | closely related to copyright including, without limitation, 88 | performance, broadcast, sound recording, and Sui Generis Database 89 | Rights, without regard to how the rights are labeled or 90 | categorized. For purposes of this Public License, the rights 91 | specified in Section 2(b)(1)-(2) are not Copyright and Similar 92 | Rights. 93 | d. Effective Technological Measures means those measures that, in the 94 | absence of proper authority, may not be circumvented under laws 95 | fulfilling obligations under Article 11 of the WIPO Copyright 96 | Treaty adopted on December 20, 1996, and/or similar international 97 | agreements. 98 | 99 | e. Exceptions and Limitations means fair use, fair dealing, and/or 100 | any other exception or limitation to Copyright and Similar Rights 101 | that applies to Your use of the Licensed Material. 102 | 103 | f. Licensed Material means the artistic or literary work, database, 104 | or other material to which the Licensor applied this Public 105 | License. 106 | 107 | g. Licensed Rights means the rights granted to You subject to the 108 | terms and conditions of this Public License, which are limited to 109 | all Copyright and Similar Rights that apply to Your use of the 110 | Licensed Material and that the Licensor has authority to license. 111 | 112 | h. Licensor means the individual(s) or entity(ies) granting rights 113 | under this Public License. 114 | 115 | i. NonCommercial means not primarily intended for or directed towards 116 | commercial advantage or monetary compensation. For purposes of 117 | this Public License, the exchange of the Licensed Material for 118 | other material subject to Copyright and Similar Rights by digital 119 | file-sharing or similar means is NonCommercial provided there is 120 | no payment of monetary compensation in connection with the 121 | exchange. 122 | 123 | j. Share means to provide material to the public by any means or 124 | process that requires permission under the Licensed Rights, such 125 | as reproduction, public display, public performance, distribution, 126 | dissemination, communication, or importation, and to make material 127 | available to the public including in ways that members of the 128 | public may access the material from a place and at a time 129 | individually chosen by them. 130 | 131 | k. Sui Generis Database Rights means rights other than copyright 132 | resulting from Directive 96/9/EC of the European Parliament and of 133 | the Council of 11 March 1996 on the legal protection of databases, 134 | as amended and/or succeeded, as well as other essentially 135 | equivalent rights anywhere in the world. 136 | 137 | l. You means the individual or entity exercising the Licensed Rights 138 | under this Public License. Your has a corresponding meaning. 139 | 140 | Section 2 -- Scope. 141 | 142 | a. License grant. 143 | 144 | 1. Subject to the terms and conditions of this Public License, 145 | the Licensor hereby grants You a worldwide, royalty-free, 146 | non-sublicensable, non-exclusive, irrevocable license to 147 | exercise the Licensed Rights in the Licensed Material to: 148 | 149 | a. reproduce and Share the Licensed Material, in whole or 150 | in part, for NonCommercial purposes only; and 151 | 152 | b. produce, reproduce, and Share Adapted Material for 153 | NonCommercial purposes only. 154 | 155 | 2. Exceptions and Limitations. For the avoidance of doubt, where 156 | Exceptions and Limitations apply to Your use, this Public 157 | License does not apply, and You do not need to comply with 158 | its terms and conditions. 159 | 160 | 3. Term. The term of this Public License is specified in Section 161 | 6(a). 162 | 163 | 4. Media and formats; technical modifications allowed. The 164 | Licensor authorizes You to exercise the Licensed Rights in 165 | all media and formats whether now known or hereafter created, 166 | and to make technical modifications necessary to do so. The 167 | Licensor waives and/or agrees not to assert any right or 168 | authority to forbid You from making technical modifications 169 | necessary to exercise the Licensed Rights, including 170 | technical modifications necessary to circumvent Effective 171 | Technological Measures. For purposes of this Public License, 172 | simply making modifications authorized by this Section 2(a) 173 | (4) never produces Adapted Material. 174 | 175 | 5. Downstream recipients. 176 | 177 | a. Offer from the Licensor -- Licensed Material. Every 178 | recipient of the Licensed Material automatically 179 | receives an offer from the Licensor to exercise the 180 | Licensed Rights under the terms and conditions of this 181 | Public License. 182 | 183 | b. No downstream restrictions. You may not offer or impose 184 | any additional or different terms or conditions on, or 185 | apply any Effective Technological Measures to, the 186 | Licensed Material if doing so restricts exercise of the 187 | Licensed Rights by any recipient of the Licensed 188 | Material. 189 | 190 | 6. No endorsement. Nothing in this Public License constitutes or 191 | may be construed as permission to assert or imply that You 192 | are, or that Your use of the Licensed Material is, connected 193 | with, or sponsored, endorsed, or granted official status by, 194 | the Licensor or others designated to receive attribution as 195 | provided in Section 3(a)(1)(A)(i). 196 | 197 | b. Other rights. 198 | 199 | 1. Moral rights, such as the right of integrity, are not 200 | licensed under this Public License, nor are publicity, 201 | privacy, and/or other similar personality rights; however, to 202 | the extent possible, the Licensor waives and/or agrees not to 203 | assert any such rights held by the Licensor to the limited 204 | extent necessary to allow You to exercise the Licensed 205 | Rights, but not otherwise. 206 | 207 | 2. Patent and trademark rights are not licensed under this 208 | Public License. 209 | 210 | 3. To the extent possible, the Licensor waives any right to 211 | collect royalties from You for the exercise of the Licensed 212 | Rights, whether directly or through a collecting society 213 | under any voluntary or waivable statutory or compulsory 214 | licensing scheme. In all other cases the Licensor expressly 215 | reserves any right to collect such royalties, including when 216 | the Licensed Material is used other than for NonCommercial 217 | purposes. 218 | 219 | Section 3 -- License Conditions. 220 | 221 | Your exercise of the Licensed Rights is expressly made subject to the 222 | following conditions. 223 | 224 | a. Attribution. 225 | 226 | 1. If You Share the Licensed Material (including in modified 227 | form), You must: 228 | 229 | a. retain the following if it is supplied by the Licensor 230 | with the Licensed Material: 231 | 232 | i. identification of the creator(s) of the Licensed 233 | Material and any others designated to receive 234 | attribution, in any reasonable manner requested by 235 | the Licensor (including by pseudonym if 236 | designated); 237 | 238 | ii. a copyright notice; 239 | 240 | iii. a notice that refers to this Public License; 241 | 242 | iv. a notice that refers to the disclaimer of 243 | warranties; 244 | 245 | v. a URI or hyperlink to the Licensed Material to the 246 | extent reasonably practicable; 247 | 248 | b. indicate if You modified the Licensed Material and 249 | retain an indication of any previous modifications; and 250 | 251 | c. indicate the Licensed Material is licensed under this 252 | Public License, and include the text of, or the URI or 253 | hyperlink to, this Public License. 254 | 255 | 2. You may satisfy the conditions in Section 3(a)(1) in any 256 | reasonable manner based on the medium, means, and context in 257 | which You Share the Licensed Material. For example, it may be 258 | reasonable to satisfy the conditions by providing a URI or 259 | hyperlink to a resource that includes the required 260 | information. 261 | 262 | 3. If requested by the Licensor, You must remove any of the 263 | information required by Section 3(a)(1)(A) to the extent 264 | reasonably practicable. 265 | 266 | 4. If You Share Adapted Material You produce, the Adapter's 267 | License You apply must not prevent recipients of the Adapted 268 | Material from complying with this Public License. 269 | 270 | Section 4 -- Sui Generis Database Rights. 271 | 272 | Where the Licensed Rights include Sui Generis Database Rights that 273 | apply to Your use of the Licensed Material: 274 | 275 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right 276 | to extract, reuse, reproduce, and Share all or a substantial 277 | portion of the contents of the database for NonCommercial purposes 278 | only; 279 | 280 | b. if You include all or a substantial portion of the database 281 | contents in a database in which You have Sui Generis Database 282 | Rights, then the database in which You have Sui Generis Database 283 | Rights (but not its individual contents) is Adapted Material; and 284 | 285 | c. You must comply with the conditions in Section 3(a) if You Share 286 | all or a substantial portion of the contents of the database. 287 | 288 | For the avoidance of doubt, this Section 4 supplements and does not 289 | replace Your obligations under this Public License where the Licensed 290 | Rights include other Copyright and Similar Rights. 291 | 292 | Section 5 -- Disclaimer of Warranties and Limitation of Liability. 293 | 294 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE 295 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS 296 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF 297 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, 298 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, 299 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR 300 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, 301 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT 302 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT 303 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. 304 | 305 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE 306 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, 307 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, 308 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, 309 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR 310 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN 311 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR 312 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR 313 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. 314 | 315 | c. The disclaimer of warranties and limitation of liability provided 316 | above shall be interpreted in a manner that, to the extent 317 | possible, most closely approximates an absolute disclaimer and 318 | waiver of all liability. 319 | 320 | Section 6 -- Term and Termination. 321 | 322 | a. This Public License applies for the term of the Copyright and 323 | Similar Rights licensed here. However, if You fail to comply with 324 | this Public License, then Your rights under this Public License 325 | terminate automatically. 326 | 327 | b. Where Your right to use the Licensed Material has terminated under 328 | Section 6(a), it reinstates: 329 | 330 | 1. automatically as of the date the violation is cured, provided 331 | it is cured within 30 days of Your discovery of the 332 | violation; or 333 | 334 | 2. upon express reinstatement by the Licensor. 335 | 336 | For the avoidance of doubt, this Section 6(b) does not affect any 337 | right the Licensor may have to seek remedies for Your violations 338 | of this Public License. 339 | 340 | c. For the avoidance of doubt, the Licensor may also offer the 341 | Licensed Material under separate terms or conditions or stop 342 | distributing the Licensed Material at any time; however, doing so 343 | will not terminate this Public License. 344 | 345 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public 346 | License. 347 | 348 | Section 7 -- Other Terms and Conditions. 349 | 350 | a. The Licensor shall not be bound by any additional or different 351 | terms or conditions communicated by You unless expressly agreed. 352 | 353 | b. Any arrangements, understandings, or agreements regarding the 354 | Licensed Material not stated herein are separate from and 355 | independent of the terms and conditions of this Public License. 356 | 357 | Section 8 -- Interpretation. 358 | 359 | a. For the avoidance of doubt, this Public License does not, and 360 | shall not be interpreted to, reduce, limit, restrict, or impose 361 | conditions on any use of the Licensed Material that could lawfully 362 | be made without permission under this Public License. 363 | 364 | b. To the extent possible, if any provision of this Public License is 365 | deemed unenforceable, it shall be automatically reformed to the 366 | minimum extent necessary to make it enforceable. If the provision 367 | cannot be reformed, it shall be severed from this Public License 368 | without affecting the enforceability of the remaining terms and 369 | conditions. 370 | 371 | c. No term or condition of this Public License will be waived and no 372 | failure to comply consented to unless expressly agreed to by the 373 | Licensor. 374 | 375 | d. Nothing in this Public License constitutes or may be interpreted 376 | as a limitation upon, or waiver of, any privileges and immunities 377 | that apply to the Licensor or You, including from the legal 378 | processes of any jurisdiction or authority. 379 | 380 | ======================================================================= 381 | 382 | Creative Commons is not a party to its public 383 | licenses. Notwithstanding, Creative Commons may elect to apply one of 384 | its public licenses to material it publishes and in those instances 385 | will be considered the “Licensor.” The text of the Creative Commons 386 | public licenses is dedicated to the public domain under the CC0 Public 387 | Domain Dedication. Except for the limited purpose of indicating that 388 | material is shared under a Creative Commons public license or as 389 | otherwise permitted by the Creative Commons policies published at 390 | creativecommons.org/policies, Creative Commons does not authorize the 391 | use of the trademark "Creative Commons" or any other trademark or logo 392 | of Creative Commons without its prior written consent including, 393 | without limitation, in connection with any unauthorized modifications 394 | to any of its public licenses or any other arrangements, 395 | understandings, or agreements concerning use of licensed material. For 396 | the avoidance of doubt, this paragraph does not form part of the 397 | public licenses. 398 | 399 | Creative Commons may be contacted at creativecommons.org. 400 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Contriever: Unsupervised Dense Information Retrieval with Contrastive Learning 2 | 3 | This repository contains pre-trained models, code for pre-training and evaluation for our paper [Unsupervised Dense Information Retrieval with Contrastive Learning](https://arxiv.org/abs/2112.09118). 4 | 5 | We use a simple contrastive learning framework to pre-train models for information retrieval. Contriever, trained without supervision, is competitive with BM25 for R@100 on the BEIR benchmark. After finetuning on MSMARCO, Contriever obtains strong performance, especially for the recall at 100. 6 | 7 | We also trained a multilingual version of Contriever, mContriever, achieving strong multilingual and cross-lingual retrieval performance. 8 | 9 | ## Getting started 10 | 11 | Pre-trained models can be loaded through the HuggingFace transformers library: 12 | 13 | ```python 14 | from src.contriever import Contriever 15 | from transformers import AutoTokenizer 16 | 17 | contriever = Contriever.from_pretrained("facebook/contriever") 18 | tokenizer = AutoTokenizer.from_pretrained("facebook/contriever") #Load the associated tokenizer: 19 | ``` 20 | 21 | Then embeddings for different sentences can be obtained by doing the following: 22 | 23 | ```python 24 | 25 | sentences = [ 26 | "Where was Marie Curie born?", 27 | "Maria Sklodowska, later known as Marie Curie, was born on November 7, 1867.", 28 | "Born in Paris on 15 May 1859, Pierre Curie was the son of Eugène Curie, a doctor of French Catholic origin from Alsace." 29 | ] 30 | 31 | inputs = tokenizer(sentences, padding=True, truncation=True, return_tensors="pt") 32 | embeddings = model(**inputs) 33 | ``` 34 | 35 | Then similarity scores between the different sentences are obtained with a dot product between the embeddings: 36 | ```python 37 | 38 | score01 = embeddings[0] @ embeddings[1] #1.0473 39 | score02 = embeddings[0] @ embeddings[2] #1.0095 40 | ``` 41 | 42 | ## Pre-trained models 43 | 44 | The following pre-trained models are available: 45 | * *contriever*: pre-trained on CC-net and English Wikipedia without any supervised data, 46 | * *contriever-msmarco*: contriever with fine-tuning on MSMARCO, 47 | * *mcontriever*: pre-trained on 29 languages using data from CC-net, 48 | * *mcontriever-msmarco*: mcontriever with fine-tuning on MSMARCO. 49 | 50 | 51 | ```python 52 | from src.contriever import Contriever 53 | 54 | contriever = Contriever.from_pretrained("facebook/contriever") 55 | contriever_msmarco = Contriever.from_pretrained("facebook/contriever-msmarco") 56 | mcontriever = Contriever.from_pretrained("facebook/mcontriever") 57 | mcontriever_msmarco = Contriever.from_pretrained("facebook/mcontriever-msmarco") 58 | ``` 59 | 60 | ## Evaluation 61 | 62 | ### Question answering retrieval 63 | 64 | NaturalQuestions and TriviaQA data can be downloaded from the FiD repository . The NaturalQuestions data slightly differs from the data provided in the DPR repository: we use the answers provided in the original NaturalQuestions data while DPR apply a post-processing step, which affects the tokenization of words. 65 | 66 |
67 | 68 | Retrieval is performed on the set of Wikipeda passages used in DPR. Download passages: 69 | 70 | 71 | ```bash 72 | wget https://dl.fbaipublicfiles.com/dpr/wikipedia_split/psgs_w100.tsv.gz 73 | ``` 74 |
75 | 76 |
77 | 78 | Generate passage embeddings: 79 | 80 | 81 | ```bash 82 | python generate_passage_embeddings.py \ 83 | --model_name_or_path facebook/contriever \ 84 | --output_dir contriever_embeddings \ 85 | --passages psgs_w100.tsv \ 86 | --shard_id 0 --num_shards 1 \ 87 | ``` 88 |
89 | 90 |
91 | 92 | Alternatively, download passage embeddings pre-computed with Contriever or Contriever-msmarco: 93 | 94 | 95 | ```bash 96 | wget https://dl.fbaipublicfiles.com/contriever/embeddings/contriever/wikipedia_embeddings.tar 97 | wget https://dl.fbaipublicfiles.com/contriever/embeddings/contriever-msmarco/wikipedia_embeddings.tar 98 | ``` 99 |
100 | 101 |
102 | 103 | Retrieve top-100 passages: 104 | 105 | 106 | ```python 107 | python passage_retrieval.py \ 108 | --model_name_or_path facebook/contriever \ 109 | --passages psgs_w100.tsv \ 110 | --passages_embeddings "contriever_embeddings/*" \ 111 | --data nq_dir/test.json \ 112 | --output_dir contriever_nq \ 113 | ``` 114 |
115 | 116 | This leads to the following results: 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 | 141 | 142 | 143 | 144 | 145 | 146 | 147 | 148 | 149 | 150 | 151 |
ModelNaturalQuestionsTriviaQA
R@5R@20R@100R@5R@20R@100
Contriever47.867.882.159.467.883.2
Contriever-msmarco65.779.688.071.380.485.7
152 | 153 | ### BEIR 154 | 155 | Scores on the BEIR benchmark can be reproduced using [beireval.py](beireval.py). 156 | 157 | ```bash 158 | python beireval.py --model_name_or_path contriever-msmarco --dataset scifact 159 | ``` 160 | 161 | 162 | The Touche-2020 dataset has been update in BEIR, thus results will differ if the current version is used. 163 | 164 | 165 | 166 | 167 | 168 | 169 | 170 | 171 | 172 | 173 | 174 | 175 | 176 | 177 | 178 | 179 | 180 | 181 | 182 | 183 | 184 | 185 | 186 | 187 | 188 | 189 | 190 | 191 | 192 | 193 | 194 | 195 | 196 | 197 | 198 | 199 | 200 | 201 | 202 | 203 | 204 | 205 | 206 | 207 | 208 | 209 | 210 | 211 | 212 | 213 | 214 | 215 | 216 | 217 | 218 | 219 | 220 | 221 | 222 |
nDCG@10AvgMSMARCOTREC-CovidNFCorpusNaturalQuestionsHotpotQAFiQAArguAnaTóuche-2020QuoraCQAdupstackDBPediaScidocsFeverClimate-feverScifact
Contriever37.720.627.431.725.448.124.537.919.383.528.429.214.968.215.564.9
Contriever-msmarco46.640.759.632.849.863.832.944.623.086.534.541.316.575.823.767.7
223 | 224 | 225 | 226 | 227 | 228 | 229 | 230 | 231 | 232 | 233 | 234 | 235 | 236 | 237 | 238 | 239 | 240 | 241 | 242 | 243 | 244 | 245 | 246 | 247 | 248 | 249 | 250 | 251 | 252 | 253 | 254 | 255 | 256 | 257 | 258 | 259 | 260 | 261 | 262 | 263 | 264 | 265 | 266 | 267 | 268 | 269 | 270 | 271 | 272 | 273 | 274 | 275 | 276 | 277 | 278 | 279 | 280 | 281 | 282 | 283 | 284 |
R@100AvgMSMARCOTREC-covidNFCorpusNaturalQuestionsHotpotQAFiQAArguAnaTóuche-2020QuoraCQAdupstackDBPediaScidocsFeverClimate-feverScifact
Contriever-msmarco59.667.217.229.477.170.456.290.122.598.761.445.336.093.644.192.6
Contriever-msmarco67.089.140.730.092.577.765.697.729.499.366.354.137.894.957.494.7
285 | 286 | ## Multilingual evaluation 287 | 288 | We evaluate mContriever on Mr. Tydi v1.1 and a cross-lingual retrieval setting derived from MKQA. You will find below steps to reproduce our results on these datasets. 289 | 290 | ### Mr. TyDi v1.1 291 | 292 | For multilingual evaluation on Mr. TyDi v1.1, we download datasets from and convert them to the BEIR format using (data_scripts/convertmrtydi2beir.py)[data_scripts/convertmrtydi2beir]. 293 | Evaluation on Swahili can be performed by doing the following: 294 | 295 |
296 | 297 | Download data: 298 | 299 | 300 | ```bash 301 | wget https://git.uwaterloo.ca/jimmylin/mr.tydi/-/raw/master/data/mrtydi-v1.1-swahili.tar.gz -P mrtydi 302 | tar -xf mrtydi/mrtydi-v1.1-swahili.tar.gz -C mrtydi 303 | gzip -d mrtydi/mrtydi-v1.1-swahili/collection/docs.jsonl.gz 304 | ``` 305 |
306 | 307 |
308 | 309 | Convert data: 310 | 311 | 312 | ```bash 313 | python data_scripts/convertmrtydi2beir.py mrtydi/mrtydi-v1.1-swahili mrtydi/mrtydi-v1.1-swahili 314 | ``` 315 |
316 | 317 |
318 | 319 | Evaluation: 320 | 321 | 322 | 323 | ```bash 324 | python beireval.py --model_name_or_path facebook/mcontriever --dataset mrtydi/mrtydi-v1.1-swahili --normalize_text 325 | ``` 326 |
327 | 328 | 329 | 330 | 331 | 332 | 333 | 334 | 335 | 336 | 337 | 338 | 339 | 340 | 341 | 342 | 343 | 344 | 345 | 346 | 347 | 348 | 349 | 350 | 351 | 352 | 353 | 354 | 355 | 356 | 357 | 358 | 359 | 360 | 361 | 362 | 363 | 364 | 365 | 366 | 367 | 368 | 369 | 370 | 371 | 372 | 373 | 374 | 375 | 376 | 377 | 378 | 379 | 380 | 381 | 382 | 383 | 384 | 385 | 386 | 387 | 388 | 389 | 390 |
MRR@100arbnenfiidjakoruswtethavg
mContriever27.336.39.221.123.519.522.317.538.322.537.225.0
mContriever-msmarco43.442.327.125.142.632.434.236.151.237.440.238.4
+ Mr. TyDi72.467.256.660.263.054.955.359.770.790.367.365.2
391 | 392 | 393 | 394 | 395 | 396 | 397 | 398 | 399 | 400 | 401 | 402 | 403 | 404 | 405 | 406 | 407 | 408 | 409 | 410 | 411 | 412 | 413 | 414 | 415 | 416 | 417 | 418 | 419 | 420 | 421 | 422 | 423 | 424 | 425 | 426 | 427 | 428 | 429 | 430 | 431 | 432 | 433 | 434 | 435 | 436 | 437 | 438 | 439 | 440 | 441 | 442 | 443 | 444 | 445 | 446 | 447 | 448 | 449 | 450 | 451 | 452 | 453 | 454 |
R@100arbnenfiidjakoruswtethavg
mContriever82.089.648.879.681.472.866.268.588.780.890.377.2
mContriever-msmarco88.791.477.288.189.881.778.283.891.496.690.587.0
+ Mr. TyDi94.098.692.292.794.588.888.992.493.798.995.293.6
455 | 456 | 457 | 458 | ### Cross-lingual MKQA 459 | 460 | Here our goal is to measure how well retrievers are to retrieve relevant documents in English Wikipedia given a query in another language. 461 | For this we use MKQA and evaluate if the answer is in the retrieved documents based on the DPR evaluation script. 462 | 463 |
464 | 465 | Download data: 466 | 467 | 468 | 469 | ```bash 470 | wget https://raw.githubusercontent.com/apple/ml-mkqa/master/dataset/mkqa.jsonl.gz 471 | ``` 472 |
473 | 474 |
475 | 476 | Preprocess data: 477 | 478 | 479 | 480 | ```bash 481 | python data_scripts/preprocess_xmkqa.py mkqa.jsonl xmkqa 482 | ``` 483 |
484 | 485 |
486 | 487 | Generate embeddings: 488 | 489 | 490 | 491 | ```bash 492 | python generate_passage_embeddings.py \ 493 | --model_name_or_path facebook/mcontriever \ 494 | --output_dir mcontriever_embeddings \ 495 | --passages psgs_w100.tsv \ 496 | --shard_id 0 --num_shards 1 \ 497 | --lowercase --normalize_text \ 498 | ``` 499 |
500 | 501 |
502 | 503 | Alternatively, download passage embeddings pre-computed with mContriever or mContriever-msmarco: 504 | 505 | 506 | ```bash 507 | wget https://dl.fbaipublicfiles.com/contriever/embeddings/mcontriever/wikipedia_embeddings.tar 508 | wget https://dl.fbaipublicfiles.com/contriever/embeddings/mcontriever-msmarco/wikipedia_embeddings.tar 509 | ``` 510 |
511 | 512 | 513 |
514 | 515 | Retrieve passages and compute retrieval accuracy: 516 | 517 | 518 | 519 | ```bash 520 | 521 | python passage_retrieval.py \ 522 | --model_name_or_path facebook/mcontriever \ 523 | --passages psgs_w100.tsv \ 524 | --passages_embeddings "mcontriever_embeddings/*" \ 525 | --data "xmkqa/*.jsonl" \ 526 | --output_dir mcontriever_xmkqa \ 527 | --lowercase --normalize_text \ 528 | ``` 529 |
530 | 531 | 532 | 533 | 534 | 535 | 536 | 537 | 538 | 539 | 540 | 541 | 542 | 543 | 544 | 545 | 546 | 547 | 548 | 549 | 550 | 551 | 552 | 553 | 554 | 555 | 556 | 557 | 558 | 559 | 560 | 561 | 562 | 563 | 564 | 565 | 566 | 567 | 568 | 569 | 570 | 571 | 572 | 573 | 574 | 575 | 576 | 577 | 578 | 579 | 580 | 581 | 582 | 583 | 584 | 585 | 586 | 587 | 588 | 589 | 590 | 591 | 592 | 593 | 594 | 595 | 596 | 597 | 598 | 599 | 600 | 601 | 602 | 603 | 604 | 605 | 606 | 607 | 608 | 609 | 610 | 611 | 612 | 613 | 614 | 615 | 616 | 617 | 618 | 619 | 620 | 621 | 622 | 623 |
R@100avgenarfijakoruessvhethdadefritnlplpthuvimskmnotrzh-cnzh-hkzh-tw
mContriever49.265.343.043.147.144.851.837.254.544.751.449.349.050.256.761.744.454.547.745.156.727.850.244.354.351.952.5
mContriever-msmarco65.675.653.366.660.455.464.770.070.859.663.572.066.670.170.371.468.868.566.767.871.637.871.568.764.164.564.3
624 | 625 | 626 | 627 | 628 | 629 | 630 | 631 | 632 | 633 | 634 | 635 | 636 | 637 | 638 | 639 | 640 | 641 | 642 | 643 | 644 | 645 | 646 | 647 | 648 | 649 | 650 | 651 | 652 | 653 | 654 | 655 | 656 | 657 | 658 | 659 | 660 | 661 | 662 | 663 | 664 | 665 | 666 | 667 | 668 | 669 | 670 | 671 | 672 | 673 | 674 | 675 | 676 | 677 | 678 | 679 | 680 | 681 | 682 | 683 | 684 | 685 | 686 | 687 | 688 | 689 | 690 | 691 | 692 | 693 | 694 | 695 | 696 | 697 | 698 | 699 | 700 | 701 | 702 | 703 | 704 | 705 | 706 | 707 | 708 | 709 | 710 | 711 | 712 | 713 | 714 | 715 | 716 |
R@20avgenarfijakoruessvhethdadefritnlplpthuvimskmnotrzh-cnzh-hkzh-tw
mContriever31.450.226.626.729.427.932.720.737.622.231.131.231.230.738.645.125.137.628.327.339.615.733.226.535.032.732.5
mContriever-msmarco53.967.240.155.146.241.752.359.360.045.652.062.054.859.359.460.958.156.955.255.960.926.261.056.750.951.951.2
717 | 718 | 719 | ## Training 720 | 721 | 722 | ### Data pre-processing 723 | We perform pre-training on data from CCNet and Wikipedia. 724 | Contriever, the English monolingual model, is trained on English data from Wikipedia and CCNet. 725 | mContriever, the multilingual model, is pre-trained on 29 languages using data from CCNet. 726 | After converting data into a text file, we tokenize and chunk it into multiple sub-files using the [`data_scripts/tokenization_script.sh`](data_scripts/tokenization_script.sh). 727 | The different chunks are then loaded separately by the different processes in a distributed job. 728 | For mContriever, we use the option `--normalize_text` to preprocess data, this normalize certain common caracters that are not present in mBERT tokenizer. 729 | 730 | ### Training 731 | [`train.py`](train.py) provides the code for the contrastive training phase of Contriever. 732 | 733 |
734 | 735 | For Contriever, the English monolingual model, we use the following options on 32 gpus: 736 | 737 | 738 | 739 | ```bash 740 | python train.py \ 741 | --retriever_model_id bert-base-uncased --pooling average \ 742 | --augmentation delete --prob_augmentation 0.1 \ 743 | --train_data "data/wiki/ data/cc-net/" --loading_mode split \ 744 | --ratio_min 0.1 --ratio_max 0.5 --chunk_length 256 \ 745 | --momentum 0.9995 --moco_queue 131072 --temperature 0.05 \ 746 | --warmup_steps 20000 --total_steps 500000 --lr 0.00005 \ 747 | --scheduler linear --optim adamw --per_gpu_batch_size 64 \ 748 | --output_dir /checkpoint/gizacard/contriever/xling/contriever \ 749 | 750 | ``` 751 |
752 | 753 |
754 | 755 | For mContriever, the multilingual model, we use the following options on 32 gpus: 756 | 757 | 758 | 759 | ```bash 760 | TDIR=encoded-data/bert-base-multilingual-cased/ 761 | TRAINDATASETS="${TDIR}fr_XX ${TDIR}en_XX ${TDIR}ar_AR ${TDIR}bn_IN ${TDIR}fi_FI ${TDIR}id_ID ${TDIR}ja_XX ${TDIR}ko_KR ${TDIR}ru_RU ${TDIR}sw_KE ${TDIR}hu_HU ${TDIR}he_IL ${TDIR}it_IT ${TDIR}km_KM ${TDIR}ms_MY ${TDIR}nl_XX ${TDIR}no_XX ${TDIR}pl_PL ${TDIR}pt_XX ${TDIR}sv_SE ${TDIR}te_IN ${TDIR}th_TH ${TDIR}tr_TR ${TDIR}vi_VN ${TDIR}zh_CN ${TDIR}zh_TW ${TDIR}es_XX ${TDIR}de_DE ${TDIR}da_DK" 762 | 763 | python train.py \ 764 | --retriever_model_id bert-base-multilingual-cased --pooling average \ 765 | --train_data ${TRAINDATASETS} --loading_mode split \ 766 | --ratio_min 0.1 --ratio_max 0.5 --chunk_length 256 \ 767 | --momentum 0.999 --moco_queue 32768 --temperature 0.05 \ 768 | --warmup_steps 20000 --total_steps 500000 --lr 0.00005 \ 769 | --scheduler linear --optim adamw --per_gpu_batch_size 64 \ 770 | --output_dir /checkpoint/gizacard/contriever/xling/mcontriever \ 771 | ``` 772 | 773 |
774 | 775 | The full training script used on our slurm cluster are available in the [`example_scripts`](example_scripts) folder. 776 | 777 | 778 | ## References 779 | 780 | If you find this repository useful, please consider giving a star and citing this work: 781 | 782 | [1] G. Izacard, M. Caron, L. Hosseini, S. Riedel, P. Bojanowski, A. Joulin, E. Grave [*Unsupervised Dense Information Retrieval with Contrastive Learning*](https://arxiv.org/abs/2112.09118) 783 | 784 | ```bibtex 785 | @misc{izacard2021contriever, 786 | title={Unsupervised Dense Information Retrieval with Contrastive Learning}, 787 | author={Gautier Izacard and Mathilde Caron and Lucas Hosseini and Sebastian Riedel and Piotr Bojanowski and Armand Joulin and Edouard Grave}, 788 | year={2021}, 789 | url = {https://arxiv.org/abs/2112.09118}, 790 | doi = {10.48550/ARXIV.2112.09118}, 791 | } 792 | ``` 793 | 794 | ## License 795 | 796 | See the [LICENSE](LICENSE) file for more details. 797 | -------------------------------------------------------------------------------- /data_scripts/convertmrtydi2beir.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | import sys 4 | import os 5 | import csv 6 | import json 7 | 8 | def convert2beir(data_path, output_path): 9 | 10 | splits = ['test', 'dev', 'train'] 11 | queries_path = os.path.join(output_path, "queries.jsonl") 12 | corpus_path = os.path.join(output_path, "corpus.jsonl") 13 | os.makedirs(os.path.dirname(corpus_path), exist_ok=True) 14 | queries = [] 15 | with open(queries_path, "w", encoding="utf-8") as fout: 16 | with open(os.path.join(data_path, f"topic.tsv"), "r", encoding="utf-8") as fin: 17 | reader = csv.reader(fin, delimiter="\t") 18 | for x in reader: 19 | qdict = { 20 | "_id": x[0], 21 | "text": x[1] 22 | } 23 | json.dump(qdict, fout, ensure_ascii=False) 24 | fout.write('\n') 25 | 26 | with open(os.path.join(data_path, "collection", "docs.jsonl"), "r") as fin: 27 | with open(corpus_path, "w", encoding="utf-8") as fout: 28 | for line in fin: 29 | x = json.loads(line) 30 | x["_id"] = x["id"] 31 | x["text"] = x["contents"] 32 | x["title"] = "" 33 | del x["id"] 34 | del x["contents"] 35 | json.dump(x, fout, ensure_ascii=False) 36 | fout.write('\n') 37 | 38 | 39 | for split in splits: 40 | 41 | qrels_path = os.path.join(output_path, "qrels", f"{split}.tsv") 42 | os.makedirs(os.path.dirname(qrels_path), exist_ok=True) 43 | 44 | with open(os.path.join(data_path, f"qrels.{split}.txt"), "r", encoding="utf-8") as fin: 45 | with open(qrels_path, "w", encoding="utf-8") as fout: 46 | writer = csv.writer(fout, delimiter='\t') 47 | writer.writerow(["query-id", "corpus-id", "score"]) 48 | for line in fin: 49 | line = line.strip() 50 | el = line.split() 51 | qid = el[0] 52 | i = el[2] 53 | s = el[3] 54 | writer.writerow([qid, i, s]) 55 | 56 | 57 | if __name__ == '__main__': 58 | convert2beir(sys.argv[1], sys.argv[2]) 59 | -------------------------------------------------------------------------------- /data_scripts/preprocess_xmkqa.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | import sys 4 | import os 5 | import json 6 | from collections import defaultdict 7 | 8 | def preprocess_xmkqa(input_path, output_dir): 9 | os.makedirs(output_dir, exist_ok=True) 10 | mkqa = [] 11 | with open(input_path, 'r') as fin: 12 | for line in fin: 13 | ex = json.loads(line) 14 | mkqa.append(ex) 15 | mkqadict = {ex['example_id']:ex for ex in mkqa} 16 | 17 | langs = ['en', 'ar', 'fi', 'ja', 'ko', 'ru', 'es', 'sv', 'he', 'th', \ 18 | 'da', 'de', 'fr', 'it', 'nl', 'pl', 'pt', 'hu', 'vi', 'ms', \ 19 | 'km', 'no', 'tr', 'zh_cn', 'zh_hk', 'zh_tw'] 20 | langdata = defaultdict(list) 21 | 22 | for ex in mkqa: 23 | answers = [] 24 | for a in ex['answers']['en']: 25 | flag = False 26 | if not (a['type'] == 'unanswerable' or a['type'] == 'binary' or a['type'] == 'long_answer'): 27 | flag = True 28 | answers.extend(a.get("aliases", [])) 29 | answers.append(a.get("text")) 30 | if flag: 31 | for lang in langs: 32 | langex = { 33 | 'id': ex['example_id'], 34 | 'lang': lang, 35 | 'question': ex['queries'][lang], #question in specific languages 36 | 'answers': answers #english answers 37 | } 38 | langdata[lang].append(langex) 39 | 40 | 41 | for lang, data in langdata.items(): 42 | with open(os.path.join(output_dir, f'{lang}.jsonl'), 'w') as fout: 43 | for ex in data: 44 | json.dump(ex, fout, ensure_ascii=False) 45 | fout.write('\n') 46 | 47 | if __name__ == '__main__': 48 | preprocess_xmkqa(sys.argv[1], sys.argv[2]) 49 | -------------------------------------------------------------------------------- /data_scripts/tokenization_script.sh: -------------------------------------------------------------------------------- 1 | NSPLIT=128 #Must be larger than the number of processes used during training 2 | FILENAME=en_XX.txt 3 | INFILE=./${FILENAME} 4 | TOKENIZER=bert-base-uncased 5 | #TOKENIZER=bert-base-multilingual-cased 6 | SPLITDIR=./tmp-tokenization-${TOKENIZER}-${FILENAME}/ 7 | OUTDIR=./encoded-data/${TOKENIZER}/$(echo "$FILENAME" | cut -f 1 -d '.') 8 | NPROCESS=8 9 | 10 | mkdir -p ${SPLITDIR} 11 | echo ${INFILE} 12 | split -a 3 -d -n l/${NSPLIT} ${INFILE} ${SPLITDIR} 13 | 14 | pids=() 15 | 16 | for ((i=0;i<$NSPLIT;i++)); do 17 | num=$(printf "%03d\n" $i); 18 | FILE=${SPLITDIR}${num}; 19 | #we used --normalize_text as an additional option for mContriever 20 | python3 preprocess.py --tokenizer ${TOKENIZER} --datapath ${FILE} --outdir ${OUTDIR} & 21 | pids+=($!); 22 | if (( $i % $NPROCESS == 0 )) 23 | then 24 | for pid in ${pids[@]}; do 25 | wait $pid 26 | done 27 | fi 28 | done 29 | 30 | for pid in ${pids[@]}; do 31 | wait $pid 32 | done 33 | 34 | echo ${SPLITDIR} 35 | 36 | rm -r ${SPLITDIR} 37 | -------------------------------------------------------------------------------- /eval_beir.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import sys 8 | import argparse 9 | import torch 10 | import logging 11 | import json 12 | import numpy as np 13 | import os 14 | 15 | import src.slurm 16 | import src.contriever 17 | import src.beir_utils 18 | import src.utils 19 | import src.dist_utils 20 | import src.contriever 21 | 22 | logger = logging.getLogger(__name__) 23 | 24 | 25 | def main(args): 26 | 27 | src.slurm.init_distributed_mode(args) 28 | src.slurm.init_signal_handler() 29 | 30 | os.makedirs(args.output_dir, exist_ok=True) 31 | 32 | logger = src.utils.init_logger(args) 33 | 34 | model, tokenizer, _ = src.contriever.load_retriever(args.model_name_or_path) 35 | model = model.cuda() 36 | model.eval() 37 | query_encoder = model 38 | doc_encoder = model 39 | 40 | logger.info("Start indexing") 41 | 42 | metrics = src.beir_utils.evaluate_model( 43 | query_encoder=query_encoder, 44 | doc_encoder=doc_encoder, 45 | tokenizer=tokenizer, 46 | dataset=args.dataset, 47 | batch_size=args.per_gpu_batch_size, 48 | norm_query=args.norm_query, 49 | norm_doc=args.norm_doc, 50 | is_main=src.dist_utils.is_main(), 51 | split="dev" if args.dataset == "msmarco" else "test", 52 | score_function=args.score_function, 53 | beir_dir=args.beir_dir, 54 | save_results_path=args.save_results_path, 55 | lower_case=args.lower_case, 56 | normalize_text=args.normalize_text, 57 | ) 58 | 59 | if src.dist_utils.is_main(): 60 | for key, value in metrics.items(): 61 | logger.info(f"{args.dataset} : {key}: {value:.1f}") 62 | 63 | 64 | if __name__ == "__main__": 65 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 66 | 67 | parser.add_argument("--dataset", type=str, help="Evaluation dataset from the BEIR benchmark") 68 | parser.add_argument("--beir_dir", type=str, default="./", help="Directory to save and load beir datasets") 69 | parser.add_argument("--text_maxlength", type=int, default=512, help="Maximum text length") 70 | 71 | parser.add_argument("--per_gpu_batch_size", default=128, type=int, help="Batch size per GPU/CPU for indexing.") 72 | parser.add_argument("--output_dir", type=str, default="./my_experiment", help="Output directory") 73 | parser.add_argument("--model_name_or_path", type=str, help="Model name or path") 74 | parser.add_argument( 75 | "--score_function", type=str, default="dot", help="Metric used to compute similarity between two embeddings" 76 | ) 77 | parser.add_argument("--norm_query", action="store_true", help="Normalize query representation") 78 | parser.add_argument("--norm_doc", action="store_true", help="Normalize document representation") 79 | parser.add_argument("--lower_case", action="store_true", help="lowercase query and document text") 80 | parser.add_argument( 81 | "--normalize_text", action="store_true", help="Apply function to normalize some common characters" 82 | ) 83 | parser.add_argument("--save_results_path", type=str, default=None, help="Path to save result object") 84 | 85 | parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") 86 | parser.add_argument("--main_port", type=int, default=-1, help="Main port (for multi-node SLURM jobs)") 87 | 88 | args, _ = parser.parse_known_args() 89 | main(args) 90 | -------------------------------------------------------------------------------- /evaluate_retrieved_passages.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import argparse 8 | import json 9 | import logging 10 | import glob 11 | 12 | import numpy as np 13 | import torch 14 | 15 | import src.utils 16 | 17 | from src.evaluation import calculate_matches 18 | 19 | logger = logging.getLogger(__name__) 20 | 21 | def validate(data, workers_num): 22 | match_stats = calculate_matches(data, workers_num) 23 | top_k_hits = match_stats.top_k_hits 24 | 25 | #logger.info('Validation results: top k documents hits %s', top_k_hits) 26 | top_k_hits = [v / len(data) for v in top_k_hits] 27 | #logger.info('Validation results: top k documents hits accuracy %s', top_k_hits) 28 | return top_k_hits 29 | 30 | 31 | def main(opt): 32 | logger = src.utils.init_logger(opt, stdout_only=True) 33 | datapaths = glob.glob(args.data) 34 | r20, r100 = [], [] 35 | for path in datapaths: 36 | data = [] 37 | with open(path, 'r') as fin: 38 | for line in fin: 39 | data.append(json.loads(line)) 40 | #data = json.load(fin) 41 | answers = [ex['answers'] for ex in data] 42 | top_k_hits = validate(data, args.validation_workers) 43 | message = f"Evaluate results from {path}:" 44 | for k in [5, 10, 20, 100]: 45 | if k <= len(top_k_hits): 46 | recall = 100 * top_k_hits[k-1] 47 | if k == 20: 48 | r20.append(f"{recall:.1f}") 49 | if k == 100: 50 | r100.append(f"{recall:.1f}") 51 | message += f' R@{k}: {recall:.1f}' 52 | logger.info(message) 53 | print(datapaths) 54 | print('\t'.join(r20)) 55 | print('\t'.join(r100)) 56 | 57 | 58 | if __name__ == '__main__': 59 | parser = argparse.ArgumentParser() 60 | 61 | parser.add_argument('--data', required=True, type=str, default=None) 62 | parser.add_argument('--validation_workers', type=int, default=16, 63 | help="Number of parallel processes to validate results") 64 | 65 | args = parser.parse_args() 66 | main(args) 67 | -------------------------------------------------------------------------------- /example_scripts/contriever.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --cpus-per-task=5 3 | #SBATCH --nodes=4 4 | #SBATCH --ntasks-per-node=8 5 | #SBATCH --gres=gpu:8 6 | #SBATCH --time=72:00:00 7 | #SBATCH --job-name=contriever 8 | #SBATCH --output=/private/home/gizacard/contriever/logtrain/%A 9 | #SBATCH --partition=learnlab 10 | #SBATCH --mem=450GB 11 | #SBATCH --signal=USR1@140 12 | #SBATCH --open-mode=append 13 | 14 | 15 | port=$(shuf -i 15000-16000 -n 1) 16 | TDIR="/private/home/gizacard/contriever/encoded-data" 17 | TRAINDATASETS="${TDIR}/wikisub/ ${TDIR}/cc-netsub/" 18 | 19 | rmin=0.05 20 | rmax=0.5 21 | T=0.05 22 | QSIZE=131072 23 | MOM=0.9995 24 | POOL=average 25 | AUG=delete 26 | PAUG=0.1 27 | LC=0. 28 | mo=bert-base-uncased 29 | mp=none 30 | 31 | name=$SLURM_JOB_ID-$POOL-rmin$rmin-rmax$rmax-T$T-$QSIZE-$MOM-$mo-$AUG-$PAUG 32 | 33 | srun ~gizacard/anaconda3/envs/contriever/bin/python3 train.py \ 34 | --model_path $mp \ 35 | --sampling_coefficient $LC \ 36 | --retriever_model_id $mo --pooling $POOL \ 37 | --augmentation $AUG --prob_augmentation $PAUG \ 38 | --train_data $TRAINDATASETS --loading_mode split \ 39 | --ratio_min $rmin --ratio_max $rmax --chunk_length 256 \ 40 | --momentum $MOM --queue_size $QSIZE --temperature $T \ 41 | --warmup_steps 20000 --total_steps 500000 --lr 0.00005 \ 42 | --name $name \ 43 | --scheduler linear \ 44 | --optim adamw \ 45 | --per_gpu_batch_size 64 \ 46 | --output_dir /checkpoint/gizacard/contriever/xling/$name \ 47 | --main_port $port \ 48 | 49 | -------------------------------------------------------------------------------- /example_scripts/mcontriever.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --cpus-per-task=5 3 | #SBATCH --nodes=8 4 | #SBATCH --ntasks-per-node=8 5 | #SBATCH --gres=gpu:8 6 | #SBATCH --time=72:00:00 7 | #SBATCH --job-name=mcontriever 8 | #SBATCH --output=/private/home/gizacard/contriever/logtrain/%A 9 | #SBATCH --partition=learnlab 10 | #SBATCH --mem=450GB 11 | #SBATCH --signal=USR1@140 12 | #SBATCH --open-mode=append 13 | 14 | 15 | port=$(shuf -i 15000-16000 -n 1) 16 | 17 | TDIR=/private/home/gizacard/contriever/encoded-data/bert-base-multilingual-cased/ 18 | TRAINDATASETS="${TDIR}fr_XX ${TDIR}en_XX ${TDIR}ar_AR ${TDIR}bn_IN ${TDIR}fi_FI ${TDIR}id_ID ${TDIR}ja_XX ${TDIR}ko_KR ${TDIR}ru_RU ${TDIR}sw_KE ${TDIR}hu_HU ${TDIR}he_IL ${TDIR}it_IT ${TDIR}km_KM ${TDIR}ms_MY ${TDIR}nl_XX ${TDIR}no_XX ${TDIR}pl_PL ${TDIR}pt_XX ${TDIR}sv_SE ${TDIR}te_IN ${TDIR}th_TH ${TDIR}tr_TR ${TDIR}vi_VN ${TDIR}zh_CN ${TDIR}zh_TW ${TDIR}es_XX ${TDIR}de_DE ${TDIR}da_DK" 19 | 20 | rmin=0.1 21 | rmax=0.5 22 | T=0.05 23 | QSIZE=32768 24 | MOM=0.999 25 | POOL=average 26 | AUG=none 27 | PAUG=0. 28 | LC=0. 29 | mo=bert-base-multilingual-cased 30 | mp=none 31 | 32 | name=$SLURM_JOB_ID-$POOL-rmin$rmin-rmax$rmax-T$T-$QSIZE-$MOM-$mo-$AUG-$PAUG 33 | 34 | srun ~gizacard/anaconda3/envs/pytorch10/bin/python3 ~gizacard/contriever/train.py \ 35 | --model_path $mp \ 36 | --sampling_coefficient $LC \ 37 | --augmentation $AUG --prob_augmentation $PAUG \ 38 | --retriever_model_id $mo --pooling $POOL \ 39 | --train_data $TRAINDATASETS --loading_mode split \ 40 | --ratio_min $rmin --ratio_max $rmax --chunk_length 256 \ 41 | --momentum $MOM --queue_size $QSIZE --temperature $T \ 42 | --warmup_steps 20000 --total_steps 500000 --lr 0.00005 \ 43 | --name $name \ 44 | --scheduler linear \ 45 | --optim adamw \ 46 | --per_gpu_batch_size 64 \ 47 | --output_dir /checkpoint/gizacard/contriever/xling/$name \ 48 | --main_port $port \ 49 | -------------------------------------------------------------------------------- /finetuning.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | import pdb 4 | import os 5 | import time 6 | import sys 7 | import torch 8 | from torch.utils.tensorboard import SummaryWriter 9 | import logging 10 | import json 11 | import numpy as np 12 | import torch.distributed as dist 13 | from torch.utils.data import DataLoader, RandomSampler, SequentialSampler 14 | 15 | from src.options import Options 16 | from src import data, beir_utils, slurm, dist_utils, utils, contriever, finetuning_data, inbatch 17 | 18 | import train 19 | 20 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 21 | 22 | logger = logging.getLogger(__name__) 23 | 24 | 25 | def finetuning(opt, model, optimizer, scheduler, tokenizer, step): 26 | 27 | run_stats = utils.WeightedAvgStats() 28 | 29 | tb_logger = utils.init_tb_logger(opt.output_dir) 30 | 31 | if hasattr(model, "module"): 32 | eval_model = model.module 33 | else: 34 | eval_model = model 35 | eval_model = eval_model.get_encoder() 36 | 37 | train_dataset = finetuning_data.Dataset( 38 | datapaths=opt.train_data, 39 | negative_ctxs=opt.negative_ctxs, 40 | negative_hard_ratio=opt.negative_hard_ratio, 41 | negative_hard_min_idx=opt.negative_hard_min_idx, 42 | normalize=opt.eval_normalize_text, 43 | global_rank=dist_utils.get_rank(), 44 | world_size=dist_utils.get_world_size(), 45 | maxload=opt.maxload, 46 | training=True, 47 | ) 48 | collator = finetuning_data.Collator(tokenizer, passage_maxlength=opt.chunk_length) 49 | train_sampler = RandomSampler(train_dataset) 50 | train_dataloader = DataLoader( 51 | train_dataset, 52 | sampler=train_sampler, 53 | batch_size=opt.per_gpu_batch_size, 54 | drop_last=True, 55 | num_workers=opt.num_workers, 56 | collate_fn=collator, 57 | ) 58 | 59 | train.eval_model(opt, eval_model, None, tokenizer, tb_logger, step) 60 | evaluate(opt, eval_model, tokenizer, tb_logger, step) 61 | 62 | epoch = 1 63 | 64 | model.train() 65 | prev_ids, prev_mask = None, None 66 | while step < opt.total_steps: 67 | logger.info(f"Start epoch {epoch}, number of batches: {len(train_dataloader)}") 68 | for i, batch in enumerate(train_dataloader): 69 | batch = {key: value.cuda() if isinstance(value, torch.Tensor) else value for key, value in batch.items()} 70 | step += 1 71 | 72 | train_loss, iter_stats = model(**batch, stats_prefix="train") 73 | train_loss.backward() 74 | 75 | if opt.optim == "sam" or opt.optim == "asam": 76 | optimizer.first_step(zero_grad=True) 77 | 78 | sam_loss, _ = model(**batch, stats_prefix="train/sam_opt") 79 | sam_loss.backward() 80 | optimizer.second_step(zero_grad=True) 81 | else: 82 | optimizer.step() 83 | scheduler.step() 84 | optimizer.zero_grad() 85 | 86 | run_stats.update(iter_stats) 87 | 88 | if step % opt.log_freq == 0: 89 | log = f"{step} / {opt.total_steps}" 90 | for k, v in sorted(run_stats.average_stats.items()): 91 | log += f" | {k}: {v:.3f}" 92 | if tb_logger: 93 | tb_logger.add_scalar(k, v, step) 94 | log += f" | lr: {scheduler.get_last_lr()[0]:0.3g}" 95 | log += f" | Memory: {torch.cuda.max_memory_allocated()//1e9} GiB" 96 | 97 | logger.info(log) 98 | run_stats.reset() 99 | 100 | if step % opt.eval_freq == 0: 101 | 102 | train.eval_model(opt, eval_model, None, tokenizer, tb_logger, step) 103 | evaluate(opt, eval_model, tokenizer, tb_logger, step) 104 | 105 | if step % opt.save_freq == 0 and dist_utils.get_rank() == 0: 106 | utils.save( 107 | eval_model, 108 | optimizer, 109 | scheduler, 110 | step, 111 | opt, 112 | opt.output_dir, 113 | f"step-{step}", 114 | ) 115 | model.train() 116 | 117 | if step >= opt.total_steps: 118 | break 119 | 120 | epoch += 1 121 | 122 | 123 | def evaluate(opt, model, tokenizer, tb_logger, step): 124 | dataset = finetuning_data.Dataset( 125 | datapaths=opt.eval_data, 126 | normalize=opt.eval_normalize_text, 127 | global_rank=dist_utils.get_rank(), 128 | world_size=dist_utils.get_world_size(), 129 | maxload=opt.maxload, 130 | training=False, 131 | ) 132 | collator = finetuning_data.Collator(tokenizer, passage_maxlength=opt.chunk_length) 133 | sampler = SequentialSampler(dataset) 134 | dataloader = DataLoader( 135 | dataset, 136 | sampler=sampler, 137 | batch_size=opt.per_gpu_batch_size, 138 | drop_last=False, 139 | num_workers=opt.num_workers, 140 | collate_fn=collator, 141 | ) 142 | 143 | model.eval() 144 | if hasattr(model, "module"): 145 | model = model.module 146 | correct_samples, total_samples, total_step = 0, 0, 0 147 | all_q, all_g, all_n = [], [], [] 148 | with torch.no_grad(): 149 | for i, batch in enumerate(dataloader): 150 | batch = {key: value.cuda() if isinstance(value, torch.Tensor) else value for key, value in batch.items()} 151 | 152 | all_tokens = torch.cat([batch["g_tokens"], batch["n_tokens"]], dim=0) 153 | all_mask = torch.cat([batch["g_mask"], batch["n_mask"]], dim=0) 154 | 155 | q_emb = model(input_ids=batch["q_tokens"], attention_mask=batch["q_mask"], normalize=opt.norm_query) 156 | all_emb = model(input_ids=all_tokens, attention_mask=all_mask, normalize=opt.norm_doc) 157 | 158 | g_emb, n_emb = torch.split(all_emb, [len(batch["g_tokens"]), len(batch["n_tokens"])]) 159 | 160 | all_q.append(q_emb) 161 | all_g.append(g_emb) 162 | all_n.append(n_emb) 163 | 164 | all_q = torch.cat(all_q, dim=0) 165 | all_g = torch.cat(all_g, dim=0) 166 | all_n = torch.cat(all_n, dim=0) 167 | 168 | labels = torch.arange(0, len(all_q), device=all_q.device, dtype=torch.long) 169 | 170 | all_sizes = dist_utils.get_varsize(all_g) 171 | all_g = dist_utils.varsize_gather_nograd(all_g) 172 | all_n = dist_utils.varsize_gather_nograd(all_n) 173 | labels = labels + sum(all_sizes[: dist_utils.get_rank()]) 174 | 175 | scores_pos = torch.einsum("id, jd->ij", all_q, all_g) 176 | scores_neg = torch.einsum("id, jd->ij", all_q, all_n) 177 | scores = torch.cat([scores_pos, scores_neg], dim=-1) 178 | 179 | argmax_idx = torch.argmax(scores, dim=1) 180 | sorted_scores, indices = torch.sort(scores, descending=True) 181 | isrelevant = indices == labels[:, None] 182 | rs = [r.cpu().numpy().nonzero()[0] for r in isrelevant] 183 | mrr = np.mean([1.0 / (r[0] + 1) if r.size else 0.0 for r in rs]) 184 | 185 | acc = (argmax_idx == labels).sum() / all_q.size(0) 186 | acc, total = dist_utils.weighted_average(acc, all_q.size(0)) 187 | mrr, _ = dist_utils.weighted_average(mrr, all_q.size(0)) 188 | acc = 100 * acc 189 | 190 | message = [] 191 | if dist_utils.is_main(): 192 | message = [f"eval acc: {acc:.2f}%", f"eval mrr: {mrr:.3f}"] 193 | logger.info(" | ".join(message)) 194 | if tb_logger is not None: 195 | tb_logger.add_scalar(f"eval_acc", acc, step) 196 | tb_logger.add_scalar(f"mrr", mrr, step) 197 | 198 | 199 | def main(): 200 | logger.info("Start") 201 | 202 | options = Options() 203 | opt = options.parse() 204 | 205 | torch.manual_seed(opt.seed) 206 | slurm.init_distributed_mode(opt) 207 | slurm.init_signal_handler() 208 | 209 | directory_exists = os.path.isdir(opt.output_dir) 210 | if dist.is_initialized(): 211 | dist.barrier() 212 | os.makedirs(opt.output_dir, exist_ok=True) 213 | if not directory_exists and dist_utils.is_main(): 214 | options.print_options(opt) 215 | if dist.is_initialized(): 216 | dist.barrier() 217 | utils.init_logger(opt) 218 | 219 | step = 0 220 | 221 | retriever, tokenizer, retriever_model_id = contriever.load_retriever(opt.model_path, opt.pooling, opt.random_init) 222 | opt.retriever_model_id = retriever_model_id 223 | model = inbatch.InBatch(opt, retriever, tokenizer) 224 | 225 | model = model.cuda() 226 | 227 | optimizer, scheduler = utils.set_optim(opt, model) 228 | # if dist_utils.is_main(): 229 | # utils.save(model, optimizer, scheduler, global_step, 0., opt, opt.output_dir, f"step-{0}") 230 | logger.info(utils.get_parameters(model)) 231 | 232 | for name, module in model.named_modules(): 233 | if isinstance(module, torch.nn.Dropout): 234 | module.p = opt.dropout 235 | 236 | if torch.distributed.is_initialized(): 237 | model = torch.nn.parallel.DistributedDataParallel( 238 | model, 239 | device_ids=[opt.local_rank], 240 | output_device=opt.local_rank, 241 | find_unused_parameters=False, 242 | ) 243 | 244 | logger.info("Start training") 245 | finetuning(opt, model, optimizer, scheduler, tokenizer, step) 246 | 247 | 248 | if __name__ == "__main__": 249 | main() 250 | -------------------------------------------------------------------------------- /generate_passage_embeddings.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import os 8 | 9 | import argparse 10 | import csv 11 | import logging 12 | import pickle 13 | 14 | import numpy as np 15 | import torch 16 | 17 | import transformers 18 | 19 | import src.slurm 20 | import src.contriever 21 | import src.utils 22 | import src.data 23 | import src.normalize_text 24 | 25 | 26 | def embed_passages(args, passages, model, tokenizer): 27 | total = 0 28 | allids, allembeddings = [], [] 29 | batch_ids, batch_text = [], [] 30 | with torch.no_grad(): 31 | for k, p in enumerate(passages): 32 | batch_ids.append(p["id"]) 33 | if args.no_title or not "title" in p: 34 | text = p["text"] 35 | else: 36 | text = p["title"] + " " + p["text"] 37 | if args.lowercase: 38 | text = text.lower() 39 | if args.normalize_text: 40 | text = src.normalize_text.normalize(text) 41 | batch_text.append(text) 42 | 43 | if len(batch_text) == args.per_gpu_batch_size or k == len(passages) - 1: 44 | 45 | encoded_batch = tokenizer.batch_encode_plus( 46 | batch_text, 47 | return_tensors="pt", 48 | max_length=args.passage_maxlength, 49 | padding=True, 50 | truncation=True, 51 | ) 52 | 53 | encoded_batch = {k: v.cuda() for k, v in encoded_batch.items()} 54 | embeddings = model(**encoded_batch) 55 | 56 | embeddings = embeddings.cpu() 57 | total += len(batch_ids) 58 | allids.extend(batch_ids) 59 | allembeddings.append(embeddings) 60 | 61 | batch_text = [] 62 | batch_ids = [] 63 | if k % 100000 == 0 and k > 0: 64 | print(f"Encoded passages {total}") 65 | 66 | allembeddings = torch.cat(allembeddings, dim=0).numpy() 67 | return allids, allembeddings 68 | 69 | 70 | def main(args): 71 | model, tokenizer, _ = src.contriever.load_retriever(args.model_name_or_path) 72 | print(f"Model loaded from {args.model_name_or_path}.", flush=True) 73 | model.eval() 74 | model = model.cuda() 75 | if not args.no_fp16: 76 | model = model.half() 77 | 78 | passages = src.data.load_passages(args.passages) 79 | 80 | shard_size = len(passages) // args.num_shards 81 | start_idx = args.shard_id * shard_size 82 | end_idx = start_idx + shard_size 83 | if args.shard_id == args.num_shards - 1: 84 | end_idx = len(passages) 85 | 86 | passages = passages[start_idx:end_idx] 87 | print(f"Embedding generation for {len(passages)} passages from idx {start_idx} to {end_idx}.") 88 | 89 | allids, allembeddings = embed_passages(args, passages, model, tokenizer) 90 | 91 | save_file = os.path.join(args.output_dir, args.prefix + f"_{args.shard_id:02d}") 92 | os.makedirs(args.output_dir, exist_ok=True) 93 | print(f"Saving {len(allids)} passage embeddings to {save_file}.") 94 | with open(save_file, mode="wb") as f: 95 | pickle.dump((allids, allembeddings), f) 96 | 97 | print(f"Total passages processed {len(allids)}. Written to {save_file}.") 98 | 99 | 100 | if __name__ == "__main__": 101 | parser = argparse.ArgumentParser() 102 | 103 | parser.add_argument("--passages", type=str, default=None, help="Path to passages (.tsv file)") 104 | parser.add_argument("--output_dir", type=str, default="wikipedia_embeddings", help="dir path to save embeddings") 105 | parser.add_argument("--prefix", type=str, default="passages", help="prefix path to save embeddings") 106 | parser.add_argument("--shard_id", type=int, default=0, help="Id of the current shard") 107 | parser.add_argument("--num_shards", type=int, default=1, help="Total number of shards") 108 | parser.add_argument( 109 | "--per_gpu_batch_size", type=int, default=512, help="Batch size for the passage encoder forward pass" 110 | ) 111 | parser.add_argument("--passage_maxlength", type=int, default=512, help="Maximum number of tokens in a passage") 112 | parser.add_argument( 113 | "--model_name_or_path", type=str, help="path to directory containing model weights and config file" 114 | ) 115 | parser.add_argument("--no_fp16", action="store_true", help="inference in fp32") 116 | parser.add_argument("--no_title", action="store_true", help="title not added to the passage body") 117 | parser.add_argument("--lowercase", action="store_true", help="lowercase text before encoding") 118 | parser.add_argument("--normalize_text", action="store_true", help="lowercase text before encoding") 119 | 120 | args = parser.parse_args() 121 | 122 | src.slurm.init_distributed_mode(args) 123 | 124 | main(args) 125 | -------------------------------------------------------------------------------- /passage_retrieval.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import os 8 | import argparse 9 | import csv 10 | import json 11 | import logging 12 | import pickle 13 | import time 14 | import glob 15 | from pathlib import Path 16 | 17 | import numpy as np 18 | import torch 19 | import transformers 20 | 21 | import src.index 22 | import src.contriever 23 | import src.utils 24 | import src.slurm 25 | import src.data 26 | from src.evaluation import calculate_matches 27 | import src.normalize_text 28 | 29 | os.environ["TOKENIZERS_PARALLELISM"] = "true" 30 | 31 | 32 | def embed_queries(args, queries, model, tokenizer): 33 | model.eval() 34 | embeddings, batch_question = [], [] 35 | with torch.no_grad(): 36 | 37 | for k, q in enumerate(queries): 38 | if args.lowercase: 39 | q = q.lower() 40 | if args.normalize_text: 41 | q = src.normalize_text.normalize(q) 42 | batch_question.append(q) 43 | 44 | if len(batch_question) == args.per_gpu_batch_size or k == len(queries) - 1: 45 | 46 | encoded_batch = tokenizer.batch_encode_plus( 47 | batch_question, 48 | return_tensors="pt", 49 | max_length=args.question_maxlength, 50 | padding=True, 51 | truncation=True, 52 | ) 53 | encoded_batch = {k: v.cuda() for k, v in encoded_batch.items()} 54 | output = model(**encoded_batch) 55 | embeddings.append(output.cpu()) 56 | 57 | batch_question = [] 58 | 59 | embeddings = torch.cat(embeddings, dim=0) 60 | print(f"Questions embeddings shape: {embeddings.size()}") 61 | 62 | return embeddings.numpy() 63 | 64 | 65 | def index_encoded_data(index, embedding_files, indexing_batch_size): 66 | allids = [] 67 | allembeddings = np.array([]) 68 | for i, file_path in enumerate(embedding_files): 69 | print(f"Loading file {file_path}") 70 | with open(file_path, "rb") as fin: 71 | ids, embeddings = pickle.load(fin) 72 | 73 | allembeddings = np.vstack((allembeddings, embeddings)) if allembeddings.size else embeddings 74 | allids.extend(ids) 75 | while allembeddings.shape[0] > indexing_batch_size: 76 | allembeddings, allids = add_embeddings(index, allembeddings, allids, indexing_batch_size) 77 | 78 | while allembeddings.shape[0] > 0: 79 | allembeddings, allids = add_embeddings(index, allembeddings, allids, indexing_batch_size) 80 | 81 | print("Data indexing completed.") 82 | 83 | 84 | def add_embeddings(index, embeddings, ids, indexing_batch_size): 85 | end_idx = min(indexing_batch_size, embeddings.shape[0]) 86 | ids_toadd = ids[:end_idx] 87 | embeddings_toadd = embeddings[:end_idx] 88 | ids = ids[end_idx:] 89 | embeddings = embeddings[end_idx:] 90 | index.index_data(ids_toadd, embeddings_toadd) 91 | return embeddings, ids 92 | 93 | 94 | def validate(data, workers_num): 95 | match_stats = calculate_matches(data, workers_num) 96 | top_k_hits = match_stats.top_k_hits 97 | 98 | print("Validation results: top k documents hits %s", top_k_hits) 99 | top_k_hits = [v / len(data) for v in top_k_hits] 100 | message = "" 101 | for k in [5, 10, 20, 100]: 102 | if k <= len(top_k_hits): 103 | message += f"R@{k}: {top_k_hits[k-1]} " 104 | print(message) 105 | return match_stats.questions_doc_hits 106 | 107 | 108 | def add_passages(data, passages, top_passages_and_scores): 109 | # add passages to original data 110 | merged_data = [] 111 | assert len(data) == len(top_passages_and_scores) 112 | for i, d in enumerate(data): 113 | results_and_scores = top_passages_and_scores[i] 114 | docs = [passages[doc_id] for doc_id in results_and_scores[0]] 115 | scores = [str(score) for score in results_and_scores[1]] 116 | ctxs_num = len(docs) 117 | d["ctxs"] = [ 118 | { 119 | "id": results_and_scores[0][c], 120 | "title": docs[c]["title"], 121 | "text": docs[c]["text"], 122 | "score": scores[c], 123 | } 124 | for c in range(ctxs_num) 125 | ] 126 | 127 | 128 | def add_hasanswer(data, hasanswer): 129 | # add hasanswer to data 130 | for i, ex in enumerate(data): 131 | for k, d in enumerate(ex["ctxs"]): 132 | d["hasanswer"] = hasanswer[i][k] 133 | 134 | 135 | def load_data(data_path): 136 | if data_path.endswith(".json"): 137 | with open(data_path, "r") as fin: 138 | data = json.load(fin) 139 | elif data_path.endswith(".jsonl"): 140 | data = [] 141 | with open(data_path, "r") as fin: 142 | for k, example in enumerate(fin): 143 | example = json.loads(example) 144 | data.append(example) 145 | return data 146 | 147 | 148 | def main(args): 149 | 150 | print(f"Loading model from: {args.model_name_or_path}") 151 | model, tokenizer, _ = src.contriever.load_retriever(args.model_name_or_path) 152 | model.eval() 153 | model = model.cuda() 154 | if not args.no_fp16: 155 | model = model.half() 156 | 157 | index = src.index.Indexer(args.projection_size, args.n_subquantizers, args.n_bits) 158 | 159 | # index all passages 160 | input_paths = glob.glob(args.passages_embeddings) 161 | input_paths = sorted(input_paths) 162 | embeddings_dir = os.path.dirname(input_paths[0]) 163 | index_path = os.path.join(embeddings_dir, "index.faiss") 164 | if args.save_or_load_index and os.path.exists(index_path): 165 | index.deserialize_from(embeddings_dir) 166 | else: 167 | print(f"Indexing passages from files {input_paths}") 168 | start_time_indexing = time.time() 169 | index_encoded_data(index, input_paths, args.indexing_batch_size) 170 | print(f"Indexing time: {time.time()-start_time_indexing:.1f} s.") 171 | if args.save_or_load_index: 172 | index.serialize(embeddings_dir) 173 | 174 | # load passages 175 | passages = src.data.load_passages(args.passages) 176 | passage_id_map = {x["id"]: x for x in passages} 177 | 178 | data_paths = glob.glob(args.data) 179 | alldata = [] 180 | for path in data_paths: 181 | data = load_data(path) 182 | output_path = os.path.join(args.output_dir, os.path.basename(path)) 183 | 184 | queries = [ex["question"] for ex in data] 185 | questions_embedding = embed_queries(args, queries, model, tokenizer) 186 | 187 | # get top k results 188 | start_time_retrieval = time.time() 189 | top_ids_and_scores = index.search_knn(questions_embedding, args.n_docs) 190 | print(f"Search time: {time.time()-start_time_retrieval:.1f} s.") 191 | 192 | add_passages(data, passage_id_map, top_ids_and_scores) 193 | hasanswer = validate(data, args.validation_workers) 194 | add_hasanswer(data, hasanswer) 195 | os.makedirs(os.path.dirname(output_path), exist_ok=True) 196 | with open(output_path, "w") as fout: 197 | for ex in data: 198 | json.dump(ex, fout, ensure_ascii=False) 199 | fout.write("\n") 200 | print(f"Saved results to {output_path}") 201 | 202 | 203 | if __name__ == "__main__": 204 | parser = argparse.ArgumentParser() 205 | 206 | parser.add_argument( 207 | "--data", 208 | required=True, 209 | type=str, 210 | default=None, 211 | help=".json file containing question and answers, similar format to reader data", 212 | ) 213 | parser.add_argument("--passages", type=str, default=None, help="Path to passages (.tsv file)") 214 | parser.add_argument("--passages_embeddings", type=str, default=None, help="Glob path to encoded passages") 215 | parser.add_argument( 216 | "--output_dir", type=str, default=None, help="Results are written to outputdir with data suffix" 217 | ) 218 | parser.add_argument("--n_docs", type=int, default=100, help="Number of documents to retrieve per questions") 219 | parser.add_argument( 220 | "--validation_workers", type=int, default=32, help="Number of parallel processes to validate results" 221 | ) 222 | parser.add_argument("--per_gpu_batch_size", type=int, default=64, help="Batch size for question encoding") 223 | parser.add_argument( 224 | "--save_or_load_index", action="store_true", help="If enabled, save index and load index if it exists" 225 | ) 226 | parser.add_argument( 227 | "--model_name_or_path", type=str, help="path to directory containing model weights and config file" 228 | ) 229 | parser.add_argument("--no_fp16", action="store_true", help="inference in fp32") 230 | parser.add_argument("--question_maxlength", type=int, default=512, help="Maximum number of tokens in a question") 231 | parser.add_argument( 232 | "--indexing_batch_size", type=int, default=1000000, help="Batch size of the number of passages indexed" 233 | ) 234 | parser.add_argument("--projection_size", type=int, default=768) 235 | parser.add_argument( 236 | "--n_subquantizers", 237 | type=int, 238 | default=0, 239 | help="Number of subquantizer used for vector quantization, if 0 flat index is used", 240 | ) 241 | parser.add_argument("--n_bits", type=int, default=8, help="Number of bits per subquantizer") 242 | parser.add_argument("--lang", nargs="+") 243 | parser.add_argument("--dataset", type=str, default="none") 244 | parser.add_argument("--lowercase", action="store_true", help="lowercase text before encoding") 245 | parser.add_argument("--normalize_text", action="store_true", help="normalize text") 246 | 247 | args = parser.parse_args() 248 | src.slurm.init_distributed_mode(args) 249 | main(args) 250 | -------------------------------------------------------------------------------- /preprocess.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | import os 4 | import argparse 5 | import torch 6 | 7 | import transformers 8 | from src.normalize_text import normalize 9 | 10 | 11 | def save(tensor, split_path): 12 | if not os.path.exists(os.path.dirname(split_path)): 13 | os.makedirs(os.path.dirname(split_path)) 14 | with open(split_path, 'wb') as fout: 15 | torch.save(tensor, fout) 16 | 17 | def apply_tokenizer(path, tokenizer, normalize_text=False): 18 | alltokens = [] 19 | lines = [] 20 | with open(path, "r", encoding="utf-8") as fin: 21 | for k, line in enumerate(fin): 22 | if normalize_text: 23 | line = normalize(line) 24 | 25 | lines.append(line) 26 | if len(lines) > 1000000: 27 | tokens = tokenizer.batch_encode_plus(lines, add_special_tokens=False)['input_ids'] 28 | tokens = [torch.tensor(x, dtype=torch.int) for x in tokens] 29 | alltokens.extend(tokens) 30 | lines = [] 31 | 32 | tokens = tokenizer.batch_encode_plus(lines, add_special_tokens=False)['input_ids'] 33 | tokens = [torch.tensor(x, dtype=torch.int) for x in tokens] 34 | alltokens.extend(tokens) 35 | 36 | alltokens = torch.cat(alltokens) 37 | return alltokens 38 | 39 | def tokenize_file(args): 40 | filename = os.path.basename(args.datapath) 41 | savepath = os.path.join(args.outdir, f"{filename}.pkl") 42 | if os.path.exists(savepath): 43 | if args.overwrite: 44 | print(f"File {savepath} already exists, overwriting") 45 | else: 46 | print(f"File {savepath} already exists, exiting") 47 | return 48 | try: 49 | tokenizer = transformers.AutoTokenizer.from_pretrained(args.tokenizer, local_files_only=True) 50 | except: 51 | tokenizer = transformers.AutoTokenizer.from_pretrained(args.tokenizer, local_files_only=False) 52 | print(f"Encoding {args.datapath}...") 53 | tokens = apply_tokenizer(args.datapath, tokenizer, normalize_text=args.normalize_text) 54 | 55 | print(f"Saving at {savepath}...") 56 | save(tokens, savepath) 57 | 58 | 59 | if __name__ == '__main__': 60 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 61 | parser.add_argument("--datapath", type=str) 62 | parser.add_argument("--outdir", type=str) 63 | parser.add_argument("--tokenizer", type=str) 64 | parser.add_argument("--overwrite", action="store_true") 65 | parser.add_argument("--normalize_text", action="store_true") 66 | 67 | args, _ = parser.parse_known_args() 68 | tokenize_file(args) 69 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.11.0 2 | transformers==4.18.0 3 | beir==1.0.0 4 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/contriever/39fb2201450cdba1648183737ed56d6b1bc33778/src/__init__.py -------------------------------------------------------------------------------- /src/beir_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | import os 4 | from collections import defaultdict 5 | from typing import List, Dict 6 | import numpy as np 7 | import torch 8 | import torch.distributed as dist 9 | 10 | import beir.util 11 | from beir.datasets.data_loader import GenericDataLoader 12 | from beir.retrieval.evaluation import EvaluateRetrieval 13 | from beir.retrieval.search.dense import DenseRetrievalExactSearch 14 | 15 | from beir.reranking.models import CrossEncoder 16 | from beir.reranking import Rerank 17 | 18 | import src.dist_utils as dist_utils 19 | from src import normalize_text 20 | 21 | 22 | class DenseEncoderModel: 23 | def __init__( 24 | self, 25 | query_encoder, 26 | doc_encoder=None, 27 | tokenizer=None, 28 | max_length=512, 29 | add_special_tokens=True, 30 | norm_query=False, 31 | norm_doc=False, 32 | lower_case=False, 33 | normalize_text=False, 34 | **kwargs, 35 | ): 36 | self.query_encoder = query_encoder 37 | self.doc_encoder = doc_encoder 38 | self.tokenizer = tokenizer 39 | self.max_length = max_length 40 | self.add_special_tokens = add_special_tokens 41 | self.norm_query = norm_query 42 | self.norm_doc = norm_doc 43 | self.lower_case = lower_case 44 | self.normalize_text = normalize_text 45 | 46 | def encode_queries(self, queries: List[str], batch_size: int, **kwargs) -> np.ndarray: 47 | 48 | if dist.is_initialized(): 49 | idx = np.array_split(range(len(queries)), dist.get_world_size())[dist.get_rank()] 50 | else: 51 | idx = range(len(queries)) 52 | 53 | queries = [queries[i] for i in idx] 54 | if self.normalize_text: 55 | queries = [normalize_text.normalize(q) for q in queries] 56 | if self.lower_case: 57 | queries = [q.lower() for q in queries] 58 | 59 | allemb = [] 60 | nbatch = (len(queries) - 1) // batch_size + 1 61 | with torch.no_grad(): 62 | for k in range(nbatch): 63 | start_idx = k * batch_size 64 | end_idx = min((k + 1) * batch_size, len(queries)) 65 | 66 | qencode = self.tokenizer.batch_encode_plus( 67 | queries[start_idx:end_idx], 68 | max_length=self.max_length, 69 | padding=True, 70 | truncation=True, 71 | add_special_tokens=self.add_special_tokens, 72 | return_tensors="pt", 73 | ) 74 | qencode = {key: value.cuda() for key, value in qencode.items()} 75 | emb = self.query_encoder(**qencode, normalize=self.norm_query) 76 | allemb.append(emb.cpu()) 77 | 78 | allemb = torch.cat(allemb, dim=0) 79 | allemb = allemb.cuda() 80 | if dist.is_initialized(): 81 | allemb = dist_utils.varsize_gather_nograd(allemb) 82 | allemb = allemb.cpu().numpy() 83 | return allemb 84 | 85 | def encode_corpus(self, corpus: List[Dict[str, str]], batch_size: int, **kwargs): 86 | 87 | if dist.is_initialized(): 88 | idx = np.array_split(range(len(corpus)), dist.get_world_size())[dist.get_rank()] 89 | else: 90 | idx = range(len(corpus)) 91 | corpus = [corpus[i] for i in idx] 92 | corpus = [c["title"] + " " + c["text"] if len(c["title"]) > 0 else c["text"] for c in corpus] 93 | if self.normalize_text: 94 | corpus = [normalize_text.normalize(c) for c in corpus] 95 | if self.lower_case: 96 | corpus = [c.lower() for c in corpus] 97 | 98 | allemb = [] 99 | nbatch = (len(corpus) - 1) // batch_size + 1 100 | with torch.no_grad(): 101 | for k in range(nbatch): 102 | start_idx = k * batch_size 103 | end_idx = min((k + 1) * batch_size, len(corpus)) 104 | 105 | cencode = self.tokenizer.batch_encode_plus( 106 | corpus[start_idx:end_idx], 107 | max_length=self.max_length, 108 | padding=True, 109 | truncation=True, 110 | add_special_tokens=self.add_special_tokens, 111 | return_tensors="pt", 112 | ) 113 | cencode = {key: value.cuda() for key, value in cencode.items()} 114 | emb = self.doc_encoder(**cencode, normalize=self.norm_doc) 115 | allemb.append(emb.cpu()) 116 | 117 | allemb = torch.cat(allemb, dim=0) 118 | allemb = allemb.cuda() 119 | if dist.is_initialized(): 120 | allemb = dist_utils.varsize_gather_nograd(allemb) 121 | allemb = allemb.cpu().numpy() 122 | return allemb 123 | 124 | 125 | def evaluate_model( 126 | query_encoder, 127 | doc_encoder, 128 | tokenizer, 129 | dataset, 130 | batch_size=128, 131 | add_special_tokens=True, 132 | norm_query=False, 133 | norm_doc=False, 134 | is_main=True, 135 | split="test", 136 | score_function="dot", 137 | beir_dir="BEIR/datasets", 138 | save_results_path=None, 139 | lower_case=False, 140 | normalize_text=False, 141 | ): 142 | 143 | metrics = defaultdict(list) # store final results 144 | 145 | if hasattr(query_encoder, "module"): 146 | query_encoder = query_encoder.module 147 | query_encoder.eval() 148 | 149 | if doc_encoder is not None: 150 | if hasattr(doc_encoder, "module"): 151 | doc_encoder = doc_encoder.module 152 | doc_encoder.eval() 153 | else: 154 | doc_encoder = query_encoder 155 | 156 | dmodel = DenseRetrievalExactSearch( 157 | DenseEncoderModel( 158 | query_encoder=query_encoder, 159 | doc_encoder=doc_encoder, 160 | tokenizer=tokenizer, 161 | add_special_tokens=add_special_tokens, 162 | norm_query=norm_query, 163 | norm_doc=norm_doc, 164 | lower_case=lower_case, 165 | normalize_text=normalize_text, 166 | ), 167 | batch_size=batch_size, 168 | ) 169 | retriever = EvaluateRetrieval(dmodel, score_function=score_function) 170 | data_path = os.path.join(beir_dir, dataset) 171 | 172 | if not os.path.isdir(data_path) and is_main: 173 | url = "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{}.zip".format(dataset) 174 | data_path = beir.util.download_and_unzip(url, beir_dir) 175 | dist_utils.barrier() 176 | 177 | if not dataset == "cqadupstack": 178 | corpus, queries, qrels = GenericDataLoader(data_folder=data_path).load(split=split) 179 | results = retriever.retrieve(corpus, queries) 180 | if is_main: 181 | ndcg, _map, recall, precision = retriever.evaluate(qrels, results, retriever.k_values) 182 | for metric in (ndcg, _map, recall, precision, "mrr", "recall_cap", "hole"): 183 | if isinstance(metric, str): 184 | metric = retriever.evaluate_custom(qrels, results, retriever.k_values, metric=metric) 185 | for key, value in metric.items(): 186 | metrics[key].append(value) 187 | if save_results_path is not None: 188 | torch.save(results, f"{save_results_path}") 189 | elif dataset == "cqadupstack": # compute macroaverage over datasets 190 | paths = glob.glob(data_path) 191 | for path in paths: 192 | corpus, queries, qrels = GenericDataLoader(data_folder=data_folder).load(split=split) 193 | results = retriever.retrieve(corpus, queries) 194 | if is_main: 195 | ndcg, _map, recall, precision = retriever.evaluate(qrels, results, retriever.k_values) 196 | for metric in (ndcg, _map, recall, precision, "mrr", "recall_cap", "hole"): 197 | if isinstance(metric, str): 198 | metric = retriever.evaluate_custom(qrels, results, retriever.k_values, metric=metric) 199 | for key, value in metric.items(): 200 | metrics[key].append(value) 201 | for key, value in metrics.items(): 202 | assert ( 203 | len(value) == 12 204 | ), f"cqadupstack includes 12 datasets, only {len(value)} values were compute for the {key} metric" 205 | 206 | metrics = {key: 100 * np.mean(value) for key, value in metrics.items()} 207 | 208 | return metrics 209 | -------------------------------------------------------------------------------- /src/contriever.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | import os 4 | import torch 5 | import transformers 6 | from transformers import BertModel, XLMRobertaModel 7 | 8 | from src import utils 9 | 10 | 11 | class Contriever(BertModel): 12 | def __init__(self, config, pooling="average", **kwargs): 13 | super().__init__(config, add_pooling_layer=False) 14 | if not hasattr(config, "pooling"): 15 | self.config.pooling = pooling 16 | 17 | def forward( 18 | self, 19 | input_ids=None, 20 | attention_mask=None, 21 | token_type_ids=None, 22 | position_ids=None, 23 | head_mask=None, 24 | inputs_embeds=None, 25 | encoder_hidden_states=None, 26 | encoder_attention_mask=None, 27 | output_attentions=None, 28 | output_hidden_states=None, 29 | normalize=False, 30 | ): 31 | 32 | model_output = super().forward( 33 | input_ids=input_ids, 34 | attention_mask=attention_mask, 35 | token_type_ids=token_type_ids, 36 | position_ids=position_ids, 37 | head_mask=head_mask, 38 | inputs_embeds=inputs_embeds, 39 | encoder_hidden_states=encoder_hidden_states, 40 | encoder_attention_mask=encoder_attention_mask, 41 | output_attentions=output_attentions, 42 | output_hidden_states=output_hidden_states, 43 | ) 44 | 45 | last_hidden = model_output["last_hidden_state"] 46 | last_hidden = last_hidden.masked_fill(~attention_mask[..., None].bool(), 0.0) 47 | 48 | if self.config.pooling == "average": 49 | emb = last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None] 50 | elif self.config.pooling == "cls": 51 | emb = last_hidden[:, 0] 52 | 53 | if normalize: 54 | emb = torch.nn.functional.normalize(emb, dim=-1) 55 | return emb 56 | 57 | 58 | class XLMRetriever(XLMRobertaModel): 59 | def __init__(self, config, pooling="average", **kwargs): 60 | super().__init__(config, add_pooling_layer=False) 61 | if not hasattr(config, "pooling"): 62 | self.config.pooling = pooling 63 | 64 | def forward( 65 | self, 66 | input_ids=None, 67 | attention_mask=None, 68 | token_type_ids=None, 69 | position_ids=None, 70 | head_mask=None, 71 | inputs_embeds=None, 72 | encoder_hidden_states=None, 73 | encoder_attention_mask=None, 74 | output_attentions=None, 75 | output_hidden_states=None, 76 | normalize=False, 77 | ): 78 | 79 | model_output = super().forward( 80 | input_ids=input_ids, 81 | attention_mask=attention_mask, 82 | token_type_ids=token_type_ids, 83 | position_ids=position_ids, 84 | head_mask=head_mask, 85 | inputs_embeds=inputs_embeds, 86 | encoder_hidden_states=encoder_hidden_states, 87 | encoder_attention_mask=encoder_attention_mask, 88 | output_attentions=output_attentions, 89 | output_hidden_states=output_hidden_states, 90 | ) 91 | 92 | last_hidden = model_output["last_hidden_state"] 93 | last_hidden = last_hidden.masked_fill(~attention_mask[..., None].bool(), 0.0) 94 | if self.config.pooling == "average": 95 | emb = last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None] 96 | elif self.config.pooling == "cls": 97 | emb = last_hidden[:, 0] 98 | if normalize: 99 | emb = torch.nn.functional.normalize(emb, dim=-1) 100 | return emb 101 | 102 | 103 | def load_retriever(model_path, pooling="average", random_init=False): 104 | # try: check if model exists locally 105 | path = os.path.join(model_path, "checkpoint.pth") 106 | if os.path.exists(path): 107 | pretrained_dict = torch.load(path, map_location="cpu") 108 | opt = pretrained_dict["opt"] 109 | if hasattr(opt, "retriever_model_id"): 110 | retriever_model_id = opt.retriever_model_id 111 | else: 112 | # retriever_model_id = "bert-base-uncased" 113 | retriever_model_id = "bert-base-multilingual-cased" 114 | tokenizer = utils.load_hf(transformers.AutoTokenizer, retriever_model_id) 115 | cfg = utils.load_hf(transformers.AutoConfig, retriever_model_id) 116 | if "xlm" in retriever_model_id: 117 | model_class = XLMRetriever 118 | else: 119 | model_class = Contriever 120 | retriever = model_class(cfg) 121 | pretrained_dict = pretrained_dict["model"] 122 | 123 | if any("encoder_q." in key for key in pretrained_dict.keys()): # test if model is defined with moco class 124 | pretrained_dict = {k.replace("encoder_q.", ""): v for k, v in pretrained_dict.items() if "encoder_q." in k} 125 | elif any("encoder." in key for key in pretrained_dict.keys()): # test if model is defined with inbatch class 126 | pretrained_dict = {k.replace("encoder.", ""): v for k, v in pretrained_dict.items() if "encoder." in k} 127 | retriever.load_state_dict(pretrained_dict, strict=False) 128 | else: 129 | retriever_model_id = model_path 130 | if "xlm" in retriever_model_id: 131 | model_class = XLMRetriever 132 | else: 133 | model_class = Contriever 134 | cfg = utils.load_hf(transformers.AutoConfig, model_path) 135 | tokenizer = utils.load_hf(transformers.AutoTokenizer, model_path) 136 | retriever = utils.load_hf(model_class, model_path) 137 | 138 | return retriever, tokenizer, retriever_model_id 139 | -------------------------------------------------------------------------------- /src/data.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | import os 4 | import glob 5 | import torch 6 | import random 7 | import json 8 | import csv 9 | import numpy as np 10 | import numpy.random 11 | import logging 12 | from collections import defaultdict 13 | import torch.distributed as dist 14 | 15 | from src import dist_utils 16 | 17 | logger = logging.getLogger(__name__) 18 | 19 | 20 | def load_data(opt, tokenizer): 21 | datasets = {} 22 | for path in opt.train_data: 23 | data = load_dataset(path, opt.loading_mode) 24 | if data is not None: 25 | datasets[path] = Dataset(data, opt.chunk_length, tokenizer, opt) 26 | dataset = MultiDataset(datasets) 27 | dataset.set_prob(coeff=opt.sampling_coefficient) 28 | return dataset 29 | 30 | 31 | def load_dataset(data_path, loading_mode): 32 | files = glob.glob(os.path.join(data_path, "*.p*")) 33 | files.sort() 34 | tensors = [] 35 | if loading_mode == "split": 36 | files_split = list(np.array_split(files, dist_utils.get_world_size()))[dist_utils.get_rank()] 37 | for filepath in files_split: 38 | try: 39 | tensors.append(torch.load(filepath, map_location="cpu")) 40 | except: 41 | logger.warning(f"Unable to load file {filepath}") 42 | elif loading_mode == "full": 43 | for fin in files: 44 | tensors.append(torch.load(fin, map_location="cpu")) 45 | elif loading_mode == "single": 46 | tensors.append(torch.load(files[0], map_location="cpu")) 47 | if len(tensors) == 0: 48 | return None 49 | tensor = torch.cat(tensors) 50 | return tensor 51 | 52 | 53 | class MultiDataset(torch.utils.data.Dataset): 54 | def __init__(self, datasets): 55 | 56 | self.datasets = datasets 57 | self.prob = [1 / len(self.datasets) for _ in self.datasets] 58 | self.dataset_ids = list(self.datasets.keys()) 59 | 60 | def __len__(self): 61 | return sum([len(dataset) for dataset in self.datasets.values()]) 62 | 63 | def __getitem__(self, index): 64 | dataset_idx = numpy.random.choice(range(len(self.prob)), 1, p=self.prob)[0] 65 | did = self.dataset_ids[dataset_idx] 66 | index = random.randint(0, len(self.datasets[did]) - 1) 67 | sample = self.datasets[did][index] 68 | sample["dataset_id"] = did 69 | return sample 70 | 71 | def generate_offset(self): 72 | for dataset in self.datasets.values(): 73 | dataset.generate_offset() 74 | 75 | def set_prob(self, coeff=0.0): 76 | 77 | prob = np.array([float(len(dataset)) for _, dataset in self.datasets.items()]) 78 | prob /= prob.sum() 79 | prob = np.array([p**coeff for p in prob]) 80 | prob /= prob.sum() 81 | self.prob = prob 82 | 83 | 84 | class Dataset(torch.utils.data.Dataset): 85 | """Monolingual dataset based on a list of paths""" 86 | 87 | def __init__(self, data, chunk_length, tokenizer, opt): 88 | 89 | self.data = data 90 | self.chunk_length = chunk_length 91 | self.tokenizer = tokenizer 92 | self.opt = opt 93 | self.generate_offset() 94 | 95 | def __len__(self): 96 | return (self.data.size(0) - self.offset) // self.chunk_length 97 | 98 | def __getitem__(self, index): 99 | start_idx = self.offset + index * self.chunk_length 100 | end_idx = start_idx + self.chunk_length 101 | tokens = self.data[start_idx:end_idx] 102 | q_tokens = randomcrop(tokens, self.opt.ratio_min, self.opt.ratio_max) 103 | k_tokens = randomcrop(tokens, self.opt.ratio_min, self.opt.ratio_max) 104 | q_tokens = apply_augmentation(q_tokens, self.opt) 105 | q_tokens = add_bos_eos(q_tokens, self.tokenizer.bos_token_id, self.tokenizer.eos_token_id) 106 | k_tokens = apply_augmentation(k_tokens, self.opt) 107 | k_tokens = add_bos_eos(k_tokens, self.tokenizer.bos_token_id, self.tokenizer.eos_token_id) 108 | 109 | return {"q_tokens": q_tokens, "k_tokens": k_tokens} 110 | 111 | def generate_offset(self): 112 | self.offset = random.randint(0, self.chunk_length - 1) 113 | 114 | 115 | class Collator(object): 116 | def __init__(self, opt): 117 | self.opt = opt 118 | 119 | def __call__(self, batch_examples): 120 | 121 | batch = defaultdict(list) 122 | for example in batch_examples: 123 | for k, v in example.items(): 124 | batch[k].append(v) 125 | 126 | q_tokens, q_mask = build_mask(batch["q_tokens"]) 127 | k_tokens, k_mask = build_mask(batch["k_tokens"]) 128 | 129 | batch["q_tokens"] = q_tokens 130 | batch["q_mask"] = q_mask 131 | batch["k_tokens"] = k_tokens 132 | batch["k_mask"] = k_mask 133 | 134 | return batch 135 | 136 | 137 | def randomcrop(x, ratio_min, ratio_max): 138 | 139 | ratio = random.uniform(ratio_min, ratio_max) 140 | length = int(len(x) * ratio) 141 | start = random.randint(0, len(x) - length) 142 | end = start + length 143 | crop = x[start:end].clone() 144 | return crop 145 | 146 | 147 | def build_mask(tensors): 148 | shapes = [x.shape for x in tensors] 149 | maxlength = max([len(x) for x in tensors]) 150 | returnmasks = [] 151 | ids = [] 152 | for k, x in enumerate(tensors): 153 | returnmasks.append(torch.tensor([1] * len(x) + [0] * (maxlength - len(x)))) 154 | ids.append(torch.cat((x, torch.tensor([0] * (maxlength - len(x)))))) 155 | ids = torch.stack(ids, dim=0).long() 156 | returnmasks = torch.stack(returnmasks, dim=0).bool() 157 | return ids, returnmasks 158 | 159 | 160 | def add_token(x, token): 161 | x = torch.cat((torch.tensor([token]), x)) 162 | return x 163 | 164 | 165 | def deleteword(x, p=0.1): 166 | mask = np.random.rand(len(x)) 167 | x = [e for e, m in zip(x, mask) if m > p] 168 | return x 169 | 170 | 171 | def replaceword(x, min_random, max_random, p=0.1): 172 | mask = np.random.rand(len(x)) 173 | x = [e if m > p else random.randint(min_random, max_random) for e, m in zip(x, mask)] 174 | return x 175 | 176 | 177 | def maskword(x, mask_id, p=0.1): 178 | mask = np.random.rand(len(x)) 179 | x = [e if m > p else mask_id for e, m in zip(x, mask)] 180 | return x 181 | 182 | 183 | def shuffleword(x, p=0.1): 184 | count = (np.random.rand(len(x)) < p).sum() 185 | """Shuffles any n number of values in a list""" 186 | indices_to_shuffle = random.sample(range(len(x)), k=count) 187 | to_shuffle = [x[i] for i in indices_to_shuffle] 188 | random.shuffle(to_shuffle) 189 | for index, value in enumerate(to_shuffle): 190 | old_index = indices_to_shuffle[index] 191 | x[old_index] = value 192 | return x 193 | 194 | 195 | def apply_augmentation(x, opt): 196 | if opt.augmentation == "mask": 197 | return torch.tensor(maskword(x, mask_id=opt.mask_id, p=opt.prob_augmentation)) 198 | elif opt.augmentation == "replace": 199 | return torch.tensor( 200 | replaceword(x, min_random=opt.start_id, max_random=opt.vocab_size - 1, p=opt.prob_augmentation) 201 | ) 202 | elif opt.augmentation == "delete": 203 | return torch.tensor(deleteword(x, p=opt.prob_augmentation)) 204 | elif opt.augmentation == "shuffle": 205 | return torch.tensor(shuffleword(x, p=opt.prob_augmentation)) 206 | else: 207 | if not isinstance(x, torch.Tensor): 208 | x = torch.Tensor(x) 209 | return x 210 | 211 | 212 | def add_bos_eos(x, bos_token_id, eos_token_id): 213 | if not isinstance(x, torch.Tensor): 214 | x = torch.Tensor(x) 215 | if bos_token_id is None and eos_token_id is not None: 216 | x = torch.cat([x.clone().detach(), torch.tensor([eos_token_id])]) 217 | elif bos_token_id is not None and eos_token_id is None: 218 | x = torch.cat([torch.tensor([bos_token_id]), x.clone().detach()]) 219 | elif bos_token_id is None and eos_token_id is None: 220 | pass 221 | else: 222 | x = torch.cat([torch.tensor([bos_token_id]), x.clone().detach(), torch.tensor([eos_token_id])]) 223 | return x 224 | 225 | 226 | # Used for passage retrieval 227 | def load_passages(path): 228 | if not os.path.exists(path): 229 | logger.info(f"{path} does not exist") 230 | return 231 | logger.info(f"Loading passages from: {path}") 232 | passages = [] 233 | with open(path) as fin: 234 | if path.endswith(".jsonl"): 235 | for k, line in enumerate(fin): 236 | ex = json.loads(line) 237 | passages.append(ex) 238 | else: 239 | reader = csv.reader(fin, delimiter="\t") 240 | for k, row in enumerate(reader): 241 | if not row[0] == "id": 242 | ex = {"id": row[0], "title": row[2], "text": row[1]} 243 | passages.append(ex) 244 | return passages 245 | -------------------------------------------------------------------------------- /src/dist_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | import torch 4 | import torch.distributed as dist 5 | 6 | 7 | class Gather(torch.autograd.Function): 8 | @staticmethod 9 | def forward(ctx, x: torch.tensor): 10 | output = [torch.zeros_like(x) for _ in range(dist.get_world_size())] 11 | dist.all_gather(output, x) 12 | return tuple(output) 13 | 14 | @staticmethod 15 | def backward(ctx, *grads): 16 | all_gradients = torch.stack(grads) 17 | dist.all_reduce(all_gradients) 18 | return all_gradients[dist.get_rank()] 19 | 20 | 21 | def gather(x: torch.tensor): 22 | if not dist.is_initialized(): 23 | return x 24 | x_gather = Gather.apply(x) 25 | x_gather = torch.cat(x_gather, dim=0) 26 | return x_gather 27 | 28 | 29 | @torch.no_grad() 30 | def gather_nograd(x: torch.tensor): 31 | if not dist.is_initialized(): 32 | return x 33 | x_gather = [torch.ones_like(x) for _ in range(dist.get_world_size())] 34 | dist.all_gather(x_gather, x, async_op=False) 35 | 36 | x_gather = torch.cat(x_gather, dim=0) 37 | return x_gather 38 | 39 | 40 | @torch.no_grad() 41 | def varsize_gather_nograd(x: torch.Tensor): 42 | """gather tensors of different sizes along the first dimension""" 43 | if not dist.is_initialized(): 44 | return x 45 | 46 | # determine max size 47 | size = torch.tensor([x.shape[0]], device=x.device, dtype=torch.int) 48 | allsizes = [torch.zeros_like(size) for _ in range(dist.get_world_size())] 49 | dist.all_gather(allsizes, size) 50 | max_size = max([size.cpu().max() for size in allsizes]) 51 | 52 | padded = torch.empty(max_size, *x.shape[1:], dtype=x.dtype, device=x.device) 53 | padded[: x.shape[0]] = x 54 | output = [torch.zeros_like(padded) for _ in range(dist.get_world_size())] 55 | dist.all_gather(output, padded) 56 | 57 | output = [tensor[: allsizes[k]] for k, tensor in enumerate(output)] 58 | output = torch.cat(output, dim=0) 59 | 60 | return output 61 | 62 | 63 | @torch.no_grad() 64 | def get_varsize(x: torch.Tensor): 65 | """gather tensors of different sizes along the first dimension""" 66 | if not dist.is_initialized(): 67 | return [x.shape[0]] 68 | 69 | # determine max size 70 | size = torch.tensor([x.shape[0]], device=x.device, dtype=torch.int) 71 | allsizes = [torch.zeros_like(size) for _ in range(dist.get_world_size())] 72 | dist.all_gather(allsizes, size) 73 | allsizes = torch.cat(allsizes) 74 | return allsizes 75 | 76 | 77 | def get_rank(): 78 | if not dist.is_available(): 79 | return 0 80 | if not dist.is_initialized(): 81 | return 0 82 | return dist.get_rank() 83 | 84 | 85 | def is_main(): 86 | return get_rank() == 0 87 | 88 | 89 | def get_world_size(): 90 | if not dist.is_initialized(): 91 | return 1 92 | else: 93 | return dist.get_world_size() 94 | 95 | 96 | def barrier(): 97 | if dist.is_initialized(): 98 | dist.barrier() 99 | 100 | 101 | def average_main(x): 102 | if not dist.is_initialized(): 103 | return x 104 | if dist.is_initialized() and dist.get_world_size() > 1: 105 | dist.reduce(x, 0, op=dist.ReduceOp.SUM) 106 | if is_main(): 107 | x = x / dist.get_world_size() 108 | return x 109 | 110 | 111 | def sum_main(x): 112 | if not dist.is_initialized(): 113 | return x 114 | if dist.is_initialized() and dist.get_world_size() > 1: 115 | dist.reduce(x, 0, op=dist.ReduceOp.SUM) 116 | return x 117 | 118 | 119 | def weighted_average(x, count): 120 | if not dist.is_initialized(): 121 | if isinstance(x, torch.Tensor): 122 | x = x.item() 123 | return x, count 124 | t_loss = torch.tensor([x * count]).cuda() 125 | t_total = torch.tensor([count]).cuda() 126 | t_loss = sum_main(t_loss) 127 | t_total = sum_main(t_total) 128 | return (t_loss / t_total).item(), t_total.item() 129 | -------------------------------------------------------------------------------- /src/evaluation.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import collections 9 | import logging 10 | import regex 11 | import string 12 | import unicodedata 13 | from functools import partial 14 | from multiprocessing import Pool as ProcessPool 15 | from typing import Tuple, List, Dict 16 | import numpy as np 17 | 18 | """ 19 | Evaluation code from DPR: https://github.com/facebookresearch/DPR 20 | """ 21 | 22 | class SimpleTokenizer(object): 23 | ALPHA_NUM = r'[\p{L}\p{N}\p{M}]+' 24 | NON_WS = r'[^\p{Z}\p{C}]' 25 | 26 | def __init__(self): 27 | """ 28 | Args: 29 | annotators: None or empty set (only tokenizes). 30 | """ 31 | self._regexp = regex.compile( 32 | '(%s)|(%s)' % (self.ALPHA_NUM, self.NON_WS), 33 | flags=regex.IGNORECASE + regex.UNICODE + regex.MULTILINE 34 | ) 35 | 36 | def tokenize(self, text, uncased=False): 37 | matches = [m for m in self._regexp.finditer(text)] 38 | if uncased: 39 | tokens = [m.group().lower() for m in matches] 40 | else: 41 | tokens = [m.group() for m in matches] 42 | return tokens 43 | 44 | logger = logging.getLogger(__name__) 45 | 46 | QAMatchStats = collections.namedtuple('QAMatchStats', ['top_k_hits', 'questions_doc_hits']) 47 | 48 | def calculate_matches(data: List, workers_num: int): 49 | """ 50 | Evaluates answers presence in the set of documents. This function is supposed to be used with a large collection of 51 | documents and results. It internally forks multiple sub-processes for evaluation and then merges results 52 | :param all_docs: dictionary of the entire documents database. doc_id -> (doc_text, title) 53 | :param answers: list of answers's list. One list per question 54 | :param closest_docs: document ids of the top results along with their scores 55 | :param workers_num: amount of parallel threads to process data 56 | :param match_type: type of answer matching. Refer to has_answer code for available options 57 | :return: matching information tuple. 58 | top_k_hits - a list where the index is the amount of top documents retrieved and the value is the total amount of 59 | valid matches across an entire dataset. 60 | questions_doc_hits - more detailed info with answer matches for every question and every retrieved document 61 | """ 62 | 63 | logger.info('Matching answers in top docs...') 64 | 65 | tokenizer = SimpleTokenizer() 66 | get_score_partial = partial(check_answer, tokenizer=tokenizer) 67 | 68 | processes = ProcessPool(processes=workers_num) 69 | scores = processes.map(get_score_partial, data) 70 | 71 | logger.info('Per question validation results len=%d', len(scores)) 72 | 73 | n_docs = len(data[0]['ctxs']) 74 | top_k_hits = [0] * n_docs 75 | for question_hits in scores: 76 | best_hit = next((i for i, x in enumerate(question_hits) if x), None) 77 | if best_hit is not None: 78 | top_k_hits[best_hit:] = [v + 1 for v in top_k_hits[best_hit:]] 79 | 80 | return QAMatchStats(top_k_hits, scores) 81 | 82 | def check_answer(example, tokenizer) -> List[bool]: 83 | """Search through all the top docs to see if they have any of the answers.""" 84 | answers = example['answers'] 85 | ctxs = example['ctxs'] 86 | 87 | hits = [] 88 | 89 | for i, doc in enumerate(ctxs): 90 | text = doc['text'] 91 | 92 | if text is None: # cannot find the document for some reason 93 | logger.warning("no doc in db") 94 | hits.append(False) 95 | continue 96 | 97 | hits.append(has_answer(answers, text, tokenizer)) 98 | 99 | return hits 100 | 101 | def has_answer(answers, text, tokenizer) -> bool: 102 | """Check if a document contains an answer string.""" 103 | text = _normalize(text) 104 | text = tokenizer.tokenize(text, uncased=True) 105 | 106 | for answer in answers: 107 | answer = _normalize(answer) 108 | answer = tokenizer.tokenize(answer, uncased=True) 109 | for i in range(0, len(text) - len(answer) + 1): 110 | if answer == text[i: i + len(answer)]: 111 | return True 112 | return False 113 | 114 | ################################################# 115 | ######## READER EVALUATION ######## 116 | ################################################# 117 | 118 | def _normalize(text): 119 | return unicodedata.normalize('NFD', text) 120 | 121 | #Normalization and score functions from SQuAD evaluation script https://worksheets.codalab.org/rest/bundles/0x6b567e1cf2e041ec80d7098f031c5c9e/contents/blob/ 122 | def normalize_answer(s): 123 | def remove_articles(text): 124 | return regex.sub(r'\b(a|an|the)\b', ' ', text) 125 | 126 | def white_space_fix(text): 127 | return ' '.join(text.split()) 128 | 129 | def remove_punc(text): 130 | exclude = set(string.punctuation) 131 | return ''.join(ch for ch in text if ch not in exclude) 132 | 133 | def lower(text): 134 | return text.lower() 135 | 136 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 137 | 138 | def em(prediction, ground_truth): 139 | return normalize_answer(prediction) == normalize_answer(ground_truth) 140 | 141 | def f1(prediction, ground_truth): 142 | prediction_tokens = normalize_answer(prediction).split() 143 | ground_truth_tokens = normalize_answer(ground_truth).split() 144 | common = Counter(prediction_tokens) & Counter(ground_truth_tokens) 145 | num_same = sum(common.values()) 146 | if num_same == 0: 147 | return 0 148 | precision = 1.0 * num_same / len(prediction_tokens) 149 | recall = 1.0 * num_same / len(ground_truth_tokens) 150 | f1 = (2 * precision * recall) / (precision + recall) 151 | return f1 152 | 153 | def f1_score(prediction, ground_truths): 154 | return max([f1(prediction, gt) for gt in ground_truths]) 155 | 156 | def exact_match_score(prediction, ground_truths): 157 | return max([em(prediction, gt) for gt in ground_truths]) 158 | 159 | #################################################### 160 | ######## RETRIEVER EVALUATION ######## 161 | #################################################### 162 | 163 | def eval_batch(scores, inversions, avg_topk, idx_topk): 164 | for k, s in enumerate(scores): 165 | s = s.cpu().numpy() 166 | sorted_idx = np.argsort(-s) 167 | score(sorted_idx, inversions, avg_topk, idx_topk) 168 | 169 | def count_inversions(arr): 170 | inv_count = 0 171 | lenarr = len(arr) 172 | for i in range(lenarr): 173 | for j in range(i + 1, lenarr): 174 | if (arr[i] > arr[j]): 175 | inv_count += 1 176 | return inv_count 177 | 178 | def score(x, inversions, avg_topk, idx_topk): 179 | x = np.array(x) 180 | inversions.append(count_inversions(x)) 181 | for k in avg_topk: 182 | # ratio of passages in the predicted top-k that are 183 | # also in the topk given by gold score 184 | avg_pred_topk = (x[:k] 0: 43 | random_negatives = random.sample(example["negative_ctxs"], n_random_negatives) 44 | negatives += random_negatives 45 | if n_hard_negatives > 0: 46 | hard_negatives = random.sample( 47 | example["hard_negative_ctxs"][self.negative_hard_min_idx :], n_hard_negatives 48 | ) 49 | negatives += hard_negatives 50 | else: 51 | gold = example["positive_ctxs"][0] 52 | nidx = 0 53 | if "negative_ctxs" in example: 54 | negatives = [example["negative_ctxs"][nidx]] 55 | else: 56 | negatives = [] 57 | 58 | gold = gold["title"] + " " + gold["text"] if "title" in gold and len(gold["title"]) > 0 else gold["text"] 59 | 60 | negatives = [ 61 | n["title"] + " " + n["text"] if ("title" in n and len(n["title"]) > 0) else n["text"] for n in negatives 62 | ] 63 | 64 | example = { 65 | "query": self.normalize_fn(question), 66 | "gold": self.normalize_fn(gold), 67 | "negatives": [self.normalize_fn(n) for n in negatives], 68 | } 69 | return example 70 | 71 | def _load_data(self, datapaths, global_rank, world_size, maxload): 72 | counter = 0 73 | self.data = [] 74 | for path in datapaths: 75 | path = str(path) 76 | if path.endswith(".jsonl"): 77 | file_data, counter = self._load_data_jsonl(path, global_rank, world_size, counter, maxload) 78 | elif path.endswith(".json"): 79 | file_data, counter = self._load_data_json(path, global_rank, world_size, counter, maxload) 80 | self.data.extend(file_data) 81 | if maxload is not None and maxload > 0 and counter >= maxload: 82 | break 83 | 84 | def _load_data_json(self, path, global_rank, world_size, counter, maxload=None): 85 | examples = [] 86 | with open(path, "r") as fin: 87 | data = json.load(fin) 88 | for example in data: 89 | counter += 1 90 | if global_rank > -1 and not counter % world_size == global_rank: 91 | continue 92 | examples.append(example) 93 | if maxload is not None and maxload > 0 and counter == maxload: 94 | break 95 | 96 | return examples, counter 97 | 98 | def _load_data_jsonl(self, path, global_rank, world_size, counter, maxload=None): 99 | examples = [] 100 | with open(path, "r") as fin: 101 | for line in fin: 102 | counter += 1 103 | if global_rank > -1 and not counter % world_size == global_rank: 104 | continue 105 | example = json.loads(line) 106 | examples.append(example) 107 | if maxload is not None and maxload > 0 and counter == maxload: 108 | break 109 | 110 | return examples, counter 111 | 112 | def sample_n_hard_negatives(self, ex): 113 | 114 | if "hard_negative_ctxs" in ex: 115 | n_hard_negatives = sum([random.random() < self.negative_hard_ratio for _ in range(self.negative_ctxs)]) 116 | n_hard_negatives = min(n_hard_negatives, len(ex["hard_negative_ctxs"][self.negative_hard_min_idx :])) 117 | else: 118 | n_hard_negatives = 0 119 | n_random_negatives = self.negative_ctxs - n_hard_negatives 120 | if "negative_ctxs" in ex: 121 | n_random_negatives = min(n_random_negatives, len(ex["negative_ctxs"])) 122 | else: 123 | n_random_negatives = 0 124 | return n_hard_negatives, n_random_negatives 125 | 126 | 127 | class Collator(object): 128 | def __init__(self, tokenizer, passage_maxlength=200): 129 | self.tokenizer = tokenizer 130 | self.passage_maxlength = passage_maxlength 131 | 132 | def __call__(self, batch): 133 | queries = [ex["query"] for ex in batch] 134 | golds = [ex["gold"] for ex in batch] 135 | negs = [item for ex in batch for item in ex["negatives"]] 136 | allpassages = golds + negs 137 | 138 | qout = self.tokenizer.batch_encode_plus( 139 | queries, 140 | max_length=self.passage_maxlength, 141 | truncation=True, 142 | padding=True, 143 | add_special_tokens=True, 144 | return_tensors="pt", 145 | ) 146 | kout = self.tokenizer.batch_encode_plus( 147 | allpassages, 148 | max_length=self.passage_maxlength, 149 | truncation=True, 150 | padding=True, 151 | add_special_tokens=True, 152 | return_tensors="pt", 153 | ) 154 | q_tokens, q_mask = qout["input_ids"], qout["attention_mask"].bool() 155 | k_tokens, k_mask = kout["input_ids"], kout["attention_mask"].bool() 156 | 157 | g_tokens, g_mask = k_tokens[: len(golds)], k_mask[: len(golds)] 158 | n_tokens, n_mask = k_tokens[len(golds) :], k_mask[len(golds) :] 159 | 160 | batch = { 161 | "q_tokens": q_tokens, 162 | "q_mask": q_mask, 163 | "k_tokens": k_tokens, 164 | "k_mask": k_mask, 165 | "g_tokens": g_tokens, 166 | "g_mask": g_mask, 167 | "n_tokens": n_tokens, 168 | "n_mask": n_mask, 169 | } 170 | 171 | return batch 172 | -------------------------------------------------------------------------------- /src/inbatch.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | import torch 4 | import torch.nn as nn 5 | import numpy as np 6 | import math 7 | import random 8 | import transformers 9 | import logging 10 | import torch.distributed as dist 11 | 12 | from src import contriever, dist_utils, utils 13 | 14 | logger = logging.getLogger(__name__) 15 | 16 | 17 | class InBatch(nn.Module): 18 | def __init__(self, opt, retriever=None, tokenizer=None): 19 | super(InBatch, self).__init__() 20 | 21 | self.opt = opt 22 | self.norm_doc = opt.norm_doc 23 | self.norm_query = opt.norm_query 24 | self.label_smoothing = opt.label_smoothing 25 | if retriever is None or tokenizer is None: 26 | retriever, tokenizer = self._load_retriever( 27 | opt.retriever_model_id, pooling=opt.pooling, random_init=opt.random_init 28 | ) 29 | self.tokenizer = tokenizer 30 | self.encoder = retriever 31 | 32 | def _load_retriever(self, model_id, pooling, random_init): 33 | cfg = utils.load_hf(transformers.AutoConfig, model_id) 34 | tokenizer = utils.load_hf(transformers.AutoTokenizer, model_id) 35 | 36 | if "xlm" in model_id: 37 | model_class = contriever.XLMRetriever 38 | else: 39 | model_class = contriever.Contriever 40 | 41 | if random_init: 42 | retriever = model_class(cfg) 43 | else: 44 | retriever = utils.load_hf(model_class, model_id) 45 | 46 | if "bert-" in model_id: 47 | if tokenizer.bos_token_id is None: 48 | tokenizer.bos_token = "[CLS]" 49 | if tokenizer.eos_token_id is None: 50 | tokenizer.eos_token = "[SEP]" 51 | 52 | retriever.config.pooling = pooling 53 | 54 | return retriever, tokenizer 55 | 56 | def get_encoder(self): 57 | return self.encoder 58 | 59 | def forward(self, q_tokens, q_mask, k_tokens, k_mask, stats_prefix="", iter_stats={}, **kwargs): 60 | 61 | bsz = len(q_tokens) 62 | labels = torch.arange(0, bsz, dtype=torch.long, device=q_tokens.device) 63 | 64 | qemb = self.encoder(input_ids=q_tokens, attention_mask=q_mask, normalize=self.norm_query) 65 | kemb = self.encoder(input_ids=k_tokens, attention_mask=k_mask, normalize=self.norm_doc) 66 | 67 | gather_fn = dist_utils.gather 68 | 69 | gather_kemb = gather_fn(kemb) 70 | 71 | labels = labels + dist_utils.get_rank() * len(kemb) 72 | 73 | scores = torch.einsum("id, jd->ij", qemb / self.opt.temperature, gather_kemb) 74 | 75 | loss = torch.nn.functional.cross_entropy(scores, labels, label_smoothing=self.label_smoothing) 76 | 77 | # log stats 78 | if len(stats_prefix) > 0: 79 | stats_prefix = stats_prefix + "/" 80 | iter_stats[f"{stats_prefix}loss"] = (loss.item(), bsz) 81 | 82 | predicted_idx = torch.argmax(scores, dim=-1) 83 | accuracy = 100 * (predicted_idx == labels).float().mean() 84 | stdq = torch.std(qemb, dim=0).mean().item() 85 | stdk = torch.std(kemb, dim=0).mean().item() 86 | iter_stats[f"{stats_prefix}accuracy"] = (accuracy, bsz) 87 | iter_stats[f"{stats_prefix}stdq"] = (stdq, bsz) 88 | iter_stats[f"{stats_prefix}stdk"] = (stdk, bsz) 89 | 90 | return loss, iter_stats 91 | -------------------------------------------------------------------------------- /src/index.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | import os 8 | import pickle 9 | from typing import List, Tuple 10 | 11 | import faiss 12 | import numpy as np 13 | from tqdm import tqdm 14 | 15 | class Indexer(object): 16 | 17 | def __init__(self, vector_sz, n_subquantizers=0, n_bits=8): 18 | if n_subquantizers > 0: 19 | self.index = faiss.IndexPQ(vector_sz, n_subquantizers, n_bits, faiss.METRIC_INNER_PRODUCT) 20 | else: 21 | self.index = faiss.IndexFlatIP(vector_sz) 22 | #self.index_id_to_db_id = np.empty((0), dtype=np.int64) 23 | self.index_id_to_db_id = [] 24 | 25 | def index_data(self, ids, embeddings): 26 | self._update_id_mapping(ids) 27 | embeddings = embeddings.astype('float32') 28 | if not self.index.is_trained: 29 | self.index.train(embeddings) 30 | self.index.add(embeddings) 31 | 32 | print(f'Total data indexed {len(self.index_id_to_db_id)}') 33 | 34 | def search_knn(self, query_vectors: np.array, top_docs: int, index_batch_size: int = 2048) -> List[Tuple[List[object], List[float]]]: 35 | query_vectors = query_vectors.astype('float32') 36 | result = [] 37 | nbatch = (len(query_vectors)-1) // index_batch_size + 1 38 | for k in tqdm(range(nbatch)): 39 | start_idx = k*index_batch_size 40 | end_idx = min((k+1)*index_batch_size, len(query_vectors)) 41 | q = query_vectors[start_idx: end_idx] 42 | scores, indexes = self.index.search(q, top_docs) 43 | # convert to external ids 44 | db_ids = [[str(self.index_id_to_db_id[i]) for i in query_top_idxs] for query_top_idxs in indexes] 45 | result.extend([(db_ids[i], scores[i]) for i in range(len(db_ids))]) 46 | return result 47 | 48 | def serialize(self, dir_path): 49 | index_file = os.path.join(dir_path, 'index.faiss') 50 | meta_file = os.path.join(dir_path, 'index_meta.faiss') 51 | print(f'Serializing index to {index_file}, meta data to {meta_file}') 52 | 53 | faiss.write_index(self.index, index_file) 54 | with open(meta_file, mode='wb') as f: 55 | pickle.dump(self.index_id_to_db_id, f) 56 | 57 | def deserialize_from(self, dir_path): 58 | index_file = os.path.join(dir_path, 'index.faiss') 59 | meta_file = os.path.join(dir_path, 'index_meta.faiss') 60 | print(f'Loading index from {index_file}, meta data from {meta_file}') 61 | 62 | self.index = faiss.read_index(index_file) 63 | print('Loaded index of type %s and size %d', type(self.index), self.index.ntotal) 64 | 65 | with open(meta_file, "rb") as reader: 66 | self.index_id_to_db_id = pickle.load(reader) 67 | assert len( 68 | self.index_id_to_db_id) == self.index.ntotal, 'Deserialized index_id_to_db_id should match faiss index size' 69 | 70 | def _update_id_mapping(self, db_ids: List): 71 | #new_ids = np.array(db_ids, dtype=np.int64) 72 | #self.index_id_to_db_id = np.concatenate((self.index_id_to_db_id, new_ids), axis=0) 73 | self.index_id_to_db_id.extend(db_ids) -------------------------------------------------------------------------------- /src/moco.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | import torch 4 | import torch.nn as nn 5 | import logging 6 | import copy 7 | import transformers 8 | 9 | from src import contriever, dist_utils, utils 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | class MoCo(nn.Module): 15 | def __init__(self, opt): 16 | super(MoCo, self).__init__() 17 | 18 | self.queue_size = opt.queue_size 19 | self.momentum = opt.momentum 20 | self.temperature = opt.temperature 21 | self.label_smoothing = opt.label_smoothing 22 | self.norm_doc = opt.norm_doc 23 | self.norm_query = opt.norm_query 24 | self.moco_train_mode_encoder_k = opt.moco_train_mode_encoder_k # apply the encoder on keys in train mode 25 | 26 | retriever, tokenizer = self._load_retriever( 27 | opt.retriever_model_id, pooling=opt.pooling, random_init=opt.random_init 28 | ) 29 | 30 | self.tokenizer = tokenizer 31 | self.encoder_q = retriever 32 | self.encoder_k = copy.deepcopy(retriever) 33 | 34 | for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()): 35 | param_k.data.copy_(param_q.data) 36 | param_k.requires_grad = False 37 | 38 | # create the queue 39 | self.register_buffer("queue", torch.randn(opt.projection_size, self.queue_size)) 40 | self.queue = nn.functional.normalize(self.queue, dim=0) 41 | 42 | self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long)) 43 | 44 | def _load_retriever(self, model_id, pooling, random_init): 45 | cfg = utils.load_hf(transformers.AutoConfig, model_id) 46 | tokenizer = utils.load_hf(transformers.AutoTokenizer, model_id) 47 | 48 | if "xlm" in model_id: 49 | model_class = contriever.XLMRetriever 50 | else: 51 | model_class = contriever.Contriever 52 | 53 | if random_init: 54 | retriever = model_class(cfg) 55 | else: 56 | retriever = utils.load_hf(model_class, model_id) 57 | 58 | if "bert-" in model_id: 59 | if tokenizer.bos_token_id is None: 60 | tokenizer.bos_token = "[CLS]" 61 | if tokenizer.eos_token_id is None: 62 | tokenizer.eos_token = "[SEP]" 63 | 64 | retriever.config.pooling = pooling 65 | 66 | return retriever, tokenizer 67 | 68 | def get_encoder(self, return_encoder_k=False): 69 | if return_encoder_k: 70 | return self.encoder_k 71 | else: 72 | return self.encoder_q 73 | 74 | def _momentum_update_key_encoder(self): 75 | """ 76 | Update of the key encoder 77 | """ 78 | for param_q, param_k in zip(self.encoder_q.parameters(), self.encoder_k.parameters()): 79 | param_k.data = param_k.data * self.momentum + param_q.data * (1.0 - self.momentum) 80 | 81 | @torch.no_grad() 82 | def _dequeue_and_enqueue(self, keys): 83 | # gather keys before updating queue 84 | keys = dist_utils.gather_nograd(keys.contiguous()) 85 | 86 | batch_size = keys.shape[0] 87 | 88 | ptr = int(self.queue_ptr) 89 | assert self.queue_size % batch_size == 0, f"{batch_size}, {self.queue_size}" # for simplicity 90 | 91 | # replace the keys at ptr (dequeue and enqueue) 92 | self.queue[:, ptr : ptr + batch_size] = keys.T 93 | ptr = (ptr + batch_size) % self.queue_size # move pointer 94 | 95 | self.queue_ptr[0] = ptr 96 | 97 | def _compute_logits(self, q, k): 98 | l_pos = torch.einsum("nc,nc->n", [q, k]).unsqueeze(-1) 99 | l_neg = torch.einsum("nc,ck->nk", [q, self.queue.clone().detach()]) 100 | 101 | logits = torch.cat([l_pos, l_neg], dim=1) 102 | return logits 103 | 104 | def forward(self, q_tokens, q_mask, k_tokens, k_mask, stats_prefix="", iter_stats={}, **kwargs): 105 | bsz = q_tokens.size(0) 106 | 107 | q = self.encoder_q(input_ids=q_tokens, attention_mask=q_mask, normalize=self.norm_query) 108 | 109 | # compute key features 110 | with torch.no_grad(): # no gradient to keys 111 | self._momentum_update_key_encoder() # update the key encoder 112 | 113 | if not self.encoder_k.training and not self.moco_train_mode_encoder_k: 114 | self.encoder_k.eval() 115 | 116 | k = self.encoder_k(input_ids=k_tokens, attention_mask=k_mask, normalize=self.norm_doc) 117 | 118 | logits = self._compute_logits(q, k) / self.temperature 119 | 120 | # labels: positive key indicators 121 | labels = torch.zeros(bsz, dtype=torch.long).cuda() 122 | 123 | loss = torch.nn.functional.cross_entropy(logits, labels, label_smoothing=self.label_smoothing) 124 | 125 | self._dequeue_and_enqueue(k) 126 | 127 | # log stats 128 | if len(stats_prefix) > 0: 129 | stats_prefix = stats_prefix + "/" 130 | iter_stats[f"{stats_prefix}loss"] = (loss.item(), bsz) 131 | 132 | predicted_idx = torch.argmax(logits, dim=-1) 133 | accuracy = 100 * (predicted_idx == labels).float().mean() 134 | stdq = torch.std(q, dim=0).mean().item() 135 | stdk = torch.std(k, dim=0).mean().item() 136 | iter_stats[f"{stats_prefix}accuracy"] = (accuracy, bsz) 137 | iter_stats[f"{stats_prefix}stdq"] = (stdq, bsz) 138 | iter_stats[f"{stats_prefix}stdk"] = (stdk, bsz) 139 | 140 | return loss, iter_stats 141 | -------------------------------------------------------------------------------- /src/normalize_text.py: -------------------------------------------------------------------------------- 1 | """ 2 | adapted from chemdataextractor.text.normalize 3 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 4 | Tools for normalizing text. 5 | https://github.com/mcs07/ChemDataExtractor 6 | :copyright: Copyright 2016 by Matt Swain. 7 | :license: MIT 8 | 9 | Permission is hereby granted, free of charge, to any person obtaining 10 | a copy of this software and associated documentation files (the 11 | 'Software'), to deal in the Software without restriction, including 12 | without limitation the rights to use, copy, modify, merge, publish, 13 | distribute, sublicense, and/or sell copies of the Software, and to 14 | permit persons to whom the Software is furnished to do so, subject to 15 | the following conditions: 16 | 17 | The above copyright notice and this permission notice shall be 18 | included in all copies or substantial portions of the Software. 19 | 20 | THE SOFTWARE IS PROVIDED 'AS IS', WITHOUT WARRANTY OF ANY KIND, 21 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 22 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. 23 | IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY 24 | CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, 25 | TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE 26 | SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 27 | """ 28 | 29 | #: Control characters. 30 | CONTROLS = { 31 | '\u0001', '\u0002', '\u0003', '\u0004', '\u0005', '\u0006', '\u0007', '\u0008', '\u000e', '\u000f', '\u0011', 32 | '\u0012', '\u0013', '\u0014', '\u0015', '\u0016', '\u0017', '\u0018', '\u0019', '\u001a', '\u001b', 33 | } 34 | # There are further control characters, but they are instead replaced with a space by unicode normalization 35 | # '\u0009', '\u000a', '\u000b', '\u000c', '\u000d', '\u001c', '\u001d', '\u001e', '\u001f' 36 | 37 | 38 | #: Hyphen and dash characters. 39 | HYPHENS = { 40 | '-', # \u002d Hyphen-minus 41 | '‐', # \u2010 Hyphen 42 | '‑', # \u2011 Non-breaking hyphen 43 | '⁃', # \u2043 Hyphen bullet 44 | '‒', # \u2012 figure dash 45 | '–', # \u2013 en dash 46 | '—', # \u2014 em dash 47 | '―', # \u2015 horizontal bar 48 | } 49 | 50 | #: Minus characters. 51 | MINUSES = { 52 | '-', # \u002d Hyphen-minus 53 | '−', # \u2212 Minus 54 | '-', # \uff0d Full-width Hyphen-minus 55 | '⁻', # \u207b Superscript minus 56 | } 57 | 58 | #: Plus characters. 59 | PLUSES = { 60 | '+', # \u002b Plus 61 | '+', # \uff0b Full-width Plus 62 | '⁺', # \u207a Superscript plus 63 | } 64 | 65 | #: Slash characters. 66 | SLASHES = { 67 | '/', # \u002f Solidus 68 | '⁄', # \u2044 Fraction slash 69 | '∕', # \u2215 Division slash 70 | } 71 | 72 | #: Tilde characters. 73 | TILDES = { 74 | '~', # \u007e Tilde 75 | '˜', # \u02dc Small tilde 76 | '⁓', # \u2053 Swung dash 77 | '∼', # \u223c Tilde operator #in mbert vocab 78 | '∽', # \u223d Reversed tilde 79 | '∿', # \u223f Sine wave 80 | '〜', # \u301c Wave dash #in mbert vocab 81 | '~', # \uff5e Full-width tilde #in mbert vocab 82 | } 83 | 84 | #: Apostrophe characters. 85 | APOSTROPHES = { 86 | "'", # \u0027 87 | '’', # \u2019 88 | '՚', # \u055a 89 | 'Ꞌ', # \ua78b 90 | 'ꞌ', # \ua78c 91 | ''', # \uff07 92 | } 93 | 94 | #: Single quote characters. 95 | SINGLE_QUOTES = { 96 | "'", # \u0027 97 | '‘', # \u2018 98 | '’', # \u2019 99 | '‚', # \u201a 100 | '‛', # \u201b 101 | 102 | } 103 | 104 | #: Double quote characters. 105 | DOUBLE_QUOTES = { 106 | '"', # \u0022 107 | '“', # \u201c 108 | '”', # \u201d 109 | '„', # \u201e 110 | '‟', # \u201f 111 | } 112 | 113 | #: Accent characters. 114 | ACCENTS = { 115 | '`', # \u0060 116 | '´', # \u00b4 117 | } 118 | 119 | #: Prime characters. 120 | PRIMES = { 121 | '′', # \u2032 122 | '″', # \u2033 123 | '‴', # \u2034 124 | '‵', # \u2035 125 | '‶', # \u2036 126 | '‷', # \u2037 127 | '⁗', # \u2057 128 | } 129 | 130 | #: Quote characters, including apostrophes, single quotes, double quotes, accents and primes. 131 | QUOTES = APOSTROPHES | SINGLE_QUOTES | DOUBLE_QUOTES | ACCENTS | PRIMES 132 | 133 | def normalize(text): 134 | for control in CONTROLS: 135 | text = text.replace(control, '') 136 | text = text.replace('\u000b', ' ').replace('\u000c', ' ').replace(u'\u0085', ' ') 137 | 138 | for hyphen in HYPHENS | MINUSES: 139 | text = text.replace(hyphen, '-') 140 | text = text.replace('\u00ad', '') 141 | 142 | for double_quote in DOUBLE_QUOTES: 143 | text = text.replace(double_quote, '"') # \u0022 144 | for single_quote in (SINGLE_QUOTES | APOSTROPHES | ACCENTS): 145 | text = text.replace(single_quote, "'") # \u0027 146 | text = text.replace('′', "'") # \u2032 prime 147 | text = text.replace('‵', "'") # \u2035 reversed prime 148 | text = text.replace('″', "''") # \u2033 double prime 149 | text = text.replace('‶', "''") # \u2036 reversed double prime 150 | text = text.replace('‴', "'''") # \u2034 triple prime 151 | text = text.replace('‷', "'''") # \u2037 reversed triple prime 152 | text = text.replace('⁗', "''''") # \u2057 quadruple prime 153 | 154 | text = text.replace('…', '...').replace(' . . . ', ' ... ') # \u2026 155 | 156 | for slash in SLASHES: 157 | text = text.replace(slash, '/') 158 | 159 | #for tilde in TILDES: 160 | # text = text.replace(tilde, '~') 161 | 162 | return text 163 | -------------------------------------------------------------------------------- /src/options.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | import argparse 4 | import os 5 | 6 | 7 | class Options: 8 | def __init__(self): 9 | self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 10 | self.initialize() 11 | 12 | def initialize(self): 13 | # basic parameters 14 | self.parser.add_argument( 15 | "--output_dir", type=str, default="./checkpoint/my_experiments", help="models are saved here" 16 | ) 17 | self.parser.add_argument( 18 | "--train_data", 19 | nargs="+", 20 | default=[], 21 | help="Data used for training, passed as a list of directories splitted into tensor files.", 22 | ) 23 | self.parser.add_argument( 24 | "--eval_data", 25 | nargs="+", 26 | default=[], 27 | help="Data used for evaluation during finetuning, this option is not used during contrastive pre-training.", 28 | ) 29 | self.parser.add_argument( 30 | "--eval_datasets", nargs="+", default=[], help="List of datasets used for evaluation, in BEIR format" 31 | ) 32 | self.parser.add_argument( 33 | "--eval_datasets_dir", type=str, default="./", help="Directory where eval datasets are stored" 34 | ) 35 | self.parser.add_argument("--model_path", type=str, default="none", help="path for retraining") 36 | self.parser.add_argument("--continue_training", action="store_true") 37 | self.parser.add_argument("--num_workers", type=int, default=5) 38 | 39 | self.parser.add_argument("--chunk_length", type=int, default=256) 40 | self.parser.add_argument("--loading_mode", type=str, default="split") 41 | self.parser.add_argument("--lower_case", action="store_true", help="perform evaluation after lowercasing") 42 | self.parser.add_argument( 43 | "--sampling_coefficient", 44 | type=float, 45 | default=0.0, 46 | help="coefficient used for sampling between different datasets during training, \ 47 | by default sampling is uniform over datasets", 48 | ) 49 | self.parser.add_argument("--augmentation", type=str, default="none") 50 | self.parser.add_argument("--prob_augmentation", type=float, default=0.0) 51 | 52 | self.parser.add_argument("--dropout", type=float, default=0.1) 53 | self.parser.add_argument("--rho", type=float, default=0.05) 54 | 55 | self.parser.add_argument("--contrastive_mode", type=str, default="moco") 56 | self.parser.add_argument("--queue_size", type=int, default=65536) 57 | self.parser.add_argument("--temperature", type=float, default=1.0) 58 | self.parser.add_argument("--momentum", type=float, default=0.999) 59 | self.parser.add_argument("--moco_train_mode_encoder_k", action="store_true") 60 | self.parser.add_argument("--eval_normalize_text", action="store_true") 61 | self.parser.add_argument("--norm_query", action="store_true") 62 | self.parser.add_argument("--norm_doc", action="store_true") 63 | self.parser.add_argument("--projection_size", type=int, default=768) 64 | 65 | self.parser.add_argument("--ratio_min", type=float, default=0.1) 66 | self.parser.add_argument("--ratio_max", type=float, default=0.5) 67 | self.parser.add_argument("--score_function", type=str, default="dot") 68 | self.parser.add_argument("--retriever_model_id", type=str, default="bert-base-uncased") 69 | self.parser.add_argument("--pooling", type=str, default="average") 70 | self.parser.add_argument("--random_init", action="store_true", help="init model with random weights") 71 | 72 | # dataset parameters 73 | self.parser.add_argument("--per_gpu_batch_size", default=64, type=int, help="Batch size per GPU for training.") 74 | self.parser.add_argument( 75 | "--per_gpu_eval_batch_size", default=256, type=int, help="Batch size per GPU for evaluation." 76 | ) 77 | self.parser.add_argument("--total_steps", type=int, default=1000) 78 | self.parser.add_argument("--warmup_steps", type=int, default=-1) 79 | 80 | self.parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") 81 | self.parser.add_argument("--main_port", type=int, default=10001, help="Master port (for multi-node SLURM jobs)") 82 | self.parser.add_argument("--seed", type=int, default=0, help="random seed for initialization") 83 | # training parameters 84 | self.parser.add_argument("--optim", type=str, default="adamw") 85 | self.parser.add_argument("--scheduler", type=str, default="linear") 86 | self.parser.add_argument("--lr", type=float, default=1e-4, help="learning rate") 87 | self.parser.add_argument( 88 | "--lr_min_ratio", 89 | type=float, 90 | default=0.0, 91 | help="minimum learning rate at the end of the optimization schedule as a ratio of the learning rate", 92 | ) 93 | self.parser.add_argument("--weight_decay", type=float, default=0.01, help="learning rate") 94 | self.parser.add_argument("--beta1", type=float, default=0.9, help="beta1") 95 | self.parser.add_argument("--beta2", type=float, default=0.98, help="beta2") 96 | self.parser.add_argument("--eps", type=float, default=1e-6, help="eps") 97 | self.parser.add_argument( 98 | "--log_freq", type=int, default=100, help="log train stats every steps during training" 99 | ) 100 | self.parser.add_argument( 101 | "--eval_freq", type=int, default=500, help="evaluate model every steps during training" 102 | ) 103 | self.parser.add_argument("--save_freq", type=int, default=50000) 104 | self.parser.add_argument("--maxload", type=int, default=None) 105 | self.parser.add_argument("--label_smoothing", type=float, default=0.0) 106 | 107 | # finetuning options 108 | self.parser.add_argument("--negative_ctxs", type=int, default=1) 109 | self.parser.add_argument("--negative_hard_min_idx", type=int, default=0) 110 | self.parser.add_argument("--negative_hard_ratio", type=float, default=0.0) 111 | 112 | def print_options(self, opt): 113 | message = "" 114 | for k, v in sorted(vars(opt).items()): 115 | comment = "" 116 | default = self.parser.get_default(k) 117 | if v != default: 118 | comment = f"\t[default: %s]" % str(default) 119 | message += f"{str(k):>40}: {str(v):<40}{comment}\n" 120 | print(message, flush=True) 121 | model_dir = os.path.join(opt.output_dir, "models") 122 | if not os.path.exists(model_dir): 123 | os.makedirs(os.path.join(opt.output_dir, "models")) 124 | file_name = os.path.join(opt.output_dir, "opt.txt") 125 | with open(file_name, "wt") as opt_file: 126 | opt_file.write(message) 127 | opt_file.write("\n") 128 | 129 | def parse(self): 130 | opt, _ = self.parser.parse_known_args() 131 | # opt = self.parser.parse_args() 132 | return opt 133 | -------------------------------------------------------------------------------- /src/slurm.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # All rights reserved. 3 | # 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | from logging import getLogger 8 | import os 9 | import sys 10 | import torch 11 | import socket 12 | import signal 13 | import subprocess 14 | 15 | 16 | logger = getLogger() 17 | 18 | def sig_handler(signum, frame): 19 | logger.warning("Signal handler called with signal " + str(signum)) 20 | prod_id = int(os.environ['SLURM_PROCID']) 21 | logger.warning("Host: %s - Global rank: %i" % (socket.gethostname(), prod_id)) 22 | if prod_id == 0: 23 | logger.warning("Requeuing job " + os.environ['SLURM_JOB_ID']) 24 | os.system('scontrol requeue ' + os.environ['SLURM_JOB_ID']) 25 | else: 26 | logger.warning("Not the main process, no need to requeue.") 27 | sys.exit(-1) 28 | 29 | 30 | def term_handler(signum, frame): 31 | logger.warning("Signal handler called with signal " + str(signum)) 32 | logger.warning("Bypassing SIGTERM.") 33 | 34 | 35 | def init_signal_handler(): 36 | """ 37 | Handle signals sent by SLURM for time limit / pre-emption. 38 | """ 39 | signal.signal(signal.SIGUSR1, sig_handler) 40 | signal.signal(signal.SIGTERM, term_handler) 41 | 42 | 43 | def init_distributed_mode(params): 44 | """ 45 | Handle single and multi-GPU / multi-node / SLURM jobs. 46 | Initialize the following variables: 47 | - local_rank 48 | - global_rank 49 | - world_size 50 | """ 51 | is_slurm_job = 'SLURM_JOB_ID' in os.environ and not 'WORLD_SIZE' in os.environ 52 | has_local_rank = hasattr(params, 'local_rank') 53 | 54 | # SLURM job without torch.distributed.launch 55 | if is_slurm_job and has_local_rank: 56 | 57 | assert params.local_rank == -1 # on the cluster, this is handled by SLURM 58 | 59 | # local rank on the current node / global rank 60 | params.local_rank = int(os.environ['SLURM_LOCALID']) 61 | params.global_rank = int(os.environ['SLURM_PROCID']) 62 | params.world_size = int(os.environ['SLURM_NTASKS']) 63 | 64 | # define master address and master port 65 | hostnames = subprocess.check_output(['scontrol', 'show', 'hostnames', os.environ['SLURM_JOB_NODELIST']]) 66 | params.main_addr = hostnames.split()[0].decode('utf-8') 67 | assert 10001 <= params.main_port <= 20000 or params.world_size == 1 68 | 69 | # set environment variables for 'env://' 70 | os.environ['MASTER_ADDR'] = params.main_addr 71 | os.environ['MASTER_PORT'] = str(params.main_port) 72 | os.environ['WORLD_SIZE'] = str(params.world_size) 73 | os.environ['RANK'] = str(params.global_rank) 74 | is_distributed = True 75 | 76 | 77 | # multi-GPU job (local or multi-node) - jobs started with torch.distributed.launch 78 | elif has_local_rank and params.local_rank != -1: 79 | 80 | assert params.main_port == -1 81 | 82 | # read environment variables 83 | params.global_rank = int(os.environ['RANK']) 84 | params.world_size = int(os.environ['WORLD_SIZE']) 85 | 86 | is_distributed = True 87 | 88 | # local job (single GPU) 89 | else: 90 | params.local_rank = 0 91 | params.global_rank = 0 92 | params.world_size = 1 93 | is_distributed = False 94 | 95 | # set GPU device 96 | torch.cuda.set_device(params.local_rank) 97 | 98 | # initialize multi-GPU 99 | if is_distributed: 100 | 101 | # http://pytorch.apachecn.org/en/0.3.0/distributed.html#environment-variable-initialization 102 | # 'env://' will read these environment variables: 103 | # MASTER_PORT - required; has to be a free port on machine with rank 0 104 | # MASTER_ADDR - required (except for rank 0); address of rank 0 node 105 | # WORLD_SIZE - required; can be set either here, or in a call to init function 106 | # RANK - required; can be set either here, or in a call to init function 107 | 108 | #print("Initializing PyTorch distributed ...") 109 | torch.distributed.init_process_group( 110 | init_method='env://', 111 | backend='nccl', 112 | #world_size=params.world_size, 113 | #rank=params.global_rank, 114 | ) -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | import os 4 | import sys 5 | import logging 6 | import torch 7 | import errno 8 | from typing import Union, Tuple, List, Dict 9 | from collections import defaultdict 10 | 11 | from src import dist_utils 12 | 13 | Number = Union[float, int] 14 | 15 | logger = logging.getLogger(__name__) 16 | 17 | 18 | def init_logger(args, stdout_only=False): 19 | if torch.distributed.is_initialized(): 20 | torch.distributed.barrier() 21 | stdout_handler = logging.StreamHandler(sys.stdout) 22 | handlers = [stdout_handler] 23 | if not stdout_only: 24 | file_handler = logging.FileHandler(filename=os.path.join(args.output_dir, "run.log")) 25 | handlers.append(file_handler) 26 | logging.basicConfig( 27 | datefmt="%m/%d/%Y %H:%M:%S", 28 | level=logging.INFO if dist_utils.is_main() else logging.WARN, 29 | format="[%(asctime)s] {%(filename)s:%(lineno)d} %(levelname)s - %(message)s", 30 | handlers=handlers, 31 | ) 32 | return logger 33 | 34 | 35 | def symlink_force(target, link_name): 36 | try: 37 | os.symlink(target, link_name) 38 | except OSError as e: 39 | if e.errno == errno.EEXIST: 40 | os.remove(link_name) 41 | os.symlink(target, link_name) 42 | else: 43 | raise e 44 | 45 | 46 | def save(model, optimizer, scheduler, step, opt, dir_path, name): 47 | model_to_save = model.module if hasattr(model, "module") else model 48 | path = os.path.join(dir_path, "checkpoint") 49 | epoch_path = os.path.join(path, name) # "step-%s" % step) 50 | os.makedirs(epoch_path, exist_ok=True) 51 | cp = os.path.join(path, "latest") 52 | fp = os.path.join(epoch_path, "checkpoint.pth") 53 | checkpoint = { 54 | "step": step, 55 | "model": model_to_save.state_dict(), 56 | "optimizer": optimizer.state_dict(), 57 | "scheduler": scheduler.state_dict(), 58 | "opt": opt, 59 | } 60 | torch.save(checkpoint, fp) 61 | symlink_force(epoch_path, cp) 62 | if not name == "lastlog": 63 | logger.info(f"Saving model to {epoch_path}") 64 | 65 | 66 | def load(model_class, dir_path, opt, reset_params=False): 67 | epoch_path = os.path.realpath(dir_path) 68 | checkpoint_path = os.path.join(epoch_path, "checkpoint.pth") 69 | logger.info(f"loading checkpoint {checkpoint_path}") 70 | checkpoint = torch.load(checkpoint_path, map_location="cpu") 71 | opt_checkpoint = checkpoint["opt"] 72 | state_dict = checkpoint["model"] 73 | 74 | model = model_class(opt_checkpoint) 75 | model.load_state_dict(state_dict, strict=True) 76 | model = model.cuda() 77 | step = checkpoint["step"] 78 | if not reset_params: 79 | optimizer, scheduler = set_optim(opt_checkpoint, model) 80 | scheduler.load_state_dict(checkpoint["scheduler"]) 81 | optimizer.load_state_dict(checkpoint["optimizer"]) 82 | else: 83 | optimizer, scheduler = set_optim(opt, model) 84 | 85 | return model, optimizer, scheduler, opt_checkpoint, step 86 | 87 | 88 | ############ OPTIM 89 | 90 | 91 | class WarmupLinearScheduler(torch.optim.lr_scheduler.LambdaLR): 92 | def __init__(self, optimizer, warmup, total, ratio, last_epoch=-1): 93 | self.warmup = warmup 94 | self.total = total 95 | self.ratio = ratio 96 | super(WarmupLinearScheduler, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch) 97 | 98 | def lr_lambda(self, step): 99 | if step < self.warmup: 100 | return (1 - self.ratio) * step / float(max(1, self.warmup)) 101 | 102 | return max( 103 | 0.0, 104 | 1.0 + (self.ratio - 1) * (step - self.warmup) / float(max(1.0, self.total - self.warmup)), 105 | ) 106 | 107 | 108 | class CosineScheduler(torch.optim.lr_scheduler.LambdaLR): 109 | def __init__(self, optimizer, warmup, total, ratio=0.1, last_epoch=-1): 110 | self.warmup = warmup 111 | self.total = total 112 | self.ratio = ratio 113 | super(CosineScheduler, self).__init__(optimizer, self.lr_lambda, last_epoch=last_epoch) 114 | 115 | def lr_lambda(self, step): 116 | if step < self.warmup: 117 | return float(step) / self.warmup 118 | s = float(step - self.warmup) / (self.total - self.warmup) 119 | return self.ratio + (1.0 - self.ratio) * math.cos(0.5 * math.pi * s) 120 | 121 | 122 | def set_optim(opt, model): 123 | if opt.optim == "adamw": 124 | optimizer = torch.optim.AdamW( 125 | model.parameters(), lr=opt.lr, betas=(opt.beta1, opt.beta2), eps=opt.eps, weight_decay=opt.weight_decay 126 | ) 127 | else: 128 | raise NotImplementedError("optimizer class not implemented") 129 | 130 | scheduler_args = { 131 | "warmup": opt.warmup_steps, 132 | "total": opt.total_steps, 133 | "ratio": opt.lr_min_ratio, 134 | } 135 | if opt.scheduler == "linear": 136 | scheduler_class = WarmupLinearScheduler 137 | elif opt.scheduler == "cosine": 138 | scheduler_class = CosineScheduler 139 | else: 140 | raise ValueError 141 | scheduler = scheduler_class(optimizer, **scheduler_args) 142 | return optimizer, scheduler 143 | 144 | 145 | def get_parameters(net, verbose=False): 146 | num_params = 0 147 | for param in net.parameters(): 148 | num_params += param.numel() 149 | message = "[Network] Total number of parameters : %.6f M" % (num_params / 1e6) 150 | return message 151 | 152 | 153 | class WeightedAvgStats: 154 | """provides an average over a bunch of stats""" 155 | 156 | def __init__(self): 157 | self.raw_stats: Dict[str, float] = defaultdict(float) 158 | self.total_weights: Dict[str, float] = defaultdict(float) 159 | 160 | def update(self, vals: Dict[str, Tuple[Number, Number]]) -> None: 161 | for key, (value, weight) in vals.items(): 162 | self.raw_stats[key] += value * weight 163 | self.total_weights[key] += weight 164 | 165 | @property 166 | def stats(self) -> Dict[str, float]: 167 | return {x: self.raw_stats[x] / self.total_weights[x] for x in self.raw_stats.keys()} 168 | 169 | @property 170 | def tuple_stats(self) -> Dict[str, Tuple[float, float]]: 171 | return {x: (self.raw_stats[x] / self.total_weights[x], self.total_weights[x]) for x in self.raw_stats.keys()} 172 | 173 | def reset(self) -> None: 174 | self.raw_stats = defaultdict(float) 175 | self.total_weights = defaultdict(float) 176 | 177 | @property 178 | def average_stats(self) -> Dict[str, float]: 179 | keys = sorted(self.raw_stats.keys()) 180 | if torch.distributed.is_initialized(): 181 | torch.distributed.broadcast_object_list(keys, src=0) 182 | global_dict = {} 183 | for k in keys: 184 | if not k in self.total_weights: 185 | v = 0.0 186 | else: 187 | v = self.raw_stats[k] / self.total_weights[k] 188 | v, _ = dist_utils.weighted_average(v, self.total_weights[k]) 189 | global_dict[k] = v 190 | return global_dict 191 | 192 | 193 | def load_hf(object_class, model_name): 194 | try: 195 | obj = object_class.from_pretrained(model_name, local_files_only=True) 196 | except: 197 | obj = object_class.from_pretrained(model_name, local_files_only=False) 198 | return obj 199 | 200 | 201 | def init_tb_logger(output_dir): 202 | try: 203 | from torch.utils import tensorboard 204 | 205 | if dist_utils.is_main(): 206 | tb_logger = tensorboard.SummaryWriter(output_dir) 207 | else: 208 | tb_logger = None 209 | except: 210 | logger.warning("Tensorboard is not available.") 211 | tb_logger = None 212 | 213 | return tb_logger 214 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | import os 4 | import time 5 | import sys 6 | import torch 7 | import logging 8 | import json 9 | import numpy as np 10 | import random 11 | import pickle 12 | 13 | import torch.distributed as dist 14 | from torch.utils.data import DataLoader, RandomSampler 15 | 16 | from src.options import Options 17 | from src import data, beir_utils, slurm, dist_utils, utils 18 | from src import moco, inbatch 19 | 20 | 21 | logger = logging.getLogger(__name__) 22 | 23 | 24 | def train(opt, model, optimizer, scheduler, step): 25 | 26 | run_stats = utils.WeightedAvgStats() 27 | 28 | tb_logger = utils.init_tb_logger(opt.output_dir) 29 | 30 | logger.info("Data loading") 31 | if isinstance(model, torch.nn.parallel.DistributedDataParallel): 32 | tokenizer = model.module.tokenizer 33 | else: 34 | tokenizer = model.tokenizer 35 | collator = data.Collator(opt=opt) 36 | train_dataset = data.load_data(opt, tokenizer) 37 | logger.warning(f"Data loading finished for rank {dist_utils.get_rank()}") 38 | 39 | train_sampler = RandomSampler(train_dataset) 40 | train_dataloader = DataLoader( 41 | train_dataset, 42 | sampler=train_sampler, 43 | batch_size=opt.per_gpu_batch_size, 44 | drop_last=True, 45 | num_workers=opt.num_workers, 46 | collate_fn=collator, 47 | ) 48 | 49 | epoch = 1 50 | 51 | model.train() 52 | while step < opt.total_steps: 53 | train_dataset.generate_offset() 54 | 55 | logger.info(f"Start epoch {epoch}") 56 | for i, batch in enumerate(train_dataloader): 57 | step += 1 58 | 59 | batch = {key: value.cuda() if isinstance(value, torch.Tensor) else value for key, value in batch.items()} 60 | train_loss, iter_stats = model(**batch, stats_prefix="train") 61 | 62 | train_loss.backward() 63 | optimizer.step() 64 | 65 | scheduler.step() 66 | model.zero_grad() 67 | 68 | run_stats.update(iter_stats) 69 | 70 | if step % opt.log_freq == 0: 71 | log = f"{step} / {opt.total_steps}" 72 | for k, v in sorted(run_stats.average_stats.items()): 73 | log += f" | {k}: {v:.3f}" 74 | if tb_logger: 75 | tb_logger.add_scalar(k, v, step) 76 | log += f" | lr: {scheduler.get_last_lr()[0]:0.3g}" 77 | log += f" | Memory: {torch.cuda.max_memory_allocated()//1e9} GiB" 78 | 79 | logger.info(log) 80 | run_stats.reset() 81 | 82 | if step % opt.eval_freq == 0: 83 | if isinstance(model, torch.nn.parallel.DistributedDataParallel): 84 | encoder = model.module.get_encoder() 85 | else: 86 | encoder = model.get_encoder() 87 | eval_model( 88 | opt, query_encoder=encoder, doc_encoder=encoder, tokenizer=tokenizer, tb_logger=tb_logger, step=step 89 | ) 90 | 91 | if dist_utils.is_main(): 92 | utils.save(model, optimizer, scheduler, step, opt, opt.output_dir, f"lastlog") 93 | 94 | model.train() 95 | 96 | if dist_utils.is_main() and step % opt.save_freq == 0: 97 | utils.save(model, optimizer, scheduler, step, opt, opt.output_dir, f"step-{step}") 98 | 99 | if step > opt.total_steps: 100 | break 101 | epoch += 1 102 | 103 | 104 | def eval_model(opt, query_encoder, doc_encoder, tokenizer, tb_logger, step): 105 | for datasetname in opt.eval_datasets: 106 | metrics = beir_utils.evaluate_model( 107 | query_encoder, 108 | doc_encoder, 109 | tokenizer, 110 | dataset=datasetname, 111 | batch_size=opt.per_gpu_eval_batch_size, 112 | norm_doc=opt.norm_doc, 113 | norm_query=opt.norm_query, 114 | beir_dir=opt.eval_datasets_dir, 115 | score_function=opt.score_function, 116 | lower_case=opt.lower_case, 117 | normalize_text=opt.eval_normalize_text, 118 | ) 119 | 120 | message = [] 121 | if dist_utils.is_main(): 122 | for metric in ["NDCG@10", "Recall@10", "Recall@100"]: 123 | message.append(f"{datasetname}/{metric}: {metrics[metric]:.2f}") 124 | if tb_logger is not None: 125 | tb_logger.add_scalar(f"{datasetname}/{metric}", metrics[metric], step) 126 | logger.info(" | ".join(message)) 127 | 128 | 129 | if __name__ == "__main__": 130 | logger.info("Start") 131 | 132 | options = Options() 133 | opt = options.parse() 134 | 135 | torch.manual_seed(opt.seed) 136 | slurm.init_distributed_mode(opt) 137 | slurm.init_signal_handler() 138 | 139 | directory_exists = os.path.isdir(opt.output_dir) 140 | if dist.is_initialized(): 141 | dist.barrier() 142 | os.makedirs(opt.output_dir, exist_ok=True) 143 | if not directory_exists and dist_utils.is_main(): 144 | options.print_options(opt) 145 | if dist.is_initialized(): 146 | dist.barrier() 147 | utils.init_logger(opt) 148 | 149 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 150 | 151 | if opt.contrastive_mode == "moco": 152 | model_class = moco.MoCo 153 | elif opt.contrastive_mode == "inbatch": 154 | model_class = inbatch.InBatch 155 | else: 156 | raise ValueError(f"contrastive mode: {opt.contrastive_mode} not recognised") 157 | 158 | if not directory_exists and opt.model_path == "none": 159 | model = model_class(opt) 160 | model = model.cuda() 161 | optimizer, scheduler = utils.set_optim(opt, model) 162 | step = 0 163 | elif directory_exists: 164 | model_path = os.path.join(opt.output_dir, "checkpoint", "latest") 165 | model, optimizer, scheduler, opt_checkpoint, step = utils.load( 166 | model_class, 167 | model_path, 168 | opt, 169 | reset_params=False, 170 | ) 171 | logger.info(f"Model loaded from {opt.output_dir}") 172 | else: 173 | model, optimizer, scheduler, opt_checkpoint, step = utils.load( 174 | model_class, 175 | opt.model_path, 176 | opt, 177 | reset_params=False if opt.continue_training else True, 178 | ) 179 | if not opt.continue_training: 180 | step = 0 181 | logger.info(f"Model loaded from {opt.model_path}") 182 | 183 | logger.info(utils.get_parameters(model)) 184 | 185 | if dist.is_initialized(): 186 | model = torch.nn.parallel.DistributedDataParallel( 187 | model, 188 | device_ids=[opt.local_rank], 189 | output_device=opt.local_rank, 190 | find_unused_parameters=False, 191 | ) 192 | dist.barrier() 193 | 194 | logger.info("Start training") 195 | train(opt, model, optimizer, scheduler, step) 196 | --------------------------------------------------------------------------------