├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── THIRD-PARTY-LICENSES.txt ├── lookahead ├── generation.py ├── lookahead.py ├── run_generate.py └── scorer.py ├── ranking ├── generate_documents.py └── rank.py └── teacher-student ├── src ├── data.py ├── deepspeed_config.json ├── run_summarization.py └── trainer.py └── train_script.sh /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | ## Code of Conduct 2 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 3 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 4 | opensource-codeofconduct@amazon.com with any additional questions or comments. 5 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing Guidelines 2 | 3 | Thank you for your interest in contributing to our project. Whether it's a bug report, new feature, correction, or additional 4 | documentation, we greatly value feedback and contributions from our community. 5 | 6 | Please read through this document before submitting any issues or pull requests to ensure we have all the necessary 7 | information to effectively respond to your bug report or contribution. 8 | 9 | 10 | ## Reporting Bugs/Feature Requests 11 | 12 | We welcome you to use the GitHub issue tracker to report bugs or suggest features. 13 | 14 | When filing an issue, please check existing open, or recently closed, issues to make sure somebody else hasn't already 15 | reported the issue. Please try to include as much information as you can. Details like these are incredibly useful: 16 | 17 | * A reproducible test case or series of steps 18 | * The version of our code being used 19 | * Any modifications you've made relevant to the bug 20 | * Anything unusual about your environment or deployment 21 | 22 | 23 | ## Contributing via Pull Requests 24 | Contributions via pull requests are much appreciated. Before sending us a pull request, please ensure that: 25 | 26 | 1. You are working against the latest source on the *main* branch. 27 | 2. You check existing open, and recently merged, pull requests to make sure someone else hasn't addressed the problem already. 28 | 3. You open an issue to discuss any significant work - we would hate for your time to be wasted. 29 | 30 | To send us a pull request, please: 31 | 32 | 1. Fork the repository. 33 | 2. Modify the source; please focus on the specific change you are contributing. If you also reformat all the code, it will be hard for us to focus on your change. 34 | 3. Ensure local tests pass. 35 | 4. Commit to your fork using clear commit messages. 36 | 5. Send us a pull request, answering any default questions in the pull request interface. 37 | 6. Pay attention to any automated CI failures reported in the pull request, and stay involved in the conversation. 38 | 39 | GitHub provides additional document on [forking a repository](https://help.github.com/articles/fork-a-repo/) and 40 | [creating a pull request](https://help.github.com/articles/creating-a-pull-request/). 41 | 42 | 43 | ## Finding contributions to work on 44 | Looking at the existing issues is a great way to find something to contribute on. As our projects, by default, use the default GitHub issue labels (enhancement/bug/duplicate/help wanted/invalid/question/wontfix), looking at any 'help wanted' issues is a great place to start. 45 | 46 | 47 | ## Code of Conduct 48 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 49 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 50 | opensource-codeofconduct@amazon.com with any additional questions or comments. 51 | 52 | 53 | ## Security issue notifications 54 | If you discover a potential security issue in this project we ask that you notify AWS/Amazon Security via our [vulnerability reporting page](http://aws.amazon.com/security/vulnerability-reporting/). Please do **not** create a public github issue. 55 | 56 | 57 | ## Licensing 58 | 59 | See the [LICENSE](LICENSE) file for our project's licensing. We will ask you to confirm the licensing of your contribution. 60 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. 2 | 3 | // SPDX-License-Identifier: CC-BY-NC-4.0 4 | 5 | Attribution-NonCommercial 4.0 International 6 | 7 | ======================================================================= 8 | 9 | Creative Commons Corporation ("Creative Commons") is not a law firm and 10 | does not provide legal services or legal advice. Distribution of 11 | Creative Commons public licenses does not create a lawyer-client or 12 | other relationship. Creative Commons makes its licenses and related 13 | information available on an "as-is" basis. Creative Commons gives no 14 | warranties regarding its licenses, any material licensed under their 15 | terms and conditions, or any related information. Creative Commons 16 | disclaims all liability for damages resulting from their use to the 17 | fullest extent possible. 18 | 19 | Using Creative Commons Public Licenses 20 | 21 | Creative Commons public licenses provide a standard set of terms and 22 | conditions that creators and other rights holders may use to share 23 | original works of authorship and other material subject to copyright 24 | and certain other rights specified in the public license below. The 25 | following considerations are for informational purposes only, are not 26 | exhaustive, and do not form part of our licenses. 27 | 28 | Considerations for licensors: Our public licenses are 29 | intended for use by those authorized to give the public 30 | permission to use material in ways otherwise restricted by 31 | copyright and certain other rights. Our licenses are 32 | irrevocable. Licensors should read and understand the terms 33 | and conditions of the license they choose before applying it. 34 | Licensors should also secure all rights necessary before 35 | applying our licenses so that the public can reuse the 36 | material as expected. Licensors should clearly mark any 37 | material not subject to the license. This includes other CC- 38 | licensed material, or material used under an exception or 39 | limitation to copyright. More considerations for licensors: 40 | wiki.creativecommons.org/Considerations_for_licensors 41 | 42 | Considerations for the public: By using one of our public 43 | licenses, a licensor grants the public permission to use the 44 | licensed material under specified terms and conditions. If 45 | the licensor's permission is not necessary for any reason--for 46 | example, because of any applicable exception or limitation to 47 | copyright--then that use is not regulated by the license. Our 48 | licenses grant only permissions under copyright and certain 49 | other rights that a licensor has authority to grant. Use of 50 | the licensed material may still be restricted for other 51 | reasons, including because others have copyright or other 52 | rights in the material. A licensor may make special requests, 53 | such as asking that all changes be marked or described. 54 | Although not required by our licenses, you are encouraged to 55 | respect those requests where reasonable. More considerations 56 | for the public: 57 | wiki.creativecommons.org/Considerations_for_licensees 58 | 59 | ======================================================================= 60 | 61 | Creative Commons Attribution-NonCommercial 4.0 International Public 62 | License 63 | 64 | By exercising the Licensed Rights (defined below), You accept and agree 65 | to be bound by the terms and conditions of this Creative Commons 66 | Attribution-NonCommercial 4.0 International Public License ("Public 67 | License"). To the extent this Public License may be interpreted as a 68 | contract, You are granted the Licensed Rights in consideration of Your 69 | acceptance of these terms and conditions, and the Licensor grants You 70 | such rights in consideration of benefits the Licensor receives from 71 | making the Licensed Material available under these terms and 72 | conditions. 73 | 74 | 75 | Section 1 -- Definitions. 76 | 77 | a. Adapted Material means material subject to Copyright and Similar 78 | Rights that is derived from or based upon the Licensed Material 79 | and in which the Licensed Material is translated, altered, 80 | arranged, transformed, or otherwise modified in a manner requiring 81 | permission under the Copyright and Similar Rights held by the 82 | Licensor. For purposes of this Public License, where the Licensed 83 | Material is a musical work, performance, or sound recording, 84 | Adapted Material is always produced where the Licensed Material is 85 | synched in timed relation with a moving image. 86 | 87 | b. Adapter's License means the license You apply to Your Copyright 88 | and Similar Rights in Your contributions to Adapted Material in 89 | accordance with the terms and conditions of this Public License. 90 | 91 | c. Copyright and Similar Rights means copyright and/or similar rights 92 | closely related to copyright including, without limitation, 93 | performance, broadcast, sound recording, and Sui Generis Database 94 | Rights, without regard to how the rights are labeled or 95 | categorized. For purposes of this Public License, the rights 96 | specified in Section 2(b)(1)-(2) are not Copyright and Similar 97 | Rights. 98 | d. Effective Technological Measures means those measures that, in the 99 | absence of proper authority, may not be circumvented under laws 100 | fulfilling obligations under Article 11 of the WIPO Copyright 101 | Treaty adopted on December 20, 1996, and/or similar international 102 | agreements. 103 | 104 | e. Exceptions and Limitations means fair use, fair dealing, and/or 105 | any other exception or limitation to Copyright and Similar Rights 106 | that applies to Your use of the Licensed Material. 107 | 108 | f. Licensed Material means the artistic or literary work, database, 109 | or other material to which the Licensor applied this Public 110 | License. 111 | 112 | g. Licensed Rights means the rights granted to You subject to the 113 | terms and conditions of this Public License, which are limited to 114 | all Copyright and Similar Rights that apply to Your use of the 115 | Licensed Material and that the Licensor has authority to license. 116 | 117 | h. Licensor means the individual(s) or entity(ies) granting rights 118 | under this Public License. 119 | 120 | i. NonCommercial means not primarily intended for or directed towards 121 | commercial advantage or monetary compensation. For purposes of 122 | this Public License, the exchange of the Licensed Material for 123 | other material subject to Copyright and Similar Rights by digital 124 | file-sharing or similar means is NonCommercial provided there is 125 | no payment of monetary compensation in connection with the 126 | exchange. 127 | 128 | j. Share means to provide material to the public by any means or 129 | process that requires permission under the Licensed Rights, such 130 | as reproduction, public display, public performance, distribution, 131 | dissemination, communication, or importation, and to make material 132 | available to the public including in ways that members of the 133 | public may access the material from a place and at a time 134 | individually chosen by them. 135 | 136 | k. Sui Generis Database Rights means rights other than copyright 137 | resulting from Directive 96/9/EC of the European Parliament and of 138 | the Council of 11 March 1996 on the legal protection of databases, 139 | as amended and/or succeeded, as well as other essentially 140 | equivalent rights anywhere in the world. 141 | 142 | l. You means the individual or entity exercising the Licensed Rights 143 | under this Public License. Your has a corresponding meaning. 144 | 145 | 146 | Section 2 -- Scope. 147 | 148 | a. License grant. 149 | 150 | 1. Subject to the terms and conditions of this Public License, 151 | the Licensor hereby grants You a worldwide, royalty-free, 152 | non-sublicensable, non-exclusive, irrevocable license to 153 | exercise the Licensed Rights in the Licensed Material to: 154 | 155 | a. reproduce and Share the Licensed Material, in whole or 156 | in part, for NonCommercial purposes only; and 157 | 158 | b. produce, reproduce, and Share Adapted Material for 159 | NonCommercial purposes only. 160 | 161 | 2. Exceptions and Limitations. For the avoidance of doubt, where 162 | Exceptions and Limitations apply to Your use, this Public 163 | License does not apply, and You do not need to comply with 164 | its terms and conditions. 165 | 166 | 3. Term. The term of this Public License is specified in Section 167 | 6(a). 168 | 169 | 4. Media and formats; technical modifications allowed. The 170 | Licensor authorizes You to exercise the Licensed Rights in 171 | all media and formats whether now known or hereafter created, 172 | and to make technical modifications necessary to do so. The 173 | Licensor waives and/or agrees not to assert any right or 174 | authority to forbid You from making technical modifications 175 | necessary to exercise the Licensed Rights, including 176 | technical modifications necessary to circumvent Effective 177 | Technological Measures. For purposes of this Public License, 178 | simply making modifications authorized by this Section 2(a) 179 | (4) never produces Adapted Material. 180 | 181 | 5. Downstream recipients. 182 | 183 | a. Offer from the Licensor -- Licensed Material. Every 184 | recipient of the Licensed Material automatically 185 | receives an offer from the Licensor to exercise the 186 | Licensed Rights under the terms and conditions of this 187 | Public License. 188 | 189 | b. No downstream restrictions. You may not offer or impose 190 | any additional or different terms or conditions on, or 191 | apply any Effective Technological Measures to, the 192 | Licensed Material if doing so restricts exercise of the 193 | Licensed Rights by any recipient of the Licensed 194 | Material. 195 | 196 | 6. No endorsement. Nothing in this Public License constitutes or 197 | may be construed as permission to assert or imply that You 198 | are, or that Your use of the Licensed Material is, connected 199 | with, or sponsored, endorsed, or granted official status by, 200 | the Licensor or others designated to receive attribution as 201 | provided in Section 3(a)(1)(A)(i). 202 | 203 | b. Other rights. 204 | 205 | 1. Moral rights, such as the right of integrity, are not 206 | licensed under this Public License, nor are publicity, 207 | privacy, and/or other similar personality rights; however, to 208 | the extent possible, the Licensor waives and/or agrees not to 209 | assert any such rights held by the Licensor to the limited 210 | extent necessary to allow You to exercise the Licensed 211 | Rights, but not otherwise. 212 | 213 | 2. Patent and trademark rights are not licensed under this 214 | Public License. 215 | 216 | 3. To the extent possible, the Licensor waives any right to 217 | collect royalties from You for the exercise of the Licensed 218 | Rights, whether directly or through a collecting society 219 | under any voluntary or waivable statutory or compulsory 220 | licensing scheme. In all other cases the Licensor expressly 221 | reserves any right to collect such royalties, including when 222 | the Licensed Material is used other than for NonCommercial 223 | purposes. 224 | 225 | 226 | Section 3 -- License Conditions. 227 | 228 | Your exercise of the Licensed Rights is expressly made subject to the 229 | following conditions. 230 | 231 | a. Attribution. 232 | 233 | 1. If You Share the Licensed Material (including in modified 234 | form), You must: 235 | 236 | a. retain the following if it is supplied by the Licensor 237 | with the Licensed Material: 238 | 239 | i. identification of the creator(s) of the Licensed 240 | Material and any others designated to receive 241 | attribution, in any reasonable manner requested by 242 | the Licensor (including by pseudonym if 243 | designated); 244 | 245 | ii. a copyright notice; 246 | 247 | iii. a notice that refers to this Public License; 248 | 249 | iv. a notice that refers to the disclaimer of 250 | warranties; 251 | 252 | v. a URI or hyperlink to the Licensed Material to the 253 | extent reasonably practicable; 254 | 255 | b. indicate if You modified the Licensed Material and 256 | retain an indication of any previous modifications; and 257 | 258 | c. indicate the Licensed Material is licensed under this 259 | Public License, and include the text of, or the URI or 260 | hyperlink to, this Public License. 261 | 262 | 2. You may satisfy the conditions in Section 3(a)(1) in any 263 | reasonable manner based on the medium, means, and context in 264 | which You Share the Licensed Material. For example, it may be 265 | reasonable to satisfy the conditions by providing a URI or 266 | hyperlink to a resource that includes the required 267 | information. 268 | 269 | 3. If requested by the Licensor, You must remove any of the 270 | information required by Section 3(a)(1)(A) to the extent 271 | reasonably practicable. 272 | 273 | 4. If You Share Adapted Material You produce, the Adapter's 274 | License You apply must not prevent recipients of the Adapted 275 | Material from complying with this Public License. 276 | 277 | 278 | Section 4 -- Sui Generis Database Rights. 279 | 280 | Where the Licensed Rights include Sui Generis Database Rights that 281 | apply to Your use of the Licensed Material: 282 | 283 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right 284 | to extract, reuse, reproduce, and Share all or a substantial 285 | portion of the contents of the database for NonCommercial purposes 286 | only; 287 | 288 | b. if You include all or a substantial portion of the database 289 | contents in a database in which You have Sui Generis Database 290 | Rights, then the database in which You have Sui Generis Database 291 | Rights (but not its individual contents) is Adapted Material; and 292 | 293 | c. You must comply with the conditions in Section 3(a) if You Share 294 | all or a substantial portion of the contents of the database. 295 | 296 | For the avoidance of doubt, this Section 4 supplements and does not 297 | replace Your obligations under this Public License where the Licensed 298 | Rights include other Copyright and Similar Rights. 299 | 300 | 301 | Section 5 -- Disclaimer of Warranties and Limitation of Liability. 302 | 303 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE 304 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS 305 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF 306 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, 307 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, 308 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR 309 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, 310 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT 311 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT 312 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. 313 | 314 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE 315 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, 316 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, 317 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, 318 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR 319 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN 320 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR 321 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR 322 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. 323 | 324 | c. The disclaimer of warranties and limitation of liability provided 325 | above shall be interpreted in a manner that, to the extent 326 | possible, most closely approximates an absolute disclaimer and 327 | waiver of all liability. 328 | 329 | 330 | Section 6 -- Term and Termination. 331 | 332 | a. This Public License applies for the term of the Copyright and 333 | Similar Rights licensed here. However, if You fail to comply with 334 | this Public License, then Your rights under this Public License 335 | terminate automatically. 336 | 337 | b. Where Your right to use the Licensed Material has terminated under 338 | Section 6(a), it reinstates: 339 | 340 | 1. automatically as of the date the violation is cured, provided 341 | it is cured within 30 days of Your discovery of the 342 | violation; or 343 | 344 | 2. upon express reinstatement by the Licensor. 345 | 346 | For the avoidance of doubt, this Section 6(b) does not affect any 347 | right the Licensor may have to seek remedies for Your violations 348 | of this Public License. 349 | 350 | c. For the avoidance of doubt, the Licensor may also offer the 351 | Licensed Material under separate terms or conditions or stop 352 | distributing the Licensed Material at any time; however, doing so 353 | will not terminate this Public License. 354 | 355 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public 356 | License. 357 | 358 | 359 | Section 7 -- Other Terms and Conditions. 360 | 361 | a. The Licensor shall not be bound by any additional or different 362 | terms or conditions communicated by You unless expressly agreed. 363 | 364 | b. Any arrangements, understandings, or agreements regarding the 365 | Licensed Material not stated herein are separate from and 366 | independent of the terms and conditions of this Public License. 367 | 368 | 369 | Section 8 -- Interpretation. 370 | 371 | a. For the avoidance of doubt, this Public License does not, and 372 | shall not be interpreted to, reduce, limit, restrict, or impose 373 | conditions on any use of the Licensed Material that could lawfully 374 | be made without permission under this Public License. 375 | 376 | b. To the extent possible, if any provision of this Public License is 377 | deemed unenforceable, it shall be automatically reformed to the 378 | minimum extent necessary to make it enforceable. If the provision 379 | cannot be reformed, it shall be severed from this Public License 380 | without affecting the enforceability of the remaining terms and 381 | conditions. 382 | 383 | c. No term or condition of this Public License will be waived and no 384 | failure to comply consented to unless expressly agreed to by the 385 | Licensor. 386 | 387 | d. Nothing in this Public License constitutes or may be interpreted 388 | as a limitation upon, or waiver of, any privileges and immunities 389 | that apply to the Licensor or You, including from the legal 390 | processes of any jurisdiction or authority. 391 | 392 | ======================================================================= 393 | 394 | Creative Commons is not a party to its public 395 | licenses. Notwithstanding, Creative Commons may elect to apply one of 396 | its public licenses to material it publishes and in those instances 397 | will be considered the “Licensor.” The text of the Creative Commons 398 | public licenses is dedicated to the public domain under the CC0 Public 399 | Domain Dedication. Except for the limited purpose of indicating that 400 | material is shared under a Creative Commons public license or as 401 | otherwise permitted by the Creative Commons policies published at 402 | creativecommons.org/policies, Creative Commons does not authorize the 403 | use of the trademark "Creative Commons" or any other trademark or logo 404 | of Creative Commons without its prior written consent including, 405 | without limitation, in connection with any unauthorized modifications 406 | to any of its public licenses or any other arrangements, 407 | understandings, or agreements concerning use of licensed material. For 408 | the avoidance of doubt, this paragraph does not form part of the 409 | public licenses. 410 | 411 | Creative Commons may be contacted at creativecommons.org. 412 | 413 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Faithfulness-Aware Decoding Strategies for Abstractive Summarization (EACL 2023) 2 | 3 | This repository contains the code for the paper "Faithfulness-Aware Decoding Strategies for Abstractive Summarization." 4 | 5 | **Authors:** [David Wan](https://meetdavidwan.github.io), [Mengwen Liu](https://www.amazon.science/author/mengwen-liu), [Kathleen McKeown](http://www.cs.columbia.edu/~kathy), [Markus Dreyer](https://markusdreyer.org), and [Mohit Bansal](https://www.cs.unc.edu/~mbansal) 6 | 7 | **Arxiv:** https://arxiv.org/abs/2303.03278 8 | 9 | ## 1. Generating Summaries with Lookahead 10 | 11 | ### 1.1. Environment 12 | Needed packages: 13 | - PyTorch 14 | - [transformers](https://github.com/huggingface/transformers/tree/v4.21.0) >= 4.21.0 15 | - [datasets](https://github.com/huggingface/datasets/tree/2.4.0) >=2.4.0 16 | 17 | ### 1.2. Description 18 | Please use the `lookahead/run_generate.py` file for running generation. We modify Huggingface's generation code to allow for the lookahead heuristics. 19 | 20 | To generate with baseline decoding methods, i.e. the original generation methods, simply run the file without `--do_lookahead`. And to run the decoding with lookahead, simply add `--do_lookahead` and configure the appropriate configuration. 21 | 22 | If using beam search with `--num_return_sequnces >1`, the output will be a json file where each item is a list of summary candidates. Otherwise, it is just a plain file where each line is the output summary. 23 | 24 | ### 1.3. Important arguments: 25 | - `--model_name`: The Huggingface model to use for summarization 26 | - `--document_file`: The input document, where each line represents a single document 27 | - `--output_file`: The output summary file to write to. 28 | - `--batch_size, --num_beams, --num_return_sequences, --max_input_length, --max_output_length, --do_sample`: The arguments used to control the base decoding method. Please refer to HuggingFace's original generation function for more details on this. 29 | - `do_lookahead`: Controls whether to use the lookahead 30 | - `--lookahead_length`: How many tokens to look into the future. By default, it should be the same as the `--max_output_length` to generate the full summary. 31 | - `--lookhead_lambda`: The weight for the lookahead heuristics 32 | - `--top_k`: How many top tokens the lookahead should consider to generate the full summary and provide the heuristics score. This should be set greater than the `lookhaed_beam` 33 | - `--lookahead_decoding_type`: Which decoding strategy to use for the lookahead. The setup is similar to the base decoding strategy 34 | - `--lookhead_beam`: The beam size of the lookahead used when `--lookhead_decoding_type=beam` 35 | - `--scorer_model_type,--scorer_num_layers`: The configuration for BERTScore scorer. Please refer to the official code for more details. 36 | 37 | ### 1.4. Examples 38 | ``` 39 | # greedy without lookhead 40 | python run_generate.py --document_file xsum.document --output_file xsum_greedy.summary 41 | 42 | # beam without lookahead 43 | python run_generate.py --document_file xsum.document --output_file xsum_beam.json --num_beams 10 --num_return_sequences 10 44 | 45 | # greedy with greedy lookahead 46 | python run_generate.py --document_file xsum.document --output_file xsum_greedy_lookahead_greedy.summary --do_lookahead --lookahead_decoding_type greedy 47 | 48 | ``` 49 | 50 | ## 2. Decoding with Ranking 51 | 52 | ### 2.1 Additional Dependencies 53 | Please install the dependencies specified in 1.1 first. 54 | - pandas 55 | 56 | The scripts can be found under `ranking` directory. We assume that all files will be named with the same prefix, for example, `FILE_PREFIX=xsum_beam10` The steps to do the ranking: 57 | 1. Use the `lookahead/run_generate.py` file to generate beam outputs to generate a file `{FILE_PREFIX}_summary.json` 58 | 2. Generate the document file (with repeats according to the beam size) with `ranking/generate_documents.py xsum.document 10 ${FILE_PREFIX}.document`. 59 | 3. Please follow the instruction of the metrics' official repo to install and run the metrics. You may need to edit the code to allow for saving the scores. We expect each metric file to be in the json format. For example, `${FILE_PREFIX}_factcc.json, ${FILE_PREFIX}_dae.json, ...` 60 | 4. Run `ranking/rank.py --file_prefix ${FILE_PREFIX}`, and it will generate `${FILE_PREFIX}.csv`, which is the pandas csv file that stores all information, and `${FILE_PREFIX}_ranked_summary.txt`, that outputs the ranked summary with each line representing the ranked summary. 61 | 62 | ## 3. Training with Distillation 63 | 64 | ### 3.1. Additional Dependencies 65 | Please install the dependencies specified in 1.1 first. 66 | - nltk 67 | - numpy 68 | - [Deepspeed](https://github.com/microsoft/DeepSpeed) (Optional) 69 | 70 | ### 3.2. Description 71 | 72 | The code can be found in `teacher-student/src`. The script `teacher-student/src/run_summarization.py` trains the summarization model with multiple references. The file is very similar to transformers' summarization code except for modifications to allow for multiple references. One additional argument is: 73 | - `--additional_reference_file`: This should point to the file containing summaries of the training data, split into one line each. 74 | 75 | An example of a training script can be found in `teacher-student/train_script.sh`. 76 | 77 | ## Security 78 | 79 | See [CONTRIBUTING](CONTRIBUTING.md#security-issue-notifications) for more information. 80 | 81 | ## License Summary 82 | 83 | The documentation is made available under the CC-BY-NC-4.0 License. See the LICENSE file. 84 | 85 | ## Citation 86 | 87 | ``` 88 | @inproceedings{wan-etal-2023-faithful-generation, 89 | title = "Faithfulness-Aware Decoding Strategies for Abstractive Summarization", 90 | author = "Wan, David and 91 | Liu, Mengwen and 92 | McKeown, Kathleen and 93 | Dreyer Markus and 94 | Bansal, Mohit", 95 | booktitle = "Proceedings of the 17th Conference of the European Chapter of the Association for Computational Linguistics", 96 | year={2023} 97 | } 98 | ``` 99 | -------------------------------------------------------------------------------- /THIRD-PARTY-LICENSES.txt: -------------------------------------------------------------------------------- 1 | ** NLTK; version 3.7 -- http://www.nltk.org/ 2 | ** Hugging Face Transformers; version 4.21 -- https://github.com/huggingface/transformers 3 | 4 | Apache License 5 | Version 2.0, January 2004 6 | http://www.apache.org/licenses/ 7 | 8 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 9 | 10 | 1. Definitions. 11 | 12 | "License" shall mean the terms and conditions for use, reproduction, and 13 | distribution as defined by Sections 1 through 9 of this document. 14 | 15 | "Licensor" shall mean the copyright owner or entity authorized by the copyright 16 | owner that is granting the License. 17 | 18 | "Legal Entity" shall mean the union of the acting entity and all other entities 19 | that control, are controlled by, or are under common control with that entity. 20 | For the purposes of this definition, "control" means (i) the power, direct or 21 | indirect, to cause the direction or management of such entity, whether by 22 | contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the 23 | outstanding shares, or (iii) beneficial ownership of such entity. 24 | 25 | "You" (or "Your") shall mean an individual or Legal Entity exercising 26 | permissions granted by this License. 27 | 28 | "Source" form shall mean the preferred form for making modifications, including 29 | but not limited to software source code, documentation source, and configuration 30 | files. 31 | 32 | "Object" form shall mean any form resulting from mechanical transformation or 33 | translation of a Source form, including but not limited to compiled object code, 34 | generated documentation, and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or Object form, made 37 | available under the License, as indicated by a copyright notice that is included 38 | in or attached to the work (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object form, that 41 | is based on (or derived from) the Work and for which the editorial revisions, 42 | annotations, elaborations, or other modifications represent, as a whole, an 43 | original work of authorship. For the purposes of this License, Derivative Works 44 | shall not include works that remain separable from, or merely link (or bind by 45 | name) to the interfaces of, the Work and Derivative Works thereof. 46 | 47 | "Contribution" shall mean any work of authorship, including the original version 48 | of the Work and any modifications or additions to that Work or Derivative Works 49 | thereof, that is intentionally submitted to Licensor for inclusion in the Work 50 | by the copyright owner or by an individual or Legal Entity authorized to submit 51 | on behalf of the copyright owner. For the purposes of this definition, 52 | "submitted" means any form of electronic, verbal, or written communication sent 53 | to the Licensor or its representatives, including but not limited to 54 | communication on electronic mailing lists, source code control systems, and 55 | issue tracking systems that are managed by, or on behalf of, the Licensor for 56 | the purpose of discussing and improving the Work, but excluding communication 57 | that is conspicuously marked or otherwise designated in writing by the copyright 58 | owner as "Not a Contribution." 59 | 60 | "Contributor" shall mean Licensor and any individual or Legal Entity on behalf 61 | of whom a Contribution has been received by Licensor and subsequently 62 | incorporated within the Work. 63 | 64 | 2. Grant of Copyright License. Subject to the terms and conditions of this 65 | License, each Contributor hereby grants to You a perpetual, worldwide, non- 66 | exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, 67 | prepare Derivative Works of, publicly display, publicly perform, sublicense, and 68 | distribute the Work and such Derivative Works in Source or Object form. 69 | 70 | 3. Grant of Patent License. Subject to the terms and conditions of this License, 71 | each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no- 72 | charge, royalty-free, irrevocable (except as stated in this section) patent 73 | license to make, have made, use, offer to sell, sell, import, and otherwise 74 | transfer the Work, where such license applies only to those patent claims 75 | licensable by such Contributor that are necessarily infringed by their 76 | Contribution(s) alone or by combination of their Contribution(s) with the Work 77 | to which such Contribution(s) was submitted. If You institute patent litigation 78 | against any entity (including a cross-claim or counterclaim in a lawsuit) 79 | alleging that the Work or a Contribution incorporated within the Work 80 | constitutes direct or contributory patent infringement, then any patent licenses 81 | granted to You under this License for that Work shall terminate as of the date 82 | such litigation is filed. 83 | 84 | 4. Redistribution. You may reproduce and distribute copies of the Work or 85 | Derivative Works thereof in any medium, with or without modifications, and in 86 | Source or Object form, provided that You meet the following conditions: 87 | 88 | (a) You must give any other recipients of the Work or Derivative Works a 89 | copy of this License; and 90 | 91 | (b) You must cause any modified files to carry prominent notices stating 92 | that You changed the files; and 93 | 94 | (c) You must retain, in the Source form of any Derivative Works that You 95 | distribute, all copyright, patent, trademark, and attribution notices from the 96 | Source form of the Work, excluding those notices that do not pertain to any part 97 | of the Derivative Works; and 98 | 99 | (d) If the Work includes a "NOTICE" text file as part of its distribution, 100 | then any Derivative Works that You distribute must include a readable copy of 101 | the attribution notices contained within such NOTICE file, excluding those 102 | notices that do not pertain to any part of the Derivative Works, in at least one 103 | of the following places: within a NOTICE text file distributed as part of the 104 | Derivative Works; within the Source form or documentation, if provided along 105 | with the Derivative Works; or, within a display generated by the Derivative 106 | Works, if and wherever such third-party notices normally appear. The contents of 107 | the NOTICE file are for informational purposes only and do not modify the 108 | License. You may add Your own attribution notices within Derivative Works that 109 | You distribute, alongside or as an addendum to the NOTICE text from the Work, 110 | provided that such additional attribution notices cannot be construed as 111 | modifying the License. 112 | 113 | You may add Your own copyright statement to Your modifications and may 114 | provide additional or different license terms and conditions for use, 115 | reproduction, or distribution of Your modifications, or for any such Derivative 116 | Works as a whole, provided Your use, reproduction, and distribution of the Work 117 | otherwise complies with the conditions stated in this License. 118 | 119 | 5. Submission of Contributions. Unless You explicitly state otherwise, any 120 | Contribution intentionally submitted for inclusion in the Work by You to the 121 | Licensor shall be under the terms and conditions of this License, without any 122 | additional terms or conditions. Notwithstanding the above, nothing herein shall 123 | supersede or modify the terms of any separate license agreement you may have 124 | executed with Licensor regarding such Contributions. 125 | 126 | 6. Trademarks. This License does not grant permission to use the trade names, 127 | trademarks, service marks, or product names of the Licensor, except as required 128 | for reasonable and customary use in describing the origin of the Work and 129 | reproducing the content of the NOTICE file. 130 | 131 | 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in 132 | writing, Licensor provides the Work (and each Contributor provides its 133 | Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 134 | KIND, either express or implied, including, without limitation, any warranties 135 | or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 136 | PARTICULAR PURPOSE. You are solely responsible for determining the 137 | appropriateness of using or redistributing the Work and assume any risks 138 | associated with Your exercise of permissions under this License. 139 | 140 | 8. Limitation of Liability. In no event and under no legal theory, whether in 141 | tort (including negligence), contract, or otherwise, unless required by 142 | applicable law (such as deliberate and grossly negligent acts) or agreed to in 143 | writing, shall any Contributor be liable to You for damages, including any 144 | direct, indirect, special, incidental, or consequential damages of any character 145 | arising as a result of this License or out of the use or inability to use the 146 | Work (including but not limited to damages for loss of goodwill, work stoppage, 147 | computer failure or malfunction, or any and all other commercial damages or 148 | losses), even if such Contributor has been advised of the possibility of such 149 | damages. 150 | 151 | 9. Accepting Warranty or Additional Liability. While redistributing the Work or 152 | Derivative Works thereof, You may choose to offer, and charge a fee for, 153 | acceptance of support, warranty, indemnity, or other liability obligations 154 | and/or rights consistent with this License. However, in accepting such 155 | obligations, You may act only on Your own behalf and on Your sole 156 | responsibility, not on behalf of any other Contributor, and only if You agree to 157 | indemnify, defend, and hold each Contributor harmless for any liability incurred 158 | by, or claims asserted against, such Contributor by reason of your accepting any 159 | such warranty or additional liability. 160 | 161 | END OF TERMS AND CONDITIONS 162 | 163 | APPENDIX: How to apply the Apache License to your work. 164 | 165 | To apply the Apache License to your work, attach the following boilerplate 166 | notice, with the fields enclosed by brackets "[]" replaced with your own 167 | identifying information. (Don't include the brackets!) The text should be 168 | enclosed in the appropriate comment syntax for the file format. We also 169 | recommend that a file or class name and description of purpose be included on 170 | the same "printed page" as the copyright notice for easier identification within 171 | third-party archives. 172 | 173 | Copyright [yyyy] [name of copyright owner] 174 | 175 | Licensed under the Apache License, Version 2.0 (the "License"); 176 | you may not use this file except in compliance with the License. 177 | You may obtain a copy of the License at 178 | 179 | http://www.apache.org/licenses/LICENSE-2.0 180 | 181 | Unless required by applicable law or agreed to in writing, software 182 | distributed under the License is distributed on an "AS IS" BASIS, 183 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 184 | See the License for the specific language governing permissions and 185 | limitations under the License. 186 | 187 | * For NLTK see also this required NOTICE: 188 | Copyright (C) 2001-2019 NLTK Project 189 | 190 | Licensed under the Apache License, Version 2.0 (the 'License'); 191 | you may not use this file except in compliance with the License. 192 | You may obtain a copy of the License at 193 | 194 | http://www.apache.org/licenses/LICENSE-2.0 195 | 196 | Unless required by applicable law or agreed to in writing, software 197 | distributed under the License is distributed on an 'AS IS' BASIS, 198 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 199 | See the License for the specific language governing permissions and 200 | limitations under the License. 201 | * For Hugging Face Transformers see also this required NOTICE: 202 | Copyright 2018- The Hugging Face team. All rights reserved. 203 | 204 | Apache License 205 | Version 2.0, January 2004 206 | http://www.apache.org/licenses/ 207 | 208 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 209 | 210 | 1. Definitions. 211 | 212 | "License" shall mean the terms and conditions for use, reproduction, 213 | and distribution as defined by Sections 1 through 9 of this document. 214 | 215 | "Licensor" shall mean the copyright owner or entity authorized by 216 | the copyright owner that is granting the License. 217 | 218 | "Legal Entity" shall mean the union of the acting entity and all 219 | other entities that control, are controlled by, or are under common 220 | control with that entity. For the purposes of this definition, 221 | "control" means (i) the power, direct or indirect, to cause the 222 | direction or management of such entity, whether by contract or 223 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 224 | outstanding shares, or (iii) beneficial ownership of such entity. 225 | 226 | "You" (or "Your") shall mean an individual or Legal Entity 227 | exercising permissions granted by this License. 228 | 229 | "Source" form shall mean the preferred form for making modifications, 230 | including but not limited to software source code, documentation 231 | source, and configuration files. 232 | 233 | "Object" form shall mean any form resulting from mechanical 234 | transformation or translation of a Source form, including but 235 | not limited to compiled object code, generated documentation, 236 | and conversions to other media types. 237 | 238 | "Work" shall mean the work of authorship, whether in Source or 239 | Object form, made available under the License, as indicated by a 240 | copyright notice that is included in or attached to the work 241 | (an example is provided in the Appendix below). 242 | 243 | "Derivative Works" shall mean any work, whether in Source or Object 244 | form, that is based on (or derived from) the Work and for which the 245 | editorial revisions, annotations, elaborations, or other modifications 246 | represent, as a whole, an original work of authorship. For the 247 | purposes 248 | of this License, Derivative Works shall not include works that remain 249 | separable from, or merely link (or bind by name) to the interfaces of, 250 | the Work and Derivative Works thereof. 251 | 252 | "Contribution" shall mean any work of authorship, including 253 | the original version of the Work and any modifications or additions 254 | to that Work or Derivative Works thereof, that is intentionally 255 | submitted to Licensor for inclusion in the Work by the copyright owner 256 | or by an individual or Legal Entity authorized to submit on behalf of 257 | the copyright owner. For the purposes of this definition, "submitted" 258 | means any form of electronic, verbal, or written communication sent 259 | to the Licensor or its representatives, including but not limited to 260 | communication on electronic mailing lists, source code control 261 | systems, 262 | and issue tracking systems that are managed by, or on behalf of, the 263 | Licensor for the purpose of discussing and improving the Work, but 264 | excluding communication that is conspicuously marked or otherwise 265 | designated in writing by the copyright owner as "Not a Contribution." 266 | 267 | "Contributor" shall mean Licensor and any individual or Legal Entity 268 | on behalf of whom a Contribution has been received by Licensor and 269 | subsequently incorporated within the Work. 270 | 271 | 2. Grant of Copyright License. Subject to the terms and conditions of 272 | this License, each Contributor hereby grants to You a perpetual, 273 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 274 | copyright license to reproduce, prepare Derivative Works of, 275 | publicly display, publicly perform, sublicense, and distribute the 276 | Work and such Derivative Works in Source or Object form. 277 | 278 | 3. Grant of Patent License. Subject to the terms and conditions of 279 | this License, each Contributor hereby grants to You a perpetual, 280 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 281 | (except as stated in this section) patent license to make, have made, 282 | use, offer to sell, sell, import, and otherwise transfer the Work, 283 | where such license applies only to those patent claims licensable 284 | by such Contributor that are necessarily infringed by their 285 | Contribution(s) alone or by combination of their Contribution(s) 286 | with the Work to which such Contribution(s) was submitted. If You 287 | institute patent litigation against any entity (including a 288 | cross-claim or counterclaim in a lawsuit) alleging that the Work 289 | or a Contribution incorporated within the Work constitutes direct 290 | or contributory patent infringement, then any patent licenses 291 | granted to You under this License for that Work shall terminate 292 | as of the date such litigation is filed. 293 | 294 | 4. Redistribution. You may reproduce and distribute copies of the 295 | Work or Derivative Works thereof in any medium, with or without 296 | modifications, and in Source or Object form, provided that You 297 | meet the following conditions: 298 | 299 | (a) You must give any other recipients of the Work or 300 | Derivative Works a copy of this License; and 301 | 302 | (b) You must cause any modified files to carry prominent notices 303 | stating that You changed the files; and 304 | 305 | (c) You must retain, in the Source form of any Derivative Works 306 | that You distribute, all copyright, patent, trademark, and 307 | attribution notices from the Source form of the Work, 308 | excluding those notices that do not pertain to any part of 309 | the Derivative Works; and 310 | 311 | (d) If the Work includes a "NOTICE" text file as part of its 312 | distribution, then any Derivative Works that You distribute must 313 | include a readable copy of the attribution notices contained 314 | within such NOTICE file, excluding those notices that do not 315 | pertain to any part of the Derivative Works, in at least one 316 | of the following places: within a NOTICE text file distributed 317 | as part of the Derivative Works; within the Source form or 318 | documentation, if provided along with the Derivative Works; or, 319 | within a display generated by the Derivative Works, if and 320 | wherever such third-party notices normally appear. The contents 321 | of the NOTICE file are for informational purposes only and 322 | do not modify the License. You may add Your own attribution 323 | notices within Derivative Works that You distribute, alongside 324 | or as an addendum to the NOTICE text from the Work, provided 325 | that such additional attribution notices cannot be construed 326 | as modifying the License. 327 | 328 | You may add Your own copyright statement to Your modifications and 329 | may provide additional or different license terms and conditions 330 | for use, reproduction, or distribution of Your modifications, or 331 | for any such Derivative Works as a whole, provided Your use, 332 | reproduction, and distribution of the Work otherwise complies with 333 | the conditions stated in this License. 334 | 335 | 5. Submission of Contributions. Unless You explicitly state otherwise, 336 | any Contribution intentionally submitted for inclusion in the Work 337 | by You to the Licensor shall be under the terms and conditions of 338 | this License, without any additional terms or conditions. 339 | Notwithstanding the above, nothing herein shall supersede or modify 340 | the terms of any separate license agreement you may have executed 341 | with Licensor regarding such Contributions. 342 | 343 | 6. Trademarks. This License does not grant permission to use the trade 344 | names, trademarks, service marks, or product names of the Licensor, 345 | except as required for reasonable and customary use in describing the 346 | origin of the Work and reproducing the content of the NOTICE file. 347 | 348 | 7. Disclaimer of Warranty. Unless required by applicable law or 349 | agreed to in writing, Licensor provides the Work (and each 350 | Contributor provides its Contributions) on an "AS IS" BASIS, 351 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 352 | implied, including, without limitation, any warranties or conditions 353 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 354 | PARTICULAR PURPOSE. You are solely responsible for determining the 355 | appropriateness of using or redistributing the Work and assume any 356 | risks associated with Your exercise of permissions under this License. 357 | 358 | 8. Limitation of Liability. In no event and under no legal theory, 359 | whether in tort (including negligence), contract, or otherwise, 360 | unless required by applicable law (such as deliberate and grossly 361 | negligent acts) or agreed to in writing, shall any Contributor be 362 | liable to You for damages, including any direct, indirect, special, 363 | incidental, or consequential damages of any character arising as a 364 | result of this License or out of the use or inability to use the 365 | Work (including but not limited to damages for loss of goodwill, 366 | work stoppage, computer failure or malfunction, or any and all 367 | other commercial damages or losses), even if such Contributor 368 | has been advised of the possibility of such damages. 369 | 370 | 9. Accepting Warranty or Additional Liability. While redistributing 371 | the Work or Derivative Works thereof, You may choose to offer, 372 | and charge a fee for, acceptance of support, warranty, indemnity, 373 | or other liability obligations and/or rights consistent with this 374 | License. However, in accepting such obligations, You may act only 375 | on Your own behalf and on Your sole responsibility, not on behalf 376 | of any other Contributor, and only if You agree to indemnify, 377 | defend, and hold each Contributor harmless for any liability 378 | incurred by, or claims asserted against, such Contributor by reason 379 | of your accepting any such warranty or additional liability. 380 | 381 | END OF TERMS AND CONDITIONS 382 | 383 | APPENDIX: How to apply the Apache License to your work. 384 | 385 | To apply the Apache License to your work, attach the following 386 | boilerplate notice, with the fields enclosed by brackets "[]" 387 | replaced with your own identifying information. (Don't include 388 | the brackets!) The text should be enclosed in the appropriate 389 | comment syntax for the file format. We also recommend that a 390 | file or class name and description of purpose be included on the 391 | same "printed page" as the copyright notice for easier 392 | identification within third-party archives. 393 | 394 | Copyright [yyyy] [name of copyright owner] 395 | 396 | Licensed under the Apache License, Version 2.0 (the "License"); 397 | you may not use this file except in compliance with the License. 398 | You may obtain a copy of the License at 399 | 400 | http://www.apache.org/licenses/LICENSE-2.0 401 | 402 | Unless required by applicable law or agreed to in writing, software 403 | distributed under the License is distributed on an "AS IS" BASIS, 404 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 405 | See the License for the specific language governing permissions and 406 | limitations under the License. 407 | 408 | ------ 409 | 410 | ** factCC; version n/a -- https://github.com/salesforce/factCC 411 | Copyright (c) 2019, Salesforce.com, Inc. 412 | 413 | Copyright (c) 2019, Salesforce.com, Inc. 414 | All rights reserved. 415 | 416 | Redistribution and use in source and binary forms, with or without modification, 417 | are permitted provided that the following conditions are met: 418 | 419 | * Redistributions of source code must retain the above copyright notice, this 420 | list of conditions and the following disclaimer. 421 | 422 | * Redistributions in binary form must reproduce the above copyright notice, this 423 | list of conditions and the following disclaimer in the documentation and/or 424 | other materials provided with the distribution. 425 | 426 | * Neither the name of Salesforce.com nor the names of its contributors may be 427 | used to endorse or promote products derived from this software without specific 428 | prior written permission. 429 | 430 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 431 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 432 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 433 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR 434 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 435 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 436 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 437 | ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 438 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 439 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 440 | 441 | ------ 442 | 443 | ** Pytorch 1.12; version 1.12 -- https://github.com/pytorch/pytorch 444 | From PyTorch: 445 | 446 | Copyright (c) 2016- Facebook, Inc (Adam Paszke) 447 | Copyright (c) 2014- Facebook, Inc (Soumith Chintala) 448 | Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert) 449 | Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu) 450 | Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu) 451 | Copyright (c) 2011-2013 NYU (Clement Farabet) 452 | Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, 453 | Iain Melvin, Jason Weston) 454 | Copyright (c) 2006 Idiap Research Institute (Samy Bengio) 455 | Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, 456 | Johnny Mariethoz) 457 | 458 | From Caffe2: 459 | 460 | Copyright (c) 2016-present, Facebook Inc. All rights reserved. 461 | 462 | All contributions by Facebook: 463 | Copyright (c) 2016 Facebook Inc. 464 | 465 | All contributions by Google: 466 | Copyright (c) 2015 Google Inc. 467 | All rights reserved. 468 | 469 | All contributions by Yangqing Jia: 470 | Copyright (c) 2015 Yangqing Jia 471 | All rights reserved. 472 | 473 | All contributions by Kakao Brain: 474 | Copyright 2019-2020 Kakao Brain 475 | 476 | All contributions by Cruise LLC: 477 | Copyright (c) 2022 Cruise LLC. 478 | All rights reserved. 479 | 480 | All contributions from Caffe: 481 | Copyright(c) 2013, 2014, 2015, the respective contributors 482 | All rights reserved. 483 | 484 | All other contributions: 485 | Copyright(c) 2015, 2016 the respective contributors 486 | All rights reserved. 487 | 488 | Caffe2 uses a copyright model similar to Caffe: each contributor holds 489 | copyright over their contributions to Caffe2. The project versioning records 490 | all such contribution and copyright details. If a contributor wants to further 491 | mark their specific copyright on a particular contribution, they should 492 | indicate their copyright solely in the commit message of the change when it is 493 | committed. 494 | 495 | All rights reserved. 496 | 497 | Redistribution and use in source and binary forms, with or without 498 | modification, are permitted provided that the following conditions are met: 499 | 500 | 1. Redistributions of source code must retain the above copyright 501 | notice, this list of conditions and the following disclaimer. 502 | 503 | 2. Redistributions in binary form must reproduce the above copyright 504 | notice, this list of conditions and the following disclaimer in the 505 | documentation and/or other materials provided with the distribution. 506 | 507 | 3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories 508 | America 509 | and IDIAP Research Institute nor the names of its contributors may be 510 | used to endorse or promote products derived from this software without 511 | specific prior written permission. 512 | 513 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 514 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 515 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 516 | ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE 517 | LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 518 | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 519 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 520 | INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 521 | CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 522 | ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 523 | POSSIBILITY OF SUCH DAMAGE. 524 | 525 | From PyTorch: 526 | 527 | Copyright (c) 2016- Facebook, Inc (Adam Paszke) 528 | Copyright (c) 2014- Facebook, Inc (Soumith Chintala) 529 | Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert) 530 | Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu) 531 | Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu) 532 | Copyright (c) 2011-2013 NYU (Clement Farabet) 533 | Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, 534 | Iain Melvin, Jason Weston) 535 | Copyright (c) 2006 Idiap Research Institute (Samy Bengio) 536 | Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, 537 | Johnny Mariethoz) 538 | 539 | From Caffe2: 540 | 541 | Copyright (c) 2016-present, Facebook Inc. All rights reserved. 542 | 543 | All contributions by Facebook: 544 | Copyright (c) 2016 Facebook Inc. 545 | 546 | All contributions by Google: 547 | Copyright (c) 2015 Google Inc. 548 | All rights reserved. 549 | 550 | All contributions by Yangqing Jia: 551 | Copyright (c) 2015 Yangqing Jia 552 | All rights reserved. 553 | 554 | All contributions by Kakao Brain: 555 | Copyright 2019-2020 Kakao Brain 556 | 557 | All contributions by Cruise LLC: 558 | Copyright (c) 2022 Cruise LLC. 559 | All rights reserved. 560 | 561 | All contributions from Caffe: 562 | Copyright(c) 2013, 2014, 2015, the respective contributors 563 | All rights reserved. 564 | 565 | All other contributions: 566 | Copyright(c) 2015, 2016 the respective contributors 567 | All rights reserved. 568 | 569 | Caffe2 uses a copyright model similar to Caffe: each contributor holds 570 | copyright over their contributions to Caffe2. The project versioning records 571 | all such contribution and copyright details. If a contributor wants to further 572 | mark their specific copyright on a particular contribution, they should 573 | indicate their copyright solely in the commit message of the change when it is 574 | committed. 575 | 576 | All rights reserved. 577 | 578 | Redistribution and use in source and binary forms, with or without 579 | modification, are permitted provided that the following conditions are met: 580 | 581 | 1. Redistributions of source code must retain the above copyright 582 | notice, this list of conditions and the following disclaimer. 583 | 584 | 2. Redistributions in binary form must reproduce the above copyright 585 | notice, this list of conditions and the following disclaimer in the 586 | documentation and/or other materials provided with the distribution. 587 | 588 | 3. Neither the names of Facebook, Deepmind Technologies, NYU, NEC Laboratories 589 | America 590 | and IDIAP Research Institute nor the names of its contributors may be 591 | used to endorse or promote products derived from this software without 592 | specific prior written permission. 593 | 594 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 595 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 596 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE 597 | ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE 598 | LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR 599 | CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 600 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 601 | INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN 602 | CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) 603 | ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE 604 | POSSIBILITY OF SUCH DAMAGE. 605 | 606 | ------ 607 | 608 | ** BertScore; version 0.3.13 -- https://github.com/Tiiiger/bert_score 609 | Copyright (c) 2019 Tianyi Zhang, Varsha Kishore, Felix Wu, Kilian Q. Weinberger, 610 | and Yoav Artzi. 611 | 612 | MIT License 613 | 614 | Copyright (c) 2019 Tianyi Zhang, Varsha Kishore, Felix Wu, Kilian Q. Weinberger, 615 | and Yoav Artzi. 616 | 617 | Permission is hereby granted, free of charge, to any person obtaining a copy 618 | of this software and associated documentation files (the "Software"), to deal 619 | in the Software without restriction, including without limitation the rights 620 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 621 | copies of the Software, and to permit persons to whom the Software is 622 | furnished to do so, subject to the following conditions: 623 | 624 | The above copyright notice and this permission notice shall be included in all 625 | copies or substantial portions of the Software. 626 | 627 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 628 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 629 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 630 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 631 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 632 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 633 | SOFTWARE. 634 | -------------------------------------------------------------------------------- /lookahead/lookahead.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | from sys import prefix 3 | import warnings 4 | from dataclasses import dataclass 5 | from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union 6 | 7 | import torch 8 | import torch.distributed as dist 9 | from torch import nn 10 | import torch.nn.functional as F 11 | 12 | import copy 13 | 14 | from transformers.generation_beam_constraints import Constraint, DisjunctiveConstraint, PhrasalConstraint 15 | from transformers.generation_beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer 16 | from transformers.generation_logits_process import ( 17 | EncoderNoRepeatNGramLogitsProcessor, 18 | ExponentialDecayLengthPenalty, 19 | ForcedBOSTokenLogitsProcessor, 20 | ForcedEOSTokenLogitsProcessor, 21 | HammingDiversityLogitsProcessor, 22 | InfNanRemoveLogitsProcessor, 23 | LogitNormalization, 24 | LogitsProcessorList, 25 | MinLengthLogitsProcessor, 26 | NoBadWordsLogitsProcessor, 27 | NoRepeatNGramLogitsProcessor, 28 | PrefixConstrainedLogitsProcessor, 29 | RepetitionPenaltyLogitsProcessor, 30 | TemperatureLogitsWarper, 31 | TopKLogitsWarper, 32 | TopPLogitsWarper, 33 | TypicalLogitsWarper, 34 | ) 35 | from transformers.generation_stopping_criteria import ( 36 | MaxLengthCriteria, 37 | MaxTimeCriteria, 38 | StoppingCriteria, 39 | StoppingCriteriaList, 40 | validate_stopping_criteria, 41 | ) 42 | from transformers.pytorch_utils import torch_int_div 43 | from transformers.utils import ModelOutput, logging 44 | 45 | from transformers.generation_utils import ( 46 | GreedySearchEncoderDecoderOutput, 47 | GreedySearchDecoderOnlyOutput, 48 | BeamSearchEncoderDecoderOutput, 49 | BeamSearchDecoderOnlyOutput, 50 | SampleEncoderDecoderOutput, 51 | SampleDecoderOnlyOutput, 52 | ) 53 | 54 | logger = logging.get_logger(__name__) 55 | 56 | class Lookahead: 57 | """ 58 | Object that performs the lookahead. This is very similar to GenerationMixin, since it needs to decode the sequence as well, 59 | but this contains the additional function to compute heuristics score. 60 | """ 61 | 62 | def __init__( 63 | self, 64 | model, 65 | tokenizer, 66 | scorer, 67 | lookahead_length=1, 68 | lookahead_lambda=1.0, 69 | lookahead_top_k=5, 70 | decoding_type="greedy", 71 | max_length: Optional[int] = None, 72 | min_length: Optional[int] = None, 73 | do_sample: Optional[bool] = None, 74 | early_stopping: Optional[bool] = None, 75 | num_beams: Optional[int] = None, 76 | temperature: Optional[float] = None, 77 | top_k: Optional[int] = None, 78 | top_p: Optional[float] = None, 79 | typical_p: Optional[float] = None, 80 | repetition_penalty: Optional[float] = None, 81 | bad_words_ids: Optional[Iterable[int]] = None, 82 | force_words_ids: Optional[Union[Iterable[int], Iterable[Iterable[int]]]] = None, 83 | bos_token_id: Optional[int] = None, 84 | pad_token_id: Optional[int] = None, 85 | eos_token_id: Optional[int] = None, 86 | length_penalty: Optional[float] = None, 87 | no_repeat_ngram_size: Optional[int] = None, 88 | encoder_no_repeat_ngram_size: Optional[int] = None, 89 | num_return_sequences: Optional[int] = None, 90 | max_time: Optional[float] = None, 91 | max_new_tokens: Optional[int] = None, 92 | decoder_start_token_id: Optional[int] = None, 93 | use_cache: Optional[bool] = None, 94 | num_beam_groups: Optional[int] = None, 95 | diversity_penalty: Optional[float] = None, 96 | prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, 97 | logits_processor: Optional[LogitsProcessorList] = LogitsProcessorList(), 98 | renormalize_logits: Optional[bool] = None, 99 | stopping_criteria: Optional[StoppingCriteriaList] = StoppingCriteriaList(), 100 | constraints: Optional[List[Constraint]] = None, 101 | output_attentions: Optional[bool] = None, 102 | output_hidden_states: Optional[bool] = None, 103 | output_scores: Optional[bool] = None, 104 | return_dict_in_generate: Optional[bool] = None, 105 | forced_bos_token_id: Optional[int] = None, 106 | forced_eos_token_id: Optional[int] = None, 107 | remove_invalid_values: Optional[bool] = None, 108 | synced_gpus: Optional[bool] = False, 109 | exponential_decay_length_penalty: Optional[Tuple[Union[int, float]]] = None, 110 | ): 111 | """ 112 | model: The Huggingface Model 113 | tokenizer: The tokenizer for decoding the summaries 114 | scorer: Scorer object that calculates the score given document and summary 115 | lookahead_length: The number of tokens to look ahead 116 | lookahead_lambda: The weight for the score 117 | lookahead_top_k: The number of top tokens to consider for expansion 118 | decoding_type: The decoding type for lookahead. [greedy, beam, sample] 119 | 120 | Other parameters are the same arguments expected for GenerationMixin to control the generation 121 | """ 122 | self.model = model 123 | self.tokenizer = tokenizer 124 | self.scorer = scorer 125 | 126 | if lookahead_length == -1: 127 | assert max_length is not None 128 | self.lookahead_length = max_length 129 | self.lookahead_until_sent = True 130 | else: 131 | self.lookahead_length = lookahead_length 132 | self.lookahead_until_sent = False 133 | 134 | self.lookahead_lambda = lookahead_lambda 135 | self.lookahead_top_k = lookahead_top_k 136 | self.decoding_type = decoding_type 137 | 138 | if self.decoding_type == "greedy": 139 | self.decoding_func = self.greedy_search 140 | elif self.decoding_type == "beam": 141 | self.decoding_func = self.beam_search 142 | elif self.decoding_type == "sample": 143 | self.decoding_func = self.sample 144 | 145 | # generation parameters from generate() 146 | self.bos_token_id = self.model.config.bos_token_id 147 | self.num_beams = num_beams if num_beams is not None else self.model.config.num_beams 148 | self.length_penalty = length_penalty if length_penalty is not None else self.model.config.length_penalty 149 | self.early_stopping = early_stopping if early_stopping is not None else self.model.config.early_stopping 150 | self.num_beam_groups = num_beam_groups if num_beam_groups is not None else self.model.config.num_beam_groups 151 | self.num_return_sequences = ( 152 | num_return_sequences if num_return_sequences is not None else self.model.config.num_return_sequences 153 | ) 154 | 155 | self.pad_token_id = self.model.config.pad_token_id 156 | self.eos_token_id = self.model.config.eos_token_id 157 | 158 | if self.eos_token_id is None and hasattr(self.model.config, "decoder"): 159 | self.eos_token_id = self.model.config.decoder.eos_token_id 160 | 161 | if self.pad_token_id is None and self.eos_token_id is not None: 162 | # special case if pad_token_id is not defined 163 | logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{self.eos_token_id} for open-end generation.") 164 | self.pad_token_id = self.eos_token_id 165 | self.max_length = max_length 166 | self.min_length = min_length 167 | self.temperature = temperature 168 | self.top_k = top_k 169 | self.top_p = top_p 170 | self.typical_p = typical_p 171 | self.reptition_penality = repetition_penalty 172 | self.bad_words_ids = bad_words_ids 173 | self.force_words_ids = force_words_ids 174 | self.no_repeat_ngram_size = no_repeat_ngram_size 175 | self.encoder_no_repeat_ngram_size = encoder_no_repeat_ngram_size 176 | self.max_new_tokens = max_new_tokens 177 | self.decoder_start_token_id = decoder_start_token_id 178 | self.use_cache = use_cache 179 | self.diversity_penalty = diversity_penalty 180 | self.prefix_allowed_tokens_fn = prefix_allowed_tokens_fn 181 | self.renormalize_logits = renormalize_logits 182 | self.contraints = constraints 183 | self.forced_bos_token_id = forced_bos_token_id 184 | self.forced_eos_token_id = forced_eos_token_id 185 | self.remove_invalid_values = remove_invalid_values 186 | self.exponential_decay_length_penalty = exponential_decay_length_penalty 187 | self.synced_gpus = synced_gpus 188 | 189 | # self.return_dict_in_generate = return_dict_in_generate 190 | self.return_dict_in_generate = True 191 | self.output_attentions = output_attentions 192 | self.output_hidden_states = output_hidden_states 193 | self.output_scores = output_scores 194 | 195 | # If not provided, logits processor will be prepared later since it requires input_tensor 196 | self.logits_processor = logits_processor 197 | 198 | # prepare stopping criteria 199 | self.stopping_criteria = self.model._get_stopping_criteria( 200 | max_length=max_length, max_time=max_time, stopping_criteria=stopping_criteria 201 | ) 202 | 203 | self.logits_warper = self.model._get_logits_warper( 204 | top_k=self.top_k, 205 | top_p=self.top_p, 206 | typical_p=self.typical_p, 207 | temperature=self.temperature, 208 | num_beams=self.num_beams, 209 | renormalize_logits=self.renormalize_logits, 210 | ) 211 | 212 | 213 | def score( 214 | self, 215 | input_ids, 216 | next_token_scores, 217 | num_beams=1, 218 | **model_kwargs, 219 | ): 220 | """ 221 | Main function to call for the lookahead. This function generates the sequences and return the calculated heurstics 222 | """ 223 | 224 | # prepare for generation 225 | if self.logits_processor is None: 226 | input_ids_seq_length = input_ids.size(1) 227 | inputs_tensor = model_kwargs["encoder_outputs"][self.model.main_input_name] 228 | 229 | self.logits_processor = self.model._get_logits_processor( 230 | repetition_penalty=self.repetition_penalty, 231 | no_repeat_ngram_size=self.no_repeat_ngram_size, 232 | encoder_no_repeat_ngram_size=self.encoder_no_repeat_ngram_size, 233 | input_ids_seq_length=input_ids_seq_length, 234 | encoder_input_ids=inputs_tensor, 235 | bad_words_ids=self.bad_words_ids, 236 | min_length=self.min_length, 237 | max_length=self.max_length, 238 | eos_token_id=self.eos_token_id, 239 | forced_bos_token_id=self.forced_bos_token_id, 240 | forced_eos_token_id=self.forced_eos_token_id, 241 | prefix_allowed_tokens_fn=self.prefix_allowed_tokens_fn, 242 | num_beams=self.num_beams, 243 | num_beam_groups=self.num_beam_groups, 244 | diversity_penalty=self.diversity_penalty, 245 | remove_invalid_values=self.remove_invalid_values, 246 | exponential_decay_length_penalty=self.exponential_decay_length_penalty, 247 | logits_processor=self.logits_processor, 248 | renormalize_logits=self.renormalize_logits, 249 | ) 250 | 251 | do_sample = "sample" in self.decoding_type 252 | use_beam = "beam" in self.decoding_type 253 | beam_scorer = None 254 | 255 | if use_beam: 256 | batch_size = input_ids.shape[0] * self.lookahead_top_k 257 | beam_scorer = BeamSearchScorer( 258 | batch_size=batch_size, 259 | num_beams=self.num_beams, 260 | max_length=self.stopping_criteria.max_length, 261 | device=input_ids.device, 262 | length_penalty=self.length_penalty, 263 | do_early_stopping=self.early_stopping, 264 | num_beam_hyps_to_keep=self.num_return_sequences, 265 | num_beam_groups=self.num_beam_groups, 266 | ) 267 | 268 | indices = torch.arange(input_ids.size(0), dtype=input_ids.dtype, device=input_ids.device) 269 | 270 | # expand for top k tokens to use with scorer 271 | _, top_k_indices = torch.topk(next_token_scores, k=self.lookahead_top_k, dim=-1) 272 | top_k_indices = top_k_indices.reshape(-1) 273 | 274 | indices = indices.repeat_interleave(self.lookahead_top_k) 275 | input_ids = torch.cat([input_ids[indices],top_k_indices.unsqueeze(1)], dim=1) 276 | 277 | # adjust model_kwargs 278 | model_kwargs = self.expand_model_kwargs(model_kwargs, indices) 279 | 280 | # expand if necssary for beam, currently ignoring sampling with multiple num sequences 281 | if use_beam: 282 | input_ids, model_kwargs = self.model._expand_inputs_for_generation( 283 | input_ids, 284 | expand_size=self.num_beams, 285 | is_encoder_decoder=self.model.config.is_encoder_decoder, 286 | **model_kwargs, 287 | ) 288 | indices = indices.repeat_interleave(self.num_beams) 289 | # exapand inputs for generation but does not expand past 290 | if "past" in model_kwargs: 291 | model_kwargs["past"] = tuple([tuple([p.repeat_interleave(self.num_beams, dim=0) for p in past]) for past in model_kwargs["past"]]) 292 | 293 | # calling the respective decoding function 294 | # the only difference between this implementation and the original is the addition of lookahead length and breaking once that is reached 295 | if self.lookahead_length == 0: 296 | seq = input_ids 297 | else: 298 | dec_out = self.decoding_func(input_ids, beam_scorer, **model_kwargs) 299 | seq = dec_out["sequences"] 300 | 301 | # generate the actual summary 302 | dec_seq = self.tokenizer.batch_decode(seq, skip_special_tokens=True) 303 | 304 | # calculate score given the heuristics, need to account for different indices when doing beam search 305 | _lookahead_scores = self.scorer.score(dec_seq, torch.div(indices, num_beams, rounding_mode="trunc")) 306 | _lookahead_scores = torch.clamp(_lookahead_scores,min=1e-9).log() 307 | 308 | _lookahead_scores = _lookahead_scores.view(-1, self.lookahead_top_k, self.num_beams) 309 | _lookahead_scores, _ = _lookahead_scores.max(-1) 310 | 311 | lookahead_scores = torch.ones_like(next_token_scores, dtype=_lookahead_scores.dtype, device=next_token_scores.device) * 1e-9 312 | lookahead_scores = lookahead_scores.log() 313 | 314 | next_token_scores = F.log_softmax(next_token_scores, dim=-1) 315 | 316 | if use_beam: 317 | # remove repat interleave for beams 318 | indices = indices.view(-1,self.num_beams)[:,0] 319 | 320 | lookahead_scores[indices, top_k_indices] = _lookahead_scores.view(-1) 321 | 322 | return self.lookahead_lambda * lookahead_scores 323 | 324 | def greedy_search( 325 | self, 326 | input_ids: torch.LongTensor, 327 | beam_scorer = None, 328 | **model_kwargs, 329 | ): 330 | # init attention / hidden states / scores tuples 331 | scores = () if (self.return_dict_in_generate and self.output_scores) else None 332 | decoder_attentions = () if (self.return_dict_in_generate and self.output_attentions) else None 333 | cross_attentions = () if (self.return_dict_in_generate and self.output_attentions) else None 334 | decoder_hidden_states = () if (self.return_dict_in_generate and self.output_hidden_states) else None 335 | 336 | # if model is an encoder-decoder, retrieve encoder attention weights and hidden states 337 | if self.return_dict_in_generate and self.model.config.is_encoder_decoder: 338 | encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if self.output_attentions else None 339 | encoder_hidden_states = ( 340 | model_kwargs["encoder_outputs"].get("hidden_states") if self.output_hidden_states else None 341 | ) 342 | 343 | # keep track of which sequences are already finished 344 | unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) 345 | cur_len = input_ids.shape[-1] 346 | 347 | lookahead_length = self.lookahead_length + cur_len 348 | 349 | this_peer_finished = False # used by synced_gpus only 350 | while True: 351 | 352 | if self.synced_gpus: 353 | # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. 354 | # The following logic allows an early break if all peers finished generating their sequence 355 | this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device) 356 | # send 0.0 if we finished, 1.0 otherwise 357 | dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) 358 | # did all peers finish? the reduced sum will be 0.0 then 359 | if this_peer_finished_flag.item() == 0.0: 360 | break 361 | 362 | # prepare model inputs 363 | model_inputs = self.model.prepare_inputs_for_generation(input_ids, **model_kwargs) 364 | 365 | # forward pass to get next token 366 | outputs = self.model( 367 | **model_inputs, 368 | return_dict=True, 369 | output_attentions=self.output_attentions, 370 | output_hidden_states=self.output_hidden_states, 371 | ) 372 | 373 | if self.synced_gpus and this_peer_finished: 374 | cur_len = cur_len + 1 375 | continue # don't waste resources running the code we don't need 376 | 377 | next_token_logits = outputs.logits[:, -1, :] 378 | 379 | # Store scores, attentions and hidden_states when required 380 | if self.return_dict_in_generate: 381 | if self.output_scores: 382 | scores += (next_token_logits,) 383 | if self.output_attentions: 384 | decoder_attentions += ( 385 | (outputs.decoder_attentions,) if self.model.config.is_encoder_decoder else (outputs.attentions,) 386 | ) 387 | if self.model.config.is_encoder_decoder: 388 | cross_attentions += (outputs.cross_attentions,) 389 | 390 | if self.output_hidden_states: 391 | decoder_hidden_states += ( 392 | (outputs.decoder_hidden_states,) 393 | if self.model.config.is_encoder_decoder 394 | else (outputs.hidden_states,) 395 | ) 396 | 397 | # pre-process distribution 398 | next_tokens_scores = self.logits_processor(input_ids, next_token_logits) 399 | 400 | # argmax 401 | next_tokens = torch.argmax(next_tokens_scores, dim=-1) 402 | 403 | # finished sentences should have their next token be a padding token 404 | if self.eos_token_id is not None: 405 | if self.pad_token_id is None: 406 | raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") 407 | next_tokens = next_tokens * unfinished_sequences + self.pad_token_id * (1 - unfinished_sequences) 408 | 409 | # update generated ids, model inputs, and length for next step 410 | input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) 411 | model_kwargs = self.model._update_model_kwargs_for_generation( 412 | outputs, model_kwargs, is_encoder_decoder=self.model.config.is_encoder_decoder 413 | ) 414 | cur_len = cur_len + 1 415 | 416 | # Lookahead break 417 | if cur_len >= lookahead_length: 418 | break 419 | 420 | # if eos_token was found in one sentence, set sentence to finished 421 | if self.eos_token_id is not None: 422 | unfinished_sequences = unfinished_sequences.mul((next_tokens != self.eos_token_id).long()) 423 | 424 | # stop when each sentence is finished, or if we exceed the maximum length 425 | if unfinished_sequences.max() == 0 or self.stopping_criteria(input_ids, scores): 426 | if not self.synced_gpus: 427 | break 428 | else: 429 | this_peer_finished = True 430 | 431 | if self.return_dict_in_generate: 432 | if self.model.config.is_encoder_decoder: 433 | return GreedySearchEncoderDecoderOutput( 434 | sequences=input_ids, 435 | scores=scores, 436 | encoder_attentions=encoder_attentions, 437 | encoder_hidden_states=encoder_hidden_states, 438 | decoder_attentions=decoder_attentions, 439 | cross_attentions=cross_attentions, 440 | decoder_hidden_states=decoder_hidden_states, 441 | ) 442 | else: 443 | return GreedySearchDecoderOnlyOutput( 444 | sequences=input_ids, 445 | scores=scores, 446 | attentions=decoder_attentions, 447 | hidden_states=decoder_hidden_states, 448 | ) 449 | else: 450 | return input_ids 451 | 452 | def beam_search( 453 | self, 454 | input_ids: torch.LongTensor, 455 | beam_scorer = None, 456 | **model_kwargs, 457 | ): 458 | batch_size = len(beam_scorer._beam_hyps) 459 | num_beams = beam_scorer.num_beams 460 | 461 | batch_beam_size, cur_len = input_ids.shape 462 | 463 | lookahead_length = self.lookahead_length + cur_len 464 | 465 | if num_beams * batch_size != batch_beam_size: 466 | raise ValueError( 467 | f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}." 468 | ) 469 | 470 | # init attention / hidden states / scores tuples 471 | scores = () if (self.return_dict_in_generate and self.output_scores) else None 472 | beam_indices = ( 473 | tuple(() for _ in range(batch_beam_size)) if (self.return_dict_in_generate and self.output_scores) else None 474 | ) 475 | decoder_attentions = () if (self.return_dict_in_generate and self.output_attentions) else None 476 | cross_attentions = () if (self.return_dict_in_generate and self.output_attentions) else None 477 | decoder_hidden_states = () if (self.return_dict_in_generate and self.output_hidden_states) else None 478 | 479 | # if model is an encoder-decoder, retrieve encoder attention weights and hidden states 480 | if self.return_dict_in_generate and self.model.config.is_encoder_decoder: 481 | encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if self.output_attentions else None 482 | encoder_hidden_states = ( 483 | model_kwargs["encoder_outputs"].get("hidden_states") if self.output_hidden_states else None 484 | ) 485 | 486 | beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device) 487 | beam_scores[:, 1:] = -1e9 488 | beam_scores = beam_scores.view((batch_size * num_beams,)) 489 | 490 | this_peer_finished = False # used by synced_gpus only 491 | while True: 492 | 493 | if self.synced_gpus: 494 | # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. 495 | # The following logic allows an early break if all peers finished generating their sequence 496 | this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device) 497 | # send 0.0 if we finished, 1.0 otherwise 498 | dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) 499 | # did all peers finish? the reduced sum will be 0.0 then 500 | if this_peer_finished_flag.item() == 0.0: 501 | break 502 | 503 | model_inputs = self.model.prepare_inputs_for_generation(input_ids, **model_kwargs) 504 | 505 | outputs = self.model( 506 | **model_inputs, 507 | return_dict=True, 508 | output_attentions=self.output_attentions, 509 | output_hidden_states=self.output_hidden_states, 510 | ) 511 | 512 | if self.synced_gpus and this_peer_finished: 513 | cur_len = cur_len + 1 514 | continue # don't waste resources running the code we don't need 515 | 516 | next_token_logits = outputs.logits[:, -1, :] 517 | # hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id` 518 | # cannot be generated both before and after the `nn.functional.log_softmax` operation. 519 | next_token_logits = self.model.adjust_logits_during_generation(next_token_logits, cur_len=cur_len) 520 | next_token_scores = nn.functional.log_softmax( 521 | next_token_logits, dim=-1 522 | ) # (batch_size * num_beams, vocab_size) 523 | 524 | next_token_scores_processed = self.logits_processor(input_ids, next_token_scores) 525 | next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(next_token_scores) 526 | 527 | # Store scores, attentions and hidden_states when required 528 | if self.return_dict_in_generate: 529 | if self.output_scores: 530 | scores += (next_token_scores_processed,) 531 | if self.output_attentions: 532 | decoder_attentions += ( 533 | (outputs.decoder_attentions,) if self.model.config.is_encoder_decoder else (outputs.attentions,) 534 | ) 535 | if self.model.config.is_encoder_decoder: 536 | cross_attentions += (outputs.cross_attentions,) 537 | 538 | if self.output_hidden_states: 539 | decoder_hidden_states += ( 540 | (outputs.decoder_hidden_states,) 541 | if self.model.config.is_encoder_decoder 542 | else (outputs.hidden_states,) 543 | ) 544 | 545 | # reshape for beam search 546 | vocab_size = next_token_scores.shape[-1] 547 | next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size) 548 | 549 | next_token_scores, next_tokens = torch.topk( 550 | next_token_scores, 2 * num_beams, dim=1, largest=True, sorted=True 551 | ) 552 | 553 | next_indices = torch_int_div(next_tokens, vocab_size) 554 | next_tokens = next_tokens % vocab_size 555 | 556 | # stateless 557 | beam_outputs = beam_scorer.process( 558 | input_ids, 559 | next_token_scores, 560 | next_tokens, 561 | next_indices, 562 | pad_token_id=self.pad_token_id, 563 | eos_token_id=self.eos_token_id, 564 | ) 565 | 566 | beam_scores = beam_outputs["next_beam_scores"] 567 | beam_next_tokens = beam_outputs["next_beam_tokens"] 568 | beam_idx = beam_outputs["next_beam_indices"] 569 | 570 | input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) 571 | 572 | model_kwargs = self.model._update_model_kwargs_for_generation( 573 | outputs, model_kwargs, is_encoder_decoder=self.model.config.is_encoder_decoder 574 | ) 575 | if model_kwargs["past"] is not None: 576 | model_kwargs["past"] = self.model._reorder_cache(model_kwargs["past"], beam_idx) 577 | 578 | if self.return_dict_in_generate and self.output_scores: 579 | beam_indices = tuple((beam_indices[beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices)))) 580 | 581 | # increase cur_len 582 | cur_len = cur_len + 1 583 | 584 | if cur_len >= lookahead_length: 585 | break 586 | 587 | if beam_scorer.is_done or self.stopping_criteria(input_ids, scores): 588 | if not self.synced_gpus: 589 | break 590 | else: 591 | this_peer_finished = True 592 | 593 | sequence_outputs = beam_scorer.finalize( 594 | input_ids, 595 | beam_scores, 596 | next_tokens, 597 | next_indices, 598 | pad_token_id=self.pad_token_id, 599 | eos_token_id=self.eos_token_id, 600 | max_length=self.stopping_criteria.max_length, 601 | ) 602 | 603 | if self.return_dict_in_generate: 604 | if not self.output_scores: 605 | sequence_outputs["sequence_scores"] = None 606 | else: 607 | num_return_sequences = beam_scorer.num_beam_hyps_to_keep 608 | # return only as many indices as sequences 609 | beam_indices = tuple( 610 | (beam_indices[i * num_beams : i * num_beams + num_return_sequences] for i in range(batch_size)) 611 | ) 612 | beam_indices = sum(beam_indices, ()) 613 | 614 | if self.model.config.is_encoder_decoder: 615 | return BeamSearchEncoderDecoderOutput( 616 | sequences=sequence_outputs["sequences"], 617 | sequences_scores=sequence_outputs["sequence_scores"], 618 | scores=scores, 619 | beam_indices=beam_indices, 620 | encoder_attentions=encoder_attentions, 621 | encoder_hidden_states=encoder_hidden_states, 622 | decoder_attentions=decoder_attentions, 623 | cross_attentions=cross_attentions, 624 | decoder_hidden_states=decoder_hidden_states, 625 | ) 626 | else: 627 | return BeamSearchDecoderOnlyOutput( 628 | sequences=sequence_outputs["sequences"], 629 | sequences_scores=sequence_outputs["sequence_scores"], 630 | scores=scores, 631 | beam_indices=beam_indices, 632 | attentions=decoder_attentions, 633 | hidden_states=decoder_hidden_states, 634 | ) 635 | else: 636 | return sequence_outputs["sequences"] 637 | 638 | def sample( 639 | self, 640 | input_ids: torch.LongTensor, 641 | beam_scorer = None, 642 | **model_kwargs, 643 | ): 644 | scores = () if (self.return_dict_in_generate and self.output_scores) else None 645 | decoder_attentions = () if (self.return_dict_in_generate and self.output_attentions) else None 646 | cross_attentions = () if (self.return_dict_in_generate and self.output_attentions) else None 647 | decoder_hidden_states = () if (self.return_dict_in_generate and self.output_hidden_states) else None 648 | 649 | # if model is an encoder-decoder, retrieve encoder attention weights and hidden states 650 | if self.return_dict_in_generate and self.model.config.is_encoder_decoder: 651 | encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if self.output_attentions else None 652 | encoder_hidden_states = ( 653 | model_kwargs["encoder_outputs"].get("hidden_states") if self.output_hidden_states else None 654 | ) 655 | 656 | # keep track of which sequences are already finished 657 | unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) 658 | cur_len = input_ids.shape[-1] 659 | 660 | lookahead_length = self.lookahead_length + cur_len 661 | 662 | this_peer_finished = False # used by synced_gpus only 663 | # auto-regressive generation 664 | while True: 665 | 666 | if self.synced_gpus: 667 | # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. 668 | # The following logic allows an early break if all peers finished generating their sequence 669 | this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device) 670 | # send 0.0 if we finished, 1.0 otherwise 671 | dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) 672 | # did all peers finish? the reduced sum will be 0.0 then 673 | if this_peer_finished_flag.item() == 0.0: 674 | break 675 | 676 | # prepare model inputs 677 | model_inputs = self.model.prepare_inputs_for_generation(input_ids, **model_kwargs) 678 | 679 | # forward pass to get next token 680 | outputs = self.model( 681 | **model_inputs, 682 | return_dict=True, 683 | output_attentions=self.output_attentions, 684 | output_hidden_states=self.output_hidden_states, 685 | ) 686 | 687 | if self.synced_gpus and this_peer_finished: 688 | cur_len = cur_len + 1 689 | continue # don't waste resources running the code we don't need 690 | 691 | next_token_logits = outputs.logits[:, -1, :] 692 | 693 | # pre-process distribution 694 | next_token_scores = self.logits_processor(input_ids, next_token_logits) 695 | next_token_scores = self.logits_warper(input_ids, next_token_scores) 696 | 697 | # Store scores, attentions and hidden_states when required 698 | if self.return_dict_in_generate: 699 | if self.output_scores: 700 | scores += (next_token_scores,) 701 | if self.output_attentions: 702 | decoder_attentions += ( 703 | (outputs.decoder_attentions,) if self.model.config.is_encoder_decoder else (outputs.attentions,) 704 | ) 705 | if self.model.config.is_encoder_decoder: 706 | cross_attentions += (outputs.cross_attentions,) 707 | 708 | if self.output_hidden_states: 709 | decoder_hidden_states += ( 710 | (outputs.decoder_hidden_states,) 711 | if self.model.config.is_encoder_decoder 712 | else (outputs.hidden_states,) 713 | ) 714 | 715 | # sample 716 | probs = nn.functional.softmax(next_token_scores, dim=-1) 717 | next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) 718 | 719 | # finished sentences should have their next token be a padding token 720 | if self.eos_token_id is not None: 721 | if self.pad_token_id is None: 722 | raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") 723 | next_tokens = next_tokens * unfinished_sequences + self.pad_token_id * (1 - unfinished_sequences) 724 | 725 | # update generated ids, model inputs, and length for next step 726 | input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) 727 | model_kwargs = self.model._update_model_kwargs_for_generation( 728 | outputs, model_kwargs, is_encoder_decoder=self.model.config.is_encoder_decoder 729 | ) 730 | cur_len = cur_len + 1 731 | 732 | if cur_len >= lookahead_length: 733 | break 734 | 735 | # if eos_token was found in one sentence, set sentence to finished 736 | if self.eos_token_id is not None: 737 | unfinished_sequences = unfinished_sequences.mul((next_tokens != self.eos_token_id).long()) 738 | 739 | # stop when each sentence is finished, or if we exceed the maximum length 740 | if unfinished_sequences.max() == 0 or self.stopping_criteria(input_ids, scores): 741 | if not self.synced_gpus: 742 | break 743 | else: 744 | this_peer_finished = True 745 | 746 | if self.return_dict_in_generate: 747 | if self.model.config.is_encoder_decoder: 748 | return SampleEncoderDecoderOutput( 749 | sequences=input_ids, 750 | scores=scores, 751 | encoder_attentions=encoder_attentions, 752 | encoder_hidden_states=encoder_hidden_states, 753 | decoder_attentions=decoder_attentions, 754 | cross_attentions=cross_attentions, 755 | decoder_hidden_states=decoder_hidden_states, 756 | ) 757 | else: 758 | return SampleDecoderOnlyOutput( 759 | sequences=input_ids, 760 | scores=scores, 761 | attentions=decoder_attentions, 762 | hidden_states=decoder_hidden_states, 763 | ) 764 | else: 765 | return input_ids 766 | 767 | 768 | def expand_model_kwargs(self, model_kwargs, indices): 769 | model_kwargs = copy.deepcopy(model_kwargs) 770 | if "attention_mask" in model_kwargs: 771 | model_kwargs["attention_mask"] = model_kwargs["attention_mask"][indices] 772 | if "encoder_outputs" in model_kwargs: 773 | for k,v in model_kwargs["encoder_outputs"].items(): 774 | if v is not None: 775 | model_kwargs["encoder_outputs"][k] = v[indices] 776 | if "past" in model_kwargs: 777 | model_kwargs["past"] = tuple([tuple([p[indices] for p in past]) for past in model_kwargs["past"]]) 778 | return model_kwargs -------------------------------------------------------------------------------- /lookahead/run_generate.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoTokenizer, AutoModelForSeq2SeqLM 2 | from scorer import BERTScoreScorer 3 | from lookahead import Lookahead 4 | from generation import Generator 5 | from tqdm import tqdm 6 | 7 | import json 8 | 9 | import argparse 10 | 11 | parser = argparse.ArgumentParser() 12 | 13 | # base decoding model 14 | parser.add_argument("--model_name", type=str, default="facebook/bart-large-xsum") 15 | parser.add_argument("--cache_dir", type=str, default="./cache") 16 | 17 | # input output 18 | parser.add_argument("--document_file", type=str, required=True) 19 | parser.add_argument("--output_file", type=str, required=True) 20 | 21 | # base decoding configuration. Please refer to Huggingface's GenerationMixin for the explaination of the parameters 22 | parser.add_argument("--batch_size", type=int, default=16) 23 | parser.add_argument("--num_beams", type=int, default=1) 24 | parser.add_argument("--num_return_sequences", type=int, default=1) 25 | parser.add_argument("--max_input_length", type=int, default=512) 26 | parser.add_argument("--max_output_length", type=int, default=64) 27 | parser.add_argument("--do_sample", action='store_true', default=False) 28 | 29 | # lookahead configuration 30 | parser.add_argument("--do_lookahead", action="store_true", default=False) 31 | parser.add_argument("--lookahead_length", type=int, default=64) 32 | parser.add_argument("--lookahead_lambda", type=int, default=25) 33 | parser.add_argument("--top_k", type=int, default=5) 34 | parser.add_argument("--lookahead_decoding_type", type=str, default="greedy", choices=["greedy","beam","sample"]) 35 | parser.add_argument("--lookahead_beam", type=int, default=1) 36 | 37 | # scorer configuration 38 | parser.add_argument("--scorer_model_type", type=str, default="roberta-large") 39 | parser.add_argument("--scorer_num_layers", type=int, default=17) 40 | 41 | args = parser.parse_args() 42 | 43 | # loading model 44 | tokenizer = AutoTokenizer.from_pretrained(args.model_name, cache_dir=args.cache_dir) 45 | model = AutoModelForSeq2SeqLM.from_pretrained(args.model_name, cache_dir=args.cache_dir) 46 | model = model.cuda() # can optionally call .half() for mixed precision 47 | 48 | # loading input 49 | documents = [line.strip() for line in open(args.document_file)] 50 | 51 | # Load scorer for lookahead 52 | scorer = BERTScoreScorer( 53 | model_name=args.scorer_model_type, 54 | num_layers=args.scorer_num_layers, 55 | cache_dir=args.cache_dir, 56 | ) 57 | 58 | # Create lookahead 59 | lookahead = None 60 | if args.do_lookahead: 61 | lookahead = Lookahead( 62 | model, 63 | tokenizer, 64 | scorer, 65 | lookahead_length=args.lookahead_length, 66 | lookahead_lambda=args.lookahead_lambda, 67 | lookahead_top_k=args.top_k, 68 | decoding_type=args.lookahead_decoding_type, 69 | num_beams=args.lookahead_beam, 70 | num_return_sequences=args.lookahead_beam, 71 | max_length=args.max_output_length, 72 | ) 73 | 74 | # Create generator with lookahead 75 | generator = Generator(model, lookahead=lookahead) 76 | 77 | summaries = [] 78 | 79 | for i in tqdm(range(0, len(documents), args.batch_size)): 80 | input_str = documents[i:i+args.batch_size] 81 | 82 | # IMPROTANT! Need to prepare document 83 | if generator.lookahead is not None: 84 | generator.lookahead.scorer.prepare_document(input_str) 85 | 86 | inputs = tokenizer(input_str, max_length=args.max_input_length, padding=True, truncation=True, return_tensors="pt") 87 | 88 | inputs = {k:v.cuda() for k,v in inputs.items()} 89 | 90 | output = generator.generate( 91 | input_ids = inputs["input_ids"], 92 | attention_mask=inputs["attention_mask"], 93 | num_beams=args.num_beams, 94 | num_return_sequences=args.num_return_sequences, 95 | max_length=args.max_output_length, 96 | do_sample=args.do_sample, 97 | ) 98 | 99 | output = tokenizer.batch_decode(output, skip_special_tokens=True) 100 | 101 | if args.num_return_sequences == 1: 102 | summaries += output 103 | else: 104 | for i in range(0, len(output), args.num_return_sequences): 105 | summaries.append(output[i:i+args.num_return_sequences]) 106 | 107 | # Save file 108 | with open(args.output_file, "w") as f: 109 | if args.num_return_sequences == 1: 110 | for line in summaries: 111 | f.write(line + "\n") 112 | else: 113 | json.dump(summaries, f) -------------------------------------------------------------------------------- /lookahead/scorer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from transformers import AutoModel, AutoTokenizer 5 | 6 | class BERTScoreScorer: 7 | """ 8 | Scorer using BS-Fact, code adapted from bertscore official repo: https://github.com/Tiiiger/bert_score 9 | """ 10 | def __init__(self, model_name="roberta-large", device="cuda", num_layers=17, cache_dir=".cache"): 11 | model = AutoModel.from_pretrained(model_name, cache_dir=cache_dir) 12 | self.tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir) 13 | # We assume we are using roberta-large, please reference https://github.com/Tiiiger/bert_score/blob/dbcf6db37e8bd6ff68446f06b0ba5d0763b62d20/bert_score/utils.py#L247 14 | # if you wish to use other model and select the recommended layer 15 | model.encoder.layer = torch.nn.ModuleList([layer for layer in model.encoder.layer[:num_layers]]) 16 | 17 | self.model = model.to(device) 18 | self.device = device 19 | 20 | def prepare_document(self, input_str): 21 | """ 22 | Prepare anything that requires processing on document. 23 | This is called each iteration only once to save computation. 24 | """ 25 | self.bertscore_input_embedding, self.bertscore_input_attention_mask, self.bertscore_input_idf = self.encode_text(input_str) 26 | 27 | def score(self, summaries, index): 28 | """ 29 | Output the score for each example. 30 | summaries: The summary strings 31 | index: The indice of example (document that it should be compared to). IT should ideally be just range() except for beam search. 32 | """ 33 | bertscore_output_embedding, bertscore_output_attention_mask, bertscore_output_idf = self.encode_text(summaries) 34 | 35 | bertscore_input_embedding = self.bertscore_input_embedding[index] 36 | bertscore_input_attention_mask = self.bertscore_input_attention_mask[index] 37 | bertscore_input_idf = self.bertscore_input_idf[index] 38 | 39 | bertscore_scores = self.compute_bertscore( 40 | bertscore_input_embedding, 41 | bertscore_input_attention_mask, 42 | bertscore_input_idf, 43 | bertscore_output_embedding, 44 | bertscore_output_attention_mask, 45 | bertscore_output_idf, 46 | ) 47 | return bertscore_scores 48 | 49 | def encode_text(self, input_str): 50 | """ 51 | Helper function to encode any string to tensor using the tokenizer 52 | """ 53 | inputs = self.tokenizer(input_str, padding=True, truncation=True, return_tensors="pt") 54 | inputs = {k:v.to(self.device) for k,v in inputs.items()} 55 | with torch.no_grad(): 56 | outputs = self.model(**inputs) 57 | 58 | # idf 59 | idf = torch.clone(inputs["attention_mask"]).float() 60 | idf[idf == self.tokenizer.sep_token_id] = 0 61 | idf[idf == self.tokenizer.cls_token_id] = 0 62 | idf.div_(idf.sum(dim=1, keepdim=True)) 63 | 64 | return F.normalize(outputs[0], dim=-1), inputs["attention_mask"], idf 65 | 66 | def compute_bertscore(self, doc_embedding, doc_masks, doc_idf, summ_embedding, summ_masks, summ_idf): 67 | """ 68 | Helper function that is modified from the official code (greedy_cos_idf() method) https://github.com/Tiiiger/bert_score/blob/dbcf6db37e8bd6ff68446f06b0ba5d0763b62d20/bert_score/utils.py#L469 69 | """ 70 | 71 | batch_size = doc_embedding.size(0) 72 | sim = torch.bmm(summ_embedding, doc_embedding.transpose(1, 2)) 73 | masks = torch.bmm(summ_masks.unsqueeze(2).float(), doc_masks.unsqueeze(1).float()) 74 | masks = masks.expand(batch_size, -1, -1).contiguous().view_as(sim) 75 | 76 | masks = masks.float().to(sim.device) 77 | sim = sim * masks 78 | 79 | precision = sim.max(dim=2)[0] 80 | precision_scale = summ_idf.to(precision.device) 81 | P = (precision * precision_scale).sum(dim=1) 82 | 83 | summ_zero_mask = summ_masks.sum(dim=1).eq(2) 84 | if torch.any(summ_zero_mask): 85 | P = P.masked_fill(summ_zero_mask, 0.0) 86 | 87 | doc_zero_mask = doc_masks.sum(dim=1).eq(2) 88 | if torch.any(doc_zero_mask): 89 | P = P.masked_fill(doc_zero_mask, 0.0) 90 | 91 | return P -------------------------------------------------------------------------------- /ranking/generate_documents.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | input_file = sys.argv[1] 4 | num_beams = int(sys.argv[2]) 5 | output_file = sys.argv[3] 6 | 7 | document = [line.strip() for line in open(input_file)] 8 | 9 | 10 | documents_extended = [] 11 | for doc in document: 12 | for _ in range(num_beams): 13 | documents_extended.append(doc) 14 | 15 | with open(output_file,"w") as f: 16 | for doc in documents_extended: 17 | f.write(doc + "\n") 18 | -------------------------------------------------------------------------------- /ranking/rank.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import json 3 | import pandas as pd 4 | 5 | file_prefix = sys.argv[1] 6 | 7 | num_beams = int(sys.argv[2]) 8 | 9 | document = [line.strip() for line in open(file_prefix + ".document")] 10 | summary = [line.strip() for line in open(file_prefix + ".summary")] 11 | 12 | id = [] 13 | for i in range(len(document)): 14 | id.append(i//num_beams) 15 | 16 | factcc = json.load(open(file_prefix + "_factcc.json")) 17 | dae = json.load(open(file_prefix + "_dae.json")) 18 | bsfact = json.load(open(file_prefix + "_dae.json")) 19 | questeval = json.load(open(file_prefix + "_questeval.json")) 20 | 21 | # create csv 22 | 23 | d = { 24 | "id": id, 25 | "document": document, 26 | "summary": summary, 27 | "questeval":questeval, 28 | "dae":dae, 29 | "factcc":factcc, 30 | "bsfact":bsfact, 31 | } 32 | 33 | print({k:len(v) for k,v in d.items()}) 34 | 35 | df = pd.DataFrame.from_dict(d) 36 | 37 | print(df) 38 | 39 | df.to_csv(file_prefix + ".csv", index=False) 40 | 41 | # rank 42 | models = ["bsfact", "factcc","dae", "questeval"] 43 | weight, bias = [1.96576989, 0.2972612, -0.29037403, 0.93960678], -1.9096430327732379 44 | 45 | df["composite"] = [sum( [row[m] * w for m,w in zip(models, weight)] ) + bias for i, row in df.iterrows()] 46 | 47 | summ = df.loc[df.groupby(["id"])["composite"].idxmax()]["summary"] 48 | 49 | with open(file_prefix + "_ranked_summary.txt", "w") as f: 50 | for s in summ: 51 | f.write(s + "\n") -------------------------------------------------------------------------------- /teacher-student/src/data.py: -------------------------------------------------------------------------------- 1 | import random 2 | import warnings 3 | from collections.abc import Mapping 4 | from dataclasses import dataclass 5 | from typing import Any, Callable, Dict, List, NewType, Optional, Tuple, Union 6 | 7 | from transformers import PreTrainedTokenizerBase 8 | from transformers.utils import PaddingStrategy 9 | 10 | @dataclass 11 | class DataCollatorForSeq2SeqWithMultipleReferences: 12 | """ 13 | This is similar to DataCollatorForSeq2Seq except that it also accounts for additional output summaries. 14 | 15 | Data collator that will dynamically pad the inputs received, as well as the labels. 16 | Args: 17 | tokenizer ([`PreTrainedTokenizer`] or [`PreTrainedTokenizerFast`]): 18 | The tokenizer used for encoding the data. 19 | model ([`PreTrainedModel`]): 20 | The model that is being trained. If set and has the *prepare_decoder_input_ids_from_labels*, use it to 21 | prepare the *decoder_input_ids* 22 | This is useful when using *label_smoothing* to avoid calculating loss twice. 23 | padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`): 24 | Select a strategy to pad the returned sequences (according to the model's padding side and padding index) 25 | among: 26 | - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single sequence 27 | is provided). 28 | - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum 29 | acceptable input length for the model if that argument is not provided. 30 | - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different 31 | lengths). 32 | max_length (`int`, *optional*): 33 | Maximum length of the returned list and optionally padding length (see above). 34 | pad_to_multiple_of (`int`, *optional*): 35 | If set will pad the sequence to a multiple of the provided value. 36 | This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >= 37 | 7.5 (Volta). 38 | label_pad_token_id (`int`, *optional*, defaults to -100): 39 | The id to use when padding the labels (-100 will be automatically ignored by PyTorch loss functions). 40 | return_tensors (`str`): 41 | The type of Tensor to return. Allowable values are "np", "pt" and "tf". 42 | """ 43 | 44 | tokenizer: PreTrainedTokenizerBase 45 | model: Optional[Any] = None 46 | padding: Union[bool, str, PaddingStrategy] = True 47 | max_length: Optional[int] = None 48 | pad_to_multiple_of: Optional[int] = None 49 | label_pad_token_id: int = -100 50 | return_tensors: str = "pt" 51 | 52 | def __call__(self, features, return_tensors=None): 53 | import numpy as np 54 | 55 | if return_tensors is None: 56 | return_tensors = self.return_tensors 57 | labels = [feature["labels"] for feature in features] if "labels" in features[0].keys() else None 58 | additional_labels = [feature["additional_labels"] for feature in features] if "additional_labels" in features[0].keys() else None 59 | additional_candidates = [feature["additional_candidates"] for feature in features] if "additional_candidates" in features[0].keys() else None 60 | 61 | # We have to pad the labels before calling `tokenizer.pad` as this method won't pad them and needs them of the 62 | # same length to return tensors. 63 | if labels is not None: 64 | max_label_length = max(len(l) for l in labels) 65 | if self.pad_to_multiple_of is not None: 66 | max_label_length = ( 67 | (max_label_length + self.pad_to_multiple_of - 1) 68 | // self.pad_to_multiple_of 69 | * self.pad_to_multiple_of 70 | ) 71 | 72 | padding_side = self.tokenizer.padding_side 73 | for feature in features: 74 | remainder = [self.label_pad_token_id] * (max_label_length - len(feature["labels"])) 75 | if isinstance(feature["labels"], list): 76 | feature["labels"] = ( 77 | feature["labels"] + remainder if padding_side == "right" else remainder + feature["labels"] 78 | ) 79 | elif padding_side == "right": 80 | feature["labels"] = np.concatenate([feature["labels"], remainder]).astype(np.int64) 81 | else: 82 | feature["labels"] = np.concatenate([remainder, feature["labels"]]).astype(np.int64) 83 | 84 | if additional_labels is not None: 85 | max_label_length = max(len(l) for l in additional_labels) 86 | if self.pad_to_multiple_of is not None: 87 | max_label_length = ( 88 | (max_label_length + self.pad_to_multiple_of - 1) 89 | // self.pad_to_multiple_of 90 | * self.pad_to_multiple_of 91 | ) 92 | 93 | padding_side = self.tokenizer.padding_side 94 | for feature in features: 95 | remainder = [self.label_pad_token_id] * (max_label_length - len(feature["additional_labels"])) 96 | if isinstance(feature["additional_labels"], list): 97 | feature["additional_labels"] = ( 98 | feature["additional_labels"] + remainder if padding_side == "right" else remainder + feature["additional_labels"] 99 | ) 100 | elif padding_side == "right": 101 | feature["additional_labels"] = np.concatenate([feature["additional_labels"], remainder]).astype(np.int64) 102 | else: 103 | feature["additional_labels"] = np.concatenate([remainder, feature["additional_labels"]]).astype(np.int64) 104 | 105 | if additional_candidates is not None: 106 | max_label_length = max(max([len(l) for l in ll]) for ll in additional_candidates) 107 | if self.pad_to_multiple_of is not None: 108 | max_label_length = ( 109 | (max_label_length + self.pad_to_multiple_of - 1) 110 | // self.pad_to_multiple_of 111 | * self.pad_to_multiple_of 112 | ) 113 | 114 | padding_side = self.tokenizer.padding_side 115 | 116 | for feature in features: 117 | padded_feature = [] 118 | for feat in feature["additional_candidates"]: 119 | remainder = [self.label_pad_token_id] * (max_label_length - len(feat)) 120 | if isinstance(feat, list): 121 | _feat = ( 122 | feat + remainder if padding_side == "right" else remainder + feat 123 | ) 124 | elif padding_side == "right": 125 | _feat = np.concatenate([feat, remainder]).astype(np.int64) 126 | else: 127 | _feat = np.concatenate([remainder, feat]).astype(np.int64) 128 | padded_feature.append(_feat) 129 | feature["additional_candidates"] = padded_feature 130 | 131 | features = self.tokenizer.pad( 132 | features, 133 | padding=self.padding, 134 | max_length=self.max_length, 135 | pad_to_multiple_of=self.pad_to_multiple_of, 136 | return_tensors=return_tensors, 137 | ) 138 | 139 | # prepare decoder_input_ids 140 | if ( 141 | labels is not None 142 | and self.model is not None 143 | and hasattr(self.model, "prepare_decoder_input_ids_from_labels") 144 | ): 145 | decoder_input_ids = self.model.prepare_decoder_input_ids_from_labels(labels=features["labels"]) 146 | features["decoder_input_ids"] = decoder_input_ids 147 | 148 | if ( 149 | additional_labels is not None 150 | and self.model is not None 151 | and hasattr(self.model, "prepare_decoder_input_ids_from_labels") 152 | ): 153 | decoder_input_ids = self.model.prepare_decoder_input_ids_from_labels(labels=features["additional_labels"]) 154 | features["additional_decoder_input_ids"] = decoder_input_ids 155 | 156 | # additional candidates 157 | if additional_candidates is not None: 158 | additional_candidates = features.pop("additional_candidates") 159 | num_candidates = additional_candidates.size(1) 160 | 161 | # additional_candidates.masked_fill_(additional_candidates == -100, self.tokenizer.pad_token_id) 162 | # features["candidates_decoder_input_ids"] = additional_candidates 163 | 164 | # no need to shift 165 | additional_candidates = additional_candidates.view(-1, additional_candidates.size(-1)) 166 | decoder_input_ids = self.model.prepare_decoder_input_ids_from_labels(labels=additional_candidates) 167 | decoder_input_ids = decoder_input_ids.view(-1, num_candidates, decoder_input_ids.size(-1)) 168 | features["candidates_decoder_input_ids"] = decoder_input_ids 169 | 170 | return features -------------------------------------------------------------------------------- /teacher-student/src/deepspeed_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | 11 | "optimizer": { 12 | "type": "AdamW", 13 | "params": { 14 | "lr": "auto", 15 | "betas": "auto", 16 | "eps": "auto", 17 | "weight_decay": "auto" 18 | } 19 | }, 20 | 21 | "scheduler": { 22 | "type": "WarmupLR", 23 | "params": { 24 | "warmup_min_lr": "auto", 25 | "warmup_max_lr": "auto", 26 | "warmup_num_steps": "auto" 27 | } 28 | }, 29 | 30 | "zero_optimization": { 31 | "stage": 2, 32 | "offload_optimizer": { 33 | "device": "cpu", 34 | "pin_memory": true 35 | }, 36 | "allgather_partitions": true, 37 | "allgather_bucket_size": 2e8, 38 | "overlap_comm": true, 39 | "reduce_scatter": true, 40 | "reduce_bucket_size": 2e8, 41 | "contiguous_gradients": true 42 | }, 43 | 44 | "gradient_accumulation_steps": "auto", 45 | "gradient_clipping": "auto", 46 | "steps_per_print": 2000, 47 | "train_batch_size": "auto", 48 | "train_micro_batch_size_per_gpu": "auto", 49 | "wall_clock_breakdown": false 50 | } -------------------------------------------------------------------------------- /teacher-student/src/run_summarization.py: -------------------------------------------------------------------------------- 1 | """ 2 | Fine-tuning script for summarization adapted from Huggingface (https://github.com/huggingface/transformers/blob/main/examples/pytorch/summarization/run_summarization.py). 3 | """ 4 | 5 | import logging 6 | import os 7 | import sys 8 | from dataclasses import dataclass, field 9 | from typing import Optional 10 | 11 | import datasets 12 | import nltk # Here to have a nice missing dependency error message early on 13 | import numpy as np 14 | from datasets import load_dataset, load_metric 15 | 16 | import transformers 17 | from filelock import FileLock 18 | from transformers import ( 19 | AutoConfig, 20 | AutoModelForSeq2SeqLM, 21 | AutoTokenizer, 22 | DataCollatorForSeq2Seq, 23 | HfArgumentParser, 24 | MBart50Tokenizer, 25 | MBart50TokenizerFast, 26 | MBartTokenizer, 27 | MBartTokenizerFast, 28 | Seq2SeqTrainer, 29 | Seq2SeqTrainingArguments, 30 | set_seed, 31 | ) 32 | from transformers.trainer_utils import get_last_checkpoint 33 | from transformers.utils import check_min_version, is_offline_mode, send_example_telemetry 34 | from transformers.utils.versions import require_version 35 | 36 | from data import DataCollatorForSeq2SeqWithMultipleReferences 37 | from trainer import CustomTrainer 38 | 39 | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. 40 | # check_min_version("4.21.0.dev0") 41 | 42 | # require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt") 43 | 44 | logger = logging.getLogger(__name__) 45 | 46 | try: 47 | nltk.data.find("tokenizers/punkt") 48 | except (LookupError, OSError): 49 | if is_offline_mode(): 50 | raise LookupError( 51 | "Offline mode: run this script without TRANSFORMERS_OFFLINE first to download nltk data files" 52 | ) 53 | with FileLock(".lock") as lock: 54 | nltk.download("punkt", quiet=True) 55 | 56 | # A list of all multilingual tokenizer which require lang attribute. 57 | MULTILINGUAL_TOKENIZERS = [MBartTokenizer, MBartTokenizerFast, MBart50Tokenizer, MBart50TokenizerFast] 58 | 59 | 60 | @dataclass 61 | class ModelArguments: 62 | """ 63 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. 64 | """ 65 | 66 | model_name_or_path: str = field( 67 | metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} 68 | ) 69 | config_name: Optional[str] = field( 70 | default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} 71 | ) 72 | tokenizer_name: Optional[str] = field( 73 | default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} 74 | ) 75 | cache_dir: Optional[str] = field( 76 | default=None, 77 | metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"}, 78 | ) 79 | use_fast_tokenizer: bool = field( 80 | default=True, 81 | metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}, 82 | ) 83 | model_revision: str = field( 84 | default="main", 85 | metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, 86 | ) 87 | use_auth_token: bool = field( 88 | default=False, 89 | metadata={ 90 | "help": ( 91 | "Will use the token generated when running `transformers-cli login` (necessary to use this script " 92 | "with private models)." 93 | ) 94 | }, 95 | ) 96 | resize_position_embeddings: Optional[bool] = field( 97 | default=None, 98 | metadata={ 99 | "help": ( 100 | "Whether to automatically resize the position embeddings if `max_source_length` exceeds " 101 | "the model's position embeddings." 102 | ) 103 | }, 104 | ) 105 | alpha: float = field(default=1.0) 106 | 107 | @dataclass 108 | class DataTrainingArguments: 109 | """ 110 | Arguments pertaining to what data we are going to input our model for training and eval. 111 | """ 112 | 113 | lang: Optional[str] = field(default=None, metadata={"help": "Language id for summarization."}) 114 | 115 | dataset_name: Optional[str] = field( 116 | default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} 117 | ) 118 | dataset_config_name: Optional[str] = field( 119 | default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} 120 | ) 121 | text_column: Optional[str] = field( 122 | default=None, 123 | metadata={"help": "The name of the column in the datasets containing the full texts (for summarization)."}, 124 | ) 125 | summary_column: Optional[str] = field( 126 | default=None, 127 | metadata={"help": "The name of the column in the datasets containing the summaries (for summarization)."}, 128 | ) 129 | train_file: Optional[str] = field( 130 | default=None, metadata={"help": "The input training data file (a jsonlines or csv file)."} 131 | ) 132 | validation_file: Optional[str] = field( 133 | default=None, 134 | metadata={ 135 | "help": ( 136 | "An optional input evaluation data file to evaluate the metrics (rouge) on (a jsonlines or csv file)." 137 | ) 138 | }, 139 | ) 140 | test_file: Optional[str] = field( 141 | default=None, 142 | metadata={ 143 | "help": "An optional input test data file to evaluate the metrics (rouge) on (a jsonlines or csv file)." 144 | }, 145 | ) 146 | overwrite_cache: bool = field( 147 | default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} 148 | ) 149 | preprocessing_num_workers: Optional[int] = field( 150 | default=None, 151 | metadata={"help": "The number of processes to use for the preprocessing."}, 152 | ) 153 | max_source_length: Optional[int] = field( 154 | default=1024, 155 | metadata={ 156 | "help": ( 157 | "The maximum total input sequence length after tokenization. Sequences longer " 158 | "than this will be truncated, sequences shorter will be padded." 159 | ) 160 | }, 161 | ) 162 | max_target_length: Optional[int] = field( 163 | default=128, 164 | metadata={ 165 | "help": ( 166 | "The maximum total sequence length for target text after tokenization. Sequences longer " 167 | "than this will be truncated, sequences shorter will be padded." 168 | ) 169 | }, 170 | ) 171 | val_max_target_length: Optional[int] = field( 172 | default=None, 173 | metadata={ 174 | "help": ( 175 | "The maximum total sequence length for validation target text after tokenization. Sequences longer " 176 | "than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`." 177 | "This argument is also used to override the ``max_length`` param of ``model.generate``, which is used " 178 | "during ``evaluate`` and ``predict``." 179 | ) 180 | }, 181 | ) 182 | pad_to_max_length: bool = field( 183 | default=False, 184 | metadata={ 185 | "help": ( 186 | "Whether to pad all samples to model maximum sentence length. " 187 | "If False, will pad the samples dynamically when batching to the maximum length in the batch. More " 188 | "efficient on GPU but very bad for TPU." 189 | ) 190 | }, 191 | ) 192 | max_train_samples: Optional[int] = field( 193 | default=None, 194 | metadata={ 195 | "help": ( 196 | "For debugging purposes or quicker training, truncate the number of training examples to this " 197 | "value if set." 198 | ) 199 | }, 200 | ) 201 | max_eval_samples: Optional[int] = field( 202 | default=None, 203 | metadata={ 204 | "help": ( 205 | "For debugging purposes or quicker training, truncate the number of evaluation examples to this " 206 | "value if set." 207 | ) 208 | }, 209 | ) 210 | max_predict_samples: Optional[int] = field( 211 | default=None, 212 | metadata={ 213 | "help": ( 214 | "For debugging purposes or quicker training, truncate the number of prediction examples to this " 215 | "value if set." 216 | ) 217 | }, 218 | ) 219 | num_beams: Optional[int] = field( 220 | default=None, 221 | metadata={ 222 | "help": ( 223 | "Number of beams to use for evaluation. This argument will be passed to ``model.generate``, " 224 | "which is used during ``evaluate`` and ``predict``." 225 | ) 226 | }, 227 | ) 228 | ignore_pad_token_for_loss: bool = field( 229 | default=True, 230 | metadata={ 231 | "help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not." 232 | }, 233 | ) 234 | source_prefix: Optional[str] = field( 235 | default="", metadata={"help": "A prefix to add before every source text (useful for T5 models)."} 236 | ) 237 | 238 | forced_bos_token: Optional[str] = field( 239 | default=None, 240 | metadata={ 241 | "help": ( 242 | "The token to force as the first generated token after the decoder_start_token_id." 243 | "Useful for multilingual models like mBART where the first generated token" 244 | "needs to be the target language token (Usually it is the target language token)" 245 | ) 246 | }, 247 | ) 248 | 249 | reference_file: Optional[str] = field( 250 | default=None 251 | ) 252 | additional_reference_file: Optional[str] = field( 253 | default=None 254 | ) 255 | 256 | def __post_init__(self): 257 | if self.dataset_name is None and self.train_file is None and self.validation_file is None: 258 | raise ValueError("Need either a dataset name or a training/validation file.") 259 | else: 260 | if self.train_file is not None: 261 | extension = self.train_file.split(".")[-1] 262 | assert extension in ["csv", "json"], "`train_file` should be a csv or a json file." 263 | if self.validation_file is not None: 264 | extension = self.validation_file.split(".")[-1] 265 | assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file." 266 | if self.val_max_target_length is None: 267 | self.val_max_target_length = self.max_target_length 268 | 269 | 270 | summarization_name_mapping = { 271 | "amazon_reviews_multi": ("review_body", "review_title"), 272 | "big_patent": ("description", "abstract"), 273 | "cnn_dailymail": ("article", "highlights"), 274 | "orange_sum": ("text", "summary"), 275 | "pn_summary": ("article", "summary"), 276 | "psc": ("extract_text", "summary_text"), 277 | "samsum": ("dialogue", "summary"), 278 | "thaisum": ("body", "summary"), 279 | "xglue": ("news_body", "news_title"), 280 | "xsum": ("document", "summary"), 281 | "wiki_summary": ("article", "highlights"), 282 | "multi_news": ("document", "summary"), 283 | } 284 | 285 | 286 | def main(): 287 | # See all possible arguments in src/transformers/training_args.py 288 | # or by passing the --help flag to this script. 289 | # We now keep distinct sets of args, for a cleaner separation of concerns. 290 | 291 | parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments)) 292 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): 293 | # If we pass only one argument to the script and it's the path to a json file, 294 | # let's parse it to get our arguments. 295 | model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) 296 | else: 297 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 298 | 299 | # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The 300 | # information sent is the one passed as arguments along with your Python/PyTorch versions. 301 | send_example_telemetry("run_summarization", model_args, data_args) 302 | 303 | # Setup logging 304 | logging.basicConfig( 305 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 306 | datefmt="%m/%d/%Y %H:%M:%S", 307 | handlers=[logging.StreamHandler(sys.stdout)], 308 | ) 309 | log_level = training_args.get_process_log_level() 310 | logger.setLevel(log_level) 311 | datasets.utils.logging.set_verbosity(log_level) 312 | transformers.utils.logging.set_verbosity(log_level) 313 | transformers.utils.logging.enable_default_handler() 314 | transformers.utils.logging.enable_explicit_format() 315 | 316 | # Log on each process the small summary: 317 | logger.warning( 318 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" 319 | + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" 320 | ) 321 | logger.info(f"Training/evaluation parameters {training_args}") 322 | 323 | if data_args.source_prefix is None and model_args.model_name_or_path in [ 324 | "t5-small", 325 | "t5-base", 326 | "t5-large", 327 | "t5-3b", 328 | "t5-11b", 329 | ]: 330 | logger.warning( 331 | "You're running a t5 model but didn't provide a source prefix, which is the expected, e.g. with " 332 | "`--source_prefix 'summarize: ' `" 333 | ) 334 | 335 | # Detecting last checkpoint. 336 | last_checkpoint = None 337 | if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir: 338 | last_checkpoint = get_last_checkpoint(training_args.output_dir) 339 | if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: 340 | raise ValueError( 341 | f"Output directory ({training_args.output_dir}) already exists and is not empty. " 342 | "Use --overwrite_output_dir to overcome." 343 | ) 344 | elif last_checkpoint is not None and training_args.resume_from_checkpoint is None: 345 | logger.info( 346 | f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " 347 | "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." 348 | ) 349 | 350 | # Set seed before initializing model. 351 | set_seed(training_args.seed) 352 | 353 | # Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below) 354 | # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/ 355 | # (the dataset will be downloaded automatically from the datasets Hub). 356 | # 357 | # For CSV/JSON files this script will use the first column for the full texts and the second column for the 358 | # summaries (unless you specify column names for this with the `text_column` and `summary_column` arguments). 359 | # 360 | # In distributed training, the load_dataset function guarantee that only one local process can concurrently 361 | # download the dataset. 362 | if data_args.dataset_name is not None: 363 | # Downloading and loading a dataset from the hub. 364 | raw_datasets = load_dataset( 365 | data_args.dataset_name, 366 | data_args.dataset_config_name, 367 | cache_dir=model_args.cache_dir, 368 | use_auth_token=True if model_args.use_auth_token else None, 369 | ) 370 | else: 371 | data_files = {} 372 | if data_args.train_file is not None: 373 | data_files["train"] = data_args.train_file 374 | extension = data_args.train_file.split(".")[-1] 375 | if data_args.validation_file is not None: 376 | data_files["validation"] = data_args.validation_file 377 | extension = data_args.validation_file.split(".")[-1] 378 | if data_args.test_file is not None: 379 | data_files["test"] = data_args.test_file 380 | extension = data_args.test_file.split(".")[-1] 381 | raw_datasets = load_dataset( 382 | extension, 383 | data_files=data_files, 384 | cache_dir=model_args.cache_dir, 385 | use_auth_token=True if model_args.use_auth_token else None, 386 | ) 387 | 388 | # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at 389 | # https://huggingface.co/docs/datasets/loading_datasets.html. 390 | 391 | # Load pretrained model and tokenizer 392 | # 393 | # Distributed training: 394 | # The .from_pretrained methods guarantee that only one local process can concurrently 395 | # download model & vocab. 396 | config = AutoConfig.from_pretrained( 397 | model_args.config_name if model_args.config_name else model_args.model_name_or_path, 398 | cache_dir=model_args.cache_dir, 399 | revision=model_args.model_revision, 400 | use_auth_token=True if model_args.use_auth_token else None, 401 | ) 402 | tokenizer = AutoTokenizer.from_pretrained( 403 | model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, 404 | cache_dir=model_args.cache_dir, 405 | use_fast=model_args.use_fast_tokenizer, 406 | revision=model_args.model_revision, 407 | use_auth_token=True if model_args.use_auth_token else None, 408 | ) 409 | model = AutoModelForSeq2SeqLM.from_pretrained( 410 | model_args.model_name_or_path, 411 | from_tf=bool(".ckpt" in model_args.model_name_or_path), 412 | config=config, 413 | cache_dir=model_args.cache_dir, 414 | revision=model_args.model_revision, 415 | use_auth_token=True if model_args.use_auth_token else None, 416 | ) 417 | 418 | model.resize_token_embeddings(len(tokenizer)) 419 | 420 | if model.config.decoder_start_token_id is None and isinstance(tokenizer, (MBartTokenizer, MBartTokenizerFast)): 421 | if isinstance(tokenizer, MBartTokenizer): 422 | model.config.decoder_start_token_id = tokenizer.lang_code_to_id[data_args.lang] 423 | else: 424 | model.config.decoder_start_token_id = tokenizer.convert_tokens_to_ids(data_args.lang) 425 | 426 | if model.config.decoder_start_token_id is None: 427 | raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined") 428 | 429 | if ( 430 | hasattr(model.config, "max_position_embeddings") 431 | and model.config.max_position_embeddings < data_args.max_source_length 432 | ): 433 | if model_args.resize_position_embeddings is None: 434 | logger.warning( 435 | "Increasing the model's number of position embedding vectors from" 436 | f" {model.config.max_position_embeddings} to {data_args.max_source_length}." 437 | ) 438 | model.resize_position_embeddings(data_args.max_source_length) 439 | elif model_args.resize_position_embeddings: 440 | model.resize_position_embeddings(data_args.max_source_length) 441 | else: 442 | raise ValueError( 443 | f"`--max_source_length` is set to {data_args.max_source_length}, but the model only has" 444 | f" {model.config.max_position_embeddings} position encodings. Consider either reducing" 445 | f" `--max_source_length` to {model.config.max_position_embeddings} or to automatically resize the" 446 | " model's position encodings by passing `--resize_position_embeddings`." 447 | ) 448 | 449 | prefix = data_args.source_prefix if data_args.source_prefix is not None else "" 450 | 451 | # Preprocessing the datasets. 452 | # We need to tokenize inputs and targets. 453 | if training_args.do_train: 454 | column_names = raw_datasets["train"].column_names 455 | elif training_args.do_eval: 456 | column_names = raw_datasets["validation"].column_names 457 | elif training_args.do_predict: 458 | column_names = raw_datasets["test"].column_names 459 | else: 460 | logger.info("There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`.") 461 | return 462 | 463 | if isinstance(tokenizer, tuple(MULTILINGUAL_TOKENIZERS)): 464 | assert ( 465 | data_args.lang is not None 466 | ), f"{tokenizer.__class__.__name__} is a multilingual tokenizer which requires --lang argument" 467 | 468 | tokenizer.src_lang = data_args.lang 469 | tokenizer.tgt_lang = data_args.lang 470 | 471 | # For multilingual translation models like mBART-50 and M2M100 we need to force the target language token 472 | # as the first generated token. We ask the user to explicitly provide this as --forced_bos_token argument. 473 | forced_bos_token_id = ( 474 | tokenizer.lang_code_to_id[data_args.forced_bos_token] if data_args.forced_bos_token is not None else None 475 | ) 476 | model.config.forced_bos_token_id = forced_bos_token_id 477 | 478 | # Get the column names for input/target. 479 | dataset_columns = summarization_name_mapping.get(data_args.dataset_name, None) 480 | if data_args.text_column is None: 481 | text_column = dataset_columns[0] if dataset_columns is not None else column_names[0] 482 | else: 483 | text_column = data_args.text_column 484 | if text_column not in column_names: 485 | raise ValueError( 486 | f"--text_column' value '{data_args.text_column}' needs to be one of: {', '.join(column_names)}" 487 | ) 488 | if data_args.summary_column is None: 489 | summary_column = dataset_columns[1] if dataset_columns is not None else column_names[1] 490 | else: 491 | summary_column = data_args.summary_column 492 | if summary_column not in column_names: 493 | raise ValueError( 494 | f"--summary_column' value '{data_args.summary_column}' needs to be one of: {', '.join(column_names)}" 495 | ) 496 | 497 | # Temporarily set max_target_length for training. 498 | max_target_length = data_args.max_target_length 499 | padding = "max_length" if data_args.pad_to_max_length else False 500 | 501 | if training_args.label_smoothing_factor > 0 and not hasattr(model, "prepare_decoder_input_ids_from_labels"): 502 | logger.warning( 503 | "label_smoothing is enabled but the `prepare_decoder_input_ids_from_labels` method is not defined for" 504 | f"`{model.__class__.__name__}`. This will lead to loss being calculated twice and will take up more memory" 505 | ) 506 | 507 | def preprocess_function(examples): 508 | # remove pairs where at least one record is None 509 | 510 | inputs, targets = [], [] 511 | for i in range(len(examples[text_column])): 512 | # if examples[text_column][i] and examples[summary_column][i]: 513 | inputs.append(examples[text_column][i]) 514 | targets.append(examples[summary_column][i]) 515 | 516 | inputs = [prefix + inp for inp in inputs] 517 | model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, padding=padding, truncation=True) 518 | 519 | # Setup the tokenizer for targets 520 | labels = targets 521 | additional_labels = None 522 | with tokenizer.as_target_tokenizer(): 523 | if "additional_labels" in examples: 524 | additional_labels = tokenizer(examples["additional_labels"], max_length=max_target_length, padding=padding, truncation=True) 525 | 526 | labels = tokenizer(labels, max_length=max_target_length, padding=padding, truncation=True) 527 | 528 | # If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore 529 | # padding in the loss. 530 | if padding == "max_length" and data_args.ignore_pad_token_for_loss: 531 | labels["input_ids"] = [ 532 | [(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"] 533 | ] 534 | 535 | if additional_labels is not None: 536 | additional_labels["input_ids"] = [ 537 | [(l if l != tokenizer.pad_token_id else -100) for l in label] for label in additional_labels["input_ids"] 538 | ] 539 | 540 | if additional_candidates is not None: 541 | additional_candidates["input_ids"] = [ 542 | [(l if l != tokenizer.pad_token_id else -100) for l in label] for label in additional_candidates["input_ids"] 543 | ] 544 | 545 | 546 | model_inputs["labels"] = labels["input_ids"] 547 | if additional_labels is not None: 548 | model_inputs["additional_labels"] = additional_labels["input_ids"] 549 | 550 | return model_inputs 551 | 552 | if training_args.do_train: 553 | if "train" not in raw_datasets: 554 | raise ValueError("--do_train requires a train dataset") 555 | train_dataset = raw_datasets["train"] 556 | 557 | # replace reference summary with reference_file if needed 558 | if data_args.reference_file is not None: 559 | references = [line.strip() for line in open(data_args.reference_file)] 560 | assert len(references) == len(train_dataset[summary_column]) 561 | print("replacing reference...") 562 | print(train_dataset[0][summary_column]) 563 | train_dataset = train_dataset.remove_columns(summary_column).add_column(summary_column, references) 564 | print(train_dataset[0][summary_column]) 565 | 566 | if data_args.additional_reference_file is not None: 567 | references = [line.strip() for line in open(data_args.additional_reference_file)] 568 | assert len(references) == len(train_dataset[summary_column]) 569 | train_dataset = train_dataset.add_column("additional_labels", references) 570 | 571 | if data_args.max_train_samples is not None: 572 | max_train_samples = min(len(train_dataset), data_args.max_train_samples) 573 | train_dataset = train_dataset.select(range(max_train_samples)) 574 | with training_args.main_process_first(desc="train dataset map pre-processing"): 575 | train_dataset = train_dataset.map( 576 | preprocess_function, 577 | batched=True, 578 | num_proc=data_args.preprocessing_num_workers, 579 | remove_columns=column_names, 580 | load_from_cache_file=not data_args.overwrite_cache, 581 | desc="Running tokenizer on train dataset", 582 | ) 583 | 584 | if training_args.do_eval: 585 | max_target_length = data_args.val_max_target_length 586 | if "validation" not in raw_datasets: 587 | raise ValueError("--do_eval requires a validation dataset") 588 | eval_dataset = raw_datasets["validation"] 589 | if data_args.max_eval_samples is not None: 590 | max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples) 591 | eval_dataset = eval_dataset.select(range(max_eval_samples)) 592 | with training_args.main_process_first(desc="validation dataset map pre-processing"): 593 | eval_dataset = eval_dataset.map( 594 | preprocess_function, 595 | batched=True, 596 | num_proc=data_args.preprocessing_num_workers, 597 | remove_columns=column_names, 598 | load_from_cache_file=not data_args.overwrite_cache, 599 | desc="Running tokenizer on validation dataset", 600 | ) 601 | 602 | if training_args.do_predict: 603 | max_target_length = data_args.val_max_target_length 604 | if "test" not in raw_datasets: 605 | raise ValueError("--do_predict requires a test dataset") 606 | predict_dataset = raw_datasets["test"] 607 | if data_args.max_predict_samples is not None: 608 | max_predict_samples = min(len(predict_dataset), data_args.max_predict_samples) 609 | predict_dataset = predict_dataset.select(range(max_predict_samples)) 610 | with training_args.main_process_first(desc="prediction dataset map pre-processing"): 611 | predict_dataset = predict_dataset.map( 612 | preprocess_function, 613 | batched=True, 614 | num_proc=data_args.preprocessing_num_workers, 615 | remove_columns=column_names, 616 | load_from_cache_file=not data_args.overwrite_cache, 617 | desc="Running tokenizer on prediction dataset", 618 | ) 619 | 620 | # Data collator 621 | label_pad_token_id = -100 if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id 622 | # data_collator = DataCollatorForSeq2Seq( 623 | data_collator = DataCollatorForSeq2SeqWithMultipleReferences( 624 | tokenizer, 625 | model=model, 626 | label_pad_token_id=label_pad_token_id, 627 | pad_to_multiple_of=8 if training_args.fp16 else None, 628 | ) 629 | 630 | # Metric 631 | metric = load_metric("rouge") 632 | 633 | def postprocess_text(preds, labels): 634 | preds = [pred.strip() for pred in preds] 635 | labels = [label.strip() for label in labels] 636 | 637 | # rougeLSum expects newline after each sentence 638 | preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds] 639 | labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels] 640 | 641 | return preds, labels 642 | 643 | def compute_metrics(eval_preds): 644 | preds, labels = eval_preds 645 | if isinstance(preds, tuple): 646 | preds = preds[0] 647 | decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True) 648 | if data_args.ignore_pad_token_for_loss: 649 | # Replace -100 in the labels as we can't decode them. 650 | labels = np.where(labels != -100, labels, tokenizer.pad_token_id) 651 | decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True) 652 | 653 | # Some simple post-processing 654 | decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels) 655 | 656 | result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True) 657 | # Extract a few results from ROUGE 658 | result = {key: value.mid.fmeasure * 100 for key, value in result.items()} 659 | 660 | prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds] 661 | result["gen_len"] = np.mean(prediction_lens) 662 | result = {k: round(v, 4) for k, v in result.items()} 663 | return result 664 | 665 | # Initialize our Trainer 666 | # trainer = Seq2SeqTrainer( 667 | # model=model, 668 | # args=training_args, 669 | # train_dataset=train_dataset if training_args.do_train else None, 670 | # eval_dataset=eval_dataset if training_args.do_eval else None, 671 | # tokenizer=tokenizer, 672 | # data_collator=data_collator, 673 | # compute_metrics=compute_metrics if training_args.predict_with_generate else None, 674 | # ) 675 | 676 | trainer = CustomTrainer( 677 | model=model, 678 | args=training_args, 679 | train_dataset=train_dataset if training_args.do_train else None, 680 | eval_dataset=eval_dataset if training_args.do_eval else None, 681 | tokenizer=tokenizer, 682 | data_collator=data_collator, 683 | compute_metrics=compute_metrics if training_args.predict_with_generate else None, 684 | alpha = model_args.alpha, 685 | ) 686 | 687 | # Training 688 | if training_args.do_train: 689 | checkpoint = None 690 | if training_args.resume_from_checkpoint is not None: 691 | checkpoint = training_args.resume_from_checkpoint 692 | elif last_checkpoint is not None: 693 | checkpoint = last_checkpoint 694 | train_result = trainer.train(resume_from_checkpoint=checkpoint) 695 | trainer.save_model() # Saves the tokenizer too for easy upload 696 | 697 | metrics = train_result.metrics 698 | max_train_samples = ( 699 | data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset) 700 | ) 701 | metrics["train_samples"] = min(max_train_samples, len(train_dataset)) 702 | 703 | trainer.log_metrics("train", metrics) 704 | trainer.save_metrics("train", metrics) 705 | trainer.save_state() 706 | 707 | # Evaluation 708 | results = {} 709 | max_length = ( 710 | training_args.generation_max_length 711 | if training_args.generation_max_length is not None 712 | else data_args.val_max_target_length 713 | ) 714 | num_beams = data_args.num_beams if data_args.num_beams is not None else training_args.generation_num_beams 715 | if training_args.do_eval: 716 | logger.info("*** Evaluate ***") 717 | metrics = trainer.evaluate(max_length=max_length, num_beams=num_beams, metric_key_prefix="eval") 718 | max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset) 719 | metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset)) 720 | 721 | trainer.log_metrics("eval", metrics) 722 | trainer.save_metrics("eval", metrics) 723 | 724 | if training_args.do_predict: 725 | logger.info("*** Predict ***") 726 | predict_results = trainer.predict( 727 | predict_dataset, metric_key_prefix="predict", max_length=max_length, num_beams=num_beams 728 | ) 729 | metrics = predict_results.metrics 730 | max_predict_samples = ( 731 | data_args.max_predict_samples if data_args.max_predict_samples is not None else len(predict_dataset) 732 | ) 733 | metrics["predict_samples"] = min(max_predict_samples, len(predict_dataset)) 734 | 735 | trainer.log_metrics("predict", metrics) 736 | trainer.save_metrics("predict", metrics) 737 | 738 | if trainer.is_world_process_zero(): 739 | if training_args.predict_with_generate: 740 | predictions = tokenizer.batch_decode( 741 | predict_results.predictions, skip_special_tokens=True, clean_up_tokenization_spaces=True 742 | ) 743 | predictions = [pred.strip().replace("\n"," ") for pred in predictions] 744 | output_prediction_file = os.path.join(training_args.output_dir, "generated_predictions.txt") 745 | with open(output_prediction_file, "w") as writer: 746 | writer.write("\n".join(predictions)) 747 | 748 | kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "summarization"} 749 | if data_args.dataset_name is not None: 750 | kwargs["dataset_tags"] = data_args.dataset_name 751 | if data_args.dataset_config_name is not None: 752 | kwargs["dataset_args"] = data_args.dataset_config_name 753 | kwargs["dataset"] = f"{data_args.dataset_name} {data_args.dataset_config_name}" 754 | else: 755 | kwargs["dataset"] = data_args.dataset_name 756 | 757 | if data_args.lang is not None: 758 | kwargs["language"] = data_args.lang 759 | 760 | if training_args.push_to_hub: 761 | trainer.push_to_hub(**kwargs) 762 | else: 763 | trainer.create_model_card(**kwargs) 764 | 765 | return results 766 | 767 | 768 | def _mp_fn(index): 769 | # For xla_spawn (TPUs) 770 | main() 771 | 772 | 773 | if __name__ == "__main__": 774 | main() -------------------------------------------------------------------------------- /teacher-student/src/trainer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from transformers import Seq2SeqTrainer 5 | 6 | 7 | class CustomTrainer(Seq2SeqTrainer): 8 | """ 9 | Custom trainer for multiple Cross Entropy Loss. Adapted from original huggingface trainer code. 10 | """ 11 | def __init__( 12 | self, 13 | model=None, 14 | args=None, 15 | data_collator=None, 16 | train_dataset=None, 17 | eval_dataset=None, 18 | tokenizer=None, 19 | model_init=None, 20 | compute_metrics=None, 21 | callbacks=None, 22 | optimizers=(None, None), 23 | preprocess_logits_for_metrics=None, 24 | alpha = 1.0, 25 | ): 26 | super().__init__( 27 | model=model, 28 | args=args, 29 | data_collator=data_collator, 30 | train_dataset=train_dataset, 31 | eval_dataset=eval_dataset, 32 | tokenizer=tokenizer, 33 | model_init=model_init, 34 | compute_metrics=compute_metrics, 35 | callbacks=callbacks, 36 | optimizers=optimizers, 37 | preprocess_logits_for_metrics=preprocess_logits_for_metrics, 38 | ) 39 | self.alpha = alpha 40 | 41 | def compute_loss(self, model, inputs, return_outputs=False): 42 | additional_decoder_input_ids = inputs.pop("additional_decoder_input_ids", None) 43 | additional_labels = inputs.pop("additional_labels", None) 44 | 45 | # first get encoder outputs to save computation 46 | encoder_outputs = model.get_encoder()( 47 | input_ids = inputs["input_ids"], 48 | attention_mask = inputs["attention_mask"] 49 | ) 50 | inputs["encoder_outputs"] = encoder_outputs 51 | 52 | # Cross Entropy Loss 53 | 54 | # original XE 55 | orig_loss = super().compute_loss(model, inputs, return_outputs) 56 | loss = orig_loss 57 | 58 | # additional labels 59 | if additional_labels is not None: 60 | # compute loss for labels and additional_labels separaetly 61 | 62 | inputs["decoder_input_ids"] = additional_decoder_input_ids 63 | inputs["labels"] = additional_labels 64 | additional_loss = super().compute_loss(model, inputs, return_outputs) 65 | 66 | loss += self.alpha * additional_loss 67 | 68 | return loss -------------------------------------------------------------------------------- /teacher-student/train_script.sh: -------------------------------------------------------------------------------- 1 | deepspeed --include localhost:0,1,2,3 src/run_summarization.py --fp16 \ 2 | --deepspeed src/deepspeed_config.json \ 3 | --dataset_name xsum \ 4 | --model_name_or_path facebook/bart-large \ 5 | --do_train --evaluation_strategy no \ 6 | --label_smoothing 0.1 --learning_rate 3e-5 --gradient_accumulation_step 4 --per_device_train_batch_size 8 \ 7 | --max_source_length 512 --max_target_length 64 \ 8 | --warmup_steps 500 --max_grad_norm 0.1 --max_steps 15000 --save_strategy no \ 9 | --output_dir out_xsum --overwrite_cache --remove_unused_columns true --additional_reference_file additional_summary.txt 10 | 11 | # # Optionally without deepspeed 12 | # python src/run_summarization.py --fp16 \ 13 | # --dataset_name xsum \ 14 | # --model_name_or_path facebook/bart-large \ 15 | # --do_train --evaluation_strategy no \ 16 | # --label_smoothing 0.1 --learning_rate 3e-5 --gradient_accumulation_step 4 --per_device_train_batch_size 8 \ 17 | # --max_source_length 512 --max_target_length 64 \ 18 | # --warmup_steps 500 --max_grad_norm 0.1 --max_steps 15000 --save_strategy no \ 19 | # --output_dir out_xsum --overwrite_cache --remove_unused_columns true --------------------------------------------------------------------------------