├── CODE_OF_CONDUCT.md
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── images
└── read-control.png
├── requirements.txt
└── src
├── inference
├── generation.py
├── inference_category.sh
├── inference_lookahead.sh
├── inference_score.sh
├── lookahead.py
├── run_lookahead.py
├── run_summarization.py
└── scorer.py
├── preprocess
├── generate_prompts_category.py
├── generate_prompts_score.py
└── preprocess_cnndm.py
└── train
├── ds_config_stage3_fb16.json
├── rl
├── accelerate_config.yaml
├── train.py
└── train_rl_cnndm.sh
├── run_summarization.py
└── train_cnndm.sh
/CODE_OF_CONDUCT.md:
--------------------------------------------------------------------------------
1 | ## Code of Conduct
2 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct).
3 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact
4 | opensource-codeofconduct@amazon.com with any additional questions or comments.
5 |
--------------------------------------------------------------------------------
/CONTRIBUTING.md:
--------------------------------------------------------------------------------
1 | # Contributing Guidelines
2 |
3 | Thank you for your interest in contributing to our project. Whether it's a bug report, new feature, correction, or additional
4 | documentation, we greatly value feedback and contributions from our community.
5 |
6 | Please read through this document before submitting any issues or pull requests to ensure we have all the necessary
7 | information to effectively respond to your bug report or contribution.
8 |
9 |
10 | ## Reporting Bugs/Feature Requests
11 |
12 | We welcome you to use the GitHub issue tracker to report bugs or suggest features.
13 |
14 | When filing an issue, please check existing open, or recently closed, issues to make sure somebody else hasn't already
15 | reported the issue. Please try to include as much information as you can. Details like these are incredibly useful:
16 |
17 | * A reproducible test case or series of steps
18 | * The version of our code being used
19 | * Any modifications you've made relevant to the bug
20 | * Anything unusual about your environment or deployment
21 |
22 |
23 | ## Contributing via Pull Requests
24 | Contributions via pull requests are much appreciated. Before sending us a pull request, please ensure that:
25 |
26 | 1. You are working against the latest source on the *main* branch.
27 | 2. You check existing open, and recently merged, pull requests to make sure someone else hasn't addressed the problem already.
28 | 3. You open an issue to discuss any significant work - we would hate for your time to be wasted.
29 |
30 | To send us a pull request, please:
31 |
32 | 1. Fork the repository.
33 | 2. Modify the source; please focus on the specific change you are contributing. If you also reformat all the code, it will be hard for us to focus on your change.
34 | 3. Ensure local tests pass.
35 | 4. Commit to your fork using clear commit messages.
36 | 5. Send us a pull request, answering any default questions in the pull request interface.
37 | 6. Pay attention to any automated CI failures reported in the pull request, and stay involved in the conversation.
38 |
39 | GitHub provides additional document on [forking a repository](https://help.github.com/articles/fork-a-repo/) and
40 | [creating a pull request](https://help.github.com/articles/creating-a-pull-request/).
41 |
42 |
43 | ## Finding contributions to work on
44 | Looking at the existing issues is a great way to find something to contribute on. As our projects, by default, use the default GitHub issue labels (enhancement/bug/duplicate/help wanted/invalid/question/wontfix), looking at any 'help wanted' issues is a great place to start.
45 |
46 |
47 | ## Code of Conduct
48 | This project has adopted the [Amazon Open Source Code of Conduct](https://aws.github.io/code-of-conduct).
49 | For more information see the [Code of Conduct FAQ](https://aws.github.io/code-of-conduct-faq) or contact
50 | opensource-codeofconduct@amazon.com with any additional questions or comments.
51 |
52 |
53 | ## Security issue notifications
54 | If you discover a potential security issue in this project we ask that you notify AWS/Amazon Security via our [vulnerability reporting page](http://aws.amazon.com/security/vulnerability-reporting/). Please do **not** create a public github issue.
55 |
56 |
57 | ## Licensing
58 |
59 | See the [LICENSE](LICENSE) file for our project's licensing. We will ask you to confirm the licensing of your contribution.
60 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Attribution-NonCommercial 4.0 International
2 |
3 | =======================================================================
4 |
5 | Creative Commons Corporation ("Creative Commons") is not a law firm and
6 | does not provide legal services or legal advice. Distribution of
7 | Creative Commons public licenses does not create a lawyer-client or
8 | other relationship. Creative Commons makes its licenses and related
9 | information available on an "as-is" basis. Creative Commons gives no
10 | warranties regarding its licenses, any material licensed under their
11 | terms and conditions, or any related information. Creative Commons
12 | disclaims all liability for damages resulting from their use to the
13 | fullest extent possible.
14 |
15 | Using Creative Commons Public Licenses
16 |
17 | Creative Commons public licenses provide a standard set of terms and
18 | conditions that creators and other rights holders may use to share
19 | original works of authorship and other material subject to copyright
20 | and certain other rights specified in the public license below. The
21 | following considerations are for informational purposes only, are not
22 | exhaustive, and do not form part of our licenses.
23 |
24 | Considerations for licensors: Our public licenses are
25 | intended for use by those authorized to give the public
26 | permission to use material in ways otherwise restricted by
27 | copyright and certain other rights. Our licenses are
28 | irrevocable. Licensors should read and understand the terms
29 | and conditions of the license they choose before applying it.
30 | Licensors should also secure all rights necessary before
31 | applying our licenses so that the public can reuse the
32 | material as expected. Licensors should clearly mark any
33 | material not subject to the license. This includes other CC-
34 | licensed material, or material used under an exception or
35 | limitation to copyright. More considerations for licensors:
36 | wiki.creativecommons.org/Considerations_for_licensors
37 |
38 | Considerations for the public: By using one of our public
39 | licenses, a licensor grants the public permission to use the
40 | licensed material under specified terms and conditions. If
41 | the licensor's permission is not necessary for any reason--for
42 | example, because of any applicable exception or limitation to
43 | copyright--then that use is not regulated by the license. Our
44 | licenses grant only permissions under copyright and certain
45 | other rights that a licensor has authority to grant. Use of
46 | the licensed material may still be restricted for other
47 | reasons, including because others have copyright or other
48 | rights in the material. A licensor may make special requests,
49 | such as asking that all changes be marked or described.
50 | Although not required by our licenses, you are encouraged to
51 | respect those requests where reasonable. More considerations
52 | for the public:
53 | wiki.creativecommons.org/Considerations_for_licensees
54 |
55 | =======================================================================
56 |
57 | Creative Commons Attribution-NonCommercial 4.0 International Public
58 | License
59 |
60 | By exercising the Licensed Rights (defined below), You accept and agree
61 | to be bound by the terms and conditions of this Creative Commons
62 | Attribution-NonCommercial 4.0 International Public License ("Public
63 | License"). To the extent this Public License may be interpreted as a
64 | contract, You are granted the Licensed Rights in consideration of Your
65 | acceptance of these terms and conditions, and the Licensor grants You
66 | such rights in consideration of benefits the Licensor receives from
67 | making the Licensed Material available under these terms and
68 | conditions.
69 |
70 |
71 | Section 1 -- Definitions.
72 |
73 | a. Adapted Material means material subject to Copyright and Similar
74 | Rights that is derived from or based upon the Licensed Material
75 | and in which the Licensed Material is translated, altered,
76 | arranged, transformed, or otherwise modified in a manner requiring
77 | permission under the Copyright and Similar Rights held by the
78 | Licensor. For purposes of this Public License, where the Licensed
79 | Material is a musical work, performance, or sound recording,
80 | Adapted Material is always produced where the Licensed Material is
81 | synched in timed relation with a moving image.
82 |
83 | b. Adapter's License means the license You apply to Your Copyright
84 | and Similar Rights in Your contributions to Adapted Material in
85 | accordance with the terms and conditions of this Public License.
86 |
87 | c. Copyright and Similar Rights means copyright and/or similar rights
88 | closely related to copyright including, without limitation,
89 | performance, broadcast, sound recording, and Sui Generis Database
90 | Rights, without regard to how the rights are labeled or
91 | categorized. For purposes of this Public License, the rights
92 | specified in Section 2(b)(1)-(2) are not Copyright and Similar
93 | Rights.
94 | d. Effective Technological Measures means those measures that, in the
95 | absence of proper authority, may not be circumvented under laws
96 | fulfilling obligations under Article 11 of the WIPO Copyright
97 | Treaty adopted on December 20, 1996, and/or similar international
98 | agreements.
99 |
100 | e. Exceptions and Limitations means fair use, fair dealing, and/or
101 | any other exception or limitation to Copyright and Similar Rights
102 | that applies to Your use of the Licensed Material.
103 |
104 | f. Licensed Material means the artistic or literary work, database,
105 | or other material to which the Licensor applied this Public
106 | License.
107 |
108 | g. Licensed Rights means the rights granted to You subject to the
109 | terms and conditions of this Public License, which are limited to
110 | all Copyright and Similar Rights that apply to Your use of the
111 | Licensed Material and that the Licensor has authority to license.
112 |
113 | h. Licensor means the individual(s) or entity(ies) granting rights
114 | under this Public License.
115 |
116 | i. NonCommercial means not primarily intended for or directed towards
117 | commercial advantage or monetary compensation. For purposes of
118 | this Public License, the exchange of the Licensed Material for
119 | other material subject to Copyright and Similar Rights by digital
120 | file-sharing or similar means is NonCommercial provided there is
121 | no payment of monetary compensation in connection with the
122 | exchange.
123 |
124 | j. Share means to provide material to the public by any means or
125 | process that requires permission under the Licensed Rights, such
126 | as reproduction, public display, public performance, distribution,
127 | dissemination, communication, or importation, and to make material
128 | available to the public including in ways that members of the
129 | public may access the material from a place and at a time
130 | individually chosen by them.
131 |
132 | k. Sui Generis Database Rights means rights other than copyright
133 | resulting from Directive 96/9/EC of the European Parliament and of
134 | the Council of 11 March 1996 on the legal protection of databases,
135 | as amended and/or succeeded, as well as other essentially
136 | equivalent rights anywhere in the world.
137 |
138 | l. You means the individual or entity exercising the Licensed Rights
139 | under this Public License. Your has a corresponding meaning.
140 |
141 |
142 | Section 2 -- Scope.
143 |
144 | a. License grant.
145 |
146 | 1. Subject to the terms and conditions of this Public License,
147 | the Licensor hereby grants You a worldwide, royalty-free,
148 | non-sublicensable, non-exclusive, irrevocable license to
149 | exercise the Licensed Rights in the Licensed Material to:
150 |
151 | a. reproduce and Share the Licensed Material, in whole or
152 | in part, for NonCommercial purposes only; and
153 |
154 | b. produce, reproduce, and Share Adapted Material for
155 | NonCommercial purposes only.
156 |
157 | 2. Exceptions and Limitations. For the avoidance of doubt, where
158 | Exceptions and Limitations apply to Your use, this Public
159 | License does not apply, and You do not need to comply with
160 | its terms and conditions.
161 |
162 | 3. Term. The term of this Public License is specified in Section
163 | 6(a).
164 |
165 | 4. Media and formats; technical modifications allowed. The
166 | Licensor authorizes You to exercise the Licensed Rights in
167 | all media and formats whether now known or hereafter created,
168 | and to make technical modifications necessary to do so. The
169 | Licensor waives and/or agrees not to assert any right or
170 | authority to forbid You from making technical modifications
171 | necessary to exercise the Licensed Rights, including
172 | technical modifications necessary to circumvent Effective
173 | Technological Measures. For purposes of this Public License,
174 | simply making modifications authorized by this Section 2(a)
175 | (4) never produces Adapted Material.
176 |
177 | 5. Downstream recipients.
178 |
179 | a. Offer from the Licensor -- Licensed Material. Every
180 | recipient of the Licensed Material automatically
181 | receives an offer from the Licensor to exercise the
182 | Licensed Rights under the terms and conditions of this
183 | Public License.
184 |
185 | b. No downstream restrictions. You may not offer or impose
186 | any additional or different terms or conditions on, or
187 | apply any Effective Technological Measures to, the
188 | Licensed Material if doing so restricts exercise of the
189 | Licensed Rights by any recipient of the Licensed
190 | Material.
191 |
192 | 6. No endorsement. Nothing in this Public License constitutes or
193 | may be construed as permission to assert or imply that You
194 | are, or that Your use of the Licensed Material is, connected
195 | with, or sponsored, endorsed, or granted official status by,
196 | the Licensor or others designated to receive attribution as
197 | provided in Section 3(a)(1)(A)(i).
198 |
199 | b. Other rights.
200 |
201 | 1. Moral rights, such as the right of integrity, are not
202 | licensed under this Public License, nor are publicity,
203 | privacy, and/or other similar personality rights; however, to
204 | the extent possible, the Licensor waives and/or agrees not to
205 | assert any such rights held by the Licensor to the limited
206 | extent necessary to allow You to exercise the Licensed
207 | Rights, but not otherwise.
208 |
209 | 2. Patent and trademark rights are not licensed under this
210 | Public License.
211 |
212 | 3. To the extent possible, the Licensor waives any right to
213 | collect royalties from You for the exercise of the Licensed
214 | Rights, whether directly or through a collecting society
215 | under any voluntary or waivable statutory or compulsory
216 | licensing scheme. In all other cases the Licensor expressly
217 | reserves any right to collect such royalties, including when
218 | the Licensed Material is used other than for NonCommercial
219 | purposes.
220 |
221 |
222 | Section 3 -- License Conditions.
223 |
224 | Your exercise of the Licensed Rights is expressly made subject to the
225 | following conditions.
226 |
227 | a. Attribution.
228 |
229 | 1. If You Share the Licensed Material (including in modified
230 | form), You must:
231 |
232 | a. retain the following if it is supplied by the Licensor
233 | with the Licensed Material:
234 |
235 | i. identification of the creator(s) of the Licensed
236 | Material and any others designated to receive
237 | attribution, in any reasonable manner requested by
238 | the Licensor (including by pseudonym if
239 | designated);
240 |
241 | ii. a copyright notice;
242 |
243 | iii. a notice that refers to this Public License;
244 |
245 | iv. a notice that refers to the disclaimer of
246 | warranties;
247 |
248 | v. a URI or hyperlink to the Licensed Material to the
249 | extent reasonably practicable;
250 |
251 | b. indicate if You modified the Licensed Material and
252 | retain an indication of any previous modifications; and
253 |
254 | c. indicate the Licensed Material is licensed under this
255 | Public License, and include the text of, or the URI or
256 | hyperlink to, this Public License.
257 |
258 | 2. You may satisfy the conditions in Section 3(a)(1) in any
259 | reasonable manner based on the medium, means, and context in
260 | which You Share the Licensed Material. For example, it may be
261 | reasonable to satisfy the conditions by providing a URI or
262 | hyperlink to a resource that includes the required
263 | information.
264 |
265 | 3. If requested by the Licensor, You must remove any of the
266 | information required by Section 3(a)(1)(A) to the extent
267 | reasonably practicable.
268 |
269 | 4. If You Share Adapted Material You produce, the Adapter's
270 | License You apply must not prevent recipients of the Adapted
271 | Material from complying with this Public License.
272 |
273 |
274 | Section 4 -- Sui Generis Database Rights.
275 |
276 | Where the Licensed Rights include Sui Generis Database Rights that
277 | apply to Your use of the Licensed Material:
278 |
279 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right
280 | to extract, reuse, reproduce, and Share all or a substantial
281 | portion of the contents of the database for NonCommercial purposes
282 | only;
283 |
284 | b. if You include all or a substantial portion of the database
285 | contents in a database in which You have Sui Generis Database
286 | Rights, then the database in which You have Sui Generis Database
287 | Rights (but not its individual contents) is Adapted Material; and
288 |
289 | c. You must comply with the conditions in Section 3(a) if You Share
290 | all or a substantial portion of the contents of the database.
291 |
292 | For the avoidance of doubt, this Section 4 supplements and does not
293 | replace Your obligations under this Public License where the Licensed
294 | Rights include other Copyright and Similar Rights.
295 |
296 |
297 | Section 5 -- Disclaimer of Warranties and Limitation of Liability.
298 |
299 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
300 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
301 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
302 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
303 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
304 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
305 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
306 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
307 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
308 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
309 |
310 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
311 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
312 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
313 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
314 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
315 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
316 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
317 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
318 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
319 |
320 | c. The disclaimer of warranties and limitation of liability provided
321 | above shall be interpreted in a manner that, to the extent
322 | possible, most closely approximates an absolute disclaimer and
323 | waiver of all liability.
324 |
325 |
326 | Section 6 -- Term and Termination.
327 |
328 | a. This Public License applies for the term of the Copyright and
329 | Similar Rights licensed here. However, if You fail to comply with
330 | this Public License, then Your rights under this Public License
331 | terminate automatically.
332 |
333 | b. Where Your right to use the Licensed Material has terminated under
334 | Section 6(a), it reinstates:
335 |
336 | 1. automatically as of the date the violation is cured, provided
337 | it is cured within 30 days of Your discovery of the
338 | violation; or
339 |
340 | 2. upon express reinstatement by the Licensor.
341 |
342 | For the avoidance of doubt, this Section 6(b) does not affect any
343 | right the Licensor may have to seek remedies for Your violations
344 | of this Public License.
345 |
346 | c. For the avoidance of doubt, the Licensor may also offer the
347 | Licensed Material under separate terms or conditions or stop
348 | distributing the Licensed Material at any time; however, doing so
349 | will not terminate this Public License.
350 |
351 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
352 | License.
353 |
354 |
355 | Section 7 -- Other Terms and Conditions.
356 |
357 | a. The Licensor shall not be bound by any additional or different
358 | terms or conditions communicated by You unless expressly agreed.
359 |
360 | b. Any arrangements, understandings, or agreements regarding the
361 | Licensed Material not stated herein are separate from and
362 | independent of the terms and conditions of this Public License.
363 |
364 |
365 | Section 8 -- Interpretation.
366 |
367 | a. For the avoidance of doubt, this Public License does not, and
368 | shall not be interpreted to, reduce, limit, restrict, or impose
369 | conditions on any use of the Licensed Material that could lawfully
370 | be made without permission under this Public License.
371 |
372 | b. To the extent possible, if any provision of this Public License is
373 | deemed unenforceable, it shall be automatically reformed to the
374 | minimum extent necessary to make it enforceable. If the provision
375 | cannot be reformed, it shall be severed from this Public License
376 | without affecting the enforceability of the remaining terms and
377 | conditions.
378 |
379 | c. No term or condition of this Public License will be waived and no
380 | failure to comply consented to unless expressly agreed to by the
381 | Licensor.
382 |
383 | d. Nothing in this Public License constitutes or may be interpreted
384 | as a limitation upon, or waiver of, any privileges and immunities
385 | that apply to the Licensor or You, including from the legal
386 | processes of any jurisdiction or authority.
387 |
388 | =======================================================================
389 |
390 | Creative Commons is not a party to its public
391 | licenses. Notwithstanding, Creative Commons may elect to apply one of
392 | its public licenses to material it publishes and in those instances
393 | will be considered the “Licensor.” The text of the Creative Commons
394 | public licenses is dedicated to the public domain under the CC0 Public
395 | Domain Dedication. Except for the limited purpose of indicating that
396 | material is shared under a Creative Commons public license or as
397 | otherwise permitted by the Creative Commons policies published at
398 | creativecommons.org/policies, Creative Commons does not authorize the
399 | use of the trademark "Creative Commons" or any other trademark or logo
400 | of Creative Commons without its prior written consent including,
401 | without limitation, in connection with any unauthorized modifications
402 | to any of its public licenses or any other arrangements,
403 | understandings, or agreements concerning use of licensed material. For
404 | the avoidance of doubt, this paragraph does not form part of the
405 | public licenses.
406 |
407 | Creative Commons may be contacted at creativecommons.org.
408 |
409 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Generating Summaries with Controllable Readability Levels (EMNLP 2023)
2 |
3 | This repository contains the code for the paper "[Generating Summaries with Controllable Readability Levels](https://arxiv.org/pdf/2310.10623)".
4 |
5 | We developed three text generation techniques for controlling readability:
6 |
7 |
8 |
9 |
10 |
11 | (a) illustrates the approach to control the summary readability via fine-grained instructions. (b) shows the RL method where given an input document and the readability level, the policy generates a summary to be scored by our Gaussian-based reward, and (c) shows the lookahead approach which uses a readability score of a future summary to guide the generation.
12 |
13 | ## Environment
14 |
15 | The easiest way to proceed is to create a conda environment:
16 | ```
17 | conda create -n readability_summ python=3.7
18 | conda activate readability_summ
19 | ```
20 |
21 | Further, install PyTorch:
22 |
23 | ```
24 | conda install pytorch torchvision torchaudio cpuonly -c pytorch
25 | ```
26 |
27 | Install the packages required:
28 | ```
29 | pip install -r requirements.txt
30 | ```
31 |
32 | Install trlx (for the RL method):
33 | ```
34 | git clone https://github.com/CarperAI/trlx.git
35 | cd trlx
36 | pip install torch --extra-index-url https://download.pytorch.org/whl/cu118
37 | pip install -e .
38 | ```
39 |
40 |
41 | ## Preprocess data
42 |
43 | For computing the readability scores for CNN/DM, execute:
44 |
45 | ```
46 | cd src/preprocess
47 | python preprocess_cnndm.py
48 | ```
49 |
50 | Generate the prompts:
51 | ```
52 | python generate_prompts_category.py
53 | python generate_prompts_score.py
54 | ```
55 |
56 |
57 | ## Training
58 |
59 | Execute the following commands for training for the prompt-based methods:
60 | ```
61 | cd src/train
62 | ./train_cnndm.sh
63 | ```
64 |
65 | For the RL method, execute:
66 | ```
67 | cd src/train/rl
68 | ./train_rl_cnndm.sh
69 | ```
70 |
71 | ## Inference
72 |
73 | For inference, run:
74 | ```
75 | cd inference/
76 | ./inference_score.sh
77 | ./inference_category.sh
78 | ```
79 |
80 | For lookahead inference, run:
81 | ```
82 | ./inference_lookahead.sh
83 | ```
84 |
85 | ## Security
86 |
87 | See [CONTRIBUTING](CONTRIBUTING.md#security-issue-notifications) for more information.
88 |
89 | ## License Summary
90 |
91 | The documentation is made available under the CC-BY-NC-4.0 License. See the LICENSE file.
92 |
93 | ## Citation
94 |
95 | ```
96 | @inproceedings{ribeiro-etal-2023-generating,
97 | title = "Generating Summaries with Controllable Readability Levels",
98 | author = "Ribeiro, Leonardo F. R. and
99 | Bansal, Mohit and
100 | Dreyer, Markus",
101 | editor = "Bouamor, Houda and
102 | Pino, Juan and
103 | Bali, Kalika",
104 | booktitle = "Proceedings of the 2023 Conference on Empirical Methods in Natural Language Processing",
105 | month = dec,
106 | year = "2023",
107 | address = "Singapore",
108 | publisher = "Association for Computational Linguistics",
109 | url = "https://aclanthology.org/2023.emnlp-main.714",
110 | doi = "10.18653/v1/2023.emnlp-main.714",
111 | pages = "11669--11687",
112 | abstract = "Readability refers to how easily a reader can understand a written text. Several factors affect the readability level, such as the complexity of the text, its subject matter, and the reader{'}s background knowledge. Generating summaries based on different readability levels is critical for enabling knowledge consumption by diverse audiences. However, current text generation approaches lack refined control, resulting in texts that are not customized to readers{'} proficiency levels. In this work, we bridge this gap and study techniques to generate summaries at specified readability levels. Unlike previous methods that focus on a specific readability level (e.g., lay summarization), we generate summaries with fine-grained control over their readability. We develop three text generation techniques for controlling readability: (1) instruction-based readability control, (2) reinforcement learning to minimize the gap between requested and observed readability and (3) a decoding approach that uses lookahead to estimate the readability of upcoming decoding steps. We show that our generation methods significantly improve readability control on news summarization (CNN/DM dataset), as measured by various readability metrics and human judgement, establishing strong baselines for controllable readability in summarization.",
113 | }
114 |
115 | ```
116 |
--------------------------------------------------------------------------------
/images/read-control.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/amazon-science/controllable-readability-summarization/6ecc10458e18cf034136b6be6b07f8e1b7e8f245/images/read-control.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | transformers==4.48.0
2 | rouge-score==0.1.2
3 | accelerate==0.19.0
4 | datasets==2.12.0
5 | deepspeed==0.15.1
6 | evaluate==0.4.0
7 | py-readability-metrics==1.4.4
8 |
--------------------------------------------------------------------------------
/src/inference/inference_category.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | source ~/anaconda3/etc/profile.d/conda.sh
4 |
5 | conda activate readability_summ
6 |
7 | VAL_FILE='../data/test_prompt_category.json'
8 | MODEL_PATH=$1
9 |
10 |
11 | OUTPUT_DIR='outputs/1/'
12 | CUDA_VISIBLE_DEVICES=4 python -u run_summarization.py --model_name_or_path ${MODEL_PATH} \
13 | --output_dir ${OUTPUT_DIR} --text_column input_noprompt --summary_column summary \
14 | --train_file ${VAL_FILE} \
15 | --validation_file ${VAL_FILE} \
16 | --test_file ${VAL_FILE} \
17 | --max_source_length 1024 \
18 | --val_max_target_length 256 \
19 | --max_target_length 256 \
20 | --generation_max_length 256 \
21 | --num_beams 3 \
22 | --source_prefix "Write highlights for this article for a 11 years old student:\n\n" \
23 | --evaluation_strategy "steps" \
24 | --per_device_train_batch_size 1 \
25 | --per_device_eval_batch_size 16 \
26 | --predict_with_generate \
27 | --do_predict &
28 |
29 | P1=$!
30 |
31 |
32 | OUTPUT_DIR='outputs/2/'
33 | CUDA_VISIBLE_DEVICES=5 python -u run_summarization.py --model_name_or_path ${MODEL_PATH} \
34 | --output_dir ${OUTPUT_DIR} --text_column input_noprompt --summary_column summary \
35 | --train_file ${VAL_FILE} \
36 | --validation_file ${VAL_FILE} \
37 | --test_file ${VAL_FILE} \
38 | --max_source_length 1024 \
39 | --val_max_target_length 256 \
40 | --max_target_length 256 \
41 | --generation_max_length 256 \
42 | --num_beams 3 \
43 | --source_prefix "Write highlights for this article for a middle school student:\n\n" \
44 | --evaluation_strategy "steps" \
45 | --per_device_train_batch_size 1 \
46 | --per_device_eval_batch_size 16 \
47 | --predict_with_generate \
48 | --do_predict &
49 |
50 | P2=$!
51 |
52 |
53 | OUTPUT_DIR='outputs/3/'
54 | CUDA_VISIBLE_DEVICES=6 python -u run_summarization.py --model_name_or_path ${MODEL_PATH} \
55 | --output_dir ${OUTPUT_DIR} --text_column input_noprompt --summary_column summary \
56 | --train_file ${VAL_FILE} \
57 | --validation_file ${VAL_FILE} \
58 | --test_file ${VAL_FILE} \
59 | --max_source_length 1024 \
60 | --val_max_target_length 256 \
61 | --max_target_length 256 \
62 | --generation_max_length 256 \
63 | --num_beams 3 \
64 | --source_prefix "Write highlights for this article for a high school student:\n\n" \
65 | --evaluation_strategy "steps" \
66 | --per_device_train_batch_size 1 \
67 | --per_device_eval_batch_size 16 \
68 | --predict_with_generate \
69 | --do_predict &
70 |
71 | P3=$!
72 |
73 |
74 | OUTPUT_DIR='outputs/4/'
75 | CUDA_VISIBLE_DEVICES=7 python -u run_summarization.py --model_name_or_path ${MODEL_PATH} \
76 | --output_dir ${OUTPUT_DIR} --text_column input_noprompt --summary_column summary \
77 | --train_file ${VAL_FILE} \
78 | --validation_file ${VAL_FILE} \
79 | --test_file ${VAL_FILE} \
80 | --max_source_length 1024 \
81 | --val_max_target_length 256 \
82 | --max_target_length 256 \
83 | --generation_max_length 256 \
84 | --num_beams 3 \
85 | --source_prefix "Write highlights for this article for a college student:\n\n" \
86 | --evaluation_strategy "steps" \
87 | --per_device_train_batch_size 1 \
88 | --per_device_eval_batch_size 16 \
89 | --predict_with_generate \
90 | --do_predict &
91 |
92 | P4=$!
93 |
94 | wait $P1 $P2 $P3 $P4
95 |
96 | conda deactivate
--------------------------------------------------------------------------------
/src/inference/inference_lookahead.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | source ~/anaconda3/etc/profile.d/conda.sh
4 |
5 | conda activate readability_summ
6 |
7 | LOOKAHEAD_LENGTH=20
8 | DOC_FILE='../data/test_prompt_category.json'
9 | MODEL_PATH=$1
10 |
11 |
12 | PROMPT="Write highlights for this article for a 11 years old student:\n\n"
13 | OUTPUT_FILE="11yold.txt"
14 | SCORE=90
15 | CUDA_VISIBLE_DEVICES=0 python run_lookahead.py --document_file ${DOC_FILE} --output_file ${OUTPUT_FILE} --do_lookahead --lookahead_decoding_type greedy --model_name ${MODEL_PATH} --lookahead_length ${LOOKAHEAD_LENGTH} \
16 | --prompt "${PROMPT}" --score ${SCORE} &
17 | P1=$!
18 |
19 |
20 | PROMPT="Write highlights for this article for a middle school student:\n\n"
21 | OUTPUT_FILE="middle-school.txt"
22 | SCORE=70
23 | CUDA_VISIBLE_DEVICES=1 python run_lookahead.py --document_file ${DOC_FILE} --output_file ${OUTPUT_FILE} --do_lookahead --lookahead_decoding_type greedy --model_name ${MODEL_PATH} --lookahead_length ${LOOKAHEAD_LENGTH} \
24 | --prompt "${PROMPT}" --score ${SCORE} &
25 | P2=$!
26 |
27 |
28 | PROMPT="Write highlights for this article for a high school student:\n\n"
29 | OUTPUT_FILE="high-school.txt"
30 | SCORE=50
31 | CUDA_VISIBLE_DEVICES=2 python run_lookahead.py --document_file ${DOC_FILE} --output_file ${OUTPUT_FILE} --do_lookahead --lookahead_decoding_type greedy --model_name ${MODEL_PATH} --lookahead_length ${LOOKAHEAD_LENGTH} \
32 | --prompt "${PROMPT}" --score ${SCORE} &
33 | P3=$!
34 |
35 |
36 | PROMPT="Write highlights for this article for a college student:\n\n"
37 | OUTPUT_FILE="college-student.txt"
38 | SCORE=30
39 | CUDA_VISIBLE_DEVICES=3 python run_lookahead.py --document_file ${DOC_FILE} --output_file ${OUTPUT_FILE} --do_lookahead --lookahead_decoding_type greedy --model_name ${MODEL_PATH} --lookahead_length ${LOOKAHEAD_LENGTH} \
40 | --prompt "${PROMPT}" --score ${SCORE} &
41 | P4=$!
42 |
43 | wait $P1 $P2 $P3 $P4
44 |
45 | conda deactivate
--------------------------------------------------------------------------------
/src/inference/inference_score.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | source ~/anaconda3/etc/profile.d/conda.sh
4 |
5 | conda activate readability_summ
6 |
7 | VAL_FILE='../data/test_prompt_score.json'
8 | MODEL_PATH=$1
9 |
10 |
11 | OUTPUT_DIR='outputs/1/'
12 | CUDA_VISIBLE_DEVICES=0 python -u run_summarization.py --model_name_or_path ${MODEL_PATH} \
13 | --output_dir ${OUTPUT_DIR} --text_column input_noprompt --summary_column summary \
14 | --train_file ${VAL_FILE} \
15 | --validation_file ${VAL_FILE} \
16 | --test_file ${VAL_FILE} \
17 | --max_source_length 1024 \
18 | --source_prefix "Write highlights for this article with a flesch kincaid score of 90:\n\n" \
19 | --evaluation_strategy "steps" \
20 | --per_device_train_batch_size 1 \
21 | --per_device_eval_batch_size 16 \
22 | --predict_with_generate \
23 | --do_predict &
24 |
25 | P1=$!
26 |
27 |
28 | OUTPUT_DIR='outputs/2/'
29 | CUDA_VISIBLE_DEVICES=1 python -u run_summarization.py --model_name_or_path ${MODEL_PATH} \
30 | --output_dir ${OUTPUT_DIR} --text_column input_noprompt --summary_column summary \
31 | --train_file ${VAL_FILE} \
32 | --validation_file ${VAL_FILE} \
33 | --test_file ${VAL_FILE} \
34 | --max_source_length 1024 \
35 | --source_prefix "Write highlights for this article with a flesch kincaid score of 70:\n\n" \
36 | --evaluation_strategy "steps" \
37 | --per_device_train_batch_size 1 \
38 | --per_device_eval_batch_size 16 \
39 | --predict_with_generate \
40 | --do_predict &
41 |
42 | P2=$!
43 |
44 |
45 | OUTPUT_DIR='outputs/3/'
46 | CUDA_VISIBLE_DEVICES=2 python -u run_summarization.py --model_name_or_path ${MODEL_PATH} \
47 | --output_dir ${OUTPUT_DIR} --text_column input_noprompt --summary_column summary \
48 | --train_file ${VAL_FILE} \
49 | --validation_file ${VAL_FILE} \
50 | --test_file ${VAL_FILE} \
51 | --max_source_length 1024 \
52 | --source_prefix "Write highlights for this article with a flesch kincaid score of 50:\n\n" \
53 | --evaluation_strategy "steps" \
54 | --per_device_train_batch_size 1 \
55 | --per_device_eval_batch_size 16 \
56 | --predict_with_generate \
57 | --do_predict &
58 |
59 | P3=$!
60 |
61 |
62 | OUTPUT_DIR='outputs/4/'
63 | CUDA_VISIBLE_DEVICES=3 python -u run_summarization.py --model_name_or_path ${MODEL_PATH} \
64 | --output_dir ${OUTPUT_DIR} --text_column input_noprompt --summary_column summary \
65 | --train_file ${VAL_FILE} \
66 | --validation_file ${VAL_FILE} \
67 | --test_file ${VAL_FILE} \
68 | --max_source_length 1024 \
69 | --source_prefix "Write highlights for this article with a flesch kincaid score of 30:\n\n" \
70 | --evaluation_strategy "steps" \
71 | --per_device_train_batch_size 1 \
72 | --per_device_eval_batch_size 16 \
73 | --predict_with_generate \
74 | --do_predict &
75 |
76 | P4=$!
77 |
78 | wait $P1 $P2 $P3 $P4
79 |
80 | conda deactivate
--------------------------------------------------------------------------------
/src/inference/lookahead.py:
--------------------------------------------------------------------------------
1 | import inspect
2 | from sys import prefix
3 | import warnings
4 | from dataclasses import dataclass
5 | from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
6 |
7 | import torch
8 | import torch.distributed as dist
9 | from torch import nn
10 | import torch.nn.functional as F
11 |
12 | import copy
13 |
14 | from transformers.generation_beam_constraints import Constraint, DisjunctiveConstraint, PhrasalConstraint
15 | from transformers.generation_beam_search import BeamScorer, BeamSearchScorer, ConstrainedBeamSearchScorer
16 | from transformers.generation_logits_process import (
17 | EncoderNoRepeatNGramLogitsProcessor,
18 | ExponentialDecayLengthPenalty,
19 | ForcedBOSTokenLogitsProcessor,
20 | ForcedEOSTokenLogitsProcessor,
21 | HammingDiversityLogitsProcessor,
22 | InfNanRemoveLogitsProcessor,
23 | LogitNormalization,
24 | LogitsProcessorList,
25 | MinLengthLogitsProcessor,
26 | NoBadWordsLogitsProcessor,
27 | NoRepeatNGramLogitsProcessor,
28 | PrefixConstrainedLogitsProcessor,
29 | RepetitionPenaltyLogitsProcessor,
30 | TemperatureLogitsWarper,
31 | TopKLogitsWarper,
32 | TopPLogitsWarper,
33 | TypicalLogitsWarper,
34 | )
35 | from transformers.generation_stopping_criteria import (
36 | MaxLengthCriteria,
37 | MaxTimeCriteria,
38 | StoppingCriteria,
39 | StoppingCriteriaList,
40 | validate_stopping_criteria,
41 | )
42 | from transformers.pytorch_utils import torch_int_div
43 | from transformers.utils import ModelOutput, logging
44 |
45 | from transformers.generation_utils import (
46 | GreedySearchEncoderDecoderOutput,
47 | GreedySearchDecoderOnlyOutput,
48 | BeamSearchEncoderDecoderOutput,
49 | BeamSearchDecoderOnlyOutput,
50 | SampleEncoderDecoderOutput,
51 | SampleDecoderOnlyOutput,
52 | )
53 |
54 | logger = logging.get_logger(__name__)
55 |
56 | class Lookahead:
57 | """
58 | Object that performs the lookahead. This is very similar to GenerationMixin, since it needs to decode the sequence as well,
59 | but this contains the additional function to compute heuristics score.
60 | """
61 |
62 | def __init__(
63 | self,
64 | model,
65 | tokenizer,
66 | scorer,
67 | lookahead_length=1,
68 | lookahead_lambda=1.0,
69 | lookahead_top_k=5,
70 | decoding_type="greedy",
71 | max_length: Optional[int] = None,
72 | min_length: Optional[int] = None,
73 | do_sample: Optional[bool] = None,
74 | early_stopping: Optional[bool] = None,
75 | num_beams: Optional[int] = None,
76 | temperature: Optional[float] = None,
77 | top_k: Optional[int] = None,
78 | top_p: Optional[float] = None,
79 | typical_p: Optional[float] = None,
80 | repetition_penalty: Optional[float] = None,
81 | bad_words_ids: Optional[Iterable[int]] = None,
82 | force_words_ids: Optional[Union[Iterable[int], Iterable[Iterable[int]]]] = None,
83 | bos_token_id: Optional[int] = None,
84 | pad_token_id: Optional[int] = None,
85 | eos_token_id: Optional[int] = None,
86 | length_penalty: Optional[float] = None,
87 | no_repeat_ngram_size: Optional[int] = None,
88 | encoder_no_repeat_ngram_size: Optional[int] = None,
89 | num_return_sequences: Optional[int] = None,
90 | max_time: Optional[float] = None,
91 | max_new_tokens: Optional[int] = None,
92 | decoder_start_token_id: Optional[int] = None,
93 | use_cache: Optional[bool] = None,
94 | num_beam_groups: Optional[int] = None,
95 | diversity_penalty: Optional[float] = None,
96 | prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], List[int]]] = None,
97 | logits_processor: Optional[LogitsProcessorList] = LogitsProcessorList(),
98 | renormalize_logits: Optional[bool] = None,
99 | stopping_criteria: Optional[StoppingCriteriaList] = StoppingCriteriaList(),
100 | constraints: Optional[List[Constraint]] = None,
101 | output_attentions: Optional[bool] = None,
102 | output_hidden_states: Optional[bool] = None,
103 | output_scores: Optional[bool] = None,
104 | return_dict_in_generate: Optional[bool] = None,
105 | forced_bos_token_id: Optional[int] = None,
106 | forced_eos_token_id: Optional[int] = None,
107 | remove_invalid_values: Optional[bool] = None,
108 | synced_gpus: Optional[bool] = False,
109 | exponential_decay_length_penalty: Optional[Tuple[Union[int, float]]] = None,
110 | ):
111 | """
112 | model: The Huggingface Model
113 | tokenizer: The tokenizer for decoding the summaries
114 | scorer: Scorer object that calculates the score given document and summary
115 | lookahead_length: The number of tokens to look ahead
116 | lookahead_lambda: The weight for the score
117 | lookahead_top_k: The number of top tokens to consider for expansion
118 | decoding_type: The decoding type for lookahead. [greedy, beam, sample]
119 |
120 | Other parameters are the same arguments expected for GenerationMixin to control the generation
121 | """
122 | self.model = model
123 | self.tokenizer = tokenizer
124 | self.scorer = scorer
125 |
126 | if lookahead_length == -1:
127 | assert max_length is not None
128 | self.lookahead_length = max_length
129 | self.lookahead_until_sent = True
130 | else:
131 | self.lookahead_length = lookahead_length
132 | self.lookahead_until_sent = False
133 |
134 | self.lookahead_lambda = lookahead_lambda
135 | self.lookahead_top_k = lookahead_top_k
136 | self.decoding_type = decoding_type
137 |
138 | if self.decoding_type == "greedy":
139 | self.decoding_func = self.greedy_search
140 | elif self.decoding_type == "beam":
141 | self.decoding_func = self.beam_search
142 | elif self.decoding_type == "sample":
143 | self.decoding_func = self.sample
144 |
145 | # generation parameters from generate()
146 | self.bos_token_id = self.model.config.bos_token_id
147 | self.num_beams = num_beams if num_beams is not None else self.model.config.num_beams
148 | self.length_penalty = length_penalty if length_penalty is not None else self.model.config.length_penalty
149 | self.early_stopping = early_stopping if early_stopping is not None else self.model.config.early_stopping
150 | self.num_beam_groups = num_beam_groups if num_beam_groups is not None else self.model.config.num_beam_groups
151 | self.num_return_sequences = (
152 | num_return_sequences if num_return_sequences is not None else self.model.config.num_return_sequences
153 | )
154 |
155 | self.pad_token_id = self.model.config.pad_token_id
156 | self.eos_token_id = self.model.config.eos_token_id
157 |
158 | if self.eos_token_id is None and hasattr(self.model.config, "decoder"):
159 | self.eos_token_id = self.model.config.decoder.eos_token_id
160 |
161 | if self.pad_token_id is None and self.eos_token_id is not None:
162 | # special case if pad_token_id is not defined
163 | logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{self.eos_token_id} for open-end generation.")
164 | self.pad_token_id = self.eos_token_id
165 | self.max_length = max_length
166 | self.min_length = min_length
167 | self.temperature = temperature
168 | self.top_k = top_k
169 | self.top_p = top_p
170 | self.typical_p = typical_p
171 | self.reptition_penality = repetition_penalty
172 | self.bad_words_ids = bad_words_ids
173 | self.force_words_ids = force_words_ids
174 | self.no_repeat_ngram_size = no_repeat_ngram_size
175 | self.encoder_no_repeat_ngram_size = encoder_no_repeat_ngram_size
176 | self.max_new_tokens = max_new_tokens
177 | self.decoder_start_token_id = decoder_start_token_id
178 | self.use_cache = use_cache
179 | self.diversity_penalty = diversity_penalty
180 | self.prefix_allowed_tokens_fn = prefix_allowed_tokens_fn
181 | self.renormalize_logits = renormalize_logits
182 | self.contraints = constraints
183 | self.forced_bos_token_id = forced_bos_token_id
184 | self.forced_eos_token_id = forced_eos_token_id
185 | self.remove_invalid_values = remove_invalid_values
186 | self.exponential_decay_length_penalty = exponential_decay_length_penalty
187 | self.synced_gpus = synced_gpus
188 |
189 | # self.return_dict_in_generate = return_dict_in_generate
190 | self.return_dict_in_generate = True
191 | self.output_attentions = output_attentions
192 | self.output_hidden_states = output_hidden_states
193 | self.output_scores = output_scores
194 |
195 | # If not provided, logits processor will be prepared later since it requires input_tensor
196 | self.logits_processor = logits_processor
197 |
198 | # prepare stopping criteria
199 | self.stopping_criteria = self.model._get_stopping_criteria(
200 | max_length=max_length, max_time=max_time, stopping_criteria=stopping_criteria
201 | )
202 |
203 | self.logits_warper = self.model._get_logits_warper(
204 | top_k=self.top_k,
205 | top_p=self.top_p,
206 | typical_p=self.typical_p,
207 | temperature=self.temperature,
208 | num_beams=self.num_beams,
209 | renormalize_logits=self.renormalize_logits,
210 | )
211 |
212 |
213 | def score(
214 | self,
215 | input_ids,
216 | next_token_scores,
217 | num_beams=1,
218 | **model_kwargs,
219 | ):
220 | """
221 | Main function to call for the lookahead. This function generates the sequences and return the calculated heurstics
222 | """
223 |
224 | # prepare for generation
225 | if self.logits_processor is None:
226 | input_ids_seq_length = input_ids.size(1)
227 | inputs_tensor = model_kwargs["encoder_outputs"][self.model.main_input_name]
228 |
229 | self.logits_processor = self.model._get_logits_processor(
230 | repetition_penalty=self.repetition_penalty,
231 | no_repeat_ngram_size=self.no_repeat_ngram_size,
232 | encoder_no_repeat_ngram_size=self.encoder_no_repeat_ngram_size,
233 | input_ids_seq_length=input_ids_seq_length,
234 | encoder_input_ids=inputs_tensor,
235 | bad_words_ids=self.bad_words_ids,
236 | min_length=self.min_length,
237 | max_length=self.max_length,
238 | eos_token_id=self.eos_token_id,
239 | forced_bos_token_id=self.forced_bos_token_id,
240 | forced_eos_token_id=self.forced_eos_token_id,
241 | prefix_allowed_tokens_fn=self.prefix_allowed_tokens_fn,
242 | num_beams=self.num_beams,
243 | num_beam_groups=self.num_beam_groups,
244 | diversity_penalty=self.diversity_penalty,
245 | remove_invalid_values=self.remove_invalid_values,
246 | exponential_decay_length_penalty=self.exponential_decay_length_penalty,
247 | logits_processor=self.logits_processor,
248 | renormalize_logits=self.renormalize_logits,
249 | )
250 |
251 | do_sample = "sample" in self.decoding_type
252 | use_beam = "beam" in self.decoding_type
253 | beam_scorer = None
254 |
255 | if use_beam:
256 | batch_size = input_ids.shape[0] * self.lookahead_top_k
257 | beam_scorer = BeamSearchScorer(
258 | batch_size=batch_size,
259 | num_beams=self.num_beams,
260 | max_length=self.stopping_criteria.max_length,
261 | device=input_ids.device,
262 | length_penalty=self.length_penalty,
263 | do_early_stopping=self.early_stopping,
264 | num_beam_hyps_to_keep=self.num_return_sequences,
265 | num_beam_groups=self.num_beam_groups,
266 | )
267 |
268 | indices = torch.arange(input_ids.size(0), dtype=input_ids.dtype, device=input_ids.device)
269 |
270 | # expand for top k tokens to use with scorer
271 | _, top_k_indices = torch.topk(next_token_scores, k=self.lookahead_top_k, dim=-1)
272 | top_k_indices = top_k_indices.reshape(-1)
273 |
274 | indices = indices.repeat_interleave(self.lookahead_top_k)
275 | input_ids = torch.cat([input_ids[indices],top_k_indices.unsqueeze(1)], dim=1)
276 |
277 | # adjust model_kwargs
278 | model_kwargs = self.expand_model_kwargs(model_kwargs, indices)
279 |
280 | # expand if necssary for beam, currently ignoring sampling with multiple num sequences
281 | if use_beam:
282 | input_ids, model_kwargs = self.model._expand_inputs_for_generation(
283 | input_ids,
284 | expand_size=self.num_beams,
285 | is_encoder_decoder=self.model.config.is_encoder_decoder,
286 | **model_kwargs,
287 | )
288 | indices = indices.repeat_interleave(self.num_beams)
289 | # exapand inputs for generation but does not expand past
290 | if "past" in model_kwargs:
291 | model_kwargs["past"] = tuple([tuple([p.repeat_interleave(self.num_beams, dim=0) for p in past]) for past in model_kwargs["past"]])
292 |
293 | # calling the respective decoding function
294 | # the only difference between this implementation and the original is the addition of lookahead length and breaking once that is reached
295 | if self.lookahead_length == 0:
296 | seq = input_ids
297 | else:
298 | dec_out = self.decoding_func(input_ids, beam_scorer, **model_kwargs)
299 | seq = dec_out["sequences"]
300 |
301 | # generate the actual summary
302 | dec_seq = self.tokenizer.batch_decode(seq, skip_special_tokens=True)
303 |
304 | # calculate score given the heuristics, need to account for different indices when doing beam search
305 | # import pdb
306 | # pdb.set_trace()
307 | _lookahead_scores = self.scorer.score(dec_seq, torch.div(indices, num_beams, rounding_mode="trunc"))
308 | _lookahead_scores = torch.clamp(_lookahead_scores,min=1e-9).log()
309 |
310 | _lookahead_scores = _lookahead_scores.view(-1, self.lookahead_top_k, self.num_beams)
311 | _lookahead_scores, _ = _lookahead_scores.max(-1)
312 |
313 | lookahead_scores = torch.ones_like(next_token_scores, dtype=_lookahead_scores.dtype, device=next_token_scores.device) * 1e-9
314 | lookahead_scores = lookahead_scores.log()
315 |
316 | next_token_scores = F.log_softmax(next_token_scores, dim=-1)
317 |
318 | if use_beam:
319 | # remove repat interleave for beams
320 | indices = indices.view(-1,self.num_beams)[:,0]
321 |
322 | lookahead_scores[indices, top_k_indices] = _lookahead_scores.view(-1)
323 |
324 | return self.lookahead_lambda * lookahead_scores
325 |
326 | def greedy_search(
327 | self,
328 | input_ids: torch.LongTensor,
329 | beam_scorer = None,
330 | **model_kwargs,
331 | ):
332 | # init attention / hidden states / scores tuples
333 | scores = () if (self.return_dict_in_generate and self.output_scores) else None
334 | decoder_attentions = () if (self.return_dict_in_generate and self.output_attentions) else None
335 | cross_attentions = () if (self.return_dict_in_generate and self.output_attentions) else None
336 | decoder_hidden_states = () if (self.return_dict_in_generate and self.output_hidden_states) else None
337 |
338 | # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
339 | if self.return_dict_in_generate and self.model.config.is_encoder_decoder:
340 | encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if self.output_attentions else None
341 | encoder_hidden_states = (
342 | model_kwargs["encoder_outputs"].get("hidden_states") if self.output_hidden_states else None
343 | )
344 |
345 | # keep track of which sequences are already finished
346 | unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
347 | cur_len = input_ids.shape[-1]
348 |
349 | lookahead_length = self.lookahead_length + cur_len
350 |
351 | this_peer_finished = False # used by synced_gpus only
352 | while True:
353 |
354 | if self.synced_gpus:
355 | # Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
356 | # The following logic allows an early break if all peers finished generating their sequence
357 | this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
358 | # send 0.0 if we finished, 1.0 otherwise
359 | dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
360 | # did all peers finish? the reduced sum will be 0.0 then
361 | if this_peer_finished_flag.item() == 0.0:
362 | break
363 |
364 | # prepare model inputs
365 | model_inputs = self.model.prepare_inputs_for_generation(input_ids, **model_kwargs)
366 |
367 | # forward pass to get next token
368 | outputs = self.model(
369 | **model_inputs,
370 | return_dict=True,
371 | output_attentions=self.output_attentions,
372 | output_hidden_states=self.output_hidden_states,
373 | )
374 |
375 | if self.synced_gpus and this_peer_finished:
376 | cur_len = cur_len + 1
377 | continue # don't waste resources running the code we don't need
378 |
379 | next_token_logits = outputs.logits[:, -1, :]
380 |
381 | # Store scores, attentions and hidden_states when required
382 | if self.return_dict_in_generate:
383 | if self.output_scores:
384 | scores += (next_token_logits,)
385 | if self.output_attentions:
386 | decoder_attentions += (
387 | (outputs.decoder_attentions,) if self.model.config.is_encoder_decoder else (outputs.attentions,)
388 | )
389 | if self.model.config.is_encoder_decoder:
390 | cross_attentions += (outputs.cross_attentions,)
391 |
392 | if self.output_hidden_states:
393 | decoder_hidden_states += (
394 | (outputs.decoder_hidden_states,)
395 | if self.model.config.is_encoder_decoder
396 | else (outputs.hidden_states,)
397 | )
398 |
399 | # pre-process distribution
400 | next_tokens_scores = self.logits_processor(input_ids, next_token_logits)
401 |
402 | # argmax
403 | next_tokens = torch.argmax(next_tokens_scores, dim=-1)
404 |
405 | # finished sentences should have their next token be a padding token
406 | if self.eos_token_id is not None:
407 | if self.pad_token_id is None:
408 | raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
409 | next_tokens = next_tokens * unfinished_sequences + self.pad_token_id * (1 - unfinished_sequences)
410 |
411 | # update generated ids, model inputs, and length for next step
412 | input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
413 | model_kwargs = self.model._update_model_kwargs_for_generation(
414 | outputs, model_kwargs, is_encoder_decoder=self.model.config.is_encoder_decoder
415 | )
416 | cur_len = cur_len + 1
417 |
418 | # Lookahead break
419 | if cur_len >= lookahead_length:
420 | break
421 |
422 | # if eos_token was found in one sentence, set sentence to finished
423 | if self.eos_token_id is not None:
424 | unfinished_sequences = unfinished_sequences.mul((next_tokens != self.eos_token_id).long())
425 |
426 | # stop when each sentence is finished, or if we exceed the maximum length
427 | if unfinished_sequences.max() == 0 or self.stopping_criteria(input_ids, scores):
428 | if not self.synced_gpus:
429 | break
430 | else:
431 | this_peer_finished = True
432 |
433 | if self.return_dict_in_generate:
434 | if self.model.config.is_encoder_decoder:
435 | return GreedySearchEncoderDecoderOutput(
436 | sequences=input_ids,
437 | scores=scores,
438 | encoder_attentions=encoder_attentions,
439 | encoder_hidden_states=encoder_hidden_states,
440 | decoder_attentions=decoder_attentions,
441 | cross_attentions=cross_attentions,
442 | decoder_hidden_states=decoder_hidden_states,
443 | )
444 | else:
445 | return GreedySearchDecoderOnlyOutput(
446 | sequences=input_ids,
447 | scores=scores,
448 | attentions=decoder_attentions,
449 | hidden_states=decoder_hidden_states,
450 | )
451 | else:
452 | return input_ids
453 |
454 | def beam_search(
455 | self,
456 | input_ids: torch.LongTensor,
457 | beam_scorer = None,
458 | **model_kwargs,
459 | ):
460 | batch_size = len(beam_scorer._beam_hyps)
461 | num_beams = beam_scorer.num_beams
462 |
463 | batch_beam_size, cur_len = input_ids.shape
464 |
465 | lookahead_length = self.lookahead_length + cur_len
466 |
467 | if num_beams * batch_size != batch_beam_size:
468 | raise ValueError(
469 | f"Batch dimension of `input_ids` should be {num_beams * batch_size}, but is {batch_beam_size}."
470 | )
471 |
472 | # init attention / hidden states / scores tuples
473 | scores = () if (self.return_dict_in_generate and self.output_scores) else None
474 | beam_indices = (
475 | tuple(() for _ in range(batch_beam_size)) if (self.return_dict_in_generate and self.output_scores) else None
476 | )
477 | decoder_attentions = () if (self.return_dict_in_generate and self.output_attentions) else None
478 | cross_attentions = () if (self.return_dict_in_generate and self.output_attentions) else None
479 | decoder_hidden_states = () if (self.return_dict_in_generate and self.output_hidden_states) else None
480 |
481 | # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
482 | if self.return_dict_in_generate and self.model.config.is_encoder_decoder:
483 | encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if self.output_attentions else None
484 | encoder_hidden_states = (
485 | model_kwargs["encoder_outputs"].get("hidden_states") if self.output_hidden_states else None
486 | )
487 |
488 | beam_scores = torch.zeros((batch_size, num_beams), dtype=torch.float, device=input_ids.device)
489 | beam_scores[:, 1:] = -1e9
490 | beam_scores = beam_scores.view((batch_size * num_beams,))
491 |
492 | this_peer_finished = False # used by synced_gpus only
493 | while True:
494 |
495 | if self.synced_gpus:
496 | # Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
497 | # The following logic allows an early break if all peers finished generating their sequence
498 | this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
499 | # send 0.0 if we finished, 1.0 otherwise
500 | dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
501 | # did all peers finish? the reduced sum will be 0.0 then
502 | if this_peer_finished_flag.item() == 0.0:
503 | break
504 |
505 | model_inputs = self.model.prepare_inputs_for_generation(input_ids, **model_kwargs)
506 |
507 | outputs = self.model(
508 | **model_inputs,
509 | return_dict=True,
510 | output_attentions=self.output_attentions,
511 | output_hidden_states=self.output_hidden_states,
512 | )
513 |
514 | if self.synced_gpus and this_peer_finished:
515 | cur_len = cur_len + 1
516 | continue # don't waste resources running the code we don't need
517 |
518 | next_token_logits = outputs.logits[:, -1, :]
519 | # hack: adjust tokens for Marian. For Marian we have to make sure that the `pad_token_id`
520 | # cannot be generated both before and after the `nn.functional.log_softmax` operation.
521 | next_token_logits = self.model.adjust_logits_during_generation(next_token_logits, cur_len=cur_len)
522 | next_token_scores = nn.functional.log_softmax(
523 | next_token_logits, dim=-1
524 | ) # (batch_size * num_beams, vocab_size)
525 |
526 | next_token_scores_processed = self.logits_processor(input_ids, next_token_scores)
527 | next_token_scores = next_token_scores_processed + beam_scores[:, None].expand_as(next_token_scores)
528 |
529 | # Store scores, attentions and hidden_states when required
530 | if self.return_dict_in_generate:
531 | if self.output_scores:
532 | scores += (next_token_scores_processed,)
533 | if self.output_attentions:
534 | decoder_attentions += (
535 | (outputs.decoder_attentions,) if self.model.config.is_encoder_decoder else (outputs.attentions,)
536 | )
537 | if self.model.config.is_encoder_decoder:
538 | cross_attentions += (outputs.cross_attentions,)
539 |
540 | if self.output_hidden_states:
541 | decoder_hidden_states += (
542 | (outputs.decoder_hidden_states,)
543 | if self.model.config.is_encoder_decoder
544 | else (outputs.hidden_states,)
545 | )
546 |
547 | # reshape for beam search
548 | vocab_size = next_token_scores.shape[-1]
549 | next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size)
550 |
551 | next_token_scores, next_tokens = torch.topk(
552 | next_token_scores, 2 * num_beams, dim=1, largest=True, sorted=True
553 | )
554 |
555 | next_indices = torch_int_div(next_tokens, vocab_size)
556 | next_tokens = next_tokens % vocab_size
557 |
558 | # stateless
559 | beam_outputs = beam_scorer.process(
560 | input_ids,
561 | next_token_scores,
562 | next_tokens,
563 | next_indices,
564 | pad_token_id=self.pad_token_id,
565 | eos_token_id=self.eos_token_id,
566 | )
567 |
568 | beam_scores = beam_outputs["next_beam_scores"]
569 | beam_next_tokens = beam_outputs["next_beam_tokens"]
570 | beam_idx = beam_outputs["next_beam_indices"]
571 |
572 | input_ids = torch.cat([input_ids[beam_idx, :], beam_next_tokens.unsqueeze(-1)], dim=-1)
573 |
574 | model_kwargs = self.model._update_model_kwargs_for_generation(
575 | outputs, model_kwargs, is_encoder_decoder=self.model.config.is_encoder_decoder
576 | )
577 | if model_kwargs["past"] is not None:
578 | model_kwargs["past"] = self.model._reorder_cache(model_kwargs["past"], beam_idx)
579 |
580 | if self.return_dict_in_generate and self.output_scores:
581 | beam_indices = tuple((beam_indices[beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices))))
582 |
583 | # increase cur_len
584 | cur_len = cur_len + 1
585 |
586 | if cur_len >= lookahead_length:
587 | break
588 |
589 | if beam_scorer.is_done or self.stopping_criteria(input_ids, scores):
590 | if not self.synced_gpus:
591 | break
592 | else:
593 | this_peer_finished = True
594 |
595 | sequence_outputs = beam_scorer.finalize(
596 | input_ids,
597 | beam_scores,
598 | next_tokens,
599 | next_indices,
600 | pad_token_id=self.pad_token_id,
601 | eos_token_id=self.eos_token_id,
602 | max_length=self.stopping_criteria.max_length,
603 | )
604 |
605 | if self.return_dict_in_generate:
606 | if not self.output_scores:
607 | sequence_outputs["sequence_scores"] = None
608 | else:
609 | num_return_sequences = beam_scorer.num_beam_hyps_to_keep
610 | # return only as many indices as sequences
611 | beam_indices = tuple(
612 | (beam_indices[i * num_beams : i * num_beams + num_return_sequences] for i in range(batch_size))
613 | )
614 | beam_indices = sum(beam_indices, ())
615 |
616 | if self.model.config.is_encoder_decoder:
617 | return BeamSearchEncoderDecoderOutput(
618 | sequences=sequence_outputs["sequences"],
619 | sequences_scores=sequence_outputs["sequence_scores"],
620 | scores=scores,
621 | beam_indices=beam_indices,
622 | encoder_attentions=encoder_attentions,
623 | encoder_hidden_states=encoder_hidden_states,
624 | decoder_attentions=decoder_attentions,
625 | cross_attentions=cross_attentions,
626 | decoder_hidden_states=decoder_hidden_states,
627 | )
628 | else:
629 | return BeamSearchDecoderOnlyOutput(
630 | sequences=sequence_outputs["sequences"],
631 | sequences_scores=sequence_outputs["sequence_scores"],
632 | scores=scores,
633 | beam_indices=beam_indices,
634 | attentions=decoder_attentions,
635 | hidden_states=decoder_hidden_states,
636 | )
637 | else:
638 | return sequence_outputs["sequences"]
639 |
640 | def sample(
641 | self,
642 | input_ids: torch.LongTensor,
643 | beam_scorer = None,
644 | **model_kwargs,
645 | ):
646 | scores = () if (self.return_dict_in_generate and self.output_scores) else None
647 | decoder_attentions = () if (self.return_dict_in_generate and self.output_attentions) else None
648 | cross_attentions = () if (self.return_dict_in_generate and self.output_attentions) else None
649 | decoder_hidden_states = () if (self.return_dict_in_generate and self.output_hidden_states) else None
650 |
651 | # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
652 | if self.return_dict_in_generate and self.model.config.is_encoder_decoder:
653 | encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if self.output_attentions else None
654 | encoder_hidden_states = (
655 | model_kwargs["encoder_outputs"].get("hidden_states") if self.output_hidden_states else None
656 | )
657 |
658 | # keep track of which sequences are already finished
659 | unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
660 | cur_len = input_ids.shape[-1]
661 |
662 | lookahead_length = self.lookahead_length + cur_len
663 |
664 | this_peer_finished = False # used by synced_gpus only
665 | # auto-regressive generation
666 | while True:
667 |
668 | if self.synced_gpus:
669 | # Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
670 | # The following logic allows an early break if all peers finished generating their sequence
671 | this_peer_finished_flag = torch.tensor(0.0 if this_peer_finished else 1.0).to(input_ids.device)
672 | # send 0.0 if we finished, 1.0 otherwise
673 | dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
674 | # did all peers finish? the reduced sum will be 0.0 then
675 | if this_peer_finished_flag.item() == 0.0:
676 | break
677 |
678 | # prepare model inputs
679 | model_inputs = self.model.prepare_inputs_for_generation(input_ids, **model_kwargs)
680 |
681 | # forward pass to get next token
682 | outputs = self.model(
683 | **model_inputs,
684 | return_dict=True,
685 | output_attentions=self.output_attentions,
686 | output_hidden_states=self.output_hidden_states,
687 | )
688 |
689 | if self.synced_gpus and this_peer_finished:
690 | cur_len = cur_len + 1
691 | continue # don't waste resources running the code we don't need
692 |
693 | next_token_logits = outputs.logits[:, -1, :]
694 |
695 | # pre-process distribution
696 | next_token_scores = self.logits_processor(input_ids, next_token_logits)
697 | next_token_scores = self.logits_warper(input_ids, next_token_scores)
698 |
699 | # Store scores, attentions and hidden_states when required
700 | if self.return_dict_in_generate:
701 | if self.output_scores:
702 | scores += (next_token_scores,)
703 | if self.output_attentions:
704 | decoder_attentions += (
705 | (outputs.decoder_attentions,) if self.model.config.is_encoder_decoder else (outputs.attentions,)
706 | )
707 | if self.model.config.is_encoder_decoder:
708 | cross_attentions += (outputs.cross_attentions,)
709 |
710 | if self.output_hidden_states:
711 | decoder_hidden_states += (
712 | (outputs.decoder_hidden_states,)
713 | if self.model.config.is_encoder_decoder
714 | else (outputs.hidden_states,)
715 | )
716 |
717 | # sample
718 | probs = nn.functional.softmax(next_token_scores, dim=-1)
719 | next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
720 |
721 | # finished sentences should have their next token be a padding token
722 | if self.eos_token_id is not None:
723 | if self.pad_token_id is None:
724 | raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
725 | next_tokens = next_tokens * unfinished_sequences + self.pad_token_id * (1 - unfinished_sequences)
726 |
727 | # update generated ids, model inputs, and length for next step
728 | input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
729 | model_kwargs = self.model._update_model_kwargs_for_generation(
730 | outputs, model_kwargs, is_encoder_decoder=self.model.config.is_encoder_decoder
731 | )
732 | cur_len = cur_len + 1
733 |
734 | if cur_len >= lookahead_length:
735 | break
736 |
737 | # if eos_token was found in one sentence, set sentence to finished
738 | if self.eos_token_id is not None:
739 | unfinished_sequences = unfinished_sequences.mul((next_tokens != self.eos_token_id).long())
740 |
741 | # stop when each sentence is finished, or if we exceed the maximum length
742 | if unfinished_sequences.max() == 0 or self.stopping_criteria(input_ids, scores):
743 | if not self.synced_gpus:
744 | break
745 | else:
746 | this_peer_finished = True
747 |
748 | if self.return_dict_in_generate:
749 | if self.model.config.is_encoder_decoder:
750 | return SampleEncoderDecoderOutput(
751 | sequences=input_ids,
752 | scores=scores,
753 | encoder_attentions=encoder_attentions,
754 | encoder_hidden_states=encoder_hidden_states,
755 | decoder_attentions=decoder_attentions,
756 | cross_attentions=cross_attentions,
757 | decoder_hidden_states=decoder_hidden_states,
758 | )
759 | else:
760 | return SampleDecoderOnlyOutput(
761 | sequences=input_ids,
762 | scores=scores,
763 | attentions=decoder_attentions,
764 | hidden_states=decoder_hidden_states,
765 | )
766 | else:
767 | return input_ids
768 |
769 |
770 | def expand_model_kwargs(self, model_kwargs, indices):
771 | model_kwargs = copy.deepcopy(model_kwargs)
772 | if "attention_mask" in model_kwargs:
773 | model_kwargs["attention_mask"] = model_kwargs["attention_mask"][indices]
774 | if "encoder_outputs" in model_kwargs:
775 | for k,v in model_kwargs["encoder_outputs"].items():
776 | if v is not None:
777 | model_kwargs["encoder_outputs"][k] = v[indices]
778 | if "past" in model_kwargs:
779 | model_kwargs["past"] = tuple([tuple([p[indices] for p in past]) for past in model_kwargs["past"]])
780 | return model_kwargs
--------------------------------------------------------------------------------
/src/inference/run_lookahead.py:
--------------------------------------------------------------------------------
1 | from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
2 | from scorer import FleschScorer
3 | from lookahead import Lookahead
4 | from generation import Generator
5 | from tqdm import tqdm
6 |
7 | import json
8 |
9 | import argparse
10 |
11 | def open_file(file):
12 | entities = []
13 |
14 | for line in open(file).readlines():
15 | entities.append(json.loads(line))
16 |
17 | return entities
18 |
19 | parser = argparse.ArgumentParser()
20 |
21 | # base decoding model
22 | parser.add_argument("--model_name", type=str, default="facebook/bart-large-xsum")
23 | parser.add_argument("--cache_dir", type=str, default="./cache")
24 |
25 | # input output
26 | parser.add_argument("--document_file", type=str, required=True)
27 | parser.add_argument("--output_file", type=str, required=True)
28 |
29 | # base decoding configuration. Please refer to Huggingface's GenerationMixin for the explaination of the parameters
30 | parser.add_argument("--batch_size", type=int, default=8)
31 | parser.add_argument("--score", type=int, default=30)
32 | parser.add_argument("--prompt", type=str, default="")
33 | parser.add_argument("--num_beams", type=int, default=1)
34 | parser.add_argument("--num_return_sequences", type=int, default=1)
35 | parser.add_argument("--max_input_length", type=int, default=1024)
36 | parser.add_argument("--max_output_length", type=int, default=256)
37 | parser.add_argument("--do_sample", action='store_true', default=False)
38 |
39 | # lookahead configuration
40 | parser.add_argument("--do_lookahead", action="store_true", default=False)
41 | parser.add_argument("--lookahead_length", type=int, default=64)
42 | parser.add_argument("--lookahead_lambda", type=int, default=25)
43 | parser.add_argument("--top_k", type=int, default=5)
44 | parser.add_argument("--lookahead_decoding_type", type=str, default="greedy", choices=["greedy","beam","sample"])
45 | parser.add_argument("--lookahead_beam", type=int, default=1)
46 |
47 | # scorer configuration
48 | parser.add_argument("--scorer_model_type", type=str, default="roberta-large")
49 | parser.add_argument("--scorer_num_layers", type=int, default=17)
50 |
51 | args = parser.parse_args()
52 |
53 | # loading model
54 | tokenizer = AutoTokenizer.from_pretrained(args.model_name, cache_dir=args.cache_dir)
55 | model = AutoModelForSeq2SeqLM.from_pretrained(args.model_name, cache_dir=args.cache_dir)
56 | model = model.cuda() # can optionally call .half() for mixed precision
57 |
58 | # loading input
59 | documents = open_file(args.document_file)
60 | documents = [args.prompt + doc["input_noprompt"] for doc in documents]
61 |
62 | scorer = FleschScorer(
63 | 'flesch',
64 | args.score
65 | )
66 |
67 | # Create lookahead
68 | lookahead = None
69 | if args.do_lookahead:
70 | lookahead = Lookahead(
71 | model,
72 | tokenizer,
73 | scorer,
74 | lookahead_length=args.lookahead_length,
75 | lookahead_lambda=args.lookahead_lambda,
76 | lookahead_top_k=args.top_k,
77 | decoding_type=args.lookahead_decoding_type,
78 | num_beams=args.lookahead_beam,
79 | num_return_sequences=args.lookahead_beam,
80 | max_length=args.max_output_length,
81 | )
82 |
83 | # Create generator with lookahead
84 | generator = Generator(model, lookahead=lookahead)
85 |
86 | summaries = []
87 |
88 | for i in tqdm(range(0, len(documents), args.batch_size)):
89 | input_str = documents[i:i+args.batch_size]
90 |
91 | inputs = tokenizer(input_str, max_length=args.max_input_length, padding=True, truncation=True, return_tensors="pt")
92 |
93 | inputs = {k:v.cuda() for k,v in inputs.items()}
94 |
95 | output = generator.generate(
96 | input_ids = inputs["input_ids"],
97 | attention_mask=inputs["attention_mask"],
98 | num_beams=args.num_beams,
99 | num_return_sequences=args.num_return_sequences,
100 | max_length=args.max_output_length,
101 | do_sample=args.do_sample,
102 | )
103 |
104 | output = tokenizer.batch_decode(output, skip_special_tokens=True)
105 |
106 | if args.num_return_sequences == 1:
107 | summaries += output
108 | else:
109 | for i in range(0, len(output), args.num_return_sequences):
110 | summaries.append(output[i:i+args.num_return_sequences])
111 |
112 | # Save file
113 | with open(args.output_file, "w") as f:
114 | if args.num_return_sequences == 1:
115 | for line in summaries:
116 | f.write(line + "\n")
117 | else:
118 | json.dump(summaries, f)
--------------------------------------------------------------------------------
/src/inference/run_summarization.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding=utf-8
3 | # Copyright 2021 The HuggingFace Team. All rights reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | """
17 | Fine-tuning the library models for sequence to sequence.
18 | """
19 | # You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments.
20 |
21 | import logging
22 | import os
23 | import sys
24 | from dataclasses import dataclass, field
25 | from typing import Optional
26 |
27 | import datasets
28 | import nltk # Here to have a nice missing dependency error message early on
29 | import numpy as np
30 | from datasets import load_dataset
31 |
32 | import evaluate
33 | import transformers
34 | from filelock import FileLock
35 | from transformers import (
36 | AutoConfig,
37 | AutoModelForSeq2SeqLM,
38 | AutoTokenizer,
39 | DataCollatorForSeq2Seq,
40 | HfArgumentParser,
41 | MBart50Tokenizer,
42 | MBart50TokenizerFast,
43 | MBartTokenizer,
44 | MBartTokenizerFast,
45 | Seq2SeqTrainer,
46 | Seq2SeqTrainingArguments,
47 | set_seed,
48 | )
49 | from transformers.trainer_utils import get_last_checkpoint
50 | from transformers.utils import check_min_version, is_offline_mode, send_example_telemetry
51 | from transformers.utils.versions import require_version
52 |
53 | os.environ["NCCL_DEBUG"] = "INFO"
54 |
55 | # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
56 | #check_min_version("4.25.0.dev0")
57 |
58 | require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt")
59 |
60 | logger = logging.getLogger(__name__)
61 |
62 | try:
63 | nltk.data.find("tokenizers/punkt")
64 | except (LookupError, OSError):
65 | if is_offline_mode():
66 | raise LookupError(
67 | "Offline mode: run this script without TRANSFORMERS_OFFLINE first to download nltk data files"
68 | )
69 | with FileLock(".lock") as lock:
70 | nltk.download("punkt", quiet=True)
71 |
72 | # A list of all multilingual tokenizer which require lang attribute.
73 | MULTILINGUAL_TOKENIZERS = [MBartTokenizer, MBartTokenizerFast, MBart50Tokenizer, MBart50TokenizerFast]
74 |
75 |
76 | @dataclass
77 | class ModelArguments:
78 | """
79 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
80 | """
81 |
82 | model_name_or_path: str = field(
83 | metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
84 | )
85 | config_name: Optional[str] = field(
86 | default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
87 | )
88 | tokenizer_name: Optional[str] = field(
89 | default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
90 | )
91 | cache_dir: Optional[str] = field(
92 | default=None,
93 | metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"},
94 | )
95 | use_fast_tokenizer: bool = field(
96 | default=True,
97 | metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
98 | )
99 | model_revision: str = field(
100 | default="main",
101 | metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
102 | )
103 | use_auth_token: bool = field(
104 | default=False,
105 | metadata={
106 | "help": (
107 | "Will use the token generated when running `huggingface-cli login` (necessary to use this script "
108 | "with private models)."
109 | )
110 | },
111 | )
112 | resize_position_embeddings: Optional[bool] = field(
113 | default=None,
114 | metadata={
115 | "help": (
116 | "Whether to automatically resize the position embeddings if `max_source_length` exceeds "
117 | "the model's position embeddings."
118 | )
119 | },
120 | )
121 |
122 |
123 | @dataclass
124 | class DataTrainingArguments:
125 | """
126 | Arguments pertaining to what data we are going to input our model for training and eval.
127 | """
128 |
129 | lang: Optional[str] = field(default=None, metadata={"help": "Language id for summarization."})
130 |
131 | dataset_name: Optional[str] = field(
132 | default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
133 | )
134 | dataset_config_name: Optional[str] = field(
135 | default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
136 | )
137 | text_column: Optional[str] = field(
138 | default=None,
139 | metadata={"help": "The name of the column in the datasets containing the full texts (for summarization)."},
140 | )
141 | summary_column: Optional[str] = field(
142 | default=None,
143 | metadata={"help": "The name of the column in the datasets containing the summaries (for summarization)."},
144 | )
145 | train_file: Optional[str] = field(
146 | default=None, metadata={"help": "The input training data file (a jsonlines or csv file)."}
147 | )
148 | validation_file: Optional[str] = field(
149 | default=None,
150 | metadata={
151 | "help": (
152 | "An optional input evaluation data file to evaluate the metrics (rouge) on (a jsonlines or csv file)."
153 | )
154 | },
155 | )
156 | test_file: Optional[str] = field(
157 | default=None,
158 | metadata={
159 | "help": "An optional input test data file to evaluate the metrics (rouge) on (a jsonlines or csv file)."
160 | },
161 | )
162 | overwrite_cache: bool = field(
163 | default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
164 | )
165 | preprocessing_num_workers: Optional[int] = field(
166 | default=None,
167 | metadata={"help": "The number of processes to use for the preprocessing."},
168 | )
169 | max_source_length: Optional[int] = field(
170 | default=1024,
171 | metadata={
172 | "help": (
173 | "The maximum total input sequence length after tokenization. Sequences longer "
174 | "than this will be truncated, sequences shorter will be padded."
175 | )
176 | },
177 | )
178 | max_target_length: Optional[int] = field(
179 | default=128,
180 | metadata={
181 | "help": (
182 | "The maximum total sequence length for target text after tokenization. Sequences longer "
183 | "than this will be truncated, sequences shorter will be padded."
184 | )
185 | },
186 | )
187 | val_max_target_length: Optional[int] = field(
188 | default=None,
189 | metadata={
190 | "help": (
191 | "The maximum total sequence length for validation target text after tokenization. Sequences longer "
192 | "than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`."
193 | "This argument is also used to override the ``max_length`` param of ``model.generate``, which is used "
194 | "during ``evaluate`` and ``predict``."
195 | )
196 | },
197 | )
198 | pad_to_max_length: bool = field(
199 | default=False,
200 | metadata={
201 | "help": (
202 | "Whether to pad all samples to model maximum sentence length. "
203 | "If False, will pad the samples dynamically when batching to the maximum length in the batch. More "
204 | "efficient on GPU but very bad for TPU."
205 | )
206 | },
207 | )
208 | max_train_samples: Optional[int] = field(
209 | default=None,
210 | metadata={
211 | "help": (
212 | "For debugging purposes or quicker training, truncate the number of training examples to this "
213 | "value if set."
214 | )
215 | },
216 | )
217 | max_eval_samples: Optional[int] = field(
218 | default=None,
219 | metadata={
220 | "help": (
221 | "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
222 | "value if set."
223 | )
224 | },
225 | )
226 | max_predict_samples: Optional[int] = field(
227 | default=None,
228 | metadata={
229 | "help": (
230 | "For debugging purposes or quicker training, truncate the number of prediction examples to this "
231 | "value if set."
232 | )
233 | },
234 | )
235 | num_beams: Optional[int] = field(
236 | default=None,
237 | metadata={
238 | "help": (
239 | "Number of beams to use for evaluation. This argument will be passed to ``model.generate``, "
240 | "which is used during ``evaluate`` and ``predict``."
241 | )
242 | },
243 | )
244 | ignore_pad_token_for_loss: bool = field(
245 | default=True,
246 | metadata={
247 | "help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not."
248 | },
249 | )
250 | source_prefix: Optional[str] = field(
251 | default="", metadata={"help": "A prefix to add before every source text (useful for T5 models)."}
252 | )
253 |
254 | forced_bos_token: Optional[str] = field(
255 | default=None,
256 | metadata={
257 | "help": (
258 | "The token to force as the first generated token after the decoder_start_token_id."
259 | "Useful for multilingual models like mBART where the first generated token"
260 | "needs to be the target language token (Usually it is the target language token)"
261 | )
262 | },
263 | )
264 |
265 | def __post_init__(self):
266 | if self.dataset_name is None and self.train_file is None and self.validation_file is None:
267 | raise ValueError("Need either a dataset name or a training/validation file.")
268 | else:
269 | if self.train_file is not None:
270 | extension = self.train_file.split(".")[-1]
271 | assert extension in ["csv", "json"], "`train_file` should be a csv or a json file."
272 | if self.validation_file is not None:
273 | extension = self.validation_file.split(".")[-1]
274 | assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file."
275 | if self.val_max_target_length is None:
276 | self.val_max_target_length = self.max_target_length
277 |
278 |
279 | summarization_name_mapping = {
280 | "amazon_reviews_multi": ("review_body", "review_title"),
281 | "big_patent": ("description", "abstract"),
282 | "cnn_dailymail": ("article", "highlights"),
283 | "orange_sum": ("text", "summary"),
284 | "pn_summary": ("article", "summary"),
285 | "psc": ("extract_text", "summary_text"),
286 | "samsum": ("dialogue", "summary"),
287 | "thaisum": ("body", "summary"),
288 | "xglue": ("news_body", "news_title"),
289 | "xsum": ("document", "summary"),
290 | "wiki_summary": ("article", "highlights"),
291 | "multi_news": ("document", "summary"),
292 | }
293 |
294 |
295 | def main():
296 | # See all possible arguments in src/transformers/training_args.py
297 | # or by passing the --help flag to this script.
298 | # We now keep distinct sets of args, for a cleaner separation of concerns.
299 |
300 | parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments))
301 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
302 | # If we pass only one argument to the script and it's the path to a json file,
303 | # let's parse it to get our arguments.
304 | model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
305 | else:
306 | model_args, data_args, training_args = parser.parse_args_into_dataclasses()
307 |
308 | # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
309 | # information sent is the one passed as arguments along with your Python/PyTorch versions.
310 | send_example_telemetry("run_summarization", model_args, data_args)
311 | print("training_args", training_args)
312 | # Setup logging
313 | logging.basicConfig(
314 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
315 | datefmt="%m/%d/%Y %H:%M:%S",
316 | handlers=[logging.StreamHandler(sys.stdout)],
317 | )
318 | log_level = training_args.get_process_log_level()
319 | logger.setLevel(log_level)
320 | datasets.utils.logging.set_verbosity(log_level)
321 | transformers.utils.logging.set_verbosity(log_level)
322 | transformers.utils.logging.enable_default_handler()
323 | transformers.utils.logging.enable_explicit_format()
324 |
325 | # Log on each process the small summary:
326 | logger.warning(
327 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
328 | + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
329 | )
330 | logger.info(f"Training/evaluation parameters {training_args}")
331 |
332 | if data_args.source_prefix is None and model_args.model_name_or_path in [
333 | "t5-small",
334 | "t5-base",
335 | "t5-large",
336 | "t5-3b",
337 | "t5-11b",
338 | ]:
339 | logger.warning(
340 | "You're running a t5 model but didn't provide a source prefix, which is the expected, e.g. with "
341 | "`--source_prefix 'summarize: ' `"
342 | )
343 |
344 | # Detecting last checkpoint.
345 | last_checkpoint = None
346 | if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
347 | last_checkpoint = get_last_checkpoint(training_args.output_dir)
348 | if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
349 | raise ValueError(
350 | f"Output directory ({training_args.output_dir}) already exists and is not empty. "
351 | "Use --overwrite_output_dir to overcome."
352 | )
353 | elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
354 | logger.info(
355 | f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
356 | "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
357 | )
358 |
359 | # Set seed before initializing model.
360 | set_seed(training_args.seed)
361 |
362 | # Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below)
363 | # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
364 | # (the dataset will be downloaded automatically from the datasets Hub).
365 | #
366 | # For CSV/JSON files this script will use the first column for the full texts and the second column for the
367 | # summaries (unless you specify column names for this with the `text_column` and `summary_column` arguments).
368 | #
369 | # In distributed training, the load_dataset function guarantee that only one local process can concurrently
370 | # download the dataset.
371 | print("data_args", data_args)
372 | if data_args.dataset_name is not None:
373 | # Downloading and loading a dataset from the hub.
374 | raw_datasets = load_dataset(
375 | data_args.dataset_name,
376 | data_args.dataset_config_name,
377 | cache_dir=model_args.cache_dir,
378 | use_auth_token=True if model_args.use_auth_token else None,
379 | )
380 | else:
381 | data_files = {}
382 | if data_args.train_file is not None:
383 | data_files["train"] = data_args.train_file
384 | extension = data_args.train_file.split(".")[-1]
385 | if data_args.validation_file is not None:
386 | data_files["validation"] = data_args.validation_file
387 | extension = data_args.validation_file.split(".")[-1]
388 | if data_args.test_file is not None:
389 | data_files["test"] = data_args.test_file
390 | extension = data_args.test_file.split(".")[-1]
391 | raw_datasets = load_dataset(
392 | extension,
393 | data_files=data_files,
394 | cache_dir=model_args.cache_dir,
395 | use_auth_token=True if model_args.use_auth_token else None,
396 | )
397 | # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
398 | # https://huggingface.co/docs/datasets/loading_datasets.html.
399 |
400 | # Load pretrained model and tokenizer
401 | #
402 | # Distributed training:
403 | # The .from_pretrained methods guarantee that only one local process can concurrently
404 | # download model & vocab.
405 | print("model_args", model_args)
406 | config = AutoConfig.from_pretrained(
407 | model_args.config_name if model_args.config_name else model_args.model_name_or_path,
408 | cache_dir=model_args.cache_dir,
409 | revision=model_args.model_revision,
410 | use_auth_token=True if model_args.use_auth_token else None,
411 | )
412 | tokenizer = AutoTokenizer.from_pretrained(
413 | model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
414 | cache_dir=model_args.cache_dir,
415 | use_fast=model_args.use_fast_tokenizer,
416 | revision=model_args.model_revision,
417 | use_auth_token=True if model_args.use_auth_token else None,
418 | )
419 | model = AutoModelForSeq2SeqLM.from_pretrained(
420 | model_args.model_name_or_path,
421 | from_tf=bool(".ckpt" in model_args.model_name_or_path),
422 | config=config,
423 | cache_dir=model_args.cache_dir,
424 | revision=model_args.model_revision,
425 | use_auth_token=True if model_args.use_auth_token else None,
426 | )
427 |
428 | # We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch
429 | # on a small vocab and want a smaller embedding size, remove this test.
430 | embedding_size = model.get_input_embeddings().weight.shape[0]
431 | if len(tokenizer) > embedding_size:
432 | model.resize_token_embeddings(len(tokenizer))
433 |
434 | if model.config.decoder_start_token_id is None and isinstance(tokenizer, (MBartTokenizer, MBartTokenizerFast)):
435 | if isinstance(tokenizer, MBartTokenizer):
436 | model.config.decoder_start_token_id = tokenizer.lang_code_to_id[data_args.lang]
437 | else:
438 | model.config.decoder_start_token_id = tokenizer.convert_tokens_to_ids(data_args.lang)
439 |
440 | if model.config.decoder_start_token_id is None:
441 | raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")
442 |
443 | if (
444 | hasattr(model.config, "max_position_embeddings")
445 | and model.config.max_position_embeddings < data_args.max_source_length
446 | ):
447 | if model_args.resize_position_embeddings is None:
448 | logger.warning(
449 | "Increasing the model's number of position embedding vectors from"
450 | f" {model.config.max_position_embeddings} to {data_args.max_source_length}."
451 | )
452 | model.resize_position_embeddings(data_args.max_source_length)
453 | elif model_args.resize_position_embeddings:
454 | model.resize_position_embeddings(data_args.max_source_length)
455 | else:
456 | raise ValueError(
457 | f"`--max_source_length` is set to {data_args.max_source_length}, but the model only has"
458 | f" {model.config.max_position_embeddings} position encodings. Consider either reducing"
459 | f" `--max_source_length` to {model.config.max_position_embeddings} or to automatically resize the"
460 | " model's position encodings by passing `--resize_position_embeddings`."
461 | )
462 |
463 | prefix = data_args.source_prefix if data_args.source_prefix is not None else ""
464 |
465 | # Preprocessing the datasets.
466 | # We need to tokenize inputs and targets.
467 | if training_args.do_train:
468 | column_names = raw_datasets["train"].column_names
469 | elif training_args.do_eval:
470 | column_names = raw_datasets["validation"].column_names
471 | elif training_args.do_predict:
472 | column_names = raw_datasets["test"].column_names
473 | else:
474 | logger.info("There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`.")
475 | return
476 |
477 | if isinstance(tokenizer, tuple(MULTILINGUAL_TOKENIZERS)):
478 | assert (
479 | data_args.lang is not None
480 | ), f"{tokenizer.__class__.__name__} is a multilingual tokenizer which requires --lang argument"
481 |
482 | tokenizer.src_lang = data_args.lang
483 | tokenizer.tgt_lang = data_args.lang
484 |
485 | # For multilingual translation models like mBART-50 and M2M100 we need to force the target language token
486 | # as the first generated token. We ask the user to explicitly provide this as --forced_bos_token argument.
487 | forced_bos_token_id = (
488 | tokenizer.lang_code_to_id[data_args.forced_bos_token] if data_args.forced_bos_token is not None else None
489 | )
490 | model.config.forced_bos_token_id = forced_bos_token_id
491 |
492 | # Get the column names for input/target.
493 | dataset_columns = summarization_name_mapping.get(data_args.dataset_name, None)
494 | if data_args.text_column is None:
495 | text_column = dataset_columns[0] if dataset_columns is not None else column_names[0]
496 | else:
497 | text_column = data_args.text_column
498 | if text_column not in column_names:
499 | raise ValueError(
500 | f"--text_column' value '{data_args.text_column}' needs to be one of: {', '.join(column_names)}"
501 | )
502 | if data_args.summary_column is None:
503 | summary_column = dataset_columns[1] if dataset_columns is not None else column_names[1]
504 | else:
505 | summary_column = data_args.summary_column
506 | if summary_column not in column_names:
507 | raise ValueError(
508 | f"--summary_column' value '{data_args.summary_column}' needs to be one of: {', '.join(column_names)}"
509 | )
510 |
511 | # Temporarily set max_target_length for training.
512 | max_target_length = data_args.max_target_length
513 | padding = "max_length" if data_args.pad_to_max_length else False
514 |
515 | if training_args.label_smoothing_factor > 0 and not hasattr(model, "prepare_decoder_input_ids_from_labels"):
516 | logger.warning(
517 | "label_smoothing is enabled but the `prepare_decoder_input_ids_from_labels` method is not defined for"
518 | f"`{model.__class__.__name__}`. This will lead to loss being calculated twice and will take up more memory"
519 | )
520 |
521 | print(data_args)
522 |
523 | def preprocess_function(examples):
524 | # remove pairs where at least one record is None
525 |
526 | inputs, targets = [], []
527 | for i in range(len(examples[text_column])):
528 | if examples[text_column][i] and examples[summary_column][i]:
529 | inputs.append(examples[text_column][i])
530 | targets.append(examples[summary_column][i])
531 |
532 | inputs = [prefix + inp for inp in inputs]
533 | model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, padding=padding, truncation=True)
534 |
535 | # Tokenize targets with the `text_target` keyword argument
536 | labels = tokenizer(text_target=targets, max_length=max_target_length, padding=padding, truncation=True)
537 |
538 | # If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore
539 | # padding in the loss.
540 | if padding == "max_length" and data_args.ignore_pad_token_for_loss:
541 | labels["input_ids"] = [
542 | [(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"]
543 | ]
544 |
545 | model_inputs["labels"] = labels["input_ids"]
546 | return model_inputs
547 |
548 | if training_args.do_train:
549 | if "train" not in raw_datasets:
550 | raise ValueError("--do_train requires a train dataset")
551 | train_dataset = raw_datasets["train"]
552 | if data_args.max_train_samples is not None:
553 | max_train_samples = min(len(train_dataset), data_args.max_train_samples)
554 | train_dataset = train_dataset.select(range(max_train_samples))
555 | with training_args.main_process_first(desc="train dataset map pre-processing"):
556 | train_dataset = train_dataset.map(
557 | preprocess_function,
558 | batched=True,
559 | num_proc=data_args.preprocessing_num_workers,
560 | remove_columns=column_names,
561 | load_from_cache_file=not data_args.overwrite_cache,
562 | desc="Running tokenizer on train dataset",
563 | )
564 |
565 | if training_args.do_eval:
566 | max_target_length = data_args.val_max_target_length
567 | if "validation" not in raw_datasets:
568 | raise ValueError("--do_eval requires a validation dataset")
569 | eval_dataset = raw_datasets["validation"]
570 | if data_args.max_eval_samples is not None:
571 | max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
572 | eval_dataset = eval_dataset.select(range(max_eval_samples))
573 | with training_args.main_process_first(desc="validation dataset map pre-processing"):
574 | eval_dataset = eval_dataset.map(
575 | preprocess_function,
576 | batched=True,
577 | num_proc=data_args.preprocessing_num_workers,
578 | remove_columns=column_names,
579 | load_from_cache_file=not data_args.overwrite_cache,
580 | desc="Running tokenizer on validation dataset",
581 | )
582 |
583 | if training_args.do_predict:
584 | max_target_length = data_args.val_max_target_length
585 | if "test" not in raw_datasets:
586 | raise ValueError("--do_predict requires a test dataset")
587 | predict_dataset = raw_datasets["test"]
588 | if data_args.max_predict_samples is not None:
589 | max_predict_samples = min(len(predict_dataset), data_args.max_predict_samples)
590 | predict_dataset = predict_dataset.select(range(max_predict_samples))
591 | with training_args.main_process_first(desc="prediction dataset map pre-processing"):
592 | predict_dataset = predict_dataset.map(
593 | preprocess_function,
594 | batched=True,
595 | num_proc=data_args.preprocessing_num_workers,
596 | remove_columns=column_names,
597 | load_from_cache_file=not data_args.overwrite_cache,
598 | desc="Running tokenizer on prediction dataset",
599 | )
600 |
601 | # Data collator
602 | label_pad_token_id = -100 if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
603 | data_collator = DataCollatorForSeq2Seq(
604 | tokenizer,
605 | model=model,
606 | label_pad_token_id=label_pad_token_id,
607 | pad_to_multiple_of=8 if training_args.fp16 else None,
608 | )
609 |
610 | # Metric
611 | metric = evaluate.load("rouge")
612 |
613 | def postprocess_text(preds, labels):
614 | preds = [pred.strip() for pred in preds]
615 | labels = [label.strip() for label in labels]
616 |
617 | # rougeLSum expects newline after each sentence
618 | preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds]
619 | labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels]
620 |
621 | return preds, labels
622 |
623 | def compute_metrics(eval_preds):
624 | preds, labels = eval_preds
625 | if isinstance(preds, tuple):
626 | preds = preds[0]
627 | decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
628 | if data_args.ignore_pad_token_for_loss:
629 | # Replace -100 in the labels as we can't decode them.
630 | labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
631 | decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
632 |
633 | # Some simple post-processing
634 | decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)
635 |
636 | result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
637 | result = {k: round(v * 100, 4) for k, v in result.items()}
638 | prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]
639 | result["gen_len"] = np.mean(prediction_lens)
640 | return result
641 |
642 | # Initialize our Trainer
643 | trainer = Seq2SeqTrainer(
644 | model=model,
645 | args=training_args,
646 | train_dataset=train_dataset if training_args.do_train else None,
647 | eval_dataset=eval_dataset if training_args.do_eval else None,
648 | tokenizer=tokenizer,
649 | data_collator=data_collator,
650 | compute_metrics=compute_metrics if training_args.predict_with_generate else None,
651 | )
652 |
653 | # Training
654 | if training_args.do_train:
655 | checkpoint = None
656 | if training_args.resume_from_checkpoint is not None:
657 | checkpoint = training_args.resume_from_checkpoint
658 | elif last_checkpoint is not None:
659 | checkpoint = last_checkpoint
660 | train_result = trainer.train(resume_from_checkpoint=checkpoint)
661 | trainer.save_model() # Saves the tokenizer too for easy upload
662 |
663 | metrics = train_result.metrics
664 | max_train_samples = (
665 | data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)
666 | )
667 | metrics["train_samples"] = min(max_train_samples, len(train_dataset))
668 |
669 | trainer.log_metrics("train", metrics)
670 | trainer.save_metrics("train", metrics)
671 | trainer.save_state()
672 |
673 | # Evaluation
674 | results = {}
675 | max_length = (
676 | training_args.generation_max_length
677 | if training_args.generation_max_length is not None
678 | else data_args.val_max_target_length
679 | )
680 | num_beams = data_args.num_beams if data_args.num_beams is not None else training_args.generation_num_beams
681 | # if training_args.do_eval:
682 | # logger.info("*** Evaluate ***")
683 | # metrics = trainer.evaluate(max_length=max_length, num_beams=num_beams, metric_key_prefix="eval")
684 | # max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
685 | # metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
686 | #
687 | # trainer.log_metrics("eval", metrics)
688 | # trainer.save_metrics("eval", metrics)
689 |
690 | if training_args.do_predict:
691 | logger.info("*** Predict ***")
692 |
693 | predict_results = trainer.predict(
694 | predict_dataset, metric_key_prefix="predict", max_length=max_length, num_beams=num_beams
695 | )
696 | metrics = predict_results.metrics
697 | max_predict_samples = (
698 | data_args.max_predict_samples if data_args.max_predict_samples is not None else len(predict_dataset)
699 | )
700 | metrics["predict_samples"] = min(max_predict_samples, len(predict_dataset))
701 |
702 | trainer.log_metrics("predict", metrics)
703 | trainer.save_metrics("predict", metrics)
704 |
705 | if trainer.is_world_process_zero():
706 | if training_args.predict_with_generate:
707 | predictions = tokenizer.batch_decode(
708 | predict_results.predictions, skip_special_tokens=True, clean_up_tokenization_spaces=True
709 | )
710 | predictions = [pred.strip() for pred in predictions]
711 | output_prediction_file = os.path.join(training_args.output_dir, "generated_predictions.txt")
712 | with open(output_prediction_file, "w") as writer:
713 | writer.write("\n".join(predictions))
714 |
715 | kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "summarization"}
716 | if data_args.dataset_name is not None:
717 | kwargs["dataset_tags"] = data_args.dataset_name
718 | if data_args.dataset_config_name is not None:
719 | kwargs["dataset_args"] = data_args.dataset_config_name
720 | kwargs["dataset"] = f"{data_args.dataset_name} {data_args.dataset_config_name}"
721 | else:
722 | kwargs["dataset"] = data_args.dataset_name
723 |
724 | if data_args.lang is not None:
725 | kwargs["language"] = data_args.lang
726 |
727 | if training_args.push_to_hub:
728 | trainer.push_to_hub(**kwargs)
729 | else:
730 | trainer.create_model_card(**kwargs)
731 |
732 | return results
733 |
734 |
735 | def _mp_fn(index):
736 | # For xla_spawn (TPUs)
737 | main()
738 |
739 |
740 | if __name__ == "__main__":
741 | main()
742 |
--------------------------------------------------------------------------------
/src/inference/scorer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from transformers import AutoModel, AutoTokenizer
5 |
6 | class BERTScoreScorer:
7 | """
8 | Scorer using BS-Fact, code adapted from bertscore official repo: https://github.com/Tiiiger/bert_score
9 | """
10 | def __init__(self, model_name="roberta-large", device="cuda", num_layers=17, cache_dir=".cache"):
11 | model = AutoModel.from_pretrained(model_name, cache_dir=cache_dir)
12 | self.tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir)
13 | # We assume we are using roberta-large, please reference https://github.com/Tiiiger/bert_score/blob/dbcf6db37e8bd6ff68446f06b0ba5d0763b62d20/bert_score/utils.py#L247
14 | # if you wish to use other model and select the recommended layer
15 | model.encoder.layer = torch.nn.ModuleList([layer for layer in model.encoder.layer[:num_layers]])
16 |
17 | self.model = model.to(device)
18 | self.device = device
19 |
20 | def prepare_document(self, input_str):
21 | """
22 | Prepare anything that requires processing on document.
23 | This is called each iteration only once to save computation.
24 | """
25 | self.bertscore_input_embedding, self.bertscore_input_attention_mask, self.bertscore_input_idf = self.encode_text(input_str)
26 |
27 | def score(self, summaries, index):
28 | """
29 | Output the score for each example.
30 | summaries: The summary strings
31 | index: The indice of example (document that it should be compared to). IT should ideally be just range() except for beam search.
32 | """
33 | bertscore_output_embedding, bertscore_output_attention_mask, bertscore_output_idf = self.encode_text(summaries)
34 |
35 | bertscore_input_embedding = self.bertscore_input_embedding[index]
36 | bertscore_input_attention_mask = self.bertscore_input_attention_mask[index]
37 | bertscore_input_idf = self.bertscore_input_idf[index]
38 |
39 | bertscore_scores = self.compute_bertscore(
40 | bertscore_input_embedding,
41 | bertscore_input_attention_mask,
42 | bertscore_input_idf,
43 | bertscore_output_embedding,
44 | bertscore_output_attention_mask,
45 | bertscore_output_idf,
46 | )
47 | return bertscore_scores
48 |
49 | def encode_text(self, input_str):
50 | """
51 | Helper function to encode any string to tensor using the tokenizer
52 | """
53 | inputs = self.tokenizer(input_str, padding=True, truncation=True, return_tensors="pt")
54 | inputs = {k:v.to(self.device) for k,v in inputs.items()}
55 | with torch.no_grad():
56 | outputs = self.model(**inputs)
57 |
58 | # idf
59 | idf = torch.clone(inputs["attention_mask"]).float()
60 | idf[idf == self.tokenizer.sep_token_id] = 0
61 | idf[idf == self.tokenizer.cls_token_id] = 0
62 | idf.div_(idf.sum(dim=1, keepdim=True))
63 |
64 | return F.normalize(outputs[0], dim=-1), inputs["attention_mask"], idf
65 |
66 | def compute_bertscore(self, doc_embedding, doc_masks, doc_idf, summ_embedding, summ_masks, summ_idf):
67 | """
68 | Helper function that is modified from the official code (greedy_cos_idf() method) https://github.com/Tiiiger/bert_score/blob/dbcf6db37e8bd6ff68446f06b0ba5d0763b62d20/bert_score/utils.py#L469
69 | """
70 |
71 | batch_size = doc_embedding.size(0)
72 | sim = torch.bmm(summ_embedding, doc_embedding.transpose(1, 2))
73 | masks = torch.bmm(summ_masks.unsqueeze(2).float(), doc_masks.unsqueeze(1).float())
74 | masks = masks.expand(batch_size, -1, -1).contiguous().view_as(sim)
75 |
76 | masks = masks.float().to(sim.device)
77 | sim = sim * masks
78 |
79 | precision = sim.max(dim=2)[0]
80 | precision_scale = summ_idf.to(precision.device)
81 | P = (precision * precision_scale).sum(dim=1)
82 |
83 | summ_zero_mask = summ_masks.sum(dim=1).eq(2)
84 | if torch.any(summ_zero_mask):
85 | P = P.masked_fill(summ_zero_mask, 0.0)
86 |
87 | doc_zero_mask = doc_masks.sum(dim=1).eq(2)
88 | if torch.any(doc_zero_mask):
89 | P = P.masked_fill(doc_zero_mask, 0.0)
90 |
91 | return P
92 |
93 | from readability import Readability
94 |
95 | def get_flesch_kincaid(text):
96 | r = Readability(text)
97 | fk = r.flesch_kincaid()
98 | return fk.score
99 |
100 |
101 | def get_flesch(text):
102 | r = Readability(text)
103 | f = r.flesch()
104 | return f.score
105 |
106 | class FleschScorer:
107 | """
108 | Scorer using BS-Fact, code adapted from bertscore official repo: https://github.com/Tiiiger/bert_score
109 | """
110 |
111 | def __init__(self, name_module, flesch_score, device="cuda"):
112 | self.name_module = name_module
113 | self.flesch_score = flesch_score
114 | self.device = device
115 |
116 | def score(self, summaries, index):
117 | """
118 | Output the score for each example.
119 | summaries: The summary strings
120 | index: The indice of example (document that it should be compared to). IT should ideally be just range() except for beam search.
121 | """
122 |
123 | flesch_scores = []
124 | for text in summaries:
125 | try:
126 | flesch_scores.append(get_flesch(text))
127 | except:
128 | flesch_scores.append(100)
129 | flesch_scores = [1 - (abs(fs - self.flesch_score) / 100) for fs in flesch_scores]
130 |
131 | return torch.tensor(flesch_scores).to(self.device)
132 |
133 |
134 | class BERTandFleschScoreScorer:
135 | """
136 | Scorer using BS-Fact, code adapted from bertscore official repo: https://github.com/Tiiiger/bert_score
137 | """
138 |
139 | def __init__(self, model_name="roberta-large", device="cuda", num_layers=17, cache_dir=".cache", flesch_score=50,
140 | readability_weight=0.8):
141 | model = AutoModel.from_pretrained(model_name, cache_dir=cache_dir)
142 | self.tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir)
143 | # We assume we are using roberta-large, please reference https://github.com/Tiiiger/bert_score/blob/dbcf6db37e8bd6ff68446f06b0ba5d0763b62d20/bert_score/utils.py#L247
144 | # if you wish to use other model and select the recommended layer
145 | model.encoder.layer = torch.nn.ModuleList([layer for layer in model.encoder.layer[:num_layers]])
146 |
147 | self.model = model.to(device)
148 | self.device = device
149 | self.flesch_score = flesch_score
150 | self.readability_weight = readability_weight
151 |
152 | def prepare_document(self, input_str):
153 | """
154 | Prepare anything that requires processing on document.
155 | This is called each iteration only once to save computation.
156 | """
157 | self.bertscore_input_embedding, self.bertscore_input_attention_mask, self.bertscore_input_idf = self.encode_text(
158 | input_str)
159 |
160 | def score(self, summaries, index):
161 | """
162 | Output the score for each example.
163 | summaries: The summary strings
164 | index: The indice of example (document that it should be compared to). IT should ideally be just range() except for beam search.
165 | """
166 | bertscore_output_embedding, bertscore_output_attention_mask, bertscore_output_idf = self.encode_text(summaries)
167 |
168 | bertscore_input_embedding = self.bertscore_input_embedding[index]
169 | bertscore_input_attention_mask = self.bertscore_input_attention_mask[index]
170 | bertscore_input_idf = self.bertscore_input_idf[index]
171 |
172 | bertscore_scores = self.compute_bertscore(
173 | bertscore_input_embedding,
174 | bertscore_input_attention_mask,
175 | bertscore_input_idf,
176 | bertscore_output_embedding,
177 | bertscore_output_attention_mask,
178 | bertscore_output_idf,
179 | )
180 |
181 | flesch_scores = []
182 | for text in summaries:
183 | try:
184 | flesch_scores.append(get_flesch(text))
185 | except:
186 | flesch_scores.append(100)
187 | flesch_scores = [1 - (abs(fs - self.flesch_score) / 100) for fs in flesch_scores]
188 |
189 | flesch_scores = torch.tensor(flesch_scores).to(self.device)
190 | assert flesch_scores.size() == bertscore_scores.size()
191 |
192 | # import pdb
193 | # pdb.set_trace()
194 |
195 | return self.readability_weight * flesch_scores + (1 - self.readability_weight) * bertscore_scores
196 |
197 | def encode_text(self, input_str):
198 | """
199 | Helper function to encode any string to tensor using the tokenizer
200 | """
201 | inputs = self.tokenizer(input_str, padding=True, truncation=True, return_tensors="pt")
202 | inputs = {k: v.to(self.device) for k, v in inputs.items()}
203 | with torch.no_grad():
204 | outputs = self.model(**inputs)
205 |
206 | # idf
207 | idf = torch.clone(inputs["attention_mask"]).float()
208 | idf[idf == self.tokenizer.sep_token_id] = 0
209 | idf[idf == self.tokenizer.cls_token_id] = 0
210 | idf.div_(idf.sum(dim=1, keepdim=True))
211 |
212 | return F.normalize(outputs[0], dim=-1), inputs["attention_mask"], idf
213 |
214 | def compute_bertscore(self, doc_embedding, doc_masks, doc_idf, summ_embedding, summ_masks, summ_idf):
215 | """
216 | Helper function that is modified from the official code (greedy_cos_idf() method) https://github.com/Tiiiger/bert_score/blob/dbcf6db37e8bd6ff68446f06b0ba5d0763b62d20/bert_score/utils.py#L469
217 | """
218 |
219 | batch_size = doc_embedding.size(0)
220 | sim = torch.bmm(summ_embedding, doc_embedding.transpose(1, 2))
221 | masks = torch.bmm(summ_masks.unsqueeze(2).float(), doc_masks.unsqueeze(1).float())
222 | masks = masks.expand(batch_size, -1, -1).contiguous().view_as(sim)
223 |
224 | masks = masks.float().to(sim.device)
225 | sim = sim * masks
226 |
227 | precision = sim.max(dim=2)[0]
228 | precision_scale = summ_idf.to(precision.device)
229 | P = (precision * precision_scale).sum(dim=1)
230 |
231 | summ_zero_mask = summ_masks.sum(dim=1).eq(2)
232 | if torch.any(summ_zero_mask):
233 | P = P.masked_fill(summ_zero_mask, 0.0)
234 |
235 | doc_zero_mask = doc_masks.sum(dim=1).eq(2)
236 | if torch.any(doc_zero_mask):
237 | P = P.masked_fill(doc_zero_mask, 0.0)
238 |
239 | return P
--------------------------------------------------------------------------------
/src/preprocess/generate_prompts_category.py:
--------------------------------------------------------------------------------
1 | import json
2 | from tqdm import tqdm
3 |
4 | def open_txt_file(file):
5 | entities = []
6 |
7 | for line in open(file).readlines():
8 | entities.append(line)
9 |
10 | return entities
11 |
12 |
13 | def open_file(file):
14 | entities = []
15 |
16 | for line in open(file).readlines():
17 | entities.append(json.loads(line))
18 |
19 | return entities
20 |
21 |
22 | def save_file(data, file):
23 | file_writer = open(file, 'w')
24 |
25 | for line in data:
26 | file_writer.write(json.dumps(line) + "\n")
27 |
28 |
29 | def get_prompt(flesch_summary):
30 | if flesch_summary >= 80:
31 | prompt = 'Write highlights for this article for a 11 years old student:\n\n'
32 | elif 80 > flesch_summary >= 60:
33 | prompt = 'Write highlights for this article for a middle school student:\n\n'
34 | elif 60 > flesch_summary >= 40:
35 | prompt = 'Write highlights for this article for a high school student:\n\n'
36 | else:
37 | prompt = 'Write highlights for this article for a college student:\n\n'
38 | return prompt
39 |
40 |
41 | def transform_data(split):
42 | data = open_file('../data/' + split + '.json')
43 | new_data = []
44 |
45 | for entry in tqdm(data):
46 |
47 | flesch_summary = entry["summary_metrics"]["flesch"]
48 |
49 | prompt = get_prompt(flesch_summary)
50 | entry["prompt"] = prompt
51 | entry["input_noprompt"] = entry["input"]
52 | entry["input"] = prompt + entry["input"]
53 | new_data.append(entry)
54 |
55 | save_file(new_data, '../data/' + split + '_prompt_category.json')
56 |
57 |
58 | transform_data('train')
59 | transform_data('validation')
60 | transform_data('test')
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
--------------------------------------------------------------------------------
/src/preprocess/generate_prompts_score.py:
--------------------------------------------------------------------------------
1 | import json
2 | from tqdm import tqdm
3 |
4 | def open_txt_file(file):
5 | entities = []
6 |
7 | for line in open(file).readlines():
8 | entities.append(line)
9 |
10 | return entities
11 |
12 |
13 | def open_file(file):
14 | entities = []
15 |
16 | for line in open(file).readlines():
17 | entities.append(json.loads(line))
18 |
19 | return entities
20 |
21 |
22 | def save_file(data, file):
23 | file_writer = open(file, 'w')
24 |
25 | for line in data:
26 | file_writer.write(json.dumps(line) + "\n")
27 |
28 |
29 | def get_prompt(flesch_summary):
30 | prompt = 'Write highlights for this article with a flesch kincaid score of ' + str(
31 | int(round(flesch_summary, 0))) + ":\n\n"
32 | return prompt
33 |
34 |
35 | def transform_data(split):
36 | data = open_file('../data/' + split + '.json')
37 | new_data = []
38 |
39 | for entry in tqdm(data):
40 |
41 | flesch_summary = entry["summary_metrics"]["flesch"]
42 | flesch_input = entry["input_metrics"]["flesch"]
43 |
44 | prompt = get_prompt(flesch_summary)
45 | entry["prompt"] = prompt
46 | entry["input_noprompt"] = entry["input"]
47 | entry["input"] = prompt + entry["input"]
48 |
49 | if split == 'test' and flesch_input >= 50:
50 | continue
51 | new_data.append(entry)
52 |
53 |
54 | save_file(new_data, '../data/' + split + '_prompt_score.json')
55 |
56 |
57 | transform_data('train')
58 | transform_data('validation')
59 | transform_data('test')
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
--------------------------------------------------------------------------------
/src/preprocess/preprocess_cnndm.py:
--------------------------------------------------------------------------------
1 | import json
2 | from readability import Readability
3 | from datasets import load_dataset
4 |
5 | # Download the CNNDM data (e.g., from https://huggingface.co/datasets/cnn_dailymail)
6 | dataset = ''
7 |
8 | def open_txt_file(file):
9 | entities = []
10 |
11 | for line in open(file).readlines():
12 | entities.append(line)
13 |
14 | return entities
15 |
16 |
17 | def open_file(file):
18 | entities = []
19 |
20 | for line in open(file).readlines():
21 | entities.append(json.loads(line))
22 |
23 | return entities
24 |
25 | def save_file(data, file):
26 | file_writer = open(file, 'w')
27 |
28 | for line in data:
29 | file_writer.write(json.dumps(line) + "\n")
30 |
31 | def get_flesch_kincaid(text):
32 | r = Readability(text)
33 | fk = r.flesch_kincaid()
34 | return fk.score
35 |
36 |
37 | def get_flesch(text):
38 | r = Readability(text)
39 | f = r.flesch()
40 | return f.score
41 |
42 |
43 | def get_dale_chall(text):
44 | r = Readability(text)
45 | dc = r.dale_chall()
46 | return dc.score
47 |
48 |
49 | def get_ari(text):
50 | r = Readability(text)
51 | ari = r.ari()
52 | return ari.score
53 |
54 |
55 | def get_coleman_liau(text):
56 | r = Readability(text)
57 | cl = r.coleman_liau()
58 | return cl.score
59 |
60 |
61 | def get_gunning_fog(text):
62 | r = Readability(text)
63 | gf = r.gunning_fog()
64 | return gf.score
65 |
66 |
67 | def get_smog(text):
68 | r = Readability(text)
69 | s = r.smog()
70 | return s.score
71 |
72 |
73 | def get_spache(text):
74 | r = Readability(text)
75 | s = r.spache()
76 | return s.score
77 |
78 | def get_linsear_write(text):
79 | r = Readability(text)
80 | lw = r.linsear_write()
81 | return lw.score
82 |
83 |
84 | def compute_metrics(text):
85 | metrics = {}
86 | flesch = get_flesch(text)
87 | metrics['flesch'] = round(flesch, 4)
88 |
89 | dale_chall = get_dale_chall(text)
90 | metrics['dale_chall'] = round(dale_chall, 4)
91 |
92 | coleman_liau = get_coleman_liau(text)
93 | metrics['coleman_liau'] = round(coleman_liau, 4)
94 |
95 | gunning_fog = get_gunning_fog(text)
96 | metrics['gunning_fog'] = round(gunning_fog, 4)
97 |
98 | return metrics
99 |
100 |
101 | def process_data(split):
102 | data = []
103 | for idx, (dial, sum, id_) in enumerate(zip(dataset[split]['article'], dataset[split]['highlights'], dataset[split]['id'])):
104 | entry = {}
105 | entry['input'] = dial
106 | metrics = compute_metrics(entry["input"])
107 | entry['input_metrics'] = metrics
108 |
109 | entry['summary'] = sum
110 | entry['id'] = str(id_)
111 | metrics = compute_metrics(entry["summary"].replace("\n", " "))
112 | entry['summary_metrics'] = metrics
113 | data.append(entry)
114 |
115 | save_file(data, 'data/' + split + '.json')
116 |
117 |
118 | process_data('train')
119 | process_data('validation')
120 | process_data('test')
121 |
122 |
123 |
124 |
--------------------------------------------------------------------------------
/src/train/ds_config_stage3_fb16.json:
--------------------------------------------------------------------------------
1 | {
2 | "bf16": {
3 | "enabled": true
4 | },
5 |
6 | "zero_optimization": {
7 | "stage": 2,
8 | "allgather_partitions": true,
9 | "allgather_bucket_size": 2e8,
10 | "overlap_comm": true,
11 | "reduce_scatter": true,
12 | "reduce_bucket_size": 2e8,
13 | "contiguous_gradients": true
14 | },
15 | "train_batch_size": "auto",
16 | "train_micro_batch_size_per_gpu": "auto",
17 | "zero_allow_untested_optimizer": true,
18 |
19 | "optimizer": {
20 | "type": "AdamW",
21 | "params": {
22 | "lr": 1e-4,
23 | "betas": [
24 | 0.9,
25 | 0.999
26 | ],
27 | "eps": 1e-8,
28 | "weight_decay": 0.0
29 | }
30 | },
31 |
32 | "scheduler": {
33 | "type": "WarmupDecayLR",
34 | "params": {
35 | "total_num_steps": "auto",
36 | "warmup_min_lr": "auto",
37 | "warmup_max_lr": "auto",
38 | "warmup_num_steps": "auto"
39 | }
40 | },
41 |
42 | "steps_per_print": 30,
43 | "wall_clock_breakdown": false
44 | }
--------------------------------------------------------------------------------
/src/train/rl/accelerate_config.yaml:
--------------------------------------------------------------------------------
1 | compute_environment: LOCAL_MACHINE
2 | distributed_type: DEEPSPEED
3 | downcast_bf16: 'no'
4 | gpu_ids: 0,1,2,3,4,5,6,7
5 | machine_rank: 0
6 | main_training_function: main
7 | mixed_precision: bf16
8 | num_machines: 1
9 | num_processes: 8
10 | rdzv_backend: static
11 | same_network: true
12 | main_process_port: 61001
13 | tpu_env: []
14 | tpu_use_cluster: false
15 | tpu_use_sudo: false
16 | use_cpu: false
17 | deepspeed_config:
18 | gradient_accumulation_steps: 1
19 | gradient_clipping: 1.0
20 | offload_optimizer_device: none
21 | offload_param_device: none
22 | zero3_init_flag: true
23 | zero_stage: 2
24 |
25 |
26 |
--------------------------------------------------------------------------------
/src/train/rl/train.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 | from datasets import load_dataset
3 | from tqdm import tqdm
4 | from transformers import AutoTokenizer
5 | from readability import Readability
6 | import numpy as np
7 | import sys
8 | eps = sys.float_info.epsilon
9 | import math
10 |
11 | import trlx
12 | from trlx.data.configs import (
13 | ModelConfig,
14 | OptimizerConfig,
15 | SchedulerConfig,
16 | TokenizerConfig,
17 | TrainConfig,
18 | TRLConfig,
19 | )
20 | from trlx.models.modeling_ppo import PPOConfig
21 |
22 | model_dir = 'checkpoints/exec-XXXX' # select the checkpoint from the prompt-based methods
23 |
24 | config = TRLConfig(
25 | train=TrainConfig(
26 | seq_length=1024,
27 | epochs=500,
28 | total_steps=100000,
29 | batch_size=2,
30 | batch_size_eval=2,
31 | checkpoint_interval=10000,
32 | eval_interval=500,
33 | save_optimizer=False,
34 | pipeline="PromptPipeline",
35 | trainer="AcceleratePPOTrainer",
36 | checkpoint_dir='checkpoint-diverse',
37 | save_best=True
38 | ),
39 | model=ModelConfig(
40 | model_path=model_dir,
41 | model_arch_type="seq2seq",
42 | num_layers_unfrozen=-1,
43 | ),
44 | tokenizer=TokenizerConfig(
45 | tokenizer_path=model_dir,
46 | truncation_side="right",
47 | ),
48 | optimizer=OptimizerConfig(
49 | name="adamw",
50 | kwargs={
51 | "lr": 1.0e-5,
52 | "betas": [0.9, 0.999],
53 | "eps": 1.0e-8,
54 | "weight_decay": 1.0e-6,
55 | },
56 | ),
57 | scheduler=SchedulerConfig(
58 | name="cosine_annealing",
59 | kwargs={
60 | "T_max": 10000,
61 | "eta_min": 1.0e-6,
62 | },
63 | ),
64 | method=PPOConfig(
65 | name="PPOConfig",
66 | num_rollouts=512,
67 | chunk_size=4,
68 | ppo_epochs=4,
69 | init_kl_coef=0.05,
70 | target=6,
71 | horizon=10000,
72 | gamma=0.99,
73 | lam=0.95,
74 | cliprange=0.2,
75 | cliprange_value=0.2,
76 | vf_coef=1.0,
77 | scale_reward=None,
78 | ref_mean=None,
79 | ref_std=None,
80 | cliprange_reward=10,
81 | gen_kwargs={
82 | "max_new_tokens": 256,
83 | },
84 | gen_experience_kwargs={
85 | "max_new_tokens": 256,
86 | "do_sample": True,
87 | "temperature": 1.0,
88 | "top_k": 50,
89 | "top_p": 0.95,
90 | },
91 | ),
92 | )
93 |
94 |
95 | def get_flesch_kincaid(text):
96 | r = Readability(text)
97 | fk = r.flesch_kincaid()
98 | return fk.score
99 |
100 |
101 | def get_flesch(text):
102 | r = Readability(text)
103 | f = r.flesch()
104 | return f.score
105 |
106 | import random
107 |
108 | def change_scores(input_data):
109 | new_data = []
110 | for text in input_data:
111 | score_sum = random.choice([10, 15, 25, 30, 33, 35, 37, 40, 45, 48, 50, 52, 60, 64, 68, 70, 71, 75, 83, 84, 88, 89, 90, 92, 93, 94, 95])
112 | new_text = "Write highlights for this article with a flesch kincaid score of " + str(score_sum) + ":\n\n" + text
113 | new_data.append(new_text)
114 | return new_data
115 |
116 | sigma = 10
117 | def calc_nd(value, mean):
118 | return 1 / (sigma * np.sqrt(2 * np.pi)) * np.exp(- (value - mean) ** 2 / (2 * sigma ** 2)) / 0.039894228040143274
119 |
120 |
121 | import torch
122 | import torch.nn as nn
123 | import torch.nn.functional as F
124 | from transformers import AutoModel, AutoTokenizer
125 | import os
126 | model_name = "roberta-large"
127 | device = "cuda:" + str(os.environ.get('LOCAL_RANK',0))
128 | num_layers = 17
129 | cache_dir=".cache"
130 | model = AutoModel.from_pretrained(model_name, cache_dir=cache_dir)
131 | model = model.to(device)
132 | tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir)
133 | model.encoder.layer = torch.nn.ModuleList([layer for layer in model.encoder.layer[:num_layers]])
134 |
135 | def encode_text(input_str):
136 | """
137 | Helper function to encode any string to tensor using the tokenizer
138 | """
139 | inputs = tokenizer(input_str, padding=True, truncation=True, return_tensors="pt")
140 | inputs = {k: v.to(device) for k, v in inputs.items()}
141 | with torch.no_grad():
142 | outputs = model(**inputs)
143 |
144 | # idf
145 | idf = torch.clone(inputs["attention_mask"]).float()
146 | idf[idf == tokenizer.sep_token_id] = 0
147 | idf[idf == tokenizer.cls_token_id] = 0
148 | idf.div_(idf.sum(dim=1, keepdim=True))
149 |
150 | return F.normalize(outputs[0], dim=-1), inputs["attention_mask"], idf
151 |
152 | def compute_bertscore(doc_embedding, doc_masks, doc_idf, summ_embedding, summ_masks, summ_idf):
153 | """
154 | Helper function that is modified from the official code (greedy_cos_idf() method) https://github.com/Tiiiger/bert_score/blob/dbcf6db37e8bd6ff68446f06b0ba5d0763b62d20/bert_score/utils.py#L469
155 | """
156 |
157 | batch_size = doc_embedding.size(0)
158 | sim = torch.bmm(summ_embedding, doc_embedding.transpose(1, 2))
159 | masks = torch.bmm(summ_masks.unsqueeze(2).float(), doc_masks.unsqueeze(1).float())
160 | masks = masks.expand(batch_size, -1, -1).contiguous().view_as(sim)
161 |
162 | masks = masks.float().to(sim.device)
163 | sim = sim * masks
164 |
165 | precision = sim.max(dim=2)[0]
166 | precision_scale = summ_idf.to(precision.device)
167 | P = (precision * precision_scale).sum(dim=1)
168 |
169 | summ_zero_mask = summ_masks.sum(dim=1).eq(2)
170 | if torch.any(summ_zero_mask):
171 | P = P.masked_fill(summ_zero_mask, 0.0)
172 |
173 | doc_zero_mask = doc_masks.sum(dim=1).eq(2)
174 | if torch.any(doc_zero_mask):
175 | P = P.masked_fill(doc_zero_mask, 0.0)
176 |
177 | return P
178 |
179 |
180 | if __name__ == "__main__":
181 |
182 | def reward_fn(samples: List[str], prompts: List[str], outputs: List[str]):
183 |
184 | flesch_scores = []
185 | original_scores = []
186 | summaries = []
187 | docs = []
188 | for (generated_summary, input_doc) in zip(outputs, prompts):
189 | score_sum = int(input_doc.split("Write highlights for this article with a flesch kincaid score of ")[1][:2].replace(":", ""))
190 | original_scores.append(score_sum)
191 | doc = input_doc.split("Write highlights for this article with a flesch kincaid score of ")[1][2:]
192 | docs.append(doc)
193 | summaries.append(generated_summary.strip())
194 |
195 | try:
196 | flesch_scores.append(get_flesch(generated_summary.strip()))
197 | except:
198 | flesch_scores.append(0)
199 |
200 | all_bertscore_scores = []
201 | for doc, summary in zip(docs, summaries):
202 |
203 | bertscore_input_embedding, bertscore_input_attention_mask, bertscore_input_idf = encode_text([doc])
204 | bertscore_output_embedding, bertscore_output_attention_mask, bertscore_output_idf = encode_text([summary])
205 |
206 | bertscore_scores = compute_bertscore(
207 | bertscore_input_embedding,
208 | bertscore_input_attention_mask,
209 | bertscore_input_idf,
210 | bertscore_output_embedding,
211 | bertscore_output_attention_mask,
212 | bertscore_output_idf,
213 | )
214 | bertscore_scores = bertscore_scores.tolist()
215 | all_bertscore_scores.extend(bertscore_scores)
216 |
217 | assert len(original_scores) == len(flesch_scores) == len(all_bertscore_scores)
218 |
219 | flesch_scores = [calc_nd(fs, o_fs) for fs, o_fs in zip(flesch_scores, original_scores)]
220 |
221 | readability_weight = 1
222 | flesch_scores = torch.tensor(flesch_scores)
223 | all_bertscore_scores = torch.tensor(all_bertscore_scores)
224 | flesch_scores = readability_weight * flesch_scores + (1 - readability_weight) * all_bertscore_scores
225 | flesch_scores = flesch_scores.tolist()
226 |
227 | return flesch_scores
228 |
229 |
230 | train_file = '../../data/train_prompt_score.json'
231 | validation_file = '../../data/train_prompt_score.json'
232 | data_files = {"train": train_file, "validation": validation_file}
233 | dataset = load_dataset("json", data_files=data_files)
234 | dataset['train'] = dataset['train'].shuffle(seed=42)
235 | dataset['validation'] = dataset['validation'].shuffle(seed=42)
236 |
237 | validation_examples = 2000
238 | val_prompts = [prompt for prompt in dataset['validation']["input_noprompt"][0:validation_examples]]
239 | print('\ntest 0\n', val_prompts[0])
240 | val_summaries = dataset['validation']["summary"][0:validation_examples]
241 | val_prompts = change_scores(val_prompts)
242 | assert len(val_prompts) == len(val_summaries)
243 | print('\ntest after 0 \n', val_prompts[0])
244 |
245 | prompts = dataset['train']["input_noprompt"]
246 | summaries = dataset['train']["summary"]
247 | prompts = [prompt for prompt in prompts]
248 | prompts = change_scores(prompts)
249 | assert len(prompts) == len(summaries)
250 |
251 | # make dictionary of prompts and labels to use for reward function
252 | tokenizer = AutoTokenizer.from_pretrained(config.model.model_path)
253 | tokenizer.padding_side = "left"
254 | tokenizer.truncation_side = "right"
255 | tokenizer.sep_token = ""
256 | prompt_label = {}
257 | max_length = config.train.seq_length
258 |
259 | trlx.train(
260 | reward_fn=reward_fn,
261 | prompts=prompts,
262 | eval_prompts=val_prompts,
263 | config=config,
264 | )
265 |
--------------------------------------------------------------------------------
/src/train/rl/train_rl_cnndm.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | source ~/anaconda3/etc/profile.d/conda.sh
4 |
5 | conda activate readability_summ
6 | export TOKENIZERS_PARALLELISM=true
7 |
8 | accelerate launch --config_file accelerate_config.yaml train.py
9 |
10 | conda deactivate
--------------------------------------------------------------------------------
/src/train/run_summarization.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding=utf-8
3 | # Copyright 2021 The HuggingFace Team. All rights reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | """
17 | Fine-tuning the library models for sequence to sequence.
18 | """
19 | # You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments.
20 |
21 | import logging
22 | import os
23 | import sys
24 | from dataclasses import dataclass, field
25 | from typing import Optional
26 |
27 | import datasets
28 | import nltk # Here to have a nice missing dependency error message early on
29 | import numpy as np
30 | from datasets import load_dataset
31 |
32 | import evaluate
33 | import transformers
34 | from filelock import FileLock
35 | from transformers import (
36 | AutoConfig,
37 | AutoModelForSeq2SeqLM,
38 | AutoTokenizer,
39 | DataCollatorForSeq2Seq,
40 | HfArgumentParser,
41 | MBart50Tokenizer,
42 | MBart50TokenizerFast,
43 | MBartTokenizer,
44 | MBartTokenizerFast,
45 | Seq2SeqTrainer,
46 | Seq2SeqTrainingArguments,
47 | set_seed,
48 | )
49 | from transformers.trainer_utils import get_last_checkpoint
50 | from transformers.utils import check_min_version, is_offline_mode, send_example_telemetry
51 | from transformers.utils.versions import require_version
52 |
53 | os.environ["NCCL_DEBUG"] = "INFO"
54 |
55 | # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
56 | #check_min_version("4.25.0.dev0")
57 |
58 | require_version("datasets>=1.8.0", "To fix: pip install -r examples/pytorch/summarization/requirements.txt")
59 |
60 | logger = logging.getLogger(__name__)
61 |
62 | try:
63 | nltk.data.find("tokenizers/punkt")
64 | except (LookupError, OSError):
65 | if is_offline_mode():
66 | raise LookupError(
67 | "Offline mode: run this script without TRANSFORMERS_OFFLINE first to download nltk data files"
68 | )
69 | with FileLock(".lock") as lock:
70 | nltk.download("punkt", quiet=True)
71 |
72 | # A list of all multilingual tokenizer which require lang attribute.
73 | MULTILINGUAL_TOKENIZERS = [MBartTokenizer, MBartTokenizerFast, MBart50Tokenizer, MBart50TokenizerFast]
74 |
75 |
76 | @dataclass
77 | class ModelArguments:
78 | """
79 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
80 | """
81 |
82 | model_name_or_path: str = field(
83 | metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
84 | )
85 | config_name: Optional[str] = field(
86 | default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
87 | )
88 | tokenizer_name: Optional[str] = field(
89 | default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
90 | )
91 | cache_dir: Optional[str] = field(
92 | default=None,
93 | metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"},
94 | )
95 | use_fast_tokenizer: bool = field(
96 | default=True,
97 | metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
98 | )
99 | model_revision: str = field(
100 | default="main",
101 | metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
102 | )
103 | use_auth_token: bool = field(
104 | default=False,
105 | metadata={
106 | "help": (
107 | "Will use the token generated when running `huggingface-cli login` (necessary to use this script "
108 | "with private models)."
109 | )
110 | },
111 | )
112 | resize_position_embeddings: Optional[bool] = field(
113 | default=None,
114 | metadata={
115 | "help": (
116 | "Whether to automatically resize the position embeddings if `max_source_length` exceeds "
117 | "the model's position embeddings."
118 | )
119 | },
120 | )
121 |
122 |
123 | @dataclass
124 | class DataTrainingArguments:
125 | """
126 | Arguments pertaining to what data we are going to input our model for training and eval.
127 | """
128 |
129 | lang: Optional[str] = field(default=None, metadata={"help": "Language id for summarization."})
130 |
131 | dataset_name: Optional[str] = field(
132 | default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
133 | )
134 | dataset_config_name: Optional[str] = field(
135 | default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
136 | )
137 | text_column: Optional[str] = field(
138 | default=None,
139 | metadata={"help": "The name of the column in the datasets containing the full texts (for summarization)."},
140 | )
141 | summary_column: Optional[str] = field(
142 | default=None,
143 | metadata={"help": "The name of the column in the datasets containing the summaries (for summarization)."},
144 | )
145 | train_file: Optional[str] = field(
146 | default=None, metadata={"help": "The input training data file (a jsonlines or csv file)."}
147 | )
148 | validation_file: Optional[str] = field(
149 | default=None,
150 | metadata={
151 | "help": (
152 | "An optional input evaluation data file to evaluate the metrics (rouge) on (a jsonlines or csv file)."
153 | )
154 | },
155 | )
156 | test_file: Optional[str] = field(
157 | default=None,
158 | metadata={
159 | "help": "An optional input test data file to evaluate the metrics (rouge) on (a jsonlines or csv file)."
160 | },
161 | )
162 | overwrite_cache: bool = field(
163 | default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
164 | )
165 | preprocessing_num_workers: Optional[int] = field(
166 | default=None,
167 | metadata={"help": "The number of processes to use for the preprocessing."},
168 | )
169 | max_source_length: Optional[int] = field(
170 | default=1024,
171 | metadata={
172 | "help": (
173 | "The maximum total input sequence length after tokenization. Sequences longer "
174 | "than this will be truncated, sequences shorter will be padded."
175 | )
176 | },
177 | )
178 | max_target_length: Optional[int] = field(
179 | default=128,
180 | metadata={
181 | "help": (
182 | "The maximum total sequence length for target text after tokenization. Sequences longer "
183 | "than this will be truncated, sequences shorter will be padded."
184 | )
185 | },
186 | )
187 | val_max_target_length: Optional[int] = field(
188 | default=None,
189 | metadata={
190 | "help": (
191 | "The maximum total sequence length for validation target text after tokenization. Sequences longer "
192 | "than this will be truncated, sequences shorter will be padded. Will default to `max_target_length`."
193 | "This argument is also used to override the ``max_length`` param of ``model.generate``, which is used "
194 | "during ``evaluate`` and ``predict``."
195 | )
196 | },
197 | )
198 | pad_to_max_length: bool = field(
199 | default=False,
200 | metadata={
201 | "help": (
202 | "Whether to pad all samples to model maximum sentence length. "
203 | "If False, will pad the samples dynamically when batching to the maximum length in the batch. More "
204 | "efficient on GPU but very bad for TPU."
205 | )
206 | },
207 | )
208 | max_train_samples: Optional[int] = field(
209 | default=None,
210 | metadata={
211 | "help": (
212 | "For debugging purposes or quicker training, truncate the number of training examples to this "
213 | "value if set."
214 | )
215 | },
216 | )
217 | max_eval_samples: Optional[int] = field(
218 | default=None,
219 | metadata={
220 | "help": (
221 | "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
222 | "value if set."
223 | )
224 | },
225 | )
226 | max_predict_samples: Optional[int] = field(
227 | default=None,
228 | metadata={
229 | "help": (
230 | "For debugging purposes or quicker training, truncate the number of prediction examples to this "
231 | "value if set."
232 | )
233 | },
234 | )
235 | num_beams: Optional[int] = field(
236 | default=None,
237 | metadata={
238 | "help": (
239 | "Number of beams to use for evaluation. This argument will be passed to ``model.generate``, "
240 | "which is used during ``evaluate`` and ``predict``."
241 | )
242 | },
243 | )
244 | ignore_pad_token_for_loss: bool = field(
245 | default=True,
246 | metadata={
247 | "help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not."
248 | },
249 | )
250 | source_prefix: Optional[str] = field(
251 | default="", metadata={"help": "A prefix to add before every source text (useful for T5 models)."}
252 | )
253 |
254 | forced_bos_token: Optional[str] = field(
255 | default=None,
256 | metadata={
257 | "help": (
258 | "The token to force as the first generated token after the decoder_start_token_id."
259 | "Useful for multilingual models like mBART where the first generated token"
260 | "needs to be the target language token (Usually it is the target language token)"
261 | )
262 | },
263 | )
264 |
265 | def __post_init__(self):
266 | if self.dataset_name is None and self.train_file is None and self.validation_file is None:
267 | raise ValueError("Need either a dataset name or a training/validation file.")
268 | else:
269 | if self.train_file is not None:
270 | extension = self.train_file.split(".")[-1]
271 | assert extension in ["csv", "json"], "`train_file` should be a csv or a json file."
272 | if self.validation_file is not None:
273 | extension = self.validation_file.split(".")[-1]
274 | assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file."
275 | if self.val_max_target_length is None:
276 | self.val_max_target_length = self.max_target_length
277 |
278 |
279 | summarization_name_mapping = {
280 | "amazon_reviews_multi": ("review_body", "review_title"),
281 | "big_patent": ("description", "abstract"),
282 | "cnn_dailymail": ("article", "highlights"),
283 | "orange_sum": ("text", "summary"),
284 | "pn_summary": ("article", "summary"),
285 | "psc": ("extract_text", "summary_text"),
286 | "samsum": ("dialogue", "summary"),
287 | "thaisum": ("body", "summary"),
288 | "xglue": ("news_body", "news_title"),
289 | "xsum": ("document", "summary"),
290 | "wiki_summary": ("article", "highlights"),
291 | "multi_news": ("document", "summary"),
292 | }
293 |
294 |
295 | def main():
296 | # See all possible arguments in src/transformers/training_args.py
297 | # or by passing the --help flag to this script.
298 | # We now keep distinct sets of args, for a cleaner separation of concerns.
299 |
300 | parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments))
301 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
302 | # If we pass only one argument to the script and it's the path to a json file,
303 | # let's parse it to get our arguments.
304 | model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
305 | else:
306 | model_args, data_args, training_args = parser.parse_args_into_dataclasses()
307 |
308 | # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
309 | # information sent is the one passed as arguments along with your Python/PyTorch versions.
310 | send_example_telemetry("run_summarization", model_args, data_args)
311 | print("training_args", training_args)
312 | # Setup logging
313 | logging.basicConfig(
314 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
315 | datefmt="%m/%d/%Y %H:%M:%S",
316 | handlers=[logging.StreamHandler(sys.stdout)],
317 | )
318 | log_level = training_args.get_process_log_level()
319 | logger.setLevel(log_level)
320 | datasets.utils.logging.set_verbosity(log_level)
321 | transformers.utils.logging.set_verbosity(log_level)
322 | transformers.utils.logging.enable_default_handler()
323 | transformers.utils.logging.enable_explicit_format()
324 |
325 | # Log on each process the small summary:
326 | logger.warning(
327 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
328 | + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
329 | )
330 | logger.info(f"Training/evaluation parameters {training_args}")
331 |
332 | if data_args.source_prefix is None and model_args.model_name_or_path in [
333 | "t5-small",
334 | "t5-base",
335 | "t5-large",
336 | "t5-3b",
337 | "t5-11b",
338 | ]:
339 | logger.warning(
340 | "You're running a t5 model but didn't provide a source prefix, which is the expected, e.g. with "
341 | "`--source_prefix 'summarize: ' `"
342 | )
343 |
344 | # Detecting last checkpoint.
345 | last_checkpoint = None
346 | if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
347 | last_checkpoint = get_last_checkpoint(training_args.output_dir)
348 | if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
349 | raise ValueError(
350 | f"Output directory ({training_args.output_dir}) already exists and is not empty. "
351 | "Use --overwrite_output_dir to overcome."
352 | )
353 | elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
354 | logger.info(
355 | f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
356 | "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
357 | )
358 |
359 | # Set seed before initializing model.
360 | set_seed(training_args.seed)
361 |
362 | # Get the datasets: you can either provide your own CSV/JSON training and evaluation files (see below)
363 | # or just provide the name of one of the public datasets available on the hub at https://huggingface.co/datasets/
364 | # (the dataset will be downloaded automatically from the datasets Hub).
365 | #
366 | # For CSV/JSON files this script will use the first column for the full texts and the second column for the
367 | # summaries (unless you specify column names for this with the `text_column` and `summary_column` arguments).
368 | #
369 | # In distributed training, the load_dataset function guarantee that only one local process can concurrently
370 | # download the dataset.
371 | print("data_args", data_args)
372 | if data_args.dataset_name is not None:
373 | # Downloading and loading a dataset from the hub.
374 | raw_datasets = load_dataset(
375 | data_args.dataset_name,
376 | data_args.dataset_config_name,
377 | cache_dir=model_args.cache_dir,
378 | use_auth_token=True if model_args.use_auth_token else None,
379 | )
380 | else:
381 | data_files = {}
382 | if data_args.train_file is not None:
383 | data_files["train"] = data_args.train_file
384 | extension = data_args.train_file.split(".")[-1]
385 | if data_args.validation_file is not None:
386 | data_files["validation"] = data_args.validation_file
387 | extension = data_args.validation_file.split(".")[-1]
388 | if data_args.test_file is not None:
389 | data_files["test"] = data_args.test_file
390 | extension = data_args.test_file.split(".")[-1]
391 | raw_datasets = load_dataset(
392 | extension,
393 | data_files=data_files,
394 | cache_dir=model_args.cache_dir,
395 | use_auth_token=True if model_args.use_auth_token else None,
396 | )
397 | # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
398 | # https://huggingface.co/docs/datasets/loading_datasets.html.
399 |
400 | # Load pretrained model and tokenizer
401 | #
402 | # Distributed training:
403 | # The .from_pretrained methods guarantee that only one local process can concurrently
404 | # download model & vocab.
405 | print("model_args", model_args)
406 | config = AutoConfig.from_pretrained(
407 | model_args.config_name if model_args.config_name else model_args.model_name_or_path,
408 | cache_dir=model_args.cache_dir,
409 | revision=model_args.model_revision,
410 | use_auth_token=True if model_args.use_auth_token else None,
411 | )
412 | tokenizer = AutoTokenizer.from_pretrained(
413 | model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
414 | cache_dir=model_args.cache_dir,
415 | use_fast=model_args.use_fast_tokenizer,
416 | revision=model_args.model_revision,
417 | use_auth_token=True if model_args.use_auth_token else None,
418 | )
419 | model = AutoModelForSeq2SeqLM.from_pretrained(
420 | model_args.model_name_or_path,
421 | from_tf=bool(".ckpt" in model_args.model_name_or_path),
422 | config=config,
423 | cache_dir=model_args.cache_dir,
424 | revision=model_args.model_revision,
425 | use_auth_token=True if model_args.use_auth_token else None,
426 | )
427 |
428 | # We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch
429 | # on a small vocab and want a smaller embedding size, remove this test.
430 | embedding_size = model.get_input_embeddings().weight.shape[0]
431 | if len(tokenizer) > embedding_size:
432 | model.resize_token_embeddings(len(tokenizer))
433 |
434 | if model.config.decoder_start_token_id is None and isinstance(tokenizer, (MBartTokenizer, MBartTokenizerFast)):
435 | if isinstance(tokenizer, MBartTokenizer):
436 | model.config.decoder_start_token_id = tokenizer.lang_code_to_id[data_args.lang]
437 | else:
438 | model.config.decoder_start_token_id = tokenizer.convert_tokens_to_ids(data_args.lang)
439 |
440 | if model.config.decoder_start_token_id is None:
441 | raise ValueError("Make sure that `config.decoder_start_token_id` is correctly defined")
442 |
443 | if (
444 | hasattr(model.config, "max_position_embeddings")
445 | and model.config.max_position_embeddings < data_args.max_source_length
446 | ):
447 | if model_args.resize_position_embeddings is None:
448 | logger.warning(
449 | "Increasing the model's number of position embedding vectors from"
450 | f" {model.config.max_position_embeddings} to {data_args.max_source_length}."
451 | )
452 | model.resize_position_embeddings(data_args.max_source_length)
453 | elif model_args.resize_position_embeddings:
454 | model.resize_position_embeddings(data_args.max_source_length)
455 | else:
456 | raise ValueError(
457 | f"`--max_source_length` is set to {data_args.max_source_length}, but the model only has"
458 | f" {model.config.max_position_embeddings} position encodings. Consider either reducing"
459 | f" `--max_source_length` to {model.config.max_position_embeddings} or to automatically resize the"
460 | " model's position encodings by passing `--resize_position_embeddings`."
461 | )
462 |
463 | prefix = data_args.source_prefix if data_args.source_prefix is not None else ""
464 |
465 | # Preprocessing the datasets.
466 | # We need to tokenize inputs and targets.
467 | if training_args.do_train:
468 | column_names = raw_datasets["train"].column_names
469 | elif training_args.do_eval:
470 | column_names = raw_datasets["validation"].column_names
471 | elif training_args.do_predict:
472 | column_names = raw_datasets["test"].column_names
473 | else:
474 | logger.info("There is nothing to do. Please pass `do_train`, `do_eval` and/or `do_predict`.")
475 | return
476 |
477 | if isinstance(tokenizer, tuple(MULTILINGUAL_TOKENIZERS)):
478 | assert (
479 | data_args.lang is not None
480 | ), f"{tokenizer.__class__.__name__} is a multilingual tokenizer which requires --lang argument"
481 |
482 | tokenizer.src_lang = data_args.lang
483 | tokenizer.tgt_lang = data_args.lang
484 |
485 | # For multilingual translation models like mBART-50 and M2M100 we need to force the target language token
486 | # as the first generated token. We ask the user to explicitly provide this as --forced_bos_token argument.
487 | forced_bos_token_id = (
488 | tokenizer.lang_code_to_id[data_args.forced_bos_token] if data_args.forced_bos_token is not None else None
489 | )
490 | model.config.forced_bos_token_id = forced_bos_token_id
491 |
492 | # Get the column names for input/target.
493 | dataset_columns = summarization_name_mapping.get(data_args.dataset_name, None)
494 | if data_args.text_column is None:
495 | text_column = dataset_columns[0] if dataset_columns is not None else column_names[0]
496 | else:
497 | text_column = data_args.text_column
498 | if text_column not in column_names:
499 | raise ValueError(
500 | f"--text_column' value '{data_args.text_column}' needs to be one of: {', '.join(column_names)}"
501 | )
502 | if data_args.summary_column is None:
503 | summary_column = dataset_columns[1] if dataset_columns is not None else column_names[1]
504 | else:
505 | summary_column = data_args.summary_column
506 | if summary_column not in column_names:
507 | raise ValueError(
508 | f"--summary_column' value '{data_args.summary_column}' needs to be one of: {', '.join(column_names)}"
509 | )
510 |
511 | # Temporarily set max_target_length for training.
512 | max_target_length = data_args.max_target_length
513 | padding = "max_length" if data_args.pad_to_max_length else False
514 |
515 | if training_args.label_smoothing_factor > 0 and not hasattr(model, "prepare_decoder_input_ids_from_labels"):
516 | logger.warning(
517 | "label_smoothing is enabled but the `prepare_decoder_input_ids_from_labels` method is not defined for"
518 | f"`{model.__class__.__name__}`. This will lead to loss being calculated twice and will take up more memory"
519 | )
520 |
521 | print(data_args)
522 |
523 | def preprocess_function(examples):
524 | # remove pairs where at least one record is None
525 |
526 | inputs, targets = [], []
527 | for i in range(len(examples[text_column])):
528 | if examples[text_column][i] and examples[summary_column][i]:
529 | inputs.append(examples[text_column][i])
530 | targets.append(examples[summary_column][i])
531 |
532 | inputs = [prefix + inp for inp in inputs]
533 | model_inputs = tokenizer(inputs, max_length=data_args.max_source_length, padding=padding, truncation=True)
534 |
535 | # Tokenize targets with the `text_target` keyword argument
536 | labels = tokenizer(text_target=targets, max_length=max_target_length, padding=padding, truncation=True)
537 |
538 | # If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore
539 | # padding in the loss.
540 | if padding == "max_length" and data_args.ignore_pad_token_for_loss:
541 | labels["input_ids"] = [
542 | [(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"]
543 | ]
544 |
545 | model_inputs["labels"] = labels["input_ids"]
546 | return model_inputs
547 |
548 | if training_args.do_train:
549 | if "train" not in raw_datasets:
550 | raise ValueError("--do_train requires a train dataset")
551 | train_dataset = raw_datasets["train"]
552 | if data_args.max_train_samples is not None:
553 | max_train_samples = min(len(train_dataset), data_args.max_train_samples)
554 | train_dataset = train_dataset.select(range(max_train_samples))
555 | with training_args.main_process_first(desc="train dataset map pre-processing"):
556 | train_dataset = train_dataset.map(
557 | preprocess_function,
558 | batched=True,
559 | num_proc=data_args.preprocessing_num_workers,
560 | remove_columns=column_names,
561 | load_from_cache_file=not data_args.overwrite_cache,
562 | desc="Running tokenizer on train dataset",
563 | )
564 |
565 | if training_args.do_eval:
566 | max_target_length = data_args.val_max_target_length
567 | if "validation" not in raw_datasets:
568 | raise ValueError("--do_eval requires a validation dataset")
569 | eval_dataset = raw_datasets["validation"]
570 | if data_args.max_eval_samples is not None:
571 | max_eval_samples = min(len(eval_dataset), data_args.max_eval_samples)
572 | eval_dataset = eval_dataset.select(range(max_eval_samples))
573 | with training_args.main_process_first(desc="validation dataset map pre-processing"):
574 | eval_dataset = eval_dataset.map(
575 | preprocess_function,
576 | batched=True,
577 | num_proc=data_args.preprocessing_num_workers,
578 | remove_columns=column_names,
579 | load_from_cache_file=not data_args.overwrite_cache,
580 | desc="Running tokenizer on validation dataset",
581 | )
582 |
583 | if training_args.do_predict:
584 | max_target_length = data_args.val_max_target_length
585 | if "test" not in raw_datasets:
586 | raise ValueError("--do_predict requires a test dataset")
587 | predict_dataset = raw_datasets["test"]
588 | if data_args.max_predict_samples is not None:
589 | max_predict_samples = min(len(predict_dataset), data_args.max_predict_samples)
590 | predict_dataset = predict_dataset.select(range(max_predict_samples))
591 | with training_args.main_process_first(desc="prediction dataset map pre-processing"):
592 | predict_dataset = predict_dataset.map(
593 | preprocess_function,
594 | batched=True,
595 | num_proc=data_args.preprocessing_num_workers,
596 | remove_columns=column_names,
597 | load_from_cache_file=not data_args.overwrite_cache,
598 | desc="Running tokenizer on prediction dataset",
599 | )
600 |
601 | # Data collator
602 | label_pad_token_id = -100 if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id
603 | data_collator = DataCollatorForSeq2Seq(
604 | tokenizer,
605 | model=model,
606 | label_pad_token_id=label_pad_token_id,
607 | pad_to_multiple_of=8 if training_args.fp16 else None,
608 | )
609 |
610 | # Metric
611 | metric = evaluate.load("rouge")
612 |
613 | def postprocess_text(preds, labels):
614 | preds = [pred.strip() for pred in preds]
615 | labels = [label.strip() for label in labels]
616 |
617 | # rougeLSum expects newline after each sentence
618 | preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds]
619 | labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels]
620 |
621 | return preds, labels
622 |
623 | def compute_metrics(eval_preds):
624 | preds, labels = eval_preds
625 | if isinstance(preds, tuple):
626 | preds = preds[0]
627 | decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
628 | if data_args.ignore_pad_token_for_loss:
629 | # Replace -100 in the labels as we can't decode them.
630 | labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
631 | decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
632 |
633 | # Some simple post-processing
634 | decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)
635 |
636 | result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
637 | result = {k: round(v * 100, 4) for k, v in result.items()}
638 | prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]
639 | result["gen_len"] = np.mean(prediction_lens)
640 | return result
641 |
642 | # Initialize our Trainer
643 | trainer = Seq2SeqTrainer(
644 | model=model,
645 | args=training_args,
646 | train_dataset=train_dataset if training_args.do_train else None,
647 | eval_dataset=eval_dataset if training_args.do_eval else None,
648 | tokenizer=tokenizer,
649 | data_collator=data_collator,
650 | compute_metrics=compute_metrics if training_args.predict_with_generate else None,
651 | )
652 |
653 | # Training
654 | if training_args.do_train:
655 | checkpoint = None
656 | if training_args.resume_from_checkpoint is not None:
657 | checkpoint = training_args.resume_from_checkpoint
658 | elif last_checkpoint is not None:
659 | checkpoint = last_checkpoint
660 | train_result = trainer.train(resume_from_checkpoint=checkpoint)
661 | trainer.save_model() # Saves the tokenizer too for easy upload
662 |
663 | metrics = train_result.metrics
664 | max_train_samples = (
665 | data_args.max_train_samples if data_args.max_train_samples is not None else len(train_dataset)
666 | )
667 | metrics["train_samples"] = min(max_train_samples, len(train_dataset))
668 |
669 | trainer.log_metrics("train", metrics)
670 | trainer.save_metrics("train", metrics)
671 | trainer.save_state()
672 |
673 | # Evaluation
674 | results = {}
675 | max_length = (
676 | training_args.generation_max_length
677 | if training_args.generation_max_length is not None
678 | else data_args.val_max_target_length
679 | )
680 | num_beams = data_args.num_beams if data_args.num_beams is not None else training_args.generation_num_beams
681 | # if training_args.do_eval:
682 | # logger.info("*** Evaluate ***")
683 | # metrics = trainer.evaluate(max_length=max_length, num_beams=num_beams, metric_key_prefix="eval")
684 | # max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
685 | # metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
686 | #
687 | # trainer.log_metrics("eval", metrics)
688 | # trainer.save_metrics("eval", metrics)
689 |
690 | if training_args.do_predict:
691 | logger.info("*** Predict ***")
692 |
693 | predict_results = trainer.predict(
694 | predict_dataset, metric_key_prefix="predict", max_length=max_length, num_beams=num_beams
695 | )
696 | metrics = predict_results.metrics
697 | max_predict_samples = (
698 | data_args.max_predict_samples if data_args.max_predict_samples is not None else len(predict_dataset)
699 | )
700 | metrics["predict_samples"] = min(max_predict_samples, len(predict_dataset))
701 |
702 | trainer.log_metrics("predict", metrics)
703 | trainer.save_metrics("predict", metrics)
704 |
705 | if trainer.is_world_process_zero():
706 | if training_args.predict_with_generate:
707 | predictions = tokenizer.batch_decode(
708 | predict_results.predictions, skip_special_tokens=True, clean_up_tokenization_spaces=True
709 | )
710 | predictions = [pred.strip() for pred in predictions]
711 | output_prediction_file = os.path.join(training_args.output_dir, "generated_predictions.txt")
712 | with open(output_prediction_file, "w") as writer:
713 | writer.write("\n".join(predictions))
714 |
715 | kwargs = {"finetuned_from": model_args.model_name_or_path, "tasks": "summarization"}
716 | if data_args.dataset_name is not None:
717 | kwargs["dataset_tags"] = data_args.dataset_name
718 | if data_args.dataset_config_name is not None:
719 | kwargs["dataset_args"] = data_args.dataset_config_name
720 | kwargs["dataset"] = f"{data_args.dataset_name} {data_args.dataset_config_name}"
721 | else:
722 | kwargs["dataset"] = data_args.dataset_name
723 |
724 | if data_args.lang is not None:
725 | kwargs["language"] = data_args.lang
726 |
727 | if training_args.push_to_hub:
728 | trainer.push_to_hub(**kwargs)
729 | else:
730 | trainer.create_model_card(**kwargs)
731 |
732 | return results
733 |
734 |
735 | def _mp_fn(index):
736 | # For xla_spawn (TPUs)
737 | main()
738 |
739 |
740 | if __name__ == "__main__":
741 | main()
742 |
--------------------------------------------------------------------------------
/src/train/train_cnndm.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | source ~/anaconda3/etc/profile.d/conda.sh
4 |
5 | conda activate readability_summ
6 |
7 | FOLDER_OUTPUT=/mnt/hd3/checkpoints/exec-$RANDOM
8 |
9 | TRAIN_FILE='../data/train_prompt_category.json'
10 | VAL_FILE='../data/validation_prompt_category.json'
11 |
12 | MODEL_NAME='google/flan-t5-large'
13 |
14 | deepspeed --master_port 61002 --include localhost:0,1,2,3,4,5,6,7 run_summarization.py --model_name_or_path ${MODEL_NAME} \
15 | --output_dir ${FOLDER_OUTPUT} --text_column input --summary_column summary \
16 | --train_file ${TRAIN_FILE} \
17 | --validation_file ${VAL_FILE} \
18 | --learning_rate 1e-4 \
19 | --max_source_length 1024 \
20 | --source_prefix "" \
21 | --num_train_epochs 20 \
22 | --logging_steps 200 \
23 | --preprocessing_num_workers 100 \
24 | --eval_steps 10000 \
25 | --save_steps 10000 \
26 | --save_total_limit 2 \
27 | --evaluation_strategy "steps" \
28 | --per_device_train_batch_size 4 \
29 | --per_device_eval_batch_size 4 \
30 | --metric_for_best_model "rouge1" \
31 | --load_best_model_at_end \
32 | --predict_with_generate \
33 | --deepspeed ds_config_stage3_fb16.json \
34 | --bf16 \
35 | --bf16_full_eval \
36 | --do_train
37 |
38 | conda deactivate
--------------------------------------------------------------------------------