├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── images └── read-control.png ├── requirements.txt └── src ├── inference ├── generation.py ├── inference_category.sh ├── inference_lookahead.sh ├── inference_score.sh ├── lookahead.py ├── run_lookahead.py ├── run_summarization.py └── scorer.py ├── preprocess ├── generate_prompts_category.py ├── generate_prompts_score.py └── preprocess_cnndm.py └── train ├── ds_config_stage3_fb16.json ├── rl ├── accelerate_config.yaml ├── train.py └── train_rl_cnndm.sh ├── run_summarization.py └── train_cnndm.sh /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | ## Code of Conduct 2 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 3 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 4 | opensource-codeofconduct@amazon.com with any additional questions or comments. 5 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing Guidelines 2 | 3 | Thank you for your interest in contributing to our project. Whether it's a bug report, new feature, correction, or additional 4 | documentation, we greatly value feedback and contributions from our community. 5 | 6 | Please read through this document before submitting any issues or pull requests to ensure we have all the necessary 7 | information to effectively respond to your bug report or contribution. 8 | 9 | 10 | ## Reporting Bugs/Feature Requests 11 | 12 | We welcome you to use the GitHub issue tracker to report bugs or suggest features. 13 | 14 | When filing an issue, please check existing open, or recently closed, issues to make sure somebody else hasn't already 15 | reported the issue. Please try to include as much information as you can. Details like these are incredibly useful: 16 | 17 | * A reproducible test case or series of steps 18 | * The version of our code being used 19 | * Any modifications you've made relevant to the bug 20 | * Anything unusual about your environment or deployment 21 | 22 | 23 | ## Contributing via Pull Requests 24 | Contributions via pull requests are much appreciated. Before sending us a pull request, please ensure that: 25 | 26 | 1. You are working against the latest source on the *main* branch. 27 | 2. You check existing open, and recently merged, pull requests to make sure someone else hasn't addressed the problem already. 28 | 3. You open an issue to discuss any significant work - we would hate for your time to be wasted. 29 | 30 | To send us a pull request, please: 31 | 32 | 1. Fork the repository. 33 | 2. Modify the source; please focus on the specific change you are contributing. If you also reformat all the code, it will be hard for us to focus on your change. 34 | 3. Ensure local tests pass. 35 | 4. Commit to your fork using clear commit messages. 36 | 5. Send us a pull request, answering any default questions in the pull request interface. 37 | 6. Pay attention to any automated CI failures reported in the pull request, and stay involved in the conversation. 38 | 39 | GitHub provides additional document on [forking a repository](https://help.github.com/articles/fork-a-repo/) and 40 | [creating a pull request](https://help.github.com/articles/creating-a-pull-request/). 41 | 42 | 43 | ## Finding contributions to work on 44 | Looking at the existing issues is a great way to find something to contribute on. As our projects, by default, use the default GitHub issue labels (enhancement/bug/duplicate/help wanted/invalid/question/wontfix), looking at any 'help wanted' issues is a great place to start. 45 | 46 | 47 | ## Code of Conduct 48 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct). 49 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact 50 | opensource-codeofconduct@amazon.com with any additional questions or comments. 51 | 52 | 53 | ## Security issue notifications 54 | If you discover a potential security issue in this project we ask that you notify AWS/Amazon Security via our [vulnerability reporting page](http://aws.amazon.com/security/vulnerability-reporting/). Please do **not** create a public github issue. 55 | 56 | 57 | ## Licensing 58 | 59 | See the [LICENSE](LICENSE) file for our project's licensing. We will ask you to confirm the licensing of your contribution. 60 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Attribution-NonCommercial 4.0 International 2 | 3 | ======================================================================= 4 | 5 | Creative Commons Corporation ("Creative Commons") is not a law firm and 6 | does not provide legal services or legal advice. Distribution of 7 | Creative Commons public licenses does not create a lawyer-client or 8 | other relationship. Creative Commons makes its licenses and related 9 | information available on an "as-is" basis. Creative Commons gives no 10 | warranties regarding its licenses, any material licensed under their 11 | terms and conditions, or any related information. Creative Commons 12 | disclaims all liability for damages resulting from their use to the 13 | fullest extent possible. 14 | 15 | Using Creative Commons Public Licenses 16 | 17 | Creative Commons public licenses provide a standard set of terms and 18 | conditions that creators and other rights holders may use to share 19 | original works of authorship and other material subject to copyright 20 | and certain other rights specified in the public license below. The 21 | following considerations are for informational purposes only, are not 22 | exhaustive, and do not form part of our licenses. 23 | 24 | Considerations for licensors: Our public licenses are 25 | intended for use by those authorized to give the public 26 | permission to use material in ways otherwise restricted by 27 | copyright and certain other rights. Our licenses are 28 | irrevocable. Licensors should read and understand the terms 29 | and conditions of the license they choose before applying it. 30 | Licensors should also secure all rights necessary before 31 | applying our licenses so that the public can reuse the 32 | material as expected. Licensors should clearly mark any 33 | material not subject to the license. This includes other CC- 34 | licensed material, or material used under an exception or 35 | limitation to copyright. More considerations for licensors: 36 | wiki.creativecommons.org/Considerations_for_licensors 37 | 38 | Considerations for the public: By using one of our public 39 | licenses, a licensor grants the public permission to use the 40 | licensed material under specified terms and conditions. If 41 | the licensor's permission is not necessary for any reason--for 42 | example, because of any applicable exception or limitation to 43 | copyright--then that use is not regulated by the license. Our 44 | licenses grant only permissions under copyright and certain 45 | other rights that a licensor has authority to grant. Use of 46 | the licensed material may still be restricted for other 47 | reasons, including because others have copyright or other 48 | rights in the material. A licensor may make special requests, 49 | such as asking that all changes be marked or described. 50 | Although not required by our licenses, you are encouraged to 51 | respect those requests where reasonable. More considerations 52 | for the public: 53 | wiki.creativecommons.org/Considerations_for_licensees 54 | 55 | ======================================================================= 56 | 57 | Creative Commons Attribution-NonCommercial 4.0 International Public 58 | License 59 | 60 | By exercising the Licensed Rights (defined below), You accept and agree 61 | to be bound by the terms and conditions of this Creative Commons 62 | Attribution-NonCommercial 4.0 International Public License ("Public 63 | License"). To the extent this Public License may be interpreted as a 64 | contract, You are granted the Licensed Rights in consideration of Your 65 | acceptance of these terms and conditions, and the Licensor grants You 66 | such rights in consideration of benefits the Licensor receives from 67 | making the Licensed Material available under these terms and 68 | conditions. 69 | 70 | 71 | Section 1 -- Definitions. 72 | 73 | a. Adapted Material means material subject to Copyright and Similar 74 | Rights that is derived from or based upon the Licensed Material 75 | and in which the Licensed Material is translated, altered, 76 | arranged, transformed, or otherwise modified in a manner requiring 77 | permission under the Copyright and Similar Rights held by the 78 | Licensor. For purposes of this Public License, where the Licensed 79 | Material is a musical work, performance, or sound recording, 80 | Adapted Material is always produced where the Licensed Material is 81 | synched in timed relation with a moving image. 82 | 83 | b. Adapter's License means the license You apply to Your Copyright 84 | and Similar Rights in Your contributions to Adapted Material in 85 | accordance with the terms and conditions of this Public License. 86 | 87 | c. Copyright and Similar Rights means copyright and/or similar rights 88 | closely related to copyright including, without limitation, 89 | performance, broadcast, sound recording, and Sui Generis Database 90 | Rights, without regard to how the rights are labeled or 91 | categorized. For purposes of this Public License, the rights 92 | specified in Section 2(b)(1)-(2) are not Copyright and Similar 93 | Rights. 94 | d. Effective Technological Measures means those measures that, in the 95 | absence of proper authority, may not be circumvented under laws 96 | fulfilling obligations under Article 11 of the WIPO Copyright 97 | Treaty adopted on December 20, 1996, and/or similar international 98 | agreements. 99 | 100 | e. Exceptions and Limitations means fair use, fair dealing, and/or 101 | any other exception or limitation to Copyright and Similar Rights 102 | that applies to Your use of the Licensed Material. 103 | 104 | f. Licensed Material means the artistic or literary work, database, 105 | or other material to which the Licensor applied this Public 106 | License. 107 | 108 | g. Licensed Rights means the rights granted to You subject to the 109 | terms and conditions of this Public License, which are limited to 110 | all Copyright and Similar Rights that apply to Your use of the 111 | Licensed Material and that the Licensor has authority to license. 112 | 113 | h. Licensor means the individual(s) or entity(ies) granting rights 114 | under this Public License. 115 | 116 | i. NonCommercial means not primarily intended for or directed towards 117 | commercial advantage or monetary compensation. For purposes of 118 | this Public License, the exchange of the Licensed Material for 119 | other material subject to Copyright and Similar Rights by digital 120 | file-sharing or similar means is NonCommercial provided there is 121 | no payment of monetary compensation in connection with the 122 | exchange. 123 | 124 | j. Share means to provide material to the public by any means or 125 | process that requires permission under the Licensed Rights, such 126 | as reproduction, public display, public performance, distribution, 127 | dissemination, communication, or importation, and to make material 128 | available to the public including in ways that members of the 129 | public may access the material from a place and at a time 130 | individually chosen by them. 131 | 132 | k. Sui Generis Database Rights means rights other than copyright 133 | resulting from Directive 96/9/EC of the European Parliament and of 134 | the Council of 11 March 1996 on the legal protection of databases, 135 | as amended and/or succeeded, as well as other essentially 136 | equivalent rights anywhere in the world. 137 | 138 | l. You means the individual or entity exercising the Licensed Rights 139 | under this Public License. Your has a corresponding meaning. 140 | 141 | 142 | Section 2 -- Scope. 143 | 144 | a. License grant. 145 | 146 | 1. Subject to the terms and conditions of this Public License, 147 | the Licensor hereby grants You a worldwide, royalty-free, 148 | non-sublicensable, non-exclusive, irrevocable license to 149 | exercise the Licensed Rights in the Licensed Material to: 150 | 151 | a. reproduce and Share the Licensed Material, in whole or 152 | in part, for NonCommercial purposes only; and 153 | 154 | b. produce, reproduce, and Share Adapted Material for 155 | NonCommercial purposes only. 156 | 157 | 2. Exceptions and Limitations. For the avoidance of doubt, where 158 | Exceptions and Limitations apply to Your use, this Public 159 | License does not apply, and You do not need to comply with 160 | its terms and conditions. 161 | 162 | 3. Term. The term of this Public License is specified in Section 163 | 6(a). 164 | 165 | 4. Media and formats; technical modifications allowed. The 166 | Licensor authorizes You to exercise the Licensed Rights in 167 | all media and formats whether now known or hereafter created, 168 | and to make technical modifications necessary to do so. The 169 | Licensor waives and/or agrees not to assert any right or 170 | authority to forbid You from making technical modifications 171 | necessary to exercise the Licensed Rights, including 172 | technical modifications necessary to circumvent Effective 173 | Technological Measures. For purposes of this Public License, 174 | simply making modifications authorized by this Section 2(a) 175 | (4) never produces Adapted Material. 176 | 177 | 5. Downstream recipients. 178 | 179 | a. Offer from the Licensor -- Licensed Material. Every 180 | recipient of the Licensed Material automatically 181 | receives an offer from the Licensor to exercise the 182 | Licensed Rights under the terms and conditions of this 183 | Public License. 184 | 185 | b. No downstream restrictions. You may not offer or impose 186 | any additional or different terms or conditions on, or 187 | apply any Effective Technological Measures to, the 188 | Licensed Material if doing so restricts exercise of the 189 | Licensed Rights by any recipient of the Licensed 190 | Material. 191 | 192 | 6. No endorsement. Nothing in this Public License constitutes or 193 | may be construed as permission to assert or imply that You 194 | are, or that Your use of the Licensed Material is, connected 195 | with, or sponsored, endorsed, or granted official status by, 196 | the Licensor or others designated to receive attribution as 197 | provided in Section 3(a)(1)(A)(i). 198 | 199 | b. Other rights. 200 | 201 | 1. Moral rights, such as the right of integrity, are not 202 | licensed under this Public License, nor are publicity, 203 | privacy, and/or other similar personality rights; however, to 204 | the extent possible, the Licensor waives and/or agrees not to 205 | assert any such rights held by the Licensor to the limited 206 | extent necessary to allow You to exercise the Licensed 207 | Rights, but not otherwise. 208 | 209 | 2. Patent and trademark rights are not licensed under this 210 | Public License. 211 | 212 | 3. To the extent possible, the Licensor waives any right to 213 | collect royalties from You for the exercise of the Licensed 214 | Rights, whether directly or through a collecting society 215 | under any voluntary or waivable statutory or compulsory 216 | licensing scheme. In all other cases the Licensor expressly 217 | reserves any right to collect such royalties, including when 218 | the Licensed Material is used other than for NonCommercial 219 | purposes. 220 | 221 | 222 | Section 3 -- License Conditions. 223 | 224 | Your exercise of the Licensed Rights is expressly made subject to the 225 | following conditions. 226 | 227 | a. Attribution. 228 | 229 | 1. If You Share the Licensed Material (including in modified 230 | form), You must: 231 | 232 | a. retain the following if it is supplied by the Licensor 233 | with the Licensed Material: 234 | 235 | i. identification of the creator(s) of the Licensed 236 | Material and any others designated to receive 237 | attribution, in any reasonable manner requested by 238 | the Licensor (including by pseudonym if 239 | designated); 240 | 241 | ii. a copyright notice; 242 | 243 | iii. a notice that refers to this Public License; 244 | 245 | iv. a notice that refers to the disclaimer of 246 | warranties; 247 | 248 | v. a URI or hyperlink to the Licensed Material to the 249 | extent reasonably practicable; 250 | 251 | b. indicate if You modified the Licensed Material and 252 | retain an indication of any previous modifications; and 253 | 254 | c. indicate the Licensed Material is licensed under this 255 | Public License, and include the text of, or the URI or 256 | hyperlink to, this Public License. 257 | 258 | 2. You may satisfy the conditions in Section 3(a)(1) in any 259 | reasonable manner based on the medium, means, and context in 260 | which You Share the Licensed Material. For example, it may be 261 | reasonable to satisfy the conditions by providing a URI or 262 | hyperlink to a resource that includes the required 263 | information. 264 | 265 | 3. If requested by the Licensor, You must remove any of the 266 | information required by Section 3(a)(1)(A) to the extent 267 | reasonably practicable. 268 | 269 | 4. If You Share Adapted Material You produce, the Adapter's 270 | License You apply must not prevent recipients of the Adapted 271 | Material from complying with this Public License. 272 | 273 | 274 | Section 4 -- Sui Generis Database Rights. 275 | 276 | Where the Licensed Rights include Sui Generis Database Rights that 277 | apply to Your use of the Licensed Material: 278 | 279 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right 280 | to extract, reuse, reproduce, and Share all or a substantial 281 | portion of the contents of the database for NonCommercial purposes 282 | only; 283 | 284 | b. if You include all or a substantial portion of the database 285 | contents in a database in which You have Sui Generis Database 286 | Rights, then the database in which You have Sui Generis Database 287 | Rights (but not its individual contents) is Adapted Material; and 288 | 289 | c. You must comply with the conditions in Section 3(a) if You Share 290 | all or a substantial portion of the contents of the database. 291 | 292 | For the avoidance of doubt, this Section 4 supplements and does not 293 | replace Your obligations under this Public License where the Licensed 294 | Rights include other Copyright and Similar Rights. 295 | 296 | 297 | Section 5 -- Disclaimer of Warranties and Limitation of Liability. 298 | 299 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE 300 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS 301 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF 302 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, 303 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, 304 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR 305 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, 306 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT 307 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT 308 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. 309 | 310 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE 311 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, 312 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, 313 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, 314 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR 315 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN 316 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR 317 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR 318 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. 319 | 320 | c. The disclaimer of warranties and limitation of liability provided 321 | above shall be interpreted in a manner that, to the extent 322 | possible, most closely approximates an absolute disclaimer and 323 | waiver of all liability. 324 | 325 | 326 | Section 6 -- Term and Termination. 327 | 328 | a. This Public License applies for the term of the Copyright and 329 | Similar Rights licensed here. However, if You fail to comply with 330 | this Public License, then Your rights under this Public License 331 | terminate automatically. 332 | 333 | b. Where Your right to use the Licensed Material has terminated under 334 | Section 6(a), it reinstates: 335 | 336 | 1. automatically as of the date the violation is cured, provided 337 | it is cured within 30 days of Your discovery of the 338 | violation; or 339 | 340 | 2. upon express reinstatement by the Licensor. 341 | 342 | For the avoidance of doubt, this Section 6(b) does not affect any 343 | right the Licensor may have to seek remedies for Your violations 344 | of this Public License. 345 | 346 | c. For the avoidance of doubt, the Licensor may also offer the 347 | Licensed Material under separate terms or conditions or stop 348 | distributing the Licensed Material at any time; however, doing so 349 | will not terminate this Public License. 350 | 351 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public 352 | License. 353 | 354 | 355 | Section 7 -- Other Terms and Conditions. 356 | 357 | a. The Licensor shall not be bound by any additional or different 358 | terms or conditions communicated by You unless expressly agreed. 359 | 360 | b. Any arrangements, understandings, or agreements regarding the 361 | Licensed Material not stated herein are separate from and 362 | independent of the terms and conditions of this Public License. 363 | 364 | 365 | Section 8 -- Interpretation. 366 | 367 | a. For the avoidance of doubt, this Public License does not, and 368 | shall not be interpreted to, reduce, limit, restrict, or impose 369 | conditions on any use of the Licensed Material that could lawfully 370 | be made without permission under this Public License. 371 | 372 | b. To the extent possible, if any provision of this Public License is 373 | deemed unenforceable, it shall be automatically reformed to the 374 | minimum extent necessary to make it enforceable. If the provision 375 | cannot be reformed, it shall be severed from this Public License 376 | without affecting the enforceability of the remaining terms and 377 | conditions. 378 | 379 | c. No term or condition of this Public License will be waived and no 380 | failure to comply consented to unless expressly agreed to by the 381 | Licensor. 382 | 383 | d. Nothing in this Public License constitutes or may be interpreted 384 | as a limitation upon, or waiver of, any privileges and immunities 385 | that apply to the Licensor or You, including from the legal 386 | processes of any jurisdiction or authority. 387 | 388 | ======================================================================= 389 | 390 | Creative Commons is not a party to its public 391 | licenses. Notwithstanding, Creative Commons may elect to apply one of 392 | its public licenses to material it publishes and in those instances 393 | will be considered the “Licensor.” The text of the Creative Commons 394 | public licenses is dedicated to the public domain under the CC0 Public 395 | Domain Dedication. Except for the limited purpose of indicating that 396 | material is shared under a Creative Commons public license or as 397 | otherwise permitted by the Creative Commons policies published at 398 | creativecommons.org/policies, Creative Commons does not authorize the 399 | use of the trademark "Creative Commons" or any other trademark or logo 400 | of Creative Commons without its prior written consent including, 401 | without limitation, in connection with any unauthorized modifications 402 | to any of its public licenses or any other arrangements, 403 | understandings, or agreements concerning use of licensed material. For 404 | the avoidance of doubt, this paragraph does not form part of the 405 | public licenses. 406 | 407 | Creative Commons may be contacted at creativecommons.org. 408 | 409 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Generating Summaries with Controllable Readability Levels (EMNLP 2023) 2 | 3 | This repository contains the code for the paper "[Generating Summaries with Controllable Readability Levels](https://arxiv.org/pdf/2310.10623)". 4 | 5 | We developed three text generation techniques for controlling readability: 6 | 7 |

8 | 9 |

10 | 11 | (a) illustrates the approach to control the summary readability via fine-grained instructions. (b) shows the RL method where given an input document and the readability level, the policy generates a summary to be scored by our Gaussian-based reward, and (c) shows the lookahead approach which uses a readability score of a future summary to guide the generation. 12 | 13 | ## Environment 14 | 15 | The easiest way to proceed is to create a conda environment: 16 | ``` 17 | conda create -n readability_summ python=3.7 18 | conda activate readability_summ 19 | ``` 20 | 21 | Further, install PyTorch: 22 | 23 | ``` 24 | conda install pytorch torchvision torchaudio cpuonly -c pytorch 25 | ``` 26 | 27 | Install the packages required: 28 | ``` 29 | pip install -r requirements.txt 30 | ``` 31 | 32 | Install trlx (for the RL method): 33 | ``` 34 | git clone https://github.com/CarperAI/trlx.git 35 | cd trlx 36 | pip install torch --extra-index-url https://download.pytorch.org/whl/cu118 37 | pip install -e . 38 | ``` 39 | 40 | 41 | ## Preprocess data 42 | 43 | For computing the readability scores for CNN/DM, execute: 44 | 45 | ``` 46 | cd src/preprocess 47 | python preprocess_cnndm.py 48 | ``` 49 | 50 | Generate the prompts: 51 | ``` 52 | python generate_prompts_category.py 53 | python generate_prompts_score.py 54 | ``` 55 | 56 | 57 | ## Training 58 | 59 | Execute the following commands for training for the prompt-based methods: 60 | ``` 61 | cd src/train 62 | ./train_cnndm.sh 63 | ``` 64 | 65 | For the RL method, execute: 66 | ``` 67 | cd src/train/rl 68 | ./train_rl_cnndm.sh 69 | ``` 70 | 71 | ## Inference 72 | 73 | For inference, run: 74 | ``` 75 | cd inference/ 76 | ./inference_score.sh 77 | ./inference_category.sh 78 | ``` 79 | 80 | For lookahead inference, run: 81 | ``` 82 | ./inference_lookahead.sh 83 | ``` 84 | 85 | ## Security 86 | 87 | See [CONTRIBUTING](CONTRIBUTING.md#security-issue-notifications) for more information. 88 | 89 | ## License Summary 90 | 91 | The documentation is made available under the CC-BY-NC-4.0 License. See the LICENSE file. 92 | 93 | ## Citation 94 | 95 | ``` 96 | @inproceedings{ribeiro-etal-2023-generating, 97 | title = "Generating Summaries with Controllable Readability Levels", 98 | author = "Ribeiro, Leonardo F. R. and 99 | Bansal, Mohit and 100 | Dreyer, Markus", 101 | editor = "Bouamor, Houda and 102 | Pino, Juan and 103 | Bali, Kalika", 104 | booktitle = "Proceedings of the 2023 Conference on Empirical Methods in Natural Language Processing", 105 | month = dec, 106 | year = "2023", 107 | address = "Singapore", 108 | publisher = "Association for Computational Linguistics", 109 | url = "https://aclanthology.org/2023.emnlp-main.714", 110 | doi = "10.18653/v1/2023.emnlp-main.714", 111 | pages = "11669--11687", 112 | abstract = "Readability refers to how easily a reader can understand a written text. Several factors affect the readability level, such as the complexity of the text, its subject matter, and the reader{'}s background knowledge. Generating summaries based on different readability levels is critical for enabling knowledge consumption by diverse audiences. However, current text generation approaches lack refined control, resulting in texts that are not customized to readers{'} proficiency levels. In this work, we bridge this gap and study techniques to generate summaries at specified readability levels. Unlike previous methods that focus on a specific readability level (e.g., lay summarization), we generate summaries with fine-grained control over their readability. We develop three text generation techniques for controlling readability: (1) instruction-based readability control, (2) reinforcement learning to minimize the gap between requested and observed readability and (3) a decoding approach that uses lookahead to estimate the readability of upcoming decoding steps. We show that our generation methods significantly improve readability control on news summarization (CNN/DM dataset), as measured by various readability metrics and human judgement, establishing strong baselines for controllable readability in summarization.", 113 | } 114 | 115 | ``` 116 | -------------------------------------------------------------------------------- /images/read-control.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/amazon-science/controllable-readability-summarization/6ecc10458e18cf034136b6be6b07f8e1b7e8f245/images/read-control.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | transformers==4.48.0 2 | rouge-score==0.1.2 3 | accelerate==0.19.0 4 | datasets==2.12.0 5 | deepspeed==0.15.1 6 | evaluate==0.4.0 7 | py-readability-metrics==1.4.4 8 | -------------------------------------------------------------------------------- /src/inference/inference_category.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | source ~/anaconda3/etc/profile.d/conda.sh 4 | 5 | conda activate readability_summ 6 | 7 | VAL_FILE='../data/test_prompt_category.json' 8 | MODEL_PATH=$1 9 | 10 | 11 | OUTPUT_DIR='outputs/1/' 12 | CUDA_VISIBLE_DEVICES=4 python -u run_summarization.py --model_name_or_path ${MODEL_PATH} \ 13 | --output_dir ${OUTPUT_DIR} --text_column input_noprompt --summary_column summary \ 14 | --train_file ${VAL_FILE} \ 15 | --validation_file ${VAL_FILE} \ 16 | --test_file ${VAL_FILE} \ 17 | --max_source_length 1024 \ 18 | --val_max_target_length 256 \ 19 | --max_target_length 256 \ 20 | --generation_max_length 256 \ 21 | --num_beams 3 \ 22 | --source_prefix "Write highlights for this article for a 11 years old student:\n\n" \ 23 | --evaluation_strategy "steps" \ 24 | --per_device_train_batch_size 1 \ 25 | --per_device_eval_batch_size 16 \ 26 | --predict_with_generate \ 27 | --do_predict & 28 | 29 | P1=$! 30 | 31 | 32 | OUTPUT_DIR='outputs/2/' 33 | CUDA_VISIBLE_DEVICES=5 python -u run_summarization.py --model_name_or_path ${MODEL_PATH} \ 34 | --output_dir ${OUTPUT_DIR} --text_column input_noprompt --summary_column summary \ 35 | --train_file ${VAL_FILE} \ 36 | --validation_file ${VAL_FILE} \ 37 | --test_file ${VAL_FILE} \ 38 | --max_source_length 1024 \ 39 | --val_max_target_length 256 \ 40 | --max_target_length 256 \ 41 | --generation_max_length 256 \ 42 | --num_beams 3 \ 43 | --source_prefix "Write highlights for this article for a middle school student:\n\n" \ 44 | --evaluation_strategy "steps" \ 45 | --per_device_train_batch_size 1 \ 46 | --per_device_eval_batch_size 16 \ 47 | --predict_with_generate \ 48 | --do_predict & 49 | 50 | P2=$! 51 | 52 | 53 | OUTPUT_DIR='outputs/3/' 54 | CUDA_VISIBLE_DEVICES=6 python -u run_summarization.py --model_name_or_path ${MODEL_PATH} \ 55 | --output_dir ${OUTPUT_DIR} --text_column input_noprompt --summary_column summary \ 56 | --train_file ${VAL_FILE} \ 57 | --validation_file ${VAL_FILE} \ 58 | --test_file ${VAL_FILE} \ 59 | --max_source_length 1024 \ 60 | --val_max_target_length 256 \ 61 | --max_target_length 256 \ 62 | --generation_max_length 256 \ 63 | --num_beams 3 \ 64 | --source_prefix "Write highlights for this article for a high school student:\n\n" \ 65 | --evaluation_strategy "steps" \ 66 | --per_device_train_batch_size 1 \ 67 | --per_device_eval_batch_size 16 \ 68 | --predict_with_generate \ 69 | --do_predict & 70 | 71 | P3=$! 72 | 73 | 74 | OUTPUT_DIR='outputs/4/' 75 | CUDA_VISIBLE_DEVICES=7 python -u run_summarization.py --model_name_or_path ${MODEL_PATH} \ 76 | --output_dir ${OUTPUT_DIR} --text_column input_noprompt --summary_column summary \ 77 | --train_file ${VAL_FILE} \ 78 | --validation_file ${VAL_FILE} \ 79 | --test_file ${VAL_FILE} \ 80 | --max_source_length 1024 \ 81 | --val_max_target_length 256 \ 82 | --max_target_length 256 \ 83 | --generation_max_length 256 \ 84 | --num_beams 3 \ 85 | --source_prefix "Write highlights for this article for a college student:\n\n" \ 86 | --evaluation_strategy "steps" \ 87 | --per_device_train_batch_size 1 \ 88 | --per_device_eval_batch_size 16 \ 89 | --predict_with_generate \ 90 | --do_predict & 91 | 92 | P4=$! 93 | 94 | wait $P1 $P2 $P3 $P4 95 | 96 | conda deactivate -------------------------------------------------------------------------------- /src/inference/inference_lookahead.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | source ~/anaconda3/etc/profile.d/conda.sh 4 | 5 | conda activate readability_summ 6 | 7 | LOOKAHEAD_LENGTH=20 8 | DOC_FILE='../data/test_prompt_category.json' 9 | MODEL_PATH=$1 10 | 11 | 12 | PROMPT="Write highlights for this article for a 11 years old student:\n\n" 13 | OUTPUT_FILE="11yold.txt" 14 | SCORE=90 15 | CUDA_VISIBLE_DEVICES=0 python run_lookahead.py --document_file ${DOC_FILE} --output_file ${OUTPUT_FILE} --do_lookahead --lookahead_decoding_type greedy --model_name ${MODEL_PATH} --lookahead_length ${LOOKAHEAD_LENGTH} \ 16 | --prompt "${PROMPT}" --score ${SCORE} & 17 | P1=$! 18 | 19 | 20 | PROMPT="Write highlights for this article for a middle school student:\n\n" 21 | OUTPUT_FILE="middle-school.txt" 22 | SCORE=70 23 | CUDA_VISIBLE_DEVICES=1 python run_lookahead.py --document_file ${DOC_FILE} --output_file ${OUTPUT_FILE} --do_lookahead --lookahead_decoding_type greedy --model_name ${MODEL_PATH} --lookahead_length ${LOOKAHEAD_LENGTH} \ 24 | --prompt "${PROMPT}" --score ${SCORE} & 25 | P2=$! 26 | 27 | 28 | PROMPT="Write highlights for this article for a high school student:\n\n" 29 | OUTPUT_FILE="high-school.txt" 30 | SCORE=50 31 | CUDA_VISIBLE_DEVICES=2 python run_lookahead.py --document_file ${DOC_FILE} --output_file ${OUTPUT_FILE} --do_lookahead --lookahead_decoding_type greedy --model_name ${MODEL_PATH} --lookahead_length ${LOOKAHEAD_LENGTH} \ 32 | --prompt "${PROMPT}" --score ${SCORE} & 33 | P3=$! 34 | 35 | 36 | PROMPT="Write highlights for this article for a college student:\n\n" 37 | OUTPUT_FILE="college-student.txt" 38 | SCORE=30 39 | CUDA_VISIBLE_DEVICES=3 python run_lookahead.py --document_file ${DOC_FILE} --output_file ${OUTPUT_FILE} --do_lookahead --lookahead_decoding_type greedy --model_name ${MODEL_PATH} --lookahead_length ${LOOKAHEAD_LENGTH} \ 40 | --prompt "${PROMPT}" --score ${SCORE} & 41 | P4=$! 42 | 43 | wait $P1 $P2 $P3 $P4 44 | 45 | conda deactivate -------------------------------------------------------------------------------- /src/inference/inference_score.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | source ~/anaconda3/etc/profile.d/conda.sh 4 | 5 | conda activate readability_summ 6 | 7 | VAL_FILE='../data/test_prompt_score.json' 8 | MODEL_PATH=$1 9 | 10 | 11 | OUTPUT_DIR='outputs/1/' 12 | CUDA_VISIBLE_DEVICES=0 python -u run_summarization.py --model_name_or_path ${MODEL_PATH} \ 13 | --output_dir ${OUTPUT_DIR} --text_column input_noprompt --summary_column summary \ 14 | --train_file ${VAL_FILE} \ 15 | --validation_file ${VAL_FILE} \ 16 | --test_file ${VAL_FILE} \ 17 | --max_source_length 1024 \ 18 | --source_prefix "Write highlights for this article with a flesch kincaid score of 90:\n\n" \ 19 | --evaluation_strategy "steps" \ 20 | --per_device_train_batch_size 1 \ 21 | --per_device_eval_batch_size 16 \ 22 | --predict_with_generate \ 23 | --do_predict & 24 | 25 | P1=$! 26 | 27 | 28 | OUTPUT_DIR='outputs/2/' 29 | CUDA_VISIBLE_DEVICES=1 python -u run_summarization.py --model_name_or_path ${MODEL_PATH} \ 30 | --output_dir ${OUTPUT_DIR} --text_column input_noprompt --summary_column summary \ 31 | --train_file ${VAL_FILE} \ 32 | --validation_file ${VAL_FILE} \ 33 | --test_file ${VAL_FILE} \ 34 | --max_source_length 1024 \ 35 | --source_prefix "Write highlights for this article with a flesch kincaid score of 70:\n\n" \ 36 | --evaluation_strategy "steps" \ 37 | --per_device_train_batch_size 1 \ 38 | --per_device_eval_batch_size 16 \ 39 | --predict_with_generate \ 40 | --do_predict & 41 | 42 | P2=$! 43 | 44 | 45 | OUTPUT_DIR='outputs/3/' 46 | CUDA_VISIBLE_DEVICES=2 python -u run_summarization.py --model_name_or_path ${MODEL_PATH} \ 47 | --output_dir ${OUTPUT_DIR} --text_column input_noprompt --summary_column summary \ 48 | --train_file ${VAL_FILE} \ 49 | --validation_file ${VAL_FILE} \ 50 | --test_file ${VAL_FILE} \ 51 | --max_source_length 1024 \ 52 | --source_prefix "Write highlights for this article with a flesch kincaid score of 50:\n\n" \ 53 | --evaluation_strategy "steps" \ 54 | --per_device_train_batch_size 1 \ 55 | --per_device_eval_batch_size 16 \ 56 | --predict_with_generate \ 57 | --do_predict & 58 | 59 | P3=$! 60 | 61 | 62 | OUTPUT_DIR='outputs/4/' 63 | CUDA_VISIBLE_DEVICES=3 python -u run_summarization.py --model_name_or_path ${MODEL_PATH} \ 64 | --output_dir ${OUTPUT_DIR} --text_column input_noprompt --summary_column summary \ 65 | --train_file ${VAL_FILE} \ 66 | --validation_file ${VAL_FILE} \ 67 | --test_file ${VAL_FILE} \ 68 | --max_source_length 1024 \ 69 | --source_prefix "Write highlights for this article with a flesch kincaid score of 30:\n\n" \ 70 | --evaluation_strategy "steps" \ 71 | --per_device_train_batch_size 1 \ 72 | --per_device_eval_batch_size 16 \ 73 | --predict_with_generate \ 74 | --do_predict & 75 | 76 | P4=$! 77 | 78 | wait $P1 $P2 $P3 $P4 79 | 80 | conda deactivate -------------------------------------------------------------------------------- /src/inference/lookahead.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | from sys import prefix 3 | import warnings 4 | from dataclasses import dataclass 5 | from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union 6 | 7 | import torch 8 | import torch.distributed as dist 9 | from torch import nn 10 | import torch.nn.functional as F 11 | 12 | import copy 13 | 14 | from transformers.generation_beam_constraints import Constraint, DisjunctiveConstraint, PhrasalConstraint 15 | from transformers.generation_beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer 16 | from transformers.generation_logits_process import ( 17 | EncoderNoRepeatNGramLogitsProcessor, 18 | ExponentialDecayLengthPenalty, 19 | ForcedBOSTokenLogitsProcessor, 20 | ForcedEOSTokenLogitsProcessor, 21 | HammingDiversityLogitsProcessor, 22 | InfNanRemoveLogitsProcessor, 23 | LogitNormalization, 24 | LogitsProcessorList, 25 | MinLengthLogitsProcessor, 26 | NoBadWordsLogitsProcessor, 27 | NoRepeatNGramLogitsProcessor, 28 | PrefixConstrainedLogitsProcessor, 29 | RepetitionPenaltyLogitsProcessor, 30 | TemperatureLogitsWarper, 31 | TopKLogitsWarper, 32 | TopPLogitsWarper, 33 | TypicalLogitsWarper, 34 | ) 35 | from transformers.generation_stopping_criteria import ( 36 | MaxLengthCriteria, 37 | MaxTimeCriteria, 38 | StoppingCriteria, 39 | StoppingCriteriaList, 40 | validate_stopping_criteria, 41 | ) 42 | from transformers.pytorch_utils import torch_int_div 43 | from transformers.utils import ModelOutput, logging 44 | 45 | from transformers.generation_utils import ( 46 | GreedySearchEncoderDecoderOutput, 47 | GreedySearchDecoderOnlyOutput, 48 | BeamSearchEncoderDecoderOutput, 49 | BeamSearchDecoderOnlyOutput, 50 | SampleEncoderDecoderOutput, 51 | SampleDecoderOnlyOutput, 52 | ) 53 | 54 | logger = logging.get_logger(__name__) 55 | 56 | class Lookahead: 57 | """ 58 | Object that performs the lookahead. This is very similar to GenerationMixin, since it needs to decode the sequence as well, 59 | but this contains the additional function to compute heuristics score. 60 | """ 61 | 62 | def __init__( 63 | self, 64 | model, 65 | tokenizer, 66 | scorer, 67 | lookahead_length=1, 68 | lookahead_lambda=1.0, 69 | lookahead_top_k=5, 70 | decoding_type="greedy", 71 | max_length: Optional[int] = None, 72 | min_length: Optional[int] = None, 73 | do_sample: Optional[bool] = None, 74 | early_stopping: Optional[bool] = None, 75 | num_beams: Optional[int] = None, 76 | temperature: Optional[float] = None, 77 | top_k: Optional[int] = None, 78 | top_p: Optional[float] = None, 79 | typical_p: Optional[float] = None, 80 | repetition_penalty: Optional[float] = None, 81 | bad_words_ids: Optional[Iterable[int]] = None, 82 | force_words_ids: Optional[Union[Iterable[int], Iterable[Iterable[int]]]] = None, 83 | bos_token_id: Optional[int] = None, 84 | pad_token_id: Optional[int] = None, 85 | eos_token_id: Optional[int] = None, 86 | length_penalty: Optional[float] = None, 87 | no_repeat_ngram_size: Optional[int] = None, 88 | encoder_no_repeat_ngram_size: Optional[int] = None, 89 | num_return_sequences: Optional[int] = None, 90 | max_time: Optional[float] = None, 91 | max_new_tokens: Optional[int] = None, 92 | decoder_start_token_id: Optional[int] = None, 93 | use_cache: Optional[bool] = None, 94 | num_beam_groups: Optional[int] = None, 95 | diversity_penalty: Optional[float] = None, 96 | prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None, 97 | logits_processor: Optional[LogitsProcessorList] = LogitsProcessorList(), 98 | renormalize_logits: Optional[bool] = None, 99 | stopping_criteria: Optional[StoppingCriteriaList] = StoppingCriteriaList(), 100 | constraints: Optional[List[Constraint]] = None, 101 | output_attentions: Optional[bool] = None, 102 | output_hidden_states: Optional[bool] = None, 103 | output_scores: Optional[bool] = None, 104 | return_dict_in_generate: Optional[bool] = None, 105 | forced_bos_token_id: Optional[int] = None, 106 | forced_eos_token_id: Optional[int] = None, 107 | remove_invalid_values: Optional[bool] = None, 108 | synced_gpus: Optional[bool] = False, 109 | exponential_decay_length_penalty: Optional[Tuple[Union[int, float]]] = None, 110 | ): 111 | """ 112 | model: The Huggingface Model 113 | tokenizer: The tokenizer for decoding the summaries 114 | scorer: Scorer object that calculates the score given document and summary 115 | lookahead_length: The number of tokens to look ahead 116 | lookahead_lambda: The weight for the score 117 | lookahead_top_k: The number of top tokens to consider for expansion 118 | decoding_type: The decoding type for lookahead. [greedy, beam, sample] 119 | 120 | Other parameters are the same arguments expected for GenerationMixin to control the generation 121 | """ 122 | self.model = model 123 | self.tokenizer = tokenizer 124 | self.scorer = scorer 125 | 126 | if lookahead_length == -1: 127 | assert max_length is not None 128 | self.lookahead_length = max_length 129 | self.lookahead_until_sent = True 130 | else: 131 | self.lookahead_length = lookahead_length 132 | self.lookahead_until_sent = False 133 | 134 | self.lookahead_lambda = lookahead_lambda 135 | self.lookahead_top_k = lookahead_top_k 136 | self.decoding_type = decoding_type 137 | 138 | if self.decoding_type == "greedy": 139 | self.decoding_func = self.greedy_search 140 | elif self.decoding_type == "beam": 141 | self.decoding_func = self.beam_search 142 | elif self.decoding_type == "sample": 143 | self.decoding_func = self.sample 144 | 145 | # generation parameters from generate() 146 | self.bos_token_id = self.model.config.bos_token_id 147 | self.num_beams = num_beams if num_beams is not None else self.model.config.num_beams 148 | self.length_penalty = length_penalty if length_penalty is not None else self.model.config.length_penalty 149 | self.early_stopping = early_stopping if early_stopping is not None else self.model.config.early_stopping 150 | self.num_beam_groups = num_beam_groups if num_beam_groups is not None else self.model.config.num_beam_groups 151 | self.num_return_sequences = ( 152 | num_return_sequences if num_return_sequences is not None else self.model.config.num_return_sequences 153 | ) 154 | 155 | self.pad_token_id = self.model.config.pad_token_id 156 | self.eos_token_id = self.model.config.eos_token_id 157 | 158 | if self.eos_token_id is None and hasattr(self.model.config, "decoder"): 159 | self.eos_token_id = self.model.config.decoder.eos_token_id 160 | 161 | if self.pad_token_id is None and self.eos_token_id is not None: 162 | # special case if pad_token_id is not defined 163 | logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{self.eos_token_id} for open-end generation.") 164 | self.pad_token_id = self.eos_token_id 165 | self.max_length = max_length 166 | self.min_length = min_length 167 | self.temperature = temperature 168 | self.top_k = top_k 169 | self.top_p = top_p 170 | self.typical_p = typical_p 171 | self.reptition_penality = repetition_penalty 172 | self.bad_words_ids = bad_words_ids 173 | self.force_words_ids = force_words_ids 174 | self.no_repeat_ngram_size = no_repeat_ngram_size 175 | self.encoder_no_repeat_ngram_size = encoder_no_repeat_ngram_size 176 | self.max_new_tokens = max_new_tokens 177 | self.decoder_start_token_id = decoder_start_token_id 178 | self.use_cache = use_cache 179 | self.diversity_penalty = diversity_penalty 180 | self.prefix_allowed_tokens_fn = prefix_allowed_tokens_fn 181 | self.renormalize_logits = renormalize_logits 182 | self.contraints = constraints 183 | self.forced_bos_token_id = forced_bos_token_id 184 | self.forced_eos_token_id = forced_eos_token_id 185 | self.remove_invalid_values = remove_invalid_values 186 | self.exponential_decay_length_penalty = exponential_decay_length_penalty 187 | self.synced_gpus = synced_gpus 188 | 189 | # self.return_dict_in_generate = return_dict_in_generate 190 | self.return_dict_in_generate = True 191 | self.output_attentions = output_attentions 192 | self.output_hidden_states = output_hidden_states 193 | self.output_scores = output_scores 194 | 195 | # If not provided, logits processor will be prepared later since it requires input_tensor 196 | self.logits_processor = logits_processor 197 | 198 | # prepare stopping criteria 199 | self.stopping_criteria = self.model._get_stopping_criteria( 200 | max_length=max_length, max_time=max_time, stopping_criteria=stopping_criteria 201 | ) 202 | 203 | self.logits_warper = self.model._get_logits_warper( 204 | top_k=self.top_k, 205 | top_p=self.top_p, 206 | typical_p=self.typical_p, 207 | temperature=self.temperature, 208 | num_beams=self.num_beams, 209 | renormalize_logits=self.renormalize_logits, 210 | ) 211 | 212 | 213 | def score( 214 | self, 215 | input_ids, 216 | next_token_scores, 217 | num_beams=1, 218 | **model_kwargs, 219 | ): 220 | """ 221 | Main function to call for the lookahead. This function generates the sequences and return the calculated heurstics 222 | """ 223 | 224 | # prepare for generation 225 | if self.logits_processor is None: 226 | input_ids_seq_length = input_ids.size(1) 227 | inputs_tensor = model_kwargs["encoder_outputs"][self.model.main_input_name] 228 | 229 | self.logits_processor = self.model._get_logits_processor( 230 | repetition_penalty=self.repetition_penalty, 231 | no_repeat_ngram_size=self.no_repeat_ngram_size, 232 | encoder_no_repeat_ngram_size=self.encoder_no_repeat_ngram_size, 233 | input_ids_seq_length=input_ids_seq_length, 234 | encoder_input_ids=inputs_tensor, 235 | bad_words_ids=self.bad_words_ids, 236 | min_length=self.min_length, 237 | max_length=self.max_length, 238 | eos_token_id=self.eos_token_id, 239 | forced_bos_token_id=self.forced_bos_token_id, 240 | forced_eos_token_id=self.forced_eos_token_id, 241 | prefix_allowed_tokens_fn=self.prefix_allowed_tokens_fn, 242 | num_beams=self.num_beams, 243 | num_beam_groups=self.num_beam_groups, 244 | diversity_penalty=self.diversity_penalty, 245 | remove_invalid_values=self.remove_invalid_values, 246 | exponential_decay_length_penalty=self.exponential_decay_length_penalty, 247 | logits_processor=self.logits_processor, 248 | renormalize_logits=self.renormalize_logits, 249 | ) 250 | 251 | do_sample = "sample" in self.decoding_type 252 | use_beam = "beam" in self.decoding_type 253 | beam_scorer = None 254 | 255 | if use_beam: 256 | batch_size = input_ids.shape[0] * self.lookahead_top_k 257 | beam_scorer = BeamSearchScorer( 258 | batch_size=batch_size, 259 | num_beams=self.num_beams, 260 | max_length=self.stopping_criteria.max_length, 261 | device=input_ids.device, 262 | length_penalty=self.length_penalty, 263 | do_early_stopping=self.early_stopping, 264 | num_beam_hyps_to_keep=self.num_return_sequences, 265 | num_beam_groups=self.num_beam_groups, 266 | ) 267 | 268 | indices = torch.arange(input_ids.size(0), dtype=input_ids.dtype, device=input_ids.device) 269 | 270 | # expand for top k tokens to use with scorer 271 | _, top_k_indices = torch.topk(next_token_scores, k=self.lookahead_top_k, dim=-1) 272 | top_k_indices = top_k_indices.reshape(-1) 273 | 274 | indices = indices.repeat_interleave(self.lookahead_top_k) 275 | input_ids = torch.cat([input_ids[indices],top_k_indices.unsqueeze(1)], dim=1) 276 | 277 | # adjust model_kwargs 278 | model_kwargs = self.expand_model_kwargs(model_kwargs, indices) 279 | 280 | # expand if necssary for beam, currently ignoring sampling with multiple num sequences 281 | if use_beam: 282 | input_ids, model_kwargs = self.model._expand_inputs_for_generation( 283 | input_ids, 284 | expand_size=self.num_beams, 285 | is_encoder_decoder=self.model.config.is_encoder_decoder, 286 | **model_kwargs, 287 | ) 288 | indices = indices.repeat_interleave(self.num_beams) 289 | # exapand inputs for generation but does not expand past 290 | if "past" in model_kwargs: 291 | model_kwargs["past"] = tuple([tuple([p.repeat_interleave(self.num_beams, dim=0) for p in past]) for past in model_kwargs["past"]]) 292 | 293 | # calling the respective decoding function 294 | # the only difference between this implementation and the original is the addition of lookahead length and breaking once that is reached 295 | if self.lookahead_length == 0: 296 | seq = input_ids 297 | else: 298 | dec_out = self.decoding_func(input_ids, beam_scorer, **model_kwargs) 299 | seq = dec_out["sequences"] 300 | 301 | # generate the actual summary 302 | dec_seq = self.tokenizer.batch_decode(seq, skip_special_tokens=True) 303 | 304 | # calculate score given the heuristics, need to account for different indices when doing beam search 305 | # import pdb 306 | # pdb.set_trace() 307 | _lookahead_scores = self.scorer.score(dec_seq, torch.div(indices, num_beams, rounding_mode="trunc")) 308 | _lookahead_scores = torch.clamp(_lookahead_scores,min=1e-9).log() 309 | 310 | _lookahead_scores = _lookahead_scores.view(-1, self.lookahead_top_k, self.num_beams) 311 | _lookahead_scores, _ = _lookahead_scores.max(-1) 312 | 313 | lookahead_scores = torch.ones_like(next_token_scores, dtype=_lookahead_scores.dtype, device=next_token_scores.device) * 1e-9 314 | lookahead_scores = lookahead_scores.log() 315 | 316 | next_token_scores = F.log_softmax(next_token_scores, dim=-1) 317 | 318 | if use_beam: 319 | # remove repat interleave for beams 320 | indices = indices.view(-1,self.num_beams)[:,0] 321 | 322 | lookahead_scores[indices, top_k_indices] = _lookahead_scores.view(-1) 323 | 324 | return self.lookahead_lambda * lookahead_scores 325 | 326 | def greedy_search( 327 | self, 328 | input_ids: torch.LongTensor, 329 | beam_scorer = None, 330 | **model_kwargs, 331 | ): 332 | # init attention / hidden states / scores tuples 333 | scores = () if (self.return_dict_in_generate and self.output_scores) else None 334 | decoder_attentions = () if (self.return_dict_in_generate and self.output_attentions) else None 335 | cross_attentions = () if (self.return_dict_in_generate and self.output_attentions) else None 336 | decoder_hidden_states = () if (self.return_dict_in_generate and self.output_hidden_states) else None 337 | 338 | # if model is an encoder-decoder, retrieve encoder attention weights and hidden states 339 | if self.return_dict_in_generate and self.model.config.is_encoder_decoder: 340 | encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if self.output_attentions else None 341 | encoder_hidden_states = ( 342 | model_kwargs["encoder_outputs"].get("hidden_states") if self.output_hidden_states else None 343 | ) 344 | 345 | # keep track of which sequences are already finished 346 | unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) 347 | cur_len = input_ids.shape[-1] 348 | 349 | lookahead_length = self.lookahead_length + cur_len 350 | 351 | this_peer_finished = False # used by synced_gpus only 352 | while True: 353 | 354 | if self.synced_gpus: 355 | # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. 356 | # The following logic allows an early break if all peers finished generating their sequence 357 | this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device) 358 | # send 0.0 if we finished, 1.0 otherwise 359 | dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) 360 | # did all peers finish? the reduced sum will be 0.0 then 361 | if this_peer_finished_flag.item() == 0.0: 362 | break 363 | 364 | # prepare model inputs 365 | model_inputs = self.model.prepare_inputs_for_generation(input_ids, **model_kwargs) 366 | 367 | # forward pass to get next token 368 | outputs = self.model( 369 | **model_inputs, 370 | return_dict=True, 371 | output_attentions=self.output_attentions, 372 | output_hidden_states=self.output_hidden_states, 373 | ) 374 | 375 | if self.synced_gpus and this_peer_finished: 376 | cur_len = cur_len + 1 377 | continue # don't waste resources running the code we don't need 378 | 379 | next_token_logits = outputs.logits[:, -1, :] 380 | 381 | # Store scores, attentions and hidden_states when required 382 | if self.return_dict_in_generate: 383 | if self.output_scores: 384 | scores += (next_token_logits,) 385 | if self.output_attentions: 386 | decoder_attentions += ( 387 | (outputs.decoder_attentions,) if self.model.config.is_encoder_decoder else (outputs.attentions,) 388 | ) 389 | if self.model.config.is_encoder_decoder: 390 | cross_attentions += (outputs.cross_attentions,) 391 | 392 | if self.output_hidden_states: 393 | decoder_hidden_states += ( 394 | (outputs.decoder_hidden_states,) 395 | if self.model.config.is_encoder_decoder 396 | else (outputs.hidden_states,) 397 | ) 398 | 399 | # pre-process distribution 400 | next_tokens_scores = self.logits_processor(input_ids, next_token_logits) 401 | 402 | # argmax 403 | next_tokens = torch.argmax(next_tokens_scores, dim=-1) 404 | 405 | # finished sentences should have their next token be a padding token 406 | if self.eos_token_id is not None: 407 | if self.pad_token_id is None: 408 | raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") 409 | next_tokens = next_tokens * unfinished_sequences + self.pad_token_id * (1 - unfinished_sequences) 410 | 411 | # update generated ids, model inputs, and length for next step 412 | input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) 413 | model_kwargs = self.model._update_model_kwargs_for_generation( 414 | outputs, model_kwargs, is_encoder_decoder=self.model.config.is_encoder_decoder 415 | ) 416 | cur_len = cur_len + 1 417 | 418 | # Lookahead break 419 | if cur_len >= lookahead_length: 420 | break 421 | 422 | # if eos_token was found in one sentence, set sentence to finished 423 | if self.eos_token_id is not None: 424 | unfinished_sequences = unfinished_sequences.mul((next_tokens != self.eos_token_id).long()) 425 | 426 | # stop when each sentence is finished, or if we exceed the maximum length 427 | if unfinished_sequences.max() == 0 or self.stopping_criteria(input_ids, scores): 428 | if not self.synced_gpus: 429 | break 430 | else: 431 | this_peer_finished = True 432 | 433 | if self.return_dict_in_generate: 434 | if self.model.config.is_encoder_decoder: 435 | return GreedySearchEncoderDecoderOutput( 436 | sequences=input_ids, 437 | scores=scores, 438 | encoder_attentions=encoder_attentions, 439 | encoder_hidden_states=encoder_hidden_states, 440 | decoder_attentions=decoder_attentions, 441 | cross_attentions=cross_attentions, 442 | decoder_hidden_states=decoder_hidden_states, 443 | ) 444 | else: 445 | return GreedySearchDecoderOnlyOutput( 446 | sequences=input_ids, 447 | scores=scores, 448 | attentions=decoder_attentions, 449 | hidden_states=decoder_hidden_states, 450 | ) 451 | else: 452 | return input_ids 453 | 454 | def beam_search( 455 | self, 456 | input_ids: torch.LongTensor, 457 | beam_scorer = None, 458 | **model_kwargs, 459 | ): 460 | batch_size = len(beam_scorer._beam_hyps) 461 | num_beams = beam_scorer.num_beams 462 | 463 | batch_beam_size, cur_len = input_ids.shape 464 | 465 | lookahead_length = self.lookahead_length + cur_len 466 | 467 | if num_beams * batch_size != batch_beam_size: 468 | raise ValueError( 469 | f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}." 470 | ) 471 | 472 | # init attention / hidden states / scores tuples 473 | scores = () if (self.return_dict_in_generate and self.output_scores) else None 474 | beam_indices = ( 475 | tuple(() for _ in range(batch_beam_size)) if (self.return_dict_in_generate and self.output_scores) else None 476 | ) 477 | decoder_attentions = () if (self.return_dict_in_generate and self.output_attentions) else None 478 | cross_attentions = () if (self.return_dict_in_generate and self.output_attentions) else None 479 | decoder_hidden_states = () if (self.return_dict_in_generate and self.output_hidden_states) else None 480 | 481 | # if model is an encoder-decoder, retrieve encoder attention weights and hidden states 482 | if self.return_dict_in_generate and self.model.config.is_encoder_decoder: 483 | encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if self.output_attentions else None 484 | encoder_hidden_states = ( 485 | model_kwargs["encoder_outputs"].get("hidden_states") if self.output_hidden_states else None 486 | ) 487 | 488 | beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device) 489 | beam_scores[:, 1:] = -1e9 490 | beam_scores = beam_scores.view((batch_size * num_beams,)) 491 | 492 | this_peer_finished = False # used by synced_gpus only 493 | while True: 494 | 495 | if self.synced_gpus: 496 | # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. 497 | # The following logic allows an early break if all peers finished generating their sequence 498 | this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device) 499 | # send 0.0 if we finished, 1.0 otherwise 500 | dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) 501 | # did all peers finish? the reduced sum will be 0.0 then 502 | if this_peer_finished_flag.item() == 0.0: 503 | break 504 | 505 | model_inputs = self.model.prepare_inputs_for_generation(input_ids, **model_kwargs) 506 | 507 | outputs = self.model( 508 | **model_inputs, 509 | return_dict=True, 510 | output_attentions=self.output_attentions, 511 | output_hidden_states=self.output_hidden_states, 512 | ) 513 | 514 | if self.synced_gpus and this_peer_finished: 515 | cur_len = cur_len + 1 516 | continue # don't waste resources running the code we don't need 517 | 518 | next_token_logits = outputs.logits[:, -1, :] 519 | # hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id` 520 | # cannot be generated both before and after the `nn.functional.log_softmax` operation. 521 | next_token_logits = self.model.adjust_logits_during_generation(next_token_logits, cur_len=cur_len) 522 | next_token_scores = nn.functional.log_softmax( 523 | next_token_logits, dim=-1 524 | ) # (batch_size * num_beams, vocab_size) 525 | 526 | next_token_scores_processed = self.logits_processor(input_ids, next_token_scores) 527 | next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(next_token_scores) 528 | 529 | # Store scores, attentions and hidden_states when required 530 | if self.return_dict_in_generate: 531 | if self.output_scores: 532 | scores += (next_token_scores_processed,) 533 | if self.output_attentions: 534 | decoder_attentions += ( 535 | (outputs.decoder_attentions,) if self.model.config.is_encoder_decoder else (outputs.attentions,) 536 | ) 537 | if self.model.config.is_encoder_decoder: 538 | cross_attentions += (outputs.cross_attentions,) 539 | 540 | if self.output_hidden_states: 541 | decoder_hidden_states += ( 542 | (outputs.decoder_hidden_states,) 543 | if self.model.config.is_encoder_decoder 544 | else (outputs.hidden_states,) 545 | ) 546 | 547 | # reshape for beam search 548 | vocab_size = next_token_scores.shape[-1] 549 | next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size) 550 | 551 | next_token_scores, next_tokens = torch.topk( 552 | next_token_scores, 2 * num_beams, dim=1, largest=True, sorted=True 553 | ) 554 | 555 | next_indices = torch_int_div(next_tokens, vocab_size) 556 | next_tokens = next_tokens % vocab_size 557 | 558 | # stateless 559 | beam_outputs = beam_scorer.process( 560 | input_ids, 561 | next_token_scores, 562 | next_tokens, 563 | next_indices, 564 | pad_token_id=self.pad_token_id, 565 | eos_token_id=self.eos_token_id, 566 | ) 567 | 568 | beam_scores = beam_outputs["next_beam_scores"] 569 | beam_next_tokens = beam_outputs["next_beam_tokens"] 570 | beam_idx = beam_outputs["next_beam_indices"] 571 | 572 | input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1) 573 | 574 | model_kwargs = self.model._update_model_kwargs_for_generation( 575 | outputs, model_kwargs, is_encoder_decoder=self.model.config.is_encoder_decoder 576 | ) 577 | if model_kwargs["past"] is not None: 578 | model_kwargs["past"] = self.model._reorder_cache(model_kwargs["past"], beam_idx) 579 | 580 | if self.return_dict_in_generate and self.output_scores: 581 | beam_indices = tuple((beam_indices[beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices)))) 582 | 583 | # increase cur_len 584 | cur_len = cur_len + 1 585 | 586 | if cur_len >= lookahead_length: 587 | break 588 | 589 | if beam_scorer.is_done or self.stopping_criteria(input_ids, scores): 590 | if not self.synced_gpus: 591 | break 592 | else: 593 | this_peer_finished = True 594 | 595 | sequence_outputs = beam_scorer.finalize( 596 | input_ids, 597 | beam_scores, 598 | next_tokens, 599 | next_indices, 600 | pad_token_id=self.pad_token_id, 601 | eos_token_id=self.eos_token_id, 602 | max_length=self.stopping_criteria.max_length, 603 | ) 604 | 605 | if self.return_dict_in_generate: 606 | if not self.output_scores: 607 | sequence_outputs["sequence_scores"] = None 608 | else: 609 | num_return_sequences = beam_scorer.num_beam_hyps_to_keep 610 | # return only as many indices as sequences 611 | beam_indices = tuple( 612 | (beam_indices[i * num_beams : i * num_beams + num_return_sequences] for i in range(batch_size)) 613 | ) 614 | beam_indices = sum(beam_indices, ()) 615 | 616 | if self.model.config.is_encoder_decoder: 617 | return BeamSearchEncoderDecoderOutput( 618 | sequences=sequence_outputs["sequences"], 619 | sequences_scores=sequence_outputs["sequence_scores"], 620 | scores=scores, 621 | beam_indices=beam_indices, 622 | encoder_attentions=encoder_attentions, 623 | encoder_hidden_states=encoder_hidden_states, 624 | decoder_attentions=decoder_attentions, 625 | cross_attentions=cross_attentions, 626 | decoder_hidden_states=decoder_hidden_states, 627 | ) 628 | else: 629 | return BeamSearchDecoderOnlyOutput( 630 | sequences=sequence_outputs["sequences"], 631 | sequences_scores=sequence_outputs["sequence_scores"], 632 | scores=scores, 633 | beam_indices=beam_indices, 634 | attentions=decoder_attentions, 635 | hidden_states=decoder_hidden_states, 636 | ) 637 | else: 638 | return sequence_outputs["sequences"] 639 | 640 | def sample( 641 | self, 642 | input_ids: torch.LongTensor, 643 | beam_scorer = None, 644 | **model_kwargs, 645 | ): 646 | scores = () if (self.return_dict_in_generate and self.output_scores) else None 647 | decoder_attentions = () if (self.return_dict_in_generate and self.output_attentions) else None 648 | cross_attentions = () if (self.return_dict_in_generate and self.output_attentions) else None 649 | decoder_hidden_states = () if (self.return_dict_in_generate and self.output_hidden_states) else None 650 | 651 | # if model is an encoder-decoder, retrieve encoder attention weights and hidden states 652 | if self.return_dict_in_generate and self.model.config.is_encoder_decoder: 653 | encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if self.output_attentions else None 654 | encoder_hidden_states = ( 655 | model_kwargs["encoder_outputs"].get("hidden_states") if self.output_hidden_states else None 656 | ) 657 | 658 | # keep track of which sequences are already finished 659 | unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1) 660 | cur_len = input_ids.shape[-1] 661 | 662 | lookahead_length = self.lookahead_length + cur_len 663 | 664 | this_peer_finished = False # used by synced_gpus only 665 | # auto-regressive generation 666 | while True: 667 | 668 | if self.synced_gpus: 669 | # Under synced_gpus the `forward` call must continue until all gpus complete their sequence. 670 | # The following logic allows an early break if all peers finished generating their sequence 671 | this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device) 672 | # send 0.0 if we finished, 1.0 otherwise 673 | dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM) 674 | # did all peers finish? the reduced sum will be 0.0 then 675 | if this_peer_finished_flag.item() == 0.0: 676 | break 677 | 678 | # prepare model inputs 679 | model_inputs = self.model.prepare_inputs_for_generation(input_ids, **model_kwargs) 680 | 681 | # forward pass to get next token 682 | outputs = self.model( 683 | **model_inputs, 684 | return_dict=True, 685 | output_attentions=self.output_attentions, 686 | output_hidden_states=self.output_hidden_states, 687 | ) 688 | 689 | if self.synced_gpus and this_peer_finished: 690 | cur_len = cur_len + 1 691 | continue # don't waste resources running the code we don't need 692 | 693 | next_token_logits = outputs.logits[:, -1, :] 694 | 695 | # pre-process distribution 696 | next_token_scores = self.logits_processor(input_ids, next_token_logits) 697 | next_token_scores = self.logits_warper(input_ids, next_token_scores) 698 | 699 | # Store scores, attentions and hidden_states when required 700 | if self.return_dict_in_generate: 701 | if self.output_scores: 702 | scores += (next_token_scores,) 703 | if self.output_attentions: 704 | decoder_attentions += ( 705 | (outputs.decoder_attentions,) if self.model.config.is_encoder_decoder else (outputs.attentions,) 706 | ) 707 | if self.model.config.is_encoder_decoder: 708 | cross_attentions += (outputs.cross_attentions,) 709 | 710 | if self.output_hidden_states: 711 | decoder_hidden_states += ( 712 | (outputs.decoder_hidden_states,) 713 | if self.model.config.is_encoder_decoder 714 | else (outputs.hidden_states,) 715 | ) 716 | 717 | # sample 718 | probs = nn.functional.softmax(next_token_scores, dim=-1) 719 | next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) 720 | 721 | # finished sentences should have their next token be a padding token 722 | if self.eos_token_id is not None: 723 | if self.pad_token_id is None: 724 | raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") 725 | next_tokens = next_tokens * unfinished_sequences + self.pad_token_id * (1 - unfinished_sequences) 726 | 727 | # update generated ids, model inputs, and length for next step 728 | input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1) 729 | model_kwargs = self.model._update_model_kwargs_for_generation( 730 | outputs, model_kwargs, is_encoder_decoder=self.model.config.is_encoder_decoder 731 | ) 732 | cur_len = cur_len + 1 733 | 734 | if cur_len >= lookahead_length: 735 | break 736 | 737 | # if eos_token was found in one sentence, set sentence to finished 738 | if self.eos_token_id is not None: 739 | unfinished_sequences = unfinished_sequences.mul((next_tokens != self.eos_token_id).long()) 740 | 741 | # stop when each sentence is finished, or if we exceed the maximum length 742 | if unfinished_sequences.max() == 0 or self.stopping_criteria(input_ids, scores): 743 | if not self.synced_gpus: 744 | break 745 | else: 746 | this_peer_finished = True 747 | 748 | if self.return_dict_in_generate: 749 | if self.model.config.is_encoder_decoder: 750 | return SampleEncoderDecoderOutput( 751 | sequences=input_ids, 752 | scores=scores, 753 | encoder_attentions=encoder_attentions, 754 | encoder_hidden_states=encoder_hidden_states, 755 | decoder_attentions=decoder_attentions, 756 | cross_attentions=cross_attentions, 757 | decoder_hidden_states=decoder_hidden_states, 758 | ) 759 | else: 760 | return SampleDecoderOnlyOutput( 761 | sequences=input_ids, 762 | scores=scores, 763 | attentions=decoder_attentions, 764 | hidden_states=decoder_hidden_states, 765 | ) 766 | else: 767 | return input_ids 768 | 769 | 770 | def expand_model_kwargs(self, model_kwargs, indices): 771 | model_kwargs = copy.deepcopy(model_kwargs) 772 | if "attention_mask" in model_kwargs: 773 | model_kwargs["attention_mask"] = model_kwargs["attention_mask"][indices] 774 | if "encoder_outputs" in model_kwargs: 775 | for k,v in model_kwargs["encoder_outputs"].items(): 776 | if v is not None: 777 | model_kwargs["encoder_outputs"][k] = v[indices] 778 | if "past" in model_kwargs: 779 | model_kwargs["past"] = tuple([tuple([p[indices] for p in past]) for past in model_kwargs["past"]]) 780 | return model_kwargs -------------------------------------------------------------------------------- /src/inference/run_lookahead.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoTokenizer, AutoModelForSeq2SeqLM 2 | from scorer import FleschScorer 3 | from lookahead import Lookahead 4 | from generation import Generator 5 | from tqdm import tqdm 6 | 7 | import json 8 | 9 | import argparse 10 | 11 | def open_file(file): 12 | entities = [] 13 | 14 | for line in open(file).readlines(): 15 | entities.append(json.loads(line)) 16 | 17 | return entities 18 | 19 | parser = argparse.ArgumentParser() 20 | 21 | # base decoding model 22 | parser.add_argument("--model_name", type=str, default="facebook/bart-large-xsum") 23 | parser.add_argument("--cache_dir", type=str, default="./cache") 24 | 25 | # input output 26 | parser.add_argument("--document_file", type=str, required=True) 27 | parser.add_argument("--output_file", type=str, required=True) 28 | 29 | # base decoding configuration. Please refer to Huggingface's GenerationMixin for the explaination of the parameters 30 | parser.add_argument("--batch_size", type=int, default=8) 31 | parser.add_argument("--score", type=int, default=30) 32 | parser.add_argument("--prompt", type=str, default="") 33 | parser.add_argument("--num_beams", type=int, default=1) 34 | parser.add_argument("--num_return_sequences", type=int, default=1) 35 | parser.add_argument("--max_input_length", type=int, default=1024) 36 | parser.add_argument("--max_output_length", type=int, default=256) 37 | parser.add_argument("--do_sample", action='store_true', default=False) 38 | 39 | # lookahead configuration 40 | parser.add_argument("--do_lookahead", action="store_true", default=False) 41 | parser.add_argument("--lookahead_length", type=int, default=64) 42 | parser.add_argument("--lookahead_lambda", type=int, default=25) 43 | parser.add_argument("--top_k", type=int, default=5) 44 | parser.add_argument("--lookahead_decoding_type", type=str, default="greedy", choices=["greedy","beam","sample"]) 45 | parser.add_argument("--lookahead_beam", type=int, default=1) 46 | 47 | # scorer configuration 48 | parser.add_argument("--scorer_model_type", type=str, default="roberta-large") 49 | parser.add_argument("--scorer_num_layers", type=int, default=17) 50 | 51 | args = parser.parse_args() 52 | 53 | # loading model 54 | tokenizer = AutoTokenizer.from_pretrained(args.model_name, cache_dir=args.cache_dir) 55 | model = AutoModelForSeq2SeqLM.from_pretrained(args.model_name, cache_dir=args.cache_dir) 56 | model = model.cuda() # can optionally call .half() for mixed precision 57 | 58 | # loading input 59 | documents = open_file(args.document_file) 60 | documents = [args.prompt + doc["input_noprompt"] for doc in documents] 61 | 62 | scorer = FleschScorer( 63 | 'flesch', 64 | args.score 65 | ) 66 | 67 | # Create lookahead 68 | lookahead = None 69 | if args.do_lookahead: 70 | lookahead = Lookahead( 71 | model, 72 | tokenizer, 73 | scorer, 74 | lookahead_length=args.lookahead_length, 75 | lookahead_lambda=args.lookahead_lambda, 76 | lookahead_top_k=args.top_k, 77 | decoding_type=args.lookahead_decoding_type, 78 | num_beams=args.lookahead_beam, 79 | num_return_sequences=args.lookahead_beam, 80 | max_length=args.max_output_length, 81 | ) 82 | 83 | # Create generator with lookahead 84 | generator = Generator(model, lookahead=lookahead) 85 | 86 | summaries = [] 87 | 88 | for i in tqdm(range(0, len(documents), args.batch_size)): 89 | input_str = documents[i:i+args.batch_size] 90 | 91 | inputs = tokenizer(input_str, max_length=args.max_input_length, padding=True, truncation=True, return_tensors="pt") 92 | 93 | inputs = {k:v.cuda() for k,v in inputs.items()} 94 | 95 | output = generator.generate( 96 | input_ids = inputs["input_ids"], 97 | attention_mask=inputs["attention_mask"], 98 | num_beams=args.num_beams, 99 | num_return_sequences=args.num_return_sequences, 100 | max_length=args.max_output_length, 101 | do_sample=args.do_sample, 102 | ) 103 | 104 | output = tokenizer.batch_decode(output, skip_special_tokens=True) 105 | 106 | if args.num_return_sequences == 1: 107 | summaries += output 108 | else: 109 | for i in range(0, len(output), args.num_return_sequences): 110 | summaries.append(output[i:i+args.num_return_sequences]) 111 | 112 | # Save file 113 | with open(args.output_file, "w") as f: 114 | if args.num_return_sequences == 1: 115 | for line in summaries: 116 | f.write(line + "\n") 117 | else: 118 | json.dump(summaries, f) -------------------------------------------------------------------------------- /src/inference/run_summarization.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | # Copyright 2021 The HuggingFace Team. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ 17 | Fine-tuning the library models for sequence to sequence. 18 | """ 19 | # You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments. 20 | 21 | import logging 22 | import os 23 | import sys 24 | from dataclasses import dataclass, field 25 | from typing import Optional 26 | 27 | import datasets 28 | import nltk # Here to have a nice missing dependency error message early on 29 | import numpy as np 30 | from datasets import load_dataset 31 | 32 | import evaluate 33 | import transformers 34 | from filelock import FileLock 35 | from transformers import ( 36 | AutoConfig, 37 | AutoModelForSeq2SeqLM, 38 | AutoTokenizer, 39 | DataCollatorForSeq2Seq, 40 | HfArgumentParser, 41 | MBart50Tokenizer, 42 | MBart50TokenizerFast, 43 | MBartTokenizer, 44 | MBartTokenizerFast, 45 | Seq2SeqTrainer, 46 | Seq2SeqTrainingArguments, 47 | set_seed, 48 | ) 49 | from transformers.trainer_utils import get_last_checkpoint 50 | from transformers.utils import check_min_version, is_offline_mode, send_example_telemetry 51 | from transformers.utils.versions import require_version 52 | 53 | os.environ["NCCL_DEBUG"] = "INFO" 54 | 55 | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. 56 | #check_min_version("4.25.0.dev0") 57 | 58 | require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt") 59 | 60 | logger = logging.getLogger(__name__) 61 | 62 | try: 63 | nltk.data.find("tokenizers/punkt") 64 | except (LookupError, OSError): 65 | if is_offline_mode(): 66 | raise LookupError( 67 | "Offline mode: run this script without TRANSFORMERS_OFFLINE first to download nltk data files" 68 | ) 69 | with FileLock(".lock") as lock: 70 | nltk.download("punkt", quiet=True) 71 | 72 | # A list of all multilingual tokenizer which require lang attribute. 73 | MULTILINGUAL_TOKENIZERS = [MBartTokenizer, MBartTokenizerFast, MBart50Tokenizer, MBart50TokenizerFast] 74 | 75 | 76 | @dataclass 77 | class ModelArguments: 78 | """ 79 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. 80 | """ 81 | 82 | model_name_or_path: str = field( 83 | metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} 84 | ) 85 | config_name: Optional[str] = field( 86 | default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} 87 | ) 88 | tokenizer_name: Optional[str] = field( 89 | default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} 90 | ) 91 | cache_dir: Optional[str] = field( 92 | default=None, 93 | metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"}, 94 | ) 95 | use_fast_tokenizer: bool = field( 96 | default=True, 97 | metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}, 98 | ) 99 | model_revision: str = field( 100 | default="main", 101 | metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, 102 | ) 103 | use_auth_token: bool = field( 104 | default=False, 105 | metadata={ 106 | "help": ( 107 | "Will use the token generated when running `huggingface-cli login` (necessary to use this script " 108 | "with private models)." 109 | ) 110 | }, 111 | ) 112 | resize_position_embeddings: Optional[bool] = field( 113 | default=None, 114 | metadata={ 115 | "help": ( 116 | "Whether to automatically resize the position embeddings if `max_source_length` exceeds " 117 | "the model's position embeddings." 118 | ) 119 | }, 120 | ) 121 | 122 | 123 | @dataclass 124 | class DataTrainingArguments: 125 | """ 126 | Arguments pertaining to what data we are going to input our model for training and eval. 127 | """ 128 | 129 | lang: Optional[str] = field(default=None, metadata={"help": "Language id for summarization."}) 130 | 131 | dataset_name: Optional[str] = field( 132 | default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} 133 | ) 134 | dataset_config_name: Optional[str] = field( 135 | default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} 136 | ) 137 | text_column: Optional[str] = field( 138 | default=None, 139 | metadata={"help": "The name of the column in the datasets containing the full texts (for summarization)."}, 140 | ) 141 | summary_column: Optional[str] = field( 142 | default=None, 143 | metadata={"help": "The name of the column in the datasets containing the summaries (for summarization)."}, 144 | ) 145 | train_file: Optional[str] = field( 146 | default=None, metadata={"help": "The input training data file (a jsonlines or csv file)."} 147 | ) 148 | validation_file: Optional[str] = field( 149 | default=None, 150 | metadata={ 151 | "help": ( 152 | "An optional input evaluation data file to evaluate the metrics (rouge) on (a jsonlines or csv file)." 153 | ) 154 | }, 155 | ) 156 | test_file: Optional[str] = field( 157 | default=None, 158 | metadata={ 159 | "help": "An optional input test data file to evaluate the metrics (rouge) on (a jsonlines or csv file)." 160 | }, 161 | ) 162 | overwrite_cache: bool = field( 163 | default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} 164 | ) 165 | preprocessing_num_workers: Optional[int] = field( 166 | default=None, 167 | metadata={"help": "The number of processes to use for the preprocessing."}, 168 | ) 169 | max_source_length: Optional[int] = field( 170 | default=1024, 171 | metadata={ 172 | "help": ( 173 | "The maximum total input sequence length after tokenization. Sequences longer " 174 | "than this will be truncated, sequences shorter will be padded." 175 | ) 176 | }, 177 | ) 178 | max_target_length: Optional[int] = field( 179 | default=128, 180 | metadata={ 181 | "help": ( 182 | "The maximum total sequence length for target text after tokenization. Sequences longer " 183 | "than this will be truncated, sequences shorter will be padded." 184 | ) 185 | }, 186 | ) 187 | val_max_target_length: Optional[int] = field( 188 | default=None, 189 | metadata={ 190 | "help": ( 191 | "The maximum total sequence length for validation target text after tokenization. Sequences longer " 192 | "than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`." 193 | "This argument is also used to override the ``max_length`` param of ``model.generate``, which is used " 194 | "during ``evaluate`` and ``predict``." 195 | ) 196 | }, 197 | ) 198 | pad_to_max_length: bool = field( 199 | default=False, 200 | metadata={ 201 | "help": ( 202 | "Whether to pad all samples to model maximum sentence length. " 203 | "If False, will pad the samples dynamically when batching to the maximum length in the batch. More " 204 | "efficient on GPU but very bad for TPU." 205 | ) 206 | }, 207 | ) 208 | max_train_samples: Optional[int] = field( 209 | default=None, 210 | metadata={ 211 | "help": ( 212 | "For debugging purposes or quicker training, truncate the number of training examples to this " 213 | "value if set." 214 | ) 215 | }, 216 | ) 217 | max_eval_samples: Optional[int] = field( 218 | default=None, 219 | metadata={ 220 | "help": ( 221 | "For debugging purposes or quicker training, truncate the number of evaluation examples to this " 222 | "value if set." 223 | ) 224 | }, 225 | ) 226 | max_predict_samples: Optional[int] = field( 227 | default=None, 228 | metadata={ 229 | "help": ( 230 | "For debugging purposes or quicker training, truncate the number of prediction examples to this " 231 | "value if set." 232 | ) 233 | }, 234 | ) 235 | num_beams: Optional[int] = field( 236 | default=None, 237 | metadata={ 238 | "help": ( 239 | "Number of beams to use for evaluation. This argument will be passed to ``model.generate``, " 240 | "which is used during ``evaluate`` and ``predict``." 241 | ) 242 | }, 243 | ) 244 | ignore_pad_token_for_loss: bool = field( 245 | default=True, 246 | metadata={ 247 | "help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not." 248 | }, 249 | ) 250 | source_prefix: Optional[str] = field( 251 | default="", metadata={"help": "A prefix to add before every source text (useful for T5 models)."} 252 | ) 253 | 254 | forced_bos_token: Optional[str] = field( 255 | default=None, 256 | metadata={ 257 | "help": ( 258 | "The token to force as the first generated token after the decoder_start_token_id." 259 | "Useful for multilingual models like mBART where the first generated token" 260 | "needs to be the target language token (Usually it is the target language token)" 261 | ) 262 | }, 263 | ) 264 | 265 | def __post_init__(self): 266 | if self.dataset_name is None and self.train_file is None and self.validation_file is None: 267 | raise ValueError("Need either a dataset name or a training/validation file.") 268 | else: 269 | if self.train_file is not None: 270 | extension = self.train_file.split(".")[-1] 271 | assert extension in ["csv", "json"], "`train_file` should be a csv or a json file." 272 | if self.validation_file is not None: 273 | extension = self.validation_file.split(".")[-1] 274 | assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file." 275 | if self.val_max_target_length is None: 276 | self.val_max_target_length = self.max_target_length 277 | 278 | 279 | summarization_name_mapping = { 280 | "amazon_reviews_multi": ("review_body", "review_title"), 281 | "big_patent": ("description", "abstract"), 282 | "cnn_dailymail": ("article", "highlights"), 283 | "orange_sum": ("text", "summary"), 284 | "pn_summary": ("article", "summary"), 285 | "psc": ("extract_text", "summary_text"), 286 | "samsum": ("dialogue", "summary"), 287 | "thaisum": ("body", "summary"), 288 | "xglue": ("news_body", "news_title"), 289 | "xsum": ("document", "summary"), 290 | "wiki_summary": ("article", "highlights"), 291 | "multi_news": ("document", "summary"), 292 | } 293 | 294 | 295 | def main(): 296 | # See all possible arguments in src/transformers/training_args.py 297 | # or by passing the --help flag to this script. 298 | # We now keep distinct sets of args, for a cleaner separation of concerns. 299 | 300 | parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments)) 301 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): 302 | # If we pass only one argument to the script and it's the path to a json file, 303 | # let's parse it to get our arguments. 304 | model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) 305 | else: 306 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 307 | 308 | # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The 309 | # information sent is the one passed as arguments along with your Python/PyTorch versions. 310 | send_example_telemetry("run_summarization", model_args, data_args) 311 | print("training_args", training_args) 312 | # Setup logging 313 | logging.basicConfig( 314 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 315 | datefmt="%m/%d/%Y %H:%M:%S", 316 | handlers=[logging.StreamHandler(sys.stdout)], 317 | ) 318 | log_level = training_args.get_process_log_level() 319 | logger.setLevel(log_level) 320 | datasets.utils.logging.set_verbosity(log_level) 321 | transformers.utils.logging.set_verbosity(log_level) 322 | transformers.utils.logging.enable_default_handler() 323 | transformers.utils.logging.enable_explicit_format() 324 | 325 | # Log on each process the small summary: 326 | logger.warning( 327 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" 328 | + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" 329 | ) 330 | logger.info(f"Training/evaluation parameters {training_args}") 331 | 332 | if data_args.source_prefix is None and model_args.model_name_or_path in [ 333 | "t5-small", 334 | "t5-base", 335 | "t5-large", 336 | "t5-3b", 337 | "t5-11b", 338 | ]: 339 | logger.warning( 340 | "You're running a t5 model but didn't provide a source prefix, which is the expected, e.g. with " 341 | "`--source_prefix 'summarize: ' `" 342 | ) 343 | 344 | # Detecting last checkpoint. 345 | last_checkpoint = None 346 | if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir: 347 | last_checkpoint = get_last_checkpoint(training_args.output_dir) 348 | if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: 349 | raise ValueError( 350 | f"Output directory ({training_args.output_dir}) already exists and is not empty. " 351 | "Use --overwrite_output_dir to overcome." 352 | ) 353 | elif last_checkpoint is not None and training_args.resume_from_checkpoint is None: 354 | logger.info( 355 | f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " 356 | "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." 357 | ) 358 | 359 | # Set seed before initializing model. 360 | set_seed(training_args.seed) 361 | 362 | # Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below) 363 | # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/ 364 | # (the dataset will be downloaded automatically from the datasets Hub). 365 | # 366 | # For CSV/JSON files this script will use the first column for the full texts and the second column for the 367 | # summaries (unless you specify column names for this with the `text_column` and `summary_column` arguments). 368 | # 369 | # In distributed training, the load_dataset function guarantee that only one local process can concurrently 370 | # download the dataset. 371 | print("data_args", data_args) 372 | if data_args.dataset_name is not None: 373 | # Downloading and loading a dataset from the hub. 374 | raw_datasets = load_dataset( 375 | data_args.dataset_name, 376 | data_args.dataset_config_name, 377 | cache_dir=model_args.cache_dir, 378 | use_auth_token=True if model_args.use_auth_token else None, 379 | ) 380 | else: 381 | data_files = {} 382 | if data_args.train_file is not None: 383 | data_files["train"] = data_args.train_file 384 | extension = data_args.train_file.split(".")[-1] 385 | if data_args.validation_file is not None: 386 | data_files["validation"] = data_args.validation_file 387 | extension = data_args.validation_file.split(".")[-1] 388 | if data_args.test_file is not None: 389 | data_files["test"] = data_args.test_file 390 | extension = data_args.test_file.split(".")[-1] 391 | raw_datasets = load_dataset( 392 | extension, 393 | data_files=data_files, 394 | cache_dir=model_args.cache_dir, 395 | use_auth_token=True if model_args.use_auth_token else None, 396 | ) 397 | # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at 398 | # https://huggingface.co/docs/datasets/loading_datasets.html. 399 | 400 | # Load pretrained model and tokenizer 401 | # 402 | # Distributed training: 403 | # The .from_pretrained methods guarantee that only one local process can concurrently 404 | # download model & vocab. 405 | print("model_args", model_args) 406 | config = AutoConfig.from_pretrained( 407 | model_args.config_name if model_args.config_name else model_args.model_name_or_path, 408 | cache_dir=model_args.cache_dir, 409 | revision=model_args.model_revision, 410 | use_auth_token=True if model_args.use_auth_token else None, 411 | ) 412 | tokenizer = AutoTokenizer.from_pretrained( 413 | model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, 414 | cache_dir=model_args.cache_dir, 415 | use_fast=model_args.use_fast_tokenizer, 416 | revision=model_args.model_revision, 417 | use_auth_token=True if model_args.use_auth_token else None, 418 | ) 419 | model = AutoModelForSeq2SeqLM.from_pretrained( 420 | model_args.model_name_or_path, 421 | from_tf=bool(".ckpt" in model_args.model_name_or_path), 422 | config=config, 423 | cache_dir=model_args.cache_dir, 424 | revision=model_args.model_revision, 425 | use_auth_token=True if model_args.use_auth_token else None, 426 | ) 427 | 428 | # We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch 429 | # on a small vocab and want a smaller embedding size, remove this test. 430 | embedding_size = model.get_input_embeddings().weight.shape[0] 431 | if len(tokenizer) > embedding_size: 432 | model.resize_token_embeddings(len(tokenizer)) 433 | 434 | if model.config.decoder_start_token_id is None and isinstance(tokenizer, (MBartTokenizer, MBartTokenizerFast)): 435 | if isinstance(tokenizer, MBartTokenizer): 436 | model.config.decoder_start_token_id = tokenizer.lang_code_to_id[data_args.lang] 437 | else: 438 | model.config.decoder_start_token_id = tokenizer.convert_tokens_to_ids(data_args.lang) 439 | 440 | if model.config.decoder_start_token_id is None: 441 | raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined") 442 | 443 | if ( 444 | hasattr(model.config, "max_position_embeddings") 445 | and model.config.max_position_embeddings < data_args.max_source_length 446 | ): 447 | if model_args.resize_position_embeddings is None: 448 | logger.warning( 449 | "Increasing the model's number of position embedding vectors from" 450 | f" {model.config.max_position_embeddings} to {data_args.max_source_length}." 451 | ) 452 | model.resize_position_embeddings(data_args.max_source_length) 453 | elif model_args.resize_position_embeddings: 454 | model.resize_position_embeddings(data_args.max_source_length) 455 | else: 456 | raise ValueError( 457 | f"`--max_source_length` is set to {data_args.max_source_length}, but the model only has" 458 | f" {model.config.max_position_embeddings} position encodings. Consider either reducing" 459 | f" `--max_source_length` to {model.config.max_position_embeddings} or to automatically resize the" 460 | " model's position encodings by passing `--resize_position_embeddings`." 461 | ) 462 | 463 | prefix = data_args.source_prefix if data_args.source_prefix is not None else "" 464 | 465 | # Preprocessing the datasets. 466 | # We need to tokenize inputs and targets. 467 | if training_args.do_train: 468 | column_names = raw_datasets["train"].column_names 469 | elif training_args.do_eval: 470 | column_names = raw_datasets["validation"].column_names 471 | elif training_args.do_predict: 472 | column_names = raw_datasets["test"].column_names 473 | else: 474 | logger.info("There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`.") 475 | return 476 | 477 | if isinstance(tokenizer, tuple(MULTILINGUAL_TOKENIZERS)): 478 | assert ( 479 | data_args.lang is not None 480 | ), f"{tokenizer.__class__.__name__} is a multilingual tokenizer which requires --lang argument" 481 | 482 | tokenizer.src_lang = data_args.lang 483 | tokenizer.tgt_lang = data_args.lang 484 | 485 | # For multilingual translation models like mBART-50 and M2M100 we need to force the target language token 486 | # as the first generated token. We ask the user to explicitly provide this as --forced_bos_token argument. 487 | forced_bos_token_id = ( 488 | tokenizer.lang_code_to_id[data_args.forced_bos_token] if data_args.forced_bos_token is not None else None 489 | ) 490 | model.config.forced_bos_token_id = forced_bos_token_id 491 | 492 | # Get the column names for input/target. 493 | dataset_columns = summarization_name_mapping.get(data_args.dataset_name, None) 494 | if data_args.text_column is None: 495 | text_column = dataset_columns[0] if dataset_columns is not None else column_names[0] 496 | else: 497 | text_column = data_args.text_column 498 | if text_column not in column_names: 499 | raise ValueError( 500 | f"--text_column' value '{data_args.text_column}' needs to be one of: {', '.join(column_names)}" 501 | ) 502 | if data_args.summary_column is None: 503 | summary_column = dataset_columns[1] if dataset_columns is not None else column_names[1] 504 | else: 505 | summary_column = data_args.summary_column 506 | if summary_column not in column_names: 507 | raise ValueError( 508 | f"--summary_column' value '{data_args.summary_column}' needs to be one of: {', '.join(column_names)}" 509 | ) 510 | 511 | # Temporarily set max_target_length for training. 512 | max_target_length = data_args.max_target_length 513 | padding = "max_length" if data_args.pad_to_max_length else False 514 | 515 | if training_args.label_smoothing_factor > 0 and not hasattr(model, "prepare_decoder_input_ids_from_labels"): 516 | logger.warning( 517 | "label_smoothing is enabled but the `prepare_decoder_input_ids_from_labels` method is not defined for" 518 | f"`{model.__class__.__name__}`. This will lead to loss being calculated twice and will take up more memory" 519 | ) 520 | 521 | print(data_args) 522 | 523 | def preprocess_function(examples): 524 | # remove pairs where at least one record is None 525 | 526 | inputs, targets = [], [] 527 | for i in range(len(examples[text_column])): 528 | if examples[text_column][i] and examples[summary_column][i]: 529 | inputs.append(examples[text_column][i]) 530 | targets.append(examples[summary_column][i]) 531 | 532 | inputs = [prefix + inp for inp in inputs] 533 | model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, padding=padding, truncation=True) 534 | 535 | # Tokenize targets with the `text_target` keyword argument 536 | labels = tokenizer(text_target=targets, max_length=max_target_length, padding=padding, truncation=True) 537 | 538 | # If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore 539 | # padding in the loss. 540 | if padding == "max_length" and data_args.ignore_pad_token_for_loss: 541 | labels["input_ids"] = [ 542 | [(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"] 543 | ] 544 | 545 | model_inputs["labels"] = labels["input_ids"] 546 | return model_inputs 547 | 548 | if training_args.do_train: 549 | if "train" not in raw_datasets: 550 | raise ValueError("--do_train requires a train dataset") 551 | train_dataset = raw_datasets["train"] 552 | if data_args.max_train_samples is not None: 553 | max_train_samples = min(len(train_dataset), data_args.max_train_samples) 554 | train_dataset = train_dataset.select(range(max_train_samples)) 555 | with training_args.main_process_first(desc="train dataset map pre-processing"): 556 | train_dataset = train_dataset.map( 557 | preprocess_function, 558 | batched=True, 559 | num_proc=data_args.preprocessing_num_workers, 560 | remove_columns=column_names, 561 | load_from_cache_file=not data_args.overwrite_cache, 562 | desc="Running tokenizer on train dataset", 563 | ) 564 | 565 | if training_args.do_eval: 566 | max_target_length = data_args.val_max_target_length 567 | if "validation" not in raw_datasets: 568 | raise ValueError("--do_eval requires a validation dataset") 569 | eval_dataset = raw_datasets["validation"] 570 | if data_args.max_eval_samples is not None: 571 | max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples) 572 | eval_dataset = eval_dataset.select(range(max_eval_samples)) 573 | with training_args.main_process_first(desc="validation dataset map pre-processing"): 574 | eval_dataset = eval_dataset.map( 575 | preprocess_function, 576 | batched=True, 577 | num_proc=data_args.preprocessing_num_workers, 578 | remove_columns=column_names, 579 | load_from_cache_file=not data_args.overwrite_cache, 580 | desc="Running tokenizer on validation dataset", 581 | ) 582 | 583 | if training_args.do_predict: 584 | max_target_length = data_args.val_max_target_length 585 | if "test" not in raw_datasets: 586 | raise ValueError("--do_predict requires a test dataset") 587 | predict_dataset = raw_datasets["test"] 588 | if data_args.max_predict_samples is not None: 589 | max_predict_samples = min(len(predict_dataset), data_args.max_predict_samples) 590 | predict_dataset = predict_dataset.select(range(max_predict_samples)) 591 | with training_args.main_process_first(desc="prediction dataset map pre-processing"): 592 | predict_dataset = predict_dataset.map( 593 | preprocess_function, 594 | batched=True, 595 | num_proc=data_args.preprocessing_num_workers, 596 | remove_columns=column_names, 597 | load_from_cache_file=not data_args.overwrite_cache, 598 | desc="Running tokenizer on prediction dataset", 599 | ) 600 | 601 | # Data collator 602 | label_pad_token_id = -100 if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id 603 | data_collator = DataCollatorForSeq2Seq( 604 | tokenizer, 605 | model=model, 606 | label_pad_token_id=label_pad_token_id, 607 | pad_to_multiple_of=8 if training_args.fp16 else None, 608 | ) 609 | 610 | # Metric 611 | metric = evaluate.load("rouge") 612 | 613 | def postprocess_text(preds, labels): 614 | preds = [pred.strip() for pred in preds] 615 | labels = [label.strip() for label in labels] 616 | 617 | # rougeLSum expects newline after each sentence 618 | preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds] 619 | labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels] 620 | 621 | return preds, labels 622 | 623 | def compute_metrics(eval_preds): 624 | preds, labels = eval_preds 625 | if isinstance(preds, tuple): 626 | preds = preds[0] 627 | decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True) 628 | if data_args.ignore_pad_token_for_loss: 629 | # Replace -100 in the labels as we can't decode them. 630 | labels = np.where(labels != -100, labels, tokenizer.pad_token_id) 631 | decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True) 632 | 633 | # Some simple post-processing 634 | decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels) 635 | 636 | result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True) 637 | result = {k: round(v * 100, 4) for k, v in result.items()} 638 | prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds] 639 | result["gen_len"] = np.mean(prediction_lens) 640 | return result 641 | 642 | # Initialize our Trainer 643 | trainer = Seq2SeqTrainer( 644 | model=model, 645 | args=training_args, 646 | train_dataset=train_dataset if training_args.do_train else None, 647 | eval_dataset=eval_dataset if training_args.do_eval else None, 648 | tokenizer=tokenizer, 649 | data_collator=data_collator, 650 | compute_metrics=compute_metrics if training_args.predict_with_generate else None, 651 | ) 652 | 653 | # Training 654 | if training_args.do_train: 655 | checkpoint = None 656 | if training_args.resume_from_checkpoint is not None: 657 | checkpoint = training_args.resume_from_checkpoint 658 | elif last_checkpoint is not None: 659 | checkpoint = last_checkpoint 660 | train_result = trainer.train(resume_from_checkpoint=checkpoint) 661 | trainer.save_model() # Saves the tokenizer too for easy upload 662 | 663 | metrics = train_result.metrics 664 | max_train_samples = ( 665 | data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset) 666 | ) 667 | metrics["train_samples"] = min(max_train_samples, len(train_dataset)) 668 | 669 | trainer.log_metrics("train", metrics) 670 | trainer.save_metrics("train", metrics) 671 | trainer.save_state() 672 | 673 | # Evaluation 674 | results = {} 675 | max_length = ( 676 | training_args.generation_max_length 677 | if training_args.generation_max_length is not None 678 | else data_args.val_max_target_length 679 | ) 680 | num_beams = data_args.num_beams if data_args.num_beams is not None else training_args.generation_num_beams 681 | # if training_args.do_eval: 682 | # logger.info("*** Evaluate ***") 683 | # metrics = trainer.evaluate(max_length=max_length, num_beams=num_beams, metric_key_prefix="eval") 684 | # max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset) 685 | # metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset)) 686 | # 687 | # trainer.log_metrics("eval", metrics) 688 | # trainer.save_metrics("eval", metrics) 689 | 690 | if training_args.do_predict: 691 | logger.info("*** Predict ***") 692 | 693 | predict_results = trainer.predict( 694 | predict_dataset, metric_key_prefix="predict", max_length=max_length, num_beams=num_beams 695 | ) 696 | metrics = predict_results.metrics 697 | max_predict_samples = ( 698 | data_args.max_predict_samples if data_args.max_predict_samples is not None else len(predict_dataset) 699 | ) 700 | metrics["predict_samples"] = min(max_predict_samples, len(predict_dataset)) 701 | 702 | trainer.log_metrics("predict", metrics) 703 | trainer.save_metrics("predict", metrics) 704 | 705 | if trainer.is_world_process_zero(): 706 | if training_args.predict_with_generate: 707 | predictions = tokenizer.batch_decode( 708 | predict_results.predictions, skip_special_tokens=True, clean_up_tokenization_spaces=True 709 | ) 710 | predictions = [pred.strip() for pred in predictions] 711 | output_prediction_file = os.path.join(training_args.output_dir, "generated_predictions.txt") 712 | with open(output_prediction_file, "w") as writer: 713 | writer.write("\n".join(predictions)) 714 | 715 | kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "summarization"} 716 | if data_args.dataset_name is not None: 717 | kwargs["dataset_tags"] = data_args.dataset_name 718 | if data_args.dataset_config_name is not None: 719 | kwargs["dataset_args"] = data_args.dataset_config_name 720 | kwargs["dataset"] = f"{data_args.dataset_name} {data_args.dataset_config_name}" 721 | else: 722 | kwargs["dataset"] = data_args.dataset_name 723 | 724 | if data_args.lang is not None: 725 | kwargs["language"] = data_args.lang 726 | 727 | if training_args.push_to_hub: 728 | trainer.push_to_hub(**kwargs) 729 | else: 730 | trainer.create_model_card(**kwargs) 731 | 732 | return results 733 | 734 | 735 | def _mp_fn(index): 736 | # For xla_spawn (TPUs) 737 | main() 738 | 739 | 740 | if __name__ == "__main__": 741 | main() 742 | -------------------------------------------------------------------------------- /src/inference/scorer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from transformers import AutoModel, AutoTokenizer 5 | 6 | class BERTScoreScorer: 7 | """ 8 | Scorer using BS-Fact, code adapted from bertscore official repo: https://github.com/Tiiiger/bert_score 9 | """ 10 | def __init__(self, model_name="roberta-large", device="cuda", num_layers=17, cache_dir=".cache"): 11 | model = AutoModel.from_pretrained(model_name, cache_dir=cache_dir) 12 | self.tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir) 13 | # We assume we are using roberta-large, please reference https://github.com/Tiiiger/bert_score/blob/dbcf6db37e8bd6ff68446f06b0ba5d0763b62d20/bert_score/utils.py#L247 14 | # if you wish to use other model and select the recommended layer 15 | model.encoder.layer = torch.nn.ModuleList([layer for layer in model.encoder.layer[:num_layers]]) 16 | 17 | self.model = model.to(device) 18 | self.device = device 19 | 20 | def prepare_document(self, input_str): 21 | """ 22 | Prepare anything that requires processing on document. 23 | This is called each iteration only once to save computation. 24 | """ 25 | self.bertscore_input_embedding, self.bertscore_input_attention_mask, self.bertscore_input_idf = self.encode_text(input_str) 26 | 27 | def score(self, summaries, index): 28 | """ 29 | Output the score for each example. 30 | summaries: The summary strings 31 | index: The indice of example (document that it should be compared to). IT should ideally be just range() except for beam search. 32 | """ 33 | bertscore_output_embedding, bertscore_output_attention_mask, bertscore_output_idf = self.encode_text(summaries) 34 | 35 | bertscore_input_embedding = self.bertscore_input_embedding[index] 36 | bertscore_input_attention_mask = self.bertscore_input_attention_mask[index] 37 | bertscore_input_idf = self.bertscore_input_idf[index] 38 | 39 | bertscore_scores = self.compute_bertscore( 40 | bertscore_input_embedding, 41 | bertscore_input_attention_mask, 42 | bertscore_input_idf, 43 | bertscore_output_embedding, 44 | bertscore_output_attention_mask, 45 | bertscore_output_idf, 46 | ) 47 | return bertscore_scores 48 | 49 | def encode_text(self, input_str): 50 | """ 51 | Helper function to encode any string to tensor using the tokenizer 52 | """ 53 | inputs = self.tokenizer(input_str, padding=True, truncation=True, return_tensors="pt") 54 | inputs = {k:v.to(self.device) for k,v in inputs.items()} 55 | with torch.no_grad(): 56 | outputs = self.model(**inputs) 57 | 58 | # idf 59 | idf = torch.clone(inputs["attention_mask"]).float() 60 | idf[idf == self.tokenizer.sep_token_id] = 0 61 | idf[idf == self.tokenizer.cls_token_id] = 0 62 | idf.div_(idf.sum(dim=1, keepdim=True)) 63 | 64 | return F.normalize(outputs[0], dim=-1), inputs["attention_mask"], idf 65 | 66 | def compute_bertscore(self, doc_embedding, doc_masks, doc_idf, summ_embedding, summ_masks, summ_idf): 67 | """ 68 | Helper function that is modified from the official code (greedy_cos_idf() method) https://github.com/Tiiiger/bert_score/blob/dbcf6db37e8bd6ff68446f06b0ba5d0763b62d20/bert_score/utils.py#L469 69 | """ 70 | 71 | batch_size = doc_embedding.size(0) 72 | sim = torch.bmm(summ_embedding, doc_embedding.transpose(1, 2)) 73 | masks = torch.bmm(summ_masks.unsqueeze(2).float(), doc_masks.unsqueeze(1).float()) 74 | masks = masks.expand(batch_size, -1, -1).contiguous().view_as(sim) 75 | 76 | masks = masks.float().to(sim.device) 77 | sim = sim * masks 78 | 79 | precision = sim.max(dim=2)[0] 80 | precision_scale = summ_idf.to(precision.device) 81 | P = (precision * precision_scale).sum(dim=1) 82 | 83 | summ_zero_mask = summ_masks.sum(dim=1).eq(2) 84 | if torch.any(summ_zero_mask): 85 | P = P.masked_fill(summ_zero_mask, 0.0) 86 | 87 | doc_zero_mask = doc_masks.sum(dim=1).eq(2) 88 | if torch.any(doc_zero_mask): 89 | P = P.masked_fill(doc_zero_mask, 0.0) 90 | 91 | return P 92 | 93 | from readability import Readability 94 | 95 | def get_flesch_kincaid(text): 96 | r = Readability(text) 97 | fk = r.flesch_kincaid() 98 | return fk.score 99 | 100 | 101 | def get_flesch(text): 102 | r = Readability(text) 103 | f = r.flesch() 104 | return f.score 105 | 106 | class FleschScorer: 107 | """ 108 | Scorer using BS-Fact, code adapted from bertscore official repo: https://github.com/Tiiiger/bert_score 109 | """ 110 | 111 | def __init__(self, name_module, flesch_score, device="cuda"): 112 | self.name_module = name_module 113 | self.flesch_score = flesch_score 114 | self.device = device 115 | 116 | def score(self, summaries, index): 117 | """ 118 | Output the score for each example. 119 | summaries: The summary strings 120 | index: The indice of example (document that it should be compared to). IT should ideally be just range() except for beam search. 121 | """ 122 | 123 | flesch_scores = [] 124 | for text in summaries: 125 | try: 126 | flesch_scores.append(get_flesch(text)) 127 | except: 128 | flesch_scores.append(100) 129 | flesch_scores = [1 - (abs(fs - self.flesch_score) / 100) for fs in flesch_scores] 130 | 131 | return torch.tensor(flesch_scores).to(self.device) 132 | 133 | 134 | class BERTandFleschScoreScorer: 135 | """ 136 | Scorer using BS-Fact, code adapted from bertscore official repo: https://github.com/Tiiiger/bert_score 137 | """ 138 | 139 | def __init__(self, model_name="roberta-large", device="cuda", num_layers=17, cache_dir=".cache", flesch_score=50, 140 | readability_weight=0.8): 141 | model = AutoModel.from_pretrained(model_name, cache_dir=cache_dir) 142 | self.tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir) 143 | # We assume we are using roberta-large, please reference https://github.com/Tiiiger/bert_score/blob/dbcf6db37e8bd6ff68446f06b0ba5d0763b62d20/bert_score/utils.py#L247 144 | # if you wish to use other model and select the recommended layer 145 | model.encoder.layer = torch.nn.ModuleList([layer for layer in model.encoder.layer[:num_layers]]) 146 | 147 | self.model = model.to(device) 148 | self.device = device 149 | self.flesch_score = flesch_score 150 | self.readability_weight = readability_weight 151 | 152 | def prepare_document(self, input_str): 153 | """ 154 | Prepare anything that requires processing on document. 155 | This is called each iteration only once to save computation. 156 | """ 157 | self.bertscore_input_embedding, self.bertscore_input_attention_mask, self.bertscore_input_idf = self.encode_text( 158 | input_str) 159 | 160 | def score(self, summaries, index): 161 | """ 162 | Output the score for each example. 163 | summaries: The summary strings 164 | index: The indice of example (document that it should be compared to). IT should ideally be just range() except for beam search. 165 | """ 166 | bertscore_output_embedding, bertscore_output_attention_mask, bertscore_output_idf = self.encode_text(summaries) 167 | 168 | bertscore_input_embedding = self.bertscore_input_embedding[index] 169 | bertscore_input_attention_mask = self.bertscore_input_attention_mask[index] 170 | bertscore_input_idf = self.bertscore_input_idf[index] 171 | 172 | bertscore_scores = self.compute_bertscore( 173 | bertscore_input_embedding, 174 | bertscore_input_attention_mask, 175 | bertscore_input_idf, 176 | bertscore_output_embedding, 177 | bertscore_output_attention_mask, 178 | bertscore_output_idf, 179 | ) 180 | 181 | flesch_scores = [] 182 | for text in summaries: 183 | try: 184 | flesch_scores.append(get_flesch(text)) 185 | except: 186 | flesch_scores.append(100) 187 | flesch_scores = [1 - (abs(fs - self.flesch_score) / 100) for fs in flesch_scores] 188 | 189 | flesch_scores = torch.tensor(flesch_scores).to(self.device) 190 | assert flesch_scores.size() == bertscore_scores.size() 191 | 192 | # import pdb 193 | # pdb.set_trace() 194 | 195 | return self.readability_weight * flesch_scores + (1 - self.readability_weight) * bertscore_scores 196 | 197 | def encode_text(self, input_str): 198 | """ 199 | Helper function to encode any string to tensor using the tokenizer 200 | """ 201 | inputs = self.tokenizer(input_str, padding=True, truncation=True, return_tensors="pt") 202 | inputs = {k: v.to(self.device) for k, v in inputs.items()} 203 | with torch.no_grad(): 204 | outputs = self.model(**inputs) 205 | 206 | # idf 207 | idf = torch.clone(inputs["attention_mask"]).float() 208 | idf[idf == self.tokenizer.sep_token_id] = 0 209 | idf[idf == self.tokenizer.cls_token_id] = 0 210 | idf.div_(idf.sum(dim=1, keepdim=True)) 211 | 212 | return F.normalize(outputs[0], dim=-1), inputs["attention_mask"], idf 213 | 214 | def compute_bertscore(self, doc_embedding, doc_masks, doc_idf, summ_embedding, summ_masks, summ_idf): 215 | """ 216 | Helper function that is modified from the official code (greedy_cos_idf() method) https://github.com/Tiiiger/bert_score/blob/dbcf6db37e8bd6ff68446f06b0ba5d0763b62d20/bert_score/utils.py#L469 217 | """ 218 | 219 | batch_size = doc_embedding.size(0) 220 | sim = torch.bmm(summ_embedding, doc_embedding.transpose(1, 2)) 221 | masks = torch.bmm(summ_masks.unsqueeze(2).float(), doc_masks.unsqueeze(1).float()) 222 | masks = masks.expand(batch_size, -1, -1).contiguous().view_as(sim) 223 | 224 | masks = masks.float().to(sim.device) 225 | sim = sim * masks 226 | 227 | precision = sim.max(dim=2)[0] 228 | precision_scale = summ_idf.to(precision.device) 229 | P = (precision * precision_scale).sum(dim=1) 230 | 231 | summ_zero_mask = summ_masks.sum(dim=1).eq(2) 232 | if torch.any(summ_zero_mask): 233 | P = P.masked_fill(summ_zero_mask, 0.0) 234 | 235 | doc_zero_mask = doc_masks.sum(dim=1).eq(2) 236 | if torch.any(doc_zero_mask): 237 | P = P.masked_fill(doc_zero_mask, 0.0) 238 | 239 | return P -------------------------------------------------------------------------------- /src/preprocess/generate_prompts_category.py: -------------------------------------------------------------------------------- 1 | import json 2 | from tqdm import tqdm 3 | 4 | def open_txt_file(file): 5 | entities = [] 6 | 7 | for line in open(file).readlines(): 8 | entities.append(line) 9 | 10 | return entities 11 | 12 | 13 | def open_file(file): 14 | entities = [] 15 | 16 | for line in open(file).readlines(): 17 | entities.append(json.loads(line)) 18 | 19 | return entities 20 | 21 | 22 | def save_file(data, file): 23 | file_writer = open(file, 'w') 24 | 25 | for line in data: 26 | file_writer.write(json.dumps(line) + "\n") 27 | 28 | 29 | def get_prompt(flesch_summary): 30 | if flesch_summary >= 80: 31 | prompt = 'Write highlights for this article for a 11 years old student:\n\n' 32 | elif 80 > flesch_summary >= 60: 33 | prompt = 'Write highlights for this article for a middle school student:\n\n' 34 | elif 60 > flesch_summary >= 40: 35 | prompt = 'Write highlights for this article for a high school student:\n\n' 36 | else: 37 | prompt = 'Write highlights for this article for a college student:\n\n' 38 | return prompt 39 | 40 | 41 | def transform_data(split): 42 | data = open_file('../data/' + split + '.json') 43 | new_data = [] 44 | 45 | for entry in tqdm(data): 46 | 47 | flesch_summary = entry["summary_metrics"]["flesch"] 48 | 49 | prompt = get_prompt(flesch_summary) 50 | entry["prompt"] = prompt 51 | entry["input_noprompt"] = entry["input"] 52 | entry["input"] = prompt + entry["input"] 53 | new_data.append(entry) 54 | 55 | save_file(new_data, '../data/' + split + '_prompt_category.json') 56 | 57 | 58 | transform_data('train') 59 | transform_data('validation') 60 | transform_data('test') 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | -------------------------------------------------------------------------------- /src/preprocess/generate_prompts_score.py: -------------------------------------------------------------------------------- 1 | import json 2 | from tqdm import tqdm 3 | 4 | def open_txt_file(file): 5 | entities = [] 6 | 7 | for line in open(file).readlines(): 8 | entities.append(line) 9 | 10 | return entities 11 | 12 | 13 | def open_file(file): 14 | entities = [] 15 | 16 | for line in open(file).readlines(): 17 | entities.append(json.loads(line)) 18 | 19 | return entities 20 | 21 | 22 | def save_file(data, file): 23 | file_writer = open(file, 'w') 24 | 25 | for line in data: 26 | file_writer.write(json.dumps(line) + "\n") 27 | 28 | 29 | def get_prompt(flesch_summary): 30 | prompt = 'Write highlights for this article with a flesch kincaid score of ' + str( 31 | int(round(flesch_summary, 0))) + ":\n\n" 32 | return prompt 33 | 34 | 35 | def transform_data(split): 36 | data = open_file('../data/' + split + '.json') 37 | new_data = [] 38 | 39 | for entry in tqdm(data): 40 | 41 | flesch_summary = entry["summary_metrics"]["flesch"] 42 | flesch_input = entry["input_metrics"]["flesch"] 43 | 44 | prompt = get_prompt(flesch_summary) 45 | entry["prompt"] = prompt 46 | entry["input_noprompt"] = entry["input"] 47 | entry["input"] = prompt + entry["input"] 48 | 49 | if split == 'test' and flesch_input >= 50: 50 | continue 51 | new_data.append(entry) 52 | 53 | 54 | save_file(new_data, '../data/' + split + '_prompt_score.json') 55 | 56 | 57 | transform_data('train') 58 | transform_data('validation') 59 | transform_data('test') 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | -------------------------------------------------------------------------------- /src/preprocess/preprocess_cnndm.py: -------------------------------------------------------------------------------- 1 | import json 2 | from readability import Readability 3 | from datasets import load_dataset 4 | 5 | # Download the CNNDM data (e.g., from https://huggingface.co/datasets/cnn_dailymail) 6 | dataset = '' 7 | 8 | def open_txt_file(file): 9 | entities = [] 10 | 11 | for line in open(file).readlines(): 12 | entities.append(line) 13 | 14 | return entities 15 | 16 | 17 | def open_file(file): 18 | entities = [] 19 | 20 | for line in open(file).readlines(): 21 | entities.append(json.loads(line)) 22 | 23 | return entities 24 | 25 | def save_file(data, file): 26 | file_writer = open(file, 'w') 27 | 28 | for line in data: 29 | file_writer.write(json.dumps(line) + "\n") 30 | 31 | def get_flesch_kincaid(text): 32 | r = Readability(text) 33 | fk = r.flesch_kincaid() 34 | return fk.score 35 | 36 | 37 | def get_flesch(text): 38 | r = Readability(text) 39 | f = r.flesch() 40 | return f.score 41 | 42 | 43 | def get_dale_chall(text): 44 | r = Readability(text) 45 | dc = r.dale_chall() 46 | return dc.score 47 | 48 | 49 | def get_ari(text): 50 | r = Readability(text) 51 | ari = r.ari() 52 | return ari.score 53 | 54 | 55 | def get_coleman_liau(text): 56 | r = Readability(text) 57 | cl = r.coleman_liau() 58 | return cl.score 59 | 60 | 61 | def get_gunning_fog(text): 62 | r = Readability(text) 63 | gf = r.gunning_fog() 64 | return gf.score 65 | 66 | 67 | def get_smog(text): 68 | r = Readability(text) 69 | s = r.smog() 70 | return s.score 71 | 72 | 73 | def get_spache(text): 74 | r = Readability(text) 75 | s = r.spache() 76 | return s.score 77 | 78 | def get_linsear_write(text): 79 | r = Readability(text) 80 | lw = r.linsear_write() 81 | return lw.score 82 | 83 | 84 | def compute_metrics(text): 85 | metrics = {} 86 | flesch = get_flesch(text) 87 | metrics['flesch'] = round(flesch, 4) 88 | 89 | dale_chall = get_dale_chall(text) 90 | metrics['dale_chall'] = round(dale_chall, 4) 91 | 92 | coleman_liau = get_coleman_liau(text) 93 | metrics['coleman_liau'] = round(coleman_liau, 4) 94 | 95 | gunning_fog = get_gunning_fog(text) 96 | metrics['gunning_fog'] = round(gunning_fog, 4) 97 | 98 | return metrics 99 | 100 | 101 | def process_data(split): 102 | data = [] 103 | for idx, (dial, sum, id_) in enumerate(zip(dataset[split]['article'], dataset[split]['highlights'], dataset[split]['id'])): 104 | entry = {} 105 | entry['input'] = dial 106 | metrics = compute_metrics(entry["input"]) 107 | entry['input_metrics'] = metrics 108 | 109 | entry['summary'] = sum 110 | entry['id'] = str(id_) 111 | metrics = compute_metrics(entry["summary"].replace("\n", " ")) 112 | entry['summary_metrics'] = metrics 113 | data.append(entry) 114 | 115 | save_file(data, 'data/' + split + '.json') 116 | 117 | 118 | process_data('train') 119 | process_data('validation') 120 | process_data('test') 121 | 122 | 123 | 124 | -------------------------------------------------------------------------------- /src/train/ds_config_stage3_fb16.json: -------------------------------------------------------------------------------- 1 | { 2 | "bf16": { 3 | "enabled": true 4 | }, 5 | 6 | "zero_optimization": { 7 | "stage": 2, 8 | "allgather_partitions": true, 9 | "allgather_bucket_size": 2e8, 10 | "overlap_comm": true, 11 | "reduce_scatter": true, 12 | "reduce_bucket_size": 2e8, 13 | "contiguous_gradients": true 14 | }, 15 | "train_batch_size": "auto", 16 | "train_micro_batch_size_per_gpu": "auto", 17 | "zero_allow_untested_optimizer": true, 18 | 19 | "optimizer": { 20 | "type": "AdamW", 21 | "params": { 22 | "lr": 1e-4, 23 | "betas": [ 24 | 0.9, 25 | 0.999 26 | ], 27 | "eps": 1e-8, 28 | "weight_decay": 0.0 29 | } 30 | }, 31 | 32 | "scheduler": { 33 | "type": "WarmupDecayLR", 34 | "params": { 35 | "total_num_steps": "auto", 36 | "warmup_min_lr": "auto", 37 | "warmup_max_lr": "auto", 38 | "warmup_num_steps": "auto" 39 | } 40 | }, 41 | 42 | "steps_per_print": 30, 43 | "wall_clock_breakdown": false 44 | } -------------------------------------------------------------------------------- /src/train/rl/accelerate_config.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | distributed_type: DEEPSPEED 3 | downcast_bf16: 'no' 4 | gpu_ids: 0,1,2,3,4,5,6,7 5 | machine_rank: 0 6 | main_training_function: main 7 | mixed_precision: bf16 8 | num_machines: 1 9 | num_processes: 8 10 | rdzv_backend: static 11 | same_network: true 12 | main_process_port: 61001 13 | tpu_env: [] 14 | tpu_use_cluster: false 15 | tpu_use_sudo: false 16 | use_cpu: false 17 | deepspeed_config: 18 | gradient_accumulation_steps: 1 19 | gradient_clipping: 1.0 20 | offload_optimizer_device: none 21 | offload_param_device: none 22 | zero3_init_flag: true 23 | zero_stage: 2 24 | 25 | 26 | -------------------------------------------------------------------------------- /src/train/rl/train.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | from datasets import load_dataset 3 | from tqdm import tqdm 4 | from transformers import AutoTokenizer 5 | from readability import Readability 6 | import numpy as np 7 | import sys 8 | eps = sys.float_info.epsilon 9 | import math 10 | 11 | import trlx 12 | from trlx.data.configs import ( 13 | ModelConfig, 14 | OptimizerConfig, 15 | SchedulerConfig, 16 | TokenizerConfig, 17 | TrainConfig, 18 | TRLConfig, 19 | ) 20 | from trlx.models.modeling_ppo import PPOConfig 21 | 22 | model_dir = 'checkpoints/exec-XXXX' # select the checkpoint from the prompt-based methods 23 | 24 | config = TRLConfig( 25 | train=TrainConfig( 26 | seq_length=1024, 27 | epochs=500, 28 | total_steps=100000, 29 | batch_size=2, 30 | batch_size_eval=2, 31 | checkpoint_interval=10000, 32 | eval_interval=500, 33 | save_optimizer=False, 34 | pipeline="PromptPipeline", 35 | trainer="AcceleratePPOTrainer", 36 | checkpoint_dir='checkpoint-diverse', 37 | save_best=True 38 | ), 39 | model=ModelConfig( 40 | model_path=model_dir, 41 | model_arch_type="seq2seq", 42 | num_layers_unfrozen=-1, 43 | ), 44 | tokenizer=TokenizerConfig( 45 | tokenizer_path=model_dir, 46 | truncation_side="right", 47 | ), 48 | optimizer=OptimizerConfig( 49 | name="adamw", 50 | kwargs={ 51 | "lr": 1.0e-5, 52 | "betas": [0.9, 0.999], 53 | "eps": 1.0e-8, 54 | "weight_decay": 1.0e-6, 55 | }, 56 | ), 57 | scheduler=SchedulerConfig( 58 | name="cosine_annealing", 59 | kwargs={ 60 | "T_max": 10000, 61 | "eta_min": 1.0e-6, 62 | }, 63 | ), 64 | method=PPOConfig( 65 | name="PPOConfig", 66 | num_rollouts=512, 67 | chunk_size=4, 68 | ppo_epochs=4, 69 | init_kl_coef=0.05, 70 | target=6, 71 | horizon=10000, 72 | gamma=0.99, 73 | lam=0.95, 74 | cliprange=0.2, 75 | cliprange_value=0.2, 76 | vf_coef=1.0, 77 | scale_reward=None, 78 | ref_mean=None, 79 | ref_std=None, 80 | cliprange_reward=10, 81 | gen_kwargs={ 82 | "max_new_tokens": 256, 83 | }, 84 | gen_experience_kwargs={ 85 | "max_new_tokens": 256, 86 | "do_sample": True, 87 | "temperature": 1.0, 88 | "top_k": 50, 89 | "top_p": 0.95, 90 | }, 91 | ), 92 | ) 93 | 94 | 95 | def get_flesch_kincaid(text): 96 | r = Readability(text) 97 | fk = r.flesch_kincaid() 98 | return fk.score 99 | 100 | 101 | def get_flesch(text): 102 | r = Readability(text) 103 | f = r.flesch() 104 | return f.score 105 | 106 | import random 107 | 108 | def change_scores(input_data): 109 | new_data = [] 110 | for text in input_data: 111 | score_sum = random.choice([10, 15, 25, 30, 33, 35, 37, 40, 45, 48, 50, 52, 60, 64, 68, 70, 71, 75, 83, 84, 88, 89, 90, 92, 93, 94, 95]) 112 | new_text = "Write highlights for this article with a flesch kincaid score of " + str(score_sum) + ":\n\n" + text 113 | new_data.append(new_text) 114 | return new_data 115 | 116 | sigma = 10 117 | def calc_nd(value, mean): 118 | return 1 / (sigma * np.sqrt(2 * np.pi)) * np.exp(- (value - mean) ** 2 / (2 * sigma ** 2)) / 0.039894228040143274 119 | 120 | 121 | import torch 122 | import torch.nn as nn 123 | import torch.nn.functional as F 124 | from transformers import AutoModel, AutoTokenizer 125 | import os 126 | model_name = "roberta-large" 127 | device = "cuda:" + str(os.environ.get('LOCAL_RANK',0)) 128 | num_layers = 17 129 | cache_dir=".cache" 130 | model = AutoModel.from_pretrained(model_name, cache_dir=cache_dir) 131 | model = model.to(device) 132 | tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir) 133 | model.encoder.layer = torch.nn.ModuleList([layer for layer in model.encoder.layer[:num_layers]]) 134 | 135 | def encode_text(input_str): 136 | """ 137 | Helper function to encode any string to tensor using the tokenizer 138 | """ 139 | inputs = tokenizer(input_str, padding=True, truncation=True, return_tensors="pt") 140 | inputs = {k: v.to(device) for k, v in inputs.items()} 141 | with torch.no_grad(): 142 | outputs = model(**inputs) 143 | 144 | # idf 145 | idf = torch.clone(inputs["attention_mask"]).float() 146 | idf[idf == tokenizer.sep_token_id] = 0 147 | idf[idf == tokenizer.cls_token_id] = 0 148 | idf.div_(idf.sum(dim=1, keepdim=True)) 149 | 150 | return F.normalize(outputs[0], dim=-1), inputs["attention_mask"], idf 151 | 152 | def compute_bertscore(doc_embedding, doc_masks, doc_idf, summ_embedding, summ_masks, summ_idf): 153 | """ 154 | Helper function that is modified from the official code (greedy_cos_idf() method) https://github.com/Tiiiger/bert_score/blob/dbcf6db37e8bd6ff68446f06b0ba5d0763b62d20/bert_score/utils.py#L469 155 | """ 156 | 157 | batch_size = doc_embedding.size(0) 158 | sim = torch.bmm(summ_embedding, doc_embedding.transpose(1, 2)) 159 | masks = torch.bmm(summ_masks.unsqueeze(2).float(), doc_masks.unsqueeze(1).float()) 160 | masks = masks.expand(batch_size, -1, -1).contiguous().view_as(sim) 161 | 162 | masks = masks.float().to(sim.device) 163 | sim = sim * masks 164 | 165 | precision = sim.max(dim=2)[0] 166 | precision_scale = summ_idf.to(precision.device) 167 | P = (precision * precision_scale).sum(dim=1) 168 | 169 | summ_zero_mask = summ_masks.sum(dim=1).eq(2) 170 | if torch.any(summ_zero_mask): 171 | P = P.masked_fill(summ_zero_mask, 0.0) 172 | 173 | doc_zero_mask = doc_masks.sum(dim=1).eq(2) 174 | if torch.any(doc_zero_mask): 175 | P = P.masked_fill(doc_zero_mask, 0.0) 176 | 177 | return P 178 | 179 | 180 | if __name__ == "__main__": 181 | 182 | def reward_fn(samples: List[str], prompts: List[str], outputs: List[str]): 183 | 184 | flesch_scores = [] 185 | original_scores = [] 186 | summaries = [] 187 | docs = [] 188 | for (generated_summary, input_doc) in zip(outputs, prompts): 189 | score_sum = int(input_doc.split("Write highlights for this article with a flesch kincaid score of ")[1][:2].replace(":", "")) 190 | original_scores.append(score_sum) 191 | doc = input_doc.split("Write highlights for this article with a flesch kincaid score of ")[1][2:] 192 | docs.append(doc) 193 | summaries.append(generated_summary.strip()) 194 | 195 | try: 196 | flesch_scores.append(get_flesch(generated_summary.strip())) 197 | except: 198 | flesch_scores.append(0) 199 | 200 | all_bertscore_scores = [] 201 | for doc, summary in zip(docs, summaries): 202 | 203 | bertscore_input_embedding, bertscore_input_attention_mask, bertscore_input_idf = encode_text([doc]) 204 | bertscore_output_embedding, bertscore_output_attention_mask, bertscore_output_idf = encode_text([summary]) 205 | 206 | bertscore_scores = compute_bertscore( 207 | bertscore_input_embedding, 208 | bertscore_input_attention_mask, 209 | bertscore_input_idf, 210 | bertscore_output_embedding, 211 | bertscore_output_attention_mask, 212 | bertscore_output_idf, 213 | ) 214 | bertscore_scores = bertscore_scores.tolist() 215 | all_bertscore_scores.extend(bertscore_scores) 216 | 217 | assert len(original_scores) == len(flesch_scores) == len(all_bertscore_scores) 218 | 219 | flesch_scores = [calc_nd(fs, o_fs) for fs, o_fs in zip(flesch_scores, original_scores)] 220 | 221 | readability_weight = 1 222 | flesch_scores = torch.tensor(flesch_scores) 223 | all_bertscore_scores = torch.tensor(all_bertscore_scores) 224 | flesch_scores = readability_weight * flesch_scores + (1 - readability_weight) * all_bertscore_scores 225 | flesch_scores = flesch_scores.tolist() 226 | 227 | return flesch_scores 228 | 229 | 230 | train_file = '../../data/train_prompt_score.json' 231 | validation_file = '../../data/train_prompt_score.json' 232 | data_files = {"train": train_file, "validation": validation_file} 233 | dataset = load_dataset("json", data_files=data_files) 234 | dataset['train'] = dataset['train'].shuffle(seed=42) 235 | dataset['validation'] = dataset['validation'].shuffle(seed=42) 236 | 237 | validation_examples = 2000 238 | val_prompts = [prompt for prompt in dataset['validation']["input_noprompt"][0:validation_examples]] 239 | print('\ntest 0\n', val_prompts[0]) 240 | val_summaries = dataset['validation']["summary"][0:validation_examples] 241 | val_prompts = change_scores(val_prompts) 242 | assert len(val_prompts) == len(val_summaries) 243 | print('\ntest after 0 \n', val_prompts[0]) 244 | 245 | prompts = dataset['train']["input_noprompt"] 246 | summaries = dataset['train']["summary"] 247 | prompts = [prompt for prompt in prompts] 248 | prompts = change_scores(prompts) 249 | assert len(prompts) == len(summaries) 250 | 251 | # make dictionary of prompts and labels to use for reward function 252 | tokenizer = AutoTokenizer.from_pretrained(config.model.model_path) 253 | tokenizer.padding_side = "left" 254 | tokenizer.truncation_side = "right" 255 | tokenizer.sep_token = "" 256 | prompt_label = {} 257 | max_length = config.train.seq_length 258 | 259 | trlx.train( 260 | reward_fn=reward_fn, 261 | prompts=prompts, 262 | eval_prompts=val_prompts, 263 | config=config, 264 | ) 265 | -------------------------------------------------------------------------------- /src/train/rl/train_rl_cnndm.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | source ~/anaconda3/etc/profile.d/conda.sh 4 | 5 | conda activate readability_summ 6 | export TOKENIZERS_PARALLELISM=true 7 | 8 | accelerate launch --config_file accelerate_config.yaml train.py 9 | 10 | conda deactivate -------------------------------------------------------------------------------- /src/train/run_summarization.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | # Copyright 2021 The HuggingFace Team. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ 17 | Fine-tuning the library models for sequence to sequence. 18 | """ 19 | # You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments. 20 | 21 | import logging 22 | import os 23 | import sys 24 | from dataclasses import dataclass, field 25 | from typing import Optional 26 | 27 | import datasets 28 | import nltk # Here to have a nice missing dependency error message early on 29 | import numpy as np 30 | from datasets import load_dataset 31 | 32 | import evaluate 33 | import transformers 34 | from filelock import FileLock 35 | from transformers import ( 36 | AutoConfig, 37 | AutoModelForSeq2SeqLM, 38 | AutoTokenizer, 39 | DataCollatorForSeq2Seq, 40 | HfArgumentParser, 41 | MBart50Tokenizer, 42 | MBart50TokenizerFast, 43 | MBartTokenizer, 44 | MBartTokenizerFast, 45 | Seq2SeqTrainer, 46 | Seq2SeqTrainingArguments, 47 | set_seed, 48 | ) 49 | from transformers.trainer_utils import get_last_checkpoint 50 | from transformers.utils import check_min_version, is_offline_mode, send_example_telemetry 51 | from transformers.utils.versions import require_version 52 | 53 | os.environ["NCCL_DEBUG"] = "INFO" 54 | 55 | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. 56 | #check_min_version("4.25.0.dev0") 57 | 58 | require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt") 59 | 60 | logger = logging.getLogger(__name__) 61 | 62 | try: 63 | nltk.data.find("tokenizers/punkt") 64 | except (LookupError, OSError): 65 | if is_offline_mode(): 66 | raise LookupError( 67 | "Offline mode: run this script without TRANSFORMERS_OFFLINE first to download nltk data files" 68 | ) 69 | with FileLock(".lock") as lock: 70 | nltk.download("punkt", quiet=True) 71 | 72 | # A list of all multilingual tokenizer which require lang attribute. 73 | MULTILINGUAL_TOKENIZERS = [MBartTokenizer, MBartTokenizerFast, MBart50Tokenizer, MBart50TokenizerFast] 74 | 75 | 76 | @dataclass 77 | class ModelArguments: 78 | """ 79 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. 80 | """ 81 | 82 | model_name_or_path: str = field( 83 | metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} 84 | ) 85 | config_name: Optional[str] = field( 86 | default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} 87 | ) 88 | tokenizer_name: Optional[str] = field( 89 | default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} 90 | ) 91 | cache_dir: Optional[str] = field( 92 | default=None, 93 | metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"}, 94 | ) 95 | use_fast_tokenizer: bool = field( 96 | default=True, 97 | metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}, 98 | ) 99 | model_revision: str = field( 100 | default="main", 101 | metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, 102 | ) 103 | use_auth_token: bool = field( 104 | default=False, 105 | metadata={ 106 | "help": ( 107 | "Will use the token generated when running `huggingface-cli login` (necessary to use this script " 108 | "with private models)." 109 | ) 110 | }, 111 | ) 112 | resize_position_embeddings: Optional[bool] = field( 113 | default=None, 114 | metadata={ 115 | "help": ( 116 | "Whether to automatically resize the position embeddings if `max_source_length` exceeds " 117 | "the model's position embeddings." 118 | ) 119 | }, 120 | ) 121 | 122 | 123 | @dataclass 124 | class DataTrainingArguments: 125 | """ 126 | Arguments pertaining to what data we are going to input our model for training and eval. 127 | """ 128 | 129 | lang: Optional[str] = field(default=None, metadata={"help": "Language id for summarization."}) 130 | 131 | dataset_name: Optional[str] = field( 132 | default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."} 133 | ) 134 | dataset_config_name: Optional[str] = field( 135 | default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} 136 | ) 137 | text_column: Optional[str] = field( 138 | default=None, 139 | metadata={"help": "The name of the column in the datasets containing the full texts (for summarization)."}, 140 | ) 141 | summary_column: Optional[str] = field( 142 | default=None, 143 | metadata={"help": "The name of the column in the datasets containing the summaries (for summarization)."}, 144 | ) 145 | train_file: Optional[str] = field( 146 | default=None, metadata={"help": "The input training data file (a jsonlines or csv file)."} 147 | ) 148 | validation_file: Optional[str] = field( 149 | default=None, 150 | metadata={ 151 | "help": ( 152 | "An optional input evaluation data file to evaluate the metrics (rouge) on (a jsonlines or csv file)." 153 | ) 154 | }, 155 | ) 156 | test_file: Optional[str] = field( 157 | default=None, 158 | metadata={ 159 | "help": "An optional input test data file to evaluate the metrics (rouge) on (a jsonlines or csv file)." 160 | }, 161 | ) 162 | overwrite_cache: bool = field( 163 | default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} 164 | ) 165 | preprocessing_num_workers: Optional[int] = field( 166 | default=None, 167 | metadata={"help": "The number of processes to use for the preprocessing."}, 168 | ) 169 | max_source_length: Optional[int] = field( 170 | default=1024, 171 | metadata={ 172 | "help": ( 173 | "The maximum total input sequence length after tokenization. Sequences longer " 174 | "than this will be truncated, sequences shorter will be padded." 175 | ) 176 | }, 177 | ) 178 | max_target_length: Optional[int] = field( 179 | default=128, 180 | metadata={ 181 | "help": ( 182 | "The maximum total sequence length for target text after tokenization. Sequences longer " 183 | "than this will be truncated, sequences shorter will be padded." 184 | ) 185 | }, 186 | ) 187 | val_max_target_length: Optional[int] = field( 188 | default=None, 189 | metadata={ 190 | "help": ( 191 | "The maximum total sequence length for validation target text after tokenization. Sequences longer " 192 | "than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`." 193 | "This argument is also used to override the ``max_length`` param of ``model.generate``, which is used " 194 | "during ``evaluate`` and ``predict``." 195 | ) 196 | }, 197 | ) 198 | pad_to_max_length: bool = field( 199 | default=False, 200 | metadata={ 201 | "help": ( 202 | "Whether to pad all samples to model maximum sentence length. " 203 | "If False, will pad the samples dynamically when batching to the maximum length in the batch. More " 204 | "efficient on GPU but very bad for TPU." 205 | ) 206 | }, 207 | ) 208 | max_train_samples: Optional[int] = field( 209 | default=None, 210 | metadata={ 211 | "help": ( 212 | "For debugging purposes or quicker training, truncate the number of training examples to this " 213 | "value if set." 214 | ) 215 | }, 216 | ) 217 | max_eval_samples: Optional[int] = field( 218 | default=None, 219 | metadata={ 220 | "help": ( 221 | "For debugging purposes or quicker training, truncate the number of evaluation examples to this " 222 | "value if set." 223 | ) 224 | }, 225 | ) 226 | max_predict_samples: Optional[int] = field( 227 | default=None, 228 | metadata={ 229 | "help": ( 230 | "For debugging purposes or quicker training, truncate the number of prediction examples to this " 231 | "value if set." 232 | ) 233 | }, 234 | ) 235 | num_beams: Optional[int] = field( 236 | default=None, 237 | metadata={ 238 | "help": ( 239 | "Number of beams to use for evaluation. This argument will be passed to ``model.generate``, " 240 | "which is used during ``evaluate`` and ``predict``." 241 | ) 242 | }, 243 | ) 244 | ignore_pad_token_for_loss: bool = field( 245 | default=True, 246 | metadata={ 247 | "help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not." 248 | }, 249 | ) 250 | source_prefix: Optional[str] = field( 251 | default="", metadata={"help": "A prefix to add before every source text (useful for T5 models)."} 252 | ) 253 | 254 | forced_bos_token: Optional[str] = field( 255 | default=None, 256 | metadata={ 257 | "help": ( 258 | "The token to force as the first generated token after the decoder_start_token_id." 259 | "Useful for multilingual models like mBART where the first generated token" 260 | "needs to be the target language token (Usually it is the target language token)" 261 | ) 262 | }, 263 | ) 264 | 265 | def __post_init__(self): 266 | if self.dataset_name is None and self.train_file is None and self.validation_file is None: 267 | raise ValueError("Need either a dataset name or a training/validation file.") 268 | else: 269 | if self.train_file is not None: 270 | extension = self.train_file.split(".")[-1] 271 | assert extension in ["csv", "json"], "`train_file` should be a csv or a json file." 272 | if self.validation_file is not None: 273 | extension = self.validation_file.split(".")[-1] 274 | assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file." 275 | if self.val_max_target_length is None: 276 | self.val_max_target_length = self.max_target_length 277 | 278 | 279 | summarization_name_mapping = { 280 | "amazon_reviews_multi": ("review_body", "review_title"), 281 | "big_patent": ("description", "abstract"), 282 | "cnn_dailymail": ("article", "highlights"), 283 | "orange_sum": ("text", "summary"), 284 | "pn_summary": ("article", "summary"), 285 | "psc": ("extract_text", "summary_text"), 286 | "samsum": ("dialogue", "summary"), 287 | "thaisum": ("body", "summary"), 288 | "xglue": ("news_body", "news_title"), 289 | "xsum": ("document", "summary"), 290 | "wiki_summary": ("article", "highlights"), 291 | "multi_news": ("document", "summary"), 292 | } 293 | 294 | 295 | def main(): 296 | # See all possible arguments in src/transformers/training_args.py 297 | # or by passing the --help flag to this script. 298 | # We now keep distinct sets of args, for a cleaner separation of concerns. 299 | 300 | parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments)) 301 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): 302 | # If we pass only one argument to the script and it's the path to a json file, 303 | # let's parse it to get our arguments. 304 | model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) 305 | else: 306 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 307 | 308 | # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The 309 | # information sent is the one passed as arguments along with your Python/PyTorch versions. 310 | send_example_telemetry("run_summarization", model_args, data_args) 311 | print("training_args", training_args) 312 | # Setup logging 313 | logging.basicConfig( 314 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 315 | datefmt="%m/%d/%Y %H:%M:%S", 316 | handlers=[logging.StreamHandler(sys.stdout)], 317 | ) 318 | log_level = training_args.get_process_log_level() 319 | logger.setLevel(log_level) 320 | datasets.utils.logging.set_verbosity(log_level) 321 | transformers.utils.logging.set_verbosity(log_level) 322 | transformers.utils.logging.enable_default_handler() 323 | transformers.utils.logging.enable_explicit_format() 324 | 325 | # Log on each process the small summary: 326 | logger.warning( 327 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" 328 | + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" 329 | ) 330 | logger.info(f"Training/evaluation parameters {training_args}") 331 | 332 | if data_args.source_prefix is None and model_args.model_name_or_path in [ 333 | "t5-small", 334 | "t5-base", 335 | "t5-large", 336 | "t5-3b", 337 | "t5-11b", 338 | ]: 339 | logger.warning( 340 | "You're running a t5 model but didn't provide a source prefix, which is the expected, e.g. with " 341 | "`--source_prefix 'summarize: ' `" 342 | ) 343 | 344 | # Detecting last checkpoint. 345 | last_checkpoint = None 346 | if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir: 347 | last_checkpoint = get_last_checkpoint(training_args.output_dir) 348 | if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: 349 | raise ValueError( 350 | f"Output directory ({training_args.output_dir}) already exists and is not empty. " 351 | "Use --overwrite_output_dir to overcome." 352 | ) 353 | elif last_checkpoint is not None and training_args.resume_from_checkpoint is None: 354 | logger.info( 355 | f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " 356 | "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." 357 | ) 358 | 359 | # Set seed before initializing model. 360 | set_seed(training_args.seed) 361 | 362 | # Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below) 363 | # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/ 364 | # (the dataset will be downloaded automatically from the datasets Hub). 365 | # 366 | # For CSV/JSON files this script will use the first column for the full texts and the second column for the 367 | # summaries (unless you specify column names for this with the `text_column` and `summary_column` arguments). 368 | # 369 | # In distributed training, the load_dataset function guarantee that only one local process can concurrently 370 | # download the dataset. 371 | print("data_args", data_args) 372 | if data_args.dataset_name is not None: 373 | # Downloading and loading a dataset from the hub. 374 | raw_datasets = load_dataset( 375 | data_args.dataset_name, 376 | data_args.dataset_config_name, 377 | cache_dir=model_args.cache_dir, 378 | use_auth_token=True if model_args.use_auth_token else None, 379 | ) 380 | else: 381 | data_files = {} 382 | if data_args.train_file is not None: 383 | data_files["train"] = data_args.train_file 384 | extension = data_args.train_file.split(".")[-1] 385 | if data_args.validation_file is not None: 386 | data_files["validation"] = data_args.validation_file 387 | extension = data_args.validation_file.split(".")[-1] 388 | if data_args.test_file is not None: 389 | data_files["test"] = data_args.test_file 390 | extension = data_args.test_file.split(".")[-1] 391 | raw_datasets = load_dataset( 392 | extension, 393 | data_files=data_files, 394 | cache_dir=model_args.cache_dir, 395 | use_auth_token=True if model_args.use_auth_token else None, 396 | ) 397 | # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at 398 | # https://huggingface.co/docs/datasets/loading_datasets.html. 399 | 400 | # Load pretrained model and tokenizer 401 | # 402 | # Distributed training: 403 | # The .from_pretrained methods guarantee that only one local process can concurrently 404 | # download model & vocab. 405 | print("model_args", model_args) 406 | config = AutoConfig.from_pretrained( 407 | model_args.config_name if model_args.config_name else model_args.model_name_or_path, 408 | cache_dir=model_args.cache_dir, 409 | revision=model_args.model_revision, 410 | use_auth_token=True if model_args.use_auth_token else None, 411 | ) 412 | tokenizer = AutoTokenizer.from_pretrained( 413 | model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, 414 | cache_dir=model_args.cache_dir, 415 | use_fast=model_args.use_fast_tokenizer, 416 | revision=model_args.model_revision, 417 | use_auth_token=True if model_args.use_auth_token else None, 418 | ) 419 | model = AutoModelForSeq2SeqLM.from_pretrained( 420 | model_args.model_name_or_path, 421 | from_tf=bool(".ckpt" in model_args.model_name_or_path), 422 | config=config, 423 | cache_dir=model_args.cache_dir, 424 | revision=model_args.model_revision, 425 | use_auth_token=True if model_args.use_auth_token else None, 426 | ) 427 | 428 | # We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch 429 | # on a small vocab and want a smaller embedding size, remove this test. 430 | embedding_size = model.get_input_embeddings().weight.shape[0] 431 | if len(tokenizer) > embedding_size: 432 | model.resize_token_embeddings(len(tokenizer)) 433 | 434 | if model.config.decoder_start_token_id is None and isinstance(tokenizer, (MBartTokenizer, MBartTokenizerFast)): 435 | if isinstance(tokenizer, MBartTokenizer): 436 | model.config.decoder_start_token_id = tokenizer.lang_code_to_id[data_args.lang] 437 | else: 438 | model.config.decoder_start_token_id = tokenizer.convert_tokens_to_ids(data_args.lang) 439 | 440 | if model.config.decoder_start_token_id is None: 441 | raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined") 442 | 443 | if ( 444 | hasattr(model.config, "max_position_embeddings") 445 | and model.config.max_position_embeddings < data_args.max_source_length 446 | ): 447 | if model_args.resize_position_embeddings is None: 448 | logger.warning( 449 | "Increasing the model's number of position embedding vectors from" 450 | f" {model.config.max_position_embeddings} to {data_args.max_source_length}." 451 | ) 452 | model.resize_position_embeddings(data_args.max_source_length) 453 | elif model_args.resize_position_embeddings: 454 | model.resize_position_embeddings(data_args.max_source_length) 455 | else: 456 | raise ValueError( 457 | f"`--max_source_length` is set to {data_args.max_source_length}, but the model only has" 458 | f" {model.config.max_position_embeddings} position encodings. Consider either reducing" 459 | f" `--max_source_length` to {model.config.max_position_embeddings} or to automatically resize the" 460 | " model's position encodings by passing `--resize_position_embeddings`." 461 | ) 462 | 463 | prefix = data_args.source_prefix if data_args.source_prefix is not None else "" 464 | 465 | # Preprocessing the datasets. 466 | # We need to tokenize inputs and targets. 467 | if training_args.do_train: 468 | column_names = raw_datasets["train"].column_names 469 | elif training_args.do_eval: 470 | column_names = raw_datasets["validation"].column_names 471 | elif training_args.do_predict: 472 | column_names = raw_datasets["test"].column_names 473 | else: 474 | logger.info("There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`.") 475 | return 476 | 477 | if isinstance(tokenizer, tuple(MULTILINGUAL_TOKENIZERS)): 478 | assert ( 479 | data_args.lang is not None 480 | ), f"{tokenizer.__class__.__name__} is a multilingual tokenizer which requires --lang argument" 481 | 482 | tokenizer.src_lang = data_args.lang 483 | tokenizer.tgt_lang = data_args.lang 484 | 485 | # For multilingual translation models like mBART-50 and M2M100 we need to force the target language token 486 | # as the first generated token. We ask the user to explicitly provide this as --forced_bos_token argument. 487 | forced_bos_token_id = ( 488 | tokenizer.lang_code_to_id[data_args.forced_bos_token] if data_args.forced_bos_token is not None else None 489 | ) 490 | model.config.forced_bos_token_id = forced_bos_token_id 491 | 492 | # Get the column names for input/target. 493 | dataset_columns = summarization_name_mapping.get(data_args.dataset_name, None) 494 | if data_args.text_column is None: 495 | text_column = dataset_columns[0] if dataset_columns is not None else column_names[0] 496 | else: 497 | text_column = data_args.text_column 498 | if text_column not in column_names: 499 | raise ValueError( 500 | f"--text_column' value '{data_args.text_column}' needs to be one of: {', '.join(column_names)}" 501 | ) 502 | if data_args.summary_column is None: 503 | summary_column = dataset_columns[1] if dataset_columns is not None else column_names[1] 504 | else: 505 | summary_column = data_args.summary_column 506 | if summary_column not in column_names: 507 | raise ValueError( 508 | f"--summary_column' value '{data_args.summary_column}' needs to be one of: {', '.join(column_names)}" 509 | ) 510 | 511 | # Temporarily set max_target_length for training. 512 | max_target_length = data_args.max_target_length 513 | padding = "max_length" if data_args.pad_to_max_length else False 514 | 515 | if training_args.label_smoothing_factor > 0 and not hasattr(model, "prepare_decoder_input_ids_from_labels"): 516 | logger.warning( 517 | "label_smoothing is enabled but the `prepare_decoder_input_ids_from_labels` method is not defined for" 518 | f"`{model.__class__.__name__}`. This will lead to loss being calculated twice and will take up more memory" 519 | ) 520 | 521 | print(data_args) 522 | 523 | def preprocess_function(examples): 524 | # remove pairs where at least one record is None 525 | 526 | inputs, targets = [], [] 527 | for i in range(len(examples[text_column])): 528 | if examples[text_column][i] and examples[summary_column][i]: 529 | inputs.append(examples[text_column][i]) 530 | targets.append(examples[summary_column][i]) 531 | 532 | inputs = [prefix + inp for inp in inputs] 533 | model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, padding=padding, truncation=True) 534 | 535 | # Tokenize targets with the `text_target` keyword argument 536 | labels = tokenizer(text_target=targets, max_length=max_target_length, padding=padding, truncation=True) 537 | 538 | # If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore 539 | # padding in the loss. 540 | if padding == "max_length" and data_args.ignore_pad_token_for_loss: 541 | labels["input_ids"] = [ 542 | [(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"] 543 | ] 544 | 545 | model_inputs["labels"] = labels["input_ids"] 546 | return model_inputs 547 | 548 | if training_args.do_train: 549 | if "train" not in raw_datasets: 550 | raise ValueError("--do_train requires a train dataset") 551 | train_dataset = raw_datasets["train"] 552 | if data_args.max_train_samples is not None: 553 | max_train_samples = min(len(train_dataset), data_args.max_train_samples) 554 | train_dataset = train_dataset.select(range(max_train_samples)) 555 | with training_args.main_process_first(desc="train dataset map pre-processing"): 556 | train_dataset = train_dataset.map( 557 | preprocess_function, 558 | batched=True, 559 | num_proc=data_args.preprocessing_num_workers, 560 | remove_columns=column_names, 561 | load_from_cache_file=not data_args.overwrite_cache, 562 | desc="Running tokenizer on train dataset", 563 | ) 564 | 565 | if training_args.do_eval: 566 | max_target_length = data_args.val_max_target_length 567 | if "validation" not in raw_datasets: 568 | raise ValueError("--do_eval requires a validation dataset") 569 | eval_dataset = raw_datasets["validation"] 570 | if data_args.max_eval_samples is not None: 571 | max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples) 572 | eval_dataset = eval_dataset.select(range(max_eval_samples)) 573 | with training_args.main_process_first(desc="validation dataset map pre-processing"): 574 | eval_dataset = eval_dataset.map( 575 | preprocess_function, 576 | batched=True, 577 | num_proc=data_args.preprocessing_num_workers, 578 | remove_columns=column_names, 579 | load_from_cache_file=not data_args.overwrite_cache, 580 | desc="Running tokenizer on validation dataset", 581 | ) 582 | 583 | if training_args.do_predict: 584 | max_target_length = data_args.val_max_target_length 585 | if "test" not in raw_datasets: 586 | raise ValueError("--do_predict requires a test dataset") 587 | predict_dataset = raw_datasets["test"] 588 | if data_args.max_predict_samples is not None: 589 | max_predict_samples = min(len(predict_dataset), data_args.max_predict_samples) 590 | predict_dataset = predict_dataset.select(range(max_predict_samples)) 591 | with training_args.main_process_first(desc="prediction dataset map pre-processing"): 592 | predict_dataset = predict_dataset.map( 593 | preprocess_function, 594 | batched=True, 595 | num_proc=data_args.preprocessing_num_workers, 596 | remove_columns=column_names, 597 | load_from_cache_file=not data_args.overwrite_cache, 598 | desc="Running tokenizer on prediction dataset", 599 | ) 600 | 601 | # Data collator 602 | label_pad_token_id = -100 if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id 603 | data_collator = DataCollatorForSeq2Seq( 604 | tokenizer, 605 | model=model, 606 | label_pad_token_id=label_pad_token_id, 607 | pad_to_multiple_of=8 if training_args.fp16 else None, 608 | ) 609 | 610 | # Metric 611 | metric = evaluate.load("rouge") 612 | 613 | def postprocess_text(preds, labels): 614 | preds = [pred.strip() for pred in preds] 615 | labels = [label.strip() for label in labels] 616 | 617 | # rougeLSum expects newline after each sentence 618 | preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds] 619 | labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels] 620 | 621 | return preds, labels 622 | 623 | def compute_metrics(eval_preds): 624 | preds, labels = eval_preds 625 | if isinstance(preds, tuple): 626 | preds = preds[0] 627 | decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True) 628 | if data_args.ignore_pad_token_for_loss: 629 | # Replace -100 in the labels as we can't decode them. 630 | labels = np.where(labels != -100, labels, tokenizer.pad_token_id) 631 | decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True) 632 | 633 | # Some simple post-processing 634 | decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels) 635 | 636 | result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True) 637 | result = {k: round(v * 100, 4) for k, v in result.items()} 638 | prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds] 639 | result["gen_len"] = np.mean(prediction_lens) 640 | return result 641 | 642 | # Initialize our Trainer 643 | trainer = Seq2SeqTrainer( 644 | model=model, 645 | args=training_args, 646 | train_dataset=train_dataset if training_args.do_train else None, 647 | eval_dataset=eval_dataset if training_args.do_eval else None, 648 | tokenizer=tokenizer, 649 | data_collator=data_collator, 650 | compute_metrics=compute_metrics if training_args.predict_with_generate else None, 651 | ) 652 | 653 | # Training 654 | if training_args.do_train: 655 | checkpoint = None 656 | if training_args.resume_from_checkpoint is not None: 657 | checkpoint = training_args.resume_from_checkpoint 658 | elif last_checkpoint is not None: 659 | checkpoint = last_checkpoint 660 | train_result = trainer.train(resume_from_checkpoint=checkpoint) 661 | trainer.save_model() # Saves the tokenizer too for easy upload 662 | 663 | metrics = train_result.metrics 664 | max_train_samples = ( 665 | data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset) 666 | ) 667 | metrics["train_samples"] = min(max_train_samples, len(train_dataset)) 668 | 669 | trainer.log_metrics("train", metrics) 670 | trainer.save_metrics("train", metrics) 671 | trainer.save_state() 672 | 673 | # Evaluation 674 | results = {} 675 | max_length = ( 676 | training_args.generation_max_length 677 | if training_args.generation_max_length is not None 678 | else data_args.val_max_target_length 679 | ) 680 | num_beams = data_args.num_beams if data_args.num_beams is not None else training_args.generation_num_beams 681 | # if training_args.do_eval: 682 | # logger.info("*** Evaluate ***") 683 | # metrics = trainer.evaluate(max_length=max_length, num_beams=num_beams, metric_key_prefix="eval") 684 | # max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset) 685 | # metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset)) 686 | # 687 | # trainer.log_metrics("eval", metrics) 688 | # trainer.save_metrics("eval", metrics) 689 | 690 | if training_args.do_predict: 691 | logger.info("*** Predict ***") 692 | 693 | predict_results = trainer.predict( 694 | predict_dataset, metric_key_prefix="predict", max_length=max_length, num_beams=num_beams 695 | ) 696 | metrics = predict_results.metrics 697 | max_predict_samples = ( 698 | data_args.max_predict_samples if data_args.max_predict_samples is not None else len(predict_dataset) 699 | ) 700 | metrics["predict_samples"] = min(max_predict_samples, len(predict_dataset)) 701 | 702 | trainer.log_metrics("predict", metrics) 703 | trainer.save_metrics("predict", metrics) 704 | 705 | if trainer.is_world_process_zero(): 706 | if training_args.predict_with_generate: 707 | predictions = tokenizer.batch_decode( 708 | predict_results.predictions, skip_special_tokens=True, clean_up_tokenization_spaces=True 709 | ) 710 | predictions = [pred.strip() for pred in predictions] 711 | output_prediction_file = os.path.join(training_args.output_dir, "generated_predictions.txt") 712 | with open(output_prediction_file, "w") as writer: 713 | writer.write("\n".join(predictions)) 714 | 715 | kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "summarization"} 716 | if data_args.dataset_name is not None: 717 | kwargs["dataset_tags"] = data_args.dataset_name 718 | if data_args.dataset_config_name is not None: 719 | kwargs["dataset_args"] = data_args.dataset_config_name 720 | kwargs["dataset"] = f"{data_args.dataset_name} {data_args.dataset_config_name}" 721 | else: 722 | kwargs["dataset"] = data_args.dataset_name 723 | 724 | if data_args.lang is not None: 725 | kwargs["language"] = data_args.lang 726 | 727 | if training_args.push_to_hub: 728 | trainer.push_to_hub(**kwargs) 729 | else: 730 | trainer.create_model_card(**kwargs) 731 | 732 | return results 733 | 734 | 735 | def _mp_fn(index): 736 | # For xla_spawn (TPUs) 737 | main() 738 | 739 | 740 | if __name__ == "__main__": 741 | main() 742 | -------------------------------------------------------------------------------- /src/train/train_cnndm.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | source ~/anaconda3/etc/profile.d/conda.sh 4 | 5 | conda activate readability_summ 6 | 7 | FOLDER_OUTPUT=/mnt/hd3/checkpoints/exec-$RANDOM 8 | 9 | TRAIN_FILE='../data/train_prompt_category.json' 10 | VAL_FILE='../data/validation_prompt_category.json' 11 | 12 | MODEL_NAME='google/flan-t5-large' 13 | 14 | deepspeed --master_port 61002 --include localhost:0,1,2,3,4,5,6,7 run_summarization.py --model_name_or_path ${MODEL_NAME} \ 15 | --output_dir ${FOLDER_OUTPUT} --text_column input --summary_column summary \ 16 | --train_file ${TRAIN_FILE} \ 17 | --validation_file ${VAL_FILE} \ 18 | --learning_rate 1e-4 \ 19 | --max_source_length 1024 \ 20 | --source_prefix "" \ 21 | --num_train_epochs 20 \ 22 | --logging_steps 200 \ 23 | --preprocessing_num_workers 100 \ 24 | --eval_steps 10000 \ 25 | --save_steps 10000 \ 26 | --save_total_limit 2 \ 27 | --evaluation_strategy "steps" \ 28 | --per_device_train_batch_size 4 \ 29 | --per_device_eval_batch_size 4 \ 30 | --metric_for_best_model "rouge1" \ 31 | --load_best_model_at_end \ 32 | --predict_with_generate \ 33 | --deepspeed ds_config_stage3_fb16.json \ 34 | --bf16 \ 35 | --bf16_full_eval \ 36 | --do_train 37 | 38 | conda deactivate --------------------------------------------------------------------------------