├── .gitignore ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── biencoder.py ├── docs └── wsd_biencoder_architecture.jpg ├── finetune_pretrained_encoder.py ├── frozen_pretrained_encoder.py └── wsd_models ├── models.py └── util.py /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | # Byte-compiled / optimized / DLL files 3 | __pycache__/ 4 | *.py[cod] 5 | *$py.class 6 | 7 | # C extensions 8 | *.so 9 | 10 | # Distribution / packaging 11 | .Python 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .coverage 43 | .coverage.* 44 | .cache 45 | nosetests.xml 46 | coverage.xml 47 | *.cover 48 | .hypothesis/ 49 | .pytest_cache/ 50 | 51 | # Translations 52 | *.mo 53 | *.pot 54 | 55 | # Django stuff: 56 | *.log 57 | local_settings.py 58 | db.sqlite3 59 | 60 | # Flask stuff: 61 | instance/ 62 | .webassets-cache 63 | 64 | # Scrapy stuff: 65 | .scrapy 66 | 67 | # Sphinx documentation 68 | docs/_build/ 69 | 70 | # PyBuilder 71 | target/ 72 | 73 | # Jupyter Notebook 74 | .ipynb_checkpoints 75 | 76 | # pyenv 77 | .python-version 78 | 79 | # celery beat schedule file 80 | celerybeat-schedule 81 | 82 | # SageMath parsed files 83 | *.sage.py 84 | 85 | # Environments 86 | .env 87 | .venv 88 | env/ 89 | venv/ 90 | ENV/ 91 | env.bak/ 92 | venv.bak/ 93 | 94 | # Spyder project settings 95 | .spyderproject 96 | .spyproject 97 | 98 | # Rope project settings 99 | .ropeproject 100 | 101 | # mkdocs documentation 102 | /site 103 | 104 | # mypy 105 | .mypy_cache/ 106 | 107 | #Macs 108 | .DS_Store -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to make participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | ## Enforcement 56 | 57 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 58 | reported by contacting the project team at . All 59 | complaints will be reviewed and investigated and will result in a response that 60 | is deemed necessary and appropriate to the circumstances. The project team is 61 | obligated to maintain confidentiality with regard to the reporter of an incident. 62 | Further details of specific enforcement policies may be posted separately. 63 | 64 | Project maintainers who do not follow or enforce the Code of Conduct in good 65 | faith may face temporary or permanent repercussions as determined by other 66 | members of the project's leadership. 67 | 68 | ## Attribution 69 | 70 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 71 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 72 | 73 | [homepage]: https://www.contributor-covenant.org 74 | 75 | For answers to common questions about this code of conduct, see 76 | https://www.contributor-covenant.org/faq -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to `wsd-biencoders` 2 | We want to make contributing to this codebase as easy and transparent as 3 | possible. 4 | 5 | ## Pull Requests 6 | We actively welcome your pull requests. 7 | 8 | 1. Fork the repo and create your branch from `main`. 9 | 2. If you've added code that should be tested, add tests. 10 | 3. If you've changed APIs, update the documentation. 11 | 4. Ensure the test suite passes. 12 | 5. Make sure your code lints. 13 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 14 | 15 | ## Contributor License Agreement ("CLA") 16 | In order to accept your pull request, we need you to submit a CLA. You only need 17 | to do this once to work on any of Facebook's open source projects. 18 | 19 | Complete your CLA here: 20 | 21 | ## Issues 22 | We use GitHub issues to track public bugs. Please ensure your description is 23 | clear and has sufficient instructions to be able to reproduce the issue. 24 | 25 | ## License 26 | By contributing to `wsd-biencoders`, you agree that your contributions will be licensed 27 | under the LICENSE file in the root directory of this source tree. 28 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Attribution-NonCommercial 4.0 International 2 | 3 | ======================================================================= 4 | 5 | Creative Commons Corporation ("Creative Commons") is not a law firm and 6 | does not provide legal services or legal advice. Distribution of 7 | Creative Commons public licenses does not create a lawyer-client or 8 | other relationship. Creative Commons makes its licenses and related 9 | information available on an "as-is" basis. Creative Commons gives no 10 | warranties regarding its licenses, any material licensed under their 11 | terms and conditions, or any related information. Creative Commons 12 | disclaims all liability for damages resulting from their use to the 13 | fullest extent possible. 14 | 15 | Using Creative Commons Public Licenses 16 | 17 | Creative Commons public licenses provide a standard set of terms and 18 | conditions that creators and other rights holders may use to share 19 | original works of authorship and other material subject to copyright 20 | and certain other rights specified in the public license below. The 21 | following considerations are for informational purposes only, are not 22 | exhaustive, and do not form part of our licenses. 23 | 24 | Considerations for licensors: Our public licenses are 25 | intended for use by those authorized to give the public 26 | permission to use material in ways otherwise restricted by 27 | copyright and certain other rights. Our licenses are 28 | irrevocable. Licensors should read and understand the terms 29 | and conditions of the license they choose before applying it. 30 | Licensors should also secure all rights necessary before 31 | applying our licenses so that the public can reuse the 32 | material as expected. Licensors should clearly mark any 33 | material not subject to the license. This includes other CC- 34 | licensed material, or material used under an exception or 35 | limitation to copyright. More considerations for licensors: 36 | wiki.creativecommons.org/Considerations_for_licensors 37 | 38 | Considerations for the public: By using one of our public 39 | licenses, a licensor grants the public permission to use the 40 | licensed material under specified terms and conditions. If 41 | the licensor's permission is not necessary for any reason--for 42 | example, because of any applicable exception or limitation to 43 | copyright--then that use is not regulated by the license. Our 44 | licenses grant only permissions under copyright and certain 45 | other rights that a licensor has authority to grant. Use of 46 | the licensed material may still be restricted for other 47 | reasons, including because others have copyright or other 48 | rights in the material. A licensor may make special requests, 49 | such as asking that all changes be marked or described. 50 | Although not required by our licenses, you are encouraged to 51 | respect those requests where reasonable. More_considerations 52 | for the public: 53 | wiki.creativecommons.org/Considerations_for_licensees 54 | 55 | ======================================================================= 56 | 57 | Creative Commons Attribution-NonCommercial 4.0 International Public 58 | License 59 | 60 | By exercising the Licensed Rights (defined below), You accept and agree 61 | to be bound by the terms and conditions of this Creative Commons 62 | Attribution-NonCommercial 4.0 International Public License ("Public 63 | License"). To the extent this Public License may be interpreted as a 64 | contract, You are granted the Licensed Rights in consideration of Your 65 | acceptance of these terms and conditions, and the Licensor grants You 66 | such rights in consideration of benefits the Licensor receives from 67 | making the Licensed Material available under these terms and 68 | conditions. 69 | 70 | Section 1 -- Definitions. 71 | 72 | a. Adapted Material means material subject to Copyright and Similar 73 | Rights that is derived from or based upon the Licensed Material 74 | and in which the Licensed Material is translated, altered, 75 | arranged, transformed, or otherwise modified in a manner requiring 76 | permission under the Copyright and Similar Rights held by the 77 | Licensor. For purposes of this Public License, where the Licensed 78 | Material is a musical work, performance, or sound recording, 79 | Adapted Material is always produced where the Licensed Material is 80 | synched in timed relation with a moving image. 81 | 82 | b. Adapter's License means the license You apply to Your Copyright 83 | and Similar Rights in Your contributions to Adapted Material in 84 | accordance with the terms and conditions of this Public License. 85 | 86 | c. Copyright and Similar Rights means copyright and/or similar rights 87 | closely related to copyright including, without limitation, 88 | performance, broadcast, sound recording, and Sui Generis Database 89 | Rights, without regard to how the rights are labeled or 90 | categorized. For purposes of this Public License, the rights 91 | specified in Section 2(b)(1)-(2) are not Copyright and Similar 92 | Rights. 93 | d. Effective Technological Measures means those measures that, in the 94 | absence of proper authority, may not be circumvented under laws 95 | fulfilling obligations under Article 11 of the WIPO Copyright 96 | Treaty adopted on December 20, 1996, and/or similar international 97 | agreements. 98 | 99 | e. Exceptions and Limitations means fair use, fair dealing, and/or 100 | any other exception or limitation to Copyright and Similar Rights 101 | that applies to Your use of the Licensed Material. 102 | 103 | f. Licensed Material means the artistic or literary work, database, 104 | or other material to which the Licensor applied this Public 105 | License. 106 | 107 | g. Licensed Rights means the rights granted to You subject to the 108 | terms and conditions of this Public License, which are limited to 109 | all Copyright and Similar Rights that apply to Your use of the 110 | Licensed Material and that the Licensor has authority to license. 111 | 112 | h. Licensor means the individual(s) or entity(ies) granting rights 113 | under this Public License. 114 | 115 | i. NonCommercial means not primarily intended for or directed towards 116 | commercial advantage or monetary compensation. For purposes of 117 | this Public License, the exchange of the Licensed Material for 118 | other material subject to Copyright and Similar Rights by digital 119 | file-sharing or similar means is NonCommercial provided there is 120 | no payment of monetary compensation in connection with the 121 | exchange. 122 | 123 | j. Share means to provide material to the public by any means or 124 | process that requires permission under the Licensed Rights, such 125 | as reproduction, public display, public performance, distribution, 126 | dissemination, communication, or importation, and to make material 127 | available to the public including in ways that members of the 128 | public may access the material from a place and at a time 129 | individually chosen by them. 130 | 131 | k. Sui Generis Database Rights means rights other than copyright 132 | resulting from Directive 96/9/EC of the European Parliament and of 133 | the Council of 11 March 1996 on the legal protection of databases, 134 | as amended and/or succeeded, as well as other essentially 135 | equivalent rights anywhere in the world. 136 | 137 | l. You means the individual or entity exercising the Licensed Rights 138 | under this Public License. Your has a corresponding meaning. 139 | 140 | Section 2 -- Scope. 141 | 142 | a. License grant. 143 | 144 | 1. Subject to the terms and conditions of this Public License, 145 | the Licensor hereby grants You a worldwide, royalty-free, 146 | non-sublicensable, non-exclusive, irrevocable license to 147 | exercise the Licensed Rights in the Licensed Material to: 148 | 149 | a. reproduce and Share the Licensed Material, in whole or 150 | in part, for NonCommercial purposes only; and 151 | 152 | b. produce, reproduce, and Share Adapted Material for 153 | NonCommercial purposes only. 154 | 155 | 2. Exceptions and Limitations. For the avoidance of doubt, where 156 | Exceptions and Limitations apply to Your use, this Public 157 | License does not apply, and You do not need to comply with 158 | its terms and conditions. 159 | 160 | 3. Term. The term of this Public License is specified in Section 161 | 6(a). 162 | 163 | 4. Media and formats; technical modifications allowed. The 164 | Licensor authorizes You to exercise the Licensed Rights in 165 | all media and formats whether now known or hereafter created, 166 | and to make technical modifications necessary to do so. The 167 | Licensor waives and/or agrees not to assert any right or 168 | authority to forbid You from making technical modifications 169 | necessary to exercise the Licensed Rights, including 170 | technical modifications necessary to circumvent Effective 171 | Technological Measures. For purposes of this Public License, 172 | simply making modifications authorized by this Section 2(a) 173 | (4) never produces Adapted Material. 174 | 175 | 5. Downstream recipients. 176 | 177 | a. Offer from the Licensor -- Licensed Material. Every 178 | recipient of the Licensed Material automatically 179 | receives an offer from the Licensor to exercise the 180 | Licensed Rights under the terms and conditions of this 181 | Public License. 182 | 183 | b. No downstream restrictions. You may not offer or impose 184 | any additional or different terms or conditions on, or 185 | apply any Effective Technological Measures to, the 186 | Licensed Material if doing so restricts exercise of the 187 | Licensed Rights by any recipient of the Licensed 188 | Material. 189 | 190 | 6. No endorsement. Nothing in this Public License constitutes or 191 | may be construed as permission to assert or imply that You 192 | are, or that Your use of the Licensed Material is, connected 193 | with, or sponsored, endorsed, or granted official status by, 194 | the Licensor or others designated to receive attribution as 195 | provided in Section 3(a)(1)(A)(i). 196 | 197 | b. Other rights. 198 | 199 | 1. Moral rights, such as the right of integrity, are not 200 | licensed under this Public License, nor are publicity, 201 | privacy, and/or other similar personality rights; however, to 202 | the extent possible, the Licensor waives and/or agrees not to 203 | assert any such rights held by the Licensor to the limited 204 | extent necessary to allow You to exercise the Licensed 205 | Rights, but not otherwise. 206 | 207 | 2. Patent and trademark rights are not licensed under this 208 | Public License. 209 | 210 | 3. To the extent possible, the Licensor waives any right to 211 | collect royalties from You for the exercise of the Licensed 212 | Rights, whether directly or through a collecting society 213 | under any voluntary or waivable statutory or compulsory 214 | licensing scheme. In all other cases the Licensor expressly 215 | reserves any right to collect such royalties, including when 216 | the Licensed Material is used other than for NonCommercial 217 | purposes. 218 | 219 | Section 3 -- License Conditions. 220 | 221 | Your exercise of the Licensed Rights is expressly made subject to the 222 | following conditions. 223 | 224 | a. Attribution. 225 | 226 | 1. If You Share the Licensed Material (including in modified 227 | form), You must: 228 | 229 | a. retain the following if it is supplied by the Licensor 230 | with the Licensed Material: 231 | 232 | i. identification of the creator(s) of the Licensed 233 | Material and any others designated to receive 234 | attribution, in any reasonable manner requested by 235 | the Licensor (including by pseudonym if 236 | designated); 237 | 238 | ii. a copyright notice; 239 | 240 | iii. a notice that refers to this Public License; 241 | 242 | iv. a notice that refers to the disclaimer of 243 | warranties; 244 | 245 | v. a URI or hyperlink to the Licensed Material to the 246 | extent reasonably practicable; 247 | 248 | b. indicate if You modified the Licensed Material and 249 | retain an indication of any previous modifications; and 250 | 251 | c. indicate the Licensed Material is licensed under this 252 | Public License, and include the text of, or the URI or 253 | hyperlink to, this Public License. 254 | 255 | 2. You may satisfy the conditions in Section 3(a)(1) in any 256 | reasonable manner based on the medium, means, and context in 257 | which You Share the Licensed Material. For example, it may be 258 | reasonable to satisfy the conditions by providing a URI or 259 | hyperlink to a resource that includes the required 260 | information. 261 | 262 | 3. If requested by the Licensor, You must remove any of the 263 | information required by Section 3(a)(1)(A) to the extent 264 | reasonably practicable. 265 | 266 | 4. If You Share Adapted Material You produce, the Adapter's 267 | License You apply must not prevent recipients of the Adapted 268 | Material from complying with this Public License. 269 | 270 | Section 4 -- Sui Generis Database Rights. 271 | 272 | Where the Licensed Rights include Sui Generis Database Rights that 273 | apply to Your use of the Licensed Material: 274 | 275 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right 276 | to extract, reuse, reproduce, and Share all or a substantial 277 | portion of the contents of the database for NonCommercial purposes 278 | only; 279 | 280 | b. if You include all or a substantial portion of the database 281 | contents in a database in which You have Sui Generis Database 282 | Rights, then the database in which You have Sui Generis Database 283 | Rights (but not its individual contents) is Adapted Material; and 284 | 285 | c. You must comply with the conditions in Section 3(a) if You Share 286 | all or a substantial portion of the contents of the database. 287 | 288 | For the avoidance of doubt, this Section 4 supplements and does not 289 | replace Your obligations under this Public License where the Licensed 290 | Rights include other Copyright and Similar Rights. 291 | 292 | Section 5 -- Disclaimer of Warranties and Limitation of Liability. 293 | 294 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE 295 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS 296 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF 297 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, 298 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, 299 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR 300 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, 301 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT 302 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT 303 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. 304 | 305 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE 306 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, 307 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, 308 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, 309 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR 310 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN 311 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR 312 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR 313 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. 314 | 315 | c. The disclaimer of warranties and limitation of liability provided 316 | above shall be interpreted in a manner that, to the extent 317 | possible, most closely approximates an absolute disclaimer and 318 | waiver of all liability. 319 | 320 | Section 6 -- Term and Termination. 321 | 322 | a. This Public License applies for the term of the Copyright and 323 | Similar Rights licensed here. However, if You fail to comply with 324 | this Public License, then Your rights under this Public License 325 | terminate automatically. 326 | 327 | b. Where Your right to use the Licensed Material has terminated under 328 | Section 6(a), it reinstates: 329 | 330 | 1. automatically as of the date the violation is cured, provided 331 | it is cured within 30 days of Your discovery of the 332 | violation; or 333 | 334 | 2. upon express reinstatement by the Licensor. 335 | 336 | For the avoidance of doubt, this Section 6(b) does not affect any 337 | right the Licensor may have to seek remedies for Your violations 338 | of this Public License. 339 | 340 | c. For the avoidance of doubt, the Licensor may also offer the 341 | Licensed Material under separate terms or conditions or stop 342 | distributing the Licensed Material at any time; however, doing so 343 | will not terminate this Public License. 344 | 345 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public 346 | License. 347 | 348 | Section 7 -- Other Terms and Conditions. 349 | 350 | a. The Licensor shall not be bound by any additional or different 351 | terms or conditions communicated by You unless expressly agreed. 352 | 353 | b. Any arrangements, understandings, or agreements regarding the 354 | Licensed Material not stated herein are separate from and 355 | independent of the terms and conditions of this Public License. 356 | 357 | Section 8 -- Interpretation. 358 | 359 | a. For the avoidance of doubt, this Public License does not, and 360 | shall not be interpreted to, reduce, limit, restrict, or impose 361 | conditions on any use of the Licensed Material that could lawfully 362 | be made without permission under this Public License. 363 | 364 | b. To the extent possible, if any provision of this Public License is 365 | deemed unenforceable, it shall be automatically reformed to the 366 | minimum extent necessary to make it enforceable. If the provision 367 | cannot be reformed, it shall be severed from this Public License 368 | without affecting the enforceability of the remaining terms and 369 | conditions. 370 | 371 | c. No term or condition of this Public License will be waived and no 372 | failure to comply consented to unless expressly agreed to by the 373 | Licensor. 374 | 375 | d. Nothing in this Public License constitutes or may be interpreted 376 | as a limitation upon, or waiver of, any privileges and immunities 377 | that apply to the Licensor or You, including from the legal 378 | processes of any jurisdiction or authority. 379 | 380 | ======================================================================= 381 | 382 | Creative Commons is not a party to its public 383 | licenses. Notwithstanding, Creative Commons may elect to apply one of 384 | its public licenses to material it publishes and in those instances 385 | will be considered the “Licensor.” The text of the Creative Commons 386 | public licenses is dedicated to the public domain under the CC0 Public 387 | Domain Dedication. Except for the limited purpose of indicating that 388 | material is shared under a Creative Commons public license or as 389 | otherwise permitted by the Creative Commons policies published at 390 | creativecommons.org/policies, Creative Commons does not authorize the 391 | use of the trademark "Creative Commons" or any other trademark or logo 392 | of Creative Commons without its prior written consent including, 393 | without limitation, in connection with any unauthorized modifications 394 | to any of its public licenses or any other arrangements, 395 | understandings, or agreements concerning use of licensed material. For 396 | the avoidance of doubt, this paragraph does not form part of the 397 | public licenses. 398 | 399 | Creative Commons may be contacted at creativecommons.org. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Gloss Informed Bi-encoders for WSD 2 | 3 | This is the codebase for the paper [Moving Down the Long Tail of Word Sense Disambiguation with Gloss Informed Bi-encoders](https://blvns.github.io/papers/acl2020.pdf). 4 | 5 | ![Architecture of the gloss informed bi-encoder model for WSD](https://github.com/facebookresearch/wsd-biencoders/blob/main/docs/wsd_biencoder_architecture.jpg) 6 | Our bi-encoder model consists of two independent, transformer encoders: (1) a context encoder, which represents the target word (and its surrounding context) and (2) a gloss encoder, that embeds the definition text for each word sense. Each encoder is initalized with a pertrained model and optimized independently. 7 | 8 | ## Dependencies 9 | To run this code, you'll need the following libraries: 10 | * [Python 3](https://www.python.org/) 11 | * [Pytorch 1.2.0](https://pytorch.org/) 12 | * [Pytorch Transformers 1.1.0](https://github.com/huggingface/transformers) 13 | * [Numpy 1.17.2](https://numpy.org/) 14 | * [NLTK 3.4.5](https://www.nltk.org/) 15 | * [tqdm](https://tqdm.github.io/) 16 | 17 | We used the [WSD Evaluation Framework](http://lcl.uniroma1.it/wsdeval/) for training and evaluating our model. 18 | 19 | ## How to Run 20 | To train a biencoder model, run `python biencoder.py --data-path $path_to_wsd_data --ckpt $path_to_checkpoint`. The required arguments are: `--data-path`, which is the filepath to the top-level directory of the WSD Evaluation Framework; and `--ckpt`, which is the filepath of the directory to which to save the trained model checkpoints and prediction files. The `Scorer.java` in the WSD Framework data files needs to be compiled, with the `Scorer.class` file in the original directory of the Scorer file. 21 | 22 | It is recommended you train this model using the `--multigpu` flag to enable model parallel (note that this requires two available GPUs). More hyperparameter options are available as arguments; run `python biencoder.py -h` for all possible arguments. 23 | 24 | To evaluate an existing biencoder, run `python biencoder.py --data-path $path_to_wsd_data --ckpt $path_to_model_checkpoint --eval --split $wsd_eval_set`. Without `--split`, this defaults to evaluating on the development set, semeval2007. The model weights and predictions for the biencoder reported in the paper can be found [here](https://drive.google.com/file/d/1NZX_eMHQfRHhJnoJwEx2GnbnYIQepIQj). 25 | 26 | Similar commands can be used to run the frozen probe for WSD (`frozen_pretrained_encoder.py`) and the finetuning a pretrained, single encoder classifier for WSD (`finetune_pretrained_encoder.py`). 27 | 28 | ## Citation 29 | If you use this work, please cite the corresponding [paper](https://blvns.github.io/papers/acl2020.pdf): 30 | ``` 31 | @inproceedings{ 32 | blevins2020wsd, 33 | title={Moving Down the Long Tail of Word Sense Disambiguation with Gloss Informed Bi-encoders}, 34 | author={Terra Blevins and Luke Zettlemoyer}, 35 | booktitle={Proceedings of the 58th Association for Computational Linguistics}, 36 | year={2020}, 37 | url={https://blvns.github.io/papers/acl2020.pdf} 38 | } 39 | ``` 40 | 41 | ## Contact 42 | Please address any questions or comments about this codebase to blvns@cs.washington.edu. If you want to suggest changes or improvements, please check out the [CONTRIBUTING](CONTRIBUTING.md) file for how to help out. 43 | 44 | ## License 45 | This codebase is Attribution-NonCommercial 4.0 International licensed, as found in the [LICENSE](https://github.com/facebookresearch/wsd-biencoders/blob/master/LICENSE) file. 46 | -------------------------------------------------------------------------------- /biencoder.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | All rights reserved. 4 | 5 | This source code is licensed under the license found in the 6 | LICENSE file in the root directory of this source tree. 7 | ''' 8 | 9 | import torch 10 | from torch.nn import functional as F 11 | from nltk.corpus import wordnet as wn 12 | import os 13 | import sys 14 | import time 15 | import math 16 | import copy 17 | import argparse 18 | from tqdm import tqdm 19 | import pickle 20 | from pytorch_transformers import * 21 | 22 | import random 23 | import numpy as np 24 | 25 | from wsd_models.util import * 26 | from wsd_models.models import BiEncoderModel 27 | 28 | parser = argparse.ArgumentParser(description='Gloss Informed Bi-encoder for WSD') 29 | 30 | #training arguments 31 | parser.add_argument('--rand_seed', type=int, default=42) 32 | parser.add_argument('--grad-norm', type=float, default=1.0) 33 | parser.add_argument('--silent', action='store_true', 34 | help='Flag to supress training progress bar for each epoch') 35 | parser.add_argument('--multigpu', action='store_true') 36 | parser.add_argument('--lr', type=float, default=0.00001) 37 | parser.add_argument('--warmup', type=int, default=10000) 38 | parser.add_argument('--context-max-length', type=int, default=128) 39 | parser.add_argument('--gloss-max-length', type=int, default=32) 40 | parser.add_argument('--epochs', type=int, default=20) 41 | parser.add_argument('--context-bsz', type=int, default=4) 42 | parser.add_argument('--gloss-bsz', type=int, default=256) 43 | parser.add_argument('--encoder-name', type=str, default='bert-base', 44 | choices=['bert-base', 'bert-large', 'roberta-base', 'roberta-large']) 45 | parser.add_argument('--ckpt', type=str, required=True, 46 | help='filepath at which to save best probing model (on dev set)') 47 | parser.add_argument('--data-path', type=str, required=True, 48 | help='Location of top-level directory for the Unified WSD Framework') 49 | 50 | #sets which parts of the model to freeze ❄️ during training for ablation 51 | parser.add_argument('--freeze-context', action='store_true') 52 | parser.add_argument('--freeze-gloss', action='store_true') 53 | parser.add_argument('--tie-encoders', action='store_true') 54 | 55 | #other training settings flags 56 | parser.add_argument('--kshot', type=int, default=-1, 57 | help='if set to k (1+), will filter training data to only have up to k examples per sense') 58 | parser.add_argument('--balanced', action='store_true', 59 | help='flag for whether or not to reweight sense losses to be balanced wrt the target word') 60 | 61 | #evaluation arguments 62 | parser.add_argument('--eval', action='store_true', 63 | help='Flag to set script to evaluate probe (rather than train)') 64 | parser.add_argument('--split', type=str, default='semeval2007', 65 | choices=['semeval2007', 'senseval2', 'senseval3', 'semeval2013', 'semeval2015', 'ALL', 'all-test'], 66 | help='Which evaluation split on which to evaluate probe') 67 | 68 | #uses these two gpus if training in multi-gpu 69 | context_device = "cuda:0" 70 | gloss_device = "cuda:1" 71 | 72 | def tokenize_glosses(gloss_arr, tokenizer, max_len): 73 | glosses = [] 74 | masks = [] 75 | for gloss_text in gloss_arr: 76 | g_ids = [torch.tensor([[x]]) for x in tokenizer.encode(tokenizer.cls_token)+tokenizer.encode(gloss_text)+tokenizer.encode(tokenizer.sep_token)] 77 | g_attn_mask = [1]*len(g_ids) 78 | g_fake_mask = [-1]*len(g_ids) 79 | g_ids, g_attn_mask, _ = normalize_length(g_ids, g_attn_mask, g_fake_mask, max_len, pad_id=tokenizer.encode(tokenizer.pad_token)[0]) 80 | g_ids = torch.cat(g_ids, dim=-1) 81 | g_attn_mask = torch.tensor(g_attn_mask) 82 | glosses.append(g_ids) 83 | masks.append(g_attn_mask) 84 | 85 | return glosses, masks 86 | 87 | #creates a sense label/ gloss dictionary for training/using the gloss encoder 88 | def load_and_preprocess_glosses(data, tokenizer, wn_senses, max_len=-1): 89 | sense_glosses = {} 90 | sense_weights = {} 91 | 92 | gloss_lengths = [] 93 | 94 | for sent in data: 95 | for _, lemma, pos, _, label in sent: 96 | if label == -1: 97 | continue #ignore unlabeled words 98 | else: 99 | key = generate_key(lemma, pos) 100 | if key not in sense_glosses: 101 | #get all sensekeys for the lemma/pos pair 102 | sensekey_arr = wn_senses[key] 103 | #get glosses for all candidate senses 104 | gloss_arr = [wn.lemma_from_key(s).synset().definition() for s in sensekey_arr] 105 | 106 | #preprocess glosses into tensors 107 | gloss_ids, gloss_masks = tokenize_glosses(gloss_arr, tokenizer, max_len) 108 | gloss_ids = torch.cat(gloss_ids, dim=0) 109 | gloss_masks = torch.stack(gloss_masks, dim=0) 110 | sense_glosses[key] = (gloss_ids, gloss_masks, sensekey_arr) 111 | 112 | #intialize weights for balancing senses 113 | sense_weights[key] = [0]*len(gloss_arr) 114 | w_idx = sensekey_arr.index(label) 115 | sense_weights[key][w_idx] += 1 116 | else: 117 | #update sense weight counts 118 | w_idx = sense_glosses[key][2].index(label) 119 | sense_weights[key][w_idx] += 1 120 | 121 | #make sure that gold label is retrieved synset 122 | assert label in sense_glosses[key][2] 123 | 124 | #normalize weights 125 | for key in sense_weights: 126 | total_w = sum(sense_weights[key]) 127 | sense_weights[key] = torch.FloatTensor([total_w/x if x !=0 else 0 for x in sense_weights[key]]) 128 | 129 | return sense_glosses, sense_weights 130 | 131 | def preprocess_context(tokenizer, text_data, bsz=1, max_len=-1): 132 | if max_len == -1: assert bsz==1 #otherwise need max_length for padding 133 | 134 | context_ids = [] 135 | context_attn_masks = [] 136 | 137 | example_keys = [] 138 | 139 | context_output_masks = [] 140 | instances = [] 141 | labels = [] 142 | 143 | #tensorize data 144 | for sent in text_data: 145 | c_ids = [torch.tensor([tokenizer.encode(tokenizer.cls_token)])] #cls token aka sos token, returns a list with index 146 | o_masks = [-1] 147 | sent_insts = [] 148 | sent_keys = [] 149 | sent_labels = [] 150 | 151 | #For each word in sentence... 152 | for idx, (word, lemma, pos, inst, label) in enumerate(sent): 153 | #tensorize word for context ids 154 | word_ids = [torch.tensor([[x]]) for x in tokenizer.encode(word.lower())] 155 | c_ids.extend(word_ids) 156 | 157 | #if word is labeled with WSD sense... 158 | if inst != -1: 159 | #add word to bert output mask to be labeled 160 | o_masks.extend([idx]*len(word_ids)) 161 | #track example instance id 162 | sent_insts.append(inst) 163 | #track example instance keys to get glosses 164 | ex_key = generate_key(lemma, pos) 165 | sent_keys.append(ex_key) 166 | sent_labels.append(label) 167 | else: 168 | #mask out output of context encoder for WSD task (not labeled) 169 | o_masks.extend([-1]*len(word_ids)) 170 | 171 | #break if we reach max len 172 | if max_len != -1 and len(c_ids) >= (max_len-1): 173 | break 174 | 175 | c_ids.append(torch.tensor([tokenizer.encode(tokenizer.sep_token)])) #aka eos token 176 | c_attn_mask = [1]*len(c_ids) 177 | o_masks.append(-1) 178 | c_ids, c_attn_masks, o_masks = normalize_length(c_ids, c_attn_mask, o_masks, max_len, pad_id=tokenizer.encode(tokenizer.pad_token)[0]) 179 | 180 | y = torch.tensor([1]*len(sent_insts), dtype=torch.float) 181 | #not including examples sentences with no annotated sense data 182 | if len(sent_insts) > 0: 183 | context_ids.append(torch.cat(c_ids, dim=-1)) 184 | context_attn_masks.append(torch.tensor(c_attn_masks).unsqueeze(dim=0)) 185 | context_output_masks.append(torch.tensor(o_masks).unsqueeze(dim=0)) 186 | example_keys.append(sent_keys) 187 | instances.append(sent_insts) 188 | labels.append(sent_labels) 189 | 190 | #package data 191 | data = list(zip(context_ids, context_attn_masks, context_output_masks, example_keys, instances, labels)) 192 | 193 | #batch data if bsz > 1 194 | if bsz > 1: 195 | print('Batching data with bsz={}...'.format(bsz)) 196 | batched_data = [] 197 | for idx in range(0, len(data), bsz): 198 | if idx+bsz <=len(data): b = data[idx:idx+bsz] 199 | else: b = data[idx:] 200 | context_ids = torch.cat([x for x,_,_,_,_,_ in b], dim=0) 201 | context_attn_mask = torch.cat([x for _,x,_,_,_,_ in b], dim=0) 202 | context_output_mask = torch.cat([x for _,_,x,_,_,_ in b], dim=0) 203 | example_keys = [] 204 | for _,_,_,x,_,_ in b: example_keys.extend(x) 205 | instances = [] 206 | for _,_,_,_,x,_ in b: instances.extend(x) 207 | labels = [] 208 | for _,_,_,_,_,x in b: labels.extend(x) 209 | batched_data.append((context_ids, context_attn_mask, context_output_mask, example_keys, instances, labels)) 210 | return batched_data 211 | else: 212 | return data 213 | 214 | def _train(train_data, model, gloss_dict, optim, schedule, criterion, gloss_bsz=-1, max_grad_norm=1.0, multigpu=False, silent=False, train_steps=-1): 215 | model.train() 216 | total_loss = 0. 217 | 218 | start_time = time.time() 219 | 220 | train_data = enumerate(train_data) 221 | if not silent: train_data = tqdm(list(train_data)) 222 | 223 | for i, (context_ids, context_attn_mask, context_output_mask, example_keys, _, labels) in train_data: 224 | 225 | #reset model 226 | model.zero_grad() 227 | #run example sentence(s) through context encoder 228 | if multigpu: 229 | context_ids = context_ids.to(context_device) 230 | context_attn_mask = context_attn_mask.to(context_device) 231 | else: 232 | context_ids = context_ids.cuda() 233 | context_attn_mask = context_attn_mask.cuda() 234 | context_output = model.context_forward(context_ids, context_attn_mask, context_output_mask) 235 | 236 | loss = 0. 237 | gloss_sz = 0 238 | context_sz = len(labels) 239 | for j, (key, label) in enumerate(zip(example_keys, labels)): 240 | output = context_output.split(1,dim=0)[j] 241 | 242 | #run example's glosses through gloss encoder 243 | gloss_ids, gloss_attn_mask, sense_keys = gloss_dict[key] 244 | if multigpu: 245 | gloss_ids = gloss_ids.to(gloss_device) 246 | gloss_attn_mask = gloss_attn_mask.to(gloss_device) 247 | else: 248 | gloss_ids = gloss_ids.cuda() 249 | gloss_attn_mask = gloss_attn_mask.cuda() 250 | 251 | gloss_output = model.gloss_forward(gloss_ids, gloss_attn_mask) 252 | gloss_output = gloss_output.transpose(0,1) 253 | 254 | #get cosine sim of example from context encoder with gloss embeddings 255 | if multigpu: 256 | output = output.cpu() 257 | gloss_output = gloss_output.cpu() 258 | 259 | output = torch.mm(output, gloss_output) 260 | 261 | #get label and calculate loss 262 | idx = sense_keys.index(label) 263 | label_tensor = torch.tensor([idx]) 264 | if not multigpu: label_tensor = label_tensor.cuda() 265 | 266 | #looks up correct candidate senses criterion 267 | #needed if balancing classes within the candidate senses of a target word 268 | loss += criterion[key](output, label_tensor) 269 | gloss_sz += gloss_output.size(-1) 270 | 271 | if gloss_bsz != -1 and gloss_sz >= gloss_bsz: 272 | #update model 273 | total_loss += loss.item() 274 | loss=loss/gloss_sz 275 | loss.backward() 276 | torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm) 277 | optim.step() 278 | schedule.step() # Update learning rate schedule 279 | 280 | #reset loss and gloss_sz 281 | loss = 0. 282 | gloss_sz = 0 283 | 284 | #reset model 285 | model.zero_grad() 286 | 287 | #rerun context through model 288 | context_output = model.context_forward(context_ids, context_attn_mask, context_output_mask) 289 | 290 | #update model after finishing context batch 291 | if gloss_bsz != -1: loss_sz = gloss_sz 292 | else: loss_sz = context_sz 293 | if loss_sz > 0: 294 | total_loss += loss.item() 295 | loss=loss/loss_sz 296 | loss.backward() 297 | torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm) 298 | optim.step() 299 | schedule.step() # Update learning rate schedule 300 | 301 | #stop epoch early if number of training steps is reached 302 | if train_steps > 0 and i+1 == train_steps: break 303 | 304 | return model, optim, schedule, total_loss 305 | 306 | def _eval(eval_data, model, gloss_dict, multigpu=False): 307 | model.eval() 308 | eval_preds = [] 309 | for context_ids, context_attn_mask, context_output_mask, example_keys, insts, _ in eval_data: 310 | with torch.no_grad(): 311 | #run example through model 312 | if multigpu: 313 | context_ids = context_ids.to(context_device) 314 | context_attn_mask = context_attn_mask.to(context_device) 315 | else: 316 | context_ids = context_ids.cuda() 317 | context_attn_mask = context_attn_mask.cuda() 318 | context_output = model.context_forward(context_ids, context_attn_mask, context_output_mask) 319 | 320 | for output, key, inst in zip(context_output.split(1,dim=0), example_keys, insts): 321 | #run example's glosses through gloss encoder 322 | gloss_ids, gloss_attn_mask, sense_keys = gloss_dict[key] 323 | if multigpu: 324 | gloss_ids = gloss_ids.to(gloss_device) 325 | gloss_attn_mask = gloss_attn_mask.to(gloss_device) 326 | else: 327 | gloss_ids = gloss_ids.cuda() 328 | gloss_attn_mask = gloss_attn_mask.cuda() 329 | gloss_output = model.gloss_forward(gloss_ids, gloss_attn_mask) 330 | gloss_output = gloss_output.transpose(0,1) 331 | 332 | #get cosine sim of example from context encoder with gloss embeddings 333 | if multigpu: 334 | output = output.cpu() 335 | gloss_output = gloss_output.cpu() 336 | output = torch.mm(output, gloss_output) 337 | pred_idx = output.topk(1, dim=-1)[1].squeeze().item() 338 | pred_label = sense_keys[pred_idx] 339 | eval_preds.append((inst, pred_label)) 340 | 341 | return eval_preds 342 | 343 | def train_model(args): 344 | print('Training WSD bi-encoder model...') 345 | if args.freeze_gloss: assert args.gloss_bsz == -1 #no gloss bsz if not training gloss encoder, memory concerns 346 | 347 | #create passed in ckpt dir if doesn't exist 348 | if not os.path.exists(args.ckpt): os.mkdir(args.ckpt) 349 | 350 | ''' 351 | LOAD PRETRAINED TOKENIZER, TRAIN AND DEV DATA 352 | ''' 353 | print('Loading data + preprocessing...') 354 | sys.stdout.flush() 355 | 356 | tokenizer = load_tokenizer(args.encoder_name) 357 | 358 | #loading WSD (semcor) data 359 | train_path = os.path.join(args.data_path, 'Training_Corpora/SemCor/') 360 | train_data = load_data(train_path, 'semcor') 361 | 362 | #filter train data for k-shot learning 363 | if args.kshot > 0: train_data = filter_k_examples(train_data, args.kshot) 364 | 365 | #dev set = semeval2007 366 | semeval2007_path = os.path.join(args.data_path, 'Evaluation_Datasets/semeval2007/') 367 | semeval2007_data = load_data(semeval2007_path, 'semeval2007') 368 | 369 | #load gloss dictionary (all senses from wordnet for each lemma/pos pair that occur in data) 370 | wn_path = os.path.join(args.data_path, 'Data_Validation/candidatesWN30.txt') 371 | wn_senses = load_wn_senses(wn_path) 372 | train_gloss_dict, train_gloss_weights = load_and_preprocess_glosses(train_data, tokenizer, wn_senses, max_len=args.gloss_max_length) 373 | semeval2007_gloss_dict, _ = load_and_preprocess_glosses(semeval2007_data, tokenizer, wn_senses, max_len=args.gloss_max_length) 374 | 375 | #preprocess and batch data (context + glosses) 376 | train_data = preprocess_context(tokenizer, train_data, bsz=args.context_bsz, max_len=args.context_max_length) 377 | semeval2007_data = preprocess_context(tokenizer, semeval2007_data, bsz=1, max_len=args.context_max_length) 378 | 379 | epochs = args.epochs 380 | overflow_steps = -1 381 | t_total = len(train_data)*epochs 382 | 383 | #if few-shot training, override epochs to calculate num. epochs + steps for equal training signal 384 | if args.kshot > 0: 385 | #hard-coded num. of steps of fair kshot evaluation against full model on default numer of epochs 386 | NUM_STEPS = 181500 #num batches in full train data (9075) * 20 epochs 387 | num_batches = len(train_data) 388 | epochs = NUM_STEPS//num_batches #recalculate number of epochs 389 | overflow_steps = NUM_STEPS%num_batches #num steps in last overflow epoch (if there is one, otherwise 0) 390 | t_total = NUM_STEPS #manually set number of steps for lr schedule 391 | if overflow_steps > 0: epochs+=1 #add extra epoch for overflow steps 392 | print('Overriding args.epochs and training for {} epochs...'.format(epochs)) 393 | 394 | ''' 395 | SET UP FINETUNING MODEL, OPTIMIZER, AND LR SCHEDULE 396 | ''' 397 | model = BiEncoderModel(args.encoder_name, freeze_gloss=args.freeze_gloss, freeze_context=args.freeze_context, tie_encoders=args.tie_encoders) 398 | 399 | #speeding up training by putting two encoders on seperate gpus (instead of data parallel) 400 | if args.multigpu: 401 | model.gloss_encoder = model.gloss_encoder.to(gloss_device) 402 | model.context_encoder = model.context_encoder.to(context_device) 403 | else: 404 | model = model.cuda() 405 | 406 | criterion = {} 407 | if args.balanced: 408 | for key in train_gloss_dict: 409 | criterion[key] = torch.nn.CrossEntropyLoss(reduction='none', weight=train_gloss_weights[key]) 410 | else: 411 | for key in train_gloss_dict: 412 | criterion[key] = torch.nn.CrossEntropyLoss(reduction='none') 413 | 414 | #optimize + scheduler from pytorch_transformers package 415 | #this taken from pytorch_transformers finetuning code 416 | weight_decay = 0.0 #this could be a parameter 417 | no_decay = ['bias', 'LayerNorm.weight'] 418 | optimizer_grouped_parameters = [ 419 | {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': weight_decay}, 420 | {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} 421 | ] 422 | adam_epsilon = 1e-8 423 | optimizer = AdamW(optimizer_grouped_parameters, lr=args.lr, eps=adam_epsilon) 424 | schedule = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup, t_total=t_total) 425 | 426 | ''' 427 | TRAIN MODEL ON SEMCOR DATA 428 | ''' 429 | 430 | best_dev_f1 = 0. 431 | print('Training probe...') 432 | sys.stdout.flush() 433 | 434 | for epoch in range(1, epochs+1): 435 | #if last epoch, pass in overflow steps to stop epoch early 436 | train_steps = -1 437 | if epoch == epochs and overflow_steps > 0: train_steps = overflow_steps 438 | 439 | #train model for one epoch or given number of training steps 440 | model, optimizer, schedule, train_loss = _train(train_data, model, train_gloss_dict, optimizer, schedule, criterion, gloss_bsz=args.gloss_bsz, max_grad_norm=args.grad_norm, silent=args.silent, multigpu=args.multigpu, train_steps=train_steps) 441 | 442 | #eval model on dev set (semeval2007) 443 | eval_preds = _eval(semeval2007_data, model, semeval2007_gloss_dict, multigpu=args.multigpu) 444 | 445 | #generate predictions file 446 | pred_filepath = os.path.join(args.ckpt, 'tmp_predictions.txt') 447 | with open(pred_filepath, 'w') as f: 448 | for inst, prediction in eval_preds: 449 | f.write('{} {}\n'.format(inst, prediction)) 450 | 451 | #run predictions through scorer 452 | gold_filepath = os.path.join(args.data_path, 'Evaluation_Datasets/semeval2007/semeval2007.gold.key.txt') 453 | scorer_path = os.path.join(args.data_path, 'Evaluation_Datasets') 454 | _, _, dev_f1 = evaluate_output(scorer_path, gold_filepath, pred_filepath) 455 | print('Dev f1 after {} epochs = {}'.format(epoch, dev_f1)) 456 | sys.stdout.flush() 457 | 458 | if dev_f1 >= best_dev_f1: 459 | print('updating best model at epoch {}...'.format(epoch)) 460 | sys.stdout.flush() 461 | best_dev_f1 = dev_f1 462 | #save to file if best probe so far on dev set 463 | model_fname = os.path.join(args.ckpt, 'best_model.ckpt') 464 | with open(model_fname, 'wb') as f: 465 | torch.save(model.state_dict(), f) 466 | sys.stdout.flush() 467 | 468 | #shuffle train set ordering after every epoch 469 | random.shuffle(train_data) 470 | return 471 | 472 | def evaluate_model(args): 473 | print('Evaluating WSD model on {}...'.format(args.split)) 474 | 475 | ''' 476 | LOAD TRAINED MODEL 477 | ''' 478 | model = BiEncoderModel(args.encoder_name, freeze_gloss=args.freeze_gloss, freeze_context=args.freeze_context) 479 | model_path = os.path.join(args.ckpt, 'best_model.ckpt') 480 | model.load_state_dict(torch.load(model_path)) 481 | model = model.cuda() 482 | 483 | 484 | ''' 485 | LOAD TOKENIZER 486 | ''' 487 | tokenizer = load_tokenizer(args.encoder_name) 488 | 489 | ''' 490 | LOAD EVAL SET 491 | ''' 492 | eval_path = os.path.join(args.data_path, 'Evaluation_Datasets/{}/'.format(args.split)) 493 | eval_data = load_data(eval_path, args.split) 494 | 495 | #load gloss dictionary (all senses from wordnet for each lemma/pos pair that occur in data) 496 | wn_path = os.path.join(args.data_path, 'Data_Validation/candidatesWN30.txt') 497 | wn_senses = load_wn_senses(wn_path) 498 | gloss_dict, _ = load_and_preprocess_glosses(eval_data, tokenizer, wn_senses, max_len=32) 499 | 500 | eval_data = preprocess_context(tokenizer, eval_data, bsz=1, max_len=-1) 501 | 502 | ''' 503 | EVALUATE MODEL 504 | ''' 505 | eval_preds = _eval(eval_data, model, gloss_dict, multigpu=False) 506 | 507 | #generate predictions file 508 | pred_filepath = os.path.join(args.ckpt, './{}_predictions.txt'.format(args.split)) 509 | with open(pred_filepath, 'w') as f: 510 | for inst, prediction in eval_preds: 511 | f.write('{} {}\n'.format(inst, prediction)) 512 | 513 | #run predictions through scorer 514 | gold_filepath = os.path.join(eval_path, '{}.gold.key.txt'.format(args.split)) 515 | scorer_path = os.path.join(args.data_path, 'Evaluation_Datasets') 516 | p, r, f1 = evaluate_output(scorer_path, gold_filepath, pred_filepath) 517 | print('f1 of BERT probe on {} test set = {}'.format(args.split, f1)) 518 | 519 | return 520 | 521 | if __name__ == "__main__": 522 | if not torch.cuda.is_available(): 523 | print("Need available GPU(s) to run this model...") 524 | quit() 525 | 526 | #parse args 527 | args = parser.parse_args() 528 | print(args) 529 | 530 | #set random seeds 531 | torch.manual_seed(args.rand_seed) 532 | os.environ['PYTHONHASHSEED'] = str(args.rand_seed) 533 | torch.cuda.manual_seed(args.rand_seed) 534 | torch.cuda.manual_seed_all(args.rand_seed) 535 | np.random.seed(args.rand_seed) 536 | random.seed(args.rand_seed) 537 | torch.backends.cudnn.benchmark = False 538 | torch.backends.cudnn.deterministic=True 539 | 540 | #evaluate model saved at checkpoint or... 541 | if args.eval: evaluate_model(args) 542 | #train model 543 | else: train_model(args) 544 | 545 | #EOF -------------------------------------------------------------------------------- /docs/wsd_biencoder_architecture.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/wsd-biencoders/dc06aa8f027b7e69f907f955c89bdda73afdd4c3/docs/wsd_biencoder_architecture.jpg -------------------------------------------------------------------------------- /finetune_pretrained_encoder.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | All rights reserved. 4 | 5 | This source code is licensed under the license found in the 6 | LICENSE file in the root directory of this source tree. 7 | ''' 8 | 9 | import torch 10 | from torch.nn import functional as F 11 | import os 12 | import sys 13 | import math 14 | import copy 15 | import argparse 16 | from tqdm import tqdm 17 | import pickle 18 | from pytorch_transformers import * 19 | 20 | import random 21 | import numpy as np 22 | 23 | from wsd_models.util import * 24 | from wsd_models.models import PretrainedClassifier 25 | 26 | parser = argparse.ArgumentParser(description='Finetuning Pretrained Encoders for WSD') 27 | parser.add_argument('--rand_seed', type=int, default=42) 28 | parser.add_argument('--grad-norm', type=float, default=1.0) 29 | parser.add_argument('--silent', action='store_true', 30 | help='Flag to supress training progress bar for each epoch') 31 | parser.add_argument('--multigpu', action='store_true') 32 | parser.add_argument('--lr', type=float, default=0.0001) 33 | parser.add_argument('--warmup', type=int, default=2000) 34 | parser.add_argument('--max-length', type=int, default=128) 35 | parser.add_argument('--epochs', type=int, default=10) 36 | parser.add_argument('--bsz', type=int, default=8) 37 | parser.add_argument('--encoder-name', type=str, default='bert-base', 38 | choices=['bert-base', 'bert-large', 'roberta-base', 'roberta-large']) 39 | parser.add_argument('--ckpt', type=str, required=True, 40 | help='filepath at which to save best probing model (on dev set)') 41 | parser.add_argument('--proj-ckpt', type=str, default='', 42 | help='filepath to a pretrained projection layer/probe (trained with frozen_pretrained_model.py) to optionally use that as projection layer initalization') 43 | parser.add_argument('--data-path', type=str, required=True, 44 | help='Location of top-level directory for the Unified WSD Framework') 45 | 46 | parser.add_argument('--eval', action='store_true', 47 | help='Flag to set script to evaluate probe (rather than train)') 48 | parser.add_argument('--split', type=str, default='semeval2007', 49 | choices=['semeval2007', 'senseval2', 'senseval3', 'semeval2013', 'semeval2015', 'ALL', 'all-test'], 50 | help='Which evaluation split on which to evaluate probe') 51 | 52 | #updated to organize keys by sentence to work with pretrained model 53 | def wn_keys(data): 54 | keys = [] 55 | for sent in data: 56 | sent_keys = [] 57 | for form, lemma, pos, inst, _ in sent: 58 | if inst != -1: 59 | key = generate_key(lemma, pos) 60 | sent_keys.append(key) 61 | if len(sent_keys) > 0: keys.append(sent_keys) 62 | return keys 63 | 64 | #takes in text data and indexes it for pretrained encoder + batching 65 | #updated to return data organized by sentence to work with pretrained model 66 | def preprocess(tokenizer, text_data, label_space, label_map, bsz=1, max_len=-1): 67 | if max_len == -1: 68 | assert bsz==1 #otherwise need max_len for padding 69 | 70 | input_ids = [] 71 | input_masks = [] 72 | bert_masks = [] 73 | output_masks = [] 74 | instances = [] 75 | label_indexes = [] 76 | 77 | #tensorize data 78 | for sent in text_data: 79 | sent_ids = [torch.tensor([tokenizer.encode(tokenizer.cls_token)])] #cls token aka sos token, returns a list with index 80 | b_masks = [-1] 81 | o_masks = [] 82 | sent_insts = [] 83 | sent_labels = [] 84 | 85 | ex_count = 0 #DEBUGGING 86 | for idx, (word, lemma, pos, inst, label) in enumerate(sent): 87 | word_ids = [torch.tensor([[x]]) for x in tokenizer.encode(word.lower())] 88 | sent_ids.extend(word_ids) 89 | 90 | if inst != -1: 91 | ex_count += 1 #DEBUGGING 92 | b_masks.extend([idx]*len(word_ids)) 93 | 94 | sent_insts.append(inst) 95 | if label in label_space: 96 | sent_labels.append(torch.tensor([label_space.index(label)])) 97 | else: 98 | sent_labels.append(torch.tensor([label_space.index('n/a')])) 99 | 100 | #adding appropriate label space for sense-labeled word (we only use this for wsd task) 101 | key = generate_key(lemma, pos) 102 | if key in label_map: 103 | l_space = label_map[key] 104 | mask = torch.zeros(len(label_space)) 105 | for l in l_space: mask[l] = 1 106 | o_masks.append(mask) 107 | else: 108 | o_masks.append(torch.ones(len(label_space))) #let this predict whatever -- should not use this (default to backoff for unseen forms) 109 | 110 | else: 111 | b_masks.extend([-1]*len(word_ids)) 112 | 113 | #break if we reach max len so we don't keep overflowing examples 114 | if max_len != -1 and len(sent_ids) >= (max_len-1): 115 | break 116 | 117 | #pad inputs + add eos token 118 | sent_ids.append(torch.tensor([tokenizer.encode(tokenizer.sep_token)])) #aka eos token 119 | input_mask = [1]*len(sent_ids) 120 | b_masks.append(-1) 121 | sent_ids, input_mask, b_masks = normalize_length(sent_ids, input_mask, b_masks, max_len, pad_id=tokenizer.encode(tokenizer.pad_token)[0]) 122 | 123 | #not including examples sentences with no annotated sense data 124 | if len(sent_insts) > 0: 125 | input_ids.append(torch.cat(sent_ids, dim=-1)) 126 | input_masks.append(torch.tensor(input_mask).unsqueeze(dim=0)) 127 | bert_masks.append(torch.tensor(b_masks).unsqueeze(dim=0)) 128 | output_masks.append(torch.stack(o_masks, dim=0)) 129 | instances.append(sent_insts) 130 | label_indexes.append(torch.cat(sent_labels, dim=0)) 131 | 132 | #batch data now that we pad it 133 | data = list(zip(input_ids, input_masks, bert_masks, output_masks, instances, label_indexes)) 134 | if bsz > 1: 135 | print('Batching data with bsz={}...'.format(bsz)) 136 | batched_data = [] 137 | for idx in range(0, len(data), bsz): 138 | if idx+bsz <=len(data): b = data[idx:idx+bsz] 139 | else: b = data[idx:] 140 | input_ids = torch.cat([x for x,_,_,_,_,_ in b], dim=0) 141 | input_mask = torch.cat([x for _,x,_,_,_,_ in b], dim=0) 142 | bert_mask = torch.cat([x for _,_,x,_,_,_ in b], dim=0) 143 | output_mask = torch.cat([x for _,_,_,x,_,_ in b], dim=0) 144 | instances = [] 145 | for _,_,_,_,x,_ in b: instances.extend(x) 146 | labels = torch.cat([x for _,_,_,_,_,x in b], dim=0) 147 | batched_data.append((input_ids, input_mask, bert_mask, output_mask, instances, labels)) 148 | return batched_data 149 | else: return data 150 | 151 | def _train(train_data, model, optim, schedule, criterion, max_grad_norm=1.0, silent=False): 152 | model.train() 153 | total_loss = 0. 154 | 155 | if not silent: train_data = tqdm(train_data) 156 | for input_ids, input_mask, bert_mask, output_mask, _, label in train_data: 157 | input_ids = input_ids.cuda() 158 | input_mask = input_mask.cuda() 159 | output_mask = output_mask.cuda() 160 | label = label.cuda() 161 | 162 | model.zero_grad() 163 | output = model.forward(input_ids, input_mask, bert_mask) 164 | #mask output to appropriate senses for target word 165 | output = torch.mul(output, output_mask) 166 | #set masked out items to -inf to get proper probabilities over the candidate senses 167 | output[output == 0] = float('-inf') 168 | 169 | output = F.softmax(output, dim=-1) 170 | 171 | loss = criterion(output, label) 172 | total_loss += loss.sum().item() 173 | loss_sz = loss.size(0) 174 | loss=loss.sum()/loss_sz 175 | loss.backward() 176 | 177 | torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm) 178 | optim.step() 179 | schedule.step() # Update learning rate schedule 180 | 181 | return model, optim, schedule, total_loss 182 | 183 | def _eval(eval_data, model, label_space): 184 | model.eval() 185 | eval_preds = [] 186 | for input_ids, input_mask, bert_mask, output_mask, insts, labels in eval_data: 187 | input_ids = input_ids.cuda() 188 | input_mask = input_mask.cuda() 189 | output_mask = output_mask.cuda() 190 | labels = labels.cuda() 191 | 192 | #run example through model 193 | with torch.no_grad(): 194 | outputs = model.forward(input_ids, input_mask, bert_mask) 195 | #mask to candidate senses for target word 196 | outputs = torch.mul(outputs, output_mask) 197 | #set masked out items to -inf to get proper probabilities over the candidate senses 198 | outputs[outputs == 0] = float('-inf') 199 | 200 | outputs = F.softmax(outputs, dim=-1) 201 | 202 | for i, output in enumerate(outputs): 203 | inst = insts[i] 204 | #get predicted label 205 | pred_id = output.topk(1, dim=-1)[1].squeeze().item() 206 | pred_label = label_space[pred_id] 207 | eval_preds.append((inst, pred_label)) 208 | 209 | return eval_preds 210 | 211 | def _eval_with_backoff(eval_data, model, label_space, wn_senses, coverage, keys): 212 | model.eval() 213 | eval_preds = [] 214 | 215 | for sent_keys, (input_ids, input_mask, bert_mask, output_mask, insts, _) in zip(keys, eval_data): 216 | input_ids = input_ids.cuda() 217 | input_mask = input_mask.cuda() 218 | output_mask = output_mask.cuda() 219 | #run example through model 220 | with torch.no_grad(): 221 | outputs = model.forward(input_ids, input_mask, bert_mask) 222 | #mask to candidate senses for target word 223 | outputs = torch.mul(outputs, output_mask) 224 | #set masked out items to -inf to get proper probabilities over the candidate senses 225 | outputs[outputs == 0] = float('-inf') 226 | 227 | outputs = F.softmax(outputs, dim=-1) 228 | 229 | for i, output in enumerate(outputs): 230 | k = sent_keys[i] 231 | inst = insts[i] 232 | if k in coverage: 233 | #get predicted label 234 | pred_id = output.topk(1, dim=-1)[1].squeeze().item() 235 | pred_label = label_space[pred_id] 236 | eval_preds.append((inst, pred_label)) 237 | 238 | #backoff to wsd for lemma+pos 239 | else: 240 | #this is ws1 for given key 241 | pred_label = wn_senses[k][0] 242 | eval_preds.append((inst, pred_label)) 243 | 244 | return eval_preds 245 | 246 | def train_model(args): 247 | print('Finetuning pretrained model on WSD...') 248 | #create passed in ckpt dir if doesn't exist 249 | if not os.path.exists(args.ckpt): os.mkdir(args.ckpt) 250 | 251 | ''' 252 | LOAD PRETRAINED MODEL'S TOKENIZER 253 | ''' 254 | #model loading code based on pytorch_transformers README example 255 | tokenizer = load_tokenizer(args.encoder_name) 256 | 257 | ''' 258 | LOADING IN TRAINING AND EVAL DATA 259 | ''' 260 | print('Loading data + preprocessing...') 261 | sys.stdout.flush() 262 | #loading WSD (semcor) data + convert to supersenses 263 | train_path = os.path.join(args.data_path, 'Training_Corpora/SemCor/') 264 | train_data = load_data(train_path, 'semcor') 265 | 266 | #calculate label space 267 | label_space, label_map = get_label_space(train_data) 268 | print('num labels = {} + 1 unknown label'.format(len(label_space)-1)) 269 | 270 | train_data = preprocess(tokenizer, train_data, label_space, label_map, bsz=args.bsz, max_len=args.max_length) 271 | 272 | #dev set = semeval2007 273 | semeval2007_path = os.path.join(args.data_path, 'Evaluation_Datasets/semeval2007/') 274 | semeval2007_data = load_data(semeval2007_path, 'semeval2007') 275 | semeval2007_data = preprocess(tokenizer, semeval2007_data, label_space, label_map, bsz=1, max_len=-1) 276 | 277 | ''' 278 | SET UP FINETUNING MODEL, OPTIMIZER, AND LR SCHEDULE 279 | ''' 280 | 281 | model = PretrainedClassifier(len(label_space), args.encoder_name, args.proj_ckpt) 282 | if args.multigpu: model = torch.nn.DataParallel(model) 283 | model = model.cuda() 284 | 285 | criterion = torch.nn.CrossEntropyLoss(reduction='none') 286 | 287 | #optimize + scheduler from pytorch_transformers package 288 | #this is from pytorch_transformers finetuning code 289 | weight_decay = 0.0 #this could be a parameter 290 | no_decay = ['bias', 'LayerNorm.weight'] 291 | optimizer_grouped_parameters = [ 292 | {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': weight_decay}, 293 | {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} 294 | ] 295 | adam_epsilon = 1e-8 296 | t_total = len(train_data)*args.epochs 297 | optimizer = AdamW(optimizer_grouped_parameters, lr=args.lr, eps=adam_epsilon) 298 | schedule = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup, t_total=t_total) 299 | 300 | ''' 301 | TRAIN FINETUNING MODEL ON SEMCOR DATA 302 | ''' 303 | 304 | best_dev_f1 = 0. 305 | print('Training probe...') 306 | sys.stdout.flush() 307 | 308 | for epoch in range(1, args.epochs+1): 309 | #train on full dataset 310 | 311 | model, optimizer, schedule, train_loss = _train(train_data, model, optimizer, schedule, criterion, max_grad_norm=args.grad_norm, silent=args.silent) 312 | #eval probe on dev set (semeval2007) 313 | eval_preds = _eval(semeval2007_data, model, label_space) 314 | 315 | #generate predictions file 316 | pred_filepath = os.path.join(args.ckpt, 'tmp_predictions.txt') 317 | with open(pred_filepath, 'w') as f: 318 | for inst, prediction in eval_preds: 319 | f.write('{} {}\n'.format(inst, prediction)) 320 | 321 | #run predictions through scorer 322 | gold_filepath = os.path.join(args.data_path, 'Evaluation_Datasets/semeval2007/semeval2007.gold.key.txt') 323 | scorer_path = os.path.join(args.data_path, 'Evaluation_Datasets') 324 | _, _, dev_f1 = evaluate_output(scorer_path, gold_filepath, pred_filepath) 325 | print('Dev f1 after {} epochs = {}'.format(epoch, dev_f1)) 326 | sys.stdout.flush() 327 | 328 | if dev_f1 >= best_dev_f1: 329 | print('updating best model at epoch {}...'.format(epoch)) 330 | sys.stdout.flush() 331 | best_dev_f1 = dev_f1 332 | #save to file if best probe so far on dev set 333 | model_fname = os.path.join(args.ckpt, 'best_model.ckpt') 334 | with open(model_fname, 'wb') as f: 335 | torch.save(model.state_dict(), f) 336 | sys.stdout.flush() 337 | 338 | #shuffle train data after every epoch 339 | random.shuffle(train_data) 340 | 341 | 342 | return 343 | 344 | def evaluate_model(args): 345 | print('Evaluating model on {} for WSD...'.format(args.split)) 346 | 347 | ''' 348 | LOAD LABEL SPACE 349 | ''' 350 | train_path = os.path.join(args.data_path, 'Training_Corpora/SemCor/') 351 | train_data = load_data(train_path, 'semcor') 352 | coverage = set([k for keys in wn_keys(train_data) for k in keys]) 353 | task_labels, label_map = get_label_space(train_data) 354 | 355 | ''' 356 | LOAD TRAINED MODEL 357 | ''' 358 | model = PretrainedClassifier(len(task_labels), args.encoder_name, '') 359 | model_path = os.path.join(args.ckpt, 'best_model.ckpt') 360 | model.load_state_dict(torch.load(model_path)) 361 | model = model.cuda() 362 | 363 | ''' 364 | LOAD TOKENIZER 365 | ''' 366 | tokenizer = load_tokenizer(args.encoder_name) 367 | 368 | ''' 369 | LOAD EVAL SET 370 | ''' 371 | eval_path = os.path.join(args.data_path, 'Evaluation_Datasets/{}/'.format(args.split)) 372 | eval_data = load_data(eval_path, args.split) 373 | #get keys to perform evaluation with backoff 374 | eval_keys = wn_keys(eval_data) 375 | eval_data = preprocess(tokenizer, eval_data, task_labels, label_map, bsz=1, max_len=-1) 376 | 377 | ''' 378 | EVALUATE MODEL w/o backoff 379 | ''' 380 | eval_preds= _eval(eval_data, model, task_labels) 381 | 382 | #generate predictions file 383 | pred_filepath = os.path.join(args.ckpt, './{}_predictions.txt'.format(args.split)) 384 | with open(pred_filepath, 'w') as f: 385 | for inst, prediction in eval_preds: 386 | f.write('{} {}\n'.format(inst, prediction)) 387 | 388 | #run predictions through scorer 389 | gold_filepath = os.path.join(eval_path, '{}.gold.key.txt'.format(args.split)) 390 | scorer_path = os.path.join(args.data_path, 'Evaluation_Datasets') 391 | p, r, f1 = evaluate_output(scorer_path, gold_filepath, pred_filepath) 392 | print('F1 on {} test set = {}'.format(args.split, f1)) 393 | 394 | wn_path = os.path.join(args.data_path, 'Data_Validation/candidatesWN30.txt') 395 | wn_senses = load_wn_senses(wn_path) 396 | eval_preds = _eval_with_backoff(eval_data, model, task_labels, wn_senses, coverage, eval_keys) 397 | 398 | #generate predictions file 399 | pred_filepath = os.path.join(args.ckpt, './{}_backoff_predictions.txt'.format(args.split)) 400 | with open(pred_filepath, 'w') as f: 401 | for inst, prediction in eval_preds: 402 | f.write('{} {}\n'.format(inst, prediction)) 403 | 404 | #run predictions through scorer 405 | gold_filepath = os.path.join(eval_path, '{}.gold.key.txt'.format(args.split)) 406 | scorer_path = os.path.join(args.data_path, 'Evaluation_Datasets') 407 | p, r, f1 = evaluate_output(scorer_path, gold_filepath, pred_filepath) 408 | print('F1 (with backoff) = {}'.format(f1)) 409 | 410 | return 411 | 412 | if __name__ == "__main__": 413 | if not torch.cuda.is_available(): 414 | print("Need available GPU(s) to run this model...") 415 | quit() 416 | 417 | #parse args 418 | args = parser.parse_args() 419 | print(args) 420 | 421 | #set random seeds 422 | torch.manual_seed(args.rand_seed) 423 | os.environ['PYTHONHASHSEED'] = str(args.rand_seed) 424 | torch.cuda.manual_seed(args.rand_seed) 425 | torch.cuda.manual_seed_all(args.rand_seed) 426 | np.random.seed(args.rand_seed) 427 | random.seed(args.rand_seed) 428 | torch.backends.cudnn.benchmark = False 429 | torch.backends.cudnn.deterministic=True 430 | 431 | #evaluate model saved at checkpoint or... 432 | if args.eval: evaluate_model(args) 433 | #finetune pretrained model 434 | else: train_model(args) 435 | 436 | 437 | -------------------------------------------------------------------------------- /frozen_pretrained_encoder.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | All rights reserved. 4 | 5 | This source code is licensed under the license found in the 6 | LICENSE file in the root directory of this source tree. 7 | ''' 8 | 9 | import torch 10 | from torch.nn import functional as F 11 | import os 12 | import sys 13 | import copy 14 | import argparse 15 | from tqdm import tqdm 16 | import pickle 17 | from pytorch_transformers import * 18 | 19 | import numpy as np 20 | import random 21 | 22 | from wsd_models.util import * 23 | 24 | parser = argparse.ArgumentParser(description='BERT Frozen Probing Model for WSD') 25 | parser.add_argument('--rand_seed', type=int, default=42) 26 | parser.add_argument('--silent', action='store_true', 27 | help='Flag to supress training progress bar for each epoch') 28 | parser.add_argument('--lr', type=float, default=0.0001) 29 | parser.add_argument('--epochs', type=int, default=100) 30 | parser.add_argument('--bsz', type=int, default=128) 31 | parser.add_argument('--ckpt', type=str, required=True, 32 | help='filepath at which to save best probing model (on dev set)') 33 | parser.add_argument('--encoder-name', type=str, default='bert-base', 34 | choices=['bert-base', 'bert-large', 'roberta-base', 'roberta-large']) 35 | parser.add_argument('--kshot', type=int, default=-1, 36 | help='if set to k (1+), will filter training data to only have up to k examples per sense') 37 | parser.add_argument('--data-path', type=str, required=True, 38 | help='Location of top-level directory for the Unified WSD Framework') 39 | 40 | parser.add_argument('--eval', action='store_true', 41 | help='Flag to set script to evaluate probe (rather than train)') 42 | parser.add_argument('--split', type=str, default='semeval2007', 43 | choices=['semeval2007', 'senseval2', 'senseval3', 'semeval2013', 'semeval2015', 'ALL', 'all-test'], 44 | help='Which evaluation split on which to evaluate probe') 45 | 46 | def wn_keys(data): 47 | keys = [] 48 | for sent in data: 49 | for form, lemma, pos, inst, _ in sent: 50 | if inst != -1: 51 | key = generate_key(lemma, pos) 52 | keys.append(key) 53 | return keys 54 | 55 | def batchify(data, bsz=1): 56 | print('Batching data with bsz={}...'.format(bsz)) 57 | batched_data = [] 58 | for i in range(0, len(data), bsz): 59 | if i+bsz < len(data): d_arr = data[i:i+bsz] 60 | else: d_arr = data[i:] #get remainder examples 61 | batched_ids = torch.cat([ids for ids, _, _, _ in d_arr], dim=0) 62 | batched_masks = torch.stack([mask for _, mask, _, _ in d_arr], dim=0) 63 | batched_insts = [inst for _, _, inst, _ in d_arr] 64 | batched_labels = torch.cat([label for _, _, _, label in d_arr], dim=0) 65 | batched_data.append((batched_ids, batched_masks, batched_insts, batched_labels)) 66 | return batched_data 67 | 68 | #takes in text data, tensorizes it for BERT, runs though BERT, 69 | #filters out the context words (not labeled), and averages 70 | #the representation(s) for words/phrases to be disambiguated 71 | #output is tuples of (input tensor prepared for linear probing model, 72 | #instance numbers (for dataset), tensor of label indexes) 73 | def preprocess(tokenizer, context_model, text_data, label_space, label_map): 74 | processed_examples = [] 75 | output_masks = [] 76 | instances = [] 77 | label_indexes = [] 78 | 79 | #tensorize data 80 | for sent in tqdm(text_data): 81 | sent_ids = [torch.tensor([tokenizer.encode(tokenizer.cls_token)])] #aka sos token, returns a list with single index 82 | bert_mask = [-1] 83 | for idx, (word, lemma, pos, inst, label) in enumerate(sent): 84 | word_ids = torch.tensor([tokenizer.encode(word.lower())]) 85 | sent_ids.append(word_ids) 86 | 87 | if inst != -1: 88 | #masking for averaging of bert outputs 89 | bert_mask.extend([idx]*word_ids.size(-1)) 90 | 91 | #tracking instance for sense-labeled word 92 | instances.append(inst) 93 | 94 | #adding label tensor for senes-labeled word 95 | if label in label_space: 96 | label_indexes.append(torch.tensor([label_space.index(label)])) 97 | else: 98 | label_indexes.append(torch.tensor([label_space.index('n/a')])) 99 | 100 | #adding appropriate label space for sense-labeled word (we only use this for wsd task) 101 | key = generate_key(lemma, pos) 102 | if key in label_map: 103 | l_space = label_map[key] 104 | o_mask = torch.zeros(len(label_space)) 105 | for l in l_space: o_mask[l] = 1 106 | output_masks.append(o_mask) 107 | else: 108 | output_masks.append(torch.ones(len(label_space))) #let this predict whatever -- should not use this (default to backoff for unseen forms) 109 | 110 | else: 111 | bert_mask.extend([-1]*word_ids.size(-1)) 112 | 113 | #add eos token 114 | sent_ids.append(torch.tensor([tokenizer.encode(tokenizer.sep_token)])) #aka eos token 115 | bert_mask.append(-1) 116 | 117 | sent_ids = torch.cat(sent_ids, dim=-1) 118 | 119 | #run inputs through frozen bert 120 | sent_ids = sent_ids.cuda() 121 | 122 | with torch.no_grad(): 123 | output = context_model(sent_ids)[0].squeeze().cpu() 124 | 125 | #average outputs for subword units in same word/phrase, drop unlabeled words 126 | combined_outputs = process_encoder_outputs(output, bert_mask) 127 | processed_examples.extend(combined_outputs) 128 | 129 | #package preprocessed data together + return 130 | data = list(zip(processed_examples, output_masks, instances, label_indexes)) 131 | return data 132 | 133 | def _train(train_data, probe, optim, criterion, bsz=1, silent=False): 134 | if not silent: train_data = tqdm(train_data) 135 | for input_ids, output_mask, _, label in train_data: 136 | input_ids = input_ids.cuda() 137 | output_mask = output_mask.cuda() 138 | label = label.cuda() 139 | 140 | optim.zero_grad() 141 | output = probe(input_ids) 142 | #mask to candidate senses for target word 143 | output = torch.mul(output, output_mask) 144 | #set masked out items to -inf to get proper probabilities over the candidate senses 145 | output[output == 0] = float('-inf') 146 | 147 | output = F.softmax(output, dim=-1) 148 | loss = criterion(output, label) 149 | 150 | batch_sz = loss.size(0) 151 | loss = loss.sum()/batch_sz 152 | loss.backward() 153 | optim.step() 154 | 155 | return probe, optim 156 | 157 | def _eval(eval_data, probe, label_space): 158 | eval_preds = [] 159 | for input_ids, output_mask, inst, _ in eval_data: 160 | input_ids = input_ids.cuda() 161 | output_mask = output_mask.cuda() 162 | 163 | #run example through model 164 | with torch.no_grad(): 165 | output = probe(input_ids) 166 | #mask to candidate senses for target word 167 | output = torch.mul(output, output_mask) 168 | #set masked out items to -inf to get proper probabilities over the candidate senses 169 | output[output == 0] = float('-inf') 170 | output = F.softmax(output, dim=-1) 171 | 172 | #get predicted label 173 | pred_id = output.topk(1, dim=-1)[1].squeeze().item() 174 | pred_label = label_space[pred_id] 175 | eval_preds.append((inst[0], pred_label)) 176 | return eval_preds 177 | 178 | def _eval_with_backoff(eval_data, probe, label_space, wn_senses, coverage, keys): 179 | eval_preds = [] 180 | for key, (input_ids, output_mask, inst, _) in zip(keys, eval_data): 181 | input_ids = input_ids.cuda() 182 | output_mask = output_mask.cuda() 183 | 184 | if key in coverage: 185 | #run example through model 186 | with torch.no_grad(): 187 | output = probe(input_ids) 188 | output = torch.mul(output, output_mask) 189 | #set masked out items to -inf to get proper probabilities over the candidate senses 190 | output[output == 0] = float('-inf') 191 | 192 | output = F.softmax(output, dim=-1) 193 | #get predicted label 194 | pred_id = output.topk(1, dim=-1)[1].squeeze().item() 195 | pred_label = label_space[pred_id] 196 | eval_preds.append((inst[0], pred_label)) 197 | #backoff to wsd for lemma+pos 198 | else: 199 | #this is ws1 for given key 200 | pred_label = wn_senses[key][0] 201 | eval_preds.append((inst[0], pred_label)) 202 | 203 | return eval_preds 204 | 205 | def train_probe(args): 206 | lr = args.lr 207 | bsz = args.bsz 208 | 209 | #create passed in ckpt dir if doesn't exist 210 | if not os.path.exists(args.ckpt): os.mkdir(args.ckpt) 211 | 212 | ''' 213 | LOAD PRETRAINED BERT MODEL 214 | ''' 215 | 216 | #model loading code based on pytorch_transformers README example 217 | tokenizer = load_tokenizer(args.encoder_name) 218 | pretrained_model, output_dim = load_pretrained_model(args.encoder_name) 219 | pretrained_model = pretrained_model.cuda() 220 | 221 | ''' 222 | LOADING IN TRAINING AND EVAL DATA 223 | ''' 224 | print('Loading data + preprocessing...') 225 | sys.stdout.flush() 226 | #loading WSD (semcor) data + convert to supersenses 227 | train_path = os.path.join(args.data_path, 'Training_Corpora/SemCor/') 228 | train_data = load_data(train_path, 'semcor') 229 | 230 | #filter train data for k-shot learning 231 | if args.kshot > 0: 232 | train_data = filter_k_examples(train_data, args.kshot) 233 | 234 | task_labels, label_map = get_label_space(train_data) 235 | print('num labels = {} + 1 unknown label'.format(len(task_labels)-1)) 236 | 237 | train_data = preprocess(tokenizer, pretrained_model, train_data, task_labels, label_map) 238 | train_data = batchify(train_data, bsz=args.bsz) 239 | 240 | num_epochs = args.epochs 241 | if args.kshot > 0: 242 | NUM_STEPS = 176600 #hard coded for fair comparision with full model on default num. of epochs 243 | num_batches = len(train_data) 244 | num_epochs = NUM_STEPS//num_batches #recalculate number of epochs 245 | overflow_steps = NUM_STEPS%num_batches #num steps in last overflow epoch (if there is one, otherwise 0) 246 | t_total = NUM_STEPS #manually set number of steps for lr schedule 247 | if overflow_steps > 0: num_epochs+=1 #add extra epoch for overflow steps 248 | print('Overriding args.epochs and training for {} epochs...'.format(epochs)) 249 | 250 | #loading eval data & convert to supersense tags 251 | #dev set = semeval2007 252 | semeval2007_path = os.path.join(args.data_path, 'Evaluation_Datasets/semeval2007/') 253 | semeval2007_data = load_data(semeval2007_path, 'semeval2007') 254 | semeval2007_data = preprocess(tokenizer, pretrained_model, semeval2007_data, task_labels, label_map) 255 | semeval2007_data = batchify(semeval2007_data, bsz=1) 256 | 257 | ''' 258 | SET UP PROBING MODEL FOR TASK 259 | ''' 260 | 261 | #probing model = projection layer to label space, loss function, and optimizer 262 | probe = torch.nn.Linear(output_dim, len(task_labels)) 263 | probe = probe.cuda() 264 | criterion = torch.nn.CrossEntropyLoss(reduction='none') 265 | optim = torch.optim.Adam(probe.parameters(), lr=lr) 266 | 267 | ''' 268 | TRAIN PROBING MODEL ON SEMCOR DATA 269 | ''' 270 | 271 | best_dev_f1 = 0. 272 | print('Training probe...') 273 | sys.stdout.flush() 274 | for epoch in range(1, num_epochs+1): 275 | #train on full dataset 276 | probe_optim = _train(train_data, probe, optim, criterion, bsz=bsz, silent=args.silent) 277 | 278 | #eval probe on dev set (semeval2007) 279 | eval_preds = _eval(semeval2007_data, probe, task_labels) 280 | 281 | #generate predictions file 282 | pred_filepath = os.path.join(args.ckpt, 'tmp_predictions.txt') 283 | with open(pred_filepath, 'w') as f: 284 | for inst, prediction in eval_preds: 285 | f.write('{} {}\n'.format(inst, prediction)) 286 | 287 | #run predictions through scorer 288 | gold_filepath = os.path.join(args.data_path, 'Evaluation_Datasets/semeval2007/semeval2007.gold.key.txt') 289 | scorer_path = os.path.join(args.data_path, 'Evaluation_Datasets') 290 | _, _, dev_f1 = evaluate_output(scorer_path, gold_filepath, pred_filepath) 291 | print('Dev f1 after {} epochs = {}'.format(epoch, dev_f1)) 292 | sys.stdout.flush() 293 | 294 | if dev_f1 >= best_dev_f1: 295 | print('updating best model at epoch {}...'.format(epoch)) 296 | sys.stdout.flush() 297 | best_dev_f1 = dev_f1 298 | #save to file if best probe so far on dev set 299 | probe_fname = os.path.join(args.ckpt, 'best_model.ckpt') 300 | with open(probe_fname, 'wb') as f: 301 | torch.save(probe.state_dict(), f) 302 | sys.stdout.flush() 303 | 304 | #shuffle train data after every epoch 305 | random.shuffle(train_data) 306 | 307 | return 308 | 309 | def evaluate_probe(args): 310 | print('Evaluating WSD probe on {}...'.format(args.split)) 311 | 312 | ''' 313 | LOAD TOKENIZER + BERT MODEL 314 | ''' 315 | tokenizer = load_tokenizer(args.encoder_name) 316 | pretrained_model, output_dim = load_pretrained_model(args.encoder_name) 317 | pretrained_model = pretrained_model.cuda() 318 | 319 | ''' 320 | GET LABEL SPACE 321 | ''' 322 | train_path = os.path.join(args.data_path, 'Training_Corpora/SemCor/') 323 | train_data = load_data(train_path, 'semcor') 324 | task_labels, label_map = get_label_space(train_data) 325 | #for backoff eval 326 | train_keys = wn_keys(train_data) 327 | coverage = set(train_keys) 328 | 329 | ''' 330 | LOAD TRAINED PROBE 331 | ''' 332 | probe = torch.nn.Linear(output_dim, len(task_labels)) 333 | probe_path = os.path.join(args.ckpt, 'best_model.ckpt') 334 | probe.load_state_dict(torch.load(probe_path)) 335 | probe = probe.cuda() 336 | 337 | ''' 338 | LOAD EVAL SET 339 | ''' 340 | eval_path = os.path.join(args.data_path, 'Evaluation_Datasets/{}/'.format(args.split)) 341 | eval_data = load_data(eval_path, args.split) 342 | #for backoff 343 | eval_keys = wn_keys(eval_data) 344 | eval_data = preprocess(tokenizer, pretrained_model, eval_data, task_labels, label_map) 345 | eval_data = batchify(eval_data, bsz=1) 346 | 347 | ''' 348 | EVALUATE PROBE w/o backoff 349 | ''' 350 | eval_preds = _eval(eval_data, probe, task_labels) 351 | 352 | #generate predictions file 353 | pred_filepath = os.path.join(args.ckpt, './{}_predictions.txt'.format(args.split)) 354 | with open(pred_filepath, 'w') as f: 355 | for inst, prediction in eval_preds: 356 | f.write('{} {}\n'.format(inst, prediction)) 357 | 358 | #run predictions through scorer 359 | gold_filepath = os.path.join(eval_path, '{}.gold.key.txt'.format(args.split)) 360 | scorer_path = os.path.join(args.data_path, 'Evaluation_Datasets') 361 | p, r, f1 = evaluate_output(scorer_path, gold_filepath, pred_filepath) 362 | print('f1 of WSD probe on {} test set = {}'.format(args.split, f1)) 363 | 364 | ''' 365 | EVALUATE PROBE with backoff 366 | ''' 367 | wn_path = os.path.join(args.data_path, 'Data_Validation/candidatesWN30.txt') 368 | wn_senses = load_wn_senses(wn_path) 369 | eval_preds = _eval_with_backoff(eval_data, probe, task_labels, wn_senses, coverage, eval_keys) 370 | 371 | #generate predictions file 372 | pred_filepath = os.path.join(args.ckpt, './{}_backoff_predictions.txt'.format(args.split)) 373 | with open(pred_filepath, 'w') as f: 374 | for inst, prediction in eval_preds: 375 | f.write('{} {}\n'.format(inst, prediction)) 376 | 377 | #run predictions through scorer 378 | gold_filepath = os.path.join(eval_path, '{}.gold.key.txt'.format(args.split)) 379 | scorer_path = os.path.join(args.data_path, 'Evaluation_Datasets') 380 | p, r, f1 = evaluate_output(scorer_path, gold_filepath, pred_filepath) 381 | print('f1 of BERT probe (with backoff) = {}'.format(f1)) 382 | 383 | return 384 | 385 | if __name__ == "__main__": 386 | if not torch.cuda.is_available(): 387 | print("Need available GPU(s) to run this model...") 388 | quit() 389 | 390 | args = parser.parse_args() 391 | print(args) 392 | 393 | #set random seeds 394 | torch.manual_seed(args.rand_seed) 395 | os.environ['PYTHONHASHSEED'] = str(args.rand_seed) 396 | 397 | torch.cuda.manual_seed(args.rand_seed) 398 | torch.cuda.manual_seed_all(args.rand_seed) 399 | 400 | np.random.seed(args.rand_seed) 401 | random.seed(args.rand_seed) 402 | torch.backends.cudnn.benchmark = False 403 | torch.backends.cudnn.deterministic=True 404 | 405 | if args.eval: 406 | evaluate_probe(args) 407 | else: 408 | train_probe(args) 409 | 410 | 411 | -------------------------------------------------------------------------------- /wsd_models/models.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | All rights reserved. 4 | 5 | This source code is licensed under the license found in the 6 | LICENSE file in the root directory of this source tree. 7 | ''' 8 | 9 | import torch 10 | from torch.nn import functional as F 11 | import math 12 | import os 13 | import sys 14 | from pytorch_transformers import * 15 | 16 | from wsd_models.util import * 17 | 18 | def load_projection(path): 19 | proj_path = os.path.join(path, 'best_probe.ckpt') 20 | with open(proj_path, 'rb') as f: proj_layer = torch.load(f) 21 | return proj_layer 22 | 23 | class PretrainedClassifier(torch.nn.Module): 24 | def __init__(self, num_labels, encoder_name, proj_ckpt_path): 25 | super(PretrainedClassifier, self).__init__() 26 | 27 | self.encoder, self.encoder_hdim = load_pretrained_model(encoder_name) 28 | 29 | if proj_ckpt_path and len(proj_ckpt_path) > 0: 30 | self.proj_layer = load_projection(proj_ckpt_path) 31 | #assert to make sure correct dims 32 | assert self.proj_layer.in_features == self.encoder_hdim 33 | assert self.proj_layer.out_features == num_labels 34 | else: 35 | self.proj_layer = torch.nn.Linear(self.encoder_hdim, num_labels) 36 | 37 | def forward(self, input_ids, input_mask, example_mask): 38 | output = self.encoder(input_ids, attention_mask=input_mask)[0] 39 | 40 | example_arr = [] 41 | for i in range(output.size(0)): 42 | example_arr.append(process_encoder_outputs(output[i], example_mask[i], as_tensor=True)) 43 | output = torch.cat(example_arr, dim=0) 44 | output = self.proj_layer(output) 45 | return output 46 | 47 | class GlossEncoder(torch.nn.Module): 48 | def __init__(self, encoder_name, freeze_gloss, tied_encoder=None): 49 | super(GlossEncoder, self).__init__() 50 | 51 | #load pretrained model as base for context encoder and gloss encoder 52 | if tied_encoder: 53 | self.gloss_encoder = tied_encoder 54 | _, self.gloss_hdim = load_pretrained_model(encoder_name) 55 | else: 56 | self.gloss_encoder, self.gloss_hdim = load_pretrained_model(encoder_name) 57 | self.is_frozen = freeze_gloss 58 | 59 | def forward(self, input_ids, attn_mask): 60 | #encode gloss text 61 | if self.is_frozen: 62 | with torch.no_grad(): 63 | gloss_output = self.gloss_encoder(input_ids, attention_mask=attn_mask)[0] 64 | else: 65 | gloss_output = self.gloss_encoder(input_ids, attention_mask=attn_mask)[0] 66 | #training model to put all sense information on CLS token 67 | gloss_output = gloss_output[:,0,:].squeeze(dim=1) #now bsz*gloss_hdim 68 | return gloss_output 69 | 70 | class ContextEncoder(torch.nn.Module): 71 | def __init__(self, encoder_name, freeze_context): 72 | super(ContextEncoder, self).__init__() 73 | 74 | #load pretrained model as base for context encoder and gloss encoder 75 | self.context_encoder, self.context_hdim = load_pretrained_model(encoder_name) 76 | self.is_frozen = freeze_context 77 | 78 | def forward(self, input_ids, attn_mask, output_mask): 79 | #encode context 80 | if self.is_frozen: 81 | with torch.no_grad(): 82 | context_output = self.context_encoder(input_ids, attention_mask=attn_mask)[0] 83 | else: 84 | context_output = self.context_encoder(input_ids, attention_mask=attn_mask)[0] 85 | 86 | #average representations over target word(s) 87 | example_arr = [] 88 | for i in range(context_output.size(0)): 89 | example_arr.append(process_encoder_outputs(context_output[i], output_mask[i], as_tensor=True)) 90 | context_output = torch.cat(example_arr, dim=0) 91 | 92 | return context_output 93 | 94 | class BiEncoderModel(torch.nn.Module): 95 | def __init__(self, encoder_name, freeze_gloss=False, freeze_context=False, tie_encoders=False): 96 | super(BiEncoderModel, self).__init__() 97 | 98 | #tying encoders for ablation 99 | self.tie_encoders = tie_encoders 100 | 101 | #load pretrained model as base for context encoder and gloss encoder 102 | self.context_encoder = ContextEncoder(encoder_name, freeze_context) 103 | if self.tie_encoders: 104 | self.gloss_encoder = GlossEncoder(encoder_name, freeze_gloss, tied_encoder=self.context_encoder.context_encoder) 105 | else: 106 | self.gloss_encoder = GlossEncoder(encoder_name, freeze_gloss) 107 | assert self.context_encoder.context_hdim == self.gloss_encoder.gloss_hdim 108 | 109 | def context_forward(self, context_input, context_input_mask, context_example_mask): 110 | return self.context_encoder.forward(context_input, context_input_mask, context_example_mask) 111 | 112 | def gloss_forward(self, gloss_input, gloss_mask): 113 | return self.gloss_encoder.forward(gloss_input, gloss_mask) 114 | 115 | #EOF -------------------------------------------------------------------------------- /wsd_models/util.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copyright (c) Facebook, Inc. and its affiliates. 3 | All rights reserved. 4 | 5 | This source code is licensed under the license found in the 6 | LICENSE file in the root directory of this source tree. 7 | ''' 8 | 9 | import os 10 | import re 11 | import torch 12 | import subprocess 13 | from pytorch_transformers import * 14 | import random 15 | 16 | pos_converter = {'NOUN':'n', 'PROPN':'n', 'VERB':'v', 'AUX':'v', 'ADJ':'a', 'ADV':'r'} 17 | 18 | def generate_key(lemma, pos): 19 | if pos in pos_converter.keys(): 20 | pos = pos_converter[pos] 21 | key = '{}+{}'.format(lemma, pos) 22 | return key 23 | 24 | def load_pretrained_model(name): 25 | if name == 'roberta-base': 26 | model = RobertaModel.from_pretrained('roberta-base') 27 | hdim = 768 28 | elif name == 'roberta-large': 29 | model = RobertaModel.from_pretrained('roberta-large') 30 | hdim = 1024 31 | elif name == 'bert-large': 32 | model = BertModel.from_pretrained('bert-large-uncased') 33 | hdim = 1024 34 | else: #bert base 35 | model = BertModel.from_pretrained('bert-base-uncased') 36 | hdim = 768 37 | return model, hdim 38 | 39 | def load_tokenizer(name): 40 | if name == 'roberta-base': 41 | tokenizer = RobertaTokenizer.from_pretrained('roberta-base') 42 | elif name == 'roberta-large': 43 | tokenizer = RobertaTokenizer.from_pretrained('roberta-large') 44 | elif name == 'bert-large': 45 | tokenizer = BertTokenizer.from_pretrained('bert-large-uncased') 46 | else: #bert base 47 | tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') 48 | return tokenizer 49 | 50 | def load_wn_senses(path): 51 | wn_senses = {} 52 | with open(path, 'r', encoding="utf8") as f: 53 | for line in f: 54 | line = line.strip().split('\t') 55 | lemma = line[0] 56 | pos = line[1] 57 | senses = line[2:] 58 | 59 | key = generate_key(lemma, pos) 60 | wn_senses[key] = senses 61 | return wn_senses 62 | 63 | def get_label_space(data): 64 | #get set of labels from dataset 65 | labels = set() 66 | 67 | for sent in data: 68 | for _, _, _, _, label in sent: 69 | if label != -1: 70 | labels.add(label) 71 | 72 | labels = list(labels) 73 | labels.sort() 74 | labels.append('n/a') 75 | 76 | label_map = {} 77 | for sent in data: 78 | for _, lemma, pos, _, label in sent: 79 | if label != -1: 80 | key = generate_key(lemma, pos) 81 | label_idx = labels.index(label) 82 | if key not in label_map: label_map[key] = set() 83 | label_map[key].add(label_idx) 84 | 85 | return labels, label_map 86 | 87 | def process_encoder_outputs(output, mask, as_tensor=False): 88 | combined_outputs = [] 89 | position = -1 90 | avg_arr = [] 91 | for idx, rep in zip(mask, torch.split(output, 1, dim=0)): 92 | #ignore unlabeled words 93 | if idx == -1: continue 94 | #average representations for units in same example 95 | elif position < idx: 96 | position=idx 97 | if len(avg_arr) > 0: combined_outputs.append(torch.mean(torch.stack(avg_arr, dim=-1), dim=-1)) 98 | avg_arr = [rep] 99 | else: 100 | assert position == idx 101 | avg_arr.append(rep) 102 | #get last example from avg_arr 103 | if len(avg_arr) > 0: combined_outputs.append(torch.mean(torch.stack(avg_arr, dim=-1), dim=-1)) 104 | if as_tensor: return torch.cat(combined_outputs, dim=0) 105 | else: return combined_outputs 106 | 107 | #run WSD Evaluation Framework scorer within python 108 | def evaluate_output(scorer_path, gold_filepath, out_filepath): 109 | eval_cmd = ['java','-cp', scorer_path, 'Scorer', gold_filepath, out_filepath] 110 | output = subprocess.Popen(eval_cmd, stdout=subprocess.PIPE ).communicate()[0] 111 | output = [x.decode("utf-8") for x in output.splitlines()] 112 | p,r,f1 = [float(output[i].split('=')[-1].strip()[:-1]) for i in range(3)] 113 | return p, r, f1 114 | 115 | def load_data(datapath, name): 116 | text_path = os.path.join(datapath, '{}.data.xml'.format(name)) 117 | gold_path = os.path.join(datapath, '{}.gold.key.txt'.format(name)) 118 | 119 | #load gold labels 120 | gold_labels = {} 121 | with open(gold_path, 'r', encoding="utf8") as f: 122 | for line in f: 123 | line = line.strip().split(' ') 124 | instance = line[0] 125 | #this means we are ignoring other senses if labeled with more than one 126 | #(happens at least in SemCor data) 127 | key = line[1] 128 | gold_labels[instance] = key 129 | 130 | #load train examples + annotate sense instances with gold labels 131 | sentences = [] 132 | s = [] 133 | with open(text_path, 'r', encoding="utf8") as f: 134 | for line in f: 135 | line = line.strip() 136 | if line == '': 137 | sentences.append(s) 138 | s=[] 139 | 140 | elif line.startswith('(.+?)<', line).group(1) 142 | lemma = re.search('lemma="(.+?)"', line).group(1) 143 | pos = re.search('pos="(.+?)"', line).group(1) 144 | 145 | #clean up data 146 | word = re.sub(''', '\'', word) 147 | lemma = re.sub(''', '\'', lemma) 148 | 149 | sense_inst = -1 150 | sense_label = -1 151 | if line.startswith('