├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── __init__.py ├── data_processing ├── create_amazoncat_ov.py ├── create_pel_data.py ├── create_wned_data.py ├── data_analysis.py └── extract_data.py ├── data_utils.py ├── decode_utils.py ├── eval.py ├── eval_psp.py ├── finetune_s2s.py ├── load_and_eval_s2s.py ├── local_configs.py ├── params.py ├── s2s_model.py ├── tests ├── __init__.py ├── test_data │ ├── __init__.py │ ├── gold.jsonl │ ├── gold_multi.jsonl │ ├── pred_all.jsonl │ ├── pred_miss.jsonl │ ├── pred_multi.jsonl │ └── pred_part.jsonl └── test_eval.py └── utils.py /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to make participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | This Code of Conduct also applies outside the project spaces when there is a 56 | reasonable belief that an individual's behavior may have a negative impact on 57 | the project or its community. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported by contacting the project team at . All 63 | complaints will be reviewed and investigated and will result in a response that 64 | is deemed necessary and appropriate to the circumstances. The project team is 65 | obligated to maintain confidentiality with regard to the reporter of an incident. 66 | Further details of specific enforcement policies may be posted separately. 67 | 68 | Project maintainers who do not follow or enforce the Code of Conduct in good 69 | faith may face temporary or permanent repercussions as determined by other 70 | members of the project's leadership. 71 | 72 | ## Attribution 73 | 74 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 75 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 76 | 77 | [homepage]: https://www.contributor-covenant.org 78 | 79 | For answers to common questions about this code of conduct, see 80 | https://www.contributor-covenant.org/faq -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to GROOV 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Pull Requests 6 | We actively welcome your pull requests. 7 | 8 | 1. Fork the repo and create your branch from `main`. 9 | 2. If you've added code that should be tested, add tests. 10 | 3. If you've changed APIs, update the documentation. 11 | 4. Ensure the test suite passes. 12 | 5. Make sure your code lints. 13 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 14 | 15 | ## Contributor License Agreement ("CLA") 16 | In order to accept your pull request, we need you to submit a CLA. You only need 17 | to do this once to work on any of Meta's open source projects. 18 | 19 | Complete your CLA here: 20 | 21 | ## Issues 22 | We use GitHub issues to track public bugs. Please ensure your description is 23 | clear and has sufficient instructions to be able to reproduce the issue. 24 | 25 | Meta has a [bounty program](https://www.facebook.com/whitehat/) for the safe 26 | disclosure of security bugs. In those cases, please go through the process 27 | outlined on that page and do not file a public issue. 28 | 29 | ## License 30 | By contributing to GROOV, you agree that your contributions will be licensed 31 | under the LICENSE file in the root directory of this source tree. -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Attribution-NonCommercial 4.0 International 3 | 4 | ======================================================================= 5 | 6 | Creative Commons Corporation ("Creative Commons") is not a law firm and 7 | does not provide legal services or legal advice. Distribution of 8 | Creative Commons public licenses does not create a lawyer-client or 9 | other relationship. Creative Commons makes its licenses and related 10 | information available on an "as-is" basis. Creative Commons gives no 11 | warranties regarding its licenses, any material licensed under their 12 | terms and conditions, or any related information. Creative Commons 13 | disclaims all liability for damages resulting from their use to the 14 | fullest extent possible. 15 | 16 | Using Creative Commons Public Licenses 17 | 18 | Creative Commons public licenses provide a standard set of terms and 19 | conditions that creators and other rights holders may use to share 20 | original works of authorship and other material subject to copyright 21 | and certain other rights specified in the public license below. The 22 | following considerations are for informational purposes only, are not 23 | exhaustive, and do not form part of our licenses. 24 | 25 | Considerations for licensors: Our public licenses are 26 | intended for use by those authorized to give the public 27 | permission to use material in ways otherwise restricted by 28 | copyright and certain other rights. Our licenses are 29 | irrevocable. Licensors should read and understand the terms 30 | and conditions of the license they choose before applying it. 31 | Licensors should also secure all rights necessary before 32 | applying our licenses so that the public can reuse the 33 | material as expected. Licensors should clearly mark any 34 | material not subject to the license. This includes other CC- 35 | licensed material, or material used under an exception or 36 | limitation to copyright. More considerations for licensors: 37 | wiki.creativecommons.org/Considerations_for_licensors 38 | 39 | Considerations for the public: By using one of our public 40 | licenses, a licensor grants the public permission to use the 41 | licensed material under specified terms and conditions. If 42 | the licensor's permission is not necessary for any reason--for 43 | example, because of any applicable exception or limitation to 44 | copyright--then that use is not regulated by the license. Our 45 | licenses grant only permissions under copyright and certain 46 | other rights that a licensor has authority to grant. Use of 47 | the licensed material may still be restricted for other 48 | reasons, including because others have copyright or other 49 | rights in the material. A licensor may make special requests, 50 | such as asking that all changes be marked or described. 51 | Although not required by our licenses, you are encouraged to 52 | respect those requests where reasonable. More_considerations 53 | for the public: 54 | wiki.creativecommons.org/Considerations_for_licensees 55 | 56 | ======================================================================= 57 | 58 | Creative Commons Attribution-NonCommercial 4.0 International Public 59 | License 60 | 61 | By exercising the Licensed Rights (defined below), You accept and agree 62 | to be bound by the terms and conditions of this Creative Commons 63 | Attribution-NonCommercial 4.0 International Public License ("Public 64 | License"). To the extent this Public License may be interpreted as a 65 | contract, You are granted the Licensed Rights in consideration of Your 66 | acceptance of these terms and conditions, and the Licensor grants You 67 | such rights in consideration of benefits the Licensor receives from 68 | making the Licensed Material available under these terms and 69 | conditions. 70 | 71 | Section 1 -- Definitions. 72 | 73 | a. Adapted Material means material subject to Copyright and Similar 74 | Rights that is derived from or based upon the Licensed Material 75 | and in which the Licensed Material is translated, altered, 76 | arranged, transformed, or otherwise modified in a manner requiring 77 | permission under the Copyright and Similar Rights held by the 78 | Licensor. For purposes of this Public License, where the Licensed 79 | Material is a musical work, performance, or sound recording, 80 | Adapted Material is always produced where the Licensed Material is 81 | synched in timed relation with a moving image. 82 | 83 | b. Adapter's License means the license You apply to Your Copyright 84 | and Similar Rights in Your contributions to Adapted Material in 85 | accordance with the terms and conditions of this Public License. 86 | 87 | c. Copyright and Similar Rights means copyright and/or similar rights 88 | closely related to copyright including, without limitation, 89 | performance, broadcast, sound recording, and Sui Generis Database 90 | Rights, without regard to how the rights are labeled or 91 | categorized. For purposes of this Public License, the rights 92 | specified in Section 2(b)(1)-(2) are not Copyright and Similar 93 | Rights. 94 | d. Effective Technological Measures means those measures that, in the 95 | absence of proper authority, may not be circumvented under laws 96 | fulfilling obligations under Article 11 of the WIPO Copyright 97 | Treaty adopted on December 20, 1996, and/or similar international 98 | agreements. 99 | 100 | e. Exceptions and Limitations means fair use, fair dealing, and/or 101 | any other exception or limitation to Copyright and Similar Rights 102 | that applies to Your use of the Licensed Material. 103 | 104 | f. Licensed Material means the artistic or literary work, database, 105 | or other material to which the Licensor applied this Public 106 | License. 107 | 108 | g. Licensed Rights means the rights granted to You subject to the 109 | terms and conditions of this Public License, which are limited to 110 | all Copyright and Similar Rights that apply to Your use of the 111 | Licensed Material and that the Licensor has authority to license. 112 | 113 | h. Licensor means the individual(s) or entity(ies) granting rights 114 | under this Public License. 115 | 116 | i. NonCommercial means not primarily intended for or directed towards 117 | commercial advantage or monetary compensation. For purposes of 118 | this Public License, the exchange of the Licensed Material for 119 | other material subject to Copyright and Similar Rights by digital 120 | file-sharing or similar means is NonCommercial provided there is 121 | no payment of monetary compensation in connection with the 122 | exchange. 123 | 124 | j. Share means to provide material to the public by any means or 125 | process that requires permission under the Licensed Rights, such 126 | as reproduction, public display, public performance, distribution, 127 | dissemination, communication, or importation, and to make material 128 | available to the public including in ways that members of the 129 | public may access the material from a place and at a time 130 | individually chosen by them. 131 | 132 | k. Sui Generis Database Rights means rights other than copyright 133 | resulting from Directive 96/9/EC of the European Parliament and of 134 | the Council of 11 March 1996 on the legal protection of databases, 135 | as amended and/or succeeded, as well as other essentially 136 | equivalent rights anywhere in the world. 137 | 138 | l. You means the individual or entity exercising the Licensed Rights 139 | under this Public License. Your has a corresponding meaning. 140 | 141 | Section 2 -- Scope. 142 | 143 | a. License grant. 144 | 145 | 1. Subject to the terms and conditions of this Public License, 146 | the Licensor hereby grants You a worldwide, royalty-free, 147 | non-sublicensable, non-exclusive, irrevocable license to 148 | exercise the Licensed Rights in the Licensed Material to: 149 | 150 | a. reproduce and Share the Licensed Material, in whole or 151 | in part, for NonCommercial purposes only; and 152 | 153 | b. produce, reproduce, and Share Adapted Material for 154 | NonCommercial purposes only. 155 | 156 | 2. Exceptions and Limitations. For the avoidance of doubt, where 157 | Exceptions and Limitations apply to Your use, this Public 158 | License does not apply, and You do not need to comply with 159 | its terms and conditions. 160 | 161 | 3. Term. The term of this Public License is specified in Section 162 | 6(a). 163 | 164 | 4. Media and formats; technical modifications allowed. The 165 | Licensor authorizes You to exercise the Licensed Rights in 166 | all media and formats whether now known or hereafter created, 167 | and to make technical modifications necessary to do so. The 168 | Licensor waives and/or agrees not to assert any right or 169 | authority to forbid You from making technical modifications 170 | necessary to exercise the Licensed Rights, including 171 | technical modifications necessary to circumvent Effective 172 | Technological Measures. For purposes of this Public License, 173 | simply making modifications authorized by this Section 2(a) 174 | (4) never produces Adapted Material. 175 | 176 | 5. Downstream recipients. 177 | 178 | a. Offer from the Licensor -- Licensed Material. Every 179 | recipient of the Licensed Material automatically 180 | receives an offer from the Licensor to exercise the 181 | Licensed Rights under the terms and conditions of this 182 | Public License. 183 | 184 | b. No downstream restrictions. You may not offer or impose 185 | any additional or different terms or conditions on, or 186 | apply any Effective Technological Measures to, the 187 | Licensed Material if doing so restricts exercise of the 188 | Licensed Rights by any recipient of the Licensed 189 | Material. 190 | 191 | 6. No endorsement. Nothing in this Public License constitutes or 192 | may be construed as permission to assert or imply that You 193 | are, or that Your use of the Licensed Material is, connected 194 | with, or sponsored, endorsed, or granted official status by, 195 | the Licensor or others designated to receive attribution as 196 | provided in Section 3(a)(1)(A)(i). 197 | 198 | b. Other rights. 199 | 200 | 1. Moral rights, such as the right of integrity, are not 201 | licensed under this Public License, nor are publicity, 202 | privacy, and/or other similar personality rights; however, to 203 | the extent possible, the Licensor waives and/or agrees not to 204 | assert any such rights held by the Licensor to the limited 205 | extent necessary to allow You to exercise the Licensed 206 | Rights, but not otherwise. 207 | 208 | 2. Patent and trademark rights are not licensed under this 209 | Public License. 210 | 211 | 3. To the extent possible, the Licensor waives any right to 212 | collect royalties from You for the exercise of the Licensed 213 | Rights, whether directly or through a collecting society 214 | under any voluntary or waivable statutory or compulsory 215 | licensing scheme. In all other cases the Licensor expressly 216 | reserves any right to collect such royalties, including when 217 | the Licensed Material is used other than for NonCommercial 218 | purposes. 219 | 220 | Section 3 -- License Conditions. 221 | 222 | Your exercise of the Licensed Rights is expressly made subject to the 223 | following conditions. 224 | 225 | a. Attribution. 226 | 227 | 1. If You Share the Licensed Material (including in modified 228 | form), You must: 229 | 230 | a. retain the following if it is supplied by the Licensor 231 | with the Licensed Material: 232 | 233 | i. identification of the creator(s) of the Licensed 234 | Material and any others designated to receive 235 | attribution, in any reasonable manner requested by 236 | the Licensor (including by pseudonym if 237 | designated); 238 | 239 | ii. a copyright notice; 240 | 241 | iii. a notice that refers to this Public License; 242 | 243 | iv. a notice that refers to the disclaimer of 244 | warranties; 245 | 246 | v. a URI or hyperlink to the Licensed Material to the 247 | extent reasonably practicable; 248 | 249 | b. indicate if You modified the Licensed Material and 250 | retain an indication of any previous modifications; and 251 | 252 | c. indicate the Licensed Material is licensed under this 253 | Public License, and include the text of, or the URI or 254 | hyperlink to, this Public License. 255 | 256 | 2. You may satisfy the conditions in Section 3(a)(1) in any 257 | reasonable manner based on the medium, means, and context in 258 | which You Share the Licensed Material. For example, it may be 259 | reasonable to satisfy the conditions by providing a URI or 260 | hyperlink to a resource that includes the required 261 | information. 262 | 263 | 3. If requested by the Licensor, You must remove any of the 264 | information required by Section 3(a)(1)(A) to the extent 265 | reasonably practicable. 266 | 267 | 4. If You Share Adapted Material You produce, the Adapter's 268 | License You apply must not prevent recipients of the Adapted 269 | Material from complying with this Public License. 270 | 271 | Section 4 -- Sui Generis Database Rights. 272 | 273 | Where the Licensed Rights include Sui Generis Database Rights that 274 | apply to Your use of the Licensed Material: 275 | 276 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right 277 | to extract, reuse, reproduce, and Share all or a substantial 278 | portion of the contents of the database for NonCommercial purposes 279 | only; 280 | 281 | b. if You include all or a substantial portion of the database 282 | contents in a database in which You have Sui Generis Database 283 | Rights, then the database in which You have Sui Generis Database 284 | Rights (but not its individual contents) is Adapted Material; and 285 | 286 | c. You must comply with the conditions in Section 3(a) if You Share 287 | all or a substantial portion of the contents of the database. 288 | 289 | For the avoidance of doubt, this Section 4 supplements and does not 290 | replace Your obligations under this Public License where the Licensed 291 | Rights include other Copyright and Similar Rights. 292 | 293 | Section 5 -- Disclaimer of Warranties and Limitation of Liability. 294 | 295 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE 296 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS 297 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF 298 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, 299 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, 300 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR 301 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, 302 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT 303 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT 304 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. 305 | 306 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE 307 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, 308 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, 309 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, 310 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR 311 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN 312 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR 313 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR 314 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. 315 | 316 | c. The disclaimer of warranties and limitation of liability provided 317 | above shall be interpreted in a manner that, to the extent 318 | possible, most closely approximates an absolute disclaimer and 319 | waiver of all liability. 320 | 321 | Section 6 -- Term and Termination. 322 | 323 | a. This Public License applies for the term of the Copyright and 324 | Similar Rights licensed here. However, if You fail to comply with 325 | this Public License, then Your rights under this Public License 326 | terminate automatically. 327 | 328 | b. Where Your right to use the Licensed Material has terminated under 329 | Section 6(a), it reinstates: 330 | 331 | 1. automatically as of the date the violation is cured, provided 332 | it is cured within 30 days of Your discovery of the 333 | violation; or 334 | 335 | 2. upon express reinstatement by the Licensor. 336 | 337 | For the avoidance of doubt, this Section 6(b) does not affect any 338 | right the Licensor may have to seek remedies for Your violations 339 | of this Public License. 340 | 341 | c. For the avoidance of doubt, the Licensor may also offer the 342 | Licensed Material under separate terms or conditions or stop 343 | distributing the Licensed Material at any time; however, doing so 344 | will not terminate this Public License. 345 | 346 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public 347 | License. 348 | 349 | Section 7 -- Other Terms and Conditions. 350 | 351 | a. The Licensor shall not be bound by any additional or different 352 | terms or conditions communicated by You unless expressly agreed. 353 | 354 | b. Any arrangements, understandings, or agreements regarding the 355 | Licensed Material not stated herein are separate from and 356 | independent of the terms and conditions of this Public License. 357 | 358 | Section 8 -- Interpretation. 359 | 360 | a. For the avoidance of doubt, this Public License does not, and 361 | shall not be interpreted to, reduce, limit, restrict, or impose 362 | conditions on any use of the Licensed Material that could lawfully 363 | be made without permission under this Public License. 364 | 365 | b. To the extent possible, if any provision of this Public License is 366 | deemed unenforceable, it shall be automatically reformed to the 367 | minimum extent necessary to make it enforceable. If the provision 368 | cannot be reformed, it shall be severed from this Public License 369 | without affecting the enforceability of the remaining terms and 370 | conditions. 371 | 372 | c. No term or condition of this Public License will be waived and no 373 | failure to comply consented to unless expressly agreed to by the 374 | Licensor. 375 | 376 | d. Nothing in this Public License constitutes or may be interpreted 377 | as a limitation upon, or waiver of, any privileges and immunities 378 | that apply to the Licensor or You, including from the legal 379 | processes of any jurisdiction or authority. 380 | 381 | ======================================================================= 382 | 383 | Creative Commons is not a party to its public 384 | licenses. Notwithstanding, Creative Commons may elect to apply one of 385 | its public licenses to material it publishes and in those instances 386 | will be considered the “Licensor.” The text of the Creative Commons 387 | public licenses is dedicated to the public domain under the CC0 Public 388 | Domain Dedication. Except for the limited purpose of indicating that 389 | material is shared under a Creative Commons public license or as 390 | otherwise permitted by the Creative Commons policies published at 391 | creativecommons.org/policies, Creative Commons does not authorize the 392 | use of the trademark "Creative Commons" or any other trademark or logo 393 | of Creative Commons without its prior written consent including, 394 | without limitation, in connection with any unauthorized modifications 395 | to any of its public licenses or any other arrangements, 396 | understandings, or agreements concerning use of licensed material. For 397 | the avoidance of doubt, this paragraph does not form part of the 398 | public licenses. 399 | 400 | Creative Commons may be contacted at creativecommons.org. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # GROOV: GeneRative Out-Of_Vocabulary tagging 2 | 3 | This is a minimal codebase to reproduce data and models for the paper "Open Vocabulary Extreme Classification Using Generative Models" 4 | 5 | ## Data 6 | 7 | To reproduce the AmazonCat-OV dataset: 8 | 1) Download raw text features for AmazonCat-13K from http://manikvarma.org/downloads/XC/XMLRepository.html 9 | 2) Format the data using utils in data_processing/extract_data.py 10 | 3) Create the shuffled data by running data_processing/create_amazoncat_ov.py 11 | 12 | ## Models 13 | The three main steps to produce GROOV tagger models: 14 | 15 | Finetune T5 on a dataset: 16 | ``` 17 | python finetune_s2s.py --train_file_path= --test_file_path= --train_batch_size=32 --eval_batch_size=32 --output_dir=test_run_results/t5_small_10ep --model_name_or_path t5-small --use_multisoftmax --data_parallel --save_after_every_eval --eval_every_k_epoch 5 --num_epochs 10 18 | ``` 19 | 20 | Run inference on a model: 21 | ``` 22 | python load_and_eval_s2s.py --train_file_path= --test_file_path= --output_dir= --eval_batch_size 10 --decode_beams 15 23 | ``` 24 | 25 | Compute PSP@K metrics on the result of inference: 26 | ``` 27 | python eval.py --guess /test_preds_sum_prob.jsonl --gold --ks 1,5,10,15 28 | python eval_psp.py --train --gold --guess /test_preds_sum_prob.jsonl 29 | ``` 30 | 31 | # GET 32 | GET is an entity tagging model that extracts set of entities without mention supervision. 33 | 34 | ## Requirements 35 | ``` 36 | python == 3.7 37 | pytorch == 1.9.1 38 | transformers == 4.9.1 39 | ``` 40 | ## Usage 41 | ### Download model and data 42 | The pretrained GET model can be downloaded [here](https://dl.fbaipublicfiles.com/groov/get_model.tar.gz). To replicate our experiments in GET paper, download the [training data and WNED benchmark](https://dl.fbaipublicfiles.com/groov/get_data.tar.gz). 43 | 44 | ### Configuration 45 | Set the path to model checkpoint and dataset in `local_configs.py` 46 | 47 | ### Train the model 48 | Finetune GET model on Wikipedia abstracts and AIDA data: 49 | ```bash 50 | LOCAL_DATA_DIR = "../GET/data" 51 | 52 | python finetune_s2s.py --train_file_path pretrain_data/small_train.jsonl 53 | --test_file_path pretrain_data/wiki_abstract_aida_dev.jsonl 54 | --output_dir 55 | --model_name_or_path t5-base 56 | --train_batch_size 16 57 | --eval_batch_size 2 58 | --num_epochs 50 59 | --max_i_length 512 60 | --max_o_length 512 61 | --data_parallel 62 | ``` 63 | ### Evaluation 64 | Generate the prediction for AIDA test data using constrained beam search: 65 | 66 | ```bash 67 | LOCAL_DATA_DIR = "../GET/data" 68 | OUTPUT_DIR = "../GET" 69 | 70 | python load_and_eval_s2s.py --output_dir experiments/ 71 | --model_name_or_path t5-base 72 | --test_file_path AIDA/aida_test_dataset.jsonl 73 | --decode_on_lattice 74 | --decode_beams 5 75 | --label_set_file entities.json 76 | --dataset_name aida_test 77 | ``` 78 | 79 | Compute the evaluation metrics on AIDA: 80 | ```bash 81 | python eval.py --guess /test_preds_naive_lattice_aida_test.jsonl 82 | --gold ../GET/data/AIDA/aida_test_dataset.jsonl 83 | ``` 84 | 85 | ## License 86 | See the [LICENSE](LICENSE) file for details. 87 | -------------------------------------------------------------------------------- /__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/GROOV/e15f399a99add2bb52247113718e7d9fd188f58f/__init__.py -------------------------------------------------------------------------------- /data_processing/create_amazoncat_ov.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | import copy 9 | from collections import defaultdict 10 | import json 11 | import os 12 | import random 13 | 14 | DATA_ROOT = "/checkpoint/danielsimig/ECG/data_test/" 15 | 16 | # Data in the format of extract_data.py 17 | ORIG_TRAIN_FILE = os.path.join(DATA_ROOT, "AmazonCat_13k_train.jsonl") 18 | ORIG_TEST_FILE = os.path.join(DATA_ROOT, "AmazonCat_13k_test.jsonl") 19 | 20 | # In this folder we're going to produce a set of files, with increasingly more data moved. 21 | # We prioduce versions with 1K, 2k, .. 10K moved labels. 1K is eventually used in the paper. 22 | SHUFFLED_FOLDER = os.path.join(DATA_ROOT, "AmazonCat_OOV/") 23 | REMOVE_STEP = 1000 24 | 25 | # Don't move too frequent labels. It's unrealistic that those labels are novel and doing so 26 | # would decrease our train set too much anyway. 27 | MAX_FREQ_TO_REMOVE = 1000 28 | 29 | 30 | def read_data(path): 31 | orig = {} 32 | labels = set() 33 | label_freqs = defaultdict(int) 34 | with open(path) as f: 35 | for line in f.readlines(): 36 | line = json.loads(line) 37 | orig[line['uid']] = line 38 | for label in line['output']: 39 | label = label.replace("_", " ") 40 | labels.add(label) 41 | label_freqs[label] += 1 42 | 43 | return orig, labels, label_freqs 44 | 45 | 46 | def write_data(data, path): 47 | with open(path, "w") as f: 48 | for line in data.values(): 49 | f.write(json.dumps(line) + "\n") 50 | 51 | 52 | train_orig, train_labels, train_label_freqs = read_data(ORIG_TRAIN_FILE) 53 | test_orig, test_labels, test_label_freqs = read_data(ORIG_TEST_FILE) 54 | 55 | random.seed(0) 56 | remove_order = [k for k, v in train_label_freqs.items() if v < MAX_FREQ_TO_REMOVE] 57 | random.shuffle(remove_order) 58 | 59 | train_shuffled = copy.deepcopy(train_orig) 60 | test_shuffled = copy.deepcopy(test_orig) 61 | 62 | print(f"ORIGINAL SIZES: Train: {len(train_shuffled)} Test: {len(test_shuffled)}") 63 | 64 | for i in range(10): 65 | removed_labels = set(remove_order[i*REMOVE_STEP: (i+1) * REMOVE_STEP]) 66 | train_shuffled_tmp = {} 67 | num_oov_labels = 0 68 | for k, v in train_shuffled.items(): 69 | # Move any data point with an occurance of a moved label to the test set 70 | if any(x in removed_labels for x in v['output']): 71 | test_shuffled[k] = v 72 | else: 73 | train_shuffled_tmp[k] = v 74 | train_shuffled = train_shuffled_tmp 75 | 76 | print(f"After moving {(i+1) * REMOVE_STEP} labels:\tTrain: {len(train_shuffled)} \tTest: {len(test_shuffled)}") 77 | 78 | write_data(train_shuffled, SHUFFLED_FOLDER + f"train_max1k_{(i+1)*REMOVE_STEP}_moved.jsonl") 79 | write_data(test_shuffled, SHUFFLED_FOLDER + f"test_max1k_{(i+1)*REMOVE_STEP}_moved.jsonl") 80 | 81 | -------------------------------------------------------------------------------- /data_processing/create_pel_data.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | from pathlib import Path 9 | import json 10 | import argparse 11 | import os 12 | from tqdm import tqdm 13 | import random 14 | import gzip 15 | from transformers import AutoTokenizer 16 | 17 | 18 | def load_data(filename, new_key=None, old_key=None): 19 | data = [] 20 | with open(filename, 'r') as fin: 21 | lines = fin.readlines() 22 | for line in tqdm(lines): 23 | instance = json.loads(line) 24 | if new_key and old_key: # Make sure instances in the merged dataset have the same output key for entities 25 | instance[new_key] = instance.pop(old_key) 26 | 27 | data.append(instance) 28 | 29 | return data 30 | 31 | def save_data(filename, data, encoding = 'utf8'): 32 | with open(filename, 'w', encoding=encoding) as fout: 33 | for res in data: 34 | json.dump(res, fout, ensure_ascii=False) 35 | fout.write("\n") 36 | 37 | # Remove tags with non-English characters, Wiktionary labels and 'None' 38 | def dedupe_labels(labels, has_mention=False): 39 | processed_labels = [] 40 | for label in labels: 41 | entity = label 42 | if has_mention: 43 | entity = label[-1] 44 | 45 | if not entity.isascii() or 'Wiktionary' in entity or entity == 'None': 46 | continue 47 | 48 | if label not in processed_labels: 49 | processed_labels.append(label) 50 | 51 | return processed_labels 52 | 53 | # Create mention table from pre-computed annotation file 54 | # https://github.com/masha-p/PPRforNED 55 | # Convert the annotation to a dict object with format {"mention" : [ candidate entities, ..]} 56 | def parse_annotation_file(file_path, mention_table): 57 | mention = None 58 | candidates = [] 59 | with open(file_path, 'r') as fin: 60 | lines = fin.readlines() 61 | for i, line in enumerate(lines): 62 | parsed_row = line.split("\t") 63 | 64 | if parsed_row[0] == "ENTITY": 65 | if mention and len(candidates) > 0: 66 | if mention not in mention_table: 67 | mention_table[mention] = candidates 68 | 69 | mention = parsed_row[1].split(':')[1] 70 | candidates = [] 71 | 72 | elif parsed_row[0] == 'CANDIDATE': 73 | url = parsed_row[5][4:] 74 | wikiname = url.rsplit('/', 1)[-1].replace('_', ' ') 75 | candidates.append(wikiname) 76 | 77 | # Process mention annotations to parallel EL format 78 | def remove_mention_text(mentions): 79 | for mention in mentions: 80 | del mention[2] 81 | 82 | def yield_lines(filepath, n_lines=None): 83 | filepath = Path(filepath) 84 | with open(filepath, "rt") as f: 85 | for i, l in enumerate(f): 86 | if n_lines is not None and i >= n_lines: 87 | break 88 | yield l.rstrip("\n") 89 | 90 | def yield_jsonl_lines(filepath, *args, **kwargs): 91 | for line in yield_lines(filepath, *args, **kwargs): 92 | yield json.loads(line) 93 | 94 | def process_kilt_data(src_path, tar_path): 95 | with gzip.open(tar_path, "wt", compresslevel=1) as f: 96 | for data in yield_jsonl_lines(src_path): 97 | 98 | # Remove labels with special characters 99 | data['entities'] = dedupe_labels(data['entities']) 100 | 101 | # Remove duplicates in topic labels 102 | data['topics'] = dedupe_labels(data['topics']) 103 | 104 | # Skip illegal data with empty input or no labels 105 | if len(data['input']) == 0 or (len(data['entities']) == 0 or len(data['topics']) == 0): 106 | continue 107 | 108 | f.write(json.dumps(data) + "\n") 109 | 110 | # Find the index of leftmost interval with desired property 111 | def search_token_index(char_index, intervals): 112 | left, right = 0, len(intervals) - 1 113 | res = len(intervals) 114 | while left <= right: 115 | mid = (left + right) // 2 116 | start, end = intervals[mid][0], intervals[mid][1] 117 | if start <= char_index <= end: 118 | res = min(res, mid) 119 | right = mid - 1 120 | elif char_index > end: 121 | left = mid + 1 122 | else: 123 | right = mid - 1 124 | 125 | return res 126 | 127 | 128 | if __name__ == "__main__": 129 | parser = argparse.ArgumentParser() 130 | parser.add_argument( 131 | "--src_path", 132 | default="/data/KILT_entities_topics/KILT_entities_topics_train.jsonl", 133 | type=str, 134 | help="Path to Wikipedia abstract data" 135 | ) 136 | 137 | parser.add_argument( 138 | "--tar_path", 139 | default="/data/KILT_entities_topics/pel_KILT_entities_topics_train.jsonl.gz", 140 | type=str, 141 | help="Path to processed data (in parallel EL format)" 142 | ) 143 | 144 | parser.add_argument( 145 | "--mention_file", 146 | default="/data/mention_table.json", 147 | type=str, 148 | help="Path to pre-computed mention table" 149 | ) 150 | 151 | parser.add_argument( 152 | "--entity_file", 153 | default="/data/entities.json", 154 | type=str, 155 | help="Path to file that contains candidate candidites" 156 | ) 157 | 158 | args = parser.parse_args() 159 | 160 | tokenizer = AutoTokenizer.from_pretrained("allenai/longformer-base-4096") 161 | # Load pre-computed mention table 162 | with open(args.mention_file) as f: 163 | mention_table = json.load(f) 164 | 165 | # Load global entity set 166 | with open(args.entity_file) as f: 167 | entity_set = json.load(f) 168 | 169 | with gzip.open(args.tar_path, "wt", compresslevel=1) as f: 170 | for data in yield_jsonl_lines(args.src_path): 171 | data['anchors'] = list() 172 | encoded_input = tokenizer(data['input'], return_offsets_mapping=True) 173 | 174 | if len(encoded_input['input_ids']) >= 4096: 175 | print(data['id']) 176 | print(len(encoded_input['input_ids'])) 177 | continue 178 | 179 | for mention in data['mentions']: 180 | # Convert character index to token index for each mention span 181 | char_start = mention['anchor']['start'] 182 | char_end = mention['anchor']['end'] 183 | token_start = search_token_index(char_start, encoded_input["offset_mapping"]) 184 | token_end = search_token_index(char_end, encoded_input["offset_mapping"]) 185 | 186 | data['anchors'].append([token_start, token_end, 187 | mention['anchor']['text'], mention['entisty']]) 188 | 189 | # Remove illegal or duplicate labels 190 | data['entities'] = dedupe_labels(data['entities']) 191 | data['topics'] = dedupe_labels(data['topics']) 192 | 193 | # Remove illegal labels in anchors (note that different mentions with same label will be kept) 194 | data['anchors'] = dedupe_labels(data['anchors'], has_mention=True) 195 | 196 | # Read candidate entities from pre-computed mention table 197 | # or generate one random label from global entity set if cannot find a match 198 | data['candidates'] = list() 199 | for mention in data['anchors']: 200 | mention_text = mention[2] 201 | if mention_text in mention_table: 202 | data['candidates'].append(mention_table[mention_text]) 203 | else: 204 | data['candidates'].append([mention[-1], random.sample(entity_set, k=1)[0]]) 205 | 206 | remove_mention_text(data['anchors']) 207 | 208 | # Skip illegal instance with empty input or no labels/mentions 209 | if len(data['input']) == 0 or (len(data['entities']) == 0 or len(data['topics']) == 0) or len(data['anchors']) == 0: 210 | continue 211 | 212 | data.pop('paragraphs') 213 | data.pop('mentions') 214 | 215 | f.write(json.dumps(data) + "\n") 216 | -------------------------------------------------------------------------------- /data_processing/create_wned_data.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | import os 9 | import glob 10 | import json 11 | import xmltodict 12 | import xml.etree.ElementTree as ET 13 | from pathlib import Path 14 | 15 | 16 | def save_data(filename, data, encoding = 'utf8'): 17 | with open(filename, 'w', encoding=encoding) as fout: 18 | for res in data: 19 | json.dump(res, fout, ensure_ascii=False) 20 | fout.write("\n") 21 | 22 | def parse_wned_dataset(dataset_folder): 23 | """Convert WNED dataset to entity tagging format 24 | - dataset_folder: path to WNED dataset 25 | return: List[dict] 26 | """ 27 | 28 | dataset_name = dataset_folder.name 29 | xml_filepath = dataset_folder / f"{dataset_name}.xml" 30 | rawtext_folder = dataset_folder / "RawText" 31 | 32 | tree = ET.parse(xml_filepath) 33 | root = tree.getroot() 34 | 35 | processed_data = [] 36 | 37 | for document in root.findall('document'): 38 | doc_name = document.get('docName') 39 | doc_id = dataset_name + '-' + doc_name 40 | 41 | with open(rawtext_folder / doc_name, 'r') as fin: 42 | # Extract input context 43 | input_context = fin.read() 44 | 45 | # Extract entity annotations 46 | entity_set = set() 47 | for annotation in document.findall('annotation'): 48 | entity = annotation.find('wikiName').text 49 | if entity is not None and entity != 'NIL': 50 | # print(entity) 51 | entity_set.add(entity) 52 | 53 | processed_data.append({'id': doc_id, 'input': input_context, 'output': list(entity_set)}) 54 | 55 | return processed_data 56 | 57 | 58 | if __name__ == "__main__": 59 | datasets_folder = "/eval_datasets/basic_data/test_datasets/wned-datasets" 60 | for dataset_path in glob.glob(f'{datasets_folder}/*/'): 61 | print("Currently processing, ", dataset_path) 62 | 63 | dataset_path = Path(dataset_path) 64 | et_data = parse_wned_dataset(dataset_path) 65 | 66 | output_path = "processed_data" 67 | os.makedirs(output_path, exist_ok=True) 68 | output_file = output_path / f"{dataset_path.name}.jsonl" 69 | 70 | save_data(output_file, et_data) 71 | -------------------------------------------------------------------------------- /data_processing/data_analysis.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | from pathlib import Path 9 | from collections import Counter 10 | import json 11 | 12 | import pandas as pd 13 | import plotly.express as px 14 | 15 | 16 | def load_data(filename): 17 | data = [] 18 | with open(filename, 'r') as fin: 19 | lines = fin.readlines() 20 | for line in lines: 21 | data.append(json.loads(line)) 22 | 23 | return data 24 | 25 | def get_label_counter(dataset): 26 | label_list = [] 27 | for instance in dataset: 28 | label_list.extend(instance['output']) 29 | label_counter = Counter(label_list) 30 | 31 | return label_counter 32 | 33 | def create_data_frame(path: Path): 34 | """Create a DataFrame for AIDA data""" 35 | data_rows = [] 36 | label_list = [] 37 | with open(path, 'r', encoding='utf-8') as reader: 38 | for line in reader: 39 | data = json.loads(line) 40 | data_rows.append({ 41 | "id": data['id'], 42 | "input": data['input'], 43 | "labels": data['output'], 44 | "num_labels": len(data['output']), 45 | "input_len": len(data['input'].split(' ')) 46 | }) 47 | label_list.extend(data['output']) 48 | 49 | df = pd.DataFrame(data_rows) 50 | label_counter = Counter(label_list) 51 | 52 | return df, label_counter 53 | 54 | def visualize_data(path: Path): 55 | df, label_counter = create_data_frame(path) 56 | 57 | # Distribution of input length 58 | print(f"Total number of samples: {len(df)} \n") 59 | print('Input length: \n') 60 | fig = px.histogram(df.input_len, x="input_len") 61 | fig.show() 62 | 63 | # label distribution 64 | print('Label Distribution:\n') 65 | fig = px.histogram(df.num_labels, x="num_labels", nbins=20) 66 | fig.show() 67 | 68 | print('Most common labels: \n') 69 | label_df = pd.DataFrame(label_counter.most_common(), columns=["label", "count"]) 70 | print(label_df.head) 71 | px.bar(label_df.head(50), x="label", y="count", title="Most common labels").show() 72 | 73 | print(f"Unique labels: {len(label_counter)}") 74 | print(f"Total labels: {sum(label_counter.values())}") 75 | 76 | # Outliers with large number of labels 77 | num_outliers_50 = len(df[df.num_labels > 50]) 78 | num_outliers_100 = len(df[df.num_labels > 100]) 79 | print('Number of outliers:\n') 80 | print(f"# of samples with 50+ labels: {num_outliers_50} ({num_outliers_50 / len(df) * 100:.4f} %)") 81 | print(f"# of samples with 100+ labels: {num_outliers_100} ({num_outliers_100 / len(df) * 100:.4f} %)") 82 | 83 | print(f"--------------------------------------------") 84 | 85 | def generate_report(df: pd.DataFrame): 86 | """Generate the distribution of num_labels and report number of outliers""" 87 | 88 | fig = px.histogram(df.num_labels, x="num_labels", nbins=20) 89 | fig.show() 90 | 91 | num_outliers_50 = len(df[df.num_labels > 50]) 92 | num_outliers_100 = len(df[df.num_labels > 100]) 93 | print(f"Total number of samples: {len(df)}") 94 | print(f"# of samples with 50+ labels: {num_outliers_50} ({num_outliers_50 / len(df) * 100:.4f} %)") 95 | print(f"# of samples with 100+ labels: {num_outliers_100} ({num_outliers_100 / len(df) * 100:.4f} %)") 96 | print(f"--------------------------------------------") 97 | 98 | 99 | if __name__ == "__main__": 100 | data_path = "/GET/data/msnbc.jsonl" 101 | visualize_data(data_path) -------------------------------------------------------------------------------- /data_processing/extract_data.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | import json 9 | import os 10 | from collections import defaultdict 11 | import mwparserfromhell 12 | 13 | 14 | """ 15 | Utilities to parse XMC data into more convenient json format. 16 | 17 | Example usage: 18 | 19 | LABELS_PATH = YOUR_BASE_DIR/raw/AmazonCat-13K.raw/Yf.txt 20 | INPUT_PATH = YOUR_BASE_DIR/raw/AmazonCat-13K.raw/trn.json 21 | OUTPUT_PATH = YOUR_BASE_DIR/AmazonCat_13k_train.json 22 | 23 | labels = load_labels(LABELS_PATH) 24 | parse_data(INPUT_PATH, OUTPUT_PATH, labels, wiki_data=False) 25 | """ 26 | 27 | 28 | def load_labels(input_path): 29 | labels = [] 30 | try: 31 | with open(input_path, "r") as ip_fp: 32 | for line in ip_fp: 33 | labels.append(line) 34 | except UnicodeDecodeError: 35 | with open(input_path, "r", encoding="ISO-8859-1") as ip_fp: 36 | for line in ip_fp: 37 | labels.append(line) 38 | 39 | return labels 40 | 41 | def clean_tag(ip_tag): 42 | op_tag = ip_tag.split("->")[1] 43 | op_tag = op_tag.strip() 44 | op_tag = op_tag.replace('_', ' ') 45 | return op_tag.strip() 46 | 47 | def parse_data(ip_json_path, op_json_path, labels, wiki_data=False, to_print=10): 48 | cnt = 0 49 | with open(ip_json_path, "r") as ip_fp, open(op_json_path, "w") as op_fp: 50 | for line in ip_fp: 51 | if cnt % 100000 == 0: 52 | print(f"{cnt} lines loaded!") 53 | op_ex = defaultdict() 54 | ip_ex = json.loads(line) 55 | 56 | op_ex['uid'] = ip_ex['uid'] 57 | if wiki_data: 58 | op_ex['input'] = ip_ex['title'].replace("_", " ").strip() + " " + mwparserfromhell.parse(ip_ex['content']).strip_code().strip() 59 | op_ex['output'] = [clean_tag(labels[tag_idx]) for tag_idx in ip_ex['target_ind']] 60 | else: 61 | op_ex['input'] = ip_ex['title'].strip() + " " + ip_ex['content'].strip() 62 | op_ex['output'] = [labels[tag_idx].strip() for tag_idx in ip_ex['target_ind']] 63 | 64 | if cnt < to_print: 65 | print("===============================================") 66 | print(f"UID: {op_ex['uid']}") 67 | print(f"Input: {op_ex['input'][:2000]}") 68 | print(f"Output: {op_ex['output']}") 69 | op_fp.write(json.dumps(op_ex) + "\n") 70 | cnt += 1 -------------------------------------------------------------------------------- /data_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | import json 9 | from itertools import chain 10 | from torch.utils.data import Dataset 11 | import random 12 | 13 | 14 | class Seq2SetDataset(Dataset): 15 | def __init__(self, path, sep, replace_underscores=False, 16 | read_per_line=False, single_label=False, output_key="output"): 17 | self.path = path 18 | self.sep = sep 19 | self.data = None 20 | self.label_set = None 21 | self.replace_underscores = replace_underscores 22 | self.read_per_line = read_per_line 23 | self.single_label = single_label 24 | 25 | self.output_key = output_key # key for gold annotations 26 | 27 | def read_data(self): 28 | 29 | print("Reading", self.path) 30 | 31 | with open(self.path, "r") as f: 32 | if self.path.split(".")[-1] == "jsonl": 33 | self.data = [json.loads(line) for line in f.readlines()] 34 | else: 35 | self.data = json.load(f) 36 | 37 | self.label_set = set(chain.from_iterable(row[self.output_key] for row in self.data)) 38 | 39 | def dedupe_data(self, tokenizer): 40 | for i, line in enumerate(self.data): 41 | new_output = [] 42 | tokenized_labels = [] 43 | for label in line[self.output_key]: 44 | tokenized = tuple(tokenizer(label).input_ids[:-1]) 45 | if tokenized not in tokenized_labels: 46 | new_output.append(label) 47 | tokenized_labels.append(tokenized) 48 | else: 49 | data_id = line["id"] if "id" in line else line["uid"] 50 | print(f"Line {data_id} has repeated labels!") 51 | line[self.output_key] = new_output 52 | if i % 10000 == 0: 53 | print(i) 54 | 55 | def __len__(self): 56 | assert ( 57 | self.data is not None 58 | ), "Attempted to access data before loading it. Call read_data() first" 59 | return len(self.data) 60 | 61 | def order_labels(self, label_seq): 62 | # Comment this line to disable label shuffling 63 | random.shuffle(label_seq) 64 | return label_seq 65 | 66 | def label_to_str(self, label): 67 | return label.replace("_", " ") if self.replace_underscores else label 68 | 69 | def output_str_to_labels(self, output_str): 70 | return [x.strip() for x in output_str.split(self.sep)] 71 | 72 | def token_ids_to_labels(self, tokenizer, token_ids): 73 | return self.output_str_to_labels( 74 | tokenizer.decode(token_ids).split("")[0].replace("", "") 75 | ) 76 | 77 | def make_example(self, idx): 78 | 79 | assert ( 80 | self.data is not None 81 | ), "Attempted to access data before loading it. Call read_data() first" 82 | 83 | example = self.data[idx] 84 | id = example["id"] if "id" in example else example["uid"] 85 | input = example["input"] 86 | labels = [random.choice(example[self.output_key])] if self.single_label else self.order_labels(example[self.output_key]) 87 | 88 | out_str = self.sep.join(self.label_to_str(label) for label in labels) 89 | 90 | return (input.lower().strip(), out_str, id) 91 | 92 | def __getitem__(self, idx): 93 | return self.make_example(idx) 94 | 95 | def return_all_inputs(self): 96 | qs = [] 97 | for i in range(len(self.data)): 98 | qs.append(self.make_example(i)[0]) 99 | return qs 100 | 101 | def get_all_labels(self): 102 | assert ( 103 | self.label_set is not None 104 | ), "Attempted to access data before loading it. Call read_data() first" 105 | return {self.label_to_str(label) for label in self.label_set} 106 | -------------------------------------------------------------------------------- /decode_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | from collections import defaultdict 9 | from itertools import groupby 10 | from math import exp 11 | import torch 12 | 13 | 14 | class TrieNode: 15 | def __init__(self): 16 | self.labels = set() 17 | self.children = defaultdict(TrieNode) 18 | 19 | def freeze(self): 20 | self.children = dict(self.children) 21 | for child in self.children.values(): 22 | child.freeze() 23 | 24 | class LabelTrie: 25 | EOS_TOKEN_ID = 1 26 | PAD_TOKEN_ID = 0 27 | 28 | def __init__(self, trie, raw_label_id_map, label_id_map, sep_token, sep_token_id, num_labels): 29 | self.trie = trie 30 | self.raw_label_id_map = raw_label_id_map 31 | self.label_id_map = label_id_map 32 | self.id_raw_label_map = {v: k for k, v in raw_label_id_map.items()} 33 | self.sep_token_id = sep_token_id 34 | self.sep_token = sep_token 35 | self.EOS_TOKEN_ID = 1 36 | self.PAD_TOKEN_ID = 0 37 | self.num_labels = num_labels 38 | 39 | @classmethod 40 | def from_labels(cls, label_set, tokenizer, sep_token): 41 | 42 | sep_token_id = tokenizer.convert_tokens_to_ids(sep_token) 43 | 44 | raw_label_id_map = {} 45 | id_raw_label_map = {} 46 | label_id_map = {} 47 | for i, label in enumerate(label_set): 48 | if tuple(tokenizer(label).input_ids[:-1]) in label_id_map: 49 | label_id = label_id_map[tuple(tokenizer(label).input_ids[:-1])] 50 | print("WARNING: Different labels have the same tokenization:") 51 | print(" ", id_raw_label_map[label_id]) 52 | print(" ", label) 53 | raw_label_id_map[label] = label_id 54 | else: 55 | raw_label_id_map[label] = i 56 | id_raw_label_map[i] = label 57 | label_id_map[tuple(tokenizer(label).input_ids[:-1])] = i 58 | 59 | # Root node 60 | trie = TrieNode() 61 | 62 | # Add labels to trie. 63 | for label in label_id_map.keys(): 64 | current = trie 65 | for token in list(label): 66 | current.labels.add(label_id_map[label]) 67 | current = current.children[token] 68 | current.labels.add(label_id_map[label]) 69 | 70 | # Allow the label to be finished with either a [SEP] token (new label to come) 71 | # or the EOS token, meaning this is the last label 72 | continue_node = current.children[sep_token_id] 73 | continue_node.labels.add(label_id_map[label]) 74 | 75 | stop_node = current.children[cls.EOS_TOKEN_ID] 76 | stop_node.labels.add(label_id_map[label]) 77 | 78 | trie.freeze() 79 | 80 | return cls(trie, raw_label_id_map, label_id_map, sep_token, sep_token_id, len(set(label_set))) 81 | 82 | def print_strings(self, node=None, depth=0): 83 | # Use this to debug smaller tries 84 | if node is None: 85 | node = self.trie 86 | for key, child in node.children.items(): 87 | if key == self.EOS_TOKEN_ID: 88 | print(" " * depth, node.labels) 89 | else: 90 | print(" " * depth, key) 91 | self.print_strings(child, depth=depth+1) 92 | 93 | def next_allowed_token(self, input_ids, max_labels=999, permutation_only=False): 94 | # Given a sequence of token ids (already decoded), determine what the next token can be 95 | # if we want to produce valid labels. 96 | 97 | # See what labels we have already decoded and whether we're already in the middle 98 | # of a new label. 99 | completed_label_ids = set() 100 | label_in_progress = [] 101 | for token in input_ids: 102 | if token == self.PAD_TOKEN_ID: 103 | continue 104 | if token == self.sep_token_id: 105 | completed_label_ids.add(self.label_id_map[tuple(label_in_progress)]) 106 | label_in_progress = [] 107 | elif token == self.EOS_TOKEN_ID: 108 | return [self.PAD_TOKEN_ID], completed_label_ids 109 | else: 110 | label_in_progress.append(token) 111 | 112 | # If we're in the middle of a label, only allow valid continuations 113 | current = self.trie 114 | for token in label_in_progress: 115 | current = current.children.get(token, None) 116 | if current is None: 117 | return [self.EOS_TOKEN_ID], completed_label_ids 118 | 119 | 120 | if not permutation_only: 121 | # This is normal decoding 122 | return [ 123 | token 124 | for token, child in current.children.items() 125 | if ( 126 | # Make sure the next token can lead to label that hasn't been produced yet 127 | child.labels.difference(completed_label_ids) 128 | 129 | # If we output [SEP] token, there will be one more label, so in total 2 more than 130 | # what we have now. If that's more than max_labels, we need to produce EOS. 131 | and not ( 132 | token == self.sep_token_id 133 | and len(completed_label_ids) + 2 > max_labels 134 | ) 135 | ) 136 | ], completed_label_ids 137 | 138 | else: 139 | # A special type of decoding that produces a permutation of all the label in the trie. 140 | # This is used for the EM experiments where we want to score the GT labels and find the best order. 141 | allowed_tokens = [ 142 | token 143 | for token, child in current.children.items() 144 | if ( 145 | # Make sure the next token can lead to label that hasn't been produced yet 146 | child.labels.difference(completed_label_ids) 147 | 148 | # Terminate precisely we prduced all labels (special rule for permutation_only) 149 | and not ( 150 | token == self.sep_token_id 151 | and len(completed_label_ids) + 1 == self.num_labels 152 | ) 153 | and not ( 154 | token == self.EOS_TOKEN_ID 155 | and len(completed_label_ids) + 1 < self.num_labels 156 | ) 157 | ) 158 | ] 159 | #print(allowed_tokens, completed_label_ids, input_ids) 160 | return allowed_tokens, completed_label_ids 161 | 162 | def compute_targets(self, input_ids, input_labels_str): 163 | # Given a sampled label sequence s (tokenized and raw text version), produce a target tensor 164 | # of dimension batch_size x max_output_len x num_tokens. Target is 1 for token index i if at a given 165 | # time t token i is a possibly correct continuation of the sequence s[:t], 0 otherwise. 166 | # Token s[t+1] will definitely have target 1, but occasionally (especially at the start of the label) 167 | # many more correct tokens are possible. 168 | 169 | targets = torch.zeros( 170 | input_ids.size(0), # num_batches 171 | input_ids.size(1), # max seq len 172 | 32128, # hack, num_tokens 173 | ) 174 | 175 | for sample_idx, (sequence, labels_str) in enumerate( 176 | zip(input_ids, input_labels_str) 177 | ): 178 | 179 | # First any positve label can be decoded. This set wll shrink over time. 180 | allowed_labels = { 181 | self.raw_label_id_map[l] for l in labels_str.split(self.sep_token) 182 | } 183 | 184 | trie_state = self.trie 185 | finished = False 186 | 187 | for token_idx, token_tensor in enumerate(sequence): 188 | 189 | # No option just to pad after EOS 190 | if finished: 191 | targets[sample_idx][token_idx][self.PAD_TOKEN_ID] = 1 192 | continue 193 | 194 | curr_token = token_tensor.item() 195 | # possible_tokens = [] # Uncomment to debug 196 | 197 | # Allow all tokens that could lead to a valid label 198 | for next_token, child in trie_state.children.items(): 199 | if ( 200 | child.labels.intersection(allowed_labels) 201 | # If this was the last allowed label, we shouldn't return [SEP] 202 | and not ( 203 | next_token == self.sep_token_id and len(allowed_labels) == 1 204 | ) 205 | # Otherwise we shouldn't return tokens if there's more labels 206 | and not ( 207 | next_token == self.EOS_TOKEN_ID and len(allowed_labels) > 1 208 | ) 209 | ): 210 | targets[sample_idx][token_idx][next_token] = 1 211 | # possible_tokens.append(next_token) # Uncomment to debug 212 | 213 | # print(possible_tokens, curr_token) # Uncomment to debug 214 | 215 | if curr_token == self.EOS_TOKEN_ID or curr_token == self.PAD_TOKEN_ID: 216 | # The decoding has finished, it's gonna be pad tokens from now on 217 | finished = True 218 | 219 | elif curr_token == self.sep_token_id: 220 | # A label was finished, the label id can be found in the child node 221 | # corresponding to [SEP], it should be the only id possible at some point. 222 | child_labels = trie_state.children[self.sep_token_id].labels 223 | assert len(child_labels) == 1 224 | produced_label = list(child_labels)[0] 225 | 226 | # For debugging 227 | if produced_label not in allowed_labels: 228 | print("Produced label not in allowed labels") 229 | print(self.id_raw_label_map[produced_label]) 230 | print([self.id_raw_label_map[x] for x in allowed_labels]) 231 | 232 | assert produced_label in allowed_labels 233 | allowed_labels = allowed_labels.difference({produced_label}) 234 | 235 | # Restart traversing the trie 236 | trie_state = self.trie 237 | 238 | else: 239 | # we're in the middle of producing a label, just traverse the trie 240 | trie_state = trie_state.children[curr_token] 241 | 242 | return targets 243 | 244 | 245 | def score_labels_by_probability_sum(sequences, scores): 246 | # Score a given label by integrating over all the sequences returned from the beam search. 247 | # p(l) = \sum_{b \in beams}(int(l \in b) * p(b)) 248 | 249 | label_scores = defaultdict(float) 250 | label_positions = defaultdict(list) 251 | 252 | for ans_ids, score in zip(sequences, scores): 253 | for pos, label in enumerate(ans_ids): 254 | label_scores[label] += exp(score) 255 | label_positions[label].append(pos) 256 | 257 | return list( 258 | sorted( 259 | label_scores.keys(), 260 | key=lambda x: -( 261 | label_scores[x] 262 | - 0.0001 * sum(label_positions[x]) / len(label_positions[x]) 263 | ), 264 | ) 265 | ) 266 | 267 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | import argparse 9 | import json 10 | import numpy as np 11 | import pprint 12 | from collections import defaultdict, OrderedDict 13 | 14 | 15 | def load_data(filename): 16 | data = [] 17 | if ".jsonl" in filename: 18 | with open(filename, "r") as fin: 19 | lines = fin.readlines() 20 | for line in lines: 21 | data.append(json.loads(line)) 22 | else: 23 | with open(filename, "r") as fin: 24 | data = json.load(fin) 25 | return data 26 | 27 | 28 | def normalize_label(s): 29 | def remove_underscore(text): 30 | return text.replace("_", " ") 31 | 32 | def lower(text): 33 | return text.lower() 34 | 35 | return remove_underscore(lower(s)) 36 | 37 | 38 | def get_rank(guess_item, gold_item): 39 | 40 | ground_truth = set() 41 | 42 | for label in gold_item["output"]: 43 | ground_truth.add(normalize_label(label)) 44 | 45 | rank = [] 46 | for label in guess_item["output"]: 47 | if normalize_label(label) in ground_truth: 48 | rank.append(True) 49 | else: 50 | rank.append(False) 51 | 52 | return rank, len(ground_truth) 53 | 54 | 55 | # 1. Precision computation 56 | def _precision(rank): 57 | if len(rank) == 0: 58 | return 0 59 | 60 | p = rank.count(True) / len(rank) 61 | 62 | return p 63 | 64 | def _precision_at_k(rank, k): 65 | 66 | # precision @ k 67 | p = rank[:k].count(True) / k 68 | 69 | return p 70 | 71 | def _propensity_scored_precision_at_k(rank, guess_inv_propensity_scores, gold_inv_propensity_scores, k): 72 | 73 | # Sum of inverse propensities for correct labels in top K 74 | num = sum(ps for correct, ps in zip(rank[:k], guess_inv_propensity_scores[:k]) if correct) 75 | # The maximum achievable propensity sum in top K 76 | den = sum(gold_inv_propensity_scores[:k]) 77 | 78 | return num / den 79 | 80 | 81 | # 2. Recall computation 82 | def _recall(rank, num_distinct_labels): 83 | if num_distinct_labels == 0: 84 | return 0 85 | 86 | r = rank.count(True) / num_distinct_labels 87 | 88 | return r 89 | 90 | def _recall_at_k(rank, num_distinct_labels, k): 91 | 92 | r = rank[:k].count(True) / num_distinct_labels 93 | 94 | return r 95 | 96 | 97 | # 3. F1 computation 98 | def _f1(rank, num_distinct_labels): 99 | 100 | p = _precision(rank) 101 | r = _recall(rank, num_distinct_labels) 102 | 103 | try: 104 | f1 = (2 * p * r) / (p + r) 105 | except ZeroDivisionError: 106 | f1 = 0 107 | 108 | return f1 109 | 110 | def _f1_at_k(rank, num_distinct_labels, k): 111 | 112 | p = _precision_at_k(rank, k) 113 | r = _recall_at_k(rank, num_distinct_labels, k) 114 | 115 | try: 116 | f1 = (2 * p * r) / (p + r) 117 | except ZeroDivisionError: 118 | f1 = 0 119 | 120 | return f1 121 | 122 | 123 | def get_ranking_metrics(guess_item, gold_item, ks, inv_propensity_scores_dict=None): 124 | 125 | P_at_k = {"precision@{}".format(k): 0 for k in sorted(ks) if k > 0} 126 | PSP_at_k = {"PSP@{}".format(k): 0 for k in sorted(ks) if k > 0} 127 | R_at_k = {"recall@{}".format(k): 0 for k in sorted(ks) if k > 1} 128 | F1_at_k = {"f1@{}".format(k): 0 for k in sorted(ks) if k > 1} 129 | 130 | assert ( 131 | "output" in guess_item 132 | ), "guess should provide the output for {}".format(guess_item['id']) 133 | 134 | for k in ks: 135 | 136 | # 0. get rank 137 | rank, num_distinct_labels = get_rank(guess_item, gold_item) 138 | 139 | if inv_propensity_scores_dict is not None: 140 | # The less frequent the predicted label is the more it adds to the score 141 | guess_inv_propensity_scores = [inv_propensity_scores_dict[normalize_label(l)] for l in guess_item['output']] 142 | 143 | # Top propensity scores are used to compute the denominator 144 | gold_inv_propensity_scores = sorted( 145 | [inv_propensity_scores_dict[normalize_label(l)] for l in gold_item['output']], 146 | key=lambda x: -x 147 | ) 148 | 149 | 150 | if num_distinct_labels > 0: 151 | 152 | # 1. precision 153 | P_at_k["precision@{}".format(k)] = _precision_at_k(rank, k) 154 | 155 | if inv_propensity_scores_dict is not None: 156 | PSP_at_k["PSP@{}".format(k)] = _propensity_scored_precision_at_k( 157 | rank, 158 | guess_inv_propensity_scores, 159 | gold_inv_propensity_scores, 160 | k, 161 | ) 162 | 163 | # 2. recall 164 | R_at_k["recall@{}".format(k)] = _recall_at_k(rank, num_distinct_labels, k) 165 | 166 | # 3. F1 score 167 | F1_at_k["f1@{}".format(k)] = _f1_at_k(rank, num_distinct_labels, k) 168 | 169 | if inv_propensity_scores_dict is not None: 170 | return {**P_at_k, **PSP_at_k, **R_at_k, **F1_at_k} 171 | else: 172 | return {**P_at_k, **R_at_k, **F1_at_k} 173 | 174 | def compute(gold_dataset, guess_dataset, ks=None, inv_propensity_scores_dict=None): 175 | 176 | result = {"precision": 0, "recall": 0, "f1": 0} 177 | if ks: 178 | ks = sorted([int(x) for x in ks]) 179 | for k in ks: 180 | if k > 0: 181 | result["precision@{}".format(k)] = 0.0 182 | if inv_propensity_scores_dict is not None: 183 | result["PSP@{}".format(k)] = 0.0 184 | 185 | if k > 1: 186 | result["recall@{}".format(k)] = 0.0 187 | result["f1@{}".format(k)] = 0.0 188 | 189 | assert len(guess_dataset) == len( 190 | gold_dataset 191 | ), "different size gold: {} guess: {}".format(len(guess_dataset), len(gold_dataset)) 192 | 193 | for guess, gold in zip(guess_dataset, gold_dataset): 194 | id_key = "id" if "id" in gold else "uid" 195 | try: 196 | assert ( 197 | str(gold[id_key]).strip() == str(guess["id"]).strip() 198 | ), "Items must have same order with same IDs" 199 | except KeyError: 200 | print(gold) 201 | print(guess) 202 | raise Exception 203 | 204 | for guess_item, gold_item in zip(guess_dataset, gold_dataset): 205 | 206 | # Aggregate rank-independent metrics 207 | rank, num_distinct_labels = get_rank(guess_item, gold_item) 208 | result["precision"] += _precision(rank) 209 | result["recall"] += _recall(rank, num_distinct_labels) 210 | result["f1"] += _f1(rank, num_distinct_labels) 211 | 212 | # Aggregate rank-based metrics 213 | if ks: 214 | ranking_metrics = get_ranking_metrics( 215 | guess_item, gold_item, ks, inv_propensity_scores_dict=inv_propensity_scores_dict 216 | ) 217 | for k in ks: 218 | if k > 0: 219 | result["precision@{}".format(k)] += ranking_metrics[ 220 | "precision@{}".format(k) 221 | ] 222 | if inv_propensity_scores_dict is not None: 223 | result["PSP@{}".format(k)] += ranking_metrics[ 224 | "PSP@{}".format(k) 225 | ] 226 | 227 | if k > 1: 228 | result["recall@{}".format(k)] += ranking_metrics["recall@{}".format(k)] 229 | result["f1@{}".format(k)] += ranking_metrics["f1@{}".format(k)] 230 | 231 | if len(guess_dataset) > 0: 232 | result["precision"] /= len(guess_dataset) 233 | result["recall"] /= len(guess_dataset) 234 | result["f1"] /= len(guess_dataset) 235 | 236 | if ks: 237 | for k in ks: 238 | if k > 0: 239 | result["precision@{}".format(k)] /= len(guess_dataset) 240 | if inv_propensity_scores_dict is not None: 241 | result["PSP@{}".format(k)] /= len(guess_dataset) 242 | if k > 1: 243 | result["recall@{}".format(k)] /= len(guess_dataset) 244 | result["f1@{}".format(k)] /= len(guess_dataset) 245 | 246 | return OrderedDict(sorted(result.items(), key=lambda x: x[0])) 247 | 248 | def inv_propensity_formula(label_frequency, num_instances, A=0.55, B=1.5): 249 | # Code based on: https://fburl.com/rp6rqhvg 250 | # Related paper: http://manikvarma.org/pubs/jain16.pdf 251 | 252 | C = (np.log(num_instances)-1)*np.power(B+1, A) 253 | return 1.0 + C*np.power(label_frequency+B, -A) 254 | 255 | 256 | def compute_inv_label_propensities(label_frequencies, num_instances): 257 | return defaultdict( 258 | lambda: inv_propensity_formula(0, num_instances), 259 | {normalize_label(label): inv_propensity_formula(freq, num_instances) for label, freq in label_frequencies.items()} 260 | ) 261 | 262 | 263 | def evaluate(gold, guess, ks=None, freqs=None): 264 | pp = pprint.PrettyPrinter(indent=4) 265 | 266 | gold_records = load_data(gold) 267 | guess_records = load_data(guess) 268 | 269 | assert len(gold_records) == len(guess_records) 270 | 271 | if freqs is not None: 272 | with open(freqs) as f: 273 | label_frequencies = json.load(f) 274 | inv_propensity_scores_dict = compute_inv_label_propensities(label_frequencies, 450000) #TODO 275 | else: 276 | inv_propensity_scores_dict = None 277 | 278 | 279 | # 2. get retrieval metrics 280 | result = compute(gold_records, guess_records, ks, inv_propensity_scores_dict) 281 | 282 | pp.pprint(result) 283 | return result 284 | 285 | 286 | if __name__ == "__main__": 287 | parser = argparse.ArgumentParser() 288 | parser.add_argument("--guess", help="Guess KILT file") 289 | parser.add_argument("--gold", help="Gold KILT file") 290 | 291 | parser.add_argument( 292 | "--ks", 293 | type=str, 294 | required=False, 295 | default=None, 296 | help="Comma separated list of positive integers for recall@k and precision@k", 297 | ) 298 | 299 | parser.add_argument( 300 | "--freqs", 301 | type=str, 302 | required=False, 303 | default=None, 304 | help="JSON file containing frequencies of labels in training data. If this is specified, we'll compued propensity-weighted P@K metrics." 305 | ) 306 | 307 | args = parser.parse_args() 308 | 309 | if args.ks: 310 | args.ks = [int(k) for k in args.ks.split(",") if int(k) > 0] 311 | 312 | evaluate(args.gold, args.guess, args.ks, args.freqs) 313 | -------------------------------------------------------------------------------- /eval_psp.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | """ 9 | Compute Propensity-Scored Precision @ K metric for a given dataset and prediction file. 10 | 11 | Related paper: http://manikvarma.org/pubs/jain16.pdf 12 | 13 | Sample commands: 14 | 15 | python3 eval_psp.py \ 16 | --train /private/home/danielsimig/ECG/data/Eurlex_4.3K/trn.jsonl \ 17 | --gold /private/home/danielsimig/ECG/data/Eurlex_4.3K/tst.jsonl \ 18 | --guess /private/home/danielsimig/ECG/data/experiments/eurlex_msm_t5-base_b32_48429737/epoch39/test_preds_sum_prob.jsonl 19 | 20 | python3 eval_psp.py --A 0.5 --B 0.4\ 21 | --train /private/home/danielsimig/ECG/data/Wikipedia_1M/trn.jsonl \ 22 | --gold /private/home/danielsimig/ECG/data/Wikipedia_1M/sample_test.jsonl \ 23 | --guess /private/home/danielsimig/ECG/data/experiments/wikipedia_1m_van_t5-base_b32_48432070/epoch0/test_preds_sum_prob.jsonl 24 | 25 | """ 26 | 27 | 28 | import argparse 29 | import numpy as np 30 | import json 31 | import scipy.sparse as sp 32 | 33 | # You'll need to install https://github.com/kunaldahiya/pyxclib for this 34 | from xclib.evaluation.xc_metrics import Metrics, compute_inv_propesity, psprecision 35 | 36 | def normalize_label(s): 37 | def remove_underscore(text): 38 | return text.replace("_", " ") 39 | 40 | def lower(text): 41 | return text.lower() 42 | 43 | return remove_underscore(lower(s)) 44 | 45 | 46 | def load_label_map(train_file, test_file): 47 | # Assign an index to every label so that we can build a sparse representation later 48 | label_idx_map = {} 49 | 50 | print("scanning train file") 51 | with open(train_file) as f: 52 | for line in f: 53 | for label in json.loads(line[:-1])['output']: 54 | if normalize_label(label) not in label_idx_map: 55 | label_idx_map[normalize_label(label)] = len(label_idx_map) 56 | 57 | print("scanning test file") 58 | with open(test_file) as f: 59 | for line in f: 60 | for label in json.loads(line[:-1])['output']: 61 | if normalize_label(label) not in label_idx_map: 62 | label_idx_map[normalize_label(label)] = len(label_idx_map) 63 | 64 | return label_idx_map 65 | 66 | 67 | def load_file(path, label_idx_map, preserve_order=False): 68 | 69 | # By default just load binary matrices. For preds use anything that preserves order 70 | score_fn = lambda x: 1 71 | if preserve_order: 72 | score_fn = lambda x: 1000-x 73 | 74 | unseen_labels = set() 75 | with open(path) as f: 76 | rows = [] 77 | for line in f: 78 | row = json.loads(line[:-1]) 79 | rows.append(row['output']) 80 | 81 | data = [] 82 | ids = [] 83 | label_ids = [] 84 | 85 | for rid, row in enumerate(rows): 86 | for rank, label in enumerate(row): 87 | if normalize_label(label) not in label_idx_map: 88 | unseen_labels.add(normalize_label(label)) 89 | continue 90 | data.append(score_fn(rank)) 91 | ids.append(rid) 92 | label_ids.append(label_idx_map[normalize_label(label)]) 93 | 94 | Y = sp.csc_matrix((data, (ids, label_ids)), shape = (len(rows), len(label_idx_map))) 95 | 96 | if len(unseen_labels) > 0: 97 | print(path, "contains", len(unseen_labels), "unseen labels") 98 | return Y 99 | 100 | 101 | 102 | if __name__ == "__main__": 103 | 104 | parser = argparse.ArgumentParser() 105 | parser.add_argument("--guess", help="Model prediction filepath") 106 | parser.add_argument("--gold", help="Gold labels filepath") 107 | parser.add_argument("--train", help="Training data filepath") 108 | parser.add_argument( 109 | "--A", 110 | type=float, 111 | default=0.55, 112 | help="Parameter A for calculating propensity. Should be 0.5 for Wiki, 0.55 otherwise.", 113 | ) 114 | parser.add_argument( 115 | "--B", 116 | type=float, 117 | default=1.5, 118 | help="Parameter B for calculating propensity. Should be 0.4 for Wiki, 1.5 otherwise", 119 | ) 120 | parser.add_argument( 121 | "--ks", 122 | type=str, 123 | required=False, 124 | default="1,3,5,10", 125 | help="Comma separated list of positive integers for recall@k and precision@k", 126 | ) 127 | 128 | args = parser.parse_args() 129 | ks = [int(k) for k in args.ks.split(",")] 130 | 131 | label_idx_map = load_label_map(args.train, args.gold) 132 | 133 | print("Loading train file...") 134 | train_Y = load_file(args.train, label_idx_map) 135 | print("Shape: ", train_Y.shape) 136 | 137 | print("loading gold file") 138 | gold_Y = load_file(args.gold, label_idx_map) 139 | print("Shape: ", gold_Y.shape) 140 | 141 | print("loading guess file") 142 | guess_Y = load_file(args.guess, label_idx_map, preserve_order=True) 143 | print("Shape: ", guess_Y.shape) 144 | 145 | inv_label_propensities = compute_inv_propesity(train_Y, A=0.5, B = 0.4) 146 | psp_at_k = psprecision(guess_Y, gold_Y, inv_label_propensities, k=max(ks)) 147 | for k in ks: 148 | print(f"PSP@{k}: {psp_at_k[k-1]:.3f}") -------------------------------------------------------------------------------- /finetune_s2s.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | import os 9 | import shutil 10 | 11 | import torch 12 | 13 | from data_utils import Seq2SetDataset 14 | from decode_utils import LabelTrie 15 | from local_configs import LOCAL_DATA_DIR, OUTPUT_DIR 16 | from params import ArgumentsS2S 17 | from s2s_model import make_s2s_model, train_s2s 18 | from utils import prepare_tokenizer 19 | 20 | parser = ArgumentsS2S() 21 | s2s_args = parser.parse_args() 22 | 23 | # Prepare directory where we store output 24 | if not os.path.isdir(os.path.join(OUTPUT_DIR, "experiments")): 25 | os.mkdir(os.path.join(OUTPUT_DIR, "experiments")) 26 | output_dir = os.path.join(OUTPUT_DIR, "experiments", s2s_args.output_dir) 27 | 28 | if s2s_args.output_dir == "tmp": 29 | # This is the default value, we will just overwrite files here. Do not store important stuff in tmp! 30 | if os.path.isdir(output_dir): 31 | shutil.rmtree(output_dir) 32 | else: 33 | # If we specified a directory, we want to make sure we don't accidentally overwrite something 34 | assert not os.path.isdir(output_dir), "Output directory already exists!" 35 | 36 | if not os.path.isdir(output_dir): 37 | os.mkdir(output_dir) 38 | 39 | args_str = '\n'.join(f'{k} : {v}' for k, v in vars(s2s_args).items()) 40 | print("*** ARGS ***\n", args_str, "\n******") 41 | with open(os.path.join(output_dir, "args.txt"), "w") as f: 42 | f.write(args_str) 43 | 44 | # Initialize model and tokenizer 45 | model_path = None 46 | if s2s_args.checkpoint_dir: 47 | model_path = os.path.join(OUTPUT_DIR, s2s_args.checkpoint_dir) 48 | 49 | s2s_scheduler, s2s_optimizer, tokenizer, model, best_eval, start_epoch = make_s2s_model( 50 | model_name=s2s_args.model_name_or_path, from_file=model_path, device=s2s_args.device 51 | ) 52 | 53 | # Construct trie used for decoding and multi-option loss 54 | prepare_tokenizer(tokenizer) 55 | sep_token = tokenizer.sep_token if tokenizer.sep_token else "[SEP]" 56 | 57 | # Prepare datasets 58 | train_data = os.path.join(LOCAL_DATA_DIR, s2s_args.train_file_path) 59 | test_data = os.path.join(LOCAL_DATA_DIR, s2s_args.test_file_path) 60 | s2s_train_set = Seq2SetDataset( 61 | train_data, sep_token, 62 | replace_underscores=s2s_args.replace_underscores, 63 | single_label=s2s_args.single_label, 64 | output_key=s2s_args.output_key 65 | ) 66 | s2s_dev_set = Seq2SetDataset( 67 | test_data, sep_token, 68 | replace_underscores=s2s_args.replace_underscores, 69 | single_label=s2s_args.single_label, 70 | output_key=s2s_args.output_key 71 | ) 72 | 73 | s2s_train_set.read_data() 74 | s2s_dev_set.read_data() 75 | 76 | # Amazon dataset has a particular label that sometimes causes the same label to appear 77 | # repeatedly in the output after tokenization, breaking the set assumption and thus the code. 78 | print("Sanity checking data...") 79 | s2s_train_set.dedupe_data(tokenizer) 80 | s2s_dev_set.dedupe_data(tokenizer) 81 | print("Done!") 82 | 83 | 84 | if s2s_args.use_multisoftmax: 85 | # To allow for any possible next token at a given time, we need a label trie 86 | # that will compute the corresponding target tensors. 87 | print("Computing label trie...") 88 | label_trie = LabelTrie.from_labels( 89 | s2s_train_set.get_all_labels().union(s2s_dev_set.get_all_labels()), 90 | tokenizer, 91 | sep_token, 92 | ) 93 | print("Done!") 94 | else: 95 | label_trie = None 96 | 97 | # for using on multiple gpu using DataParallel 98 | if s2s_args.data_parallel: 99 | s2s_model = torch.nn.DataParallel(model) 100 | else: 101 | s2s_model = model 102 | 103 | # Restore epoch num 104 | if start_epoch is None: 105 | start_epoch = 0 106 | else: 107 | start_epoch += 1 108 | 109 | train_s2s( 110 | s2s_model, 111 | tokenizer, 112 | label_trie, 113 | s2s_optimizer, 114 | s2s_scheduler, 115 | best_eval, 116 | s2s_train_set, 117 | s2s_dev_set, 118 | s2s_args, 119 | output_dir, 120 | start_epoch 121 | ) 122 | -------------------------------------------------------------------------------- /load_and_eval_s2s.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | from collections import defaultdict 9 | import functools 10 | import json 11 | import math 12 | import os 13 | import random 14 | from time import time, strftime 15 | import torch 16 | from torch.nn.parallel.data_parallel import DataParallel 17 | from tqdm import tqdm 18 | 19 | from torch.utils.data import DataLoader, SequentialSampler 20 | 21 | from data_utils import Seq2SetDataset 22 | from decode_utils import score_labels_by_probability_sum, LabelTrie 23 | from local_configs import LOCAL_DATA_DIR, OUTPUT_DIR 24 | from s2s_model import make_s2s_batch, compute_metrics 25 | from transformers import AutoModelForSeq2SeqLM, AutoTokenizer 26 | from utils import prepare_tokenizer 27 | from params import ArgumentsS2S 28 | 29 | def show_oov_samples(preds, golds, seen_label_set): 30 | preds_at_pos = defaultdict(int) 31 | oov_samples_by_pos = defaultdict(list) 32 | for p, g in zip(preds, golds): 33 | for pos, label in enumerate(p): 34 | preds_at_pos[pos] += 1 35 | if label not in seen_label_set: 36 | oov_samples_by_pos[pos].append((label, p, g)) 37 | 38 | for pos in range(5): 39 | oov_rate = len(oov_samples_by_pos[pos]) * 1.0 / preds_at_pos[pos] if preds_at_pos[pos] > 0 else 0 40 | print( 41 | f"\n *** OOV rate at position {pos}: {oov_rate} ***" 42 | ) 43 | if len(oov_samples_by_pos[pos]) >= 5: 44 | for oov_pred, prediction, gold_labels in random.sample( 45 | oov_samples_by_pos[pos], 5 46 | ): 47 | print("\n OOV prediction: ", oov_pred) 48 | print(" Predicted: ", prediction) 49 | print(" Gold: ", gold_labels) 50 | 51 | 52 | def decode_s2s(model, dataset, label_set, tokenizer, args, sep_token): 53 | model.eval() 54 | # make iterator 55 | train_sampler = SequentialSampler(dataset) 56 | model_collate_fn = functools.partial( 57 | make_s2s_batch, 58 | model=model, 59 | tokenizer=tokenizer, 60 | max_i_len=args.max_i_length, 61 | max_o_len=args.max_o_length, 62 | device=args.device, 63 | add_example_ids=True, 64 | ) 65 | data_loader = DataLoader( 66 | dataset, 67 | batch_size=args.eval_batch_size, 68 | sampler=train_sampler, 69 | collate_fn=model_collate_fn, 70 | ) 71 | epoch_iterator = tqdm(data_loader, desc="Iteration", disable=True) 72 | # accumulate loss since last print 73 | loc_steps = 0 74 | loc_loss = 0.0 75 | st_time = time() 76 | 77 | if args.decode_on_lattice: 78 | # Build trie of all possible labels 79 | label_trie = LabelTrie.from_labels(label_set, tokenizer, sep_token) 80 | 81 | def decode_on_label_lattice(batch_id, input_ids): 82 | next_tokens, completed_label_ids = label_trie.next_allowed_token( 83 | input_ids.tolist()[1:] 84 | ) 85 | 86 | # Uncomment this to debug the decoding process 87 | # if batch_id == 0: 88 | # input_tokens = tokenizer.decode(input_ids[1:], skip_special_tokens=False) 89 | # print(" " * len (input_ids), input_ids.tolist()[1:], f'"{input_tokens}"', next_tokens if len(next_tokens) < 100 else "all tokens") 90 | 91 | return next_tokens 92 | 93 | with torch.no_grad(): 94 | preds_by_method = defaultdict(list) 95 | golds = [] 96 | for step, batch_inputs in enumerate(epoch_iterator): 97 | 98 | example_ids = batch_inputs["example_ids"] 99 | del batch_inputs["example_ids"] 100 | 101 | pre_loss = model(**batch_inputs)[0] 102 | 103 | if isinstance(model, DataParallel): 104 | model_gen = model.module 105 | loss = pre_loss.sum() / pre_loss.shape[0] 106 | else: 107 | model_gen = model 108 | loss = pre_loss 109 | 110 | generated_ids = model_gen.generate( 111 | input_ids=batch_inputs["input_ids"], 112 | attention_mask=batch_inputs["attention_mask"], 113 | min_length=1, 114 | max_length=args.max_o_length + 1, 115 | do_sample=False, 116 | early_stopping=True, 117 | num_beams=args.decode_beams, 118 | temperature=1.0, 119 | top_k=None, 120 | top_p=None, 121 | eos_token_id=tokenizer.eos_token_id, 122 | no_repeat_ngram_size=3, 123 | num_return_sequences=args.decode_beams, 124 | decoder_start_token_id=tokenizer.bos_token_id, 125 | prefix_allowed_tokens_fn=decode_on_label_lattice 126 | if args.decode_on_lattice 127 | else None, 128 | return_dict_in_generate=True, 129 | output_scores=True, 130 | ) 131 | 132 | if args.decode_beams > 1: 133 | # Use beam search to find most likely sequences and integrate over those sequences 134 | 135 | # Decoder doesn't have a separate dimension for beam size, need to reshape. 136 | # num_examples might be less than args.eval_batch_size in the last batch 137 | num_examples = generated_ids["sequences"].size()[0] // args.decode_beams 138 | for example_id, sequences, scores, labels in zip( 139 | example_ids, 140 | generated_ids["sequences"].view( 141 | num_examples, args.decode_beams, -1 142 | ), 143 | generated_ids["sequences_scores"].view( 144 | num_examples, args.decode_beams, 1 145 | ), 146 | batch_inputs["labels"], 147 | ): 148 | top_label_sequences = [] 149 | top_scores = [] 150 | for sequence in sequences: 151 | top_label_sequences.append( 152 | dataset.token_ids_to_labels(tokenizer, sequence) 153 | ) 154 | 155 | naive_preds = top_label_sequences[0] 156 | filtered_preds = [l for l in naive_preds if l in label_set] 157 | sum_prob_preds = score_labels_by_probability_sum( 158 | top_label_sequences, 159 | scores, 160 | ) 161 | filtered_sum_prob_preds = [ 162 | l for l in sum_prob_preds if l in label_set 163 | ] 164 | 165 | preds_by_method["naive"].append((example_id, naive_preds)) 166 | preds_by_method["filtered"].append((example_id, filtered_preds)) 167 | preds_by_method["sum_prob"].append((example_id, sum_prob_preds)) 168 | preds_by_method["filtered_sum_prob"].append((example_id, filtered_sum_prob_preds)) 169 | golds.append(dataset.token_ids_to_labels(tokenizer, labels)) 170 | else: 171 | # No beam search, work with the simple greedy output sequence 172 | for example_id, output_token_ids, label_token_ids in zip( 173 | example_ids, generated_ids["sequences"], batch_inputs["labels"] 174 | ): 175 | 176 | naive_preds = dataset.token_ids_to_labels( 177 | tokenizer, output_token_ids 178 | ) 179 | filtered_preds = [l for l in naive_preds if l in label_set] 180 | gold = dataset.token_ids_to_labels(tokenizer, label_token_ids) 181 | 182 | preds_by_method["naive"].append((example_id, naive_preds)) 183 | preds_by_method["filtered"].append((example_id, filtered_preds)) 184 | golds.append(gold) 185 | 186 | loc_loss += loss.item() 187 | loc_steps += 1 188 | if step % args.print_freq == 0: 189 | print( 190 | "{:5d} of {:5d}".format(step, len(dataset) // args.eval_batch_size) 191 | ) 192 | 193 | # For the impatient kind 194 | # if len(golds) % 100 == 0: 195 | # for method, preds in preds_by_method.items(): 196 | # metrics = compute_metrics(preds, golds) 197 | # print(f" {method}: " + " ".join(f"{k}: {v:.3f}" for k, v in metrics.items())) 198 | 199 | print("Loss: {:.3f}".format(loc_loss / loc_steps)) 200 | 201 | for method, preds in preds_by_method.items(): 202 | 203 | metrics = compute_metrics([x[1] for x in preds], golds) 204 | print(f"{method}: " + " ".join(f"{k}: {v:.3f}" for k, v in metrics.items())) 205 | method_str = method + ("_lattice" if args.decode_on_lattice else "") + f"_{args.dataset_name}" 206 | preds_file = os.path.join(OUTPUT_DIR, args.output_dir, f"{args.pred_file_prefix}_{method_str}.jsonl") 207 | with open(preds_file, "w") as outfile: 208 | for id, preds in preds: 209 | outfile.write( 210 | json.dumps( 211 | { 212 | "id": id, 213 | "output": preds 214 | } 215 | ) + "\n" 216 | ) 217 | 218 | golds_file = os.path.join(OUTPUT_DIR, args.output_dir, f"{args.pred_file_prefix}_golds.jsonl") 219 | with open(golds_file, "w") as outfile: 220 | json.dump(golds, outfile) 221 | 222 | # Show stats and examples for the raw model output 223 | preds = [x[1] for x in preds_by_method["naive"]] 224 | 225 | show_oov_samples(preds, golds, label_set) 226 | 227 | 228 | parser = ArgumentsS2S(decode_mode=True) 229 | args = parser.parse_args() 230 | 231 | model = AutoModelForSeq2SeqLM.from_pretrained( 232 | os.path.join(OUTPUT_DIR, args.output_dir) 233 | ).to(args.device) 234 | 235 | tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path) 236 | prepare_tokenizer(tokenizer) 237 | sep = tokenizer.sep_token if tokenizer.sep_token else "[SEP]" 238 | print("sep token: ", sep) 239 | 240 | s2s_train_set = Seq2SetDataset( 241 | os.path.join(LOCAL_DATA_DIR, args.train_file_path), 242 | sep, 243 | replace_underscores=args.replace_underscores, 244 | output_key=args.output_key 245 | ) 246 | s2s_dev_set = Seq2SetDataset( 247 | os.path.join(LOCAL_DATA_DIR, args.test_file_path), 248 | sep, 249 | replace_underscores=args.replace_underscores, 250 | output_key=args.output_key 251 | ) 252 | 253 | s2s_train_set.read_data() 254 | s2s_dev_set.read_data() 255 | 256 | # Read candidate entities from external file 257 | all_labels_set = set() 258 | if args.label_set_file: 259 | print("Reading entity list from file for creating entity trie: ", args.label_set_file) 260 | with open(args.label_set_file, 'r') as fin: 261 | all_labels_set = set(json.load(fin)) 262 | 263 | print("Finish loading entity set.") 264 | 265 | else: 266 | train_label_set = s2s_train_set.get_all_labels() 267 | dev_label_set = s2s_dev_set.get_all_labels() 268 | print("# of distinct labels in train set:", len(train_label_set)) 269 | print("# of distinct labels in dev set:", len(dev_label_set)) 270 | print("# of new labels in dev set:", len(dev_label_set.difference(train_label_set))) 271 | all_labels_set = train_label_set.union(dev_label_set) 272 | 273 | 274 | decode_s2s(model, s2s_dev_set, all_labels_set, tokenizer, args, sep) 275 | -------------------------------------------------------------------------------- /local_configs.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | # Cluster-specific configs go here. This file is in gitignore so changes won't 8 | # impact other checkouts of the code. 9 | 10 | 11 | LOCAL_DATA_DIR = "/private/home/xiaodu/GET/data/" 12 | OUTPUT_DIR = "/private/home/xiaodu/GET" 13 | -------------------------------------------------------------------------------- /params.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | import argparse 9 | import os 10 | 11 | 12 | class ArgumentsS2S(argparse.ArgumentParser): 13 | def __init__( 14 | self, 15 | add_s2s_args=True, 16 | decode_mode=False, 17 | description="S2S parser", 18 | ): 19 | super().__init__( 20 | description=description, 21 | allow_abbrev=False, 22 | conflict_handler="resolve", 23 | formatter_class=argparse.HelpFormatter, 24 | add_help=add_s2s_args, 25 | ) 26 | 27 | if add_s2s_args: 28 | self.add_s2s_args() 29 | 30 | if decode_mode: 31 | self.add_decode_args() 32 | 33 | def add_s2s_args(self): 34 | parser = self.add_argument_group("Common Arguments") 35 | 36 | # Directories 37 | parser.add_argument( 38 | "--model_name_or_path", 39 | default="t5-large", 40 | type=str, 41 | help="Pretrained model name or path", 42 | ) 43 | parser.add_argument( 44 | "--output_dir", 45 | type=str, 46 | default="tmp", 47 | help="Model output path", 48 | ) 49 | parser.add_argument( 50 | "--train_file_path", 51 | type=str, 52 | help="Path of the training file relative to the working folder defined in local_configs.py" 53 | ) 54 | parser.add_argument( 55 | "--test_file_path", 56 | type=str, 57 | help="Path of the test file relative to the working folder defined in local_configs.py" 58 | ) 59 | parser.add_argument( 60 | "--checkpoint_dir", 61 | type=str, 62 | help="Load pretained model from this directory" 63 | ) 64 | 65 | # GPU use 66 | parser.add_argument( 67 | "--device", 68 | default="cuda", 69 | type=str, 70 | help="Device: CPU or CUDA", 71 | ) 72 | parser.add_argument( 73 | "--data_parallel", 74 | action="store_true", 75 | help="Use torch.DataParallel(). Don't set device when using this!" 76 | ) 77 | 78 | # Model settings 79 | parser.add_argument( 80 | "--max_i_length", 81 | default=512, 82 | type=int, 83 | help="Max input length", 84 | ) 85 | parser.add_argument( 86 | "--max_o_length", 87 | default=256, 88 | type=int, 89 | help="Max output length", 90 | ) 91 | parser.add_argument( 92 | "--single_label", 93 | action="store_true", 94 | ) 95 | parser.add_argument( 96 | "--use_multisoftmax", 97 | action="store_true", 98 | ) 99 | 100 | # Train / eval settings 101 | parser.add_argument( 102 | "--train_batch_size", 103 | default=2, 104 | type=int, 105 | ) 106 | parser.add_argument( 107 | "--backward_freq", 108 | default=1, 109 | type=int, 110 | ) 111 | parser.add_argument( 112 | "--learning_rate", 113 | default=2e-4, 114 | type=float, 115 | ) 116 | parser.add_argument( 117 | "--num_epochs", 118 | default=200, 119 | type=int, 120 | ) 121 | parser.add_argument( 122 | "--eval_batch_size", 123 | default=1, 124 | type=int, 125 | ) 126 | parser.add_argument( 127 | "--eval_every_k_epoch", 128 | default=1, 129 | type=int, 130 | ) 131 | parser.add_argument( 132 | "--eval_sampling_rate", 133 | default=None, 134 | type=float, 135 | help="Only use this portion of the eval set.", 136 | ) 137 | parser.add_argument( 138 | "--print_freq", 139 | default=20, 140 | type=int, 141 | ) 142 | parser.add_argument( 143 | "--save_after_every_eval", 144 | action="store_true", 145 | ) 146 | 147 | parser.add_argument( 148 | "--output_key", 149 | default="output", 150 | type=str, 151 | help="The key that points to gold labels in training / test data", 152 | ) 153 | 154 | # Misc 155 | parser.add_argument( 156 | "--replace_underscores", 157 | default=True, 158 | type=bool, 159 | ) 160 | parser.add_argument( 161 | "--use_proxy", 162 | action="store_true", 163 | ) 164 | 165 | def add_decode_args(self): 166 | parser = self.add_argument_group("Decoding-related Arguments") 167 | parser.add_argument( 168 | "--decode_on_lattice", 169 | action="store_true", 170 | ) 171 | parser.add_argument( 172 | "--decode_beams", 173 | default=32, 174 | type=int, 175 | ) 176 | parser.add_argument( 177 | "--pred_file_prefix", 178 | default="test_preds", 179 | type=str, 180 | ) 181 | parser.add_argument( 182 | "--label_set_file", 183 | type=str, 184 | help="Path to the file that contains set of labels for contrained decoding" 185 | ) 186 | parser.add_argument( 187 | "--dataset_name", 188 | type=str, 189 | default="aida", 190 | help="identifier of datasets for naming prediction file" 191 | ) -------------------------------------------------------------------------------- /s2s_model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | import functools 9 | import math 10 | import os 11 | import random 12 | from collections import defaultdict 13 | from time import time, strftime 14 | 15 | import torch 16 | import torch.multiprocessing as mp 17 | from torch.nn.parallel.data_parallel import DataParallel 18 | from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler, SubsetRandomSampler 19 | from tqdm import tqdm 20 | from transformers import ( 21 | AdamW, 22 | AutoModelForSeq2SeqLM, 23 | AutoTokenizer, 24 | get_linear_schedule_with_warmup, 25 | T5ForConditionalGeneration, 26 | ) 27 | from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments 28 | from torch.nn import CrossEntropyLoss 29 | 30 | from data_utils import Seq2SetDataset 31 | 32 | 33 | def compute_metrics(preds, golds): 34 | metrics = defaultdict(float) 35 | 36 | num_datapoints = len(golds) 37 | assert len(preds) == num_datapoints 38 | 39 | for g, p in zip(golds, preds): 40 | g_labels = set(g) 41 | p_labels = set(p) 42 | inter = p_labels.intersection(g_labels) 43 | 44 | metrics["micro_accuracy"] += ( 45 | (1.0) * len(inter) / len(p_labels) if len(p_labels) > 0 else 0.0 46 | ) 47 | metrics["micro_recall"] += (1.0) * len(inter) / len(g_labels) 48 | for k in [1, 3, 5]: 49 | topk_inter = set(p[:k]).intersection(g_labels) 50 | metrics[f"P@{k}"] += (1.0) * len(topk_inter) / k 51 | 52 | return {k: v * 1.0 / num_datapoints for k, v in metrics.items()} 53 | 54 | 55 | def make_s2s_model(model_name="facebook/bart-large", from_file=None, device="cuda"): 56 | tokenizer = AutoTokenizer.from_pretrained(model_name) 57 | if from_file is not None: 58 | model = AutoModelForSeq2SeqLM.from_pretrained(from_file).to(device) 59 | elif "led-base" in model_name: 60 | model = AutoModelForSeq2SeqLM.from_pretrained( 61 | model_name, gradient_checkpointing=True, use_cache=False 62 | ) 63 | else: 64 | model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device) 65 | 66 | s2s_optimizer = None 67 | s2s_scheduler = None 68 | best_eval = None 69 | start_epoch = None 70 | if from_file is not None: 71 | param_dict = torch.load( 72 | os.path.join(from_file, "state_dict.pth"), map_location=device 73 | ) # has model weights, optimizer, and scheduler states 74 | s2s_optimizer = AdamW(model.parameters(), lr=0.0001, eps=1e-8) 75 | s2s_scheduler = get_linear_schedule_with_warmup( 76 | s2s_optimizer, 77 | num_warmup_steps=400, 78 | num_training_steps=1, 79 | ) 80 | s2s_optimizer.load_state_dict(param_dict["optimizer"]) 81 | s2s_scheduler.load_state_dict(param_dict["scheduler"]) 82 | if "loss_spearman" in param_dict["best_eval"]: 83 | best_eval = param_dict["best_eval"]["loss_spearman"] 84 | else: 85 | best_eval = param_dict["best_eval"]["loss"] 86 | 87 | if "epoch" in param_dict: 88 | start_epoch = int(param_dict["epoch"]) 89 | 90 | return s2s_scheduler, s2s_optimizer, tokenizer, model, best_eval, start_epoch 91 | 92 | 93 | def make_s2s_batch( 94 | io_list, 95 | tokenizer, 96 | model, 97 | label_trie=None, 98 | max_i_len=512, 99 | max_o_len=16, 100 | device="cuda:0", 101 | add_example_ids=False 102 | ): 103 | i_ls = [i for i, _, _ in io_list] 104 | o_ls = [o for _, o, _ in io_list] 105 | 106 | i_toks = tokenizer( 107 | i_ls, max_length=max_i_len, padding="max_length", truncation=True 108 | ) 109 | i_ids, i_mask = ( 110 | torch.LongTensor(i_toks["input_ids"]).to(device), 111 | torch.LongTensor(i_toks["attention_mask"]).to(device), 112 | ) 113 | 114 | o_toks = tokenizer( 115 | o_ls, max_length=max_o_len + 1, padding="max_length", truncation=True 116 | ) 117 | 118 | o_ids, o_mask = ( 119 | torch.LongTensor(o_toks["input_ids"]).to(device), 120 | torch.LongTensor(o_toks["attention_mask"]).to(device), 121 | ) 122 | 123 | # Based on HF examples 124 | if isinstance(model, DataParallel): 125 | model = model.module 126 | 127 | if isinstance(model, T5ForConditionalGeneration): 128 | decoder_input_ids = model._shift_right(o_ids) 129 | lm_labels = o_ids 130 | else: 131 | decoder_input_ids = o_ids[:, :-1].contiguous() 132 | lm_labels = o_ids[:, 1:].contiguous().clone() 133 | 134 | model_inputs = { 135 | "input_ids": i_ids, 136 | "attention_mask": i_mask, 137 | "decoder_input_ids": decoder_input_ids, 138 | "labels": lm_labels, 139 | } 140 | 141 | # Compute target for multi-option loss 142 | if label_trie: 143 | model_inputs["targets"] = label_trie.compute_targets(o_ids, o_ls) 144 | if add_example_ids: 145 | model_inputs["example_ids"] = [id for _, _, id in io_list] 146 | 147 | return model_inputs 148 | 149 | 150 | def train_s2s_epoch( 151 | model, 152 | dataset, 153 | tokenizer, 154 | label_trie, 155 | optimizer, 156 | scheduler, 157 | args, 158 | output_dir, 159 | e=0, 160 | curriculum=False, 161 | ): 162 | model.train() 163 | # make iterator 164 | if curriculum: 165 | train_sampler = SequentialSampler(dataset) 166 | else: 167 | train_sampler = RandomSampler(dataset) 168 | 169 | tokenizer.source_len = [0.0, 0.0, 0.0, 0.0] 170 | model_collate_fn = functools.partial( 171 | make_s2s_batch, 172 | model=model, 173 | tokenizer=tokenizer, 174 | label_trie=label_trie, 175 | max_i_len=args.max_i_length, 176 | max_o_len=args.max_o_length, 177 | device=args.device, 178 | ) 179 | data_loader = DataLoader( 180 | dataset, 181 | batch_size=args.train_batch_size, 182 | sampler=train_sampler, 183 | collate_fn=model_collate_fn, 184 | ) 185 | epoch_iterator = tqdm(data_loader, desc="Iteration", disable=True) 186 | 187 | # accumulate loss since last print 188 | loc_steps = 0 189 | loc_loss = 0.0 190 | st_time = time() 191 | for step, batch_inputs in enumerate(epoch_iterator): 192 | 193 | if args.use_multisoftmax: 194 | # Targets matrix that allows for all possible next tokens at a given time 195 | # Dimension: batch_size x max_output_len x num_tokens 196 | # Need to pass label_trie to model_collate_fn for this to work 197 | # Passing this as part of batch_inputs is a hack. The model doesn't actually expect this input so need to pop 198 | targets = batch_inputs.pop("targets", None).to(args.device) 199 | 200 | model_output = model(**batch_inputs) 201 | 202 | if args.use_multisoftmax: 203 | # Variaion on SoftMax that'll allow of to distribute weight over different outcomes 204 | 205 | # We'll compute the loss ourselves outside of the model. We need the raw logits for that. 206 | lm_logits = model_output[1] 207 | 208 | # We want to take the sum of the exp() of the logits correspodinging to possible next tokens. 209 | # For everything else we want 0, this is achieved by exp(-inf) 210 | gt_label_logits = lm_logits.masked_fill(targets == 0, float("-inf")) 211 | 212 | # Intuitively, this removes competition between correct labels 213 | # torch.logsumexp(gt_label_logits, dim=2) is just the multi-option equivalent of the 214 | # log(exp(gt_label_logit)) term in the single label case. 215 | multisoftmax_per_token = -torch.logsumexp( 216 | gt_label_logits, dim=2 217 | ) + torch.logsumexp(lm_logits, dim=2) 218 | 219 | pre_loss = torch.mean(multisoftmax_per_token) 220 | else: 221 | # Use the vanilla softmax of the HF model of choice 222 | pre_loss = model_output[0] 223 | 224 | loss = pre_loss 225 | loss.mean().backward() 226 | # optimizer 227 | if step % args.backward_freq == 0: 228 | optimizer.step() 229 | scheduler.step() 230 | model.zero_grad() 231 | 232 | loc_loss += loss.mean().item() 233 | loc_steps += 1 234 | if step % args.print_freq == 0 or step == 1: 235 | print( 236 | "{:2d} {:5d} of {:5d} \t L: {:.3f} \t -- {:.3f}".format( 237 | e, 238 | step, 239 | len(dataset) // args.train_batch_size, 240 | loc_loss / loc_steps, 241 | time() - st_time, 242 | ) 243 | ) 244 | with open(os.path.join(output_dir, "output_train.txt"), "a+") as dev_file: 245 | dev_file.write( 246 | "{:2d} {:5d} of {:5d} \t L: {:.3f} \t -- {:.3f}\n".format( 247 | e, 248 | step, 249 | len(dataset) // args.train_batch_size, 250 | loc_loss / loc_steps, 251 | time() - st_time, 252 | ) 253 | ) 254 | loc_loss = 0 255 | loc_steps = 0 256 | 257 | 258 | def eval_s2s_epoch(model, dataset, tokenizer, args, output_dir, sample=None): 259 | model.eval() 260 | 261 | print("Eval with sampling rate", sample) 262 | 263 | num_examples = len(dataset) 264 | if sample is not None: 265 | num_examples_used = int(num_examples * sample) 266 | torch.manual_seed(0) 267 | eval_sampler = SubsetRandomSampler(indices=torch.randperm(num_examples)[:num_examples_used]) 268 | num_examples = num_examples_used 269 | else: 270 | eval_sampler = SequentialSampler(dataset) 271 | 272 | model_collate_fn = functools.partial( 273 | make_s2s_batch, 274 | model=model, 275 | tokenizer=tokenizer, 276 | max_i_len=args.max_i_length, 277 | max_o_len=args.max_o_length, 278 | device=args.device, 279 | ) 280 | data_loader = DataLoader( 281 | dataset, 282 | batch_size=args.eval_batch_size, 283 | sampler=eval_sampler, 284 | collate_fn=model_collate_fn, 285 | ) 286 | epoch_iterator = tqdm(data_loader, desc="Iteration", disable=True) 287 | # accumulate loss since last print 288 | loc_steps = 0 289 | loc_loss = 0.0 290 | st_time = time() 291 | 292 | with torch.no_grad(): 293 | preds = [] 294 | golds = [] 295 | for step, batch_inputs in enumerate(epoch_iterator): 296 | pre_loss = model(**batch_inputs)[0] 297 | 298 | if isinstance(model, DataParallel): 299 | model_gen = model.module 300 | loss = pre_loss.sum() / pre_loss.shape[0] 301 | else: 302 | model_gen = model 303 | loss = pre_loss 304 | 305 | generated_ids = model_gen.generate( 306 | input_ids=batch_inputs["input_ids"], 307 | attention_mask=batch_inputs["attention_mask"], 308 | min_length=1, 309 | max_length=args.max_o_length + 1, 310 | do_sample=False, 311 | early_stopping=True, 312 | num_beams=1, 313 | temperature=1.0, 314 | top_k=None, 315 | top_p=None, 316 | eos_token_id=tokenizer.eos_token_id, 317 | no_repeat_ngram_size=3, 318 | num_return_sequences=1, 319 | decoder_start_token_id=tokenizer.bos_token_id, 320 | ) 321 | 322 | generated_ids = list(generated_ids) 323 | 324 | raw_preds = [ 325 | (tokenizer.decode(ans_ids).split("")[0].replace("", "")) 326 | for ans_ids in generated_ids 327 | ] 328 | 329 | pred = [dataset.output_str_to_labels(pred) for pred in raw_preds] 330 | 331 | gold = [ 332 | dataset.output_str_to_labels( 333 | tokenizer.decode(ans_ids).split("")[0].replace("", "") 334 | ) 335 | for ans_ids in batch_inputs["labels"] 336 | ] 337 | 338 | # Print to quickly debug predictions 339 | # print(generated_ids[0]) 340 | # print("raw_pred", raw_preds[0]) 341 | # print("pred", pred[0]) 342 | # print("gold", gold[0]) 343 | 344 | golds.extend(gold) 345 | preds.extend(pred) 346 | 347 | loc_loss += loss.item() 348 | loc_steps += 1 349 | if step % args.print_freq == 0: 350 | print( 351 | "{:5d} of {:5d} \t L: {:.6f} \t -- {:.3f}".format( 352 | step, 353 | num_examples // args.eval_batch_size, 354 | loc_loss / loc_steps, 355 | time() - st_time, 356 | ) 357 | ) 358 | 359 | with open(os.path.join(output_dir, "predictions_dev.txt"), "a") as dev_file: 360 | for g, p in zip(golds, preds): 361 | dev_file.write(str(g) + "\t" + str(p) + "\n") 362 | 363 | metrics = compute_metrics(preds, golds) 364 | metric_str = "L: {:.3f} ".format(loc_loss / loc_steps) + " ".join( 365 | f"{k}: {v:.3f}" for k, v in metrics.items() 366 | ) 367 | 368 | with open(os.path.join(output_dir, "output_dev.txt"), "a") as dev_file: 369 | dev_file.write(metric_str + "\n") 370 | print(metric_str) 371 | 372 | return loc_loss, metrics["P@3"] # Use P@3 to decide best model 373 | 374 | 375 | def save_checkpoint(output_dir, s2s_model, eval_acc, s2s_optimizer, s2s_scheduler, epoch): 376 | start_time = time() 377 | print("Saving checkpoint starts at", strftime('%l:%M%p %Z on %b %d, %Y')) 378 | 379 | best_eval = eval_acc 380 | m_save_dict = { 381 | "optimizer": s2s_optimizer.state_dict(), 382 | "scheduler": s2s_scheduler.state_dict(), 383 | "best_eval": {"em": eval_acc}, 384 | "epoch": epoch 385 | } 386 | print("Saving model {}".format(output_dir)) 387 | 388 | if isinstance(s2s_model, DataParallel): 389 | s2s_model.module.save_pretrained(output_dir) 390 | else: 391 | s2s_model.save_pretrained(output_dir) 392 | 393 | torch.save(m_save_dict, os.path.join(output_dir, "state_dict.pth")) 394 | print("Saving checkpoint took", int(time() - start_time), "seconds") 395 | 396 | 397 | def train_s2s( 398 | s2s_model, 399 | s2s_tokenizer, 400 | label_trie, 401 | s2s_optimizer, 402 | s2s_scheduler, 403 | best_eval, 404 | s2s_train_dset, 405 | s2s_valid_dset, 406 | s2s_args, 407 | output_dir, 408 | start_epoch=0 409 | ): 410 | if s2s_optimizer is None: 411 | s2s_optimizer = AdamW( 412 | s2s_model.parameters(), lr=s2s_args.learning_rate, eps=1e-8 413 | ) 414 | if s2s_scheduler is None: 415 | s2s_scheduler = get_linear_schedule_with_warmup( 416 | s2s_optimizer, 417 | num_warmup_steps=400, 418 | num_training_steps=(s2s_args.num_epochs + 1) 419 | * math.ceil(len(s2s_train_dset) / s2s_args.train_batch_size), 420 | ) 421 | for e in range(start_epoch, start_epoch + s2s_args.num_epochs): 422 | train_s2s_epoch( 423 | s2s_model, 424 | s2s_train_dset, 425 | s2s_tokenizer, 426 | label_trie, 427 | s2s_optimizer, 428 | s2s_scheduler, 429 | s2s_args, 430 | output_dir, 431 | e, 432 | curriculum=(e == 0), 433 | ) 434 | 435 | # Decoding can be slow, we can control how often we want to do that 436 | if e % s2s_args.eval_every_k_epoch == s2s_args.eval_every_k_epoch - 1: 437 | start_time = time() 438 | print("Eval starts at", strftime('%l:%M%p %Z on %b %d, %Y')) 439 | 440 | eval_l, eval_acc = eval_s2s_epoch( 441 | s2s_model, 442 | s2s_valid_dset, 443 | s2s_tokenizer, 444 | s2s_args, 445 | output_dir, 446 | sample=s2s_args.eval_sampling_rate 447 | ) 448 | 449 | print("Evaluation took", int(time() - start_time), "seconds") 450 | 451 | if s2s_args.save_after_every_eval: 452 | checkpoint_dir = os.path.join(output_dir, f"epoch{e}") 453 | os.mkdir(checkpoint_dir) 454 | save_checkpoint(checkpoint_dir, s2s_model, eval_acc, s2s_optimizer, s2s_scheduler, e) 455 | elif best_eval == None or eval_acc > best_eval: 456 | save_checkpoint(output_dir, s2s_model, eval_acc, s2s_optimizer, s2s_scheduler, e) 457 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/GROOV/e15f399a99add2bb52247113718e7d9fd188f58f/tests/__init__.py -------------------------------------------------------------------------------- /tests/test_data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/GROOV/e15f399a99add2bb52247113718e7d9fd188f58f/tests/test_data/__init__.py -------------------------------------------------------------------------------- /tests/test_data/gold.jsonl: -------------------------------------------------------------------------------- 1 | {"id": "B0002OQPSK", "input": "Amazon.com: Intimo Men's Classic Silk Knit Thong, Navy, X-Large: Clothing This ultra-light and elastic men's Silk Knit thong is cool all weather comfort with the freedom that only a thong can offer. This elegant thong is crafted from pure, knitted silk with a gentle stretch that hugs your body.", "output": ["clothing & accessories", "g-strings & thongs", "men", "underwear"]} -------------------------------------------------------------------------------- /tests/test_data/gold_multi.jsonl: -------------------------------------------------------------------------------- 1 | {"id": "B000C7YDQ8", "input": "Standard Motor Products SG121 Oxygen Sensor Oxygen Sensor, OEM Fit And Quality", "output": ["automotive", "oxygen", "replacement parts", "sensors"]} 2 | {"id": "1580421717", "input": "One Move Checkmates: 201 Instructive and Challenging Mates for Beginners Eric Schiller, author of more than 100 chess books, is widely considered one of the foremost chess analysts, writers and teachers.", "output": ["books", "chess", "humor & entertainment", "puzzles & games"]} 3 | {"id": "B0008109SO", "input": "Star Wars Death Star Model Kit This is a model kit for the Star Wars Death Star.", "output": ["hobbies", "model building kits & tools", "toys & games"]} -------------------------------------------------------------------------------- /tests/test_data/pred_all.jsonl: -------------------------------------------------------------------------------- 1 | {"id": "B0002OQPSK", "input": "Amazon.com: Intimo Men's Classic Silk Knit Thong, Navy, X-Large: Clothing This ultra-light and elastic men's Silk Knit thong is cool all weather comfort with the freedom that only a thong can offer. This elegant thong is crafted from pure, knitted silk with a gentle stretch that hugs your body.", "output": ["clothing & accessories", "men", "underwear"]} -------------------------------------------------------------------------------- /tests/test_data/pred_miss.jsonl: -------------------------------------------------------------------------------- 1 | {"id": "B0002OQPSK", "input": "Amazon.com: Intimo Men's Classic Silk Knit Thong, Navy, X-Large: Clothing This ultra-light and elastic men's Silk Knit thong is cool all weather comfort with the freedom that only a thong can offer. This elegant thong is crafted from pure, knitted silk with a gentle stretch that hugs your body.", "output": ["clothing", "g-strings", "accessories", "pants"]} -------------------------------------------------------------------------------- /tests/test_data/pred_multi.jsonl: -------------------------------------------------------------------------------- 1 | {"id": "B000C7YDQ8", "input": "Standard Motor Products SG121 Oxygen Sensor Oxygen Sensor, OEM Fit And Quality", "output": ["automotive", "motors", "sensors"]} 2 | {"id": "1580421717", "input": "One Move Checkmates: 201 Instructive and Challenging Mates for Beginners Eric Schiller, author of more than 100 chess books, is widely considered one of the foremost chess analysts, writers and teachers.", "output": ["books", "entertainment", "games"]} 3 | {"id": "B0008109SO", "input": "Star Wars Death Star Model Kit This is a model kit for the Star Wars Death Star.", "output": ["tools", "toys", "games"]} -------------------------------------------------------------------------------- /tests/test_data/pred_part.jsonl: -------------------------------------------------------------------------------- 1 | {"id": "B0002OQPSK", "input": "Amazon.com: Intimo Men's Classic Silk Knit Thong, Navy, X-Large: Clothing This ultra-light and elastic men's Silk Knit thong is cool all weather comfort with the freedom that only a thong can offer. This elegant thong is crafted from pure, knitted silk with a gentle stretch that hugs your body.", "output": ["clothes", "thongs", "men", "underwear"]} -------------------------------------------------------------------------------- /tests/test_eval.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | import unittest 9 | import importlib.resources 10 | 11 | import sys 12 | sys.path.append("..") 13 | import eval 14 | from tests import test_data 15 | 16 | 17 | class TestEval(unittest.TestCase): 18 | 19 | # Test the computation of rank-based metrics 20 | def test_rank_metrics(self, ks=[1, 4]): 21 | 22 | with importlib.resources.open_text(test_data, "gold.jsonl") as gold_file: 23 | 24 | gold_records = eval.load_data(gold_file.name) 25 | 26 | # 1. no matching 27 | with importlib.resources.open_text( 28 | test_data, "pred_miss.jsonl" 29 | ) as guess_file: 30 | 31 | guess_records = eval.load_data(guess_file.name) 32 | 33 | # compute evaluation metrics 34 | result = eval.compute(gold_records, guess_records, ks) 35 | 36 | self.assertEqual(result["precision"], 0.0) 37 | self.assertEqual(result["recall"], 0.0) 38 | 39 | self.assertEqual(result["precision@1"], 0.0) 40 | self.assertEqual(result["precision@4"], 0.0) 41 | self.assertEqual(result["recall@4"], 0.0) 42 | 43 | # 2. partial matching 44 | with importlib.resources.open_text( 45 | test_data, "pred_part.jsonl" 46 | ) as guess_file: 47 | 48 | guess_records = eval.load_data(guess_file.name) 49 | 50 | # compute evaluation metrics 51 | result = eval.compute(gold_records, guess_records, ks) 52 | 53 | self.assertAlmostEqual(result["precision"], 1 / 2) 54 | self.assertAlmostEqual(result["recall"], 1 / 2) 55 | self.assertAlmostEqual(result["f1"], 1 / 2) 56 | 57 | self.assertEqual(result["precision@1"], 0.0) 58 | self.assertAlmostEqual(result["precision@4"], 1 / 2) 59 | self.assertAlmostEqual(result["recall@4"], 1 / 2) 60 | self.assertAlmostEqual(result["f1@4"], 1 / 2) 61 | 62 | # 3. all correct prediction 63 | with importlib.resources.open_text( 64 | test_data, "pred_all.jsonl" 65 | ) as guess_file: 66 | 67 | guess_records = eval.load_data(guess_file.name) 68 | 69 | # compute evaluation metrics 70 | result = eval.compute(gold_records, guess_records, ks) 71 | 72 | self.assertEqual(result["precision"], 1.0) 73 | self.assertAlmostEqual(result["recall"], 3 / 4) 74 | self.assertAlmostEqual(result["f1"], 6 / 7) 75 | 76 | self.assertEqual(result["precision@1"], 1.0) 77 | self.assertAlmostEqual(result["precision@4"], 3 / 4) 78 | self.assertAlmostEqual(result["recall@4"], 3 / 4) 79 | self.assertAlmostEqual(result["f1@4"], 3 / 4) 80 | 81 | 82 | # Test the computation of rank-independent metrics 83 | def test_stat_metrics(self): 84 | 85 | with importlib.resources.open_text(test_data, "gold.jsonl") as gold_file: 86 | 87 | gold_records = eval.load_data(gold_file.name) 88 | 89 | with importlib.resources.open_text( 90 | test_data, "pred_part.jsonl" 91 | ) as guess_file: 92 | 93 | guess_records = eval.load_data(guess_file.name) 94 | result = eval.compute(gold_records, guess_records) 95 | 96 | self.assertAlmostEqual(result["precision"], 1 / 2) 97 | self.assertAlmostEqual(result["recall"], 1 / 2) 98 | self.assertAlmostEqual(result["f1"], 1 / 2) 99 | 100 | 101 | # Test the computation of aggregated metrics among multiple samples 102 | def test_average_metrics(self): 103 | 104 | with importlib.resources.open_text(test_data, "gold_multi.jsonl") as gold_file: 105 | 106 | gold_records = eval.load_data(gold_file.name) 107 | 108 | with importlib.resources.open_text( 109 | test_data, "pred_multi.jsonl" 110 | ) as guess_file: 111 | 112 | guess_records = eval.load_data(guess_file.name) 113 | result = eval.compute(gold_records, guess_records) 114 | 115 | self.assertAlmostEqual(result["precision"], 1 / 3) 116 | self.assertAlmostEqual(result["recall"], 1 / 4) 117 | self.assertAlmostEqual(result["f1"], 2 / 7) 118 | 119 | 120 | if __name__ == '__main__': 121 | unittest.main() -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | # All rights reserved. 3 | 4 | # This source code is licensed under the license found in the 5 | # LICENSE file in the root directory of this source tree. 6 | 7 | 8 | import csv 9 | import os 10 | import pandas 11 | 12 | from itertools import chain 13 | 14 | 15 | def create_schema_from_train_test(train_labels, test_labels, schema_file): 16 | labels = {} 17 | counter = 0 18 | with open(train_labels) as f: 19 | lines = f.readlines() 20 | for line in lines: 21 | for l in line.split(): 22 | if l not in labels: 23 | labels[counter] = l 24 | counter += 1 25 | with open(test_labels) as f: 26 | lines = f.readlines() 27 | for line in lines: 28 | for l in line.split(): 29 | if l not in labels: 30 | labels[counter] = l 31 | counter += 1 32 | out_file = open(schema_file, "w") 33 | writer = csv.writer(out_file) 34 | for key, value in labels.items(): 35 | writer.writerow([key, value]) 36 | out_file.close() 37 | 38 | 39 | # schema is given in a file, each line is ID, topic_name 40 | def read_schema(data_file): 41 | with open(data_file) as csvfile: 42 | reader = csv.reader(csvfile, delimiter=",") 43 | line_count = 0 44 | data = {} 45 | for row in reader: 46 | if line_count == 0: 47 | line_count = line_count + 1 48 | continue 49 | 50 | id = row[0] 51 | name = row[1] 52 | data[id] = name 53 | return data 54 | 55 | 56 | def prepare_tokenizer(tokenizer): 57 | special_tokens = [] 58 | special_tokens.extend(["", "", "", "[SEP]"]) 59 | tokenizer.add_special_tokens({"additional_special_tokens": special_tokens}) 60 | 61 | 62 | def read_multilabel(data_file, labels): 63 | data_df = pandas.read_pickle(data_file) 64 | 65 | post_ids = data_df.post_id 66 | post_texts = data_df.post_text 67 | ocr_texts = data_df.ocr_text 68 | landing_texts = data_df.landing_page_text 69 | tags_list = data_df.tags 70 | data = [] 71 | for post_id, post_text, ocr_text, landing_text, tags in zip( 72 | post_ids, post_texts, ocr_texts, landing_texts, tags_list 73 | ): 74 | tags_names = [] 75 | for l in tags: 76 | if l in labels: 77 | tags_names.append(labels[l]) 78 | else: 79 | print(l) 80 | data.append([post_id, post_text, ocr_text, landing_text, tags_names]) 81 | return data 82 | 83 | 84 | def try_convert(val): 85 | try: 86 | return float(val) 87 | except ValueError: 88 | return -1 89 | 90 | 91 | if __name__ == "__main__": 92 | print("main") 93 | create_schema_from_train_test( 94 | "EUR-Lex/train_labels.txt", "EUR-Lex/test_labels.txt", "EUR-Lex/EUR-Lex-schema" 95 | ) 96 | --------------------------------------------------------------------------------