├── LICENSE ├── LICENSE-CODE ├── README.md ├── conformer.py ├── data.py ├── data_utils.py ├── factory.py ├── images └── CK.png ├── learner.py ├── learner_utils.py ├── loss.py ├── model.py ├── model_utils.py ├── parallel.py ├── run.py └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | Attribution 4.0 International 2 | 3 | ======================================================================= 4 | 5 | Creative Commons Corporation ("Creative Commons") is not a law firm and 6 | does not provide legal services or legal advice. Distribution of 7 | Creative Commons public licenses does not create a lawyer-client or 8 | other relationship. Creative Commons makes its licenses and related 9 | information available on an "as-is" basis. Creative Commons gives no 10 | warranties regarding its licenses, any material licensed under their 11 | terms and conditions, or any related information. Creative Commons 12 | disclaims all liability for damages resulting from their use to the 13 | fullest extent possible. 14 | 15 | Using Creative Commons Public Licenses 16 | 17 | Creative Commons public licenses provide a standard set of terms and 18 | conditions that creators and other rights holders may use to share 19 | original works of authorship and other material subject to copyright 20 | and certain other rights specified in the public license below. The 21 | following considerations are for informational purposes only, are not 22 | exhaustive, and do not form part of our licenses. 23 | 24 | Considerations for licensors: Our public licenses are 25 | intended for use by those authorized to give the public 26 | permission to use material in ways otherwise restricted by 27 | copyright and certain other rights. Our licenses are 28 | irrevocable. Licensors should read and understand the terms 29 | and conditions of the license they choose before applying it. 30 | Licensors should also secure all rights necessary before 31 | applying our licenses so that the public can reuse the 32 | material as expected. Licensors should clearly mark any 33 | material not subject to the license. This includes other CC- 34 | licensed material, or material used under an exception or 35 | limitation to copyright. More considerations for licensors: 36 | wiki.creativecommons.org/Considerations_for_licensors 37 | 38 | Considerations for the public: By using one of our public 39 | licenses, a licensor grants the public permission to use the 40 | licensed material under specified terms and conditions. If 41 | the licensor's permission is not necessary for any reason--for 42 | example, because of any applicable exception or limitation to 43 | copyright--then that use is not regulated by the license. Our 44 | licenses grant only permissions under copyright and certain 45 | other rights that a licensor has authority to grant. Use of 46 | the licensed material may still be restricted for other 47 | reasons, including because others have copyright or other 48 | rights in the material. A licensor may make special requests, 49 | such as asking that all changes be marked or described. 50 | Although not required by our licenses, you are encouraged to 51 | respect those requests where reasonable. More_considerations 52 | for the public: 53 | wiki.creativecommons.org/Considerations_for_licensees 54 | 55 | ======================================================================= 56 | 57 | Creative Commons Attribution 4.0 International Public License 58 | 59 | By exercising the Licensed Rights (defined below), You accept and agree 60 | to be bound by the terms and conditions of this Creative Commons 61 | Attribution 4.0 International Public License ("Public License"). To the 62 | extent this Public License may be interpreted as a contract, You are 63 | granted the Licensed Rights in consideration of Your acceptance of 64 | these terms and conditions, and the Licensor grants You such rights in 65 | consideration of benefits the Licensor receives from making the 66 | Licensed Material available under these terms and conditions. 67 | 68 | 69 | Section 1 -- Definitions. 70 | 71 | a. Adapted Material means material subject to Copyright and Similar 72 | Rights that is derived from or based upon the Licensed Material 73 | and in which the Licensed Material is translated, altered, 74 | arranged, transformed, or otherwise modified in a manner requiring 75 | permission under the Copyright and Similar Rights held by the 76 | Licensor. For purposes of this Public License, where the Licensed 77 | Material is a musical work, performance, or sound recording, 78 | Adapted Material is always produced where the Licensed Material is 79 | synched in timed relation with a moving image. 80 | 81 | b. Adapter's License means the license You apply to Your Copyright 82 | and Similar Rights in Your contributions to Adapted Material in 83 | accordance with the terms and conditions of this Public License. 84 | 85 | c. Copyright and Similar Rights means copyright and/or similar rights 86 | closely related to copyright including, without limitation, 87 | performance, broadcast, sound recording, and Sui Generis Database 88 | Rights, without regard to how the rights are labeled or 89 | categorized. For purposes of this Public License, the rights 90 | specified in Section 2(b)(1)-(2) are not Copyright and Similar 91 | Rights. 92 | 93 | d. Effective Technological Measures means those measures that, in the 94 | absence of proper authority, may not be circumvented under laws 95 | fulfilling obligations under Article 11 of the WIPO Copyright 96 | Treaty adopted on December 20, 1996, and/or similar international 97 | agreements. 98 | 99 | e. Exceptions and Limitations means fair use, fair dealing, and/or 100 | any other exception or limitation to Copyright and Similar Rights 101 | that applies to Your use of the Licensed Material. 102 | 103 | f. Licensed Material means the artistic or literary work, database, 104 | or other material to which the Licensor applied this Public 105 | License. 106 | 107 | g. Licensed Rights means the rights granted to You subject to the 108 | terms and conditions of this Public License, which are limited to 109 | all Copyright and Similar Rights that apply to Your use of the 110 | Licensed Material and that the Licensor has authority to license. 111 | 112 | h. Licensor means the individual(s) or entity(ies) granting rights 113 | under this Public License. 114 | 115 | i. Share means to provide material to the public by any means or 116 | process that requires permission under the Licensed Rights, such 117 | as reproduction, public display, public performance, distribution, 118 | dissemination, communication, or importation, and to make material 119 | available to the public including in ways that members of the 120 | public may access the material from a place and at a time 121 | individually chosen by them. 122 | 123 | j. Sui Generis Database Rights means rights other than copyright 124 | resulting from Directive 96/9/EC of the European Parliament and of 125 | the Council of 11 March 1996 on the legal protection of databases, 126 | as amended and/or succeeded, as well as other essentially 127 | equivalent rights anywhere in the world. 128 | 129 | k. You means the individual or entity exercising the Licensed Rights 130 | under this Public License. Your has a corresponding meaning. 131 | 132 | 133 | Section 2 -- Scope. 134 | 135 | a. License grant. 136 | 137 | 1. Subject to the terms and conditions of this Public License, 138 | the Licensor hereby grants You a worldwide, royalty-free, 139 | non-sublicensable, non-exclusive, irrevocable license to 140 | exercise the Licensed Rights in the Licensed Material to: 141 | 142 | a. reproduce and Share the Licensed Material, in whole or 143 | in part; and 144 | 145 | b. produce, reproduce, and Share Adapted Material. 146 | 147 | 2. Exceptions and Limitations. For the avoidance of doubt, where 148 | Exceptions and Limitations apply to Your use, this Public 149 | License does not apply, and You do not need to comply with 150 | its terms and conditions. 151 | 152 | 3. Term. The term of this Public License is specified in Section 153 | 6(a). 154 | 155 | 4. Media and formats; technical modifications allowed. The 156 | Licensor authorizes You to exercise the Licensed Rights in 157 | all media and formats whether now known or hereafter created, 158 | and to make technical modifications necessary to do so. The 159 | Licensor waives and/or agrees not to assert any right or 160 | authority to forbid You from making technical modifications 161 | necessary to exercise the Licensed Rights, including 162 | technical modifications necessary to circumvent Effective 163 | Technological Measures. For purposes of this Public License, 164 | simply making modifications authorized by this Section 2(a) 165 | (4) never produces Adapted Material. 166 | 167 | 5. Downstream recipients. 168 | 169 | a. Offer from the Licensor -- Licensed Material. Every 170 | recipient of the Licensed Material automatically 171 | receives an offer from the Licensor to exercise the 172 | Licensed Rights under the terms and conditions of this 173 | Public License. 174 | 175 | b. No downstream restrictions. You may not offer or impose 176 | any additional or different terms or conditions on, or 177 | apply any Effective Technological Measures to, the 178 | Licensed Material if doing so restricts exercise of the 179 | Licensed Rights by any recipient of the Licensed 180 | Material. 181 | 182 | 6. No endorsement. Nothing in this Public License constitutes or 183 | may be construed as permission to assert or imply that You 184 | are, or that Your use of the Licensed Material is, connected 185 | with, or sponsored, endorsed, or granted official status by, 186 | the Licensor or others designated to receive attribution as 187 | provided in Section 3(a)(1)(A)(i). 188 | 189 | b. Other rights. 190 | 191 | 1. Moral rights, such as the right of integrity, are not 192 | licensed under this Public License, nor are publicity, 193 | privacy, and/or other similar personality rights; however, to 194 | the extent possible, the Licensor waives and/or agrees not to 195 | assert any such rights held by the Licensor to the limited 196 | extent necessary to allow You to exercise the Licensed 197 | Rights, but not otherwise. 198 | 199 | 2. Patent and trademark rights are not licensed under this 200 | Public License. 201 | 202 | 3. To the extent possible, the Licensor waives any right to 203 | collect royalties from You for the exercise of the Licensed 204 | Rights, whether directly or through a collecting society 205 | under any voluntary or waivable statutory or compulsory 206 | licensing scheme. In all other cases the Licensor expressly 207 | reserves any right to collect such royalties. 208 | 209 | 210 | Section 3 -- License Conditions. 211 | 212 | Your exercise of the Licensed Rights is expressly made subject to the 213 | following conditions. 214 | 215 | a. Attribution. 216 | 217 | 1. If You Share the Licensed Material (including in modified 218 | form), You must: 219 | 220 | a. retain the following if it is supplied by the Licensor 221 | with the Licensed Material: 222 | 223 | i. identification of the creator(s) of the Licensed 224 | Material and any others designated to receive 225 | attribution, in any reasonable manner requested by 226 | the Licensor (including by pseudonym if 227 | designated); 228 | 229 | ii. a copyright notice; 230 | 231 | iii. a notice that refers to this Public License; 232 | 233 | iv. a notice that refers to the disclaimer of 234 | warranties; 235 | 236 | v. a URI or hyperlink to the Licensed Material to the 237 | extent reasonably practicable; 238 | 239 | b. indicate if You modified the Licensed Material and 240 | retain an indication of any previous modifications; and 241 | 242 | c. indicate the Licensed Material is licensed under this 243 | Public License, and include the text of, or the URI or 244 | hyperlink to, this Public License. 245 | 246 | 2. You may satisfy the conditions in Section 3(a)(1) in any 247 | reasonable manner based on the medium, means, and context in 248 | which You Share the Licensed Material. For example, it may be 249 | reasonable to satisfy the conditions by providing a URI or 250 | hyperlink to a resource that includes the required 251 | information. 252 | 253 | 3. If requested by the Licensor, You must remove any of the 254 | information required by Section 3(a)(1)(A) to the extent 255 | reasonably practicable. 256 | 257 | 4. If You Share Adapted Material You produce, the Adapter's 258 | License You apply must not prevent recipients of the Adapted 259 | Material from complying with this Public License. 260 | 261 | 262 | Section 4 -- Sui Generis Database Rights. 263 | 264 | Where the Licensed Rights include Sui Generis Database Rights that 265 | apply to Your use of the Licensed Material: 266 | 267 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right 268 | to extract, reuse, reproduce, and Share all or a substantial 269 | portion of the contents of the database; 270 | 271 | b. if You include all or a substantial portion of the database 272 | contents in a database in which You have Sui Generis Database 273 | Rights, then the database in which You have Sui Generis Database 274 | Rights (but not its individual contents) is Adapted Material; and 275 | 276 | c. You must comply with the conditions in Section 3(a) if You Share 277 | all or a substantial portion of the contents of the database. 278 | 279 | For the avoidance of doubt, this Section 4 supplements and does not 280 | replace Your obligations under this Public License where the Licensed 281 | Rights include other Copyright and Similar Rights. 282 | 283 | 284 | Section 5 -- Disclaimer of Warranties and Limitation of Liability. 285 | 286 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE 287 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS 288 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF 289 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, 290 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, 291 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR 292 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, 293 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT 294 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT 295 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. 296 | 297 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE 298 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, 299 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, 300 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, 301 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR 302 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN 303 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR 304 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR 305 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. 306 | 307 | c. The disclaimer of warranties and limitation of liability provided 308 | above shall be interpreted in a manner that, to the extent 309 | possible, most closely approximates an absolute disclaimer and 310 | waiver of all liability. 311 | 312 | 313 | Section 6 -- Term and Termination. 314 | 315 | a. This Public License applies for the term of the Copyright and 316 | Similar Rights licensed here. However, if You fail to comply with 317 | this Public License, then Your rights under this Public License 318 | terminate automatically. 319 | 320 | b. Where Your right to use the Licensed Material has terminated under 321 | Section 6(a), it reinstates: 322 | 323 | 1. automatically as of the date the violation is cured, provided 324 | it is cured within 30 days of Your discovery of the 325 | violation; or 326 | 327 | 2. upon express reinstatement by the Licensor. 328 | 329 | For the avoidance of doubt, this Section 6(b) does not affect any 330 | right the Licensor may have to seek remedies for Your violations 331 | of this Public License. 332 | 333 | c. For the avoidance of doubt, the Licensor may also offer the 334 | Licensed Material under separate terms or conditions or stop 335 | distributing the Licensed Material at any time; however, doing so 336 | will not terminate this Public License. 337 | 338 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public 339 | License. 340 | 341 | 342 | Section 7 -- Other Terms and Conditions. 343 | 344 | a. The Licensor shall not be bound by any additional or different 345 | terms or conditions communicated by You unless expressly agreed. 346 | 347 | b. Any arrangements, understandings, or agreements regarding the 348 | Licensed Material not stated herein are separate from and 349 | independent of the terms and conditions of this Public License. 350 | 351 | 352 | Section 8 -- Interpretation. 353 | 354 | a. For the avoidance of doubt, this Public License does not, and 355 | shall not be interpreted to, reduce, limit, restrict, or impose 356 | conditions on any use of the Licensed Material that could lawfully 357 | be made without permission under this Public License. 358 | 359 | b. To the extent possible, if any provision of this Public License is 360 | deemed unenforceable, it shall be automatically reformed to the 361 | minimum extent necessary to make it enforceable. If the provision 362 | cannot be reformed, it shall be severed from this Public License 363 | without affecting the enforceability of the remaining terms and 364 | conditions. 365 | 366 | c. No term or condition of this Public License will be waived and no 367 | failure to comply consented to unless expressly agreed to by the 368 | Licensor. 369 | 370 | d. Nothing in this Public License constitutes or may be interpreted 371 | as a limitation upon, or waiver of, any privileges and immunities 372 | that apply to the Licensor or You, including from the legal 373 | processes of any jurisdiction or authority. 374 | 375 | 376 | ======================================================================= 377 | 378 | Creative Commons is not a party to its public 379 | licenses. Notwithstanding, Creative Commons may elect to apply one of 380 | its public licenses to material it publishes and in those instances 381 | will be considered the “Licensor.” The text of the Creative Commons 382 | public licenses is dedicated to the public domain under the CC0 Public 383 | Domain Dedication. Except for the limited purpose of indicating that 384 | material is shared under a Creative Commons public license or as 385 | otherwise permitted by the Creative Commons policies published at 386 | creativecommons.org/policies, Creative Commons does not authorize the 387 | use of the trademark "Creative Commons" or any other trademark or logo 388 | of Creative Commons without its prior written consent including, 389 | without limitation, in connection with any unauthorized modifications 390 | to any of its public licenses or any other arrangements, 391 | understandings, or agreements concerning use of licensed material. For 392 | the avoidance of doubt, this paragraph does not form part of the 393 | public licenses. 394 | 395 | Creative Commons may be contacted at creativecommons.org. -------------------------------------------------------------------------------- /LICENSE-CODE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Microsoft Corporation. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # TREC Deep Learning Quick Start 2 | 3 | This is a quick start guide for the document ranking task in the TREC Deep Learning (TREC-DL) benchmark. 4 | If you are new to TREC-DL, then this repository may make it more convenient for you to download all the required datasets and then train and evaluate a relatively efficient deep neural baseline on this benchmark, under both the rerank and the fullrank settings. 5 | 6 | If you are unfamiliar with the TREC-DL benchmark, then you may want to first go through the websites and overview paper corresponding to previous and current editions of the track. 7 | * TREC-DL 2019: [website](https://microsoft.github.io/msmarco/TREC-Deep-Learning-2019) and [overview paper](https://arxiv.org/pdf/2003.07820.pdf) 8 | * TREC-DL 2020: [website](https://microsoft.github.io/msmarco/TREC-Deep-Learning-2020) and [overview paper](https://arxiv.org/pdf/2102.07662.pdf) 9 | * TREC-DL (current): [website](https://microsoft.github.io/msmarco/TREC-Deep-Learning) 10 | 11 | ### DISCLAIMER 12 | While some of the contributors to this repository also serve as organizers for TREC-DL, please note that this code is **NOT** officially associated in any way with the TREC track. 13 | Instead, this is a personal codebase that we have been using for our own experimentation and we are releasing it publicly with the hope that it may be useful for others who are just starting out on this benchmark. 14 | 15 | As with any research code, you may find some kinks or bugs. 16 | Please report any and all bugs and issues you discover, and we will try to get to them as soon as possible. 17 | If you have any questions or feedback, please reach out to us via [email](mailto:bmitra@microsoft.com) or [Twitter](https://twitter.com/UnderdogGeek). 18 | 19 | Also, please be aware that we may sometimes push new changes and model updates based on personal on-going research and experimentation. 20 | 21 | 22 | ## The Conformer-Kernel Model with Query Term Independence (QTI) 23 | 24 | The base model implements the Conformer-Kernel architecture with QTI, as described in this [paper](https://arxiv.org/pdf/2007.10434.pdf). 25 | 26 | ![The Conformer-Kernel architecture with QTI](images/CK.png) 27 | 28 | If you use this code for your research, please cite the [paper](https://arxiv.org/pdf/2007.10434.pdf) as follows: 29 | 30 | ``` 31 | @article{mitra2020conformer, 32 | title={Conformer-Kernel with Query Term Independence for Document Retrieval}, 33 | author={Mitra, Bhaskar and Hofst\"{a}tter, Sebastian and Zamani, Hamed and Craswell, Nick}, 34 | journal={arXiv preprint arXiv:}, 35 | year={2020} 36 | } 37 | ``` 38 | 39 | Specifically, the code provides a choice between three existing models: 40 | * **NDRM1**: A Conformer-Kernel architecture with QTI for latent representation learning and matching 41 | * **NDRM2**: A simple learned BM25-like ranking function with QTI for explicit term matching 42 | * **NDRM3** (default): A linear combination of **NDRM1** and **NDRM2** 43 | 44 | You can also plug-in your own neural model by simply replacing the ```model.py``` and ```model_utils.py``` with appropriate implementations corresponding to your model. 45 | The full retrieval evaluation assumes query term independence. 46 | If that assumption does not hold for your new model, please comment out the calls to ```evaluate_full_retrieval``` in ```learner.py```. 47 | 48 | You can also enable the use of the [ORCAS dataset](https://microsoft.github.io/TREC-2020-Deep-Learning/ORCAS) as training data or as an additional document field by using the ```--orcas_train``` and ```--orcas_field``` arguments, respectively. 49 | 50 | ## Requirements 51 | 52 | The code in this repository has been tested with: 53 | * **Python version**: 3.5.5 54 | * **PyTorch version**: 1.3.1 55 | * **CUDA version**: 10.0.130 56 | 57 | The training and evaluatin were performed using **4 Tesla P100 GPUs** with 16280 MiB memory each. 58 | Depending on your GPU availability, you may need to set the minibatch size accordingly for train (```--mb_size_train```), test (```--mb_size_test```), and inference (```--mb_size_infer```). 59 | 60 | In addition, the code assumes the following Python packages are installed: 61 | * numpy 62 | * fasttext 63 | * krovetzstemmer 64 | * clint 65 | 66 | Using PIP, you can install all of them by running the following from command-line: 67 | 68 | ``` 69 | pip install numpy fasttext krovetzstemmer clint 70 | ``` 71 | 72 | ## Getting Started 73 | 74 | Please clone the repo and run ```python run.py```. 75 | 76 | The script should automatically download all necessary data files, if missing, which can take significant amount of time depending on network speed. 77 | If the download fails for any particular file then please delete the local incomplete copy and re-run the script. 78 | The script performs pretty aggressive text normalization that may not always be appropriate. 79 | Please be aware of this and modify the code if you desire a different behaviour. 80 | 81 | After the download completes, the script should first pretrain a word2vec model for the input embeddings. 82 | Then subsequently, it should train a simple neural document ranking model (NDRM) and report metrics on the TREC-DL 2019 test set for both the reranking and the fullranking tasks. 83 | The script should also prepare the run files corresponding to the TREC-DL 2020 test set for submission. 84 | 85 | Couple of additional notes: 86 | * The code automatically downloads the whole ORCAS dataset. 87 | We plan to make this optional in the future but have not got to implementing it yet. 88 | So, please feel free to disable that in the code directly for now to avoid downloading them unnecessarily if you don't plan to use them. 89 | * The IDF file is generated conservatively only for terms that appear in the train, dev, validation, and test queries. 90 | So, if you change or add to the query files, then please delete the generated IDF file and rerun the script to regenerate it. 91 | 92 | ## Legal Notices 93 | 94 | Microsoft and any contributors grant you a license to the Microsoft documentation and other content in this repository under the [Creative Commons Attribution 4.0 International Public License](https://creativecommons.org/licenses/by/4.0/legalcode), see the [LICENSE](LICENSE) file, and grant you a license to any code in the repository under the [MIT License](https://opensource.org/licenses/MIT), see the [LICENSE-CODE](LICENSE-CODE) file. 95 | 96 | Microsoft, Windows, Microsoft Azure and/or other Microsoft products and services referenced in the documentation 97 | may be either trademarks or registered trademarks of Microsoft in the United States and/or other countries. 98 | The licenses for this project do not grant you rights to use any Microsoft names, logos, or trademarks. 99 | Microsoft's general trademark guidelines can be found at . 100 | 101 | Privacy information can be found at . 102 | 103 | Microsoft and any contributors reserve all other rights, whether under their respective copyrights, patents, 104 | or trademarks, whether by implication, estoppel or otherwise. 105 | -------------------------------------------------------------------------------- /conformer.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import math 3 | import datetime 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | 9 | def multi_head_sep_attention_forward(query, # type: Tensor 10 | key, # type: Tensor 11 | value, # type: Tensor 12 | embed_dim_to_check, # type: int 13 | num_heads, # type: int 14 | in_proj_weight, # type: Tensor 15 | in_proj_bias, # type: Tensor 16 | bias_k, # type: Optional[Tensor] 17 | bias_v, # type: Optional[Tensor] 18 | add_zero_attn, # type: bool 19 | dropout_p, # type: float 20 | out_proj_weight, # type: Tensor 21 | out_proj_bias, # type: Tensor 22 | training=True, # type: bool 23 | key_padding_mask=None, # type: Optional[Tensor] 24 | need_weights=True, # type: bool 25 | attn_mask=None, # type: Optional[Tensor] 26 | use_separate_proj_weight=False, # type: bool 27 | q_proj_weight=None, # type: Optional[Tensor] 28 | k_proj_weight=None, # type: Optional[Tensor] 29 | v_proj_weight=None, # type: Optional[Tensor] 30 | static_k=None, # type: Optional[Tensor] 31 | static_v=None, # type: Optional[Tensor] 32 | shared_qk=False, # type: bool 33 | sep=False # type: bool 34 | ): 35 | if not torch.jit.is_scripting(): 36 | tens_ops = (query, key, value, in_proj_weight, in_proj_bias, bias_k, bias_v, 37 | out_proj_weight, out_proj_bias) 38 | #if any([type(t) is not Tensor for t in tens_ops]) and has_torch_function(tens_ops): 39 | # return handle_torch_function( 40 | # multi_head_sep_attention_forward, tens_ops, query, key, value, 41 | # embed_dim_to_check, num_heads, in_proj_weight, in_proj_bias, 42 | # bias_k, bias_v, add_zero_attn, dropout_p, out_proj_weight, 43 | # out_proj_bias, training=training, key_padding_mask=key_padding_mask, 44 | # need_weights=need_weights, attn_mask=attn_mask, 45 | # use_separate_proj_weight=use_separate_proj_weight, 46 | # q_proj_weight=q_proj_weight, k_proj_weight=k_proj_weight, 47 | # v_proj_weight=v_proj_weight, static_k=static_k, static_v=static_v) 48 | tgt_len, bsz, embed_dim = query.size() 49 | assert embed_dim == embed_dim_to_check 50 | assert key.size() == value.size() 51 | 52 | head_dim = embed_dim // num_heads 53 | assert head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads" 54 | scaling = float(head_dim) ** -0.5 55 | 56 | if not use_separate_proj_weight: 57 | if torch.equal(query, key) and torch.equal(key, value): 58 | # self-attention 59 | if shared_qk: 60 | k, v = F.linear(query, in_proj_weight, in_proj_bias).chunk(2, dim=-1) 61 | q = k 62 | else: 63 | q, k, v = F.linear(query, in_proj_weight, in_proj_bias).chunk(3, dim=-1) 64 | 65 | elif torch.equal(key, value): 66 | # encoder-decoder attention 67 | # This is inline in_proj function with in_proj_weight and in_proj_bias 68 | _b = in_proj_bias 69 | _start = 0 70 | _end = embed_dim 71 | _w = in_proj_weight[_start:_end, :] 72 | if _b is not None: 73 | _b = _b[_start:_end] 74 | q = F.linear(query, _w, _b) 75 | 76 | if key is None: 77 | assert value is None 78 | k = None 79 | v = None 80 | else: 81 | 82 | # This is inline in_proj function with in_proj_weight and in_proj_bias 83 | _b = in_proj_bias 84 | _start = embed_dim 85 | _end = None 86 | _w = in_proj_weight[_start:, :] 87 | if _b is not None: 88 | _b = _b[_start:] 89 | k, v = F.linear(key, _w, _b).chunk(2, dim=-1) 90 | 91 | else: 92 | # This is inline in_proj function with in_proj_weight and in_proj_bias 93 | _b = in_proj_bias 94 | _start = 0 95 | _end = embed_dim 96 | _w = in_proj_weight[_start:_end, :] 97 | if _b is not None: 98 | _b = _b[_start:_end] 99 | q = F.linear(query, _w, _b) 100 | 101 | # This is inline in_proj function with in_proj_weight and in_proj_bias 102 | _b = in_proj_bias 103 | _start = embed_dim 104 | _end = embed_dim * 2 105 | _w = in_proj_weight[_start:_end, :] 106 | if _b is not None: 107 | _b = _b[_start:_end] 108 | k = F.linear(key, _w, _b) 109 | 110 | # This is inline in_proj function with in_proj_weight and in_proj_bias 111 | _b = in_proj_bias 112 | _start = embed_dim * 2 113 | _end = None 114 | _w = in_proj_weight[_start:, :] 115 | if _b is not None: 116 | _b = _b[_start:] 117 | v = F.linear(value, _w, _b) 118 | else: 119 | q_proj_weight_non_opt = torch.jit._unwrap_optional(q_proj_weight) 120 | len1, len2 = q_proj_weight_non_opt.size() 121 | assert len1 == embed_dim and len2 == query.size(-1) 122 | 123 | k_proj_weight_non_opt = torch.jit._unwrap_optional(k_proj_weight) 124 | len1, len2 = k_proj_weight_non_opt.size() 125 | assert len1 == embed_dim and len2 == key.size(-1) 126 | 127 | v_proj_weight_non_opt = torch.jit._unwrap_optional(v_proj_weight) 128 | len1, len2 = v_proj_weight_non_opt.size() 129 | assert len1 == embed_dim and len2 == value.size(-1) 130 | 131 | if in_proj_bias is not None: 132 | q = F.linear(query, q_proj_weight_non_opt, in_proj_bias[0:embed_dim]) 133 | k = F.linear(key, k_proj_weight_non_opt, in_proj_bias[embed_dim:(embed_dim * 2)]) 134 | v = F.linear(value, v_proj_weight_non_opt, in_proj_bias[(embed_dim * 2):]) 135 | else: 136 | q = F.linear(query, q_proj_weight_non_opt, in_proj_bias) 137 | k = F.linear(key, k_proj_weight_non_opt, in_proj_bias) 138 | v = F.linear(value, v_proj_weight_non_opt, in_proj_bias) 139 | q = q * scaling 140 | 141 | if attn_mask is not None: 142 | if attn_mask.dim() == 2: 143 | attn_mask = attn_mask.unsqueeze(0) 144 | if list(attn_mask.size()) != [1, query.size(0), key.size(0)]: 145 | raise RuntimeError('The size of the 2D attn_mask is not correct.') 146 | elif attn_mask.dim() == 3: 147 | if list(attn_mask.size()) != [bsz * num_heads, query.size(0), key.size(0)]: 148 | raise RuntimeError('The size of the 3D attn_mask is not correct.') 149 | else: 150 | raise RuntimeError("attn_mask's dimension {} is not supported".format(attn_mask.dim())) 151 | # attn_mask's dim is 3 now. 152 | 153 | if bias_k is not None and bias_v is not None: 154 | if static_k is None and static_v is None: 155 | k = torch.cat([k, bias_k.repeat(1, bsz, 1)]) 156 | v = torch.cat([v, bias_v.repeat(1, bsz, 1)]) 157 | if attn_mask is not None: 158 | attn_mask = pad(attn_mask, (0, 1)) 159 | if key_padding_mask is not None: 160 | key_padding_mask = pad(key_padding_mask, (0, 1)) 161 | else: 162 | assert static_k is None, "bias cannot be added to static key." 163 | assert static_v is None, "bias cannot be added to static value." 164 | else: 165 | assert bias_k is None 166 | assert bias_v is None 167 | 168 | q = q.contiguous().view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1) 169 | if k is not None: 170 | k = k.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) 171 | if v is not None: 172 | v = v.contiguous().view(-1, bsz * num_heads, head_dim).transpose(0, 1) 173 | 174 | if static_k is not None: 175 | assert static_k.size(0) == bsz * num_heads 176 | assert static_k.size(2) == head_dim 177 | k = static_k 178 | 179 | if static_v is not None: 180 | assert static_v.size(0) == bsz * num_heads 181 | assert static_v.size(2) == head_dim 182 | v = static_v 183 | 184 | src_len = k.size(1) 185 | 186 | if key_padding_mask is not None: 187 | assert key_padding_mask.size(0) == bsz 188 | assert key_padding_mask.size(1) == src_len 189 | 190 | if add_zero_attn: 191 | src_len += 1 192 | k = torch.cat([k, torch.zeros((k.size(0), 1) + k.size()[2:], dtype=k.dtype, device=k.device)], dim=1) 193 | v = torch.cat([v, torch.zeros((v.size(0), 1) + v.size()[2:], dtype=v.dtype, device=v.device)], dim=1) 194 | if attn_mask is not None: 195 | attn_mask = pad(attn_mask, (0, 1)) 196 | if key_padding_mask is not None: 197 | key_padding_mask = pad(key_padding_mask, (0, 1)) 198 | 199 | if sep: 200 | attn_output_weights = k.transpose(1, 2) 201 | assert list(attn_output_weights.size()) == [bsz * num_heads, head_dim, src_len] 202 | 203 | if key_padding_mask is not None: 204 | attn_output_weights = attn_output_weights.view(bsz, num_heads, head_dim, src_len) 205 | attn_output_weights = attn_output_weights.masked_fill( 206 | key_padding_mask.unsqueeze(1).unsqueeze(2), 207 | float('-inf'), 208 | ) 209 | attn_output_weights = attn_output_weights.view(bsz * num_heads, head_dim, src_len) 210 | 211 | attn_output_weights = F.softmax( 212 | attn_output_weights, dim=-1) 213 | attn_output_weights = F.dropout(attn_output_weights, p=dropout_p, training=training) 214 | 215 | attn_output = torch.bmm(attn_output_weights, v) 216 | assert list(attn_output.size()) == [bsz * num_heads, head_dim, head_dim] 217 | 218 | attn_output_weights = F.softmax( 219 | q, dim=-1) 220 | attn_output_weights = F.dropout(attn_output_weights, p=dropout_p, training=training) 221 | 222 | attn_output = torch.bmm(attn_output_weights, attn_output) 223 | assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] 224 | attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) 225 | attn_output = F.linear(attn_output, out_proj_weight, out_proj_bias) 226 | return attn_output, None 227 | else: 228 | attn_output_weights = torch.bmm(q, k.transpose(1, 2)) 229 | assert list(attn_output_weights.size()) == [bsz * num_heads, tgt_len, src_len] 230 | 231 | if attn_mask is not None: 232 | attn_output_weights += attn_mask 233 | 234 | if key_padding_mask is not None: 235 | attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len) 236 | attn_output_weights = attn_output_weights.masked_fill( 237 | key_padding_mask.unsqueeze(1).unsqueeze(2), 238 | float('-inf'), 239 | ) 240 | attn_output_weights = attn_output_weights.view(bsz * num_heads, tgt_len, src_len) 241 | 242 | attn_output_weights = F.softmax( 243 | attn_output_weights, dim=-1) 244 | attn_output_weights = F.dropout(attn_output_weights, p=dropout_p, training=training) 245 | 246 | attn_output = torch.bmm(attn_output_weights, v) 247 | assert list(attn_output.size()) == [bsz * num_heads, tgt_len, head_dim] 248 | attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim) 249 | attn_output = F.linear(attn_output, out_proj_weight, out_proj_bias) 250 | 251 | if need_weights: 252 | # average attention weights over heads 253 | attn_output_weights = attn_output_weights.view(bsz, num_heads, tgt_len, src_len) 254 | return attn_output, attn_output_weights.sum(dim=1) / num_heads 255 | else: 256 | return attn_output, None 257 | 258 | 259 | class MultiheadSeparableAttention(nn.Module): 260 | __annotations__ = { 261 | 'bias_k': torch._jit_internal.Optional[torch.Tensor], 262 | 'bias_v': torch._jit_internal.Optional[torch.Tensor], 263 | } 264 | __constants__ = ['q_proj_weight', 'k_proj_weight', 'v_proj_weight', 'in_proj_weight'] 265 | 266 | def __init__(self, embed_dim, num_heads, dropout=0., bias=True, add_bias_kv=False, add_zero_attn=False, kdim=None, vdim=None, shared_qk=False, sep=False): 267 | super(MultiheadSeparableAttention, self).__init__() 268 | self.embed_dim = embed_dim 269 | self.kdim = kdim if kdim is not None else embed_dim 270 | self.vdim = vdim if vdim is not None else embed_dim 271 | self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim 272 | 273 | self.num_heads = num_heads 274 | self.dropout = dropout 275 | self.head_dim = embed_dim // num_heads 276 | assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads" 277 | 278 | self.shared_qk = shared_qk 279 | self.sep = sep 280 | 281 | if self._qkv_same_embed_dim is False: 282 | self.q_proj_weight = nn.Parameter(torch.Tensor(embed_dim, embed_dim)) 283 | self.k_proj_weight = nn.Parameter(torch.Tensor(embed_dim, self.kdim)) 284 | self.v_proj_weight = nn.Parameter(torch.Tensor(embed_dim, self.vdim)) 285 | self.register_parameter('in_proj_weight', None) 286 | else: 287 | if shared_qk: 288 | self.in_proj_weight = nn.Parameter(torch.empty(2 * embed_dim, embed_dim)) 289 | else: 290 | self.in_proj_weight = nn.Parameter(torch.empty(3 * embed_dim, embed_dim)) 291 | self.register_parameter('q_proj_weight', None) 292 | self.register_parameter('k_proj_weight', None) 293 | self.register_parameter('v_proj_weight', None) 294 | 295 | if bias: 296 | if shared_qk: 297 | self.in_proj_bias = nn.Parameter(torch.empty(2 * embed_dim)) 298 | else: 299 | self.in_proj_bias = nn.Parameter(torch.empty(3 * embed_dim)) 300 | else: 301 | self.register_parameter('in_proj_bias', None) 302 | self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) 303 | 304 | if add_bias_kv: 305 | self.bias_k = nn.Parameter(torch.empty(1, 1, embed_dim)) 306 | self.bias_v = nn.Parameter(torch.empty(1, 1, embed_dim)) 307 | else: 308 | self.bias_k = self.bias_v = None 309 | 310 | self.add_zero_attn = add_zero_attn 311 | 312 | 313 | self._reset_parameters() 314 | 315 | def _reset_parameters(self): 316 | if self._qkv_same_embed_dim: 317 | nn.init.xavier_uniform_(self.in_proj_weight) 318 | else: 319 | nn.init.xavier_uniform_(self.q_proj_weight) 320 | nn.init.xavier_uniform_(self.k_proj_weight) 321 | nn.init.xavier_uniform_(self.v_proj_weight) 322 | 323 | if self.in_proj_bias is not None: 324 | nn.init.constant_(self.in_proj_bias, 0.) 325 | nn.init.constant_(self.out_proj.bias, 0.) 326 | if self.bias_k is not None: 327 | nn.init.xavier_normal_(self.bias_k) 328 | if self.bias_v is not None: 329 | nn.init.xavier_normal_(self.bias_v) 330 | 331 | def __setstate__(self, state): 332 | # Support loading old MultiheadSeparableAttention checkpoints generated by v1.1.0 333 | if '_qkv_same_embed_dim' not in state: 334 | state['_qkv_same_embed_dim'] = True 335 | 336 | super(MultiheadSeparableAttention, self).__setstate__(state) 337 | 338 | def forward(self, query, key, value, key_padding_mask=None, 339 | need_weights=True, attn_mask=None): 340 | if not self._qkv_same_embed_dim: 341 | return multi_head_sep_attention_forward( 342 | query, key, value, self.embed_dim, self.num_heads, 343 | self.in_proj_weight, self.in_proj_bias, 344 | self.bias_k, self.bias_v, self.add_zero_attn, 345 | self.dropout, self.out_proj.weight, self.out_proj.bias, 346 | training=self.training, 347 | key_padding_mask=key_padding_mask, need_weights=need_weights, 348 | attn_mask=attn_mask, use_separate_proj_weight=True, 349 | q_proj_weight=self.q_proj_weight, k_proj_weight=self.k_proj_weight, 350 | v_proj_weight=self.v_proj_weight, shared_qk=self.shared_qk, sep=self.sep) 351 | else: 352 | return multi_head_sep_attention_forward( 353 | query, key, value, self.embed_dim, self.num_heads, 354 | self.in_proj_weight, self.in_proj_bias, 355 | self.bias_k, self.bias_v, self.add_zero_attn, 356 | self.dropout, self.out_proj.weight, self.out_proj.bias, 357 | training=self.training, 358 | key_padding_mask=key_padding_mask, need_weights=need_weights, 359 | attn_mask=attn_mask, shared_qk=self.shared_qk, sep=self.sep) 360 | 361 | 362 | class ConformerEncoderLayer(nn.Module): 363 | 364 | def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu", attn=False, conv=False, convsz=1, shared_qk=False, sep=False): 365 | super(ConformerEncoderLayer, self).__init__() 366 | self.conv = conv 367 | if conv: 368 | assert convsz % 2 == 1 # kernel size should be an odd number 369 | self.conv_layer = nn.Conv1d(d_model, d_model, convsz, padding=int((convsz-1)/2), groups=nhead) 370 | self.attn = attn 371 | if attn: 372 | self.self_attn = MultiheadSeparableAttention(d_model, nhead, dropout=dropout, shared_qk=shared_qk, sep=sep) 373 | # Implementation of Feedforward model 374 | self.linear1 = nn.Linear(d_model, dim_feedforward) 375 | self.dropout = nn.Dropout(dropout) 376 | self.linear2 = nn.Linear(dim_feedforward, d_model) 377 | 378 | self.norm1 = nn.LayerNorm(d_model) 379 | self.norm2 = nn.LayerNorm(d_model) 380 | self.dropout1 = nn.Dropout(dropout) 381 | self.dropout2 = nn.Dropout(dropout) 382 | 383 | self.activation = _get_activation_fn(activation) 384 | 385 | def __setstate__(self, state): 386 | if 'activation' not in state: 387 | state['activation'] = F.relu 388 | super(ConformerEncoderLayer, self).__setstate__(state) 389 | 390 | def forward(self, src, src_mask=None, src_key_padding_mask=None): 391 | if self.conv: 392 | src2 = src.permute(1, 2, 0) 393 | src2 = self.conv_layer(src2) 394 | src2 = src2.permute(2, 0, 1) 395 | else: 396 | src2 = src 397 | if self.attn: 398 | src2 = self.self_attn(src2, src2, src2, attn_mask=src_mask, 399 | key_padding_mask=src_key_padding_mask)[0] 400 | src = src + self.dropout1(src2) 401 | src = self.norm1(src) 402 | src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) 403 | src = src + self.dropout2(src2) 404 | src = self.norm2(src) 405 | return src 406 | 407 | 408 | def _get_clones(module, N): 409 | return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) 410 | 411 | 412 | def _get_activation_fn(activation): 413 | if activation == "relu": 414 | return F.relu 415 | elif activation == "gelu": 416 | return F.gelu 417 | 418 | raise RuntimeError("activation should be relu/gelu, not {}".format(activation)) 419 | 420 | 421 | class ConformerEncoder(nn.Module): 422 | __constants__ = ['norm'] 423 | 424 | def __init__(self, encoder_layer, num_layers, norm=None): 425 | super(ConformerEncoder, self).__init__() 426 | self.layers = _get_clones(encoder_layer, num_layers) 427 | self.num_layers = num_layers 428 | self.norm = norm 429 | 430 | def forward(self, src, mask=None, src_key_padding_mask=None): 431 | output = src 432 | 433 | for mod in self.layers: 434 | output = mod(output, src_mask=mask, src_key_padding_mask=src_key_padding_mask) 435 | 436 | if self.norm is not None: 437 | output = self.norm(output) 438 | 439 | return output 440 | 441 | 442 | class PositionalEncoding(nn.Module): 443 | 444 | def __init__(self, d_model, dropout=0.1, max_len=5000): 445 | super(PositionalEncoding, self).__init__() 446 | self.dropout = nn.Dropout(p=dropout) 447 | pe = torch.zeros(max_len, d_model) 448 | position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) 449 | div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) 450 | pe[:, 0::2] = torch.sin(position * div_term) 451 | pe[:, 1::2] = torch.cos(position * div_term) 452 | pe = pe.unsqueeze(0).transpose(0, 1) 453 | self.register_buffer('pe', pe) 454 | 455 | def forward(self, x): 456 | x = x + self.pe[:x.size(0), :] 457 | return self.dropout(x) 458 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import csv 3 | import random 4 | import torch 5 | from torch.utils import data 6 | from collections import Counter 7 | import numpy as np 8 | 9 | 10 | class TRECSupervisedTrainMultiDataset(data.Dataset): 11 | 12 | def __init__(self, datasets): 13 | self.datasets = datasets 14 | self.num_datasets = len(datasets) 15 | self.max_dataset_len = max([len(dataset) for dataset in self.datasets]) 16 | 17 | def __len__(self): 18 | return self.max_dataset_len * self.num_datasets 19 | 20 | def __getitem__(self, idx): 21 | dataset = self.datasets[idx // self.max_dataset_len] 22 | idx = idx % self.max_dataset_len 23 | idx = idx % len(dataset) 24 | return dataset[idx] 25 | 26 | 27 | class TRECSupervisedTrainDataset(data.Dataset): 28 | 29 | def __init__(self, query_set, featurize, utils): 30 | self.featurize = featurize 31 | self.utils = utils 32 | self.f_docs = open(os.path.join(utils.args.local_dir, utils.args.file_in_docs), 'rt', encoding='utf8') 33 | self.f_orcas = open(os.path.join(utils.args.local_dir, utils.args.file_gen_orcas_docs), 'rt', encoding='utf8') 34 | self.qids = getattr(utils, 'qids_{}'.format(query_set)) 35 | self.cand = getattr(utils, 'cand_{}'.format(query_set)) 36 | self.qrels = getattr(utils, 'qrels_{}'.format(query_set)) 37 | 38 | def __len__(self): 39 | return len(self.qids) 40 | 41 | def __getitem__(self, idx): 42 | qid = self.qids[idx] 43 | q = self.utils.qs[qid] 44 | dids = self.cand[qid] 45 | qrels = self.qrels.get(qid, {}) 46 | labeled = {} 47 | for did in dids: 48 | label = qrels.get(did, 0) + 1 49 | if label not in labeled: 50 | labeled[label] = [] 51 | labeled[label].append(did) 52 | if len(labeled) > 1: 53 | sampled_labels = sorted(random.sample(labeled.keys(), 2), reverse=True) 54 | else: 55 | sampled_labels = list(labeled.keys()) 56 | cands = [(random.sample(labeled[label], 1)[0], label) for label in sampled_labels] 57 | num_rand_negs = self.utils.args.num_rand_negs + 2 - len(cands) 58 | if num_rand_negs > 0: 59 | cands += [(did, 0) for did in random.sample(self.utils.dids, num_rand_negs)] 60 | dids = [x[0] for x in cands] 61 | ds = [self.utils.get_doc_content(self.f_docs, self.f_orcas, did) for did in dids] 62 | labels = torch.FloatTensor(np.asarray([x[1] for x in cands], dtype=np.float32)) 63 | features = self.featurize(q, ds) 64 | return (qid, dids, labels, features) 65 | 66 | 67 | class TRECSupervisedTestDataset(data.Dataset): 68 | 69 | def __init__(self, query_set, featurize, utils): 70 | self.featurize = featurize 71 | self.utils = utils 72 | self.cand = getattr(utils, 'cand_{}'.format(query_set)) 73 | self.cand = [[(qid, did) for did in cands] for qid,cands in self.cand.items()] 74 | self.cand = [item for sublist in self.cand for item in sublist] 75 | self.qrels = getattr(utils, 'qrels_{}'.format(query_set)) if hasattr(utils, 'qrels_{}'.format(query_set)) else {} 76 | self.f_docs = open(os.path.join(utils.args.local_dir, utils.args.file_in_docs), 'rt', encoding='utf8') 77 | self.f_orcas = open(os.path.join(utils.args.local_dir, utils.args.file_gen_orcas_docs), 'rt', encoding='utf8') 78 | 79 | def __len__(self): 80 | return len(self.cand) 81 | 82 | def __getitem__(self, idx): 83 | qid, did = self.cand[idx] 84 | qrels = self.qrels.get(qid, {}) 85 | label = qrels.get(did, 0) 86 | cand = [(did, label)] 87 | labels = torch.FloatTensor(np.asarray([label], dtype=np.float32)) 88 | features = self.featurize(self.utils.qs[qid], [self.utils.get_doc_content(self.f_docs, self.f_orcas, did)]) 89 | return (qid, did, labels, features) 90 | 91 | 92 | class TRECInferenceDataset(data.IterableDataset): 93 | 94 | def __init__(self, query_set, featurize, tokenize, utils): 95 | self.featurize = featurize 96 | self.utils = utils 97 | self.qids = getattr(utils, 'qids_{}'.format(query_set)) 98 | self.qid_to_terms = {qid: tokenize(utils.qs[qid])[:utils.args.max_terms_query] for qid in self.qids} 99 | self.query_terms = list(set([item for sublist in self.qid_to_terms.values() for item in sublist])) 100 | self.query = ' '.join(self.query_terms) 101 | self.f_docs = open(os.path.join(utils.args.local_dir, utils.args.file_in_docs), 'rt', encoding='utf8') 102 | self.f_orcas = open(os.path.join(utils.args.local_dir, utils.args.file_gen_orcas_docs), 'rt', encoding='utf8') 103 | 104 | def __iter__(self): 105 | for did in self.utils.dids: 106 | features = self.featurize(self.query, [self.utils.get_doc_content(self.f_docs, self.f_orcas, did)], infer_mode=True) 107 | yield (did, features) 108 | -------------------------------------------------------------------------------- /data_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import csv 3 | import sys 4 | import gzip 5 | import random 6 | import shutil 7 | import tarfile 8 | import requests 9 | import numpy as np 10 | from clint.textui import progress 11 | from requests.adapters import HTTPAdapter 12 | from requests.exceptions import ConnectionError 13 | 14 | 15 | class DataUtils: 16 | 17 | def __init__(self, printer): 18 | self.printer = printer 19 | self.prenorm_file_cols = {'file_in_docs': [1, 2, 3], 'file_in_orcas': [1], 'file_in_qs_train': [1], 'file_in_qs_dev': [1], 'file_in_qs_val': [1], 'file_in_qs_test': [1], 'file_in_qs_orcas': [1]} 20 | csv.field_size_limit(sys.maxsize) 21 | 22 | def parser_add_args(self, parser): 23 | parser.add_argument('--local_dir', default='data/', help='root directory for data files (default: data/)') 24 | parser.add_argument('--web_dir', default='https://msmarco.blob.core.windows.net/msmarcoranking/', help='root directory for data files (default: /data/home/bmitra/data/trec2019-doc/)') 25 | parser.add_argument('--file_in_docs', default='msmarco-docs.tsv', help='filename for document collection (default: msmarco-docs.tsv)') 26 | parser.add_argument('--file_in_orcas', default='orcas.tsv', help='filename for orcas data (default: orcas.tsv)') 27 | parser.add_argument('--file_in_qs_train', default='msmarco-doctrain-queries.tsv', help='filename for train queries (default: msmarco-doctrain-queries.tsv)') 28 | parser.add_argument('--file_in_qs_dev', default='msmarco-docdev-queries.tsv', help='filename for development queries (default: msmarco-docdev-queries.tsv)') 29 | parser.add_argument('--file_in_qs_val', default='msmarco-test2019-queries.tsv', help='filename for validation queries (default: msmarco-test2019-queries.tsv)') 30 | parser.add_argument('--file_in_qs_test', default='msmarco-test2020-queries.tsv', help='filename for test queries (default: msmarco-test2020-queries.tsv)') 31 | parser.add_argument('--file_in_qs_orcas', default='orcas-doctrain-queries.tsv', help='filename for orcas queries (default: orcas-doctrain-queries.tsv)') 32 | parser.add_argument('--file_in_cnd_train', default='msmarco-doctrain-top100', help='filename for top 100 train candidates (default: msmarco-doctrain-top100)') 33 | parser.add_argument('--file_in_cnd_dev', default='msmarco-docdev-top100', help='filename for top 100 dev candidates (default: msmarco-docdev-top100)') 34 | parser.add_argument('--file_in_cnd_val', default='msmarco-doctest2019-top100', help='filename for top 100 validation candidates (default: msmarco-doctest2019-top100)') 35 | parser.add_argument('--file_in_cnd_test', default='msmarco-doctest2020-top100', help='filename for top 100 test candidates (default: msmarco-doctest2020-top100)') 36 | parser.add_argument('--file_in_cnd_orcas', default='orcas-doctrain-top100', help='filename for orcas candidates (default: orcas-doctrain-top100)') 37 | parser.add_argument('--file_in_qrel_train', default='msmarco-doctrain-qrels.tsv', help='filename for train qrels (default: msmarco-doctrain-qrels.tsv)') 38 | parser.add_argument('--file_in_qrel_dev', default='msmarco-docdev-qrels.tsv', help='filename for dev qrels (default: msmarco-docdev-qrels.tsv)') 39 | parser.add_argument('--file_in_qrel_val', default='2019qrels-docs.txt', help='filename for validation qrels (default: 2019qrels-docs.txt)') 40 | parser.add_argument('--file_in_qrel_orcas', default='orcas-doctrain-qrels.tsv', help='filename for orcas qrels (default: orcas-doctrain-qrels.tsv)') 41 | parser.add_argument('--file_gen_docs_lookup', default='lookup-docs-norm.tsv', help='filename for document offsets for collection (default: lookup-docs-norm.tsv)') 42 | parser.add_argument('--file_gen_orcas_docs', default='orcas-docs.tsv', help='filename for orcas field (default: orcas-docs.tsv)') 43 | parser.add_argument('--file_gen_orcas_docs_lookup', default='lookup-orcas-docs-norm.tsv', help='filename for document offsets for orcas field (default: lookup-docs-orcas-norm.tsv)') 44 | parser.add_argument('--num_fields', default=4, help='number of fields per document (default: 4)', type=int) 45 | parser.add_argument('--num_dev_queries', default=100, help='number of queries to sample for dev set (default: 100)', type=int) 46 | 47 | def parser_validate_args(self, args): 48 | self.args = args 49 | if not os.path.exists(args.local_dir): 50 | os.makedirs(args.local_dir) 51 | 52 | def setup_and_verify(self): 53 | self.__verify_in_data() 54 | self.__verify_gen_data() 55 | self.__preload_data_to_memory() 56 | 57 | def get_doc_content(self, f_docs, f_orcas, did): 58 | if did == '': 59 | return [''] * self.args.num_fields 60 | f_docs.seek(self.doc_offsets[did]) 61 | line = f_docs.readline() 62 | assert line.startswith(did + "\t"), 'looking for {} at position {}, found {}'.format(did, self.doc_offsets[did], line) 63 | field_values = line.split('\t')[1:] 64 | if did in self.orcas_docs_offsets: 65 | f_orcas.seek(self.orcas_docs_offsets[did]) 66 | line = f_orcas.readline() 67 | assert line.startswith(did + "\t"), 'looking for {} at position {}, found {}'.format(did, self.orcas_docs_offsets[did], line) 68 | orcas_field = line.split('\t')[1] 69 | else: 70 | orcas_field = '' 71 | field_values.append(orcas_field) 72 | return field_values 73 | 74 | def __preload_data_to_memory(self): 75 | self.printer.print('preloading data to memory') 76 | self.doc_offsets = self.__get_doc_offsets(os.path.join(self.args.local_dir, self.args.file_gen_docs_lookup)) 77 | self.orcas_docs_offsets = self.__get_doc_offsets(os.path.join(self.args.local_dir, self.args.file_gen_orcas_docs_lookup)) 78 | self.dids = list(self.doc_offsets.keys()) 79 | qs_train = self.__load_set('train') 80 | qs_dev = self.__load_set('dev', num_samples=self.args.num_dev_queries) 81 | qs_val = self.__load_set('val') 82 | qs_test = self.__load_set('test') 83 | if self.args.orcas_train: 84 | qs_orcas = self.__load_set('orcas') 85 | self.qs = {**qs_train, **qs_orcas, **qs_dev, **qs_val, **qs_test} 86 | else: 87 | self.qs = {**qs_train, **qs_dev, **qs_val, **qs_test} 88 | setattr(self.args, 'collection_size', len(self.doc_offsets)) 89 | setattr(self.args, 'num_train_queries', len(qs_train)) 90 | self.args.num_dev_queries = len(qs_dev) 91 | 92 | def __verify_in_data(self): 93 | self.printer.print('verifying input data') 94 | for k, file_name in vars(self.args).items(): 95 | if k.startswith('file_in_'): 96 | expect_prenorm = self.__should_prenorm_file(k) 97 | if expect_prenorm: 98 | file_norm = self.__get_post_norm_filename(file_name) 99 | if self.__verify_and_download_file(file_norm): 100 | setattr(self.args, k, file_norm) 101 | continue 102 | if self.__verify_and_download_file(file_name): 103 | if expect_prenorm: 104 | self.__prenorm_input_file(k, os.path.join(self.args.local_dir, file_name), os.path.join(self.args.local_dir, file_norm)) 105 | setattr(self.args, k, file_norm) 106 | else: 107 | self.printer.print('error: can not find file {}'.format(file_name)) 108 | sys.exit(0) 109 | 110 | def __verify_gen_data(self): 111 | self.printer.print('verifying intermediate data') 112 | for k, file_name in vars(self.args).items(): 113 | if k.startswith('file_gen_'): 114 | if not self.__verify_and_download_file(file_name): 115 | if k == 'file_gen_docs_lookup': 116 | self.__generate_lookup() 117 | elif k == 'file_gen_orcas_docs' or k == 'file_gen_orcas_docs_lookup': 118 | self.__generate_orcas_field() 119 | 120 | def __should_prenorm_file(self, file_key): 121 | return (file_key in self.prenorm_file_cols) 122 | 123 | def __get_post_norm_filename(self, file_name): 124 | return file_name + '.norm' 125 | 126 | def __prenorm_input_file(self, file_key, file_path, file_path_norm): 127 | self.printer.print('normalizing {}'.format(file_path)) 128 | with open(file_path, 'rt', encoding='utf8') as f_in: 129 | with open(file_path_norm, 'w', encoding='utf8') as f_out: 130 | reader = csv.reader(f_in, delimiter='\t', quoting=csv.QUOTE_NONE) 131 | cols_to_clean = self.prenorm_file_cols[file_key] 132 | for row in reader: 133 | clean_cols = [] 134 | for i in range(len(row)): 135 | clean_cols.append(self.parent.model_utils.clean_text(row[i]) if i in cols_to_clean else row[i]) 136 | clean_text = '\t'.join(clean_cols) 137 | f_out.write(clean_text) 138 | f_out.write('\n') 139 | os.remove(file_path) 140 | 141 | def __generate_lookup(self): 142 | self.printer.print('generating document offsets for collection') 143 | with open(os.path.join(self.args.local_dir, self.args.file_in_docs), 'rt', encoding='utf8') as f_in: 144 | with open(os.path.join(self.args.local_dir, self.args.file_gen_docs_lookup), 'w', encoding='utf8') as f_out: 145 | offset = 0 146 | line = f_in.readline() 147 | while line: 148 | did = line.split('\t')[0] 149 | f_out.write('{}\t{}\n'.format(did, offset)) 150 | offset = f_in.tell() 151 | line = f_in.readline() 152 | 153 | def __generate_orcas_field(self): 154 | self.printer.print('generating orcas field data') 155 | orcas_field = {} 156 | with open(os.path.join(self.args.local_dir, self.args.file_in_orcas), 'rt', encoding='utf8') as f_in: 157 | reader = csv.reader(f_in, delimiter='\t') 158 | for [qid, q, did, _] in reader: 159 | if did not in orcas_field: 160 | orcas_field[did] = [] 161 | orcas_field[did].append(q) 162 | orcas_field = {k: ' '.join(v) for k,v in orcas_field.items()} 163 | with open(os.path.join(self.args.local_dir, self.args.file_gen_orcas_docs), 'w', encoding='utf8') as f_out: 164 | with open(os.path.join(self.args.local_dir, self.args.file_gen_orcas_docs_lookup), 'w', encoding='utf8') as f_lookup: 165 | offset = 0 166 | for did, field in orcas_field.items(): 167 | f_out.write('{}\t{}\n'.format(did, field)) 168 | f_lookup.write('{}\t{}\n'.format(did, offset)) 169 | offset = f_out.tell() 170 | 171 | def __get_doc_offsets(self, lookup_file): 172 | offsets = {} 173 | with open(lookup_file, 'rt', encoding='utf8') as f: 174 | reader = csv.reader(f, delimiter='\t') 175 | for [did, offset] in reader: 176 | offsets[did] = int(offset) 177 | return offsets 178 | 179 | def __load_set(self, query_set, num_samples=0): 180 | file_in_qs = getattr(self.args, 'file_in_qs_{}'.format(query_set)) 181 | qs = self.__get_qs(file_in_qs) 182 | file_in_cnd = getattr(self.args, 'file_in_cnd_{}'.format(query_set)) 183 | cand = self.__get_candidates(file_in_cnd) 184 | qids = set(qs.keys()) & set(cand.keys()) 185 | if query_set != 'test': 186 | file_in_qrel = getattr(self.args, 'file_in_qrel_{}'.format(query_set)) 187 | qrels = self.__get_qrels(file_in_qrel) 188 | qids = qids & set(qrels.keys()) 189 | qids = list(qids) 190 | if num_samples > 0: 191 | qids = random.sample(qids, min(num_samples, len(qids))) 192 | setattr(self, 'qids_{}'.format(query_set), qids) 193 | if query_set != 'test': 194 | qrels = {qid: qrels[qid] for qid in qids} 195 | setattr(self, 'qrels_{}'.format(query_set), qrels) 196 | cand = {qid: cand[qid] for qid in qids} 197 | setattr(self, 'cand_{}'.format(query_set), cand) 198 | qs = {qid: qs[qid] for qid in qids} 199 | return qs 200 | 201 | def __get_qrels(self, qrels_file): 202 | qrels = {} 203 | with open(os.path.join(self.args.local_dir, qrels_file), 'rt', encoding='utf8') as f: 204 | reader = csv.reader(f, delimiter=' ') 205 | for [qid, _, did, rating] in reader: 206 | rating = int(rating) 207 | if rating == 0: 208 | continue 209 | if qid not in qrels: 210 | qrels[qid] = {} 211 | qrels[qid][did] = rating 212 | return qrels 213 | 214 | def __get_candidates(self, cnd_file): 215 | cands = {} 216 | with open(os.path.join(self.args.local_dir, cnd_file), 'rt', encoding='utf8') as f: 217 | reader = csv.reader(f, delimiter=' ') 218 | for [qid, _, did, _, _, _] in reader: 219 | if qid not in cands: 220 | cands[qid] = [did] 221 | else: 222 | cands[qid].append(did) 223 | return cands 224 | 225 | def __get_qs(self, qs_file): 226 | qs = {} 227 | with open(os.path.join(self.args.local_dir, qs_file), 'rt', encoding='utf8') as f: 228 | reader = csv.reader(f, delimiter='\t') 229 | for [qid, q_txt] in reader: 230 | qs[qid] = q_txt 231 | return qs 232 | 233 | def __verify_and_download_file(self, file_name): 234 | file_local = os.path.join(self.args.local_dir, file_name) 235 | if not os.path.exists(file_local): 236 | file_local_tar = '{}.tar'.format(file_local) 237 | file_local_gz = '{}.gz'.format(file_local) 238 | file_local_tar_gz = '{}.tar.gz'.format(file_local) 239 | if os.path.exists(file_local_tar): 240 | self.__untar(file_local_tar) 241 | elif os.path.exists(file_local_gz): 242 | self.__uncompress(file_local_gz) 243 | elif os.path.exists(file_local_tar_gz): 244 | self.__untar(file_local_tar_gz) 245 | else: 246 | file_web = os.path.join(self.args.web_dir, file_name) 247 | file_web_tar = '{}.tar'.format(file_web) 248 | file_web_gz = '{}.gz'.format(file_web) 249 | file_web_tar_gz = '{}.tar.gz'.format(file_web) 250 | if self.__web_file_exists(file_web): 251 | self.__download_file(file_web, file_local) 252 | elif self.__web_file_exists(file_web_tar): 253 | self.__download_file(file_web_tar, file_local_tar) 254 | self.__untar(file_local_tar) 255 | elif self.__web_file_exists(file_web_gz): 256 | self.__download_file(file_web_gz, file_local_gz) 257 | self.__uncompress(file_local_gz) 258 | elif self.__web_file_exists(file_web_tar_gz): 259 | self.__download_file(file_web_tar_gz, file_local_tar_gz) 260 | self.__untar(file_local_tar_gz) 261 | else: 262 | return False 263 | return True 264 | 265 | def __untar(self, filename): 266 | self.printer.print('unpacking {}'.format(filename)) 267 | f = tarfile.open(filename) 268 | f.extractall(path=os.path.dirname(filename)) 269 | f.close() 270 | os.remove(filename) 271 | 272 | def __uncompress(self, filename, block_size=65536): 273 | self.printer.print('uncompressing {}'.format(filename)) 274 | with gzip.open(filename, 'rb') as s_file: 275 | with open(filename[:-3], 'wb') as d_file: 276 | shutil.copyfileobj(s_file, d_file, block_size) 277 | os.remove(filename) 278 | 279 | def __web_file_exists(self, url): 280 | return requests.head(url).status_code != 404 281 | 282 | def __download_file(self, filename_web, filename_local): 283 | self.printer.print('downloading {}'.format(filename_web)) 284 | chunk_size = 1048576 285 | adapter = HTTPAdapter(max_retries=10) 286 | session = requests.Session() 287 | session.mount(filename_web, adapter) 288 | try: 289 | r = session.get(filename_web, stream=True, timeout=5) 290 | with open(filename_local, 'wb') as f: 291 | total_length = int(r.headers.get('content-length')) 292 | for ch in progress.bar(r.iter_content(chunk_size=chunk_size), expected_size=(total_length / chunk_size) + 1): 293 | if ch: 294 | f.write(ch) 295 | except ConnectionError as ce: 296 | self.printer.print('error: {}'.format(ce)) 297 | -------------------------------------------------------------------------------- /factory.py: -------------------------------------------------------------------------------- 1 | from model import NDRM1, NDRM2, NDRM3 2 | from loss import SmoothMRRLoss, RankNetLoss, MarginLoss 3 | 4 | class Factory: 5 | 6 | models = {'ndrm1': NDRM1, 'ndrm2': NDRM2, 'ndrm3': NDRM3} 7 | losses = {'smoothmrr': SmoothMRRLoss, 'ranknet': RankNetLoss, 'margin': MarginLoss} 8 | 9 | def get_model(utils): 10 | return Factory.safe_get('model', utils.args.model, Factory.models)(utils) 11 | 12 | def get_loss(utils, device): 13 | return Factory.safe_get('loss', utils.args.loss, Factory.losses)().to(device) 14 | 15 | def safe_get(type, name, dict): 16 | assert name in dict, 'unknown {} {}'.format(type, name) 17 | return dict[name] 18 | -------------------------------------------------------------------------------- /images/CK.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/bmitra-msft/TREC-Deep-Learning-Quick-Start/b066780e2e856e6f4c4cd08b70b218bdf99506b8/images/CK.png -------------------------------------------------------------------------------- /learner.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.optim as optim 4 | from torch.utils.data import DataLoader 5 | from parallel import DataParallelModel, DataParallelCriterion 6 | from data import TRECSupervisedTrainMultiDataset, TRECSupervisedTrainDataset, TRECSupervisedTestDataset, TRECInferenceDataset 7 | from clint.textui import progress 8 | from factory import Factory 9 | 10 | 11 | class Learner: 12 | 13 | def __init__(self, utils): 14 | self.utils = utils 15 | torch.manual_seed(self.utils.args.seed) 16 | self.device = torch.device(utils.args.device) 17 | if utils.args.device == 'cuda': 18 | self.num_devices = torch.cuda.device_count() 19 | if utils.args.single_gpu: 20 | self.num_devices = min(self.num_devices, 1) 21 | assert self.num_devices > 0, 'no gpus available for training' 22 | else: 23 | self.num_devices = 0 24 | self.model = self.__get_model_instance() 25 | self.model_parameter_count = self.model.parameter_count() 26 | if self.num_devices > 1: 27 | self.model = DataParallelModel(self.model) 28 | self.optimizer = optim.Adam(self.model.parameters(), lr=self.utils.args.lr) 29 | 30 | def train_and_evaluate(self): 31 | self.utils.printer.print('starting supervised model training and evaluation of model with {:,} parameters and {} loss'.format(self.model_parameter_count, self.utils.args.loss)) 32 | dataset = TRECSupervisedTrainDataset('train', self.utils.parent.model_utils.featurize, self.utils.parent.data_utils) 33 | if self.utils.args.orcas_train: 34 | dataset_orcas = TRECSupervisedTrainDataset('orcas', self.utils.parent.model_utils.featurize, self.utils.parent.data_utils) 35 | dataset = TRECSupervisedTrainMultiDataset([dataset, dataset_orcas]) 36 | dataloader = DataLoader(dataset, shuffle=True, batch_size=self.utils.args.mb_size_train, pin_memory=(self.utils.args.device == 'cuda')) 37 | mb_idx = 0 38 | ep_idx = 0 39 | loss_agg = 0 40 | best_dev_mrr = 0 41 | best_val_ndcg = 0 42 | self.criterion = Factory.get_loss(self.utils, self.device) 43 | if self.num_devices > 1: 44 | self.criterion = DataParallelCriterion(self.criterion) 45 | self.model.train() 46 | for _, _, labels, features in self.__enumerate_infinitely(dataloader): 47 | self.optimizer.zero_grad() 48 | features = self.__move_features_to_device(features) 49 | labels = self.__move_features_to_device(labels) 50 | if isinstance(features, torch.Tensor): 51 | out = self.model(features) 52 | else: 53 | out = self.model(*features) 54 | loss = self.criterion(out, labels) 55 | loss.backward() 56 | self.optimizer.step() 57 | loss_agg += loss.item() 58 | mb_idx += 1 59 | if mb_idx == self.utils.args.epoch_size: 60 | ep_idx += 1 61 | loss_agg /= self.utils.args.epoch_size 62 | self.utils.printer.print('epoch: {}\tloss: {:.5f}'.format(ep_idx, loss_agg), end='') 63 | dev_results = self.evaluate('dev', model=self.model) 64 | dev_mrr, _, _ = self.utils.parent.evaluate_results(dev_results, self.utils.parent.data_utils.qrels_dev) 65 | self.utils.printer.print('dev mrr: {:.3f}'.format(dev_mrr), end='', suppress_timestamp=True) 66 | val_results = self.evaluate('val', model=self.model) 67 | val_mrr, val_ncg, val_ndcg = self.utils.parent.evaluate_results(val_results, self.utils.parent.data_utils.qrels_val) 68 | self.utils.printer.print('val mrr: {:.3f}\tval ncg: {:.3f}\tval ndcg: {:.3f}'.format(val_mrr, val_ncg, val_ndcg), suppress_timestamp=True) 69 | if dev_mrr >= best_dev_mrr: 70 | self.__save_results(val_results, 'val', 'rerank') 71 | self.__save_model('dev') 72 | self.best_model_dev = self.__get_model_copy(self.model) 73 | best_dev_mrr = dev_mrr 74 | if val_ndcg >= best_val_ndcg: 75 | self.__save_model('val') 76 | self.best_model_val = self.__get_model_copy(self.model) 77 | best_val_ndcg = val_ndcg 78 | mb_idx = 0 79 | loss_agg = 0 80 | self.model.train() 81 | if ep_idx == self.utils.args.num_epochs_train: 82 | break 83 | val_results = self.evaluate_full_retrieval('val', model=self.best_model_dev) 84 | val_mrr, val_ncg, val_ndcg = self.utils.parent.evaluate_results(val_results, self.utils.parent.data_utils.qrels_val) 85 | self.utils.printer.print('full retrieval val mrr: {:.3f}\tfull retrieval val ncg: {:.3f}\tfull retrieval val ndcg: {:.3f}'.format(val_mrr, val_ncg, val_ndcg)) 86 | self.__save_results(val_results, 'val', 'fullrank') 87 | test_results = self.evaluate('test', model=self.best_model_val) 88 | self.__save_results(test_results, 'test', 'rerank') 89 | test_results = self.evaluate_full_retrieval('test', model=self.best_model_val) 90 | self.__save_results(test_results, 'test', 'fullrank') 91 | 92 | def evaluate(self, query_set, model=None): 93 | dataset = TRECSupervisedTestDataset(query_set, self.utils.parent.model_utils.featurize, self.utils.parent.data_utils) 94 | dataloader = DataLoader(dataset, shuffle=False, drop_last=False, batch_size=self.utils.args.mb_size_test, pin_memory=(self.utils.args.device == 'cuda')) 95 | if model == None: 96 | model = self.best_model 97 | if isinstance(model, DataParallelModel): 98 | model.module.eval() 99 | else: 100 | model.eval() 101 | results = {} 102 | with torch.no_grad(): 103 | for _, (qids, dids, _, features) in enumerate(dataloader): 104 | features = self.__move_features_to_device(features) 105 | if isinstance(features, torch.Tensor): 106 | out = model(features) 107 | else: 108 | out = model(*features) 109 | if self.num_devices > 1: 110 | out = torch.cat(tuple([out[i] for i in range(self.num_devices)]), dim=0) 111 | out = out.cpu().numpy() 112 | for i in range(len(qids)): 113 | if qids[i] not in results: 114 | results[qids[i]] = [] 115 | results[qids[i]].append((dids[i], out[i, 0])) 116 | results = {qid: sorted(docs, key=lambda x: (x[1], x[0]), reverse=True)[:self.utils.args.max_metric_pos_nodisc] for qid, docs in results.items()} 117 | return results 118 | 119 | def evaluate_full_retrieval(self, query_set, model=None): 120 | self.utils.printer.print('starting evaluation of model on {} under full retrieval setting'.format(query_set)) 121 | dataset = TRECInferenceDataset(query_set, self.utils.parent.model_utils.featurize, self.utils.parent.model_utils.tokenize, self.utils.parent.data_utils) 122 | dataloader = DataLoader(dataset, shuffle=False, drop_last=False, batch_size=self.utils.args.mb_size_infer, pin_memory=(self.utils.args.device == 'cuda')) 123 | if model == None: 124 | model = self.model 125 | if isinstance(model, DataParallelModel): 126 | model.module.eval() 127 | else: 128 | model.eval() 129 | num_query_terms = len(dataset.query_terms) 130 | impacts = [[] for i in range(num_query_terms)] 131 | with torch.no_grad(): 132 | for dids, features in progress.bar(dataloader, expected_size=(self.utils.args.collection_size / self.utils.args.mb_size_infer) + 1): 133 | num_queries = len(dids) 134 | if num_queries < self.num_devices: 135 | temp_model = self.__get_model_copy(model, num_devices_tgt=num_queries) 136 | del model 137 | torch.cuda.empty_cache() 138 | model = temp_model 139 | if isinstance(model, DataParallelModel): 140 | model.module.eval() 141 | else: 142 | model.eval() 143 | features = self.__move_features_to_device(features) 144 | if isinstance(features, torch.Tensor): 145 | out = model(features, qti_mode=True) 146 | else: 147 | out = model(*features, qti_mode=True) 148 | if self.num_devices > 1: 149 | out = torch.cat(tuple([out[i] for i in range(self.num_devices)]), dim=0) 150 | out = out.view(-1, num_query_terms).cpu().numpy() 151 | for i in range(len(dids)): 152 | for j in range(num_query_terms): 153 | score = out[i, j] 154 | if score != 0: 155 | impacts[j].append((dids[i], score)) 156 | results = {} 157 | for qid, terms in dataset.qid_to_terms.items(): 158 | results[qid] = {} 159 | for term in terms: 160 | term_idx = dataset.query_terms.index(term) 161 | for did, score in impacts[term_idx]: 162 | if did not in results[qid]: 163 | results[qid][did] = score 164 | else: 165 | results[qid][did] += score 166 | results = {qid: sorted([(did, score) for did, score in docs.items()], key=lambda x: (x[1], x[0]), reverse=True)[:self.utils.args.max_metric_pos_nodisc] for qid, docs in results.items()} 167 | return results 168 | 169 | def __enumerate_infinitely(self, dataloader): 170 | while True: 171 | for _, x in enumerate(dataloader): 172 | yield x 173 | 174 | def __move_features_to_device(self, features): 175 | if isinstance(features, torch.Tensor): 176 | return features.to(self.device) 177 | if isinstance(features, tuple) and len(features) > 0: 178 | return tuple(self.__move_features_to_device(feature) for feature in features) 179 | if isinstance(features, list) and len(features) > 0: 180 | return [self.__move_features_to_device(feature) for feature in features] 181 | if isinstance(features, dict) and len(features) > 0: 182 | return {k : self.__move_features_to_device(feature) for k, feature in features.items()} 183 | 184 | def __save_results(self, results, query_set, runtype): 185 | with open(os.path.join(self.utils.args.local_dir, self.utils.args.file_out_scores.format(query_set, runtype)), mode='w', encoding='utf-8') as f: 186 | for qid, docs in results.items(): 187 | for i in range(len(docs)): 188 | did = docs[i][0] 189 | score = docs[i][1] 190 | f.write('{} Q0 {} {} {} {}\n'.format(qid, did, i+1, score, self.utils.args.model)) 191 | 192 | def __save_model(self, marker): 193 | torch.save({'model_state_dict': self.model.state_dict(), 'optimizer_state_dict': self.optimizer.state_dict()}, os.path.join(self.utils.args.local_dir, self.utils.args.file_out_model.format(marker))) 194 | 195 | def __load_model(self, path): 196 | checkpoint = torch.load(path) 197 | self.model.load_state_dict(checkpoint['model_state_dict']) 198 | self.best_model = self.model 199 | self.optimizer.load_state_dict(checkpoint['optimizer_state_dict']) 200 | 201 | def __get_model_instance(self): 202 | return Factory.get_model(self.utils.parent.model_utils).to(self.device).float() 203 | 204 | def __get_model_copy(self, model_src, num_devices_tgt=0): 205 | model_tgt = self.__get_model_instance() 206 | if self.num_devices > 1: 207 | model_src = model_src.module 208 | model_tgt.load_state_dict(model_src.state_dict()) 209 | if self.num_devices > 0: 210 | model_tgt = model_tgt.to(self.device) 211 | if self.num_devices > 1: 212 | if num_devices_tgt == 0: 213 | model_tgt = DataParallelModel(model_tgt) 214 | else: 215 | model_tgt = DataParallelModel(model_tgt, device_ids=list(range(num_devices_tgt))) 216 | return model_tgt -------------------------------------------------------------------------------- /learner_utils.py: -------------------------------------------------------------------------------- 1 | class LearnerUtils: 2 | 3 | def __init__(self, printer): 4 | self.printer = printer 5 | 6 | def parser_add_args(self, parser): 7 | parser.add_argument('--seed', default=0, help='random seed (default: 0)', type=int) 8 | parser.add_argument('--device', default='cuda', help='device identifier (default: cuda)') 9 | parser.add_argument('--loss', default='ranknet', help='training loss (default: ranknet, also allowed: smoothmrr)') 10 | parser.add_argument('--lr', default=0.0001, help='learning rate (default: 0.0001)', type=float) 11 | parser.add_argument('--mb_size_train', default=32, help='minibatch size for training (default: 32)', type=int) 12 | parser.add_argument('--mb_size_test', default=256, help='minibatch size for test (default: 256)', type=int) 13 | parser.add_argument('--mb_size_infer', default=16, help='minibatch size for inference (default: 16)', type=int) 14 | parser.add_argument('--num_rand_negs', default=2, help='number of random negative documents for training (default: 2)', type=int) 15 | parser.add_argument('--epoch_size', default=4096, help='epoch size (default: 4096)', type=int) 16 | parser.add_argument('--num_epochs_train', default=32, help='number of epochs to train (default: 32)', type=int) 17 | parser.add_argument('--max_metric_pos', default=10, help='rank cutoff for computing discounted metrics (default: 10)', type=int) 18 | parser.add_argument('--max_metric_pos_nodisc', default=100, help='rank cutoff for computing non-discounted metrics (default: 100)', type=int) 19 | parser.add_argument('--file_out_model', default='model-{}.pt', help='filename for trained model (default: model-{}.pt)') 20 | parser.add_argument('--file_out_scores', default='scores-{}-{}.txt', help='filename for eval run file in TREC submission format (default: scores-{}-{}.txt)') 21 | parser.add_argument('--orcas_train', help='use orcas data for model training', action='store_true') 22 | parser.add_argument('--single_gpu', help='disable multi GPU training', action='store_true') 23 | 24 | def parser_validate_args(self, args): 25 | self.args = args 26 | assert args.mb_size_train > 0 and args.mb_size_test > 0, 'minibatch size must be greater than zero' 27 | assert args.epoch_size > 0, 'epoch size must be greater than zero' 28 | assert args.num_epochs_train > 0, 'number of epochs must be greater than zero' 29 | assert args.device == 'cuda' or not args.single_gpu, 'can not disable multi gpu training when device is not set to cuda' 30 | 31 | def setup_and_verify(self): 32 | pass 33 | -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class SmoothRank(nn.Module): 6 | 7 | def __init__(self): 8 | super(SmoothRank, self).__init__() 9 | self.sigmoid = torch.nn.Sigmoid() 10 | 11 | def forward(self, scores): 12 | x_0 = scores.unsqueeze(dim=-1) # [Q x D] --> [Q x D x 1] 13 | x_1 = scores.unsqueeze(dim=-2) # [Q x D] --> [Q x 1 x D] 14 | diff = x_1 - x_0 # [Q x D x 1], [Q x 1 x D] --> [Q x D x D] 15 | is_lower = self.sigmoid(diff) # [Q x D x D] --> [Q x D x D] 16 | ranks = torch.sum(is_lower, dim=-1) + 0.5 # [Q x D x D] --> [Q x D] 17 | return ranks 18 | 19 | 20 | class SmoothMRRLoss(nn.Module): 21 | 22 | def __init__(self): 23 | super(SmoothMRRLoss, self).__init__() 24 | self.soft_ranker = SmoothRank() 25 | self.zero = nn.Parameter(torch.tensor([0], dtype=torch.float32), requires_grad=False) 26 | self.one = nn.Parameter(torch.tensor([1], dtype=torch.float32), requires_grad=False) 27 | 28 | def forward(self, scores, labels, agg=True): 29 | ranks = self.soft_ranker(scores) # [Q x D] --> [Q x D] 30 | labels = torch.where(labels > 0, self.one, self.zero) # [Q x D] --> [Q x D] 31 | rr = labels / ranks # [Q x D], [Q x D] --> [Q x D] 32 | rr_max, _ = rr.max(dim=-1) # [Q x D] --> [Q] 33 | loss = 1 - rr_max # [Q] --> [Q] 34 | if agg: 35 | loss = loss.mean() # [Q] --> [1] 36 | return loss 37 | 38 | 39 | class RankNetLoss(nn.Module): 40 | 41 | def __init__(self): 42 | super(RankNetLoss, self).__init__() 43 | self.sigmoid = nn.Sigmoid() 44 | self.zero = nn.Parameter(torch.tensor([0], dtype=torch.float32), requires_grad=False) 45 | self.one = nn.Parameter(torch.tensor([1], dtype=torch.float32), requires_grad=False) 46 | 47 | def forward(self, scores, labels, agg=True): 48 | x_0 = scores.unsqueeze(dim=-1) # [Q x D] --> [Q x D x 1] 49 | x_1 = scores.unsqueeze(dim=-2) # [Q x D] --> [Q x 1 x D] 50 | x = x_0 - x_1 # [Q x D x 1], [Q x 1 x D] --> [Q x D x D] 51 | x = self.sigmoid(x) # [Q x D x D] --> [Q x D x D] 52 | x = -torch.log(x + 1e-6) # [Q x D x D] --> [Q x D x D] 53 | y_0 = labels.unsqueeze(dim=-1) # [Q x D] --> [Q x D x 1] 54 | y_1 = labels.unsqueeze(dim=-2) # [Q x D] --> [Q x 1 x D] 55 | y = y_0 - y_1 # [Q x D x 1], [Q x 1 x D] --> [Q x D x D] 56 | pair_mask = torch.where(y > 0, self.one, self.zero) # [Q x D x D] --> [Q x D x D] 57 | num_pairs = pair_mask.sum(dim=-1) # [Q x D x D] --> [Q x D] 58 | num_pairs = num_pairs.sum(dim=-1) # [Q x D] --> [Q] 59 | num_pairs = torch.where(num_pairs > 0, num_pairs, self.one) # [Q] --> [Q] 60 | loss = x * pair_mask # [Q x D x D], [Q x D x D] --> [Q x D x D] 61 | loss = loss.sum(dim=-1) # [Q x D x D] --> [Q x D] 62 | loss = loss.sum(dim=-1) # [Q x D] --> [Q] 63 | loss = loss / num_pairs # [Q], [Q] --> [Q] 64 | if agg: 65 | loss = loss.mean() # [Q] --> [1] 66 | return loss 67 | 68 | 69 | class MarginLoss(nn.Module): 70 | 71 | def __init__(self): 72 | super(MarginLoss, self).__init__() 73 | self.zero = nn.Parameter(torch.tensor([0], dtype=torch.float32), requires_grad=False) 74 | self.one = nn.Parameter(torch.tensor([1], dtype=torch.float32), requires_grad=False) 75 | self.neg_one = nn.Parameter(torch.tensor([-1], dtype=torch.float32), requires_grad=False) 76 | self.margin = nn.Parameter(torch.tensor([0.1], dtype=torch.float32), requires_grad=False) 77 | 78 | def forward(self, scores, labels, agg=True): 79 | x_0 = scores.unsqueeze(dim=-1) # [Q x D] --> [Q x D x 1] 80 | x_1 = scores.unsqueeze(dim=-2) # [Q x D] --> [Q x 1 x D] 81 | x = x_0 - x_1 # [Q x D x 1], [Q x 1 x D] --> [Q x D x D] 82 | y_0 = labels.unsqueeze(dim=-1) # [Q x D] --> [Q x D x 1] 83 | y_1 = labels.unsqueeze(dim=-2) # [Q x D] --> [Q x 1 x D] 84 | y = y_0 - y_1 # [Q x D x 1], [Q x 1 x D] --> [Q x D x D] 85 | y = torch.where(y > 0, self.one, y) # [Q x D x D] --> [Q x D x D] 86 | y = torch.where(y < 0, self.neg_one, y) # [Q x D x D] --> [Q x D x D] 87 | loss = y * x # [Q x D x D], [Q x D x D] --> [Q x D x D] 88 | loss = self.margin - loss # [1], [Q x D x D] --> [Q x D x D] 89 | loss = torch.where(loss < 0, self.zero, loss) # [Q x D x D] --> [Q x D x D] 90 | loss = loss.sum(dim=-1) # [Q x D x D] --> [Q x D] 91 | loss = loss.sum(dim=-1) # [Q x D] --> [Q] 92 | num_pairs = torch.where(y < 0, self.one, y) # [Q x D x D] --> [Q x D x D] 93 | num_pairs = num_pairs.sum(dim=-1) # [Q x D x D] --> [Q x D] 94 | num_pairs = num_pairs.sum(dim=-1) # [Q x D] --> [Q] 95 | loss = loss / num_pairs # [Q], [Q] --> [Q] 96 | if agg: 97 | loss = loss.mean() # [Q] --> [1] 98 | return loss 99 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import math 2 | import random 3 | import torch 4 | import torch.nn as nn 5 | import numpy as np 6 | from conformer import PositionalEncoding, ConformerEncoderLayer, ConformerEncoder 7 | 8 | class NDRM1(nn.Module): 9 | 10 | def __init__(self, utils): 11 | super(NDRM1, self).__init__() 12 | self.utils = utils 13 | self.embed = nn.Embedding(self.utils.args.vocab_size, utils.args.num_hidden_nodes) 14 | self.embed.weight = nn.Parameter(self.utils.pretrained_embeddings, requires_grad=True) 15 | self.pos_encoder = PositionalEncoding(utils.args.num_hidden_nodes, dropout=utils.args.drop, max_len=utils.args.max_terms_doc) 16 | self.fc_qt = nn.Linear(utils.args.num_hidden_nodes, utils.args.num_hidden_nodes) 17 | enable_conformer = (not self.utils.args.no_conformer) 18 | encoder_layers = ConformerEncoderLayer(utils.args.num_hidden_nodes, utils.args.num_attn_heads, utils.args.num_hidden_nodes, utils.args.drop, 19 | attn=True, conv=enable_conformer, convsz=utils.args.conv_window_size, shared_qk=enable_conformer, sep=enable_conformer) 20 | self.contextualize = ConformerEncoder(encoder_layers, utils.args.num_encoder_layers) 21 | self.fc_ctx = nn.Linear(2, 1) 22 | self.cosine_sim = nn.CosineSimilarity(dim=-1) 23 | self.rbf_kernel = RBFKernel(utils) 24 | 25 | def forward(self, q, d, mask_q, mask_d, qti_mode=False): 26 | q = self.embed(q) # [Q x Tq] --> [Q x Tq x H] 27 | q = self.fc_qt(q) # [Q x Tq x H] --> [Q x Tq x H] 28 | q = q.unsqueeze(dim=1) # [Q x Tq x H] --> [Q x 1 x Tq x H] 29 | q = q.unsqueeze(dim=-2) # [Q x 1 x Tq x H] --> [Q x 1 x Tq x 1 x H] 30 | d = self.embed(d) # [Q x D x Td] --> [Q x D x Td x H] 31 | shape_mask = mask_d.shape 32 | mask_d = mask_d.view(-1, shape_mask[-1]) # [Q x D x Td] --> [QD x Td] 33 | shape_d = d.shape 34 | d_ctx = d.view(-1, shape_d[2], shape_d[3]) # [Q x D x Td x H] --> [QD x Td x H] 35 | d_ctx = d_ctx.permute(1, 0, 2) # [Q x D x Td x H] --> [Td x QD x H] 36 | d_ctx = self.pos_encoder(d_ctx) # [Td x QD x H] --> [Td x QD x H] 37 | d_ctx = self.contextualize(d_ctx, src_key_padding_mask=~mask_d.bool()) # [Td x QD x H], [Q x D x Td] --> [Td x QD x H] 38 | d_ctx = d_ctx.permute(1, 0, 2) # [Td x QD x H] --> [QD x Td x H] 39 | d_ctx = d_ctx.view(shape_d) # [QD x Td x H] --> [Q x D x Td x H] 40 | mask_d = mask_d.view(shape_mask) # [QD x Td] --> [Q x D x Td] 41 | d = torch.stack([d, d_ctx], dim=-1) # [Q x D x Td x H], [Q x D x Td x H] --> [Q x D x Td x H x 2] 42 | d = self.fc_ctx(d) # [Q x D x Td x H x 2] --> [Q x D x Td x H x 1] 43 | d = d.squeeze(dim=-1) # [Q x D x Td x H x 1] --> [Q x D x Td x H] 44 | d = d.unsqueeze(dim=2) # [Q x D x Td x H] --> [Q x D x 1 x Td x H] 45 | y = self.cosine_sim(q, d) # [Q x 1 x Tq x 1 x H], [Q x D x 1 x Td x H] --> [Q x D x Tq x Td] 46 | y = self.rbf_kernel(y, mask_d) # [Q x D x Tq x Td] --> [Q x D x Tq] 47 | mask_q = mask_q.unsqueeze(1) # [Q x Tq] --> [Q x 1 x Tq] 48 | y = y * mask_q # [Q x D x Tq], [Q x 1 x Tq] --> [Q x D x Tq] 49 | if not qti_mode: 50 | y = y.sum(dim=-1) # [Q x D x Tq] --> [Q x D] 51 | return y 52 | 53 | def parameter_count(self): 54 | return sum(p.numel() for p in self.parameters() if p.requires_grad) 55 | 56 | class NDRM2(nn.Module): 57 | 58 | def __init__(self, utils): 59 | super(NDRM2, self).__init__() 60 | self.utils = utils 61 | self.norm_dlen = BatchScale(1) 62 | self.norm_tf = BatchScale(1) 63 | self.fc_dlen = nn.Sequential(nn.Linear(1, 1), nn.ReLU()) 64 | with torch.no_grad(): 65 | self.fc_dlen[0].weight = nn.Parameter(torch.ones((1, 1), dtype=torch.float32), requires_grad=True) 66 | self.fc_dlen[0].bias = nn.Parameter(torch.zeros(1, dtype=torch.float32), requires_grad=True) 67 | 68 | def forward(self, qd, mask_q, q_idf, dlen, qti_mode=False): 69 | shape_dlen = dlen.shape 70 | dlen = dlen.view(-1, 1) # [Q x D] --> [QD x 1] 71 | dlen = self.norm_dlen(dlen) # [QD x 1] --> [QD x 1] 72 | dlen = dlen.view(shape_dlen + (1,)) # [QD x 1] --> [Q x D x 1] 73 | dlen = self.fc_dlen(dlen) # [Q x D x 1] --> [Q x D x 1] 74 | shape_qd = qd.shape 75 | y = qd.view(-1, 1) # [Q x D x Tq] --> [QDTqx 1] 76 | y = self.norm_tf(y) # [QDTq x 1] --> [QDTq x 1] 77 | y = y.view(shape_qd) # [QDTq x 1] --> [Q x D x Tq] 78 | y = y / (y + dlen + 1e-6) # [Q x D x Tq], [Q x D x 1] --> [Q x D x Tq] 79 | q_idf = q_idf.unsqueeze(dim=1) # [Q x Tq] --> [Q x 1 x Tq] 80 | y = y * q_idf # [Q x D x Tq], [Q x 1 x Tq] --> [Q x D x Tq] 81 | mask_q = mask_q.unsqueeze(1) # [Q x Tq] --> [Q x 1 x Tq] 82 | y = y * mask_q # [Q x D x Tq], [Q x 1 x Tq] --> [Q x D x Tq] 83 | if not qti_mode: 84 | y = y.sum(dim=-1) # [Q x D x Tq] --> [Q x D] 85 | return y 86 | 87 | def parameter_count(self): 88 | return sum(p.numel() for p in self.parameters() if p.requires_grad) 89 | 90 | 91 | class NDRM3(nn.Module): 92 | 93 | def __init__(self, utils): 94 | super(NDRM3, self).__init__() 95 | self.utils = utils 96 | self.ndrm1 = NDRM1(utils) 97 | self.ndrm2 = NDRM2(utils) 98 | self.fc = nn.Sequential(nn.BatchNorm1d(2), 99 | nn.Linear(2, 1)) 100 | 101 | def forward(self, q, d, qd, mask_q, mask_d, q_idf, dlen, qti_mode=False): 102 | y_lat = self.ndrm1(q, d, mask_q, mask_d, qti_mode=True) # [*] --> [Q x D x Tq] 103 | y_exp = self.ndrm2(qd, mask_q, q_idf, dlen, qti_mode=True) # [*] --> [Q x D x Tq] 104 | y_lat = y_lat.unsqueeze(dim=-1) # [Q x D x Tq] --> [Q x D x Tq x 1] 105 | y_exp = y_exp.unsqueeze(dim=-1) # [Q x D x Tq] --> [Q x D x Tq x 1] 106 | y = torch.cat([y_lat, y_exp], dim=-1) # [Q x D x Tq x 1], [Q x D x Tq x 1] --> [Q x D x Tq x 2] 107 | shape_y = y.shape 108 | y = y.view(-1, 2) # [Q x D x Tq x 2] --> [QDTq x 2] 109 | y = self.fc(y) # [QDTq x 2] --> [QDTq x 1] 110 | y = y.view(shape_y[:-1]) # [QDTq x 1] --> [Q x D x Tq] 111 | mask_q = mask_q.unsqueeze(1) # [Q x Tq] --> [Q x 1 x Tq] 112 | y = y * mask_q # [Q x D x Tq], [Q x 1 x Tq] --> [Q x D x Tq] 113 | if not qti_mode: 114 | y = y.sum(dim=-1) # [Q x D x Tq] --> [Q x D] 115 | return y 116 | 117 | def parameter_count(self): 118 | return sum(p.numel() for p in self.parameters() if p.requires_grad) 119 | 120 | 121 | class RBFKernel(nn.Module): 122 | 123 | def __init__(self, utils): 124 | super(RBFKernel, self).__init__() 125 | self.utils = utils 126 | mus = nn.Parameter(torch.from_numpy(np.array([(1 - 2 * i / utils.args.rbf_kernel_dim) for i in range(utils.args.rbf_kernel_dim)], dtype=np.float32)).view(1, 1, 1, 1, -1), requires_grad=False) 127 | sigmas = nn.Parameter(torch.from_numpy(np.array([0.1], dtype=np.float32)).view(1, 1, 1, 1, 1), requires_grad=False) 128 | denom = 2 * torch.pow(sigmas, 2) 129 | self.avg_pool = nn.AvgPool1d(utils.args.rbf_kernel_pool_size, stride=utils.args.rbf_kernel_pool_stride) 130 | self.fc = nn.Sequential(nn.Linear(utils.args.rbf_kernel_dim, utils.args.num_hidden_nodes), 131 | nn.ReLU(), 132 | nn.Dropout(p=utils.args.drop), 133 | nn.Linear(utils.args.num_hidden_nodes, 1)) 134 | self.register_buffer('mus', mus) 135 | self.register_buffer('denom', denom) 136 | 137 | def forward(self, x, mask_d): 138 | y = x.unsqueeze(dim=-1) # [Q x D x Tq x Td] --> [Q x D x Tq x Td x 1] 139 | y = y - self.mus # [Q x D x Tq x Td x 1], [1 x 1 x 1 x 1 x K] --> [Q x D x Tq x Td x K] 140 | y = torch.pow(y, 2) # [Q x D x Tq x Td x K] --> [Q x D x Tq x Td x K] 141 | y = y / self.denom # [Q x D x Tq x Td x K], [1 x 1 x 1 x 1 x 1] --> [Q x D x Tq x Td x K] 142 | y = torch.exp(-y) # [Q x D x Tq x Td x K] --> [Q x D x Tq x Td x K] 143 | mask_d = mask_d.unsqueeze(dim=2) # [Q x D x Td] --> [Q x D x 1 x Td] 144 | mask_d = mask_d.unsqueeze(dim=-1) # [Q x D x 1 x Td] --> [Q x D x 1 x Td x 1] 145 | y = y * mask_d # [Q x D x Tq x Td x K], [Q x D x 1 x Td x 1] --> [Q x D x Tq x Td x K] 146 | shape_y = y.shape 147 | y = y.view(-1, shape_y[-2], shape_y[-1]) # [Q x D x Tq x Td x K] --> [QDTq x Td x K] 148 | y = y.permute(0, 2, 1) # [QDTq x Td x K] --> [QDTq x K x Td] 149 | y = self.avg_pool(y) # [QDTq x K x Td] --> [QDTq x K x N] 150 | y = y * self.utils.args.rbf_kernel_pool_size # [QDTq x K x N] --> [QDTq x K x N] 151 | y = torch.log(y + 1e-6) # [QDTq x K x N] --> [QDTq x K x N] 152 | y = y.permute(0, 2, 1) # [QDTq x K x N] --> [QDTq x N x K] 153 | y = self.fc(y) # [QDTq x N x K] --> [QDTq x N x 1] 154 | y, _ = y.max(dim=1) # [QDTq x N x 1] --> [QDTq x 1] 155 | y = y.view(shape_y[:3]) # [QDTq x 1] --> [Q x D x Tq] 156 | return y 157 | 158 | 159 | class BatchScale(nn.Module): 160 | 161 | def __init__(self, num_features): 162 | super(BatchScale, self).__init__() 163 | self.num_features = num_features 164 | self.register_buffer('running_mean', torch.zeros(num_features, dtype=torch.float32)) 165 | self.register_buffer('num_samples', torch.tensor(0, dtype=torch.float32)) 166 | 167 | def forward(self, x): 168 | if self.training: 169 | mb_size = x.shape[0] 170 | self.num_samples.add_(mb_size) # [1] --> [1] 171 | new_mean = x.detach().sum(dim=0) # [B x N] --> [N] 172 | delta = new_mean - (mb_size * self.running_mean) # [N], [N] --> [N] 173 | delta = delta / self.num_samples # [N], [N] --> [N] 174 | self.running_mean.add_(delta) # [N], [N] --> [N] 175 | y = x / (self.running_mean + 1e-6) # [B x N], [N] --> [B x N] 176 | return y 177 | -------------------------------------------------------------------------------- /model_utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | import os 3 | import csv 4 | import sys 5 | import math 6 | import torch 7 | import struct 8 | import fasttext 9 | import numpy as np 10 | import krovetzstemmer 11 | from clint.textui import progress 12 | 13 | 14 | class NDRMUtils: 15 | 16 | def __init__(self, printer): 17 | self.printer = printer 18 | self.regex_drop_char = re.compile('[^a-z0-9\s]+') 19 | self.regex_multi_space = re.compile('\s+') 20 | self.stemmer = krovetzstemmer.Stemmer() 21 | self.stop_words = ['a', 'able', 'about', 'above', 'according', 'accordingly', 'across', 'actually', 'after', 'afterwards', 'again', 'against', 'ain', 'all', 'allow', 'allows', 'almost', 'alone', 'along', 'already', 'also', 'although', 22 | 'always', 'am', 'among', 'amongst', 'an', 'and', 'another', 'any', 'anybody', 'anyhow', 'anyone', 'anything', 'anyway', 'anyways', 'anywhere', 'apart', 'appear', 'appreciate', 'appropriate', 'are', 'aren', 'around', 23 | 'as', 'aside', 'ask', 'asking', 'associated', 'at', 'available', 'away', 'awfully', 'b', 'be', 'became', 'because', 'become', 'becomes', 'becoming', 'been', 'before', 'beforehand', 'behind', 'being', 'believe', 'below', 24 | 'beside', 'besides', 'best', 'better', 'between', 'beyond', 'both', 'brief', 'but', 'by', 'c', 'came', 'can', 'cannot', 'cant', 'cause', 'causes', 'certain', 'certainly', 'changes', 'clearly', 'co', 'com', 'come', 'comes', 25 | 'concerning', 'consequently', 'consider', 'considering', 'contain', 'containing', 'contains', 'corresponding', 'could', 'couldn', 'course', 'currently', 'd', 'definitely', 'described', 'despite', 'did', 'didn', 'different', 26 | 'do', 'does', 'doesn', 'doing', 'don', 'done', 'down', 'downwards', 'during', 'e', 'each', 'edu', 'eg', 'eight', 'either', 'else', 'elsewhere', 'enough', 'entirely', 'especially', 'et', 'etc', 'even', 'ever', 'every', 27 | 'everybody', 'everyone', 'everything', 'everywhere', 'ex', 'exactly', 'example', 'except', 'f', 'far', 'few', 'fifth', 'first', 'five', 'followed', 'following', 'follows', 'for', 'former', 'formerly', 'forth', 'four', 'from', 28 | 'further', 'furthermore', 'g', 'get', 'gets', 'getting', 'given', 'gives', 'go', 'goes', 'going', 'gone', 'got', 'gotten', 'greetings', 'h', 'had', 'hadn', 'happens', 'hardly', 'has', 'hasn', 'have', 'haven', 'having', 'he', 29 | 'hello', 'help', 'hence', 'her', 'here', 'hereafter', 'hereby', 'herein', 'hereupon', 'hers', 'herself', 'hi', 'him', 'himself', 'his', 'hither', 'hopefully', 'how', 'howbeit', 'however', 'i', 'ie', 'if', 'ignored', 30 | 'immediate', 'in', 'inasmuch', 'inc', 'indeed', 'indicate', 'indicated', 'indicates', 'inner', 'insofar', 'instead', 'into', 'inward', 'is', 'isn', 'it', 'its', 'itself', 'j', 'just', 'k', 'keep', 'keeps', 'kept', 'know', 31 | 'knows', 'known', 'l', 'last', 'lately', 'later', 'latter', 'latterly', 'least', 'less', 'lest', 'let', 'like', 'liked', 'likely', 'little', 'll', 'look', 'looking', 'looks', 'ltd', 'm', 'mainly', 'many', 'may', 'maybe', 32 | 'me', 'mean', 'meanwhile', 'merely', 'might', 'more', 'moreover', 'most', 'mostly', 'much', 'must', 'my', 'myself', 'n', 'name', 'namely', 'nd', 'near', 'nearly', 'necessary', 'need', 'needs', 'neither', 'never', 33 | 'nevertheless', 'new', 'next', 'nine', 'no', 'nobody', 'non', 'none', 'noone', 'nor', 'normally', 'not', 'nothing', 'novel', 'now', 'nowhere', 'o', 'obviously', 'of', 'off', 'often', 'oh', 'ok', 'okay', 'old', 'on', 'once', 34 | 'one', 'ones', 'only', 'onto', 'or', 'other', 'others', 'otherwise', 'ought', 'our', 'ours', 'ourselves', 'out', 'outside', 'over', 'overall', 'own', 'p', 'particular', 'particularly', 'per', 'perhaps', 'placed', 'please', 35 | 'plus', 'possible', 'presumably', 'probably', 'provides', 'q', 'que', 'quite', 'qv', 'r', 'rather', 'rd', 're', 'really', 'reasonably', 'regarding', 'regardless', 'regards', 'relatively', 'respectively', 'right', 's', 'said', 36 | 'same', 'saw', 'say', 'saying', 'says', 'second', 'secondly', 'see', 'seeing', 'seem', 'seemed', 'seeming', 'seems', 'seen', 'self', 'selves', 'sensible', 'sent', 'serious', 'seriously', 'seven', 'several', 'shall', 'she', 37 | 'should', 'shouldn', 'since', 'six', 'so', 'some', 'somebody', 'somehow', 'someone', 'something', 'sometime', 'sometimes', 'somewhat', 'somewhere', 'soon', 'sorry', 'specified', 'specify', 'specifying', 'still', 'sub', 38 | 'such', 'sup', 'sure', 't', 'take', 'taken', 'tell', 'tends', 'th', 'than', 'thank', 'thanks', 'thanx', 'that', 'thats', 'the', 'their', 'theirs', 'them', 'themselves', 'then', 'thence', 'there', 'thereafter', 'thereby', 39 | 'therefore', 'therein', 'theres', 'thereupon', 'these', 'they', 'think', 'third', 'this', 'thorough', 'thoroughly', 'those', 'though', 'three', 'through', 'throughout', 'thru', 'thus', 'to', 'together', 'too', 'took', 40 | 'toward', 'towards', 'tried', 'tries', 'truly', 'try', 'trying', 'twice', 'two', 'u', 'un', 'under', 'unfortunately', 'unless', 'unlikely', 'until', 'unto', 'up', 'upon', 'us', 'use', 'used', 'useful', 'uses', 'using', 41 | 'usually', 'uucp', 'v', 've', 'value', 'various', 'very', 'via', 'viz', 'vs', 'w', 'want', 'wants', 'was', 'wasn', 'way', 'we', 'welcome', 'well', 'went', 'were', 'weren', 'what', 'whatever', 'when', 'whence', 'whenever', 42 | 'where', 'whereafter', 'whereas', 'whereby', 'wherein', 'whereupon', 'wherever', 'whether', 'which', 'while', 'whither', 'who', 'whoever', 'whole', 'whom', 'whose', 'why', 'will', 'willing', 'wish', 'with', 'within', 43 | 'without', 'won', 'wonder', 'would', 'would', 'wouldn', 'x', 'y', 'yes', 'yet', 'you', 'youve', 'your', 'youre', 'yours', 'yourself', 'yourselves', 'z', 'zero'] 44 | 45 | def parser_add_args(self, parser): 46 | parser.add_argument('--model', default='ndrm3', help='model architecture (default: ndrm3)') 47 | parser.add_argument('--max_terms_query', default=20, help='maximum number of terms to consider for query (default: 20)', type=int) 48 | parser.add_argument('--max_terms_doc', default=4000, help='maximum number of terms to consider for long text (default: 4000)', type=int) 49 | parser.add_argument('--max_terms_orcas', default=2000, help='maximum number of terms to consider for long text (default: 2000)', type=int) 50 | parser.add_argument('--num_hidden_nodes', default=256, help='size of hidden layers (default: 256)', type=int) 51 | parser.add_argument('--num_encoder_layers', default=2, help='number of document encoder layers (default: 2)', type=int) 52 | parser.add_argument('--conv_window_size', default=31, help='window size for encoder convolution layer (default: 31)', type=int) 53 | parser.add_argument('--num_attn_heads', default=32, help='number of self-attention heads (default: 32)', type=int) 54 | parser.add_argument('--rbf_kernel_dim', default=10, help='number of RBF kernels (default: 10)', type=int) 55 | parser.add_argument('--rbf_kernel_pool_size', default=300, help='window size for pooling layer in RBF kernels (default: 300)', type=int) 56 | parser.add_argument('--rbf_kernel_pool_stride', default=100, help='stride for pooling layer in RBF kernels (default: 100)', type=int) 57 | parser.add_argument('--drop', default=0.2, help='dropout rate (default: 0.2)', type=float) 58 | parser.add_argument('--file_gen_idfs', default = 'ndrm-idfs.tsv', help = 'filename for inverse document frequencies (default: ndrm-idfs.tsv)') 59 | parser.add_argument('--file_gen_embeddings', default='ndrm-embeddings.bin', help='filename for fasttext embeddings (default: ndrm-embeddings.bin)') 60 | parser.add_argument('--orcas_field', help='use orcas data as additional document field', action='store_true') 61 | parser.add_argument('--no_conformer', help='use conformer model', action='store_true') 62 | 63 | def parser_validate_args(self, args): 64 | self.args = args 65 | assert args.max_terms_query > 0, 'maximum number of terms in query must be greater than zero' 66 | assert args.max_terms_doc > 0, 'maximum number of terms in document must be greater than zero' 67 | assert args.num_hidden_nodes % args.num_attn_heads == 0, 'number of hidden nodes should be divisible by the number of attention heads' 68 | assert args.drop >= 0 and args.drop < 1, 'dropout rate must be between 0 and 1' 69 | 70 | def setup_and_verify(self): 71 | self.__verify_gen_data() 72 | self.__preload_data_to_memory() 73 | 74 | def clean_text(self, s): 75 | s = self.regex_multi_space.sub(' ', self.regex_drop_char.sub(' ', s.lower())).strip() 76 | s = ' '.join([self.stemmer(t) for t in s.split() if t not in self.stop_words]) 77 | return s 78 | 79 | def tokenize(self, s): 80 | return s.split() 81 | 82 | def featurize(self, q, ds, infer_mode=False): 83 | q = self.tokenize(q) 84 | max_q_terms = len(q) if infer_mode else self.args.max_terms_query 85 | for i in range(len(ds)): 86 | fields = ds[i] 87 | other_fields = self.tokenize(' '.join(fields[:-1])) 88 | if self.args.orcas_field: 89 | orcas_field = self.tokenize(fields[-1])[:self.args.max_terms_orcas] 90 | ds[i] = [''] + orcas_field + other_fields + [''] 91 | else: 92 | ds[i] = [''] + other_fields + [''] 93 | feat_q, feat_mask_q = self.__get_features_lat(q, max_q_terms) 94 | feat_q = np.asarray(feat_q, dtype=np.int64) 95 | feat_mask_q = np.asarray(feat_mask_q, dtype=np.float32) 96 | if self.args.model != 'ndrm2': 97 | features = [self.__get_features_lat(doc, self.args.max_terms_doc) for doc in ds] 98 | feat_d = [feat[0] for feat in features] 99 | feat_d = np.asarray(feat_d, dtype=np.int64) 100 | feat_mask_d = [feat[1] for feat in features] 101 | feat_mask_d = np.asarray(feat_mask_d, dtype=np.float32) 102 | if self.args.model != 'ndrm1': 103 | feat_qd = [self.__get_features_exp(q, doc, max_q_terms) for doc in ds] 104 | feat_qd = np.asarray(feat_qd, dtype=np.float32) 105 | feat_idf = self.__get_features_idf(q, max_q_terms) 106 | feat_idf = np.asarray(feat_idf, dtype=np.float32) 107 | feat_dlen = self.__get_features_dlen(ds) 108 | feat_dlen = np.asarray(feat_dlen, dtype=np.float32) 109 | if self.args.model == 'ndrm1': 110 | return feat_q, feat_d, feat_mask_q, feat_mask_d 111 | if self.args.model == 'ndrm2': 112 | return feat_qd, feat_mask_q, feat_idf, feat_dlen 113 | return feat_q, feat_d, feat_qd, feat_mask_q, feat_mask_d, feat_idf, feat_dlen 114 | 115 | def __verify_gen_data(self): 116 | self.printer.print('verifying model specific input data') 117 | for k, file_name in vars(self.args).items(): 118 | if k.startswith('file_gen_'): 119 | file_path = os.path.join(self.args.local_dir, file_name) 120 | if not os.path.exists(file_path): 121 | if k == 'file_gen_embeddings': 122 | self.__generate_embeddings(file_path) 123 | elif k == 'file_gen_idfs': 124 | self.__generate_idfs(file_path) 125 | 126 | def __generate_embeddings(self, file_path): 127 | self.printer.print('generating fasttext term embeddings') 128 | tmp_file = os.path.join(self.args.local_dir, 'tmp') 129 | with open(tmp_file, 'w', encoding='utf8') as f_out: 130 | with open(os.path.join(self.args.local_dir, self.args.file_in_qs_train), 'rt', encoding='utf8') as f_in: 131 | reader = csv.reader(f_in, delimiter= '\t') 132 | for [_, q] in reader: 133 | f_out.write(q) 134 | f_out.write('\n') 135 | with open(os.path.join(self.args.local_dir, self.args.file_in_docs), 'rt', encoding='utf8') as f_in: 136 | reader = csv.reader(f_in, delimiter= '\t') 137 | for row in reader: 138 | f_out.write('\n'.join(row[1:])) 139 | f_out.write('\n') 140 | self.printer.print('training fasttext term embeddings') 141 | embeddings = fasttext.train_unsupervised(tmp_file, model='skipgram', dim=self.args.num_hidden_nodes // 2, bucket=10000, minCount=100, minn=1, maxn=0, ws=10, epoch=5) 142 | embeddings.save_model(file_path) 143 | os.remove(tmp_file) 144 | 145 | def __generate_idfs(self, file_path): 146 | self.printer.print('generating inverse document frequencies for query terms') 147 | terms_q = set([item for sublist in [self.tokenize(q)[:self.args.max_terms_query] for q in self.parent.data_utils.qs.values()] for item in sublist]) 148 | dfs = {term: 0 for term in terms_q} 149 | n = 0 150 | with open(os.path.join(self.args.local_dir, self.args.file_in_docs), 'rt', encoding = 'utf8') as f: 151 | reader = csv.reader(f, delimiter= '\t') 152 | for row in progress.bar(reader, expected_size=self.args.collection_size, every=(self.args.collection_size // 10000)): 153 | terms_d = set().union(*[field.split() for field in row[1:]]) 154 | terms = terms_q & terms_d 155 | for term in terms: 156 | dfs[term] += 1 157 | n += 1 158 | idfs = {k : max(math.log((n - v + 0.5) / (v + 0.5)), 0) for k,v in dfs.items()} 159 | idfs = {k : v for k,v in idfs.items() if v > 0} 160 | idfs = sorted(idfs.items(), key = lambda kv : kv[1]) 161 | with open(file_path, 'w', encoding = 'utf8') as f: 162 | for (k, v) in idfs: 163 | f.write('{}\t{}\n'.format(k, v)) 164 | 165 | def __preload_data_to_memory(self): 166 | self.printer.print('preloading model specific data to memory') 167 | self.vocab, self.pretrained_embeddings = self.__get_pretrained_embeddings() 168 | setattr(self.args, 'vocab_size', self.pretrained_embeddings.size()[0]) 169 | self.idfs = self.__get_idfs() 170 | 171 | def __get_pretrained_embeddings(self): 172 | model = fasttext.load_model(os.path.join(self.args.local_dir, self.args.file_gen_embeddings)) 173 | embed_size = model.get_input_matrix().shape[1] * 2 174 | self.__clear_line_console() 175 | if self.args.num_hidden_nodes != embed_size: 176 | self.printer.print('error: pretrained embedding size ({}) does not match specified embedding size ({})'.format(embed_size, self.args.num_hidden_nodes)) 177 | sys.exit(0) 178 | pretrained_embeddings = torch.cat([torch.FloatTensor(model.get_input_matrix()), torch.FloatTensor(model.get_output_matrix())], dim=1) 179 | pretrained_embeddings = torch.cat([torch.zeros([3, embed_size], dtype=torch.float32), pretrained_embeddings], dim=0) 180 | terms = model.get_words(include_freq=False) 181 | vocab = {'UNK': 0, '': 1, '': 2} 182 | for i in range(len(terms)): 183 | vocab[terms[i]] = i + 3 184 | return vocab, pretrained_embeddings 185 | 186 | def __get_features_lat(self, terms, max_terms): 187 | terms = terms[:max_terms] 188 | num_terms = len(terms) 189 | num_pad = max_terms - num_terms 190 | features = [self.vocab.get(terms[i], self.vocab['UNK']) for i in range(num_terms)] + [0]*num_pad 191 | masks = [1]*num_terms + [0]*num_pad 192 | return features, masks 193 | 194 | def __get_features_exp(self, q, d, max_q_terms): 195 | q = q[:max_q_terms] 196 | features = [d.count(term) for term in q] 197 | pad_len = max_q_terms - len(q) 198 | features.extend([0]*pad_len) 199 | return features 200 | 201 | def __get_features_dlen(self, ds): 202 | features = [len(d) for d in ds] 203 | return features 204 | 205 | def __get_features_idf(self, terms, max_terms): 206 | terms = terms[:max_terms] 207 | num_terms = len(terms) 208 | num_pad = max_terms - num_terms 209 | features = [self.idfs.get(terms[i], 0) for i in range(num_terms)] + [0]*num_pad 210 | return features 211 | 212 | def __get_idfs(self): 213 | idfs = {} 214 | with open(os.path.join(self.args.local_dir, self.args.file_gen_idfs), 'rt', encoding = 'utf8') as f: 215 | reader = csv.reader(f, delimiter = '\t') 216 | for [term, idf] in reader: 217 | idfs[term] = float(idf) 218 | return idfs 219 | 220 | def __clear_line_console(self): 221 | sys.stdout.write("\033[F") 222 | sys.stdout.write("\033[K") 223 | -------------------------------------------------------------------------------- /parallel.py: -------------------------------------------------------------------------------- 1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | ## Created by: Hang Zhang, Rutgers University, Email: zhang.hang@rutgers.edu 3 | ## Modified by Thomas Wolf, HuggingFace Inc., Email: thomas@huggingface.co 4 | ## Copyright (c) 2017-2018 5 | ## 6 | ## This source code is licensed under the MIT-style license found in the 7 | ## LICENSE file in the root directory of this source tree 8 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 9 | 10 | """Encoding Data Parallel""" 11 | import threading 12 | import functools 13 | import torch 14 | from torch.autograd import Variable, Function 15 | import torch.cuda.comm as comm 16 | from torch.nn.parallel.data_parallel import DataParallel 17 | from torch.nn.parallel.parallel_apply import get_a_var 18 | from torch.nn.parallel.scatter_gather import gather 19 | from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast 20 | from torch.nn.parallel.distributed import DistributedDataParallel 21 | 22 | torch_ver = torch.__version__[:3] 23 | 24 | __all__ = ['allreduce', 'DataParallelModel', 'DataParallelCriterion', 25 | 'patch_replication_callback'] 26 | 27 | def allreduce(*inputs): 28 | """Cross GPU all reduce autograd operation for calculate mean and 29 | variance in SyncBN. 30 | """ 31 | return AllReduce.apply(*inputs) 32 | 33 | class AllReduce(Function): 34 | @staticmethod 35 | def forward(ctx, num_inputs, *inputs): 36 | ctx.num_inputs = num_inputs 37 | ctx.target_gpus = [inputs[i].get_device() for i in range(0, len(inputs), num_inputs)] 38 | inputs = [inputs[i:i + num_inputs] 39 | for i in range(0, len(inputs), num_inputs)] 40 | # sort before reduce sum 41 | inputs = sorted(inputs, key=lambda i: i[0].get_device()) 42 | results = comm.reduce_add_coalesced(inputs, ctx.target_gpus[0]) 43 | outputs = comm.broadcast_coalesced(results, ctx.target_gpus) 44 | return tuple([t for tensors in outputs for t in tensors]) 45 | 46 | @staticmethod 47 | def backward(ctx, *inputs): 48 | inputs = [i.data for i in inputs] 49 | inputs = [inputs[i:i + ctx.num_inputs] 50 | for i in range(0, len(inputs), ctx.num_inputs)] 51 | results = comm.reduce_add_coalesced(inputs, ctx.target_gpus[0]) 52 | outputs = comm.broadcast_coalesced(results, ctx.target_gpus) 53 | return (None,) + tuple([Variable(t) for tensors in outputs for t in tensors]) 54 | 55 | 56 | class Reduce(Function): 57 | @staticmethod 58 | def forward(ctx, *inputs): 59 | ctx.target_gpus = [inputs[i].get_device() for i in range(len(inputs))] 60 | inputs = sorted(inputs, key=lambda i: i.get_device()) 61 | return comm.reduce_add(inputs) 62 | 63 | @staticmethod 64 | def backward(ctx, gradOutput): 65 | return Broadcast.apply(ctx.target_gpus, gradOutput) 66 | 67 | class DistributedDataParallelModel(DistributedDataParallel): 68 | """Implements data parallelism at the module level for the DistributedDataParallel module. 69 | This container parallelizes the application of the given module by 70 | splitting the input across the specified devices by chunking in the 71 | batch dimension. 72 | In the forward pass, the module is replicated on each device, 73 | and each replica handles a portion of the input. During the backwards pass, 74 | gradients from each replica are summed into the original module. 75 | Note that the outputs are not gathered, please use compatible 76 | :class:`encoding.parallel.DataParallelCriterion`. 77 | The batch size should be larger than the number of GPUs used. It should 78 | also be an integer multiple of the number of GPUs so that each chunk is 79 | the same size (so that each GPU processes the same number of samples). 80 | Args: 81 | module: module to be parallelized 82 | device_ids: CUDA devices (default: all devices) 83 | Reference: 84 | Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang, Ambrish Tyagi, 85 | Amit Agrawal. “Context Encoding for Semantic Segmentation. 86 | *The IEEE Conference on Computer Vision and Pattern Recognition (CVPR) 2018* 87 | Example:: 88 | >>> net = encoding.nn.DistributedDataParallelModel(model, device_ids=[0, 1, 2]) 89 | >>> y = net(x) 90 | """ 91 | def gather(self, outputs, output_device): 92 | return outputs 93 | 94 | class DataParallelModel(DataParallel): 95 | """Implements data parallelism at the module level. 96 | 97 | This container parallelizes the application of the given module by 98 | splitting the input across the specified devices by chunking in the 99 | batch dimension. 100 | In the forward pass, the module is replicated on each device, 101 | and each replica handles a portion of the input. During the backwards pass, 102 | gradients from each replica are summed into the original module. 103 | Note that the outputs are not gathered, please use compatible 104 | :class:`encoding.parallel.DataParallelCriterion`. 105 | 106 | The batch size should be larger than the number of GPUs used. It should 107 | also be an integer multiple of the number of GPUs so that each chunk is 108 | the same size (so that each GPU processes the same number of samples). 109 | 110 | Args: 111 | module: module to be parallelized 112 | device_ids: CUDA devices (default: all devices) 113 | 114 | Reference: 115 | Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang, Ambrish Tyagi, 116 | Amit Agrawal. “Context Encoding for Semantic Segmentation. 117 | *The IEEE Conference on Computer Vision and Pattern Recognition (CVPR) 2018* 118 | 119 | Example:: 120 | 121 | >>> net = encoding.nn.DataParallelModel(model, device_ids=[0, 1, 2]) 122 | >>> y = net(x) 123 | """ 124 | def gather(self, outputs, output_device): 125 | return outputs 126 | 127 | def replicate(self, module, device_ids): 128 | modules = super(DataParallelModel, self).replicate(module, device_ids) 129 | execute_replication_callbacks(modules) 130 | return modules 131 | 132 | 133 | class DataParallelCriterion(DataParallel): 134 | """ 135 | Calculate loss in multiple-GPUs, which balance the memory usage. 136 | The targets are splitted across the specified devices by chunking in 137 | the batch dimension. Please use together with :class:`encoding.parallel.DataParallelModel`. 138 | 139 | Reference: 140 | Hang Zhang, Kristin Dana, Jianping Shi, Zhongyue Zhang, Xiaogang Wang, Ambrish Tyagi, 141 | Amit Agrawal. “Context Encoding for Semantic Segmentation. 142 | *The IEEE Conference on Computer Vision and Pattern Recognition (CVPR) 2018* 143 | 144 | Example:: 145 | 146 | >>> net = encoding.nn.DataParallelModel(model, device_ids=[0, 1, 2]) 147 | >>> criterion = encoding.nn.DataParallelCriterion(criterion, device_ids=[0, 1, 2]) 148 | >>> y = net(x) 149 | >>> loss = criterion(y, target) 150 | """ 151 | def forward(self, inputs, *targets, **kwargs): 152 | # input should be already scatterd 153 | # scattering the targets instead 154 | if not self.device_ids: 155 | return self.module(inputs, *targets, **kwargs) 156 | targets, kwargs = self.scatter(targets, kwargs, self.device_ids) 157 | if len(self.device_ids) == 1: 158 | return self.module(inputs, *targets[0], **kwargs[0]) 159 | replicas = self.replicate(self.module, self.device_ids[:len(inputs)]) 160 | outputs = _criterion_parallel_apply(replicas, inputs, targets, kwargs) 161 | #return Reduce.apply(*outputs) / len(outputs) 162 | return self.gather(outputs, self.output_device).mean() 163 | #return self.gather(outputs, self.output_device) 164 | 165 | 166 | def _criterion_parallel_apply(modules, inputs, targets, kwargs_tup=None, devices=None): 167 | assert len(modules) == len(inputs) 168 | assert len(targets) == len(inputs) 169 | if kwargs_tup: 170 | assert len(modules) == len(kwargs_tup) 171 | else: 172 | kwargs_tup = ({},) * len(modules) 173 | if devices is not None: 174 | assert len(modules) == len(devices) 175 | else: 176 | devices = [None] * len(modules) 177 | 178 | lock = threading.Lock() 179 | results = {} 180 | if torch_ver != "0.3": 181 | grad_enabled = torch.is_grad_enabled() 182 | 183 | def _worker(i, module, input, target, kwargs, device=None): 184 | if torch_ver != "0.3": 185 | torch.set_grad_enabled(grad_enabled) 186 | if device is None: 187 | device = get_a_var(input).get_device() 188 | try: 189 | with torch.cuda.device(device): 190 | # this also avoids accidental slicing of `input` if it is a Tensor 191 | if not isinstance(input, (list, tuple)): 192 | input = (input,) 193 | if not isinstance(target, (list, tuple)): 194 | target = (target,) 195 | output = module(*(input + target), **kwargs) 196 | with lock: 197 | results[i] = output 198 | except Exception as e: 199 | with lock: 200 | results[i] = e 201 | 202 | if len(modules) > 1: 203 | threads = [threading.Thread(target=_worker, 204 | args=(i, module, input, target, 205 | kwargs, device),) 206 | for i, (module, input, target, kwargs, device) in 207 | enumerate(zip(modules, inputs, targets, kwargs_tup, devices))] 208 | 209 | for thread in threads: 210 | thread.start() 211 | for thread in threads: 212 | thread.join() 213 | else: 214 | _worker(0, modules[0], inputs[0], kwargs_tup[0], devices[0]) 215 | 216 | outputs = [] 217 | for i in range(len(inputs)): 218 | output = results[i] 219 | if isinstance(output, Exception): 220 | raise output 221 | outputs.append(output) 222 | return outputs 223 | 224 | 225 | ########################################################################### 226 | # Adapted from Synchronized-BatchNorm-PyTorch. 227 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 228 | # 229 | class CallbackContext(object): 230 | pass 231 | 232 | 233 | def execute_replication_callbacks(modules): 234 | """ 235 | Execute an replication callback `__data_parallel_replicate__` on each module created 236 | by original replication. 237 | 238 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 239 | 240 | Note that, as all modules are isomorphism, we assign each sub-module with a context 241 | (shared among multiple copies of this module on different devices). 242 | Through this context, different copies can share some information. 243 | 244 | We guarantee that the callback on the master copy (the first copy) will be called ahead 245 | of calling the callback of any slave copies. 246 | """ 247 | master_copy = modules[0] 248 | nr_modules = len(list(master_copy.modules())) 249 | ctxs = [CallbackContext() for _ in range(nr_modules)] 250 | 251 | for i, module in enumerate(modules): 252 | for j, m in enumerate(module.modules()): 253 | if hasattr(m, '__data_parallel_replicate__'): 254 | m.__data_parallel_replicate__(ctxs[j], i) 255 | 256 | 257 | def patch_replication_callback(data_parallel): 258 | """ 259 | Monkey-patch an existing `DataParallel` object. Add the replication callback. 260 | Useful when you have customized `DataParallel` implementation. 261 | 262 | Examples: 263 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 264 | > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) 265 | > patch_replication_callback(sync_bn) 266 | # this is equivalent to 267 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 268 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 269 | """ 270 | 271 | assert isinstance(data_parallel, DataParallel) 272 | 273 | old_replicate = data_parallel.replicate 274 | 275 | @functools.wraps(old_replicate) 276 | def new_replicate(module, device_ids): 277 | modules = old_replicate(module, device_ids) 278 | execute_replication_callbacks(modules) 279 | return modules 280 | 281 | data_parallel.replicate = new_replicate 282 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | from learner import Learner 2 | from utils import Utils, Printer 3 | 4 | 5 | def main(): 6 | printer = Printer('log.txt') 7 | utils = Utils(printer) 8 | utils.setup_and_verify() 9 | utils.evaluate_baseline() 10 | learner = Learner(utils.learner_utils) 11 | learner.train_and_evaluate() 12 | utils.printer.print('finished!') 13 | 14 | if __name__ == '__main__': 15 | main() 16 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import csv 4 | import math 5 | import torch 6 | import argparse 7 | import datetime 8 | from factory import Factory 9 | from data_utils import DataUtils 10 | from model_utils import NDRMUtils 11 | from learner_utils import LearnerUtils 12 | 13 | 14 | class Printer: 15 | 16 | def __init__(self, file_path): 17 | self.log = open(file_path, mode='w', encoding='utf-8') 18 | 19 | def print(self, s, end='\n', suppress_timestamp=False): 20 | if not suppress_timestamp: 21 | msg = '[{}]\t{}{}'.format(datetime.datetime.now().strftime('%b %d, %H:%M:%S'), s, end) 22 | else: 23 | msg = '\t{}{}'.format(s, end) 24 | print(msg, flush=True, end='') 25 | self.log.write(msg) 26 | self.log.flush() 27 | 28 | def __exit__(self, exc_type, exc_value, traceback): 29 | self.log.close() 30 | 31 | 32 | class Utils: 33 | 34 | def __init__(self, printer=None): 35 | if printer == None: 36 | self.printer = Printer('log.txt') 37 | else: 38 | self.printer = printer 39 | 40 | def setup_and_verify(self): 41 | parser = argparse.ArgumentParser(description= 'trec 2019 deep learning track (document re-ranking task)') 42 | torch.set_printoptions(threshold=500) 43 | self.data_utils = DataUtils(self.printer) 44 | self.model_utils = NDRMUtils(self.printer) 45 | self.learner_utils = LearnerUtils(self.printer) 46 | self.sub_utils = [self.data_utils, self.model_utils, self.learner_utils] 47 | for sub_utils in self.sub_utils: 48 | sub_utils.parent = self 49 | sub_utils.parser_add_args(parser) 50 | self.args = parser.parse_args() 51 | for sub_utils in self.sub_utils: 52 | sub_utils.parser_validate_args(self.args) 53 | self.__print_versions() 54 | for sub_utils in self.sub_utils: 55 | sub_utils.setup_and_verify() 56 | self.__print_args() 57 | 58 | def evaluate_baseline(self): 59 | results_dev = self.get_baseline_results(self.args.file_in_cnd_dev) 60 | mrr_dev, _, _ = self.evaluate_results(results_dev, self.data_utils.qrels_dev) 61 | results_val = self.get_baseline_results(self.args.file_in_cnd_val) 62 | mrr_val, ncg_val, ndcg_val = self.evaluate_results(results_val, self.data_utils.qrels_val) 63 | self.printer.print('baseline\tdev mrr: {:.3f}\tval mrr: {:.3f}\tval ncg: {:.3f}\tval ndcg: {:.3f}'.format(mrr_dev, mrr_val, ncg_val, ndcg_val)) 64 | 65 | def get_baseline_results(self, cnd_file): 66 | results = {} 67 | with open(os.path.join(self.args.local_dir, cnd_file), 'rt', encoding='utf8') as f: 68 | reader = csv.reader(f, delimiter=' ') 69 | for [qid, _, did, rank, _, _] in reader: 70 | rank = int(rank) 71 | if qid not in results: 72 | results[qid] = [] 73 | results[qid].append((did, -rank)) 74 | results = {qid: sorted(docs, key=lambda x: x[1], reverse=True)[:self.args.max_metric_pos_nodisc] for qid, docs in results.items()} 75 | return results 76 | 77 | def evaluate_results(self, results, qrels): 78 | mrr = 0 79 | ncg = 0 80 | ndcg = 0 81 | for qid, docs in results.items(): 82 | if qid not in qrels: 83 | continue 84 | qrels_q = qrels[qid] 85 | gains = [qrels_q.get(doc[0], 0) for doc in docs] 86 | ideal_gains = sorted(list(qrels_q.values()), reverse=True)[:self.args.max_metric_pos_nodisc] 87 | max_metric_pos_disc = min(len(gains), self.args.max_metric_pos) 88 | max_metric_pos_disc_ideal = min(len(ideal_gains), self.args.max_metric_pos) 89 | cg = sum([gain for gain in gains]) 90 | dcg = sum([gains[i] / math.log2(i + 2) for i in range(max_metric_pos_disc)]) 91 | ideal_cg = sum([ideal_gain for ideal_gain in ideal_gains]) 92 | ideal_dcg = sum([ideal_gains[i] / math.log2(i + 2) for i in range(max_metric_pos_disc_ideal)]) 93 | ncg += cg / ideal_cg if ideal_cg > 0 else 0 94 | ndcg += dcg / ideal_dcg if ideal_dcg > 0 else 0 95 | try: 96 | mrr += 1 / ([min(gain, 1) for gain in gains][:max_metric_pos_disc].index(1) + 1) 97 | except Exception: 98 | pass 99 | mrr /= len(qrels) 100 | ncg /= len(qrels) 101 | ndcg /= len(qrels) 102 | return mrr, ncg, ndcg 103 | 104 | def __print_args(self): 105 | self.printer.print('Running with following specified and inferred arguments:') 106 | for key, value in self.args._get_kwargs(): 107 | if value is not None: 108 | if isinstance(value, int) and not isinstance(value, bool): 109 | self.printer.print('\t{:<40}{:,}'.format(key, value)) 110 | else: 111 | self.printer.print('\t{:<40}{}'.format(key, value)) 112 | 113 | def __print_versions(self): 114 | self.printer.print('Python version: {}'.format(sys.version.replace('\n', ''))) 115 | self.printer.print('PyTorch version: {}'.format(torch.__version__)) 116 | self.printer.print('CUDA version: {}'.format(torch.version.cuda)) 117 | --------------------------------------------------------------------------------