├── CHANGELOG.md ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── conf ├── README.md ├── biencoder_train_cfg.yaml ├── ctx_sources │ └── default_sources.yaml ├── datasets │ ├── encoder_train_default.yaml │ └── retriever_default.yaml ├── dense_retriever.yaml ├── encoder │ └── hf_bert.yaml ├── extractive_reader_train_cfg.yaml ├── gen_embs.yaml └── train │ ├── biencoder_default.yaml │ ├── biencoder_local.yaml │ ├── biencoder_nq.yaml │ └── extractive_reader_default.yaml ├── dense_retriever.py ├── dpr ├── __init__.py ├── data │ ├── __init__.py │ ├── biencoder_data.py │ ├── download_data.py │ ├── qa_validation.py │ ├── reader_data.py │ ├── retriever_data.py │ └── tables.py ├── indexer │ └── faiss_indexers.py ├── models │ ├── __init__.py │ ├── biencoder.py │ ├── fairseq_models.py │ ├── hf_models.py │ ├── pytext_models.py │ └── reader.py ├── options.py └── utils │ ├── __init__.py │ ├── conf_utils.py │ ├── data_utils.py │ ├── dist_utils.py │ ├── model_utils.py │ └── tokenizers.py ├── generate_dense_embeddings.py ├── setup.py ├── train_dense_encoder.py └── train_extractive_reader.py /CHANGELOG.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/DPR/a31212dc0a54dfa85d8bfa01e1669f149ac832b7/CHANGELOG.md -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to make participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | ## Enforcement 56 | 57 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 58 | reported by contacting the project team at . All 59 | complaints will be reviewed and investigated and will result in a response that 60 | is deemed necessary and appropriate to the circumstances. The project team is 61 | obligated to maintain confidentiality with regard to the reporter of an incident. 62 | Further details of specific enforcement policies may be posted separately. 63 | 64 | Project maintainers who do not follow or enforce the Code of Conduct in good 65 | faith may face temporary or permanent repercussions as determined by other 66 | members of the project's leadership. 67 | 68 | ## Attribution 69 | 70 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 71 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 72 | 73 | [homepage]: https://www.contributor-covenant.org 74 | 75 | For answers to common questions about this code of conduct, see 76 | https://www.contributor-covenant.org/faq -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to DPR 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Our Development Process 6 | TBD 7 | 8 | ## Pull Requests 9 | We actively welcome your pull requests. 10 | 11 | 1. Fork the repo and create your branch from `master`. 12 | 2. If you've added code that should be tested, add tests. 13 | 3. If you've changed APIs, update the documentation. 14 | 4. Ensure the test suite passes. 15 | 5. Make sure your code lints. 16 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 17 | 18 | ## Contributor License Agreement ("CLA") 19 | In order to accept your pull request, we need you to submit a CLA. You only need 20 | to do this once to work on any of Facebook's open source projects. 21 | 22 | Complete your CLA here: 23 | 24 | ## Issues 25 | We use GitHub issues to track public bugs. Please ensure your description is 26 | clear and has sufficient instructions to be able to reproduce the issue. 27 | 28 | ## Coding Style 29 | * 2 spaces for indentation rather than tabs 30 | * 120 character line length 31 | * ... 32 | 33 | ## License 34 | By contributing to Facebook AI Research Dense Passage Retriever toolkit, you agree that your contributions will be licensed 35 | under the LICENSE file in the root directory of this source tree. -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Attribution-NonCommercial 4.0 International 2 | 3 | ======================================================================= 4 | 5 | Creative Commons Corporation ("Creative Commons") is not a law firm and 6 | does not provide legal services or legal advice. Distribution of 7 | Creative Commons public licenses does not create a lawyer-client or 8 | other relationship. Creative Commons makes its licenses and related 9 | information available on an "as-is" basis. Creative Commons gives no 10 | warranties regarding its licenses, any material licensed under their 11 | terms and conditions, or any related information. Creative Commons 12 | disclaims all liability for damages resulting from their use to the 13 | fullest extent possible. 14 | 15 | Using Creative Commons Public Licenses 16 | 17 | Creative Commons public licenses provide a standard set of terms and 18 | conditions that creators and other rights holders may use to share 19 | original works of authorship and other material subject to copyright 20 | and certain other rights specified in the public license below. The 21 | following considerations are for informational purposes only, are not 22 | exhaustive, and do not form part of our licenses. 23 | 24 | Considerations for licensors: Our public licenses are 25 | intended for use by those authorized to give the public 26 | permission to use material in ways otherwise restricted by 27 | copyright and certain other rights. Our licenses are 28 | irrevocable. Licensors should read and understand the terms 29 | and conditions of the license they choose before applying it. 30 | Licensors should also secure all rights necessary before 31 | applying our licenses so that the public can reuse the 32 | material as expected. Licensors should clearly mark any 33 | material not subject to the license. This includes other CC- 34 | licensed material, or material used under an exception or 35 | limitation to copyright. More considerations for licensors: 36 | wiki.creativecommons.org/Considerations_for_licensors 37 | 38 | Considerations for the public: By using one of our public 39 | licenses, a licensor grants the public permission to use the 40 | licensed material under specified terms and conditions. If 41 | the licensor's permission is not necessary for any reason--for 42 | example, because of any applicable exception or limitation to 43 | copyright--then that use is not regulated by the license. Our 44 | licenses grant only permissions under copyright and certain 45 | other rights that a licensor has authority to grant. Use of 46 | the licensed material may still be restricted for other 47 | reasons, including because others have copyright or other 48 | rights in the material. A licensor may make special requests, 49 | such as asking that all changes be marked or described. 50 | Although not required by our licenses, you are encouraged to 51 | respect those requests where reasonable. More_considerations 52 | for the public: 53 | wiki.creativecommons.org/Considerations_for_licensees 54 | 55 | ======================================================================= 56 | 57 | Creative Commons Attribution-NonCommercial 4.0 International Public 58 | License 59 | 60 | By exercising the Licensed Rights (defined below), You accept and agree 61 | to be bound by the terms and conditions of this Creative Commons 62 | Attribution-NonCommercial 4.0 International Public License ("Public 63 | License"). To the extent this Public License may be interpreted as a 64 | contract, You are granted the Licensed Rights in consideration of Your 65 | acceptance of these terms and conditions, and the Licensor grants You 66 | such rights in consideration of benefits the Licensor receives from 67 | making the Licensed Material available under these terms and 68 | conditions. 69 | 70 | Section 1 -- Definitions. 71 | 72 | a. Adapted Material means material subject to Copyright and Similar 73 | Rights that is derived from or based upon the Licensed Material 74 | and in which the Licensed Material is translated, altered, 75 | arranged, transformed, or otherwise modified in a manner requiring 76 | permission under the Copyright and Similar Rights held by the 77 | Licensor. For purposes of this Public License, where the Licensed 78 | Material is a musical work, performance, or sound recording, 79 | Adapted Material is always produced where the Licensed Material is 80 | synched in timed relation with a moving image. 81 | 82 | b. Adapter's License means the license You apply to Your Copyright 83 | and Similar Rights in Your contributions to Adapted Material in 84 | accordance with the terms and conditions of this Public License. 85 | 86 | c. Copyright and Similar Rights means copyright and/or similar rights 87 | closely related to copyright including, without limitation, 88 | performance, broadcast, sound recording, and Sui Generis Database 89 | Rights, without regard to how the rights are labeled or 90 | categorized. For purposes of this Public License, the rights 91 | specified in Section 2(b)(1)-(2) are not Copyright and Similar 92 | Rights. 93 | d. Effective Technological Measures means those measures that, in the 94 | absence of proper authority, may not be circumvented under laws 95 | fulfilling obligations under Article 11 of the WIPO Copyright 96 | Treaty adopted on December 20, 1996, and/or similar international 97 | agreements. 98 | 99 | e. Exceptions and Limitations means fair use, fair dealing, and/or 100 | any other exception or limitation to Copyright and Similar Rights 101 | that applies to Your use of the Licensed Material. 102 | 103 | f. Licensed Material means the artistic or literary work, database, 104 | or other material to which the Licensor applied this Public 105 | License. 106 | 107 | g. Licensed Rights means the rights granted to You subject to the 108 | terms and conditions of this Public License, which are limited to 109 | all Copyright and Similar Rights that apply to Your use of the 110 | Licensed Material and that the Licensor has authority to license. 111 | 112 | h. Licensor means the individual(s) or entity(ies) granting rights 113 | under this Public License. 114 | 115 | i. NonCommercial means not primarily intended for or directed towards 116 | commercial advantage or monetary compensation. For purposes of 117 | this Public License, the exchange of the Licensed Material for 118 | other material subject to Copyright and Similar Rights by digital 119 | file-sharing or similar means is NonCommercial provided there is 120 | no payment of monetary compensation in connection with the 121 | exchange. 122 | 123 | j. Share means to provide material to the public by any means or 124 | process that requires permission under the Licensed Rights, such 125 | as reproduction, public display, public performance, distribution, 126 | dissemination, communication, or importation, and to make material 127 | available to the public including in ways that members of the 128 | public may access the material from a place and at a time 129 | individually chosen by them. 130 | 131 | k. Sui Generis Database Rights means rights other than copyright 132 | resulting from Directive 96/9/EC of the European Parliament and of 133 | the Council of 11 March 1996 on the legal protection of databases, 134 | as amended and/or succeeded, as well as other essentially 135 | equivalent rights anywhere in the world. 136 | 137 | l. You means the individual or entity exercising the Licensed Rights 138 | under this Public License. Your has a corresponding meaning. 139 | 140 | Section 2 -- Scope. 141 | 142 | a. License grant. 143 | 144 | 1. Subject to the terms and conditions of this Public License, 145 | the Licensor hereby grants You a worldwide, royalty-free, 146 | non-sublicensable, non-exclusive, irrevocable license to 147 | exercise the Licensed Rights in the Licensed Material to: 148 | 149 | a. reproduce and Share the Licensed Material, in whole or 150 | in part, for NonCommercial purposes only; and 151 | 152 | b. produce, reproduce, and Share Adapted Material for 153 | NonCommercial purposes only. 154 | 155 | 2. Exceptions and Limitations. For the avoidance of doubt, where 156 | Exceptions and Limitations apply to Your use, this Public 157 | License does not apply, and You do not need to comply with 158 | its terms and conditions. 159 | 160 | 3. Term. The term of this Public License is specified in Section 161 | 6(a). 162 | 163 | 4. Media and formats; technical modifications allowed. The 164 | Licensor authorizes You to exercise the Licensed Rights in 165 | all media and formats whether now known or hereafter created, 166 | and to make technical modifications necessary to do so. The 167 | Licensor waives and/or agrees not to assert any right or 168 | authority to forbid You from making technical modifications 169 | necessary to exercise the Licensed Rights, including 170 | technical modifications necessary to circumvent Effective 171 | Technological Measures. For purposes of this Public License, 172 | simply making modifications authorized by this Section 2(a) 173 | (4) never produces Adapted Material. 174 | 175 | 5. Downstream recipients. 176 | 177 | a. Offer from the Licensor -- Licensed Material. Every 178 | recipient of the Licensed Material automatically 179 | receives an offer from the Licensor to exercise the 180 | Licensed Rights under the terms and conditions of this 181 | Public License. 182 | 183 | b. No downstream restrictions. You may not offer or impose 184 | any additional or different terms or conditions on, or 185 | apply any Effective Technological Measures to, the 186 | Licensed Material if doing so restricts exercise of the 187 | Licensed Rights by any recipient of the Licensed 188 | Material. 189 | 190 | 6. No endorsement. Nothing in this Public License constitutes or 191 | may be construed as permission to assert or imply that You 192 | are, or that Your use of the Licensed Material is, connected 193 | with, or sponsored, endorsed, or granted official status by, 194 | the Licensor or others designated to receive attribution as 195 | provided in Section 3(a)(1)(A)(i). 196 | 197 | b. Other rights. 198 | 199 | 1. Moral rights, such as the right of integrity, are not 200 | licensed under this Public License, nor are publicity, 201 | privacy, and/or other similar personality rights; however, to 202 | the extent possible, the Licensor waives and/or agrees not to 203 | assert any such rights held by the Licensor to the limited 204 | extent necessary to allow You to exercise the Licensed 205 | Rights, but not otherwise. 206 | 207 | 2. Patent and trademark rights are not licensed under this 208 | Public License. 209 | 210 | 3. To the extent possible, the Licensor waives any right to 211 | collect royalties from You for the exercise of the Licensed 212 | Rights, whether directly or through a collecting society 213 | under any voluntary or waivable statutory or compulsory 214 | licensing scheme. In all other cases the Licensor expressly 215 | reserves any right to collect such royalties, including when 216 | the Licensed Material is used other than for NonCommercial 217 | purposes. 218 | 219 | Section 3 -- License Conditions. 220 | 221 | Your exercise of the Licensed Rights is expressly made subject to the 222 | following conditions. 223 | 224 | a. Attribution. 225 | 226 | 1. If You Share the Licensed Material (including in modified 227 | form), You must: 228 | 229 | a. retain the following if it is supplied by the Licensor 230 | with the Licensed Material: 231 | 232 | i. identification of the creator(s) of the Licensed 233 | Material and any others designated to receive 234 | attribution, in any reasonable manner requested by 235 | the Licensor (including by pseudonym if 236 | designated); 237 | 238 | ii. a copyright notice; 239 | 240 | iii. a notice that refers to this Public License; 241 | 242 | iv. a notice that refers to the disclaimer of 243 | warranties; 244 | 245 | v. a URI or hyperlink to the Licensed Material to the 246 | extent reasonably practicable; 247 | 248 | b. indicate if You modified the Licensed Material and 249 | retain an indication of any previous modifications; and 250 | 251 | c. indicate the Licensed Material is licensed under this 252 | Public License, and include the text of, or the URI or 253 | hyperlink to, this Public License. 254 | 255 | 2. You may satisfy the conditions in Section 3(a)(1) in any 256 | reasonable manner based on the medium, means, and context in 257 | which You Share the Licensed Material. For example, it may be 258 | reasonable to satisfy the conditions by providing a URI or 259 | hyperlink to a resource that includes the required 260 | information. 261 | 262 | 3. If requested by the Licensor, You must remove any of the 263 | information required by Section 3(a)(1)(A) to the extent 264 | reasonably practicable. 265 | 266 | 4. If You Share Adapted Material You produce, the Adapter's 267 | License You apply must not prevent recipients of the Adapted 268 | Material from complying with this Public License. 269 | 270 | Section 4 -- Sui Generis Database Rights. 271 | 272 | Where the Licensed Rights include Sui Generis Database Rights that 273 | apply to Your use of the Licensed Material: 274 | 275 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right 276 | to extract, reuse, reproduce, and Share all or a substantial 277 | portion of the contents of the database for NonCommercial purposes 278 | only; 279 | 280 | b. if You include all or a substantial portion of the database 281 | contents in a database in which You have Sui Generis Database 282 | Rights, then the database in which You have Sui Generis Database 283 | Rights (but not its individual contents) is Adapted Material; and 284 | 285 | c. You must comply with the conditions in Section 3(a) if You Share 286 | all or a substantial portion of the contents of the database. 287 | 288 | For the avoidance of doubt, this Section 4 supplements and does not 289 | replace Your obligations under this Public License where the Licensed 290 | Rights include other Copyright and Similar Rights. 291 | 292 | Section 5 -- Disclaimer of Warranties and Limitation of Liability. 293 | 294 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE 295 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS 296 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF 297 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, 298 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, 299 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR 300 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, 301 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT 302 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT 303 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. 304 | 305 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE 306 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, 307 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, 308 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, 309 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR 310 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN 311 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR 312 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR 313 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. 314 | 315 | c. The disclaimer of warranties and limitation of liability provided 316 | above shall be interpreted in a manner that, to the extent 317 | possible, most closely approximates an absolute disclaimer and 318 | waiver of all liability. 319 | 320 | Section 6 -- Term and Termination. 321 | 322 | a. This Public License applies for the term of the Copyright and 323 | Similar Rights licensed here. However, if You fail to comply with 324 | this Public License, then Your rights under this Public License 325 | terminate automatically. 326 | 327 | b. Where Your right to use the Licensed Material has terminated under 328 | Section 6(a), it reinstates: 329 | 330 | 1. automatically as of the date the violation is cured, provided 331 | it is cured within 30 days of Your discovery of the 332 | violation; or 333 | 334 | 2. upon express reinstatement by the Licensor. 335 | 336 | For the avoidance of doubt, this Section 6(b) does not affect any 337 | right the Licensor may have to seek remedies for Your violations 338 | of this Public License. 339 | 340 | c. For the avoidance of doubt, the Licensor may also offer the 341 | Licensed Material under separate terms or conditions or stop 342 | distributing the Licensed Material at any time; however, doing so 343 | will not terminate this Public License. 344 | 345 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public 346 | License. 347 | 348 | Section 7 -- Other Terms and Conditions. 349 | 350 | a. The Licensor shall not be bound by any additional or different 351 | terms or conditions communicated by You unless expressly agreed. 352 | 353 | b. Any arrangements, understandings, or agreements regarding the 354 | Licensed Material not stated herein are separate from and 355 | independent of the terms and conditions of this Public License. 356 | 357 | Section 8 -- Interpretation. 358 | 359 | a. For the avoidance of doubt, this Public License does not, and 360 | shall not be interpreted to, reduce, limit, restrict, or impose 361 | conditions on any use of the Licensed Material that could lawfully 362 | be made without permission under this Public License. 363 | 364 | b. To the extent possible, if any provision of this Public License is 365 | deemed unenforceable, it shall be automatically reformed to the 366 | minimum extent necessary to make it enforceable. If the provision 367 | cannot be reformed, it shall be severed from this Public License 368 | without affecting the enforceability of the remaining terms and 369 | conditions. 370 | 371 | c. No term or condition of this Public License will be waived and no 372 | failure to comply consented to unless expressly agreed to by the 373 | Licensor. 374 | 375 | d. Nothing in this Public License constitutes or may be interpreted 376 | as a limitation upon, or waiver of, any privileges and immunities 377 | that apply to the Licensor or You, including from the legal 378 | processes of any jurisdiction or authority. 379 | 380 | ======================================================================= 381 | 382 | Creative Commons is not a party to its public 383 | licenses. Notwithstanding, Creative Commons may elect to apply one of 384 | its public licenses to material it publishes and in those instances 385 | will be considered the “Licensor.” The text of the Creative Commons 386 | public licenses is dedicated to the public domain under the CC0 Public 387 | Domain Dedication. Except for the limited purpose of indicating that 388 | material is shared under a Creative Commons public license or as 389 | otherwise permitted by the Creative Commons policies published at 390 | creativecommons.org/policies, Creative Commons does not authorize the 391 | use of the trademark "Creative Commons" or any other trademark or logo 392 | of Creative Commons without its prior written consent including, 393 | without limitation, in connection with any unauthorized modifications 394 | to any of its public licenses or any other arrangements, 395 | understandings, or agreements concerning use of licensed material. For 396 | the avoidance of doubt, this paragraph does not form part of the 397 | public licenses. 398 | 399 | Creative Commons may be contacted at creativecommons.org. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Dense Passage Retrieval 2 | 3 | Dense Passage Retrieval (`DPR`) - is a set of tools and models for state-of-the-art open-domain Q&A research. 4 | It is based on the following paper: 5 | 6 | Vladimir Karpukhin, Barlas Oguz, Sewon Min, Patrick Lewis, Ledell Wu, Sergey Edunov, Danqi Chen, Wen-tau Yih. [Dense Passage Retrieval for Open-Domain Question Answering.](https://arxiv.org/abs/2004.04906) Proceedings of the 2020 Conference on Empirical Methods in Natural Language Processing (EMNLP), pages 6769–6781, 2020. 7 | 8 | If you find this work useful, please cite the following paper: 9 | 10 | ``` 11 | @inproceedings{karpukhin-etal-2020-dense, 12 | title = "Dense Passage Retrieval for Open-Domain Question Answering", 13 | author = "Karpukhin, Vladimir and Oguz, Barlas and Min, Sewon and Lewis, Patrick and Wu, Ledell and Edunov, Sergey and Chen, Danqi and Yih, Wen-tau", 14 | booktitle = "Proceedings of the 2020 Conference on Empirical Methods in Natural Language Processing (EMNLP)", 15 | month = nov, 16 | year = "2020", 17 | address = "Online", 18 | publisher = "Association for Computational Linguistics", 19 | url = "https://www.aclweb.org/anthology/2020.emnlp-main.550", 20 | doi = "10.18653/v1/2020.emnlp-main.550", 21 | pages = "6769--6781", 22 | } 23 | ``` 24 | 25 | If you're interesting in reproducing experimental results in the paper based on our model checkpoints (i.e., don't want to train the encoders from scratch), you might consider using the [Pyserini toolkit](https://github.com/castorini/pyserini/blob/master/docs/experiments-dpr.md), which has the experiments nicely packaged in via `pip`. 26 | Their toolkit also reports higher BM25 and hybrid scores. 27 | 28 | ## Features 29 | 1. Dense retriever model is based on bi-encoder architecture. 30 | 2. Extractive Q&A reader&ranker joint model inspired by [this](https://arxiv.org/abs/1911.03868) paper. 31 | 3. Related data pre- and post- processing tools. 32 | 4. Dense retriever component for inference time logic is based on FAISS index. 33 | 34 | ## New (March 2021) release 35 | DPR codebase is upgraded with a number of enhancements and new models. 36 | Major changes: 37 | 1. [Hydra](https://hydra.cc/)-based configuration for all the command line tools exept the data loader (to be converted soon) 38 | 2. Pluggable data processing layer to support custom datasets 39 | 3. New retrieval model checkpoint with better perfromance. 40 | 41 | ## New (March 2021) retrieval model 42 | A new bi-encoder model trained on NQ dataset only is now provided: a new checkpoint, training data, retrieval results and wikipedia embeddings. 43 | It is trained on the original DPR NQ train set and its version where hard negatives are mined using DPR index itself using the previous NQ checkpoint. 44 | A Bi-encoder model is trained from scratch using this new training data combined with our original NQ training data. This training scheme gives a nice retrieval performance boost. 45 | 46 | New vs old top-k documents retrieval accuracy on NQ test set (3610 questions). 47 | 48 | | Top-k passages | Original DPR NQ model | New DPR model | 49 | | ------------- |:-------------:| -----:| 50 | | 1 | 45.87 | 52.47 | 51 | | 5 | 68.14 | 72.24 | 52 | | 20 | 79.97 | 81.33 | 53 | | 100 | 85.87 | 87.29 | 54 | 55 | New model downloadable resources names (see how to use download_data script below): 56 | 57 | Checkpoint: checkpoint.retriever.single-adv-hn.nq.bert-base-encoder 58 | 59 | New training data: data.retriever.nq-adv-hn-train 60 | 61 | Retriever resutls for NQ test set: data.retriever_results.nq.single-adv-hn.test 62 | 63 | Wikipedia embeddings: data.retriever_results.nq.single-adv-hn.wikipedia_passages 64 | 65 | 66 | ## Installation 67 | 68 | Installation from the source. Python's virtual or Conda environments are recommended. 69 | 70 | ```bash 71 | git clone git@github.com:facebookresearch/DPR.git 72 | cd DPR 73 | pip install . 74 | ``` 75 | 76 | DPR is tested on Python 3.6+ and PyTorch 1.2.0+. 77 | DPR relies on third-party libraries for encoder code implementations. 78 | It currently supports Huggingface (version <=3.1.0) BERT, Pytext BERT and Fairseq RoBERTa encoder models. 79 | Due to generality of the tokenization process, DPR uses Huggingface tokenizers as of now. So Huggingface is the only required dependency, Pytext & Fairseq are optional. 80 | Install them separately if you want to use those encoders. 81 | 82 | 83 | ## Resources & Data formats 84 | First, you need to prepare data for either retriever or reader training. 85 | Each of the DPR components has its own input/output data formats. 86 | You can see format descriptions below. 87 | DPR provides NQ & Trivia preprocessed datasets (and model checkpoints) to be downloaded from the cloud using our dpr/data/download_data.py tool. One needs to specify the resource name to be downloaded. Run 'python data/download_data.py' to see all options. 88 | 89 | ```bash 90 | python data/download_data.py \ 91 | --resource {key from download_data.py's RESOURCES_MAP} \ 92 | [optional --output_dir {your location}] 93 | ``` 94 | The resource name matching is prefix-based. So if you need to download all data resources, just use --resource data. 95 | 96 | ## Retriever input data format 97 | The default data format of the Retriever training data is JSON. 98 | It contains pools of 2 types of negative passages per question, as well as positive passages and some additional information. 99 | 100 | ``` 101 | [ 102 | { 103 | "question": "....", 104 | "answers": ["...", "...", "..."], 105 | "positive_ctxs": [{ 106 | "title": "...", 107 | "text": "...." 108 | }], 109 | "negative_ctxs": ["..."], 110 | "hard_negative_ctxs": ["..."] 111 | }, 112 | ... 113 | ] 114 | ``` 115 | 116 | Elements' structure for negative_ctxs & hard_negative_ctxs is exactly the same as for positive_ctxs. 117 | The preprocessed data available for downloading also contains some extra attributes which may be useful for model modifications (like bm25 scores per passage). Still, they are not currently in use by DPR. 118 | 119 | You can download prepared NQ dataset used in the paper by using 'data.retriever.nq' key prefix. Only dev & train subsets are available in this format. 120 | We also provide question & answers only CSV data files for all train/dev/test splits. Those are used for the model evaluation since our NQ preprocessing step looses a part of original samples set. 121 | Use 'data.retriever.qas.*' resource keys to get respective sets for evaluation. 122 | 123 | ```bash 124 | python data/download_data.py 125 | --resource data.retriever 126 | [optional --output_dir {your location}] 127 | ``` 128 | 129 | ## DPR data formats and custom processing 130 | One can use their own data format and custom data parsing & loading logic by inherting from DPR's Dataset classes in dpr/data/{biencoder|retriever|reader}_data.py files and implementing load_data() and __getitem__() methods. See [DPR hydra configuration](https://github.com/facebookresearch/DPR/blob/master/conf/README.md) instructions. 131 | 132 | 133 | ## Retriever training 134 | Retriever training quality depends on its effective batch size. The one reported in the paper used 8 x 32GB GPUs. 135 | In order to start training on one machine: 136 | ```bash 137 | python train_dense_encoder.py \ 138 | train_datasets=[list of train datasets, comma separated without spaces] \ 139 | dev_datasets=[list of dev datasets, comma separated without spaces] \ 140 | train=biencoder_local \ 141 | output_dir={path to checkpoints dir} 142 | ``` 143 | 144 | Example for NQ dataset 145 | 146 | ```bash 147 | python train_dense_encoder.py \ 148 | train_datasets=[nq_train] \ 149 | dev_datasets=[nq_dev] \ 150 | train=biencoder_local \ 151 | output_dir={path to checkpoints dir} 152 | ``` 153 | 154 | DPR uses HuggingFace BERT-base as the encoder by default. Other ready options include Fairseq's ROBERTA and Pytext BERT models. 155 | One can select them by either changing encoder configuration files (conf/encoder/hf_bert.yaml) or providing a new configuration file in conf/encoder dir and enabling it with encoder={new file name} command line parameter. 156 | 157 | Notes: 158 | - If you want to use pytext bert or fairseq roberta, you will need to download pre-trained weights and specify encoder.pretrained_file parameter. Specify the dir location of the downloaded files for 'pretrained.fairseq.roberta-base' resource prefix for RoBERTa model or the file path for pytext BERT (resource name 'pretrained.pytext.bert-base.model'). 159 | - Validation and checkpoint saving happens according to train.eval_per_epoch parameter value. 160 | - There is no stop condition besides a specified amount of epochs to train (train.num_train_epochs configuration parameter). 161 | - Every evaluation saves a model checkpoint. 162 | - The best checkpoint is logged in the train process output. 163 | - Regular NLL classification loss validation for bi-encoder training can be replaced with average rank evaluation. It aggregates passage and question vectors from the input data passages pools, does large similarity matrix calculation for those representations and then averages the rank of the gold passage for each question. We found this metric more correlating with the final retrieval performance vs nll classification loss. Note however that this average rank validation works differently in DistributedDataParallel vs DataParallel PyTorch modes. See train.val_av_rank_* set of parameters to enable this mode and modify its settings. 164 | 165 | See the section 'Best hyperparameter settings' below as e2e example for our best setups. 166 | 167 | ## Retriever inference 168 | 169 | Generating representation vectors for the static documents dataset is a highly parallelizable process which can take up to a few days if computed on a single GPU. You might want to use multiple available GPU servers by running the script on each of them independently and specifying their own shards. 170 | 171 | ```bash 172 | python generate_dense_embeddings.py \ 173 | model_file={path to biencoder checkpoint} \ 174 | ctx_src={name of the passages resource, set to dpr_wiki to use our original wikipedia split} \ 175 | shard_id={shard_num, 0-based} num_shards={total number of shards} \ 176 | out_file={result files location + name PREFX} 177 | ``` 178 | 179 | The name of the resource for ctx_src parameter 180 | or just the source name from conf/ctx_sources/default_sources.yaml file. 181 | 182 | Note: you can use much large batch size here compared to training mode. For example, setting batch_size 128 for 2 GPU(16gb) server should work fine. 183 | You can download already generated wikipedia embeddings from our original model (trained on NQ dataset) using resource key 'data.retriever_results.nq.single.wikipedia_passages'. 184 | Embeddings resource name for the new better model 'data.retriever_results.nq.single-adv-hn.wikipedia_passages' 185 | 186 | We generally use the following params on 50 2-gpu nodes: batch_size=128 shard_id=0 num_shards=50 187 | 188 | 189 | 190 | ## Retriever validation against the entire set of documents: 191 | 192 | ```bash 193 | 194 | python dense_retriever.py \ 195 | model_file={path to a checkpoint downloaded from our download_data.py as 'checkpoint.retriever.single.nq.bert-base-encoder'} \ 196 | qa_dataset={the name os the test source} \ 197 | ctx_datatsets=[{list of passage sources's names, comma separated without spaces}] \ 198 | encoded_ctx_files=[{list of encoded document files glob expression, comma separated without spaces}] \ 199 | out_file={path to output json file with results} 200 | 201 | ``` 202 | 203 | For example, If your generated embeddings fpr two passages set as ~/myproject/embeddings_passages1/wiki_passages_* and ~/myproject/embeddings_passages2/wiki_passages_* files and want to evaluate on NQ dataset: 204 | 205 | ```bash 206 | python dense_retriever.py \ 207 | model_file={path to a checkpoint file} \ 208 | qa_dataset=nq_test \ 209 | ctx_datatsets=[dpr_wiki] \ 210 | encoded_ctx_files=[\"~/myproject/embeddings_passages1/wiki_passages_*\",\"~/myproject/embeddings_passages2/wiki_passages_*\"] \ 211 | out_file={path to output json file with results} 212 | ``` 213 | 214 | 215 | The tool writes retrieved results for subsequent reader model training into specified out_file. 216 | It is a json with the following format: 217 | 218 | ``` 219 | [ 220 | { 221 | "question": "...", 222 | "answers": ["...", "...", ... ], 223 | "ctxs": [ 224 | { 225 | "id": "...", # passage id from database tsv file 226 | "title": "", 227 | "text": "....", 228 | "score": "...", # retriever score 229 | "has_answer": true|false 230 | }, 231 | ] 232 | ``` 233 | Results are sorted by their similarity score, from most relevant to least relevant. 234 | 235 | By default, dense_retriever uses exhaustive search process, but you can opt in to use lossy index types. 236 | We provide HNSW and HNSW_SQ index options. 237 | Enabled them by indexer=hnsw or indexer=hnsw_sq command line arguments. 238 | Note that using this index may be useless from the research point of view since their fast retrieval process comes at the cost of much longer indexing time and higher RAM usage. 239 | The similarity score provided is the dot product for the default case of exhaustive search (indexer=flat) and L2 distance in a modified representations space in case of HNSW index. 240 | 241 | 242 | ## Reader model training 243 | ```bash 244 | python train_extractive_reader.py \ 245 | encoder.sequence_length=350 \ 246 | train_files={path to the retriever train set results file} \ 247 | dev_files={path to the retriever dev set results file} \ 248 | output_dir={path to output dir} 249 | ``` 250 | Default hyperparameters are set for a single node with 8 gpus setup. 251 | Modify them as needed in the conf/train/extractive_reader_default.yaml and conf/extractive_reader_train_cfg.yaml cpnfiguration files or override specific parameters from the command line. 252 | First time run will preprocess train_files & dev_files and convert them into serialized set of .pkl files in the same locaion and will use them on all subsequent runs. 253 | 254 | Notes: 255 | - If you want to use pytext bert or fairseq roberta, you will need to download pre-trained weights and specify encoder.pretrained_file parameter. Specify the dir location of the downloaded files for 'pretrained.fairseq.roberta-base' resource prefix for RoBERTa model or the file path for pytext BERT (resource name 'pretrained.pytext.bert-base.model'). 256 | - Reader training pipeline does model validation every train.eval_step batches 257 | - Like the bi-encoder, it saves model checkpoints on every validation 258 | - Like the bi-encoder, there is no stop condition besides a specified amount of epochs to train. 259 | - Like the bi-encoder, there is no best checkpoint selection logic, so one needs to select that based on dev set validation performance which is logged in the train process output. 260 | - Our current code only calculates the Exact Match metric. 261 | 262 | ## Reader model inference 263 | 264 | In order to make an inference, run `train_reader.py` without specifying `train_files`. Make sure to specify `model_file` with the path to the checkpoint, `passages_per_question_predict` with number of passages per question (being used when saving the prediction file), and `eval_top_docs` with a list of top passages threshold values from which to choose question's answer span (to be printed as logs). The example command line is as follows. 265 | 266 | ```bash 267 | python train_extractive_reader.py \ 268 | prediction_results_file={path to a file to write the results to} \ 269 | eval_top_docs=[10,20,40,50,80,100] \ 270 | dev_files={path to the retriever results file to evaluate} \ 271 | model_file= {path to the reader checkpoint} \ 272 | train.dev_batch_size=80 \ 273 | passages_per_question_predict=100 \ 274 | encoder.sequence_length=350 275 | ``` 276 | 277 | ## Distributed training 278 | Use Pytorch's distributed training launcher tool: 279 | 280 | ```bash 281 | python -m torch.distributed.launch \ 282 | --nproc_per_node={WORLD_SIZE} {non distributed scipt name & parameters} 283 | ``` 284 | Note: 285 | - all batch size related parameters are specified per gpu in distributed mode(DistributedDataParallel) and for all available gpus in DataParallel (single node - multi gpu) mode. 286 | 287 | ## Best hyperparameter settings 288 | 289 | e2e example with the best settings for NQ dataset. 290 | 291 | ### 1. Download all retriever training and validation data: 292 | 293 | ```bash 294 | python data/download_data.py --resource data.wikipedia_split.psgs_w100 295 | python data/download_data.py --resource data.retriever.nq 296 | python data/download_data.py --resource data.retriever.qas.nq 297 | ``` 298 | 299 | ### 2. Biencoder(Retriever) training in the single set mode. 300 | 301 | We used distributed training mode on a single 8 GPU x 32 GB server 302 | 303 | ```bash 304 | python -m torch.distributed.launch --nproc_per_node=8 305 | train_dense_encoder.py \ 306 | train=biencoder_nq \ 307 | train_datasets=[nq_train] \ 308 | dev_datasets=[nq_dev] \ 309 | train=biencoder_nq \ 310 | output_dir={your output dir} 311 | ``` 312 | 313 | New model training combines two NQ datatsets: 314 | 315 | ```bash 316 | python -m torch.distributed.launch --nproc_per_node=8 317 | train_dense_encoder.py \ 318 | train=biencoder_nq \ 319 | train_datasets=[nq_train,nq_train_hn1] \ 320 | dev_datasets=[nq_dev] \ 321 | train=biencoder_nq \ 322 | output_dir={your output dir} 323 | ``` 324 | 325 | This takes about a day to complete the training for 40 epochs. It switches to Average Rank validation on epoch 30 and it should be around 25 or less at the end. 326 | The best checkpoint for bi-encoder is usually the last, but it should not be so different if you take any after epoch ~ 25. 327 | 328 | ### 3. Generate embeddings for Wikipedia. 329 | Just use instructions for "Generating representations for large documents set". It takes about 40 minutes to produce 21 mln passages representation vectors on 50 2 GPU servers. 330 | 331 | ### 4. Evaluate retrieval accuracy and generate top passage results for each of the train/dev/test datasets. 332 | 333 | ```bash 334 | 335 | python dense_retriever.py \ 336 | model_file={path to the best checkpoint or use our proivded checkpoints (Resource names like checkpoint.retriever.*) } \ 337 | qa_dataset=nq_test \ 338 | ctx_datatsets=[dpr_wiki] \ 339 | encoded_ctx_files=["{glob expression for generated embedding files}"] \ 340 | out_file={path to the output file} 341 | ``` 342 | 343 | Adjust batch_size based on the available number of GPUs, 64-128 should work for 2 GPU server. 344 | 345 | ### 5. Reader training 346 | We trained reader model for large datasets using a single 8 GPU x 32 GB server. All the default parameters are already set to our best NQ settings. 347 | Please also download data.gold_passages_info.nq_train & data.gold_passages_info.nq_dev resources for NQ datatset - they are used for special NQ only heuristics when preprocessing the data for the NQ reader training. If you already run reader trianign on NQ data without gold_passages_src & gold_passages_src_dev specified, please delete the corresponding .pkl files so that thye will be re-generated. 348 | 349 | ```bash 350 | python train_extractive_reader.py \ 351 | encoder.sequence_length=350 \ 352 | train_files={path to the retriever train set results file} \ 353 | dev_files={path to the retriever dev set results file} \ 354 | gold_passages_src={path to data.gold_passages_info.nq_train file} \ 355 | gold_passages_src_dev={path to data.gold_passages_info.nq_dev file} \ 356 | output_dir={path to output dir} 357 | ``` 358 | 359 | We found that using the learning rate above works best with static schedule, so one needs to stop training manually based on evaluation performance dynamics. 360 | Our best results were achieved on 16-18 training epochs or after ~60k model updates. 361 | 362 | We provide all input and intermediate results for e2e pipeline for NQ dataset and most of the similar resources for Trivia. 363 | 364 | ## Misc. 365 | - TREC validation requires regexp based matching. We support only retriever validation in the regexp mode. See --match parameter option. 366 | - WebQ validation requires entity normalization, which is not included as of now. 367 | 368 | ## License 369 | DPR is CC-BY-NC 4.0 licensed as of now. 370 | -------------------------------------------------------------------------------- /conf/README.md: -------------------------------------------------------------------------------- 1 | ## Hydra 2 | 3 | [Hydra](https://github.com/facebookresearch/hydra) is an open-source Python 4 | framework that simplifies the development of research and other complex 5 | applications. The key feature is the ability to dynamically create a 6 | hierarchical configuration by composition and override it through config files 7 | and the command line. 8 | 9 | ## DPR configuration 10 | All DPR tools configuration parameters are now split between different config groups and you can either modify them in the config files or override from command line. 11 | 12 | Each tools's (train_dense_encoder.py, generate_dense_embeddings.py, dense_retriever.py and train_reader.py) main method has now a hydra @hydra.main decorator with the name of the configuration file in the conf/ dir. 13 | For example, dense_retriever.py takes all its parameters from conf/dense_retriever.yaml file. 14 | Every tool's configuration files refers to other configuration files via "defaults:" parameter. 15 | It is called a [configuration group](https://hydra.cc/docs/tutorials/structured_config/config_groups) in Hydra. 16 | 17 | Let's take a look at dense_retriever.py's configuration: 18 | 19 | 20 | ```yaml 21 | 22 | defaults: 23 | - encoder: hf_bert 24 | - datasets: retriever_default 25 | - ctx_sources: default_sources 26 | 27 | indexers: 28 | flat: 29 | _target_: dpr.indexer.faiss_indexers.DenseFlatIndexer 30 | 31 | hnsw: 32 | _target_: dpr.indexer.faiss_indexers.DenseHNSWFlatIndexer 33 | 34 | hnsw_sq: 35 | _target_: dpr.indexer.faiss_indexers.DenseHNSWSQIndexer 36 | 37 | ... 38 | qa_dataset: 39 | ... 40 | ctx_datatsets: 41 | ... 42 | indexer: flat 43 | ... 44 | 45 | ``` 46 | 47 | " - encoder: " - a configuration group that contains all parameters to instantiate the encoder. The actual parameters are located in conf/encoder/hf_bert.yaml file. 48 | If you want to override some of them, you can either 49 | - Modify that config file 50 | - Create a new config group file under conf/encoder/ folder and enable to use it by providing encoder={your file name} command line argument 51 | - Override specific parameter from command line. For example: encoder.sequence_length=300 52 | 53 | " - datasets:" - a configuration group that contains a list of all possible sources of queries for evaluation. One can find them in conf/datasets/retriever_default.yaml file. 54 | One should specify the dataset to use by providing qa_dataset parameter in order to use one of them during evaluation. For example, if you want to run the retriever on NQ test set, set qa_dataset=nq_test as a command line parameter. 55 | 56 | It is much easier now to use custom datasets, without the need to convert them to DPR format. Just define your own class that provides relevant __getitem__(), __len__() and load_data() methods (inherit from QASrc). 57 | 58 | " - ctx_sources: " - a configuration group that contains a list of all possible passage sources. One can find them in conf/ctx_sources/default_sources.yaml file. 59 | One should specify a list of names of the passages datasets as ctx_datatsets parameter. For example, if you want to use dpr's old wikipedia passages, set ctx_datatsets=[dpr_wiki]. 60 | Please note that this parameter is a list and you can effectively concatenate different passage source into one. In order to use multiple sources at once, one also needs to provide relevant embeddings files in encoded_ctx_files parameter, which is also a list. 61 | 62 | 63 | "indexers:" - a parameters map that defines various indexes. The actual index is selected by indexer parameter which is 'flat' by default but you can use loss index types by setting indexer=hnsw or indexer=hnsw_sq in the command line. 64 | 65 | Please refer to the configuration files comments for every parameter. 66 | -------------------------------------------------------------------------------- /conf/biencoder_train_cfg.yaml: -------------------------------------------------------------------------------- 1 | 2 | # configuration groups 3 | defaults: 4 | - encoder: hf_bert 5 | - train: biencoder_default 6 | - datasets: encoder_train_default 7 | 8 | train_datasets: 9 | dev_datasets: 10 | output_dir: 11 | train_sampling_rates: 12 | loss_scale_factors: 13 | 14 | # Whether to lower case the input text. Set True for uncased models, False for the cased ones. 15 | do_lower_case: True 16 | 17 | val_av_rank_start_epoch: 30 18 | seed: 12345 19 | checkpoint_file_name: dpr_biencoder 20 | 21 | # A trained bi-encoder checkpoint file to initialize the model 22 | model_file: 23 | 24 | # TODO: move to a conf group 25 | # local_rank for distributed training on gpus 26 | 27 | # TODO: rename to distributed_rank 28 | local_rank: -1 29 | global_loss_buf_sz: 592000 30 | device: 31 | distributed_world_size: 32 | distributed_port: 33 | distributed_init_method: 34 | 35 | no_cuda: False 36 | n_gpu: 37 | fp16: False 38 | 39 | # For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']." 40 | # "See details at https://nvidia.github.io/apex/amp.html 41 | fp16_opt_level: O1 42 | 43 | # tokens which won't be slit by tokenizer 44 | special_tokens: 45 | 46 | ignore_checkpoint_offset: False 47 | ignore_checkpoint_optimizer: False 48 | ignore_checkpoint_lr: False 49 | 50 | # set to >1 to enable multiple query encoders 51 | multi_q_encoder: False 52 | 53 | # Set to True to reduce memory footprint and loose a bit the full train data randomization if you train in DDP mode 54 | local_shards_dataloader: False -------------------------------------------------------------------------------- /conf/ctx_sources/default_sources.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | 3 | dpr_wiki: 4 | _target_: dpr.data.retriever_data.CsvCtxSrc 5 | file: data.wikipedia_split.psgs_w100 6 | id_prefix: 'wiki:' -------------------------------------------------------------------------------- /conf/datasets/encoder_train_default.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | 3 | nq_train: 4 | _target_: dpr.data.biencoder_data.JsonQADataset 5 | file: data.retriever.nq-train 6 | 7 | nq_train_hn1: 8 | _target_: dpr.data.biencoder_data.JsonQADataset 9 | file: data.retriever.nq-adv-hn-train 10 | 11 | nq_dev: 12 | _target_: dpr.data.biencoder_data.JsonQADataset 13 | file: data.retriever.nq-dev 14 | 15 | trivia_train: 16 | _target_: dpr.data.biencoder_data.JsonQADataset 17 | file: data.retriever.trivia-train 18 | 19 | trivia_dev: 20 | _target_: dpr.data.biencoder_data.JsonQADataset 21 | file: data.retriever.trivia-dev 22 | 23 | squad1_train: 24 | _target_: dpr.data.biencoder_data.JsonQADataset 25 | file: data.retriever.squad1-train 26 | 27 | squad1_dev: 28 | _target_: dpr.data.biencoder_data.JsonQADataset 29 | file: data.retriever.squad1-dev 30 | 31 | webq_train: 32 | _target_: dpr.data.biencoder_data.JsonQADataset 33 | file: data.retriever.webq-train 34 | 35 | webq_dev: 36 | _target_: dpr.data.biencoder_data.JsonQADataset 37 | file: data.retriever.webq-dev 38 | 39 | curatedtrec_train: 40 | _target_: dpr.data.biencoder_data.JsonQADataset 41 | file: data.retriever.curatedtrec-train 42 | 43 | curatedtrec_dev: 44 | _target_: dpr.data.biencoder_data.JsonQADataset 45 | file: data.retriever.curatedtrec-dev 46 | -------------------------------------------------------------------------------- /conf/datasets/retriever_default.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | 3 | nq_test: 4 | _target_: dpr.data.retriever_data.CsvQASrc 5 | file: data.retriever.qas.nq-test 6 | #query_special_suffix: '?' 7 | 8 | nq_train: 9 | _target_: dpr.data.retriever_data.CsvQASrc 10 | file: data.retriever.qas.nq-train 11 | 12 | nq_dev: 13 | _target_: dpr.data.retriever_data.CsvQASrc 14 | file: data.retriever.qas.nq-dev 15 | 16 | trivia_test: 17 | _target_: dpr.data.retriever_data.CsvQASrc 18 | file: data.retriever.qas.trivia-test 19 | 20 | trivia_train: 21 | _target_: dpr.data.retriever_data.CsvQASrc 22 | file: data.retriever.qas.trivia-train 23 | 24 | trivia_dev: 25 | _target_: dpr.data.retriever_data.CsvQASrc 26 | file: data.retriever.qas.trivia-dev 27 | 28 | webq_test: 29 | _target_: dpr.data.retriever_data.CsvQASrc 30 | file: data.retriever.qas.webq-test 31 | 32 | curatedtrec_test: 33 | _target_: dpr.data.retriever_data.CsvQASrc 34 | file: data.retriever.qas.curatedtrec-test 35 | 36 | -------------------------------------------------------------------------------- /conf/dense_retriever.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - encoder: hf_bert # defines encoder initialization parameters 3 | - datasets: retriever_default # contains a list of all possible sources of queries for evaluation. Specific set is selected by qa_dataset parameter 4 | - ctx_sources: default_sources # contains a list of all possible passage sources. Specific passages sources selected by ctx_datatsets parameter 5 | 6 | indexers: 7 | flat: 8 | _target_: dpr.indexer.faiss_indexers.DenseFlatIndexer 9 | 10 | hnsw: 11 | _target_: dpr.indexer.faiss_indexers.DenseHNSWFlatIndexer 12 | 13 | hnsw_sq: 14 | _target_: dpr.indexer.faiss_indexers.DenseHNSWSQIndexer 15 | 16 | # the name of the queries dataset from the 'datasets' config group 17 | qa_dataset: 18 | 19 | # a list of names of the passages datasets from the 'ctx_sources' config group 20 | ctx_datatsets: 21 | 22 | #Glob paths to encoded passages (from generate_dense_embeddings tool) 23 | encoded_ctx_files: [] 24 | 25 | out_file: 26 | # "regex" or "string" 27 | match: string 28 | n_docs: 100 29 | validation_workers: 16 30 | 31 | # Batch size to generate query embeddings 32 | batch_size: 128 33 | 34 | # Whether to lower case the input text. Set True for uncased models, False for the cased ones. 35 | do_lower_case: True 36 | 37 | # The attribute name of encoder to use for queries. Options for the BiEncoder model: question_model, ctx_model 38 | # question_model is used if this param is empty 39 | encoder_path: 40 | 41 | # path to the FAISS index location - it is only needed if you want to serialize faiss index to files or read from them 42 | # (instead of using encoded_ctx_files) 43 | # it should point to either directory or a common index files prefix name 44 | # if there is no index at the specific location, the index will be created from encoded_ctx_files 45 | index_path: 46 | 47 | kilt_out_file: 48 | 49 | # A trained bi-encoder checkpoint file to initialize the model 50 | model_file: 51 | 52 | validate_as_tables: False 53 | 54 | # RPC settings 55 | rpc_retriever_cfg_file: 56 | rpc_index_id: 57 | use_l2_conversion: False 58 | use_rpc_meta: False 59 | rpc_meta_compressed: False 60 | 61 | indexer: flat 62 | 63 | # tokens which won't be slit by tokenizer 64 | special_tokens: 65 | 66 | # TODO: move to a conf group 67 | # local_rank for distributed training on gpus 68 | local_rank: -1 69 | global_loss_buf_sz: 150000 70 | device: 71 | distributed_world_size: 72 | distributed_port: 73 | no_cuda: False 74 | n_gpu: 75 | fp16: False 76 | 77 | # For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']." 78 | # "See details at https://nvidia.github.io/apex/amp.html 79 | fp16_opt_level: O1 80 | 81 | -------------------------------------------------------------------------------- /conf/encoder/hf_bert.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | 3 | # model type. One of [hf_bert, pytext_bert, fairseq_roberta] 4 | encoder_model_type: hf_bert 5 | 6 | # HuggingFace's config name for model initialization 7 | pretrained_model_cfg: bert-base-uncased 8 | 9 | # Some encoders need to be initialized from a file 10 | pretrained_file: 11 | 12 | # Extra linear layer on top of standard bert/roberta encoder 13 | projection_dim: 0 14 | 15 | # Max length of the encoder input sequence 16 | sequence_length: 256 17 | 18 | dropout: 0.1 19 | 20 | # whether to fix (don't update) context encoder during training or not 21 | fix_ctx_encoder: False 22 | 23 | # if False, the model won't load pre-trained BERT weights 24 | pretrained: True -------------------------------------------------------------------------------- /conf/extractive_reader_train_cfg.yaml: -------------------------------------------------------------------------------- 1 | # extractive reader configuration 2 | 3 | defaults: 4 | - encoder: hf_bert 5 | - train: extractive_reader_default 6 | 7 | # A trained reader checkpoint file to initialize the model 8 | model_file: 9 | 10 | # Whether to lower case the input text. Set True for uncased models, False for the cased ones. 11 | do_lower_case: True 12 | 13 | seed: 42 14 | 15 | # glob expression for train data files 16 | train_files: 17 | 18 | # glob expression for dev data files 19 | dev_files: 20 | 21 | # Total amount of positive and negative passages per question 22 | passages_per_question: 24 23 | 24 | # Total amount of positive and negative passages per question for evaluation 25 | passages_per_question_predict: 50 26 | 27 | # The output directory where the model checkpoints will be written to 28 | output_dir: 29 | 30 | # Max amount of answer spans to marginalize per singe passage 31 | max_n_answers: 10 32 | 33 | # The maximum length of an answer that can be generated. This is needed because the start 34 | # and end predictions are not conditioned on one another 35 | max_answer_length: 10 36 | 37 | # Top retrieval passages thresholds to analyze prediction results for 38 | eval_top_docs: 39 | - 50 40 | 41 | checkpoint_file_name: dpr_extractive_reader 42 | 43 | # Path to a file to write prediction results to 44 | prediction_results_file: 45 | 46 | # Enables fully resumable mode 47 | fully_resumable: False 48 | 49 | # File with the original train dataset passages (json format) 50 | gold_passages_src: 51 | 52 | # File with the original dataset passages (json format) 53 | gold_passages_src_dev: 54 | 55 | # num of threads to pre-process data. 56 | num_workers: 16 57 | 58 | # TODO: move to a conf group 59 | # local_rank for distributed training on gpus 60 | local_rank: -1 61 | global_loss_buf_sz: 150000 62 | device: 63 | distributed_world_size: 64 | distributed_port: 65 | no_cuda: False 66 | n_gpu: 67 | fp16: False 68 | 69 | # For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']." 70 | # "See details at https://nvidia.github.io/apex/amp.html 71 | fp16_opt_level: O1 72 | 73 | # a list of tokens to avoid tokenization 74 | special_tokens: -------------------------------------------------------------------------------- /conf/gen_embs.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - encoder: hf_bert 3 | - ctx_sources: default_sources 4 | 5 | # A trained bi-encoder checkpoint file to initialize the model 6 | model_file: 7 | 8 | # Name of the all-passages resource 9 | ctx_src: 10 | 11 | # which (ctx or query) encoder to be used for embedding generation 12 | encoder_type: ctx 13 | 14 | # output .tsv file path to write results to 15 | out_file: 16 | 17 | # Whether to lower case the input text. Set True for uncased models, False for the cased ones. 18 | do_lower_case: True 19 | 20 | # Number(0-based) of data shard to process 21 | shard_id: 0 22 | 23 | # Total amount of data shards 24 | num_shards: 1 25 | 26 | # Batch size for the passage encoder forward pass (works in DataParallel mode) 27 | batch_size: 32 28 | 29 | tables_as_passages: False 30 | 31 | # tokens which won't be slit by tokenizer 32 | special_tokens: 33 | 34 | tables_chunk_sz: 100 35 | 36 | # TODO 37 | tables_split_type: type1 38 | 39 | 40 | # TODO: move to a conf group 41 | # local_rank for distributed training on gpus 42 | local_rank: -1 43 | device: 44 | distributed_world_size: 45 | distributed_port: 46 | no_cuda: False 47 | n_gpu: 48 | fp16: False 49 | 50 | # For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']." 51 | # "See details at https://nvidia.github.io/apex/amp.html 52 | fp16_opt_level: O1 -------------------------------------------------------------------------------- /conf/train/biencoder_default.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | 3 | batch_size: 2 4 | dev_batch_size: 4 5 | adam_eps: 1e-8 6 | adam_betas: (0.9, 0.999) 7 | max_grad_norm: 1.0 8 | log_batch_step: 100 9 | train_rolling_loss_step: 100 10 | weight_decay: 0.0 11 | learning_rate: 1e-5 12 | 13 | # Linear warmup over warmup_steps. 14 | warmup_steps: 100 15 | 16 | # Number of updates steps to accumulate before performing a backward/update pass. 17 | gradient_accumulation_steps: 1 18 | 19 | # Total number of training epochs to perform. 20 | num_train_epochs: 40 21 | eval_per_epoch: 1 22 | hard_negatives: 1 23 | other_negatives: 0 24 | val_av_rank_hard_neg: 30 25 | val_av_rank_other_neg: 30 26 | val_av_rank_bsz: 128 27 | val_av_rank_max_qs: 10000 -------------------------------------------------------------------------------- /conf/train/biencoder_local.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | 3 | batch_size: 1 4 | dev_batch_size: 16 5 | adam_eps: 1e-8 6 | adam_betas: (0.9, 0.999) 7 | max_grad_norm: 2.0 8 | log_batch_step: 1 9 | train_rolling_loss_step: 100 10 | weight_decay: 0.0 11 | learning_rate: 2e-5 12 | 13 | # Linear warmup over warmup_steps. 14 | warmup_steps: 1237 15 | 16 | # Number of updates steps to accumulate before performing a backward/update pass. 17 | gradient_accumulation_steps: 1 18 | 19 | # Total number of training epochs to perform. 20 | num_train_epochs: 40 21 | eval_per_epoch: 1 22 | hard_negatives: 1 23 | other_negatives: 0 24 | val_av_rank_hard_neg: 30 25 | val_av_rank_other_neg: 30 26 | val_av_rank_bsz: 128 27 | val_av_rank_max_qs: 10000 28 | -------------------------------------------------------------------------------- /conf/train/biencoder_nq.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | 3 | batch_size: 16 4 | dev_batch_size: 64 5 | adam_eps: 1e-8 6 | adam_betas: (0.9, 0.999) 7 | max_grad_norm: 2.0 8 | log_batch_step: 10 9 | train_rolling_loss_step: 100 10 | weight_decay: 0.0 11 | learning_rate: 2e-5 12 | 13 | # Linear warmup over warmup_steps. 14 | warmup_steps: 1237 15 | 16 | # Number of updates steps to accumulate before performing a backward/update pass. 17 | gradient_accumulation_steps: 1 18 | 19 | # Total number of training epochs to perform. 20 | num_train_epochs: 40 21 | eval_per_epoch: 1 22 | hard_negatives: 1 23 | other_negatives: 0 24 | val_av_rank_hard_neg: 30 25 | val_av_rank_other_neg: 30 26 | val_av_rank_bsz: 128 27 | val_av_rank_max_qs: 10000 -------------------------------------------------------------------------------- /conf/train/extractive_reader_default.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | 3 | eval_step: 2000 4 | batch_size: 16 5 | dev_batch_size: 72 6 | adam_eps: 1e-8 7 | adam_betas: (0.9, 0.999) 8 | max_grad_norm: 1.0 9 | log_batch_step: 100 10 | train_rolling_loss_step: 100 11 | weight_decay: 0.0 12 | learning_rate: 1e-5 13 | 14 | # Linear warmup over warmup_steps. 15 | warmup_steps: 0 16 | 17 | # Number of updates steps to accumulate before performing a backward/update pass. 18 | gradient_accumulation_steps: 1 19 | 20 | # Total number of training epochs to perform. 21 | num_train_epochs: 100000 22 | -------------------------------------------------------------------------------- /dpr/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/DPR/a31212dc0a54dfa85d8bfa01e1669f149ac832b7/dpr/__init__.py -------------------------------------------------------------------------------- /dpr/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/DPR/a31212dc0a54dfa85d8bfa01e1669f149ac832b7/dpr/data/__init__.py -------------------------------------------------------------------------------- /dpr/data/biencoder_data.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import glob 3 | import logging 4 | import os 5 | import random 6 | from typing import Dict, List, Tuple 7 | 8 | import jsonlines 9 | import numpy as np 10 | from omegaconf import DictConfig 11 | 12 | from dpr.data.tables import Table 13 | from dpr.utils.data_utils import read_data_from_json_files, Dataset 14 | 15 | logger = logging.getLogger(__name__) 16 | BiEncoderPassage = collections.namedtuple("BiEncoderPassage", ["text", "title"]) 17 | 18 | 19 | def get_dpr_files(source_name) -> List[str]: 20 | if os.path.exists(source_name) or glob.glob(source_name): 21 | return glob.glob(source_name) 22 | else: 23 | # try to use data downloader 24 | from dpr.data.download_data import download 25 | return download(source_name) 26 | 27 | 28 | class BiEncoderSample(object): 29 | query: str 30 | positive_passages: List[BiEncoderPassage] 31 | negative_passages: List[BiEncoderPassage] 32 | hard_negative_passages: List[BiEncoderPassage] 33 | 34 | 35 | class JsonQADataset(Dataset): 36 | def __init__( 37 | self, 38 | file: str, 39 | selector: DictConfig = None, 40 | special_token: str = None, 41 | encoder_type: str = None, 42 | shuffle_positives: bool = False, 43 | normalize: bool = False, 44 | query_special_suffix: str = None, 45 | # tmp: for cc-net results only 46 | exclude_gold: bool = False, 47 | ): 48 | super().__init__( 49 | selector, 50 | special_token=special_token, 51 | encoder_type=encoder_type, 52 | shuffle_positives=shuffle_positives, 53 | query_special_suffix=query_special_suffix, 54 | ) 55 | self.file = file 56 | self.data_files = [] 57 | self.normalize = normalize 58 | self.exclude_gold = exclude_gold 59 | 60 | def calc_total_data_len(self): 61 | if not self.data: 62 | logger.info("Loading all data") 63 | self._load_all_data() 64 | return len(self.data) 65 | 66 | def load_data(self, start_pos: int = -1, end_pos: int = -1): 67 | if not self.data: 68 | self._load_all_data() 69 | if start_pos >= 0 and end_pos >= 0: 70 | logger.info("Selecting subset range from %d to %d", start_pos, end_pos) 71 | self.data = self.data[start_pos:end_pos] 72 | 73 | def _load_all_data(self): 74 | self.data_files = get_dpr_files(self.file) 75 | logger.info("Data files: %s", self.data_files) 76 | data = read_data_from_json_files(self.data_files) 77 | # filter those without positive ctx 78 | self.data = [r for r in data if len(r["positive_ctxs"]) > 0] 79 | logger.info("Total cleaned data size: %d", len(self.data)) 80 | 81 | def __getitem__(self, index) -> BiEncoderSample: 82 | json_sample = self.data[index] 83 | r = BiEncoderSample() 84 | r.query = self._process_query(json_sample["question"]) 85 | 86 | positive_ctxs = json_sample["positive_ctxs"] 87 | if self.exclude_gold: 88 | ctxs = [ctx for ctx in positive_ctxs if "score" in ctx] 89 | if ctxs: 90 | positive_ctxs = ctxs 91 | 92 | negative_ctxs = json_sample["negative_ctxs"] if "negative_ctxs" in json_sample else [] 93 | hard_negative_ctxs = json_sample["hard_negative_ctxs"] if "hard_negative_ctxs" in json_sample else [] 94 | 95 | for ctx in positive_ctxs + negative_ctxs + hard_negative_ctxs: 96 | if "title" not in ctx: 97 | ctx["title"] = None 98 | 99 | def create_passage(ctx: dict): 100 | return BiEncoderPassage( 101 | normalize_passage(ctx["text"]) if self.normalize else ctx["text"], 102 | ctx["title"], 103 | ) 104 | 105 | r.positive_passages = [create_passage(ctx) for ctx in positive_ctxs] 106 | r.negative_passages = [create_passage(ctx) for ctx in negative_ctxs] 107 | r.hard_negative_passages = [create_passage(ctx) for ctx in hard_negative_ctxs] 108 | return r 109 | 110 | 111 | class JsonlQADataset(JsonQADataset): 112 | def __init__( 113 | self, 114 | file: str, 115 | selector: DictConfig = None, 116 | special_token: str = None, 117 | encoder_type: str = None, 118 | shuffle_positives: bool = False, 119 | normalize: bool = False, 120 | query_special_suffix: str = None, 121 | # tmp: for cc-net results only 122 | exclude_gold: bool = False, 123 | total_data_size: int = -1, 124 | ): 125 | super().__init__( 126 | file, 127 | selector, 128 | special_token, 129 | encoder_type, 130 | shuffle_positives, 131 | normalize, 132 | query_special_suffix, 133 | exclude_gold, 134 | ) 135 | self.total_data_size = total_data_size 136 | self.data_files = get_dpr_files(self.file) 137 | logger.info("Data files: %s", self.data_files) 138 | 139 | def calc_total_data_len(self): 140 | # TODO: optimize jsonl file read & results caching 141 | if self.total_data_size < 0: 142 | logger.info("Calculating data size") 143 | for file in self.data_files: 144 | with jsonlines.open(file, mode="r") as jsonl_reader: 145 | for _ in jsonl_reader: 146 | self.total_data_size += 1 147 | logger.info("total_data_size=%d", self.total_data_size) 148 | return self.total_data_size 149 | 150 | def load_data(self, start_pos: int = -1, end_pos: int = -1): 151 | if self.data: 152 | return 153 | logger.info("Jsonl loading subset range from %d to %d", start_pos, end_pos) 154 | if start_pos < 0 and end_pos < 0: 155 | for file in self.data_files: 156 | with jsonlines.open(file, mode="r") as jsonl_reader: 157 | self.data.extend([l for l in jsonl_reader]) 158 | return 159 | 160 | global_sample_id = 0 161 | for file in self.data_files: 162 | if global_sample_id >= end_pos: 163 | break 164 | with jsonlines.open(file, mode="r") as jsonl_reader: 165 | for jline in jsonl_reader: 166 | if start_pos <= global_sample_id < end_pos: 167 | self.data.append(jline) 168 | if global_sample_id >= end_pos: 169 | break 170 | global_sample_id += 1 171 | logger.info("Jsonl loaded data size %d ", len(self.data)) 172 | 173 | 174 | def normalize_passage(ctx_text: str): 175 | ctx_text = ctx_text.replace("\n", " ").replace("’", "'") 176 | if ctx_text.startswith('"'): 177 | ctx_text = ctx_text[1:] 178 | if ctx_text.endswith('"'): 179 | ctx_text = ctx_text[:-1] 180 | return ctx_text 181 | 182 | 183 | class Cell: 184 | def __init__(self): 185 | self.value_tokens: List[str] = [] 186 | self.type: str = "" 187 | self.nested_tables: List[Table] = [] 188 | 189 | def __str__(self): 190 | return " ".join(self.value_tokens) 191 | 192 | def to_dpr_json(self, cell_idx: int): 193 | r = {"col": cell_idx} 194 | r["value"] = str(self) 195 | return r 196 | 197 | 198 | class Row: 199 | def __init__(self): 200 | self.cells: List[Cell] = [] 201 | 202 | def __str__(self): 203 | return "| ".join([str(c) for c in self.cells]) 204 | 205 | def visit(self, tokens_function, row_idx: int): 206 | for i, c in enumerate(self.cells): 207 | if c.value_tokens: 208 | tokens_function(c.value_tokens, row_idx, i) 209 | 210 | def to_dpr_json(self, row_idx: int): 211 | r = {"row": row_idx} 212 | r["columns"] = [c.to_dpr_json(i) for i, c in enumerate(self.cells)] 213 | return r 214 | 215 | 216 | class Table(object): 217 | def __init__(self, caption=""): 218 | self.caption = caption 219 | self.body: List[Row] = [] 220 | self.key = None 221 | self.gold_match = False 222 | 223 | def __str__(self): 224 | table_str = ": {}\n".format(self.caption) 225 | table_str += " rows:\n" 226 | for i, r in enumerate(self.body): 227 | table_str += " row #{}: {}\n".format(i, str(r)) 228 | 229 | return table_str 230 | 231 | def get_key(self) -> str: 232 | if not self.key: 233 | self.key = str(self) 234 | return self.key 235 | 236 | def visit(self, tokens_function, include_caption: bool = False) -> bool: 237 | if include_caption: 238 | tokens_function(self.caption, -1, -1) 239 | for i, r in enumerate(self.body): 240 | r.visit(tokens_function, i) 241 | 242 | def to_dpr_json(self): 243 | r = { 244 | "caption": self.caption, 245 | "rows": [r.to_dpr_json(i) for i, r in enumerate(self.body)], 246 | } 247 | if self.gold_match: 248 | r["gold_match"] = 1 249 | return r 250 | 251 | 252 | class NQTableParser(object): 253 | def __init__(self, tokens, is_html_mask, title): 254 | self.tokens = tokens 255 | self.is_html_mask = is_html_mask 256 | self.max_idx = len(self.tokens) 257 | self.all_tables = [] 258 | 259 | self.current_table: Table = None 260 | self.tables_stack = collections.deque() 261 | self.title = title 262 | 263 | def parse(self) -> List[Table]: 264 | self.all_tables = [] 265 | self.tables_stack = collections.deque() 266 | 267 | for i in range(self.max_idx): 268 | 269 | t = self.tokens[i] 270 | 271 | if not self.is_html_mask[i]: 272 | # cell content 273 | self._on_content(t) 274 | continue 275 | 276 | if "": 279 | self._on_table_end() 280 | elif "": 283 | self._onRowEnd() 284 | elif "", ""]: 287 | self._on_cell_end() 288 | 289 | return self.all_tables 290 | 291 | def _on_table_start(self): 292 | caption = self.title 293 | parent_table = self.current_table 294 | if parent_table: 295 | self.tables_stack.append(parent_table) 296 | 297 | caption = parent_table.caption 298 | if parent_table.body and parent_table.body[-1].cells: 299 | current_cell = self.current_table.body[-1].cells[-1] 300 | caption += " | " + " ".join(current_cell.value_tokens) 301 | 302 | t = Table() 303 | t.caption = caption 304 | self.current_table = t 305 | self.all_tables.append(t) 306 | 307 | def _on_table_end(self): 308 | t = self.current_table 309 | if t: 310 | if self.tables_stack: # t is a nested table 311 | self.current_table = self.tables_stack.pop() 312 | if self.current_table.body: 313 | current_cell = self.current_table.body[-1].cells[-1] 314 | current_cell.nested_tables.append(t) 315 | else: 316 | logger.error("table end without table object") 317 | 318 | def _onRowStart(self): 319 | self.current_table.body.append(Row()) 320 | 321 | def _onRowEnd(self): 322 | pass 323 | 324 | def _onCellStart(self): 325 | current_row = self.current_table.body[-1] 326 | current_row.cells.append(Cell()) 327 | 328 | def _on_cell_end(self): 329 | pass 330 | 331 | def _on_content(self, token): 332 | if self.current_table.body: 333 | current_row = self.current_table.body[-1] 334 | current_cell = current_row.cells[-1] 335 | current_cell.value_tokens.append(token) 336 | else: # tokens outside of row/cells. Just append to the table caption. 337 | self.current_table.caption += " " + token 338 | 339 | 340 | def read_nq_tables_jsonl(path: str) -> Dict[str, Table]: 341 | tables_with_issues = 0 342 | single_row_tables = 0 343 | nested_tables = 0 344 | regular_tables = 0 345 | total_tables = 0 346 | total_rows = 0 347 | tables_dict = {} 348 | 349 | with jsonlines.open(path, mode="r") as jsonl_reader: 350 | for jline in jsonl_reader: 351 | tokens = jline["tokens"] 352 | 353 | if "( hide ) This section has multiple issues" in " ".join(tokens): 354 | tables_with_issues += 1 355 | continue 356 | 357 | mask = jline["html_mask"] 358 | # page_url = jline["doc_url"] 359 | title = jline["title"] 360 | p = NQTableParser(tokens, mask, title) 361 | tables = p.parse() 362 | 363 | # table = parse_table(tokens, mask) 364 | 365 | nested_tables += len(tables[1:]) 366 | 367 | for t in tables: 368 | total_tables += 1 369 | 370 | # calc amount of non empty rows 371 | non_empty_rows = sum([1 for r in t.body if r.cells and any([True for c in r.cells if c.value_tokens])]) 372 | 373 | if non_empty_rows <= 1: 374 | single_row_tables += 1 375 | else: 376 | regular_tables += 1 377 | total_rows += len(t.body) 378 | 379 | if t.get_key() not in tables_dict: 380 | tables_dict[t.get_key()] = t 381 | 382 | if len(tables_dict) % 1000 == 0: 383 | logger.info("tables_dict %d", len(tables_dict)) 384 | 385 | logger.info("regular tables %d", regular_tables) 386 | logger.info("tables_with_issues %d", tables_with_issues) 387 | logger.info("single_row_tables %d", single_row_tables) 388 | logger.info("nested_tables %d", nested_tables) 389 | return tables_dict 390 | 391 | 392 | def get_table_string_for_answer_check(table: Table): # this doesn't use caption 393 | table_text = "" 394 | for r in table.body: 395 | table_text += " . ".join([" ".join(c.value_tokens) for c in r.cells]) 396 | table_text += " . " 397 | return table_text 398 | 399 | 400 | # TODO: inherit from Jsonl 401 | class JsonLTablesQADataset(Dataset): 402 | def __init__( 403 | self, 404 | file: str, 405 | is_train_set: bool, 406 | selector: DictConfig = None, 407 | shuffle_positives: bool = False, 408 | max_negatives: int = 1, 409 | seed: int = 0, 410 | max_len=100, 411 | split_type: str = "type1", 412 | ): 413 | super().__init__(selector, shuffle_positives=shuffle_positives) 414 | self.data_files = glob.glob(file) 415 | self.data = [] 416 | self.is_train_set = is_train_set 417 | self.max_negatives = max_negatives 418 | self.rnd = random.Random(seed) 419 | self.max_len = max_len 420 | self.linearize_func = JsonLTablesQADataset.get_lin_func(split_type) 421 | 422 | def load_data(self, start_pos: int = -1, end_pos: int = -1): 423 | # TODO: use JsonlX super class load_data() ? 424 | data = [] 425 | for path in self.data_files: 426 | with jsonlines.open(path, mode="r") as jsonl_reader: 427 | data += [jline for jline in jsonl_reader] 428 | # filter those without positive ctx 429 | self.data = [r for r in data if len(r["positive_ctxs"]) > 0] 430 | logger.info("Total cleaned data size: {}".format(len(self.data))) 431 | if start_pos >= 0 and end_pos >= 0: 432 | logger.info("Selecting subset range from %d to %d", start_pos, end_pos) 433 | self.data = self.data[start_pos:end_pos] 434 | 435 | def __getitem__(self, index) -> BiEncoderSample: 436 | json_sample = self.data[index] 437 | r = BiEncoderSample() 438 | r.query = json_sample["question"] 439 | positive_ctxs = json_sample["positive_ctxs"] 440 | hard_negative_ctxs = json_sample["hard_negative_ctxs"] 441 | 442 | if self.shuffle_positives: 443 | self.rnd.shuffle(positive_ctxs) 444 | 445 | if self.is_train_set: 446 | self.rnd.shuffle(hard_negative_ctxs) 447 | positive_ctxs = positive_ctxs[0:1] 448 | hard_negative_ctxs = hard_negative_ctxs[0 : self.max_negatives] 449 | 450 | r.positive_passages = [ 451 | BiEncoderPassage(self.linearize_func(self, ctx, True), ctx["caption"]) for ctx in positive_ctxs 452 | ] 453 | r.negative_passages = [] 454 | r.hard_negative_passages = [ 455 | BiEncoderPassage(self.linearize_func(self, ctx, False), ctx["caption"]) for ctx in hard_negative_ctxs 456 | ] 457 | return r 458 | 459 | @classmethod 460 | def get_lin_func(cls, split_type: str): 461 | f = { 462 | "type1": JsonLTablesQADataset._linearize_table, 463 | } 464 | return f[split_type] 465 | 466 | @classmethod 467 | def split_table(cls, t: dict, max_length: int): 468 | rows = t["rows"] 469 | header = None 470 | header_len = 0 471 | start_row = 0 472 | 473 | # get the first non empty row as the "header" 474 | for i, r in enumerate(rows): 475 | row_lin, row_len = JsonLTablesQADataset._linearize_row(r) 476 | if len(row_lin) > 1: # TODO: change to checking cell value tokens 477 | header = row_lin 478 | header_len += row_len 479 | start_row = i 480 | break 481 | 482 | chunks = [] 483 | current_rows = [header] 484 | current_len = header_len 485 | 486 | for i in range(start_row + 1, len(rows)): 487 | row_lin, row_len = JsonLTablesQADataset._linearize_row(rows[i]) 488 | if len(row_lin) > 1: # TODO: change to checking cell value tokens 489 | current_rows.append(row_lin) 490 | current_len += row_len 491 | if current_len >= max_length: 492 | # linearize chunk 493 | linearized_str = "\n".join(current_rows) + "\n" 494 | chunks.append(linearized_str) 495 | current_rows = [header] 496 | current_len = header_len 497 | 498 | if len(current_rows) > 1: 499 | linearized_str = "\n".join(current_rows) + "\n" 500 | chunks.append(linearized_str) 501 | return chunks 502 | 503 | def _linearize_table(self, t: dict, is_positive: bool) -> str: 504 | rows = t["rows"] 505 | selected_rows = set() 506 | rows_linearized = [] 507 | total_words_len = 0 508 | 509 | # get the first non empty row as the "header" 510 | for i, r in enumerate(rows): 511 | row_lin, row_len = JsonLTablesQADataset._linearize_row(r) 512 | if len(row_lin) > 1: # TODO: change to checking cell value tokens 513 | selected_rows.add(i) 514 | rows_linearized.append(row_lin) 515 | total_words_len += row_len 516 | break 517 | 518 | # split to chunks 519 | if is_positive: 520 | row_idx_with_answers = [ap[0] for ap in t["answer_pos"]] 521 | 522 | if self.shuffle_positives: 523 | self.rnd.shuffle(row_idx_with_answers) 524 | for i in row_idx_with_answers: 525 | if i not in selected_rows: 526 | row_lin, row_len = JsonLTablesQADataset._linearize_row(rows[i]) 527 | selected_rows.add(i) 528 | rows_linearized.append(row_lin) 529 | total_words_len += row_len 530 | if total_words_len >= self.max_len: 531 | break 532 | 533 | if total_words_len < self.max_len: # append random rows 534 | 535 | if self.is_train_set: 536 | rows_indexes = np.random.permutation(range(len(rows))) 537 | else: 538 | rows_indexes = [*range(len(rows))] 539 | 540 | for i in rows_indexes: 541 | if i not in selected_rows: 542 | row_lin, row_len = JsonLTablesQADataset._linearize_row(rows[i]) 543 | if len(row_lin) > 1: # TODO: change to checking cell value tokens 544 | selected_rows.add(i) 545 | rows_linearized.append(row_lin) 546 | total_words_len += row_len 547 | if total_words_len >= self.max_len: 548 | break 549 | 550 | linearized_str = "" 551 | for r in rows_linearized: 552 | linearized_str += r + "\n" 553 | 554 | return linearized_str 555 | 556 | @classmethod 557 | def _linearize_row(cls, row: dict) -> Tuple[str, int]: 558 | cell_values = [c["value"] for c in row["columns"]] 559 | total_words = sum(len(c.split(" ")) for c in cell_values) 560 | return ", ".join([c["value"] for c in row["columns"]]), total_words 561 | 562 | 563 | def split_tables_to_chunks( 564 | tables_dict: Dict[str, Table], max_table_len: int, split_type: str = "type1" 565 | ) -> List[Tuple[int, str, str, int]]: 566 | tables_as_dicts = [t.to_dpr_json() for k, t in tables_dict.items()] 567 | chunks = [] 568 | chunk_id = 0 569 | for i, t in enumerate(tables_as_dicts): 570 | # TODO: support other types 571 | assert split_type == "type1" 572 | table_chunks = JsonLTablesQADataset.split_table(t, max_table_len) 573 | title = t["caption"] 574 | for c in table_chunks: 575 | # chunk id , text, title, external_id 576 | chunks.append((chunk_id, c, title, i)) 577 | chunk_id += 1 578 | if i % 1000 == 0: 579 | logger.info("Splitted %d tables to %d chunks", i, len(chunks)) 580 | return chunks 581 | -------------------------------------------------------------------------------- /dpr/data/qa_validation.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """ 9 | Set of utilities for Q&A results validation tasks - Retriver passage validation and Reader predicted answer validation 10 | """ 11 | 12 | import collections 13 | import logging 14 | import string 15 | import unicodedata 16 | import zlib 17 | from functools import partial 18 | from multiprocessing import Pool as ProcessPool 19 | from typing import Tuple, List, Dict 20 | 21 | import regex as re 22 | 23 | from dpr.data.retriever_data import TableChunk 24 | from dpr.utils.tokenizers import SimpleTokenizer 25 | 26 | logger = logging.getLogger(__name__) 27 | 28 | QAMatchStats = collections.namedtuple("QAMatchStats", ["top_k_hits", "questions_doc_hits"]) 29 | 30 | QATableMatchStats = collections.namedtuple( 31 | "QAMatchStats", ["top_k_chunk_hits", "top_k_table_hits", "questions_doc_hits"] 32 | ) 33 | 34 | 35 | def calculate_matches( 36 | all_docs: Dict[object, Tuple[str, str]], 37 | answers: List[List[str]], 38 | closest_docs: List[Tuple[List[object], List[float]]], 39 | workers_num: int, 40 | match_type: str, 41 | ) -> QAMatchStats: 42 | """ 43 | Evaluates answers presence in the set of documents. This function is supposed to be used with a large collection of 44 | documents and results. It internally forks multiple sub-processes for evaluation and then merges results 45 | :param all_docs: dictionary of the entire documents database. doc_id -> (doc_text, title) 46 | :param answers: list of answers's list. One list per question 47 | :param closest_docs: document ids of the top results along with their scores 48 | :param workers_num: amount of parallel threads to process data 49 | :param match_type: type of answer matching. Refer to has_answer code for available options 50 | :return: matching information tuple. 51 | top_k_hits - a list where the index is the amount of top documents retrieved and the value is the total amount of 52 | valid matches across an entire dataset. 53 | questions_doc_hits - more detailed info with answer matches for every question and every retrieved document 54 | """ 55 | logger.info("all_docs size %d", len(all_docs)) 56 | global dpr_all_documents 57 | dpr_all_documents = all_docs 58 | logger.info("dpr_all_documents size %d", len(dpr_all_documents)) 59 | 60 | tok_opts = {} 61 | tokenizer = SimpleTokenizer(**tok_opts) 62 | 63 | processes = ProcessPool(processes=workers_num) 64 | logger.info("Matching answers in top docs...") 65 | get_score_partial = partial(check_answer, match_type=match_type, tokenizer=tokenizer) 66 | 67 | questions_answers_docs = zip(answers, closest_docs) 68 | scores = processes.map(get_score_partial, questions_answers_docs) 69 | 70 | logger.info("Per question validation results len=%d", len(scores)) 71 | 72 | n_docs = len(closest_docs[0][0]) 73 | top_k_hits = [0] * n_docs 74 | for question_hits in scores: 75 | best_hit = next((i for i, x in enumerate(question_hits) if x), None) 76 | if best_hit is not None: 77 | top_k_hits[best_hit:] = [v + 1 for v in top_k_hits[best_hit:]] 78 | 79 | return QAMatchStats(top_k_hits, scores) 80 | 81 | 82 | def calculate_matches_from_meta( 83 | answers: List[List[str]], 84 | closest_docs: List[Tuple[List[object], List[float]]], 85 | workers_num: int, 86 | match_type: str, 87 | use_title: bool = False, 88 | meta_compressed: bool = False, 89 | ) -> QAMatchStats: 90 | 91 | tok_opts = {} 92 | tokenizer = SimpleTokenizer(**tok_opts) 93 | 94 | processes = ProcessPool(processes=workers_num) 95 | logger.info("Matching answers in top docs...") 96 | get_score_partial = partial( 97 | check_answer_from_meta, 98 | match_type=match_type, 99 | tokenizer=tokenizer, 100 | use_title=use_title, 101 | meta_compressed=meta_compressed, 102 | ) 103 | 104 | questions_answers_docs = zip(answers, closest_docs) 105 | scores = processes.map(get_score_partial, questions_answers_docs) 106 | 107 | logger.info("Per question validation results len=%d", len(scores)) 108 | 109 | n_docs = len(closest_docs[0][0]) 110 | top_k_hits = [0] * n_docs 111 | for question_hits in scores: 112 | best_hit = next((i for i, x in enumerate(question_hits) if x), None) 113 | if best_hit is not None: 114 | top_k_hits[best_hit:] = [v + 1 for v in top_k_hits[best_hit:]] 115 | 116 | return QAMatchStats(top_k_hits, scores) 117 | 118 | 119 | def check_answer(questions_answers_docs, tokenizer, match_type) -> List[bool]: 120 | """Search through all the top docs to see if they have any of the answers.""" 121 | answers, (doc_ids, doc_scores) = questions_answers_docs 122 | 123 | global dpr_all_documents 124 | hits = [] 125 | 126 | for i, doc_id in enumerate(doc_ids): 127 | doc = dpr_all_documents[doc_id] 128 | text = doc[0] 129 | 130 | answer_found = False 131 | if text is None: # cannot find the document for some reason 132 | logger.warning("no doc in db") 133 | hits.append(False) 134 | continue 135 | if match_type == "kilt": 136 | if has_answer_kilt(answers, text): 137 | answer_found = True 138 | elif has_answer(answers, text, tokenizer, match_type): 139 | answer_found = True 140 | hits.append(answer_found) 141 | return hits 142 | 143 | 144 | def check_answer_from_meta( 145 | questions_answers_docs, 146 | tokenizer, 147 | match_type, 148 | meta_body_idx: int = 1, 149 | meta_title_idx: int = 2, 150 | use_title: bool = False, 151 | meta_compressed: bool = False, 152 | ) -> List[bool]: 153 | """Search through all the top docs to see if they have any of the answers.""" 154 | answers, (docs_meta, doc_scores) = questions_answers_docs 155 | 156 | hits = [] 157 | 158 | for i, doc_meta in enumerate(docs_meta): 159 | 160 | text = doc_meta[meta_body_idx] 161 | title = doc_meta[meta_title_idx] if len(doc_meta) > meta_title_idx else "" 162 | if meta_compressed: 163 | text = zlib.decompress(text).decode() 164 | title = zlib.decompress(title).decode() 165 | 166 | if use_title: 167 | text = title + " . " + text 168 | answer_found = False 169 | if has_answer(answers, text, tokenizer, match_type): 170 | answer_found = True 171 | hits.append(answer_found) 172 | return hits 173 | 174 | 175 | def has_answer(answers, text, tokenizer, match_type) -> bool: 176 | """Check if a document contains an answer string. 177 | If `match_type` is string, token matching is done between the text and answer. 178 | If `match_type` is regex, we search the whole text with the regex. 179 | """ 180 | text = _normalize(text) 181 | 182 | if match_type == "string": 183 | # Answer is a list of possible strings 184 | text = tokenizer.tokenize(text).words(uncased=True) 185 | 186 | for single_answer in answers: 187 | single_answer = _normalize(single_answer) 188 | single_answer = tokenizer.tokenize(single_answer) 189 | single_answer = single_answer.words(uncased=True) 190 | 191 | for i in range(0, len(text) - len(single_answer) + 1): 192 | if single_answer == text[i : i + len(single_answer)]: 193 | return True 194 | 195 | elif match_type == "regex": 196 | # Answer is a regex 197 | for single_answer in answers: 198 | single_answer = _normalize(single_answer) 199 | if regex_match(text, single_answer): 200 | return True 201 | return False 202 | 203 | 204 | def regex_match(text, pattern): 205 | """Test if a regex pattern is contained within a text.""" 206 | try: 207 | pattern = re.compile(pattern, flags=re.IGNORECASE + re.UNICODE + re.MULTILINE) 208 | except BaseException: 209 | return False 210 | return pattern.search(text) is not None 211 | 212 | 213 | # function for the reader model answer validation 214 | def exact_match_score(prediction, ground_truth): 215 | return _normalize_answer(prediction) == _normalize_answer(ground_truth) 216 | 217 | 218 | def _normalize_answer(s): 219 | def remove_articles(text): 220 | return re.sub(r"\b(a|an|the)\b", " ", text) 221 | 222 | def white_space_fix(text): 223 | return " ".join(text.split()) 224 | 225 | def remove_punc(text): 226 | exclude = set(string.punctuation) 227 | return "".join(ch for ch in text if ch not in exclude) 228 | 229 | def lower(text): 230 | return text.lower() 231 | 232 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 233 | 234 | 235 | def _normalize(text): 236 | return unicodedata.normalize("NFD", text) 237 | 238 | 239 | def calculate_chunked_matches( 240 | all_docs: Dict[object, TableChunk], 241 | answers: List[List[str]], 242 | closest_docs: List[Tuple[List[object], List[float]]], 243 | workers_num: int, 244 | match_type: str, 245 | ) -> QATableMatchStats: 246 | global dpr_all_documents 247 | dpr_all_documents = all_docs 248 | 249 | global dpr_all_tables 250 | dpr_all_tables = {} 251 | 252 | for key, table_chunk in all_docs.items(): 253 | table_str, title, table_id = table_chunk 254 | table_chunks = dpr_all_tables.get(table_id, []) 255 | table_chunks.append((table_str, title)) 256 | dpr_all_tables[table_id] = table_chunks 257 | 258 | tok_opts = {} 259 | tokenizer = SimpleTokenizer(**tok_opts) 260 | 261 | processes = ProcessPool(processes=workers_num) 262 | 263 | logger.info("Matching answers in top docs...") 264 | get_score_partial = partial(check_chunked_docs_answer, match_type=match_type, tokenizer=tokenizer) 265 | questions_answers_docs = zip(answers, closest_docs) 266 | scores = processes.map(get_score_partial, questions_answers_docs) 267 | logger.info("Per question validation results len=%d", len(scores)) 268 | 269 | n_docs = len(closest_docs[0][0]) 270 | top_k_hits = [0] * n_docs 271 | top_k_orig_hits = [0] * n_docs 272 | for s in scores: 273 | question_hits, question_orig_doc_hits = s 274 | best_hit = next((i for i, x in enumerate(question_hits) if x), None) 275 | if best_hit is not None: 276 | top_k_hits[best_hit:] = [v + 1 for v in top_k_hits[best_hit:]] 277 | 278 | best_hit = next((i for i, x in enumerate(question_orig_doc_hits) if x), None) 279 | if best_hit is not None: 280 | top_k_orig_hits[best_hit:] = [v + 1 for v in top_k_orig_hits[best_hit:]] 281 | 282 | return QATableMatchStats(top_k_hits, top_k_orig_hits, scores) 283 | 284 | 285 | # -------------------- KILT eval --------------------------------- 286 | 287 | 288 | def has_answer_kilt(answers, text) -> bool: 289 | text = normalize_kilt(text) 290 | for single_answer in answers: 291 | single_answer = normalize_kilt(single_answer) 292 | if single_answer in text: 293 | return True 294 | return False 295 | 296 | 297 | # answer normalization 298 | def normalize_kilt(s): 299 | """Lower text and remove punctuation, articles and extra whitespace.""" 300 | 301 | def remove_articles(text): 302 | return re.sub(r"\b(a|an|the)\b", " ", text) 303 | 304 | def white_space_fix(text): 305 | return " ".join(text.split()) 306 | 307 | def remove_punc(text): 308 | exclude = set(string.punctuation) 309 | return "".join(ch for ch in text if ch not in exclude) 310 | 311 | def lower(text): 312 | return text.lower() 313 | 314 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 315 | -------------------------------------------------------------------------------- /dpr/data/retriever_data.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import csv 3 | import json 4 | import logging 5 | import pickle 6 | from typing import Dict, List 7 | 8 | import hydra 9 | import jsonlines 10 | import torch 11 | from omegaconf import DictConfig 12 | 13 | from dpr.data.biencoder_data import ( 14 | BiEncoderPassage, 15 | normalize_passage, 16 | get_dpr_files, 17 | read_nq_tables_jsonl, 18 | split_tables_to_chunks, 19 | ) 20 | 21 | from dpr.utils.data_utils import normalize_question 22 | 23 | logger = logging.getLogger(__name__) 24 | 25 | TableChunk = collections.namedtuple("TableChunk", ["text", "title", "table_id"]) 26 | 27 | 28 | class QASample: 29 | def __init__(self, query: str, id, answers: List[str]): 30 | self.query = query 31 | self.id = id 32 | self.answers = answers 33 | 34 | 35 | class RetrieverData(torch.utils.data.Dataset): 36 | def __init__(self, file: str): 37 | """ 38 | :param file: - real file name or the resource name as they are defined in download_data.py 39 | """ 40 | self.file = file 41 | self.data_files = [] 42 | 43 | def load_data(self): 44 | self.data_files = get_dpr_files(self.file) 45 | assert ( 46 | len(self.data_files) == 1 47 | ), "RetrieverData source currently works with single files only. Files specified: {}".format(self.data_files) 48 | self.file = self.data_files[0] 49 | 50 | 51 | class QASrc(RetrieverData): 52 | def __init__( 53 | self, 54 | file: str, 55 | selector: DictConfig = None, 56 | special_query_token: str = None, 57 | query_special_suffix: str = None, 58 | ): 59 | super().__init__(file) 60 | self.data = None 61 | self.selector = hydra.utils.instantiate(selector) if selector else None 62 | self.special_query_token = special_query_token 63 | self.query_special_suffix = query_special_suffix 64 | 65 | def __getitem__(self, index) -> QASample: 66 | return self.data[index] 67 | 68 | def __len__(self): 69 | return len(self.data) 70 | 71 | def _process_question(self, question: str): 72 | # as of now, always normalize query 73 | question = normalize_question(question) 74 | if self.query_special_suffix and not question.endswith(self.query_special_suffix): 75 | question += self.query_special_suffix 76 | return question 77 | 78 | 79 | class CsvQASrc(QASrc): 80 | def __init__( 81 | self, 82 | file: str, 83 | question_col: int = 0, 84 | answers_col: int = 1, 85 | id_col: int = -1, 86 | selector: DictConfig = None, 87 | special_query_token: str = None, 88 | query_special_suffix: str = None, 89 | data_range_start: int = -1, 90 | data_size: int = -1, 91 | ): 92 | super().__init__(file, selector, special_query_token, query_special_suffix) 93 | self.question_col = question_col 94 | self.answers_col = answers_col 95 | self.id_col = id_col 96 | self.data_range_start = data_range_start 97 | self.data_size = data_size 98 | 99 | def load_data(self): 100 | super().load_data() 101 | data = [] 102 | start = self.data_range_start 103 | # size = self.data_size 104 | samples_count = 0 105 | # TODO: optimize 106 | with open(self.file) as ifile: 107 | reader = csv.reader(ifile, delimiter="\t") 108 | for row in reader: 109 | question = row[self.question_col] 110 | answers = eval(row[self.answers_col]) 111 | id = None 112 | if self.id_col >= 0: 113 | id = row[self.id_col] 114 | samples_count += 1 115 | # if start !=-1 and samples_count<=start: 116 | # continue 117 | data.append(QASample(self._process_question(question), id, answers)) 118 | 119 | if start != -1: 120 | end = start + self.data_size if self.data_size != -1 else -1 121 | logger.info("Selecting dataset range [%s,%s]", start, end) 122 | self.data = data[start:end] if end != -1 else data[start:] 123 | else: 124 | self.data = data 125 | 126 | 127 | class JsonlQASrc(QASrc): 128 | def __init__( 129 | self, 130 | file: str, 131 | selector: DictConfig = None, 132 | question_attr: str = "question", 133 | answers_attr: str = "answers", 134 | id_attr: str = "id", 135 | special_query_token: str = None, 136 | query_special_suffix: str = None, 137 | ): 138 | super().__init__(file, selector, special_query_token, query_special_suffix) 139 | self.question_attr = question_attr 140 | self.answers_attr = answers_attr 141 | self.id_attr = id_attr 142 | 143 | def load_data(self): 144 | super().load_data() 145 | data = [] 146 | with jsonlines.open(self.file, mode="r") as jsonl_reader: 147 | for jline in jsonl_reader: 148 | question = jline[self.question_attr] 149 | answers = jline[self.answers_attr] if self.answers_attr in jline else [] 150 | id = None 151 | if self.id_attr in jline: 152 | id = jline[self.id_attr] 153 | data.append(QASample(self._process_question(question), id, answers)) 154 | self.data = data 155 | 156 | 157 | class KiltCsvQASrc(CsvQASrc): 158 | def __init__( 159 | self, 160 | file: str, 161 | kilt_gold_file: str, 162 | question_col: int = 0, 163 | answers_col: int = 1, 164 | id_col: int = -1, 165 | selector: DictConfig = None, 166 | special_query_token: str = None, 167 | query_special_suffix: str = None, 168 | data_range_start: int = -1, 169 | data_size: int = -1, 170 | ): 171 | super().__init__( 172 | file, 173 | question_col, 174 | answers_col, 175 | id_col, 176 | selector, 177 | special_query_token, 178 | query_special_suffix, 179 | data_range_start, 180 | data_size, 181 | ) 182 | self.kilt_gold_file = kilt_gold_file 183 | 184 | 185 | class KiltJsonlQASrc(JsonlQASrc): 186 | def __init__( 187 | self, 188 | file: str, 189 | kilt_gold_file: str, 190 | question_attr: str = "input", 191 | answers_attr: str = "answer", 192 | id_attr: str = "id", 193 | selector: DictConfig = None, 194 | special_query_token: str = None, 195 | query_special_suffix: str = None, 196 | ): 197 | super().__init__( 198 | file, 199 | selector, 200 | question_attr, 201 | answers_attr, 202 | id_attr, 203 | special_query_token, 204 | query_special_suffix, 205 | ) 206 | self.kilt_gold_file = kilt_gold_file 207 | 208 | def load_data(self): 209 | super().load_data() 210 | data = [] 211 | with jsonlines.open(self.file, mode="r") as jsonl_reader: 212 | for jline in jsonl_reader: 213 | question = jline[self.question_attr] 214 | out = jline["output"] 215 | answers = [o["answer"] for o in out if "answer" in o] 216 | id = None 217 | if self.id_attr in jline: 218 | id = jline[self.id_attr] 219 | data.append(QASample(self._process_question(question), id, answers)) 220 | self.data = data 221 | 222 | 223 | class TTS_ASR_QASrc(QASrc): 224 | def __init__(self, file: str, trans_file: str): 225 | super().__init__(file) 226 | self.trans_file = trans_file 227 | 228 | def load_data(self): 229 | super().load_data() 230 | orig_data_dict = {} 231 | with open(self.file, "r") as ifile: 232 | reader = csv.reader(ifile, delimiter="\t") 233 | id = 0 234 | for row in reader: 235 | question = row[0] 236 | answers = eval(row[1]) 237 | orig_data_dict[id] = (question, answers) 238 | id += 1 239 | data = [] 240 | with open(self.trans_file, "r") as tfile: 241 | reader = csv.reader(tfile, delimiter="\t") 242 | for r in reader: 243 | row_str = r[0] 244 | idx = row_str.index("(None-") 245 | q_id = int(row_str[idx + len("(None-") : -1]) 246 | orig_data = orig_data_dict[q_id] 247 | answers = orig_data[1] 248 | q = row_str[:idx].strip().lower() 249 | data.append(QASample(q, idx, answers)) 250 | self.data = data 251 | 252 | 253 | class CsvCtxSrc(RetrieverData): 254 | def __init__( 255 | self, 256 | file: str, 257 | id_col: int = 0, 258 | text_col: int = 1, 259 | title_col: int = 2, 260 | id_prefix: str = None, 261 | normalize: bool = False, 262 | ): 263 | super().__init__(file) 264 | self.text_col = text_col 265 | self.title_col = title_col 266 | self.id_col = id_col 267 | self.id_prefix = id_prefix 268 | self.normalize = normalize 269 | 270 | def load_data_to(self, ctxs: Dict[object, BiEncoderPassage]): 271 | super().load_data() 272 | logger.info("Reading file %s", self.file) 273 | with open(self.file) as ifile: 274 | reader = csv.reader(ifile, delimiter="\t") 275 | for row in reader: 276 | # for row in ifile: 277 | # row = row.strip().split("\t") 278 | if row[self.id_col] == "id": 279 | continue 280 | if self.id_prefix: 281 | sample_id = self.id_prefix + str(row[self.id_col]) 282 | else: 283 | sample_id = row[self.id_col] 284 | passage = row[self.text_col].strip('"') 285 | if self.normalize: 286 | passage = normalize_passage(passage) 287 | ctxs[sample_id] = BiEncoderPassage(passage, row[self.title_col]) 288 | 289 | 290 | class KiltCsvCtxSrc(CsvCtxSrc): 291 | def __init__( 292 | self, 293 | file: str, 294 | mapping_file: str, 295 | id_col: int = 0, 296 | text_col: int = 1, 297 | title_col: int = 2, 298 | id_prefix: str = None, 299 | normalize: bool = False, 300 | ): 301 | super().__init__(file, id_col, text_col, title_col, id_prefix, normalize=normalize) 302 | self.mapping_file = mapping_file 303 | 304 | def convert_to_kilt(self, kilt_gold_file, dpr_output, kilt_out_file): 305 | logger.info("Converting to KILT format file: %s", dpr_output) 306 | 307 | with open(dpr_output, "rt") as fin: 308 | dpr_output = json.load(fin) 309 | 310 | with jsonlines.open(kilt_gold_file, "r") as reader: 311 | kilt_gold_file = list(reader) 312 | assert len(kilt_gold_file) == len(dpr_output) 313 | map_path = self.mapping_file 314 | with open(map_path, "rb") as fin: 315 | mapping = pickle.load(fin) 316 | 317 | with jsonlines.open(kilt_out_file, mode="w") as writer: 318 | for dpr_entry, kilt_gold_entry in zip(dpr_output, kilt_gold_file): 319 | # assert dpr_entry["question"] == kilt_gold_entry["input"] 320 | provenance = [] 321 | for ctx in dpr_entry["ctxs"]: 322 | wikipedia_id, end_paragraph_id = mapping[int(ctx["id"])] 323 | provenance.append( 324 | { 325 | "wikipedia_id": wikipedia_id, 326 | "end_paragraph_id": end_paragraph_id, 327 | } 328 | ) 329 | kilt_entry = { 330 | "id": kilt_gold_entry["id"], 331 | "input": kilt_gold_entry["input"], # dpr_entry["question"], 332 | "output": [{"provenance": provenance}], 333 | } 334 | writer.write(kilt_entry) 335 | 336 | logger.info("Saved KILT formatted results to: %s", kilt_out_file) 337 | 338 | 339 | class JsonlTablesCtxSrc(object): 340 | def __init__( 341 | self, 342 | file: str, 343 | tables_chunk_sz: int = 100, 344 | split_type: str = "type1", 345 | id_prefix: str = None, 346 | ): 347 | self.tables_chunk_sz = tables_chunk_sz 348 | self.split_type = split_type 349 | self.file = file 350 | self.id_prefix = id_prefix 351 | 352 | def load_data_to(self, ctxs: Dict): 353 | docs = {} 354 | logger.info("Parsing Tables data from: %s", self.file) 355 | tables_dict = read_nq_tables_jsonl(self.file) 356 | table_chunks = split_tables_to_chunks(tables_dict, self.tables_chunk_sz, split_type=self.split_type) 357 | for chunk in table_chunks: 358 | sample_id = self.id_prefix + str(chunk[0]) 359 | docs[sample_id] = TableChunk(chunk[1], chunk[2], chunk[3]) 360 | logger.info("Loaded %d tables chunks", len(docs)) 361 | ctxs.update(docs) 362 | -------------------------------------------------------------------------------- /dpr/indexer/faiss_indexers.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """ 9 | FAISS-based index components for dense retriever 10 | """ 11 | 12 | import faiss 13 | import logging 14 | import numpy as np 15 | import os 16 | import pickle 17 | 18 | from typing import List, Tuple 19 | 20 | logger = logging.getLogger() 21 | 22 | 23 | class DenseIndexer(object): 24 | def __init__(self, buffer_size: int = 50000): 25 | self.buffer_size = buffer_size 26 | self.index_id_to_db_id = [] 27 | self.index = None 28 | 29 | def init_index(self, vector_sz: int): 30 | raise NotImplementedError 31 | 32 | def index_data(self, data: List[Tuple[object, np.array]]): 33 | raise NotImplementedError 34 | 35 | def get_index_name(self): 36 | raise NotImplementedError 37 | 38 | def search_knn(self, query_vectors: np.array, top_docs: int) -> List[Tuple[List[object], List[float]]]: 39 | raise NotImplementedError 40 | 41 | def serialize(self, file: str): 42 | logger.info("Serializing index to %s", file) 43 | 44 | if os.path.isdir(file): 45 | index_file = os.path.join(file, "index.dpr") 46 | meta_file = os.path.join(file, "index_meta.dpr") 47 | else: 48 | index_file = file + ".index.dpr" 49 | meta_file = file + ".index_meta.dpr" 50 | 51 | faiss.write_index(self.index, index_file) 52 | with open(meta_file, mode="wb") as f: 53 | pickle.dump(self.index_id_to_db_id, f) 54 | 55 | def get_files(self, path: str): 56 | if os.path.isdir(path): 57 | index_file = os.path.join(path, "index.dpr") 58 | meta_file = os.path.join(path, "index_meta.dpr") 59 | else: 60 | index_file = path + ".{}.dpr".format(self.get_index_name()) 61 | meta_file = path + ".{}_meta.dpr".format(self.get_index_name()) 62 | return index_file, meta_file 63 | 64 | def index_exists(self, path: str): 65 | index_file, meta_file = self.get_files(path) 66 | return os.path.isfile(index_file) and os.path.isfile(meta_file) 67 | 68 | def deserialize(self, path: str): 69 | logger.info("Loading index from %s", path) 70 | index_file, meta_file = self.get_files(path) 71 | 72 | self.index = faiss.read_index(index_file) 73 | logger.info("Loaded index of type %s and size %d", type(self.index), self.index.ntotal) 74 | 75 | with open(meta_file, "rb") as reader: 76 | self.index_id_to_db_id = pickle.load(reader) 77 | assert ( 78 | len(self.index_id_to_db_id) == self.index.ntotal 79 | ), "Deserialized index_id_to_db_id should match faiss index size" 80 | 81 | def _update_id_mapping(self, db_ids: List) -> int: 82 | self.index_id_to_db_id.extend(db_ids) 83 | return len(self.index_id_to_db_id) 84 | 85 | 86 | class DenseFlatIndexer(DenseIndexer): 87 | def __init__(self, buffer_size: int = 50000): 88 | super(DenseFlatIndexer, self).__init__(buffer_size=buffer_size) 89 | 90 | def init_index(self, vector_sz: int): 91 | self.index = faiss.IndexFlatIP(vector_sz) 92 | 93 | def index_data(self, data: List[Tuple[object, np.array]]): 94 | n = len(data) 95 | # indexing in batches is beneficial for many faiss index types 96 | for i in range(0, n, self.buffer_size): 97 | db_ids = [t[0] for t in data[i : i + self.buffer_size]] 98 | vectors = [np.reshape(t[1], (1, -1)) for t in data[i : i + self.buffer_size]] 99 | vectors = np.concatenate(vectors, axis=0) 100 | total_data = self._update_id_mapping(db_ids) 101 | self.index.add(vectors) 102 | logger.info("data indexed %d", total_data) 103 | 104 | indexed_cnt = len(self.index_id_to_db_id) 105 | logger.info("Total data indexed %d", indexed_cnt) 106 | 107 | def search_knn(self, query_vectors: np.array, top_docs: int) -> List[Tuple[List[object], List[float]]]: 108 | scores, indexes = self.index.search(query_vectors, top_docs) 109 | # convert to external ids 110 | db_ids = [[self.index_id_to_db_id[i] for i in query_top_idxs] for query_top_idxs in indexes] 111 | result = [(db_ids[i], scores[i]) for i in range(len(db_ids))] 112 | return result 113 | 114 | def get_index_name(self): 115 | return "flat_index" 116 | 117 | 118 | class DenseHNSWFlatIndexer(DenseIndexer): 119 | """ 120 | Efficient index for retrieval. Note: default settings are for hugh accuracy but also high RAM usage 121 | """ 122 | 123 | def __init__( 124 | self, 125 | buffer_size: int = 1e9, 126 | store_n: int = 512, 127 | ef_search: int = 128, 128 | ef_construction: int = 200, 129 | ): 130 | super(DenseHNSWFlatIndexer, self).__init__(buffer_size=buffer_size) 131 | self.store_n = store_n 132 | self.ef_search = ef_search 133 | self.ef_construction = ef_construction 134 | self.phi = 0 135 | 136 | def init_index(self, vector_sz: int): 137 | # IndexHNSWFlat supports L2 similarity only 138 | # so we have to apply DOT -> L2 similairy space conversion with the help of an extra dimension 139 | index = faiss.IndexHNSWFlat(vector_sz + 1, self.store_n) 140 | index.hnsw.efSearch = self.ef_search 141 | index.hnsw.efConstruction = self.ef_construction 142 | self.index = index 143 | 144 | def index_data(self, data: List[Tuple[object, np.array]]): 145 | n = len(data) 146 | 147 | # max norm is required before putting all vectors in the index to convert inner product similarity to L2 148 | if self.phi > 0: 149 | raise RuntimeError( 150 | "DPR HNSWF index needs to index all data at once," "results will be unpredictable otherwise." 151 | ) 152 | phi = 0 153 | for i, item in enumerate(data): 154 | id, doc_vector = item[0:2] 155 | norms = (doc_vector ** 2).sum() 156 | phi = max(phi, norms) 157 | logger.info("HNSWF DotProduct -> L2 space phi={}".format(phi)) 158 | self.phi = phi 159 | 160 | # indexing in batches is beneficial for many faiss index types 161 | bs = int(self.buffer_size) 162 | for i in range(0, n, bs): 163 | db_ids = [t[0] for t in data[i : i + bs]] 164 | vectors = [np.reshape(t[1], (1, -1)) for t in data[i : i + bs]] 165 | 166 | norms = [(doc_vector ** 2).sum() for doc_vector in vectors] 167 | aux_dims = [np.sqrt(phi - norm) for norm in norms] 168 | hnsw_vectors = [np.hstack((doc_vector, aux_dims[i].reshape(-1, 1))) for i, doc_vector in enumerate(vectors)] 169 | hnsw_vectors = np.concatenate(hnsw_vectors, axis=0) 170 | self.train(hnsw_vectors) 171 | 172 | self._update_id_mapping(db_ids) 173 | self.index.add(hnsw_vectors) 174 | logger.info("data indexed %d", len(self.index_id_to_db_id)) 175 | indexed_cnt = len(self.index_id_to_db_id) 176 | logger.info("Total data indexed %d", indexed_cnt) 177 | 178 | def train(self, vectors: np.array): 179 | pass 180 | 181 | def search_knn(self, query_vectors: np.array, top_docs: int) -> List[Tuple[List[object], List[float]]]: 182 | 183 | aux_dim = np.zeros(len(query_vectors), dtype="float32") 184 | query_nhsw_vectors = np.hstack((query_vectors, aux_dim.reshape(-1, 1))) 185 | logger.info("query_hnsw_vectors %s", query_nhsw_vectors.shape) 186 | scores, indexes = self.index.search(query_nhsw_vectors, top_docs) 187 | # convert to external ids 188 | db_ids = [[self.index_id_to_db_id[i] for i in query_top_idxs] for query_top_idxs in indexes] 189 | result = [(db_ids[i], scores[i]) for i in range(len(db_ids))] 190 | return result 191 | 192 | def deserialize(self, file: str): 193 | super(DenseHNSWFlatIndexer, self).deserialize(file) 194 | # to trigger exception on subsequent indexing 195 | self.phi = 1 196 | 197 | def get_index_name(self): 198 | return "hnsw_index" 199 | 200 | 201 | class DenseHNSWSQIndexer(DenseHNSWFlatIndexer): 202 | """ 203 | Efficient index for retrieval. Note: default settings are for hugh accuracy but also high RAM usage 204 | """ 205 | 206 | def __init__( 207 | self, 208 | buffer_size: int = 1e10, 209 | store_n: int = 128, 210 | ef_search: int = 128, 211 | ef_construction: int = 200, 212 | ): 213 | super(DenseHNSWSQIndexer, self).__init__( 214 | buffer_size=buffer_size, 215 | store_n=store_n, 216 | ef_search=ef_search, 217 | ef_construction=ef_construction, 218 | ) 219 | 220 | def init_index(self, vector_sz: int): 221 | # IndexHNSWFlat supports L2 similarity only 222 | # so we have to apply DOT -> L2 similairy space conversion with the help of an extra dimension 223 | index = faiss.IndexHNSWSQ(vector_sz + 1, faiss.ScalarQuantizer.QT_8bit, self.store_n) 224 | index.hnsw.efSearch = self.ef_search 225 | index.hnsw.efConstruction = self.ef_construction 226 | self.index = index 227 | 228 | def train(self, vectors: np.array): 229 | self.index.train(vectors) 230 | 231 | def get_index_name(self): 232 | return "hnswsq_index" 233 | -------------------------------------------------------------------------------- /dpr/models/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import importlib 9 | 10 | """ 11 | 'Router'-like set of methods for component initialization with lazy imports 12 | """ 13 | 14 | 15 | def init_hf_bert_biencoder(args, **kwargs): 16 | if importlib.util.find_spec("transformers") is None: 17 | raise RuntimeError("Please install transformers lib") 18 | from .hf_models import get_bert_biencoder_components 19 | 20 | return get_bert_biencoder_components(args, **kwargs) 21 | 22 | 23 | def init_hf_bert_reader(args, **kwargs): 24 | if importlib.util.find_spec("transformers") is None: 25 | raise RuntimeError("Please install transformers lib") 26 | from .hf_models import get_bert_reader_components 27 | 28 | return get_bert_reader_components(args, **kwargs) 29 | 30 | 31 | def init_pytext_bert_biencoder(args, **kwargs): 32 | if importlib.util.find_spec("pytext") is None: 33 | raise RuntimeError("Please install pytext lib") 34 | from .pytext_models import get_bert_biencoder_components 35 | 36 | return get_bert_biencoder_components(args, **kwargs) 37 | 38 | 39 | def init_fairseq_roberta_biencoder(args, **kwargs): 40 | if importlib.util.find_spec("fairseq") is None: 41 | raise RuntimeError("Please install fairseq lib") 42 | from .fairseq_models import get_roberta_biencoder_components 43 | 44 | return get_roberta_biencoder_components(args, **kwargs) 45 | 46 | 47 | def init_hf_bert_tenzorizer(args, **kwargs): 48 | if importlib.util.find_spec("transformers") is None: 49 | raise RuntimeError("Please install transformers lib") 50 | from .hf_models import get_bert_tensorizer 51 | 52 | return get_bert_tensorizer(args) 53 | 54 | 55 | def init_hf_roberta_tenzorizer(args, **kwargs): 56 | if importlib.util.find_spec("transformers") is None: 57 | raise RuntimeError("Please install transformers lib") 58 | from .hf_models import get_roberta_tensorizer 59 | return get_roberta_tensorizer(args.encoder.pretrained_model_cfg, args.do_lower_case, args.encoder.sequence_length) 60 | 61 | 62 | BIENCODER_INITIALIZERS = { 63 | "hf_bert": init_hf_bert_biencoder, 64 | "pytext_bert": init_pytext_bert_biencoder, 65 | "fairseq_roberta": init_fairseq_roberta_biencoder, 66 | } 67 | 68 | READER_INITIALIZERS = { 69 | "hf_bert": init_hf_bert_reader, 70 | } 71 | 72 | TENSORIZER_INITIALIZERS = { 73 | "hf_bert": init_hf_bert_tenzorizer, 74 | "hf_roberta": init_hf_roberta_tenzorizer, 75 | "pytext_bert": init_hf_bert_tenzorizer, # using HF's code as of now 76 | "fairseq_roberta": init_hf_roberta_tenzorizer, # using HF's code as of now 77 | } 78 | 79 | 80 | def init_comp(initializers_dict, type, args, **kwargs): 81 | if type in initializers_dict: 82 | return initializers_dict[type](args, **kwargs) 83 | else: 84 | raise RuntimeError("unsupported model type: {}".format(type)) 85 | 86 | 87 | def init_biencoder_components(encoder_type: str, args, **kwargs): 88 | return init_comp(BIENCODER_INITIALIZERS, encoder_type, args, **kwargs) 89 | 90 | 91 | def init_reader_components(encoder_type: str, args, **kwargs): 92 | return init_comp(READER_INITIALIZERS, encoder_type, args, **kwargs) 93 | 94 | 95 | def init_tenzorizer(encoder_type: str, args, **kwargs): 96 | return init_comp(TENSORIZER_INITIALIZERS, encoder_type, args, **kwargs) 97 | -------------------------------------------------------------------------------- /dpr/models/biencoder.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """ 9 | BiEncoder component + loss function for 'all-in-batch' training 10 | """ 11 | 12 | import collections 13 | import logging 14 | import random 15 | from typing import Tuple, List 16 | 17 | import numpy as np 18 | import torch 19 | import torch.nn.functional as F 20 | from torch import Tensor as T 21 | from torch import nn 22 | 23 | from dpr.data.biencoder_data import BiEncoderSample 24 | from dpr.utils.data_utils import Tensorizer 25 | from dpr.utils.model_utils import CheckpointState 26 | 27 | logger = logging.getLogger(__name__) 28 | 29 | BiEncoderBatch = collections.namedtuple( 30 | "BiENcoderInput", 31 | [ 32 | "question_ids", 33 | "question_segments", 34 | "context_ids", 35 | "ctx_segments", 36 | "is_positive", 37 | "hard_negatives", 38 | "encoder_type", 39 | ], 40 | ) 41 | # TODO: it is only used by _select_span_with_token. Move them to utils 42 | rnd = random.Random(0) 43 | 44 | 45 | def dot_product_scores(q_vectors: T, ctx_vectors: T) -> T: 46 | """ 47 | calculates q->ctx scores for every row in ctx_vector 48 | :param q_vector: 49 | :param ctx_vector: 50 | :return: 51 | """ 52 | # q_vector: n1 x D, ctx_vectors: n2 x D, result n1 x n2 53 | r = torch.matmul(q_vectors, torch.transpose(ctx_vectors, 0, 1)) 54 | return r 55 | 56 | 57 | def cosine_scores(q_vector: T, ctx_vectors: T): 58 | # q_vector: n1 x D, ctx_vectors: n2 x D, result n1 x n2 59 | return F.cosine_similarity(q_vector, ctx_vectors, dim=1) 60 | 61 | 62 | class BiEncoder(nn.Module): 63 | """Bi-Encoder model component. Encapsulates query/question and context/passage encoders.""" 64 | 65 | def __init__( 66 | self, 67 | question_model: nn.Module, 68 | ctx_model: nn.Module, 69 | fix_q_encoder: bool = False, 70 | fix_ctx_encoder: bool = False, 71 | ): 72 | super(BiEncoder, self).__init__() 73 | self.question_model = question_model 74 | self.ctx_model = ctx_model 75 | self.fix_q_encoder = fix_q_encoder 76 | self.fix_ctx_encoder = fix_ctx_encoder 77 | 78 | @staticmethod 79 | def get_representation( 80 | sub_model: nn.Module, 81 | ids: T, 82 | segments: T, 83 | attn_mask: T, 84 | fix_encoder: bool = False, 85 | representation_token_pos=0, 86 | ) -> (T, T, T): 87 | sequence_output = None 88 | pooled_output = None 89 | hidden_states = None 90 | if ids is not None: 91 | if fix_encoder: 92 | with torch.no_grad(): 93 | sequence_output, pooled_output, hidden_states = sub_model( 94 | ids, 95 | segments, 96 | attn_mask, 97 | representation_token_pos=representation_token_pos, 98 | ) 99 | 100 | if sub_model.training: 101 | sequence_output.requires_grad_(requires_grad=True) 102 | pooled_output.requires_grad_(requires_grad=True) 103 | else: 104 | sequence_output, pooled_output, hidden_states = sub_model( 105 | ids, 106 | segments, 107 | attn_mask, 108 | representation_token_pos=representation_token_pos, 109 | ) 110 | 111 | return sequence_output, pooled_output, hidden_states 112 | 113 | def forward( 114 | self, 115 | question_ids: T, 116 | question_segments: T, 117 | question_attn_mask: T, 118 | context_ids: T, 119 | ctx_segments: T, 120 | ctx_attn_mask: T, 121 | encoder_type: str = None, 122 | representation_token_pos=0, 123 | ) -> Tuple[T, T]: 124 | q_encoder = self.question_model if encoder_type is None or encoder_type == "question" else self.ctx_model 125 | _q_seq, q_pooled_out, _q_hidden = self.get_representation( 126 | q_encoder, 127 | question_ids, 128 | question_segments, 129 | question_attn_mask, 130 | self.fix_q_encoder, 131 | representation_token_pos=representation_token_pos, 132 | ) 133 | 134 | ctx_encoder = self.ctx_model if encoder_type is None or encoder_type == "ctx" else self.question_model 135 | _ctx_seq, ctx_pooled_out, _ctx_hidden = self.get_representation( 136 | ctx_encoder, context_ids, ctx_segments, ctx_attn_mask, self.fix_ctx_encoder 137 | ) 138 | 139 | return q_pooled_out, ctx_pooled_out 140 | 141 | def create_biencoder_input( 142 | self, 143 | samples: List[BiEncoderSample], 144 | tensorizer: Tensorizer, 145 | insert_title: bool, 146 | num_hard_negatives: int = 0, 147 | num_other_negatives: int = 0, 148 | shuffle: bool = True, 149 | shuffle_positives: bool = False, 150 | hard_neg_fallback: bool = True, 151 | query_token: str = None, 152 | ) -> BiEncoderBatch: 153 | """ 154 | Creates a batch of the biencoder training tuple. 155 | :param samples: list of BiEncoderSample-s to create the batch for 156 | :param tensorizer: components to create model input tensors from a text sequence 157 | :param insert_title: enables title insertion at the beginning of the context sequences 158 | :param num_hard_negatives: amount of hard negatives per question (taken from samples' pools) 159 | :param num_other_negatives: amount of other negatives per question (taken from samples' pools) 160 | :param shuffle: shuffles negative passages pools 161 | :param shuffle_positives: shuffles positive passages pools 162 | :return: BiEncoderBatch tuple 163 | """ 164 | question_tensors = [] 165 | ctx_tensors = [] 166 | positive_ctx_indices = [] 167 | hard_neg_ctx_indices = [] 168 | 169 | for sample in samples: 170 | # ctx+ & [ctx-] composition 171 | # as of now, take the first(gold) ctx+ only 172 | 173 | if shuffle and shuffle_positives: 174 | positive_ctxs = sample.positive_passages 175 | positive_ctx = positive_ctxs[np.random.choice(len(positive_ctxs))] 176 | else: 177 | positive_ctx = sample.positive_passages[0] 178 | 179 | neg_ctxs = sample.negative_passages 180 | hard_neg_ctxs = sample.hard_negative_passages 181 | question = sample.query 182 | # question = normalize_question(sample.query) 183 | 184 | if shuffle: 185 | random.shuffle(neg_ctxs) 186 | random.shuffle(hard_neg_ctxs) 187 | 188 | if hard_neg_fallback and len(hard_neg_ctxs) == 0: 189 | hard_neg_ctxs = neg_ctxs[0:num_hard_negatives] 190 | 191 | neg_ctxs = neg_ctxs[0:num_other_negatives] 192 | hard_neg_ctxs = hard_neg_ctxs[0:num_hard_negatives] 193 | 194 | all_ctxs = [positive_ctx] + neg_ctxs + hard_neg_ctxs 195 | hard_negatives_start_idx = 1 196 | hard_negatives_end_idx = 1 + len(hard_neg_ctxs) 197 | 198 | current_ctxs_len = len(ctx_tensors) 199 | 200 | sample_ctxs_tensors = [ 201 | tensorizer.text_to_tensor(ctx.text, title=ctx.title if (insert_title and ctx.title) else None) 202 | for ctx in all_ctxs 203 | ] 204 | 205 | ctx_tensors.extend(sample_ctxs_tensors) 206 | positive_ctx_indices.append(current_ctxs_len) 207 | hard_neg_ctx_indices.append( 208 | [ 209 | i 210 | for i in range( 211 | current_ctxs_len + hard_negatives_start_idx, 212 | current_ctxs_len + hard_negatives_end_idx, 213 | ) 214 | ] 215 | ) 216 | 217 | if query_token: 218 | # TODO: tmp workaround for EL, remove or revise 219 | if query_token == "[START_ENT]": 220 | query_span = _select_span_with_token(question, tensorizer, token_str=query_token) 221 | question_tensors.append(query_span) 222 | else: 223 | question_tensors.append(tensorizer.text_to_tensor(" ".join([query_token, question]))) 224 | else: 225 | question_tensors.append(tensorizer.text_to_tensor(question)) 226 | 227 | ctxs_tensor = torch.cat([ctx.view(1, -1) for ctx in ctx_tensors], dim=0) 228 | questions_tensor = torch.cat([q.view(1, -1) for q in question_tensors], dim=0) 229 | 230 | ctx_segments = torch.zeros_like(ctxs_tensor) 231 | question_segments = torch.zeros_like(questions_tensor) 232 | 233 | return BiEncoderBatch( 234 | questions_tensor, 235 | question_segments, 236 | ctxs_tensor, 237 | ctx_segments, 238 | positive_ctx_indices, 239 | hard_neg_ctx_indices, 240 | "question", 241 | ) 242 | 243 | def load_state(self, saved_state: CheckpointState, strict: bool = True): 244 | # TODO: make a long term HF compatibility fix 245 | # if "question_model.embeddings.position_ids" in saved_state.model_dict: 246 | # del saved_state.model_dict["question_model.embeddings.position_ids"] 247 | # del saved_state.model_dict["ctx_model.embeddings.position_ids"] 248 | self.load_state_dict(saved_state.model_dict, strict=strict) 249 | 250 | def get_state_dict(self): 251 | return self.state_dict() 252 | 253 | 254 | class BiEncoderNllLoss(object): 255 | def calc( 256 | self, 257 | q_vectors: T, 258 | ctx_vectors: T, 259 | positive_idx_per_question: list, 260 | hard_negative_idx_per_question: list = None, 261 | loss_scale: float = None, 262 | ) -> Tuple[T, int]: 263 | """ 264 | Computes nll loss for the given lists of question and ctx vectors. 265 | Note that although hard_negative_idx_per_question in not currently in use, one can use it for the 266 | loss modifications. For example - weighted NLL with different factors for hard vs regular negatives. 267 | :return: a tuple of loss value and amount of correct predictions per batch 268 | """ 269 | scores = self.get_scores(q_vectors, ctx_vectors) 270 | 271 | if len(q_vectors.size()) > 1: 272 | q_num = q_vectors.size(0) 273 | scores = scores.view(q_num, -1) 274 | 275 | softmax_scores = F.log_softmax(scores, dim=1) 276 | 277 | loss = F.nll_loss( 278 | softmax_scores, 279 | torch.tensor(positive_idx_per_question).to(softmax_scores.device), 280 | reduction="mean", 281 | ) 282 | 283 | max_score, max_idxs = torch.max(softmax_scores, 1) 284 | correct_predictions_count = (max_idxs == torch.tensor(positive_idx_per_question).to(max_idxs.device)).sum() 285 | 286 | if loss_scale: 287 | loss.mul_(loss_scale) 288 | 289 | return loss, correct_predictions_count 290 | 291 | @staticmethod 292 | def get_scores(q_vector: T, ctx_vectors: T) -> T: 293 | f = BiEncoderNllLoss.get_similarity_function() 294 | return f(q_vector, ctx_vectors) 295 | 296 | @staticmethod 297 | def get_similarity_function(): 298 | return dot_product_scores 299 | 300 | 301 | def _select_span_with_token(text: str, tensorizer: Tensorizer, token_str: str = "[START_ENT]") -> T: 302 | id = tensorizer.get_token_id(token_str) 303 | query_tensor = tensorizer.text_to_tensor(text) 304 | 305 | if id not in query_tensor: 306 | query_tensor_full = tensorizer.text_to_tensor(text, apply_max_len=False) 307 | token_indexes = (query_tensor_full == id).nonzero() 308 | if token_indexes.size(0) > 0: 309 | start_pos = token_indexes[0, 0].item() 310 | # add some randomization to avoid overfitting to a specific token position 311 | 312 | left_shit = int(tensorizer.max_length / 2) 313 | rnd_shift = int((rnd.random() - 0.5) * left_shit / 2) 314 | left_shit += rnd_shift 315 | 316 | query_tensor = query_tensor_full[start_pos - left_shit :] 317 | cls_id = tensorizer.tokenizer.cls_token_id 318 | if query_tensor[0] != cls_id: 319 | query_tensor = torch.cat([torch.tensor([cls_id]), query_tensor], dim=0) 320 | 321 | from dpr.models.reader import _pad_to_len 322 | 323 | query_tensor = _pad_to_len(query_tensor, tensorizer.get_pad_id(), tensorizer.max_length) 324 | query_tensor[-1] = tensorizer.tokenizer.sep_token_id 325 | # logger.info('aligned query_tensor %s', query_tensor) 326 | 327 | assert id in query_tensor, "query_tensor={}".format(query_tensor) 328 | return query_tensor 329 | else: 330 | raise RuntimeError("[START_ENT] toke not found for Entity Linking sample query={}".format(text)) 331 | else: 332 | return query_tensor 333 | -------------------------------------------------------------------------------- /dpr/models/fairseq_models.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """ 9 | Encoder model wrappers based on Fairseq code 10 | """ 11 | import collections 12 | import logging 13 | from typing import Tuple 14 | 15 | from fairseq.models.roberta.hub_interface import RobertaHubInterface 16 | from fairseq.models.roberta.model import RobertaModel as FairseqRobertaModel 17 | from torch import Tensor as T 18 | from torch import nn 19 | 20 | from dpr.models.hf_models import get_roberta_tensorizer 21 | from dpr.utils.data_utils import Tensorizer 22 | from fairseq.optim.adam import FairseqAdam 23 | from .biencoder import BiEncoder 24 | 25 | logger = logging.getLogger(__name__) 26 | 27 | FairseqOptCfg = collections.namedtuple("FairseqOptCfg", ["lr", "adam_betas", "adam_eps", "weight_decay"]) 28 | 29 | 30 | def get_roberta_biencoder_components(args, inference_only: bool = False, **kwargs): 31 | question_encoder = RobertaEncoder.from_pretrained(args.encoder.pretrained_file) 32 | ctx_encoder = RobertaEncoder.from_pretrained(args.encoder.pretrained_file) 33 | biencoder = BiEncoder(question_encoder, ctx_encoder) 34 | optimizer = get_fairseq_adamw_optimizer(biencoder, args) if not inference_only else None 35 | tensorizer = get_roberta_tensorizer( 36 | args.encoder.pretrained_model_cfg, args.do_lower_case, args.encoder.sequence_length 37 | ) 38 | return tensorizer, biencoder, optimizer 39 | 40 | 41 | def get_fairseq_adamw_optimizer(model: nn.Module, args): 42 | cfg = FairseqOptCfg(args.train.learning_rate, args.train.adam_betas, args.train.adam_eps, args.train.weight_decay) 43 | return FairseqAdam(cfg, model.parameters()).optimizer 44 | 45 | 46 | class RobertaEncoder(nn.Module): 47 | def __init__(self, fairseq_roberta_hub: RobertaHubInterface): 48 | super(RobertaEncoder, self).__init__() 49 | self.fairseq_roberta = fairseq_roberta_hub 50 | 51 | @classmethod 52 | def from_pretrained(cls, pretrained_dir_path: str): 53 | model = FairseqRobertaModel.from_pretrained(pretrained_dir_path) 54 | return cls(model) 55 | 56 | def forward( 57 | self, 58 | input_ids: T, 59 | token_type_ids: T, 60 | attention_mask: T, 61 | representation_token_pos=0, 62 | ) -> Tuple[T, ...]: 63 | roberta_out = self.fairseq_roberta.extract_features(input_ids) 64 | cls_out = roberta_out[:, representation_token_pos, :] 65 | return roberta_out, cls_out, None 66 | 67 | def get_out_size(self): 68 | raise NotImplementedError 69 | 70 | 71 | def get_roberta_encoder_components( 72 | pretrained_file: str, pretrained_model_cfg: str, do_lower_case: bool, sequence_length: int 73 | ) -> Tuple[RobertaEncoder, Tensorizer]: 74 | encoder = RobertaEncoder.from_pretrained(pretrained_file) 75 | tensorizer = get_roberta_tensorizer(pretrained_model_cfg, do_lower_case, sequence_length) 76 | return encoder, tensorizer 77 | -------------------------------------------------------------------------------- /dpr/models/hf_models.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """ 9 | Encoder model wrappers based on HuggingFace code 10 | """ 11 | 12 | import logging 13 | from typing import Tuple, List 14 | 15 | import torch 16 | import transformers 17 | from torch import Tensor as T 18 | from torch import nn 19 | 20 | 21 | if transformers.__version__.startswith("4"): 22 | from transformers import BertConfig, BertModel 23 | from transformers import AdamW 24 | from transformers import BertTokenizer 25 | from transformers import RobertaTokenizer 26 | else: 27 | from transformers.modeling_bert import BertConfig, BertModel 28 | from transformers.optimization import AdamW 29 | from transformers.tokenization_bert import BertTokenizer 30 | from transformers.tokenization_roberta import RobertaTokenizer 31 | 32 | from dpr.utils.data_utils import Tensorizer 33 | from dpr.models.biencoder import BiEncoder 34 | from .reader import Reader 35 | 36 | logger = logging.getLogger(__name__) 37 | 38 | 39 | def get_bert_biencoder_components(cfg, inference_only: bool = False, **kwargs): 40 | dropout = cfg.encoder.dropout if hasattr(cfg.encoder, "dropout") else 0.0 41 | question_encoder = HFBertEncoder.init_encoder( 42 | cfg.encoder.pretrained_model_cfg, 43 | projection_dim=cfg.encoder.projection_dim, 44 | dropout=dropout, 45 | pretrained=cfg.encoder.pretrained, 46 | **kwargs 47 | ) 48 | ctx_encoder = HFBertEncoder.init_encoder( 49 | cfg.encoder.pretrained_model_cfg, 50 | projection_dim=cfg.encoder.projection_dim, 51 | dropout=dropout, 52 | pretrained=cfg.encoder.pretrained, 53 | **kwargs 54 | ) 55 | 56 | fix_ctx_encoder = cfg.encoder.fix_ctx_encoder if hasattr(cfg.encoder, "fix_ctx_encoder") else False 57 | biencoder = BiEncoder(question_encoder, ctx_encoder, fix_ctx_encoder=fix_ctx_encoder) 58 | 59 | optimizer = ( 60 | get_optimizer( 61 | biencoder, 62 | learning_rate=cfg.train.learning_rate, 63 | adam_eps=cfg.train.adam_eps, 64 | weight_decay=cfg.train.weight_decay, 65 | ) 66 | if not inference_only 67 | else None 68 | ) 69 | 70 | tensorizer = get_bert_tensorizer(cfg) 71 | return tensorizer, biencoder, optimizer 72 | 73 | 74 | def get_bert_reader_components(cfg, inference_only: bool = False, **kwargs): 75 | dropout = cfg.encoder.dropout if hasattr(cfg.encoder, "dropout") else 0.0 76 | encoder = HFBertEncoder.init_encoder( 77 | cfg.encoder.pretrained_model_cfg, 78 | projection_dim=cfg.encoder.projection_dim, 79 | dropout=dropout, 80 | pretrained=cfg.encoder.pretrained, 81 | **kwargs 82 | ) 83 | 84 | hidden_size = encoder.config.hidden_size 85 | reader = Reader(encoder, hidden_size) 86 | 87 | optimizer = ( 88 | get_optimizer( 89 | reader, 90 | learning_rate=cfg.train.learning_rate, 91 | adam_eps=cfg.train.adam_eps, 92 | weight_decay=cfg.train.weight_decay, 93 | ) 94 | if not inference_only 95 | else None 96 | ) 97 | 98 | tensorizer = get_bert_tensorizer(cfg) 99 | return tensorizer, reader, optimizer 100 | 101 | 102 | # TODO: unify tensorizer init methods 103 | def get_bert_tensorizer(cfg): 104 | sequence_length = cfg.encoder.sequence_length 105 | pretrained_model_cfg = cfg.encoder.pretrained_model_cfg 106 | tokenizer = get_bert_tokenizer(pretrained_model_cfg, do_lower_case=cfg.do_lower_case) 107 | if cfg.special_tokens: 108 | _add_special_tokens(tokenizer, cfg.special_tokens) 109 | 110 | return BertTensorizer(tokenizer, sequence_length) 111 | 112 | 113 | def get_bert_tensorizer_p( 114 | pretrained_model_cfg: str, sequence_length: int, do_lower_case: bool = True, special_tokens: List[str] = [] 115 | ): 116 | tokenizer = get_bert_tokenizer(pretrained_model_cfg, do_lower_case=do_lower_case) 117 | if special_tokens: 118 | _add_special_tokens(tokenizer, special_tokens) 119 | return BertTensorizer(tokenizer, sequence_length) 120 | 121 | 122 | def _add_special_tokens(tokenizer, special_tokens): 123 | logger.info("Adding special tokens %s", special_tokens) 124 | logger.info("Tokenizer: %s", type(tokenizer)) 125 | special_tokens_num = len(special_tokens) 126 | # TODO: this is a hack-y logic that uses some private tokenizer structure which can be changed in HF code 127 | 128 | assert special_tokens_num < 500 129 | unused_ids = [tokenizer.vocab["[unused{}]".format(i)] for i in range(special_tokens_num)] 130 | logger.info("Utilizing the following unused token ids %s", unused_ids) 131 | 132 | for idx, id in enumerate(unused_ids): 133 | old_token = "[unused{}]".format(idx) 134 | del tokenizer.vocab[old_token] 135 | new_token = special_tokens[idx] 136 | tokenizer.vocab[new_token] = id 137 | tokenizer.ids_to_tokens[id] = new_token 138 | logging.debug("new token %s id=%s", new_token, id) 139 | 140 | tokenizer.additional_special_tokens = list(special_tokens) 141 | logger.info("additional_special_tokens %s", tokenizer.additional_special_tokens) 142 | logger.info("all_special_tokens_extended: %s", tokenizer.all_special_tokens_extended) 143 | logger.info("additional_special_tokens_ids: %s", tokenizer.additional_special_tokens_ids) 144 | logger.info("all_special_tokens %s", tokenizer.all_special_tokens) 145 | 146 | 147 | def get_roberta_tensorizer(pretrained_model_cfg: str, do_lower_case: bool, sequence_length: int): 148 | tokenizer = get_roberta_tokenizer(pretrained_model_cfg, do_lower_case=do_lower_case) 149 | return RobertaTensorizer(tokenizer, sequence_length) 150 | 151 | 152 | def get_optimizer( 153 | model: nn.Module, 154 | learning_rate: float = 1e-5, 155 | adam_eps: float = 1e-8, 156 | weight_decay: float = 0.0, 157 | ) -> torch.optim.Optimizer: 158 | optimizer_grouped_parameters = get_hf_model_param_grouping(model, weight_decay) 159 | return get_optimizer_grouped(optimizer_grouped_parameters, learning_rate, adam_eps) 160 | 161 | 162 | def get_hf_model_param_grouping( 163 | model: nn.Module, 164 | weight_decay: float = 0.0, 165 | ): 166 | no_decay = ["bias", "LayerNorm.weight"] 167 | 168 | return [ 169 | { 170 | "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 171 | "weight_decay": weight_decay, 172 | }, 173 | { 174 | "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 175 | "weight_decay": 0.0, 176 | }, 177 | ] 178 | 179 | 180 | def get_optimizer_grouped( 181 | optimizer_grouped_parameters: List, 182 | learning_rate: float = 1e-5, 183 | adam_eps: float = 1e-8, 184 | ) -> torch.optim.Optimizer: 185 | 186 | optimizer = AdamW(optimizer_grouped_parameters, lr=learning_rate, eps=adam_eps) 187 | return optimizer 188 | 189 | 190 | def get_bert_tokenizer(pretrained_cfg_name: str, do_lower_case: bool = True): 191 | return BertTokenizer.from_pretrained(pretrained_cfg_name, do_lower_case=do_lower_case) 192 | 193 | 194 | def get_roberta_tokenizer(pretrained_cfg_name: str, do_lower_case: bool = True): 195 | # still uses HF code for tokenizer since they are the same 196 | return RobertaTokenizer.from_pretrained(pretrained_cfg_name, do_lower_case=do_lower_case) 197 | 198 | 199 | class HFBertEncoder(BertModel): 200 | def __init__(self, config, project_dim: int = 0): 201 | BertModel.__init__(self, config) 202 | assert config.hidden_size > 0, "Encoder hidden_size can't be zero" 203 | self.encode_proj = nn.Linear(config.hidden_size, project_dim) if project_dim != 0 else None 204 | self.init_weights() 205 | 206 | @classmethod 207 | def init_encoder( 208 | cls, cfg_name: str, projection_dim: int = 0, dropout: float = 0.1, pretrained: bool = True, **kwargs 209 | ) -> BertModel: 210 | logger.info("Initializing HF BERT Encoder. cfg_name=%s", cfg_name) 211 | cfg = BertConfig.from_pretrained(cfg_name if cfg_name else "bert-base-uncased") 212 | if dropout != 0: 213 | cfg.attention_probs_dropout_prob = dropout 214 | cfg.hidden_dropout_prob = dropout 215 | 216 | if pretrained: 217 | return cls.from_pretrained(cfg_name, config=cfg, project_dim=projection_dim, **kwargs) 218 | else: 219 | return HFBertEncoder(cfg, project_dim=projection_dim) 220 | 221 | def forward( 222 | self, 223 | input_ids: T, 224 | token_type_ids: T, 225 | attention_mask: T, 226 | representation_token_pos=0, 227 | ) -> Tuple[T, ...]: 228 | 229 | out = super().forward( 230 | input_ids=input_ids, 231 | token_type_ids=token_type_ids, 232 | attention_mask=attention_mask, 233 | ) 234 | 235 | # HF >4.0 version support 236 | if transformers.__version__.startswith("4") and isinstance( 237 | out, 238 | transformers.modeling_outputs.BaseModelOutputWithPoolingAndCrossAttentions, 239 | ): 240 | sequence_output = out.last_hidden_state 241 | pooled_output = None 242 | hidden_states = out.hidden_states 243 | 244 | elif self.config.output_hidden_states: 245 | sequence_output, pooled_output, hidden_states = out 246 | else: 247 | hidden_states = None 248 | out = super().forward( 249 | input_ids=input_ids, 250 | token_type_ids=token_type_ids, 251 | attention_mask=attention_mask, 252 | ) 253 | sequence_output, pooled_output = out 254 | 255 | if isinstance(representation_token_pos, int): 256 | pooled_output = sequence_output[:, representation_token_pos, :] 257 | else: # treat as a tensor 258 | bsz = sequence_output.size(0) 259 | assert representation_token_pos.size(0) == bsz, "query bsz={} while representation_token_pos bsz={}".format( 260 | bsz, representation_token_pos.size(0) 261 | ) 262 | pooled_output = torch.stack([sequence_output[i, representation_token_pos[i, 1], :] for i in range(bsz)]) 263 | 264 | if self.encode_proj: 265 | pooled_output = self.encode_proj(pooled_output) 266 | return sequence_output, pooled_output, hidden_states 267 | 268 | # TODO: make a super class for all encoders 269 | def get_out_size(self): 270 | if self.encode_proj: 271 | return self.encode_proj.out_features 272 | return self.config.hidden_size 273 | 274 | 275 | class BertTensorizer(Tensorizer): 276 | def __init__(self, tokenizer: BertTokenizer, max_length: int, pad_to_max: bool = True): 277 | self.tokenizer = tokenizer 278 | self.max_length = max_length 279 | self.pad_to_max = pad_to_max 280 | 281 | def text_to_tensor( 282 | self, 283 | text: str, 284 | title: str = None, 285 | add_special_tokens: bool = True, 286 | apply_max_len: bool = True, 287 | ): 288 | text = text.strip() 289 | # tokenizer automatic padding is explicitly disabled since its inconsistent behavior 290 | # TODO: move max len to methods params? 291 | 292 | if title: 293 | token_ids = self.tokenizer.encode( 294 | title, 295 | text_pair=text, 296 | add_special_tokens=add_special_tokens, 297 | max_length=self.max_length if apply_max_len else 10000, 298 | pad_to_max_length=False, 299 | truncation=True, 300 | ) 301 | else: 302 | token_ids = self.tokenizer.encode( 303 | text, 304 | add_special_tokens=add_special_tokens, 305 | max_length=self.max_length if apply_max_len else 10000, 306 | pad_to_max_length=False, 307 | truncation=True, 308 | ) 309 | 310 | seq_len = self.max_length 311 | if self.pad_to_max and len(token_ids) < seq_len: 312 | token_ids = token_ids + [self.tokenizer.pad_token_id] * (seq_len - len(token_ids)) 313 | if len(token_ids) >= seq_len: 314 | token_ids = token_ids[0:seq_len] if apply_max_len else token_ids 315 | token_ids[-1] = self.tokenizer.sep_token_id 316 | 317 | return torch.tensor(token_ids) 318 | 319 | def get_pair_separator_ids(self) -> T: 320 | return torch.tensor([self.tokenizer.sep_token_id]) 321 | 322 | def get_pad_id(self) -> int: 323 | return self.tokenizer.pad_token_id 324 | 325 | def get_attn_mask(self, tokens_tensor: T) -> T: 326 | return tokens_tensor != self.get_pad_id() 327 | 328 | def is_sub_word_id(self, token_id: int): 329 | token = self.tokenizer.convert_ids_to_tokens([token_id])[0] 330 | return token.startswith("##") or token.startswith(" ##") 331 | 332 | def to_string(self, token_ids, skip_special_tokens=True): 333 | return self.tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens) 334 | 335 | def set_pad_to_max(self, do_pad: bool): 336 | self.pad_to_max = do_pad 337 | 338 | def get_token_id(self, token: str) -> int: 339 | return self.tokenizer.vocab[token] 340 | 341 | 342 | class RobertaTensorizer(BertTensorizer): 343 | def __init__(self, tokenizer, max_length: int, pad_to_max: bool = True): 344 | super(RobertaTensorizer, self).__init__(tokenizer, max_length, pad_to_max=pad_to_max) 345 | -------------------------------------------------------------------------------- /dpr/models/pytext_models.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """ 9 | Encoder model wrappers based on HuggingFace code 10 | """ 11 | 12 | import logging 13 | from typing import Tuple 14 | 15 | import torch 16 | from pytext.models.representations.transformer_sentence_encoder import TransformerSentenceEncoder 17 | from pytext.optimizer.optimizers import AdamW 18 | from torch import Tensor as T 19 | from torch import nn 20 | 21 | from .biencoder import BiEncoder 22 | 23 | logger = logging.getLogger(__name__) 24 | 25 | 26 | def get_bert_biencoder_components(args, inference_only: bool = False): 27 | # since bert tokenizer is the same in HF and pytext/fairseq, just use HF's implementation here for now 28 | from .hf_models import get_tokenizer, BertTensorizer 29 | 30 | tokenizer = get_tokenizer(args.pretrained_model_cfg, do_lower_case=args.do_lower_case) 31 | 32 | question_encoder = PytextBertEncoder.init_encoder(args.pretrained_file, 33 | projection_dim=args.projection_dim, dropout=args.dropout, 34 | vocab_size=tokenizer.vocab_size, 35 | padding_idx=tokenizer.pad_token_type_id 36 | ) 37 | 38 | ctx_encoder = PytextBertEncoder.init_encoder(args.pretrained_file, 39 | projection_dim=args.projection_dim, dropout=args.dropout, 40 | vocab_size=tokenizer.vocab_size, 41 | padding_idx=tokenizer.pad_token_type_id 42 | ) 43 | 44 | biencoder = BiEncoder(question_encoder, ctx_encoder) 45 | 46 | optimizer = get_optimizer(biencoder, 47 | learning_rate=args.learning_rate, 48 | adam_eps=args.adam_eps, weight_decay=args.weight_decay, 49 | ) if not inference_only else None 50 | 51 | tensorizer = BertTensorizer(tokenizer, args.sequence_length) 52 | return tensorizer, biencoder, optimizer 53 | 54 | 55 | def get_optimizer(model: nn.Module, learning_rate: float = 1e-5, adam_eps: float = 1e-8, 56 | weight_decay: float = 0.0) -> torch.optim.Optimizer: 57 | cfg = AdamW.Config() 58 | cfg.lr = learning_rate 59 | cfg.weight_decay = weight_decay 60 | cfg.eps = adam_eps 61 | optimizer = AdamW.from_config(cfg, model) 62 | return optimizer 63 | 64 | 65 | def get_pytext_bert_base_cfg(): 66 | cfg = TransformerSentenceEncoder.Config() 67 | cfg.embedding_dim = 768 68 | cfg.ffn_embedding_dim = 3072 69 | cfg.num_encoder_layers = 12 70 | cfg.num_attention_heads = 12 71 | cfg.num_segments = 2 72 | cfg.use_position_embeddings = True 73 | cfg.offset_positions_by_padding = True 74 | cfg.apply_bert_init = True 75 | cfg.encoder_normalize_before = True 76 | cfg.activation_fn = "gelu" 77 | cfg.projection_dim = 0 78 | cfg.max_seq_len = 512 79 | cfg.multilingual = False 80 | cfg.freeze_embeddings = False 81 | cfg.n_trans_layers_to_freeze = 0 82 | cfg.use_torchscript = False 83 | return cfg 84 | 85 | 86 | class PytextBertEncoder(TransformerSentenceEncoder): 87 | 88 | def __init__(self, config: TransformerSentenceEncoder.Config, 89 | padding_idx: int, 90 | vocab_size: int, 91 | projection_dim: int = 0, 92 | *args, 93 | **kwarg 94 | ): 95 | 96 | TransformerSentenceEncoder.__init__(self, config, False, padding_idx, vocab_size, *args, **kwarg) 97 | 98 | assert config.embedding_dim > 0, 'Encoder hidden_size can\'t be zero' 99 | self.encode_proj = nn.Linear(config.embedding_dim, projection_dim) if projection_dim != 0 else None 100 | 101 | @classmethod 102 | def init_encoder(cls, pretrained_file: str = None, projection_dim: int = 0, dropout: float = 0.1, 103 | vocab_size: int = 0, 104 | padding_idx: int = 0, **kwargs): 105 | cfg = get_pytext_bert_base_cfg() 106 | 107 | if dropout != 0: 108 | cfg.dropout = dropout 109 | cfg.attention_dropout = dropout 110 | cfg.activation_dropout = dropout 111 | 112 | encoder = cls(cfg, padding_idx, vocab_size, projection_dim, **kwargs) 113 | 114 | if pretrained_file: 115 | logger.info('Loading pre-trained pytext encoder state from %s', pretrained_file) 116 | state = torch.load(pretrained_file) 117 | encoder.load_state_dict(state) 118 | return encoder 119 | 120 | def forward(self, input_ids: T, token_type_ids: T, attention_mask: T) -> Tuple[T, ...]: 121 | pooled_output = super().forward((input_ids, attention_mask, token_type_ids, None))[0] 122 | if self.encode_proj: 123 | pooled_output = self.encode_proj(pooled_output) 124 | 125 | return None, pooled_output, None 126 | 127 | def get_out_size(self): 128 | if self.encode_proj: 129 | return self.encode_proj.out_features 130 | return self.representation_dim 131 | -------------------------------------------------------------------------------- /dpr/models/reader.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """ 9 | The reader model code + its utilities (loss computation and input batch tensor generator) 10 | """ 11 | 12 | import collections 13 | import logging 14 | from typing import List 15 | 16 | import numpy as np 17 | import torch 18 | import torch.nn as nn 19 | from torch import Tensor as T 20 | from torch.nn import CrossEntropyLoss 21 | 22 | from dpr.data.reader_data import ReaderSample, ReaderPassage 23 | from dpr.utils.model_utils import init_weights 24 | logger = logging.getLogger() 25 | 26 | ReaderBatch = collections.namedtuple( 27 | "ReaderBatch", ["input_ids", "start_positions", "end_positions", "answers_mask", "token_type_ids"] 28 | ) 29 | 30 | 31 | class Reader(nn.Module): 32 | def __init__(self, encoder: nn.Module, hidden_size): 33 | super(Reader, self).__init__() 34 | self.encoder = encoder 35 | self.qa_outputs = nn.Linear(hidden_size, 2) 36 | self.qa_classifier = nn.Linear(hidden_size, 1) 37 | init_weights([self.qa_outputs, self.qa_classifier]) 38 | 39 | def forward( 40 | self, 41 | input_ids: T, 42 | attention_mask: T, 43 | toke_type_ids: T, 44 | start_positions=None, 45 | end_positions=None, 46 | answer_mask=None, 47 | ): 48 | # notations: N - number of questions in a batch, M - number of passages per questions, L - sequence length 49 | N, M, L = input_ids.size() 50 | start_logits, end_logits, relevance_logits = self._forward( 51 | input_ids.view(N * M, L), 52 | attention_mask.view(N * M, L), 53 | toke_type_ids.view(N * M, L), 54 | ) 55 | if self.training: 56 | return compute_loss( 57 | start_positions, end_positions, answer_mask, start_logits, end_logits, relevance_logits, N, M 58 | ) 59 | 60 | return start_logits.view(N, M, L), end_logits.view(N, M, L), relevance_logits.view(N, M) 61 | 62 | def _forward(self, input_ids, attention_mask, toke_type_ids: T): 63 | sequence_output, _pooled_output, _hidden_states = self.encoder(input_ids, toke_type_ids, attention_mask) 64 | logits = self.qa_outputs(sequence_output) 65 | start_logits, end_logits = logits.split(1, dim=-1) 66 | start_logits = start_logits.squeeze(-1) 67 | end_logits = end_logits.squeeze(-1) 68 | rank_logits = self.qa_classifier(sequence_output[:, 0, :]) 69 | return start_logits, end_logits, rank_logits 70 | 71 | 72 | def compute_loss(start_positions, end_positions, answer_mask, start_logits, end_logits, relevance_logits, N, M): 73 | start_positions = start_positions.view(N * M, -1) 74 | end_positions = end_positions.view(N * M, -1) 75 | answer_mask = answer_mask.view(N * M, -1) 76 | 77 | start_logits = start_logits.view(N * M, -1) 78 | end_logits = end_logits.view(N * M, -1) 79 | relevance_logits = relevance_logits.view(N * M) 80 | 81 | answer_mask = answer_mask.type(torch.FloatTensor).cuda() 82 | 83 | ignored_index = start_logits.size(1) 84 | start_positions.clamp_(0, ignored_index) 85 | end_positions.clamp_(0, ignored_index) 86 | loss_fct = CrossEntropyLoss(reduce=False, ignore_index=ignored_index) 87 | 88 | # compute switch loss 89 | relevance_logits = relevance_logits.view(N, M) 90 | switch_labels = torch.zeros(N, dtype=torch.long).cuda() 91 | switch_loss = torch.sum(loss_fct(relevance_logits, switch_labels)) 92 | 93 | # compute span loss 94 | start_losses = [ 95 | (loss_fct(start_logits, _start_positions) * _span_mask) 96 | for (_start_positions, _span_mask) in zip( 97 | torch.unbind(start_positions, dim=1), torch.unbind(answer_mask, dim=1) 98 | ) 99 | ] 100 | 101 | end_losses = [ 102 | (loss_fct(end_logits, _end_positions) * _span_mask) 103 | for (_end_positions, _span_mask) in zip(torch.unbind(end_positions, dim=1), torch.unbind(answer_mask, dim=1)) 104 | ] 105 | loss_tensor = torch.cat([t.unsqueeze(1) for t in start_losses], dim=1) + torch.cat( 106 | [t.unsqueeze(1) for t in end_losses], dim=1 107 | ) 108 | 109 | loss_tensor = loss_tensor.view(N, M, -1).max(dim=1)[0] 110 | span_loss = _calc_mml(loss_tensor) 111 | return span_loss + switch_loss 112 | 113 | 114 | def create_reader_input( 115 | pad_token_id: int, 116 | samples: List[ReaderSample], 117 | passages_per_question: int, 118 | max_length: int, 119 | max_n_answers: int, 120 | is_train: bool, 121 | shuffle: bool, 122 | sep_token_id: int, 123 | ) -> ReaderBatch: 124 | """ 125 | Creates a reader batch instance out of a list of ReaderSample-s 126 | :param pad_token_id: id of the padding token 127 | :param samples: list of samples to create the batch for 128 | :param passages_per_question: amount of passages for every question in a batch 129 | :param max_length: max model input sequence length 130 | :param max_n_answers: max num of answers per single question 131 | :param is_train: if the samples are for a train set 132 | :param shuffle: should passages selection be randomized 133 | :return: ReaderBatch instance 134 | """ 135 | input_ids = [] 136 | start_positions = [] 137 | end_positions = [] 138 | answers_masks = [] 139 | token_type_ids = [] 140 | empty_sequence = torch.Tensor().new_full((max_length,), pad_token_id, dtype=torch.long) 141 | 142 | for sample in samples: 143 | positive_ctxs = sample.positive_passages 144 | negative_ctxs = sample.negative_passages if is_train else sample.passages 145 | 146 | sample_tensors = _create_question_passages_tensors( 147 | positive_ctxs, 148 | negative_ctxs, 149 | passages_per_question, 150 | empty_sequence, 151 | max_n_answers, 152 | pad_token_id, 153 | sep_token_id, 154 | is_train, 155 | is_random=shuffle, 156 | ) 157 | if not sample_tensors: 158 | logger.debug("No valid passages combination for question=%s ", sample.question) 159 | continue 160 | sample_input_ids, starts_tensor, ends_tensor, answer_mask, sample_ttids = sample_tensors 161 | input_ids.append(sample_input_ids) 162 | token_type_ids.append(sample_ttids) 163 | if is_train: 164 | start_positions.append(starts_tensor) 165 | end_positions.append(ends_tensor) 166 | answers_masks.append(answer_mask) 167 | input_ids = torch.cat([ids.unsqueeze(0) for ids in input_ids], dim=0) 168 | token_type_ids = torch.cat([ids.unsqueeze(0) for ids in token_type_ids], dim=0) # .unsqueeze(0) 169 | 170 | if is_train: 171 | start_positions = torch.stack(start_positions, dim=0) 172 | end_positions = torch.stack(end_positions, dim=0) 173 | answers_masks = torch.stack(answers_masks, dim=0) 174 | 175 | return ReaderBatch(input_ids, start_positions, end_positions, answers_masks, token_type_ids) 176 | 177 | 178 | def _calc_mml(loss_tensor): 179 | marginal_likelihood = torch.sum(torch.exp(-loss_tensor - 1e10 * (loss_tensor == 0).float()), 1) 180 | return -torch.sum( 181 | torch.log(marginal_likelihood + torch.ones(loss_tensor.size(0)).cuda() * (marginal_likelihood == 0).float()) 182 | ) 183 | 184 | 185 | def _pad_to_len(seq: T, pad_id: int, max_len: int): 186 | s_len = seq.size(0) 187 | if s_len > max_len: 188 | return seq[0:max_len] 189 | return torch.cat([seq, torch.Tensor().new_full((max_len - s_len,), pad_id, dtype=torch.long)], dim=0) 190 | 191 | 192 | def _get_answer_spans(idx, positives: List[ReaderPassage], max_len: int): 193 | positive_a_spans = positives[idx].answers_spans 194 | return [span for span in positive_a_spans if (span[0] < max_len and span[1] < max_len)] 195 | 196 | 197 | def _get_positive_idx(positives: List[ReaderPassage], max_len: int, is_random: bool): 198 | # select just one positive 199 | positive_idx = np.random.choice(len(positives)) if is_random else 0 200 | 201 | if not _get_answer_spans(positive_idx, positives, max_len): 202 | # question may be too long, find the first positive with at least one valid span 203 | positive_idx = next((i for i in range(len(positives)) if _get_answer_spans(i, positives, max_len)), None) 204 | return positive_idx 205 | 206 | 207 | def _create_question_passages_tensors( 208 | positives: List[ReaderPassage], 209 | negatives: List[ReaderPassage], 210 | total_size: int, 211 | empty_ids: T, 212 | max_n_answers: int, 213 | pad_token_id: int, 214 | sep_token_id: int, 215 | is_train: bool, 216 | is_random: bool = True, 217 | first_segment_ttid: int = 0, 218 | ): 219 | max_len = empty_ids.size(0) 220 | if is_train: 221 | # select just one positive 222 | positive_idx = _get_positive_idx(positives, max_len, is_random) 223 | if positive_idx is None: 224 | return None 225 | 226 | positive_a_spans = _get_answer_spans(positive_idx, positives, max_len)[0:max_n_answers] 227 | 228 | answer_starts = [span[0] for span in positive_a_spans] 229 | answer_ends = [span[1] for span in positive_a_spans] 230 | 231 | assert all(s < max_len for s in answer_starts) 232 | assert all(e < max_len for e in answer_ends) 233 | 234 | positive_input_ids = _pad_to_len(positives[positive_idx].sequence_ids, pad_token_id, max_len) 235 | 236 | answer_starts_tensor = torch.zeros((total_size, max_n_answers)).long() 237 | answer_starts_tensor[0, 0 : len(answer_starts)] = torch.tensor(answer_starts) 238 | 239 | answer_ends_tensor = torch.zeros((total_size, max_n_answers)).long() 240 | answer_ends_tensor[0, 0 : len(answer_ends)] = torch.tensor(answer_ends) 241 | 242 | answer_mask = torch.zeros((total_size, max_n_answers), dtype=torch.long) 243 | answer_mask[0, 0 : len(answer_starts)] = torch.tensor([1 for _ in range(len(answer_starts))]) 244 | 245 | positives_selected = [positive_input_ids] 246 | 247 | else: 248 | positives_selected = [] 249 | answer_starts_tensor = None 250 | answer_ends_tensor = None 251 | answer_mask = None 252 | 253 | positives_num = len(positives_selected) 254 | negative_idxs = np.random.permutation(range(len(negatives))) if is_random else range(len(negatives) - positives_num) 255 | 256 | negative_idxs = negative_idxs[: total_size - positives_num] 257 | 258 | negatives_selected = [_pad_to_len(negatives[i].sequence_ids, pad_token_id, max_len) for i in negative_idxs] 259 | negatives_num = len(negatives_selected) 260 | 261 | input_ids = torch.stack([t for t in positives_selected + negatives_selected], dim=0) 262 | 263 | toke_type_ids = _create_token_type_ids(input_ids, sep_token_id, first_segment_ttid) 264 | 265 | if positives_num + negatives_num < total_size: 266 | empty_negatives = [empty_ids.clone().view(1, -1) for _ in range(total_size - (positives_num + negatives_num))] 267 | empty_token_type_ids = [ 268 | empty_ids.clone().view(1, -1) for _ in range(total_size - (positives_num + negatives_num)) 269 | ] 270 | 271 | input_ids = torch.cat([input_ids, *empty_negatives], dim=0) 272 | toke_type_ids = torch.cat([toke_type_ids, *empty_token_type_ids], dim=0) 273 | 274 | return input_ids, answer_starts_tensor, answer_ends_tensor, answer_mask, toke_type_ids 275 | 276 | 277 | def _create_token_type_ids(input_ids: torch.Tensor, sep_token_id: int, first_segment_ttid: int = 0): 278 | 279 | token_type_ids = torch.full(input_ids.shape, fill_value=0) 280 | # return token_type_ids 281 | sep_tokens_indexes = torch.nonzero(input_ids == sep_token_id) 282 | bsz = input_ids.size(0) 283 | second_ttid = 0 if first_segment_ttid == 1 else 1 284 | 285 | for i in range(bsz): 286 | token_type_ids[i, 0 : sep_tokens_indexes[2 * i, 1] + 1] = first_segment_ttid 287 | token_type_ids[i, sep_tokens_indexes[2 * i, 1] + 1 :] = second_ttid 288 | return token_type_ids 289 | -------------------------------------------------------------------------------- /dpr/options.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """ 9 | Command line arguments utils 10 | """ 11 | 12 | 13 | import logging 14 | import os 15 | import random 16 | import socket 17 | import subprocess 18 | from typing import Tuple 19 | 20 | import numpy as np 21 | import torch 22 | from omegaconf import DictConfig 23 | 24 | logger = logging.getLogger() 25 | 26 | # TODO: to be merged with conf_utils.py 27 | 28 | 29 | def set_cfg_params_from_state(state: dict, cfg: DictConfig): 30 | """ 31 | Overrides some of the encoder config parameters from a give state object 32 | """ 33 | if not state: 34 | return 35 | 36 | cfg.do_lower_case = state["do_lower_case"] 37 | 38 | if "encoder" in state: 39 | saved_encoder_params = state["encoder"] 40 | # TODO: try to understand why cfg.encoder = state["encoder"] doesn't work 41 | 42 | for k, v in saved_encoder_params.items(): 43 | 44 | # TODO: tmp fix 45 | if k == "q_wav2vec_model_cfg": 46 | k = "q_encoder_model_cfg" 47 | if k == "q_wav2vec_cp_file": 48 | k = "q_encoder_cp_file" 49 | if k == "q_wav2vec_cp_file": 50 | k = "q_encoder_cp_file" 51 | 52 | setattr(cfg.encoder, k, v) 53 | else: # 'old' checkpoints backward compatibility support 54 | pass 55 | # cfg.encoder.pretrained_model_cfg = state["pretrained_model_cfg"] 56 | # cfg.encoder.encoder_model_type = state["encoder_model_type"] 57 | # cfg.encoder.pretrained_file = state["pretrained_file"] 58 | # cfg.encoder.projection_dim = state["projection_dim"] 59 | # cfg.encoder.sequence_length = state["sequence_length"] 60 | 61 | 62 | def get_encoder_params_state_from_cfg(cfg: DictConfig): 63 | """ 64 | Selects the param values to be saved in a checkpoint, so that a trained model can be used for downstream 65 | tasks without the need to specify these parameter again 66 | :return: Dict of params to memorize in a checkpoint 67 | """ 68 | return { 69 | "do_lower_case": cfg.do_lower_case, 70 | "encoder": cfg.encoder, 71 | } 72 | 73 | 74 | def set_seed(args): 75 | seed = args.seed 76 | random.seed(seed) 77 | np.random.seed(seed) 78 | torch.manual_seed(seed) 79 | if args.n_gpu > 0: 80 | torch.cuda.manual_seed_all(seed) 81 | 82 | 83 | def setup_cfg_gpu(cfg): 84 | """ 85 | Setup params for CUDA, GPU & distributed training 86 | """ 87 | logger.info("CFG's local_rank=%s", cfg.local_rank) 88 | ws = os.environ.get("WORLD_SIZE") 89 | cfg.distributed_world_size = int(ws) if ws else 1 90 | logger.info("Env WORLD_SIZE=%s", ws) 91 | 92 | if cfg.distributed_port and cfg.distributed_port > 0: 93 | logger.info("distributed_port is specified, trying to init distributed mode from SLURM params ...") 94 | init_method, local_rank, world_size, device = _infer_slurm_init(cfg) 95 | 96 | logger.info( 97 | "Inferred params from SLURM: init_method=%s | local_rank=%s | world_size=%s", 98 | init_method, 99 | local_rank, 100 | world_size, 101 | ) 102 | 103 | cfg.local_rank = local_rank 104 | cfg.distributed_world_size = world_size 105 | cfg.n_gpu = 1 106 | 107 | torch.cuda.set_device(device) 108 | device = str(torch.device("cuda", device)) 109 | 110 | torch.distributed.init_process_group( 111 | backend="nccl", init_method=init_method, world_size=world_size, rank=local_rank 112 | ) 113 | 114 | elif cfg.local_rank == -1 or cfg.no_cuda: # single-node multi-gpu (or cpu) mode 115 | device = str(torch.device("cuda" if torch.cuda.is_available() and not cfg.no_cuda else "cpu")) 116 | cfg.n_gpu = torch.cuda.device_count() 117 | else: # distributed mode 118 | torch.cuda.set_device(cfg.local_rank) 119 | device = str(torch.device("cuda", cfg.local_rank)) 120 | torch.distributed.init_process_group(backend="nccl") 121 | cfg.n_gpu = 1 122 | 123 | cfg.device = device 124 | 125 | logger.info( 126 | "Initialized host %s as d.rank %d on device=%s, n_gpu=%d, world size=%d", 127 | socket.gethostname(), 128 | cfg.local_rank, 129 | cfg.device, 130 | cfg.n_gpu, 131 | cfg.distributed_world_size, 132 | ) 133 | logger.info("16-bits training: %s ", cfg.fp16) 134 | return cfg 135 | 136 | 137 | def _infer_slurm_init(cfg) -> Tuple[str, int, int, int]: 138 | 139 | node_list = os.environ.get("SLURM_STEP_NODELIST") 140 | if node_list is None: 141 | node_list = os.environ.get("SLURM_JOB_NODELIST") 142 | logger.info("SLURM_JOB_NODELIST: %s", node_list) 143 | 144 | if node_list is None: 145 | raise RuntimeError("Can't find SLURM node_list from env parameters") 146 | 147 | local_rank = None 148 | world_size = None 149 | distributed_init_method = None 150 | device_id = None 151 | try: 152 | hostnames = subprocess.check_output(["scontrol", "show", "hostnames", node_list]) 153 | distributed_init_method = "tcp://{host}:{port}".format( 154 | host=hostnames.split()[0].decode("utf-8"), 155 | port=cfg.distributed_port, 156 | ) 157 | nnodes = int(os.environ.get("SLURM_NNODES")) 158 | logger.info("SLURM_NNODES: %s", nnodes) 159 | ntasks_per_node = os.environ.get("SLURM_NTASKS_PER_NODE") 160 | if ntasks_per_node is not None: 161 | ntasks_per_node = int(ntasks_per_node) 162 | logger.info("SLURM_NTASKS_PER_NODE: %s", ntasks_per_node) 163 | else: 164 | ntasks = int(os.environ.get("SLURM_NTASKS")) 165 | logger.info("SLURM_NTASKS: %s", ntasks) 166 | assert ntasks % nnodes == 0 167 | ntasks_per_node = int(ntasks / nnodes) 168 | 169 | if ntasks_per_node == 1: 170 | gpus_per_node = torch.cuda.device_count() 171 | node_id = int(os.environ.get("SLURM_NODEID")) 172 | local_rank = node_id * gpus_per_node 173 | world_size = nnodes * gpus_per_node 174 | logger.info("node_id: %s", node_id) 175 | else: 176 | world_size = ntasks_per_node * nnodes 177 | proc_id = os.environ.get("SLURM_PROCID") 178 | local_id = os.environ.get("SLURM_LOCALID") 179 | logger.info("SLURM_PROCID %s", proc_id) 180 | logger.info("SLURM_LOCALID %s", local_id) 181 | local_rank = int(proc_id) 182 | device_id = int(local_id) 183 | 184 | except subprocess.CalledProcessError as e: # scontrol failed 185 | raise e 186 | except FileNotFoundError: # Slurm is not installed 187 | pass 188 | return distributed_init_method, local_rank, world_size, device_id 189 | 190 | 191 | def setup_logger(logger): 192 | logger.setLevel(logging.INFO) 193 | if logger.hasHandlers(): 194 | logger.handlers.clear() 195 | log_formatter = logging.Formatter("[%(thread)s] %(asctime)s [%(levelname)s] %(name)s: %(message)s") 196 | console = logging.StreamHandler() 197 | console.setFormatter(log_formatter) 198 | logger.addHandler(console) 199 | -------------------------------------------------------------------------------- /dpr/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/DPR/a31212dc0a54dfa85d8bfa01e1669f149ac832b7/dpr/utils/__init__.py -------------------------------------------------------------------------------- /dpr/utils/conf_utils.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import logging 3 | import os 4 | 5 | import hydra 6 | from omegaconf import DictConfig 7 | 8 | from dpr.data.biencoder_data import JsonQADataset 9 | 10 | logger = logging.getLogger(__name__) 11 | 12 | 13 | class BiencoderDatasetsCfg(object): 14 | def __init__(self, cfg: DictConfig): 15 | ds_cfg = cfg.datasets 16 | self.train_datasets_names = cfg.train_datasets 17 | logger.info("train_datasets: %s", self.train_datasets_names) 18 | self.train_datasets = _init_datasets(self.train_datasets_names, ds_cfg) 19 | self.dev_datasets_names = cfg.dev_datasets 20 | logger.info("dev_datasets: %s", self.dev_datasets_names) 21 | self.dev_datasets = _init_datasets(self.dev_datasets_names, ds_cfg) 22 | self.sampling_rates = cfg.train_sampling_rates 23 | 24 | 25 | def _init_datasets(datasets_names, ds_cfg: DictConfig): 26 | if isinstance(datasets_names, str): 27 | return [_init_dataset(datasets_names, ds_cfg)] 28 | elif datasets_names: 29 | return [_init_dataset(ds_name, ds_cfg) for ds_name in datasets_names] 30 | else: 31 | return [] 32 | 33 | 34 | def _init_dataset(name: str, ds_cfg: DictConfig): 35 | if os.path.exists(name): 36 | # use default biencoder json class 37 | return JsonQADataset(name) 38 | elif glob.glob(name): 39 | files = glob.glob(name) 40 | return [_init_dataset(f, ds_cfg) for f in files] 41 | # try to find in cfg 42 | if name not in ds_cfg: 43 | raise RuntimeError("Can't find dataset location/config for: {}".format(name)) 44 | return hydra.utils.instantiate(ds_cfg[name]) 45 | -------------------------------------------------------------------------------- /dpr/utils/data_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """ 9 | Utilities for general purpose data processing 10 | """ 11 | import json 12 | import logging 13 | import pickle 14 | import random 15 | 16 | import itertools 17 | import math 18 | 19 | import hydra 20 | import jsonlines 21 | import torch 22 | from omegaconf import DictConfig 23 | from torch import Tensor as T 24 | from typing import List, Iterator, Callable, Tuple 25 | 26 | logger = logging.getLogger() 27 | 28 | 29 | def read_serialized_data_from_files(paths: List[str]) -> List: 30 | results = [] 31 | for i, path in enumerate(paths): 32 | with open(path, "rb") as reader: 33 | logger.info("Reading file %s", path) 34 | data = pickle.load(reader) 35 | results.extend(data) 36 | logger.info("Aggregated data size: {}".format(len(results))) 37 | logger.info("Total data size: {}".format(len(results))) 38 | return results 39 | 40 | 41 | def read_data_from_json_files(paths: List[str]) -> List: 42 | results = [] 43 | for i, path in enumerate(paths): 44 | with open(path, "r", encoding="utf-8") as f: 45 | logger.info("Reading file %s" % path) 46 | data = json.load(f) 47 | results.extend(data) 48 | logger.info("Aggregated data size: {}".format(len(results))) 49 | return results 50 | 51 | 52 | def read_data_from_jsonl_files(paths: List[str]) -> List: 53 | results = [] 54 | for i, path in enumerate(paths): 55 | logger.info("Reading file %s" % path) 56 | with jsonlines.open(path, mode="r") as jsonl_reader: 57 | data = [r for r in jsonl_reader] 58 | results.extend(data) 59 | logger.info("Aggregated data size: {}".format(len(results))) 60 | return results 61 | 62 | 63 | def normalize_question(question: str) -> str: 64 | question = question.replace("’", "'") 65 | return question 66 | 67 | 68 | class Tensorizer(object): 69 | """ 70 | Component for all text to model input data conversions and related utility methods 71 | """ 72 | 73 | # Note: title, if present, is supposed to be put before text (i.e. optional title + document body) 74 | def text_to_tensor( 75 | self, 76 | text: str, 77 | title: str = None, 78 | add_special_tokens: bool = True, 79 | apply_max_len: bool = True, 80 | ): 81 | raise NotImplementedError 82 | 83 | def get_pair_separator_ids(self) -> T: 84 | raise NotImplementedError 85 | 86 | def get_pad_id(self) -> int: 87 | raise NotImplementedError 88 | 89 | def get_attn_mask(self, tokens_tensor: T): 90 | raise NotImplementedError 91 | 92 | def is_sub_word_id(self, token_id: int): 93 | raise NotImplementedError 94 | 95 | def to_string(self, token_ids, skip_special_tokens=True): 96 | raise NotImplementedError 97 | 98 | def set_pad_to_max(self, pad: bool): 99 | raise NotImplementedError 100 | 101 | def get_token_id(self, token: str) -> int: 102 | raise NotImplementedError 103 | 104 | 105 | class RepTokenSelector(object): 106 | def get_positions(self, input_ids: T, tenzorizer: Tensorizer): 107 | raise NotImplementedError 108 | 109 | 110 | class RepStaticPosTokenSelector(RepTokenSelector): 111 | def __init__(self, static_position: int = 0): 112 | self.static_position = static_position 113 | 114 | def get_positions(self, input_ids: T, tenzorizer: Tensorizer): 115 | return self.static_position 116 | 117 | 118 | class RepSpecificTokenSelector(RepTokenSelector): 119 | def __init__(self, token: str = "[CLS]"): 120 | self.token = token 121 | self.token_id = None 122 | 123 | def get_positions(self, input_ids: T, tenzorizer: Tensorizer): 124 | if not self.token_id: 125 | self.token_id = tenzorizer.get_token_id(self.token) 126 | token_indexes = (input_ids == self.token_id).nonzero() 127 | # check if all samples in input_ids has index presence and out a default value otherwise 128 | bsz = input_ids.size(0) 129 | if bsz == token_indexes.size(0): 130 | return token_indexes 131 | 132 | token_indexes_result = [] 133 | found_idx_cnt = 0 134 | for i in range(bsz): 135 | if found_idx_cnt < token_indexes.size(0) and token_indexes[found_idx_cnt][0] == i: 136 | # this samples has the special token 137 | token_indexes_result.append(token_indexes[found_idx_cnt]) 138 | found_idx_cnt += 1 139 | else: 140 | logger.warning("missing special token %s", input_ids[i]) 141 | 142 | token_indexes_result.append( 143 | torch.tensor([i, 0]).to(input_ids.device) 144 | ) # setting 0-th token, i.e. CLS for BERT as the special one 145 | token_indexes_result = torch.stack(token_indexes_result, dim=0) 146 | return token_indexes_result 147 | 148 | 149 | DEFAULT_SELECTOR = RepStaticPosTokenSelector() 150 | 151 | 152 | class Dataset(torch.utils.data.Dataset): 153 | def __init__( 154 | self, 155 | selector: DictConfig = None, 156 | special_token: str = None, 157 | shuffle_positives: bool = False, 158 | query_special_suffix: str = None, 159 | encoder_type: str = None, 160 | ): 161 | if selector: 162 | self.selector = hydra.utils.instantiate(selector) 163 | else: 164 | self.selector = DEFAULT_SELECTOR 165 | self.special_token = special_token 166 | self.encoder_type = encoder_type 167 | self.shuffle_positives = shuffle_positives 168 | self.query_special_suffix = query_special_suffix 169 | self.data = [] 170 | 171 | def load_data(self, start_pos: int = -1, end_pos: int = -1): 172 | raise NotImplementedError 173 | 174 | def calc_total_data_len(self): 175 | raise NotImplementedError 176 | 177 | def __len__(self): 178 | return len(self.data) 179 | 180 | def __getitem__(self, index): 181 | raise NotImplementedError 182 | 183 | def _process_query(self, query: str): 184 | # as of now, always normalize query 185 | query = normalize_question(query) 186 | if self.query_special_suffix and not query.endswith(self.query_special_suffix): 187 | query += self.query_special_suffix 188 | 189 | return query 190 | 191 | 192 | # TODO: to be fully replaced with LocalSharded{...}. Keeping it only for old results reproduction compatibility 193 | class ShardedDataIterator(object): 194 | """ 195 | General purpose data iterator to be used for Pytorch's DDP mode where every node should handle its own part of 196 | the data. 197 | Instead of cutting data shards by their min size, it sets the amount of iterations by the maximum shard size. 198 | It fills the extra sample by just taking first samples in a shard. 199 | It can also optionally enforce identical batch size for all iterations (might be useful for DP mode). 200 | """ 201 | 202 | def __init__( 203 | self, 204 | dataset: Dataset, 205 | shard_id: int = 0, 206 | num_shards: int = 1, 207 | batch_size: int = 1, 208 | shuffle=True, 209 | shuffle_seed: int = 0, 210 | offset: int = 0, 211 | strict_batch_size: bool = False, 212 | ): 213 | 214 | self.dataset = dataset 215 | self.shard_id = shard_id 216 | self.num_shards = num_shards 217 | self.iteration = offset # to track in-shard iteration status 218 | self.shuffle = shuffle 219 | self.batch_size = batch_size 220 | self.shuffle_seed = shuffle_seed 221 | self.strict_batch_size = strict_batch_size 222 | self.shard_start_idx = -1 223 | self.shard_end_idx = -1 224 | self.max_iterations = 0 225 | 226 | def calculate_shards(self): 227 | logger.info("Calculating shard positions") 228 | shards_num = max(self.num_shards, 1) 229 | shard_id = max(self.shard_id, 0) 230 | 231 | total_size = self.dataset.calc_total_data_len() 232 | samples_per_shard = math.ceil(total_size / shards_num) 233 | 234 | self.shard_start_idx = shard_id * samples_per_shard 235 | self.shard_end_idx = min(self.shard_start_idx + samples_per_shard, total_size) 236 | 237 | if self.strict_batch_size: 238 | self.max_iterations = math.ceil(samples_per_shard / self.batch_size) 239 | else: 240 | self.max_iterations = int(samples_per_shard / self.batch_size) 241 | 242 | logger.info( 243 | "samples_per_shard=%d, shard_start_idx=%d, shard_end_idx=%d, max_iterations=%d", 244 | samples_per_shard, 245 | self.shard_start_idx, 246 | self.shard_end_idx, 247 | self.max_iterations, 248 | ) 249 | 250 | def load_data(self): 251 | self.calculate_shards() 252 | self.dataset.load_data() 253 | logger.info("Sharded dataset data %d", len(self.dataset)) 254 | 255 | def total_data_len(self) -> int: 256 | return len(self.dataset) 257 | 258 | def iterations_num(self) -> int: 259 | return self.max_iterations - self.iteration 260 | 261 | def max_iterations_num(self) -> int: 262 | return self.max_iterations 263 | 264 | def get_iteration(self) -> int: 265 | return self.iteration 266 | 267 | def apply(self, visitor_func: Callable): 268 | for sample in self.dataset: 269 | visitor_func(sample) 270 | 271 | def get_shard_indices(self, epoch: int): 272 | indices = list(range(len(self.dataset))) 273 | if self.shuffle: 274 | # to be able to resume, same shuffling should be used when starting from a failed/stopped iteration 275 | epoch_rnd = random.Random(self.shuffle_seed + epoch) 276 | epoch_rnd.shuffle(indices) 277 | shard_indices = indices[self.shard_start_idx : self.shard_end_idx] 278 | return shard_indices 279 | 280 | # TODO: merge with iterate_ds_sampled_data 281 | def iterate_ds_data(self, epoch: int = 0) -> Iterator[List]: 282 | # if resuming iteration somewhere in the middle of epoch, one needs to adjust max_iterations 283 | max_iterations = self.max_iterations - self.iteration 284 | shard_indices = self.get_shard_indices(epoch) 285 | 286 | for i in range(self.iteration * self.batch_size, len(shard_indices), self.batch_size): 287 | items_idxs = shard_indices[i : i + self.batch_size] 288 | if self.strict_batch_size and len(items_idxs) < self.batch_size: 289 | logger.debug("Extending batch to max size") 290 | items_idxs.extend(shard_indices[0 : self.batch_size - len(items)]) 291 | self.iteration += 1 292 | items = [self.dataset[idx] for idx in items_idxs] 293 | yield items 294 | 295 | # some shards may done iterating while the others are at the last batch. Just return the first batch 296 | while self.iteration < max_iterations: 297 | logger.debug("Fulfilling non complete shard=".format(self.shard_id)) 298 | self.iteration += 1 299 | items_idxs = shard_indices[0 : self.batch_size] 300 | items = [self.dataset[idx] for idx in items_idxs] 301 | yield items 302 | 303 | logger.info("Finished iterating, iteration={}, shard={}".format(self.iteration, self.shard_id)) 304 | # reset the iteration status 305 | self.iteration = 0 306 | 307 | def iterate_ds_sampled_data(self, num_iterations: int, epoch: int = 0) -> Iterator[List]: 308 | self.iteration = 0 309 | shard_indices = self.get_shard_indices(epoch) 310 | cycle_it = itertools.cycle(shard_indices) 311 | for i in range(num_iterations): 312 | items_idxs = [next(cycle_it) for _ in range(self.batch_size)] 313 | self.iteration += 1 314 | items = [self.dataset[idx] for idx in items_idxs] 315 | yield items 316 | 317 | logger.info("Finished iterating, iteration={}, shard={}".format(self.iteration, self.shard_id)) 318 | # TODO: reset the iteration status? 319 | self.iteration = 0 320 | 321 | def get_dataset(self) -> Dataset: 322 | return self.dataset 323 | 324 | 325 | class LocalShardedDataIterator(ShardedDataIterator): 326 | # uses only one shard after the initial dataset load to reduce memory footprint 327 | def load_data(self): 328 | self.calculate_shards() 329 | self.dataset.load_data(start_pos=self.shard_start_idx, end_pos=self.shard_end_idx) 330 | logger.info("Sharded dataset data %d", len(self.dataset)) 331 | 332 | def get_shard_indices(self, epoch: int): 333 | indices = list(range(len(self.dataset))) 334 | if self.shuffle: 335 | # to be able to resume, same shuffling should be used when starting from a failed/stopped iteration 336 | epoch_rnd = random.Random(self.shuffle_seed + epoch) 337 | epoch_rnd.shuffle(indices) 338 | shard_indices = indices 339 | return shard_indices 340 | 341 | 342 | class MultiSetDataIterator(object): 343 | """ 344 | Iterator over multiple data sources. Useful when all samples form a single batch should be from the same dataset. 345 | """ 346 | 347 | def __init__( 348 | self, 349 | datasets: List[ShardedDataIterator], 350 | shuffle_seed: int = 0, 351 | shuffle=True, 352 | sampling_rates: List = [], 353 | rank: int = 0, 354 | ): 355 | # randomized data loading to avoid file system congestion 356 | ds_list_copy = [ds for ds in datasets] 357 | rnd = random.Random(rank) 358 | rnd.shuffle(ds_list_copy) 359 | [ds.load_data() for ds in ds_list_copy] 360 | 361 | self.iterables = datasets 362 | data_lengths = [it.total_data_len() for it in datasets] 363 | self.total_data = sum(data_lengths) 364 | logger.info("rank=%d; Multi set data sizes %s", rank, data_lengths) 365 | logger.info("rank=%d; Multi set total data %s", rank, self.total_data) 366 | logger.info("rank=%d; Multi set sampling_rates %s", rank, sampling_rates) 367 | self.shuffle_seed = shuffle_seed 368 | self.shuffle = shuffle 369 | self.iteration = 0 370 | self.rank = rank 371 | 372 | if sampling_rates: 373 | self.max_its_pr_ds = [int(ds.max_iterations_num() * sampling_rates[i]) for i, ds in enumerate(datasets)] 374 | else: 375 | self.max_its_pr_ds = [ds.max_iterations_num() for ds in datasets] 376 | 377 | self.max_iterations = sum(self.max_its_pr_ds) 378 | logger.info("rank=%d; Multi set max_iterations per dataset %s", rank, self.max_its_pr_ds) 379 | logger.info("rank=%d; Multi set max_iterations %d", rank, self.max_iterations) 380 | 381 | def total_data_len(self) -> int: 382 | return self.total_data 383 | 384 | def get_max_iterations(self): 385 | return self.max_iterations 386 | 387 | def iterate_ds_data(self, epoch: int = 0) -> Iterator[Tuple[List, int]]: 388 | 389 | logger.info("rank=%d; Iteration start", self.rank) 390 | logger.info( 391 | "rank=%d; Multi set iteration: iteration ptr per set: %s", 392 | self.rank, 393 | [it.get_iteration() for it in self.iterables], 394 | ) 395 | 396 | data_src_indices = [] 397 | iterators = [] 398 | for source, src_its in enumerate(self.max_its_pr_ds): 399 | logger.info( 400 | "rank=%d; Multi set iteration: source %d, batches to be taken: %s", 401 | self.rank, 402 | source, 403 | src_its, 404 | ) 405 | data_src_indices.extend([source] * src_its) 406 | 407 | iterators.append(self.iterables[source].iterate_ds_sampled_data(src_its, epoch=epoch)) 408 | 409 | if self.shuffle: 410 | # to be able to resume, same shuffling should be used when starting from a failed/stopped iteration 411 | epoch_rnd = random.Random(self.shuffle_seed + epoch) 412 | epoch_rnd.shuffle(data_src_indices) 413 | 414 | logger.info("rank=%d; data_src_indices len=%d", self.rank, len(data_src_indices)) 415 | for i, source_idx in enumerate(data_src_indices): 416 | it = iterators[source_idx] 417 | next_item = next(it, None) 418 | if next_item is not None: 419 | self.iteration += 1 420 | yield (next_item, source_idx) 421 | else: 422 | logger.warning("rank=%d; Next item in the source %s is None", self.rank, source_idx) 423 | 424 | logger.info("rank=%d; last iteration %d", self.rank, self.iteration) 425 | 426 | logger.info( 427 | "rank=%d; Multi set iteration finished: iteration per set: %s", 428 | self.rank, 429 | [it.iteration for it in self.iterables], 430 | ) 431 | [next(it, None) for it in iterators] 432 | 433 | # TODO: clear iterators in some non-hacky way 434 | for it in self.iterables: 435 | it.iteration = 0 436 | logger.info( 437 | "rank=%d; Multi set iteration finished after next: iteration per set: %s", 438 | self.rank, 439 | [it.iteration for it in self.iterables], 440 | ) 441 | # reset the iteration status 442 | self.iteration = 0 443 | 444 | def get_iteration(self) -> int: 445 | return self.iteration 446 | 447 | def get_dataset(self, ds_id: int) -> Dataset: 448 | return self.iterables[ds_id].get_dataset() 449 | 450 | def get_datasets(self) -> List[Dataset]: 451 | return [it.get_dataset() for it in self.iterables] 452 | -------------------------------------------------------------------------------- /dpr/utils/dist_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """ 9 | Utilities for distributed model training 10 | """ 11 | 12 | import pickle 13 | 14 | import torch 15 | import torch.distributed as dist 16 | 17 | 18 | def get_rank(): 19 | return dist.get_rank() 20 | 21 | 22 | def get_world_size(): 23 | return dist.get_world_size() 24 | 25 | 26 | def get_default_group(): 27 | return dist.group.WORLD 28 | 29 | 30 | def all_reduce(tensor, group=None): 31 | if group is None: 32 | group = get_default_group() 33 | return dist.all_reduce(tensor, group=group) 34 | 35 | 36 | def all_gather_list(data, group=None, max_size=16384): 37 | """Gathers arbitrary data from all nodes into a list. 38 | Similar to :func:`~torch.distributed.all_gather` but for arbitrary Python 39 | data. Note that *data* must be picklable. 40 | Args: 41 | data (Any): data from the local worker to be gathered on other workers 42 | group (optional): group of the collective 43 | """ 44 | SIZE_STORAGE_BYTES = 4 # int32 to encode the payload size 45 | 46 | enc = pickle.dumps(data) 47 | enc_size = len(enc) 48 | 49 | if enc_size + SIZE_STORAGE_BYTES > max_size: 50 | raise ValueError( 51 | 'encoded data exceeds max_size, this can be fixed by increasing buffer size: {}'.format(enc_size)) 52 | 53 | rank = get_rank() 54 | world_size = get_world_size() 55 | buffer_size = max_size * world_size 56 | 57 | if not hasattr(all_gather_list, '_buffer') or \ 58 | all_gather_list._buffer.numel() < buffer_size: 59 | all_gather_list._buffer = torch.cuda.ByteTensor(buffer_size) 60 | all_gather_list._cpu_buffer = torch.ByteTensor(max_size).pin_memory() 61 | 62 | buffer = all_gather_list._buffer 63 | buffer.zero_() 64 | cpu_buffer = all_gather_list._cpu_buffer 65 | 66 | assert enc_size < 256 ** SIZE_STORAGE_BYTES, 'Encoded object size should be less than {} bytes'.format( 67 | 256 ** SIZE_STORAGE_BYTES) 68 | 69 | size_bytes = enc_size.to_bytes(SIZE_STORAGE_BYTES, byteorder='big') 70 | 71 | cpu_buffer[0:SIZE_STORAGE_BYTES] = torch.ByteTensor(list(size_bytes)) 72 | cpu_buffer[SIZE_STORAGE_BYTES: enc_size + SIZE_STORAGE_BYTES] = torch.ByteTensor(list(enc)) 73 | 74 | start = rank * max_size 75 | size = enc_size + SIZE_STORAGE_BYTES 76 | buffer[start: start + size].copy_(cpu_buffer[:size]) 77 | 78 | all_reduce(buffer, group=group) 79 | 80 | try: 81 | result = [] 82 | for i in range(world_size): 83 | out_buffer = buffer[i * max_size: (i + 1) * max_size] 84 | size = int.from_bytes(out_buffer[0:SIZE_STORAGE_BYTES], byteorder='big') 85 | if size > 0: 86 | result.append(pickle.loads(bytes(out_buffer[SIZE_STORAGE_BYTES: size + SIZE_STORAGE_BYTES].tolist()))) 87 | return result 88 | except pickle.UnpicklingError: 89 | raise Exception( 90 | 'Unable to unpickle data from other workers. all_gather_list requires all ' 91 | 'workers to enter the function together, so this error usually indicates ' 92 | 'that the workers have fallen out of sync somehow. Workers can fall out of ' 93 | 'sync if one of them runs out of memory, or if there are other conditions ' 94 | 'in your training script that can cause one worker to finish an epoch ' 95 | 'while other workers are still iterating over their portions of the data.' 96 | ) 97 | -------------------------------------------------------------------------------- /dpr/utils/model_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | import collections 9 | import glob 10 | import logging 11 | import os 12 | from typing import List 13 | 14 | import torch 15 | from torch import nn 16 | from torch.optim.lr_scheduler import LambdaLR 17 | from torch.serialization import default_restore_location 18 | 19 | logger = logging.getLogger() 20 | 21 | CheckpointState = collections.namedtuple( 22 | "CheckpointState", 23 | [ 24 | "model_dict", 25 | "optimizer_dict", 26 | "scheduler_dict", 27 | "offset", 28 | "epoch", 29 | "encoder_params", 30 | ], 31 | ) 32 | 33 | 34 | def setup_for_distributed_mode( 35 | model: nn.Module, 36 | optimizer: torch.optim.Optimizer, 37 | device: object, 38 | n_gpu: int = 1, 39 | local_rank: int = -1, 40 | fp16: bool = False, 41 | fp16_opt_level: str = "O1", 42 | ) -> (nn.Module, torch.optim.Optimizer): 43 | model.to(device) 44 | if fp16: 45 | try: 46 | import apex 47 | from apex import amp 48 | 49 | apex.amp.register_half_function(torch, "einsum") 50 | except ImportError: 51 | raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.") 52 | 53 | model, optimizer = amp.initialize(model, optimizer, opt_level=fp16_opt_level) 54 | 55 | if n_gpu > 1: 56 | model = torch.nn.DataParallel(model) 57 | 58 | if local_rank != -1: 59 | model = torch.nn.parallel.DistributedDataParallel( 60 | model, 61 | device_ids=[device if device else local_rank], 62 | output_device=local_rank, 63 | find_unused_parameters=True, 64 | ) 65 | return model, optimizer 66 | 67 | 68 | def move_to_cuda(sample): 69 | if len(sample) == 0: 70 | return {} 71 | 72 | def _move_to_cuda(maybe_tensor): 73 | if torch.is_tensor(maybe_tensor): 74 | return maybe_tensor.cuda() 75 | elif isinstance(maybe_tensor, dict): 76 | return {key: _move_to_cuda(value) for key, value in maybe_tensor.items()} 77 | elif isinstance(maybe_tensor, list): 78 | return [_move_to_cuda(x) for x in maybe_tensor] 79 | elif isinstance(maybe_tensor, tuple): 80 | return [_move_to_cuda(x) for x in maybe_tensor] 81 | else: 82 | return maybe_tensor 83 | 84 | return _move_to_cuda(sample) 85 | 86 | 87 | def move_to_device(sample, device): 88 | if len(sample) == 0: 89 | return {} 90 | 91 | def _move_to_device(maybe_tensor, device): 92 | if torch.is_tensor(maybe_tensor): 93 | return maybe_tensor.to(device) 94 | elif isinstance(maybe_tensor, dict): 95 | return {key: _move_to_device(value, device) for key, value in maybe_tensor.items()} 96 | elif isinstance(maybe_tensor, list): 97 | return [_move_to_device(x, device) for x in maybe_tensor] 98 | elif isinstance(maybe_tensor, tuple): 99 | return [_move_to_device(x, device) for x in maybe_tensor] 100 | else: 101 | return maybe_tensor 102 | 103 | return _move_to_device(sample, device) 104 | 105 | 106 | def get_schedule_linear( 107 | optimizer, 108 | warmup_steps, 109 | total_training_steps, 110 | steps_shift=0, 111 | last_epoch=-1, 112 | ): 113 | 114 | """Create a schedule with a learning rate that decreases linearly after 115 | linearly increasing during a warmup period. 116 | """ 117 | 118 | def lr_lambda(current_step): 119 | current_step += steps_shift 120 | if current_step < warmup_steps: 121 | return float(current_step) / float(max(1, warmup_steps)) 122 | return max( 123 | 1e-7, 124 | float(total_training_steps - current_step) / float(max(1, total_training_steps - warmup_steps)), 125 | ) 126 | 127 | return LambdaLR(optimizer, lr_lambda, last_epoch) 128 | 129 | 130 | def init_weights(modules: List): 131 | for module in modules: 132 | if isinstance(module, (nn.Linear, nn.Embedding)): 133 | module.weight.data.normal_(mean=0.0, std=0.02) 134 | elif isinstance(module, nn.LayerNorm): 135 | module.bias.data.zero_() 136 | module.weight.data.fill_(1.0) 137 | if isinstance(module, nn.Linear) and module.bias is not None: 138 | module.bias.data.zero_() 139 | 140 | 141 | def get_model_obj(model: nn.Module): 142 | return model.module if hasattr(model, "module") else model 143 | 144 | 145 | def get_model_file(args, file_prefix) -> str: 146 | if args.model_file and os.path.exists(args.model_file): 147 | return args.model_file 148 | 149 | out_cp_files = glob.glob(os.path.join(args.output_dir, file_prefix + "*")) if args.output_dir else [] 150 | logger.info("Checkpoint files %s", out_cp_files) 151 | model_file = None 152 | 153 | if len(out_cp_files) > 0: 154 | model_file = max(out_cp_files, key=os.path.getctime) 155 | return model_file 156 | 157 | 158 | def load_states_from_checkpoint(model_file: str) -> CheckpointState: 159 | logger.info("Reading saved model from %s", model_file) 160 | state_dict = torch.load(model_file, map_location=lambda s, l: default_restore_location(s, "cpu")) 161 | logger.info("model_state_dict keys %s", state_dict.keys()) 162 | return CheckpointState(**state_dict) 163 | -------------------------------------------------------------------------------- /dpr/utils/tokenizers.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | 9 | """ 10 | Most of the tokenizers code here is copied from DrQA codebase to avoid adding extra dependency 11 | """ 12 | 13 | import copy 14 | import logging 15 | 16 | import regex 17 | import spacy 18 | 19 | logger = logging.getLogger(__name__) 20 | 21 | 22 | class Tokens(object): 23 | """A class to represent a list of tokenized text.""" 24 | 25 | TEXT = 0 26 | TEXT_WS = 1 27 | SPAN = 2 28 | POS = 3 29 | LEMMA = 4 30 | NER = 5 31 | 32 | def __init__(self, data, annotators, opts=None): 33 | self.data = data 34 | self.annotators = annotators 35 | self.opts = opts or {} 36 | 37 | def __len__(self): 38 | """The number of tokens.""" 39 | return len(self.data) 40 | 41 | def slice(self, i=None, j=None): 42 | """Return a view of the list of tokens from [i, j).""" 43 | new_tokens = copy.copy(self) 44 | new_tokens.data = self.data[i:j] 45 | return new_tokens 46 | 47 | def untokenize(self): 48 | """Returns the original text (with whitespace reinserted).""" 49 | return "".join([t[self.TEXT_WS] for t in self.data]).strip() 50 | 51 | def words(self, uncased=False): 52 | """Returns a list of the text of each token 53 | 54 | Args: 55 | uncased: lower cases text 56 | """ 57 | if uncased: 58 | return [t[self.TEXT].lower() for t in self.data] 59 | else: 60 | return [t[self.TEXT] for t in self.data] 61 | 62 | def offsets(self): 63 | """Returns a list of [start, end) character offsets of each token.""" 64 | return [t[self.SPAN] for t in self.data] 65 | 66 | def pos(self): 67 | """Returns a list of part-of-speech tags of each token. 68 | Returns None if this annotation was not included. 69 | """ 70 | if "pos" not in self.annotators: 71 | return None 72 | return [t[self.POS] for t in self.data] 73 | 74 | def lemmas(self): 75 | """Returns a list of the lemmatized text of each token. 76 | Returns None if this annotation was not included. 77 | """ 78 | if "lemma" not in self.annotators: 79 | return None 80 | return [t[self.LEMMA] for t in self.data] 81 | 82 | def entities(self): 83 | """Returns a list of named-entity-recognition tags of each token. 84 | Returns None if this annotation was not included. 85 | """ 86 | if "ner" not in self.annotators: 87 | return None 88 | return [t[self.NER] for t in self.data] 89 | 90 | def ngrams(self, n=1, uncased=False, filter_fn=None, as_strings=True): 91 | """Returns a list of all ngrams from length 1 to n. 92 | 93 | Args: 94 | n: upper limit of ngram length 95 | uncased: lower cases text 96 | filter_fn: user function that takes in an ngram list and returns 97 | True or False to keep or not keep the ngram 98 | as_string: return the ngram as a string vs list 99 | """ 100 | 101 | def _skip(gram): 102 | if not filter_fn: 103 | return False 104 | return filter_fn(gram) 105 | 106 | words = self.words(uncased) 107 | ngrams = [ 108 | (s, e + 1) 109 | for s in range(len(words)) 110 | for e in range(s, min(s + n, len(words))) 111 | if not _skip(words[s : e + 1]) 112 | ] 113 | 114 | # Concatenate into strings 115 | if as_strings: 116 | ngrams = ["{}".format(" ".join(words[s:e])) for (s, e) in ngrams] 117 | 118 | return ngrams 119 | 120 | def entity_groups(self): 121 | """Group consecutive entity tokens with the same NER tag.""" 122 | entities = self.entities() 123 | if not entities: 124 | return None 125 | non_ent = self.opts.get("non_ent", "O") 126 | groups = [] 127 | idx = 0 128 | while idx < len(entities): 129 | ner_tag = entities[idx] 130 | # Check for entity tag 131 | if ner_tag != non_ent: 132 | # Chomp the sequence 133 | start = idx 134 | while idx < len(entities) and entities[idx] == ner_tag: 135 | idx += 1 136 | groups.append((self.slice(start, idx).untokenize(), ner_tag)) 137 | else: 138 | idx += 1 139 | return groups 140 | 141 | 142 | class Tokenizer(object): 143 | """Base tokenizer class. 144 | Tokenizers implement tokenize, which should return a Tokens class. 145 | """ 146 | 147 | def tokenize(self, text): 148 | raise NotImplementedError 149 | 150 | def shutdown(self): 151 | pass 152 | 153 | def __del__(self): 154 | self.shutdown() 155 | 156 | 157 | class SimpleTokenizer(Tokenizer): 158 | ALPHA_NUM = r"[\p{L}\p{N}\p{M}]+" 159 | NON_WS = r"[^\p{Z}\p{C}]" 160 | 161 | def __init__(self, **kwargs): 162 | """ 163 | Args: 164 | annotators: None or empty set (only tokenizes). 165 | """ 166 | self._regexp = regex.compile( 167 | "(%s)|(%s)" % (self.ALPHA_NUM, self.NON_WS), 168 | flags=regex.IGNORECASE + regex.UNICODE + regex.MULTILINE, 169 | ) 170 | if len(kwargs.get("annotators", {})) > 0: 171 | logger.warning( 172 | "%s only tokenizes! Skipping annotators: %s" % (type(self).__name__, kwargs.get("annotators")) 173 | ) 174 | self.annotators = set() 175 | 176 | def tokenize(self, text): 177 | data = [] 178 | matches = [m for m in self._regexp.finditer(text)] 179 | for i in range(len(matches)): 180 | # Get text 181 | token = matches[i].group() 182 | 183 | # Get whitespace 184 | span = matches[i].span() 185 | start_ws = span[0] 186 | if i + 1 < len(matches): 187 | end_ws = matches[i + 1].span()[0] 188 | else: 189 | end_ws = span[1] 190 | 191 | # Format data 192 | data.append( 193 | ( 194 | token, 195 | text[start_ws:end_ws], 196 | span, 197 | ) 198 | ) 199 | return Tokens(data, self.annotators) 200 | 201 | 202 | class SpacyTokenizer(Tokenizer): 203 | def __init__(self, **kwargs): 204 | """ 205 | Args: 206 | annotators: set that can include pos, lemma, and ner. 207 | model: spaCy model to use (either path, or keyword like 'en'). 208 | """ 209 | model = kwargs.get("model", "en_core_web_sm") # TODO: replace with en ? 210 | self.annotators = copy.deepcopy(kwargs.get("annotators", set())) 211 | nlp_kwargs = {"parser": False} 212 | if not any([p in self.annotators for p in ["lemma", "pos", "ner"]]): 213 | nlp_kwargs["tagger"] = False 214 | if "ner" not in self.annotators: 215 | nlp_kwargs["entity"] = False 216 | self.nlp = spacy.load(model, **nlp_kwargs) 217 | 218 | def tokenize(self, text): 219 | # We don't treat new lines as tokens. 220 | clean_text = text.replace("\n", " ") 221 | tokens = self.nlp.tokenizer(clean_text) 222 | if any([p in self.annotators for p in ["lemma", "pos", "ner"]]): 223 | self.nlp.tagger(tokens) 224 | if "ner" in self.annotators: 225 | self.nlp.entity(tokens) 226 | 227 | data = [] 228 | for i in range(len(tokens)): 229 | # Get whitespace 230 | start_ws = tokens[i].idx 231 | if i + 1 < len(tokens): 232 | end_ws = tokens[i + 1].idx 233 | else: 234 | end_ws = tokens[i].idx + len(tokens[i].text) 235 | 236 | data.append( 237 | ( 238 | tokens[i].text, 239 | text[start_ws:end_ws], 240 | (tokens[i].idx, tokens[i].idx + len(tokens[i].text)), 241 | tokens[i].tag_, 242 | tokens[i].lemma_, 243 | tokens[i].ent_type_, 244 | ) 245 | ) 246 | 247 | # Set special option for non-entity tag: '' vs 'O' in spaCy 248 | return Tokens(data, self.annotators, opts={"non_ent": ""}) 249 | -------------------------------------------------------------------------------- /generate_dense_embeddings.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | 8 | """ 9 | Command line tool that produces embeddings for a large documents base based on the pretrained ctx & question encoders 10 | Supposed to be used in a 'sharded' way to speed up the process. 11 | """ 12 | import logging 13 | import math 14 | import os 15 | import pathlib 16 | import pickle 17 | from typing import List, Tuple 18 | 19 | import hydra 20 | import numpy as np 21 | import torch 22 | from omegaconf import DictConfig, OmegaConf 23 | from torch import nn 24 | 25 | from dpr.data.biencoder_data import BiEncoderPassage 26 | from dpr.models import init_biencoder_components 27 | from dpr.options import set_cfg_params_from_state, setup_cfg_gpu, setup_logger 28 | 29 | from dpr.utils.data_utils import Tensorizer 30 | from dpr.utils.model_utils import ( 31 | setup_for_distributed_mode, 32 | get_model_obj, 33 | load_states_from_checkpoint, 34 | move_to_device, 35 | ) 36 | 37 | logger = logging.getLogger() 38 | setup_logger(logger) 39 | 40 | 41 | def gen_ctx_vectors( 42 | cfg: DictConfig, 43 | ctx_rows: List[Tuple[object, BiEncoderPassage]], 44 | model: nn.Module, 45 | tensorizer: Tensorizer, 46 | insert_title: bool = True, 47 | ) -> List[Tuple[object, np.array]]: 48 | n = len(ctx_rows) 49 | bsz = cfg.batch_size 50 | total = 0 51 | results = [] 52 | for j, batch_start in enumerate(range(0, n, bsz)): 53 | batch = ctx_rows[batch_start : batch_start + bsz] 54 | batch_token_tensors = [ 55 | tensorizer.text_to_tensor(ctx[1].text, title=ctx[1].title if insert_title else None) for ctx in batch 56 | ] 57 | 58 | ctx_ids_batch = move_to_device(torch.stack(batch_token_tensors, dim=0), cfg.device) 59 | ctx_seg_batch = move_to_device(torch.zeros_like(ctx_ids_batch), cfg.device) 60 | ctx_attn_mask = move_to_device(tensorizer.get_attn_mask(ctx_ids_batch), cfg.device) 61 | with torch.no_grad(): 62 | _, out, _ = model(ctx_ids_batch, ctx_seg_batch, ctx_attn_mask) 63 | out = out.cpu() 64 | 65 | ctx_ids = [r[0] for r in batch] 66 | extra_info = [] 67 | if len(batch[0]) > 3: 68 | extra_info = [r[3:] for r in batch] 69 | 70 | assert len(ctx_ids) == out.size(0) 71 | total += len(ctx_ids) 72 | 73 | # TODO: refactor to avoid 'if' 74 | if extra_info: 75 | results.extend([(ctx_ids[i], out[i].view(-1).numpy(), *extra_info[i]) for i in range(out.size(0))]) 76 | else: 77 | results.extend([(ctx_ids[i], out[i].view(-1).numpy()) for i in range(out.size(0))]) 78 | 79 | if total % 10 == 0: 80 | logger.info("Encoded passages %d", total) 81 | return results 82 | 83 | 84 | @hydra.main(config_path="conf", config_name="gen_embs") 85 | def main(cfg: DictConfig): 86 | 87 | assert cfg.model_file, "Please specify encoder checkpoint as model_file param" 88 | assert cfg.ctx_src, "Please specify passages source as ctx_src param" 89 | 90 | cfg = setup_cfg_gpu(cfg) 91 | 92 | saved_state = load_states_from_checkpoint(cfg.model_file) 93 | set_cfg_params_from_state(saved_state.encoder_params, cfg) 94 | 95 | logger.info("CFG:") 96 | logger.info("%s", OmegaConf.to_yaml(cfg)) 97 | 98 | tensorizer, encoder, _ = init_biencoder_components(cfg.encoder.encoder_model_type, cfg, inference_only=True) 99 | 100 | encoder = encoder.ctx_model if cfg.encoder_type == "ctx" else encoder.question_model 101 | 102 | encoder, _ = setup_for_distributed_mode( 103 | encoder, 104 | None, 105 | cfg.device, 106 | cfg.n_gpu, 107 | cfg.local_rank, 108 | cfg.fp16, 109 | cfg.fp16_opt_level, 110 | ) 111 | encoder.eval() 112 | 113 | # load weights from the model file 114 | model_to_load = get_model_obj(encoder) 115 | logger.info("Loading saved model state ...") 116 | logger.debug("saved model keys =%s", saved_state.model_dict.keys()) 117 | 118 | prefix_len = len("ctx_model.") 119 | ctx_state = { 120 | key[prefix_len:]: value for (key, value) in saved_state.model_dict.items() if key.startswith("ctx_model.") 121 | } 122 | model_to_load.load_state_dict(ctx_state, strict=False) 123 | 124 | logger.info("reading data source: %s", cfg.ctx_src) 125 | 126 | ctx_src = hydra.utils.instantiate(cfg.ctx_sources[cfg.ctx_src]) 127 | all_passages_dict = {} 128 | ctx_src.load_data_to(all_passages_dict) 129 | all_passages = [(k, v) for k, v in all_passages_dict.items()] 130 | 131 | shard_size = math.ceil(len(all_passages) / cfg.num_shards) 132 | start_idx = cfg.shard_id * shard_size 133 | end_idx = start_idx + shard_size 134 | 135 | logger.info( 136 | "Producing encodings for passages range: %d to %d (out of total %d)", 137 | start_idx, 138 | end_idx, 139 | len(all_passages), 140 | ) 141 | shard_passages = all_passages[start_idx:end_idx] 142 | 143 | data = gen_ctx_vectors(cfg, shard_passages, encoder, tensorizer, True) 144 | 145 | file = cfg.out_file + "_" + str(cfg.shard_id) 146 | pathlib.Path(os.path.dirname(file)).mkdir(parents=True, exist_ok=True) 147 | logger.info("Writing results to %s" % file) 148 | with open(file, mode="wb") as f: 149 | pickle.dump(data, f) 150 | 151 | logger.info("Total passages processed %d. Written to %s", len(data), file) 152 | 153 | 154 | if __name__ == "__main__": 155 | main() 156 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # 4 | # This source code is licensed under the MIT license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | from setuptools import setup 9 | 10 | with open("README.md") as f: 11 | readme = f.read() 12 | 13 | setup( 14 | name="dpr", 15 | version="1.0.0", 16 | description="Facebook AI Research Open Domain Q&A Toolkit", 17 | url="https://github.com/facebookresearch/DPR/", 18 | classifiers=[ 19 | "Intended Audience :: Science/Research", 20 | "License :: OSI Approved :: MIT License", 21 | "Programming Language :: Python :: 3.6", 22 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 23 | ], 24 | long_description=readme, 25 | long_description_content_type="text/markdown", 26 | setup_requires=[ 27 | "setuptools>=18.0", 28 | ], 29 | install_requires=[ 30 | "faiss-cpu>=1.6.1", 31 | "filelock", 32 | "numpy", 33 | "regex", 34 | "torch>=1.5.0", 35 | "transformers>=4.3", 36 | "tqdm>=4.27", 37 | "wget", 38 | "spacy>=2.1.8", 39 | "hydra-core>=1.0.0", 40 | "omegaconf>=2.0.1", 41 | "jsonlines", 42 | "soundfile", 43 | "editdistance", 44 | ], 45 | ) 46 | --------------------------------------------------------------------------------