├── DATA_LICENSE
├── LICENSE
├── README.md
├── WEIGHT_LICENSE
├── benchmarks
├── Pred_L-Eval
│ └── llm_gpt4_eval.pred.jsonl
├── Pred_LongBench
│ ├── 2wikimqa.jsonl
│ ├── gov_report.jsonl
│ ├── hotpotqa.jsonl
│ ├── lcc.jsonl
│ ├── multi_news.jsonl
│ ├── multifieldqa_en.jsonl
│ ├── musique.jsonl
│ ├── narrativeqa.jsonl
│ ├── passage_count.jsonl
│ ├── passage_retrieval_en.jsonl
│ ├── qasper.jsonl
│ ├── qmsum.jsonl
│ ├── repobench-p.jsonl
│ ├── result.json
│ ├── samsum.jsonl
│ ├── trec.jsonl
│ └── triviaqa.jsonl
└── README.md
├── demo.py
├── ds_configs
├── stage2.json
└── stage3.json
├── eval.py
├── eval_distributed.py
├── fine-tune.py
├── get_trainable_weights.py
├── gptneox_attn_replace.py
├── imgs
├── LongAlpaca.png
├── Shift-short-attention2.png
├── data-distribution-in-longalpaca12k.png
├── demo-compare-harrypotter.png
├── demo-compare-journeytothewest.png
├── demo-compare-threebody.png
├── economy-comparison.png
├── economy-prediction.png
├── paper-improvements.png
├── paper-review.png
└── paper-style-compare-cvpr-iclr.png
├── inference-qlora.py
├── inference.py
├── llama_attn_replace.py
├── llama_attn_replace_sft.py
├── merge_lora_weights_and_save_hf_model.py
├── passkey_retrivial.py
├── pdf2txt
├── README.md
├── backbone.py
├── beit.py
├── config.py
├── configs
│ ├── Base-RCNN-FPN.yaml
│ └── cascade_dit_large.yaml
├── pdf2txt.py
└── requirements.txt
├── requirements.txt
├── run_streaming_llama_longalpaca.py
├── streaming_llm
├── __init__.py
├── enable_streaming_llm.py
├── kv_cache.py
├── pos_shift
│ ├── __init__.py
│ ├── modify_falcon.py
│ ├── modify_gpt_neox.py
│ └── modify_llama.py
└── utils.py
├── supervised-fine-tune-qlora.py
└── supervised-fine-tune.py
/DATA_LICENSE:
--------------------------------------------------------------------------------
1 | Attribution-NonCommercial 4.0 International
2 |
3 | =======================================================================
4 |
5 | Creative Commons Corporation ("Creative Commons") is not a law firm and
6 | does not provide legal services or legal advice. Distribution of
7 | Creative Commons public licenses does not create a lawyer-client or
8 | other relationship. Creative Commons makes its licenses and related
9 | information available on an "as-is" basis. Creative Commons gives no
10 | warranties regarding its licenses, any material licensed under their
11 | terms and conditions, or any related information. Creative Commons
12 | disclaims all liability for damages resulting from their use to the
13 | fullest extent possible.
14 |
15 | Using Creative Commons Public Licenses
16 |
17 | Creative Commons public licenses provide a standard set of terms and
18 | conditions that creators and other rights holders may use to share
19 | original works of authorship and other material subject to copyright
20 | and certain other rights specified in the public license below. The
21 | following considerations are for informational purposes only, are not
22 | exhaustive, and do not form part of our licenses.
23 |
24 | Considerations for licensors: Our public licenses are
25 | intended for use by those authorized to give the public
26 | permission to use material in ways otherwise restricted by
27 | copyright and certain other rights. Our licenses are
28 | irrevocable. Licensors should read and understand the terms
29 | and conditions of the license they choose before applying it.
30 | Licensors should also secure all rights necessary before
31 | applying our licenses so that the public can reuse the
32 | material as expected. Licensors should clearly mark any
33 | material not subject to the license. This includes other CC-
34 | licensed material, or material used under an exception or
35 | limitation to copyright. More considerations for licensors:
36 | wiki.creativecommons.org/Considerations_for_licensors
37 |
38 | Considerations for the public: By using one of our public
39 | licenses, a licensor grants the public permission to use the
40 | licensed material under specified terms and conditions. If
41 | the licensor's permission is not necessary for any reason--for
42 | example, because of any applicable exception or limitation to
43 | copyright--then that use is not regulated by the license. Our
44 | licenses grant only permissions under copyright and certain
45 | other rights that a licensor has authority to grant. Use of
46 | the licensed material may still be restricted for other
47 | reasons, including because others have copyright or other
48 | rights in the material. A licensor may make special requests,
49 | such as asking that all changes be marked or described.
50 | Although not required by our licenses, you are encouraged to
51 | respect those requests where reasonable. More considerations
52 | for the public:
53 | wiki.creativecommons.org/Considerations_for_licensees
54 |
55 | =======================================================================
56 |
57 | Creative Commons Attribution-NonCommercial 4.0 International Public
58 | License
59 |
60 | By exercising the Licensed Rights (defined below), You accept and agree
61 | to be bound by the terms and conditions of this Creative Commons
62 | Attribution-NonCommercial 4.0 International Public License ("Public
63 | License"). To the extent this Public License may be interpreted as a
64 | contract, You are granted the Licensed Rights in consideration of Your
65 | acceptance of these terms and conditions, and the Licensor grants You
66 | such rights in consideration of benefits the Licensor receives from
67 | making the Licensed Material available under these terms and
68 | conditions.
69 |
70 |
71 | Section 1 -- Definitions.
72 |
73 | a. Adapted Material means material subject to Copyright and Similar
74 | Rights that is derived from or based upon the Licensed Material
75 | and in which the Licensed Material is translated, altered,
76 | arranged, transformed, or otherwise modified in a manner requiring
77 | permission under the Copyright and Similar Rights held by the
78 | Licensor. For purposes of this Public License, where the Licensed
79 | Material is a musical work, performance, or sound recording,
80 | Adapted Material is always produced where the Licensed Material is
81 | synched in timed relation with a moving image.
82 |
83 | b. Adapter's License means the license You apply to Your Copyright
84 | and Similar Rights in Your contributions to Adapted Material in
85 | accordance with the terms and conditions of this Public License.
86 |
87 | c. Copyright and Similar Rights means copyright and/or similar rights
88 | closely related to copyright including, without limitation,
89 | performance, broadcast, sound recording, and Sui Generis Database
90 | Rights, without regard to how the rights are labeled or
91 | categorized. For purposes of this Public License, the rights
92 | specified in Section 2(b)(1)-(2) are not Copyright and Similar
93 | Rights.
94 | d. Effective Technological Measures means those measures that, in the
95 | absence of proper authority, may not be circumvented under laws
96 | fulfilling obligations under Article 11 of the WIPO Copyright
97 | Treaty adopted on December 20, 1996, and/or similar international
98 | agreements.
99 |
100 | e. Exceptions and Limitations means fair use, fair dealing, and/or
101 | any other exception or limitation to Copyright and Similar Rights
102 | that applies to Your use of the Licensed Material.
103 |
104 | f. Licensed Material means the artistic or literary work, database,
105 | or other material to which the Licensor applied this Public
106 | License.
107 |
108 | g. Licensed Rights means the rights granted to You subject to the
109 | terms and conditions of this Public License, which are limited to
110 | all Copyright and Similar Rights that apply to Your use of the
111 | Licensed Material and that the Licensor has authority to license.
112 |
113 | h. Licensor means the individual(s) or entity(ies) granting rights
114 | under this Public License.
115 |
116 | i. NonCommercial means not primarily intended for or directed towards
117 | commercial advantage or monetary compensation. For purposes of
118 | this Public License, the exchange of the Licensed Material for
119 | other material subject to Copyright and Similar Rights by digital
120 | file-sharing or similar means is NonCommercial provided there is
121 | no payment of monetary compensation in connection with the
122 | exchange.
123 |
124 | j. Share means to provide material to the public by any means or
125 | process that requires permission under the Licensed Rights, such
126 | as reproduction, public display, public performance, distribution,
127 | dissemination, communication, or importation, and to make material
128 | available to the public including in ways that members of the
129 | public may access the material from a place and at a time
130 | individually chosen by them.
131 |
132 | k. Sui Generis Database Rights means rights other than copyright
133 | resulting from Directive 96/9/EC of the European Parliament and of
134 | the Council of 11 March 1996 on the legal protection of databases,
135 | as amended and/or succeeded, as well as other essentially
136 | equivalent rights anywhere in the world.
137 |
138 | l. You means the individual or entity exercising the Licensed Rights
139 | under this Public License. Your has a corresponding meaning.
140 |
141 |
142 | Section 2 -- Scope.
143 |
144 | a. License grant.
145 |
146 | 1. Subject to the terms and conditions of this Public License,
147 | the Licensor hereby grants You a worldwide, royalty-free,
148 | non-sublicensable, non-exclusive, irrevocable license to
149 | exercise the Licensed Rights in the Licensed Material to:
150 |
151 | a. reproduce and Share the Licensed Material, in whole or
152 | in part, for NonCommercial purposes only; and
153 |
154 | b. produce, reproduce, and Share Adapted Material for
155 | NonCommercial purposes only.
156 |
157 | 2. Exceptions and Limitations. For the avoidance of doubt, where
158 | Exceptions and Limitations apply to Your use, this Public
159 | License does not apply, and You do not need to comply with
160 | its terms and conditions.
161 |
162 | 3. Term. The term of this Public License is specified in Section
163 | 6(a).
164 |
165 | 4. Media and formats; technical modifications allowed. The
166 | Licensor authorizes You to exercise the Licensed Rights in
167 | all media and formats whether now known or hereafter created,
168 | and to make technical modifications necessary to do so. The
169 | Licensor waives and/or agrees not to assert any right or
170 | authority to forbid You from making technical modifications
171 | necessary to exercise the Licensed Rights, including
172 | technical modifications necessary to circumvent Effective
173 | Technological Measures. For purposes of this Public License,
174 | simply making modifications authorized by this Section 2(a)
175 | (4) never produces Adapted Material.
176 |
177 | 5. Downstream recipients.
178 |
179 | a. Offer from the Licensor -- Licensed Material. Every
180 | recipient of the Licensed Material automatically
181 | receives an offer from the Licensor to exercise the
182 | Licensed Rights under the terms and conditions of this
183 | Public License.
184 |
185 | b. No downstream restrictions. You may not offer or impose
186 | any additional or different terms or conditions on, or
187 | apply any Effective Technological Measures to, the
188 | Licensed Material if doing so restricts exercise of the
189 | Licensed Rights by any recipient of the Licensed
190 | Material.
191 |
192 | 6. No endorsement. Nothing in this Public License constitutes or
193 | may be construed as permission to assert or imply that You
194 | are, or that Your use of the Licensed Material is, connected
195 | with, or sponsored, endorsed, or granted official status by,
196 | the Licensor or others designated to receive attribution as
197 | provided in Section 3(a)(1)(A)(i).
198 |
199 | b. Other rights.
200 |
201 | 1. Moral rights, such as the right of integrity, are not
202 | licensed under this Public License, nor are publicity,
203 | privacy, and/or other similar personality rights; however, to
204 | the extent possible, the Licensor waives and/or agrees not to
205 | assert any such rights held by the Licensor to the limited
206 | extent necessary to allow You to exercise the Licensed
207 | Rights, but not otherwise.
208 |
209 | 2. Patent and trademark rights are not licensed under this
210 | Public License.
211 |
212 | 3. To the extent possible, the Licensor waives any right to
213 | collect royalties from You for the exercise of the Licensed
214 | Rights, whether directly or through a collecting society
215 | under any voluntary or waivable statutory or compulsory
216 | licensing scheme. In all other cases the Licensor expressly
217 | reserves any right to collect such royalties, including when
218 | the Licensed Material is used other than for NonCommercial
219 | purposes.
220 |
221 |
222 | Section 3 -- License Conditions.
223 |
224 | Your exercise of the Licensed Rights is expressly made subject to the
225 | following conditions.
226 |
227 | a. Attribution.
228 |
229 | 1. If You Share the Licensed Material (including in modified
230 | form), You must:
231 |
232 | a. retain the following if it is supplied by the Licensor
233 | with the Licensed Material:
234 |
235 | i. identification of the creator(s) of the Licensed
236 | Material and any others designated to receive
237 | attribution, in any reasonable manner requested by
238 | the Licensor (including by pseudonym if
239 | designated);
240 |
241 | ii. a copyright notice;
242 |
243 | iii. a notice that refers to this Public License;
244 |
245 | iv. a notice that refers to the disclaimer of
246 | warranties;
247 |
248 | v. a URI or hyperlink to the Licensed Material to the
249 | extent reasonably practicable;
250 |
251 | b. indicate if You modified the Licensed Material and
252 | retain an indication of any previous modifications; and
253 |
254 | c. indicate the Licensed Material is licensed under this
255 | Public License, and include the text of, or the URI or
256 | hyperlink to, this Public License.
257 |
258 | 2. You may satisfy the conditions in Section 3(a)(1) in any
259 | reasonable manner based on the medium, means, and context in
260 | which You Share the Licensed Material. For example, it may be
261 | reasonable to satisfy the conditions by providing a URI or
262 | hyperlink to a resource that includes the required
263 | information.
264 |
265 | 3. If requested by the Licensor, You must remove any of the
266 | information required by Section 3(a)(1)(A) to the extent
267 | reasonably practicable.
268 |
269 | 4. If You Share Adapted Material You produce, the Adapter's
270 | License You apply must not prevent recipients of the Adapted
271 | Material from complying with this Public License.
272 |
273 |
274 | Section 4 -- Sui Generis Database Rights.
275 |
276 | Where the Licensed Rights include Sui Generis Database Rights that
277 | apply to Your use of the Licensed Material:
278 |
279 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right
280 | to extract, reuse, reproduce, and Share all or a substantial
281 | portion of the contents of the database for NonCommercial purposes
282 | only;
283 |
284 | b. if You include all or a substantial portion of the database
285 | contents in a database in which You have Sui Generis Database
286 | Rights, then the database in which You have Sui Generis Database
287 | Rights (but not its individual contents) is Adapted Material; and
288 |
289 | c. You must comply with the conditions in Section 3(a) if You Share
290 | all or a substantial portion of the contents of the database.
291 |
292 | For the avoidance of doubt, this Section 4 supplements and does not
293 | replace Your obligations under this Public License where the Licensed
294 | Rights include other Copyright and Similar Rights.
295 |
296 |
297 | Section 5 -- Disclaimer of Warranties and Limitation of Liability.
298 |
299 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
300 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
301 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
302 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
303 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
304 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
305 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
306 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
307 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
308 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
309 |
310 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
311 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
312 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
313 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
314 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
315 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
316 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
317 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
318 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
319 |
320 | c. The disclaimer of warranties and limitation of liability provided
321 | above shall be interpreted in a manner that, to the extent
322 | possible, most closely approximates an absolute disclaimer and
323 | waiver of all liability.
324 |
325 |
326 | Section 6 -- Term and Termination.
327 |
328 | a. This Public License applies for the term of the Copyright and
329 | Similar Rights licensed here. However, if You fail to comply with
330 | this Public License, then Your rights under this Public License
331 | terminate automatically.
332 |
333 | b. Where Your right to use the Licensed Material has terminated under
334 | Section 6(a), it reinstates:
335 |
336 | 1. automatically as of the date the violation is cured, provided
337 | it is cured within 30 days of Your discovery of the
338 | violation; or
339 |
340 | 2. upon express reinstatement by the Licensor.
341 |
342 | For the avoidance of doubt, this Section 6(b) does not affect any
343 | right the Licensor may have to seek remedies for Your violations
344 | of this Public License.
345 |
346 | c. For the avoidance of doubt, the Licensor may also offer the
347 | Licensed Material under separate terms or conditions or stop
348 | distributing the Licensed Material at any time; however, doing so
349 | will not terminate this Public License.
350 |
351 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
352 | License.
353 |
354 |
355 | Section 7 -- Other Terms and Conditions.
356 |
357 | a. The Licensor shall not be bound by any additional or different
358 | terms or conditions communicated by You unless expressly agreed.
359 |
360 | b. Any arrangements, understandings, or agreements regarding the
361 | Licensed Material not stated herein are separate from and
362 | independent of the terms and conditions of this Public License.
363 |
364 |
365 | Section 8 -- Interpretation.
366 |
367 | a. For the avoidance of doubt, this Public License does not, and
368 | shall not be interpreted to, reduce, limit, restrict, or impose
369 | conditions on any use of the Licensed Material that could lawfully
370 | be made without permission under this Public License.
371 |
372 | b. To the extent possible, if any provision of this Public License is
373 | deemed unenforceable, it shall be automatically reformed to the
374 | minimum extent necessary to make it enforceable. If the provision
375 | cannot be reformed, it shall be severed from this Public License
376 | without affecting the enforceability of the remaining terms and
377 | conditions.
378 |
379 | c. No term or condition of this Public License will be waived and no
380 | failure to comply consented to unless expressly agreed to by the
381 | Licensor.
382 |
383 | d. Nothing in this Public License constitutes or may be interpreted
384 | as a limitation upon, or waiver of, any privileges and immunities
385 | that apply to the Licensor or You, including from the legal
386 | processes of any jurisdiction or authority.
387 |
388 | =======================================================================
389 |
390 | Creative Commons is not a party to its public
391 | licenses. Notwithstanding, Creative Commons may elect to apply one of
392 | its public licenses to material it publishes and in those instances
393 | will be considered the “Licensor.” The text of the Creative Commons
394 | public licenses is dedicated to the public domain under the CC0 Public
395 | Domain Dedication. Except for the limited purpose of indicating that
396 | material is shared under a Creative Commons public license or as
397 | otherwise permitted by the Creative Commons policies published at
398 | creativecommons.org/policies, Creative Commons does not authorize the
399 | use of the trademark "Creative Commons" or any other trademark or logo
400 | of Creative Commons without its prior written consent including,
401 | without limitation, in connection with any unauthorized modifications
402 | to any of its public licenses or any other arrangements,
403 | understandings, or agreements concerning use of licensed material. For
404 | the avoidance of doubt, this paragraph does not form part of the
405 | public licenses.
406 |
407 | Creative Commons may be contacted at creativecommons.org.
408 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/benchmarks/Pred_LongBench/2wikimqa.jsonl:
--------------------------------------------------------------------------------
1 | {"pred": "Izmit", "answers": ["Ozalj"], "all_classes": null, "length": 4696}
2 | {"pred": "Elizabeth", "answers": ["John the Baptist"], "all_classes": null, "length": 4776}
3 | {"pred": "Sam Spiegel Film and Television School", "answers": ["Cahiers du cinéma"], "all_classes": null, "length": 4274}
4 | {"pred": "England", "answers": ["no"], "all_classes": null, "length": 8125}
5 | {"pred": "1483", "answers": ["1510"], "all_classes": null, "length": 4621}
6 | {"pred": "Abd al-Muttalib", "answers": ["Edward Watson"], "all_classes": null, "length": 4625}
7 | {"pred": "1275", "answers": ["16 September 1360"], "all_classes": null, "length": 5001}
8 | {"pred": "Helmichis's father-in-law is Alboin.", "answers": ["Cunimund"], "all_classes": null, "length": 7639}
9 | {"pred": " Dublin", "answers": ["St Patrick's College"], "all_classes": null, "length": 3964}
10 | {"pred": "Wine of Morning director Katherine Stenholm worked at Unusual Films.", "answers": ["Bob Jones University"], "all_classes": null, "length": 5162}
11 | {"pred": "The House Of The Seven Hawks", "answers": ["The House Of The Seven Hawks"], "all_classes": null, "length": 10338}
12 | {"pred": "Marie Of Brabant's paternal grandmother is Marie Of Brabant.", "answers": ["Marie of Hohenstaufen"], "all_classes": null, "length": 3596}
13 | {"pred": "Adelaide, Countess of Soissons", "answers": ["Guy II, Count of Soissons"], "all_classes": null, "length": 1280}
14 | {"pred": "1839", "answers": ["26 April 1872"], "all_classes": null, "length": 3432}
15 | {"pred": "United States", "answers": ["America"], "all_classes": null, "length": 4442}
16 | {"pred": "Larry Parks", "answers": ["Ona Munson"], "all_classes": null, "length": 10444}
17 | {"pred": "The Death of Black King (1932)", "answers": ["The Death Of Black King"], "all_classes": null, "length": 1696}
18 | {"pred": "Yes", "answers": ["no"], "all_classes": null, "length": 5932}
19 | {"pred": " No", "answers": ["no"], "all_classes": null, "length": 535}
20 | {"pred": "Kanneshwara Rama", "answers": ["Mysore"], "all_classes": null, "length": 3532}
21 | {"pred": "Hamar", "answers": ["Kristiania"], "all_classes": null, "length": 2782}
22 | {"pred": "Kathy Griffin graduated from Oak Park High School.", "answers": ["Lee Strasberg Theatre and Film Institute"], "all_classes": null, "length": 7578}
23 | {"pred": "Hell Up In Harlem", "answers": ["Hell Up In Harlem"], "all_classes": null, "length": 9013}
24 | {"pred": "Cipriano Castro", "answers": ["Cipriano Castro"], "all_classes": null, "length": 6026}
25 | {"pred": "The Year Of The Rabbit", "answers": ["Monster On The Campus"], "all_classes": null, "length": 6192}
26 | {"pred": "Younger", "answers": ["Altuğ Çelikbilek"], "all_classes": null, "length": 889}
27 | {"pred": "Joel The Lump Of Coal", "answers": ["Jugband Blues"], "all_classes": null, "length": 3517}
28 | {"pred": "Yes", "answers": ["yes"], "all_classes": null, "length": 4141}
29 | {"pred": "Baldwin I Rátót", "answers": ["Leustach Rátót"], "all_classes": null, "length": 3948}
30 | {"pred": "Kantipur", "answers": ["Nepal"], "all_classes": null, "length": 4625}
31 | {"pred": "Archibald Acheson, 1st Earl of Gosford", "answers": ["Archibald Acheson, 2nd Earl of Gosford"], "all_classes": null, "length": 4383}
32 | {"pred": "Abd al-Muttalib", "answers": ["Jean Paul Getty"], "all_classes": null, "length": 5464}
33 | {"pred": "Goring-on-Thames", "answers": ["Goring-on-Thames, Oxfordshire"], "all_classes": null, "length": 11686}
34 | {"pred": "British", "answers": ["yes"], "all_classes": null, "length": 3122}
35 | {"pred": "The Abduction Club", "answers": ["Wooden Crosses"], "all_classes": null, "length": 4154}
36 | {"pred": "The Magic Aster", "answers": ["Above Rubies"], "all_classes": null, "length": 3299}
37 | {"pred": "Blue Blood And Red", "answers": ["Blue Blood And Red"], "all_classes": null, "length": 4436}
38 | {"pred": "Paris, France", "answers": ["Castlemaine, Victoria, Australia"], "all_classes": null, "length": 2896}
39 | {"pred": "Paul De Scherff", "answers": ["Lyudmyla Olyanovska"], "all_classes": null, "length": 6416}
40 | {"pred": "Eyüp Cemetery, Istanbul", "answers": ["Constantinople"], "all_classes": null, "length": 4769}
41 | {"pred": "Melun", "answers": ["East Francia"], "all_classes": null, "length": 4957}
42 | {"pred": "Tex And The Lord Of The Deep", "answers": ["Henry Goes Arizona"], "all_classes": null, "length": 2540}
43 | {"pred": "Wales", "answers": ["United Kingdom"], "all_classes": null, "length": 8759}
44 | {"pred": "All-American Co-Ed", "answers": ["All-American Co-Ed"], "all_classes": null, "length": 5527}
45 | {"pred": "Buenos Aires", "answers": ["Buenos Aires"], "all_classes": null, "length": 3859}
46 | {"pred": "Louise's mother-in-law is Duchess Magdalene Sibylle of Holstein-Gottorp.", "answers": ["Charlotte Amalie of Hesse-Kassel"], "all_classes": null, "length": 5695}
47 | {"pred": " Stahleck Castle", "answers": ["Brunswick"], "all_classes": null, "length": 3954}
48 | {"pred": "Marcus Annius Libo's aunt is Rupilia Faustina.", "answers": ["Vibia Sabina"], "all_classes": null, "length": 3690}
49 | {"pred": "Riding the California Trail", "answers": ["Bajo Otro Sol"], "all_classes": null, "length": 1120}
50 | {"pred": "Lisa Azuelos", "answers": ["Marie Laforêt"], "all_classes": null, "length": 3219}
51 | {"pred": "Val Kilmer", "answers": ["Sandra Nelson"], "all_classes": null, "length": 10498}
52 | {"pred": "2 March 1702", "answers": ["May 19, 1669"], "all_classes": null, "length": 4570}
53 | {"pred": "Pamplona", "answers": ["Palencia"], "all_classes": null, "length": 3496}
54 | {"pred": " Dance With A Stranger", "answers": ["Miley Naa Miley Hum"], "all_classes": null, "length": 3934}
55 | {"pred": "Space Probe Taurus", "answers": ["Tom Mix In Arabia"], "all_classes": null, "length": 3324}
56 | {"pred": "Daughter of the Jungle", "answers": ["Seven In The Sun"], "all_classes": null, "length": 2708}
57 | {"pred": "Dr. Socrates", "answers": ["Dr. Socrates"], "all_classes": null, "length": 4218}
58 | {"pred": "Changeland", "answers": ["Changeland"], "all_classes": null, "length": 6736}
59 | {"pred": "Peter Rosegger", "answers": ["Ruel Redinger"], "all_classes": null, "length": 2046}
60 | {"pred": "Pyotr Karatygin's sibling-in-law is Vasily Karatygin.", "answers": ["Alexandra Kolosova"], "all_classes": null, "length": 2898}
61 | {"pred": " Maxine Caroll Lawrence", "answers": ["Jessi Colter"], "all_classes": null, "length": 7833}
62 | {"pred": " No", "answers": ["yes"], "all_classes": null, "length": 2351}
63 | {"pred": "William Pooley", "answers": ["William Pooley"], "all_classes": null, "length": 2189}
64 | {"pred": " French", "answers": ["Sweden"], "all_classes": null, "length": 4845}
65 | {"pred": " American", "answers": ["yes"], "all_classes": null, "length": 1412}
66 | {"pred": "Melody Of The World", "answers": ["Melody Of The World"], "all_classes": null, "length": 4784}
67 | {"pred": "Madrid", "answers": ["Madrid"], "all_classes": null, "length": 2868}
68 | {"pred": "Oskar Roehler", "answers": ["Gisela Elsner"], "all_classes": null, "length": 3211}
69 | {"pred": "The Great Man's Lady", "answers": ["La Belle Américaine"], "all_classes": null, "length": 4142}
70 | {"pred": "Oklahoma City, Oklahoma", "answers": ["Oklahoma City, Oklahoma"], "all_classes": null, "length": 3277}
71 | {"pred": "Dubai", "answers": ["Dubai"], "all_classes": null, "length": 4240}
72 | {"pred": "France", "answers": ["La Trinité"], "all_classes": null, "length": 4966}
73 | {"pred": "Duke Paul Frederick of Mecklenburg-Schwerin", "answers": ["Prince Albert of Prussia"], "all_classes": null, "length": 4350}
74 | {"pred": "Yes", "answers": ["yes"], "all_classes": null, "length": 562}
75 | {"pred": "Pembroke Castle", "answers": ["Banbury"], "all_classes": null, "length": 3847}
76 | {"pred": "Sir Paul Gore, 1st Baronet", "answers": ["Sir Paul Gore, 1st Baronet"], "all_classes": null, "length": 2779}
77 | {"pred": "Ireland", "answers": ["Fatimid caliphate"], "all_classes": null, "length": 8314}
78 | {"pred": "Santa Monica", "answers": ["Dayton, Ohio"], "all_classes": null, "length": 3761}
79 | {"pred": "Bill Graham", "answers": ["Ona Munson"], "all_classes": null, "length": 11373}
80 | {"pred": "Pigsty (Film)", "answers": ["The Drover'S Sweetheart"], "all_classes": null, "length": 3866}
81 | {"pred": "Fayetteville, Arkansas", "answers": ["Tumkur"], "all_classes": null, "length": 3459}
82 | {"pred": "Closely Watched Trains", "answers": ["Det Sande Ansigt"], "all_classes": null, "length": 4713}
83 | {"pred": "Il Gaucho", "answers": ["Bomgay"], "all_classes": null, "length": 4261}
84 | {"pred": "Pliska", "answers": ["Preslav"], "all_classes": null, "length": 5869}
85 | {"pred": "The Pyramid", "answers": ["Revolt Of The Praetorians"], "all_classes": null, "length": 3504}
86 | {"pred": "France", "answers": ["yes"], "all_classes": null, "length": 10582}
87 | {"pred": "Eric XIV of Sweden died of arsenic poisoning.", "answers": ["poisoning"], "all_classes": null, "length": 5584}
88 | {"pred": "American", "answers": ["America"], "all_classes": null, "length": 5631}
89 | {"pred": "Kamehameha I", "answers": ["Kingdom of Hawaii"], "all_classes": null, "length": 8829}
90 | {"pred": "Mi Novia Está De Madre", "answers": ["X-Paroni"], "all_classes": null, "length": 2674}
91 | {"pred": "Vasantha Raagam", "answers": ["Vasantha Raagam"], "all_classes": null, "length": 6052}
92 | {"pred": " Lee Kun-Hee", "answers": ["Lee Byung-chul"], "all_classes": null, "length": 6337}
93 | {"pred": "Charles I, Duke of Bourbon", "answers": ["John I, Duke of Bourbon"], "all_classes": null, "length": 4501}
94 | {"pred": "1839", "answers": ["27 June 1839"], "all_classes": null, "length": 10742}
95 | {"pred": "Nathan Juran", "answers": ["Gura Humorului"], "all_classes": null, "length": 3888}
96 | {"pred": "Amandine Bourgeois", "answers": ["Françoise Hardy"], "all_classes": null, "length": 5091}
97 | {"pred": "Charles Wheatstone", "answers": ["Charles Wheatstone"], "all_classes": null, "length": 8386}
98 | {"pred": "Perryville, Missouri", "answers": ["Washington"], "all_classes": null, "length": 1917}
99 | {"pred": "London Melody", "answers": ["London Melody"], "all_classes": null, "length": 9611}
100 | {"pred": "Diego Fernández de Oviedo", "answers": ["Flaín Muñoz"], "all_classes": null, "length": 3208}
101 | {"pred": "Michael Sorvino", "answers": ["Mira Sorvino"], "all_classes": null, "length": 4398}
102 | {"pred": "Yes", "answers": ["yes"], "all_classes": null, "length": 1863}
103 | {"pred": "Denmark", "answers": ["Norway"], "all_classes": null, "length": 3211}
104 | {"pred": "Death by illness", "answers": ["illness"], "all_classes": null, "length": 4438}
105 | {"pred": "She died in Amesbury Priory.", "answers": ["Harby"], "all_classes": null, "length": 10562}
106 | {"pred": "The director of film Tiramisu (2002) earned the Hundred Flowers Award for Best Director.", "answers": ["Hong Kong Film Award for Best Director"], "all_classes": null, "length": 3401}
107 | {"pred": "192", "answers": ["3 September 1992"], "all_classes": null, "length": 1680}
108 | {"pred": "Nathaniel McLenaghan", "answers": ["Nathaniel Mclenaghan"], "all_classes": null, "length": 3026}
109 | {"pred": "Tisch School of the Arts", "answers": ["Tisch"], "all_classes": null, "length": 3749}
110 | {"pred": "Cuchillos De Fuego", "answers": ["Cuchillos De Fuego"], "all_classes": null, "length": 2697}
111 | {"pred": "The Ballad Of Josie", "answers": ["Moment Of Danger"], "all_classes": null, "length": 4003}
112 | {"pred": "De AS", "answers": ["De As"], "all_classes": null, "length": 2425}
113 | {"pred": "The Piper's Price", "answers": ["The Piper'S Price"], "all_classes": null, "length": 4134}
114 | {"pred": " American", "answers": ["yes"], "all_classes": null, "length": 5918}
115 | {"pred": "1753", "answers": ["13 March 1753"], "all_classes": null, "length": 4323}
116 | {"pred": "True To The Navy", "answers": ["No Trees In The Street"], "all_classes": null, "length": 7162}
117 | {"pred": "Malayalam", "answers": ["Methala"], "all_classes": null, "length": 3129}
118 | {"pred": "House of Dark Shadows", "answers": ["Alkohol"], "all_classes": null, "length": 5610}
119 | {"pred": "Do Musafir", "answers": ["Do Musafir"], "all_classes": null, "length": 1138}
120 | {"pred": " Yes", "answers": ["no"], "all_classes": null, "length": 1241}
121 | {"pred": "New York City", "answers": ["New York"], "all_classes": null, "length": 2516}
122 | {"pred": "Tiger In The Smoke", "answers": ["Contragolpe"], "all_classes": null, "length": 3675}
123 | {"pred": "Mumbai", "answers": ["Mumbai"], "all_classes": null, "length": 3052}
124 | {"pred": "The Comedians of Comedy", "answers": ["The Comedians Of Comedy"], "all_classes": null, "length": 4756}
125 | {"pred": "Tombstone Rashomon", "answers": ["Tombstone Rashomon"], "all_classes": null, "length": 5772}
126 | {"pred": "Dhuen Ki Lakeer", "answers": ["Dhuen Ki Lakeer"], "all_classes": null, "length": 4828}
127 | {"pred": "Perdón, viejita", "answers": ["Perdón, Viejita"], "all_classes": null, "length": 10456}
128 | {"pred": "University of Wisconsin-Madison", "answers": ["University of Wisconsin"], "all_classes": null, "length": 2748}
129 | {"pred": "Dudley Russell", "answers": ["Dudley Russell"], "all_classes": null, "length": 4526}
130 | {"pred": "Vytautas Straižys", "answers": ["Mirjam Polkunen"], "all_classes": null, "length": 3620}
131 | {"pred": "Russia", "answers": ["Saint Petersburg"], "all_classes": null, "length": 9479}
132 | {"pred": "Menno Meyjes", "answers": ["Eindhoven"], "all_classes": null, "length": 3592}
133 | {"pred": "Women's Suffrage Journal", "answers": ["Women'S Suffrage Journal"], "all_classes": null, "length": 3828}
134 | {"pred": "Fairmont, West Virginia", "answers": ["Fairmont, West Virginia"], "all_classes": null, "length": 3228}
135 | {"pred": "The Market Of Souls", "answers": ["The Market Of Souls"], "all_classes": null, "length": 2305}
136 | {"pred": "Yes", "answers": ["yes"], "all_classes": null, "length": 4201}
137 | {"pred": "Marrakech", "answers": ["Morocco"], "all_classes": null, "length": 4417}
138 | {"pred": "Beaulieu-sur-Loire]", "answers": ["Akkadian empire"], "all_classes": null, "length": 8701}
139 | {"pred": "Tarzan The Magnificent", "answers": ["Tarzan The Magnificent"], "all_classes": null, "length": 4580}
140 | {"pred": "Wolf Warrior", "answers": ["Wolf Warrior"], "all_classes": null, "length": 3774}
141 | {"pred": "Trịnh Cương", "answers": ["Trịnh Bính"], "all_classes": null, "length": 3718}
142 | {"pred": "Nathan Juran", "answers": ["Gura Humorului"], "all_classes": null, "length": 4854}
143 | {"pred": "Bernie Bonvois", "answers": ["Crenshaw"], "all_classes": null, "length": 7562}
144 | {"pred": "Winter Sleepers", "answers": ["Winter Sleepers"], "all_classes": null, "length": 3065}
145 | {"pred": "Bruno II of Isenburg-Braunsberg", "answers": ["Adolf II of Berg"], "all_classes": null, "length": 4040}
146 | {"pred": " No", "answers": ["no"], "all_classes": null, "length": 6771}
147 | {"pred": "Helge Schneider", "answers": ["Mülheim an der Ruhr"], "all_classes": null, "length": 3140}
148 | {"pred": "Una Prostituta Al Servizio Del Pubblico E In Regola Con Le Leggi Dello Stato", "answers": ["Una Prostituta Al Servizio Del Pubblico E In Regola Con Le Leggi Dello Stato"], "all_classes": null, "length": 2809}
149 | {"pred": " Yes", "answers": ["yes"], "all_classes": null, "length": 10991}
150 | {"pred": "Prince Bhanurangsi Savangwongse", "answers": ["Mongkut"], "all_classes": null, "length": 3759}
151 | {"pred": "Elizabeth Blount", "answers": ["Lady Frances Manners"], "all_classes": null, "length": 3293}
152 | {"pred": "L'Arbre, le maire et la médiathèque]", "answers": ["Cahiers du cinéma"], "all_classes": null, "length": 9453}
153 | {"pred": "Mexico", "answers": ["Newport, Wales"], "all_classes": null, "length": 3945}
154 | {"pred": "Orange County, Virginia", "answers": ["Gordonsville, Virginia"], "all_classes": null, "length": 9096}
155 | {"pred": "Romania", "answers": ["Mangalia"], "all_classes": null, "length": 8046}
156 | {"pred": "Pier-Luc Funk", "answers": ["Félix Leclerc"], "all_classes": null, "length": 4871}
157 | {"pred": "Henri I, Duke of Nemours", "answers": ["Jacques"], "all_classes": null, "length": 5577}
158 | {"pred": "Gisela Of Hungary's father-in-law is Stephen I Of Hungary.", "answers": ["Géza"], "all_classes": null, "length": 11622}
159 | {"pred": "Thuya", "answers": ["Yi Jiang"], "all_classes": null, "length": 4142}
160 | {"pred": "5 September 1896", "answers": ["November 3, 1867"], "all_classes": null, "length": 6374}
161 | {"pred": "Guy Newall", "answers": ["Ivy Duke"], "all_classes": null, "length": 10243}
162 | {"pred": "181", "answers": ["1 August 1813"], "all_classes": null, "length": 2765}
163 | {"pred": "The Yellow Teddy Bears", "answers": ["Season Of Strangers"], "all_classes": null, "length": 6966}
164 | {"pred": " American", "answers": ["yes"], "all_classes": null, "length": 4450}
165 | {"pred": "Abdul Ali Lalu", "answers": ["Antoine Casavant"], "all_classes": null, "length": 3131}
166 | {"pred": "Allahabad", "answers": ["Oxford"], "all_classes": null, "length": 4103}
167 | {"pred": "Sir John Gage", "answers": ["Sir Hew Dalrymple, 2nd Baronet"], "all_classes": null, "length": 2932}
168 | {"pred": "Khud-Daar", "answers": ["Murderer In The Fog"], "all_classes": null, "length": 2692}
169 | {"pred": "Paul W. S. Anderson", "answers": ["Milla"], "all_classes": null, "length": 7443}
170 | {"pred": "Naples", "answers": ["Lecce"], "all_classes": null, "length": 3523}
171 | {"pred": "Lady Magdalene's director, J. Neil Schulman, won the \"Special Jury Prize for Libertarian Ideals\" at the 20", "answers": ["Prometheus Award"], "all_classes": null, "length": 3529}
172 | {"pred": "German", "answers": ["Duchy of Saxe-Gotha-Altenburg"], "all_classes": null, "length": 4847}
173 | {"pred": "The Longshots", "answers": ["Station For Two"], "all_classes": null, "length": 9998}
174 | {"pred": "Inverkeithing, Fife, Scotland", "answers": ["Inverkeithing"], "all_classes": null, "length": 4097}
175 | {"pred": "1234", "answers": ["1234"], "all_classes": null, "length": 9182}
176 | {"pred": "1642", "answers": ["25 July 1666"], "all_classes": null, "length": 6136}
177 | {"pred": "Special Delivery (1927 Film)", "answers": ["The Wind'S Fierce"], "all_classes": null, "length": 9043}
178 | {"pred": "He died by suicide by drowning himself in his swimming pool.", "answers": ["drowning"], "all_classes": null, "length": 8169}
179 | {"pred": "Abu Dhabi", "answers": ["Abu Dhabi"], "all_classes": null, "length": 3593}
180 | {"pred": "Ottensen", "answers": ["Prenzlau"], "all_classes": null, "length": 3584}
181 | {"pred": "Poland", "answers": ["Poznań"], "all_classes": null, "length": 3814}
182 | {"pred": "Picton", "answers": ["Italy"], "all_classes": null, "length": 4597}
183 | {"pred": "New Zealand", "answers": ["Toronto, Ontario"], "all_classes": null, "length": 3366}
184 | {"pred": "The Third Kiss", "answers": ["Forbidden Daughters"], "all_classes": null, "length": 4389}
185 | {"pred": "Yes", "answers": ["yes"], "all_classes": null, "length": 4958}
186 | {"pred": "Bogdan Țăruș", "answers": ["Bogdan Țăruș"], "all_classes": null, "length": 3425}
187 | {"pred": "Folgore Division", "answers": ["Sandflow"], "all_classes": null, "length": 1945}
188 | {"pred": "Durango Valley Raiders", "answers": ["From Corleone To Brooklyn"], "all_classes": null, "length": 3894}
189 | {"pred": " No", "answers": ["yes"], "all_classes": null, "length": 667}
190 | {"pred": "Pacific Palisades", "answers": ["Vienna"], "all_classes": null, "length": 10974}
191 | {"pred": "Johnny Ekström", "answers": ["Johnny Ekström"], "all_classes": null, "length": 6044}
192 | {"pred": " No", "answers": ["yes"], "all_classes": null, "length": 592}
193 | {"pred": "English", "answers": ["England"], "all_classes": null, "length": 11950}
194 | {"pred": "Norman Panama died on January 13, 2003 due to complications of Parkinson's disease.", "answers": ["Parkinson"], "all_classes": null, "length": 4337}
195 | {"pred": "Thomas De Berkeley, 5th Baron Berkeley", "answers": ["Roger Mortimer, 1st Earl of March"], "all_classes": null, "length": 3947}
196 | {"pred": "Ali Dinar", "answers": ["Albertus Seba"], "all_classes": null, "length": 4441}
197 | {"pred": "Henry Krause", "answers": ["Ferdinand I of the Two Sicilies"], "all_classes": null, "length": 8935}
198 | {"pred": "Prince Of Arcadia", "answers": ["Prince Of Arcadia"], "all_classes": null, "length": 3155}
199 | {"pred": "Italy", "answers": ["Milan"], "all_classes": null, "length": 3195}
200 | {"pred": "Germany", "answers": ["Köln"], "all_classes": null, "length": 3719}
201 |
--------------------------------------------------------------------------------
/benchmarks/Pred_LongBench/passage_retrieval_en.jsonl:
--------------------------------------------------------------------------------
1 | {"pred": "\nParagraph 15", "answers": ["Paragraph 15"], "all_classes": null, "length": 8525}
2 | {"pred": "\nParagraph 28", "answers": ["Paragraph 8"], "all_classes": null, "length": 8679}
3 | {"pred": "Paragraph 17]", "answers": ["Paragraph 18"], "all_classes": null, "length": 9562}
4 | {"pred": "\nParagraph 17", "answers": ["Paragraph 17"], "all_classes": null, "length": 9273}
5 | {"pred": " Paragraph 14", "answers": ["Paragraph 14"], "all_classes": null, "length": 9175}
6 | {"pred": "\nParagraph 1", "answers": ["Paragraph 3"], "all_classes": null, "length": 9139}
7 | {"pred": "\nParagraph 24", "answers": ["Paragraph 24"], "all_classes": null, "length": 8832}
8 | {"pred": "Paragraph 28", "answers": ["Paragraph 27"], "all_classes": null, "length": 8995}
9 | {"pred": "\nParagraph 30\nThe embodiment of the Gopher mascot came to life in 1952 when University of Minnesota assistant band", "answers": ["Paragraph 30"], "all_classes": null, "length": 7789}
10 | {"pred": "Paragraph 20]", "answers": ["Paragraph 12"], "all_classes": null, "length": 10027}
11 | {"pred": "Paragraph 2", "answers": ["Paragraph 22"], "all_classes": null, "length": 9273}
12 | {"pred": "Paragraph 20]", "answers": ["Paragraph 30"], "all_classes": null, "length": 10178}
13 | {"pred": "Paragraph 24", "answers": ["Paragraph 24"], "all_classes": null, "length": 9221}
14 | {"pred": "Paragraph 17", "answers": ["Paragraph 19"], "all_classes": null, "length": 9718}
15 | {"pred": "Paragraph 20", "answers": ["Paragraph 3"], "all_classes": null, "length": 11340}
16 | {"pred": "Paragraph 1]", "answers": ["Paragraph 8"], "all_classes": null, "length": 9750}
17 | {"pred": "Paragraph 1", "answers": ["Paragraph 3"], "all_classes": null, "length": 8965}
18 | {"pred": "Paragraph 17]", "answers": ["Paragraph 17"], "all_classes": null, "length": 8965}
19 | {"pred": "\nParagraph 1", "answers": ["Paragraph 4"], "all_classes": null, "length": 9275}
20 | {"pred": "Paragraph 1]", "answers": ["Paragraph 3"], "all_classes": null, "length": 10934}
21 | {"pred": "Paragraph 24]", "answers": ["Paragraph 24"], "all_classes": null, "length": 9713}
22 | {"pred": "Paragraph 2", "answers": ["Paragraph 2"], "all_classes": null, "length": 9947}
23 | {"pred": "Paragraph 1", "answers": ["Paragraph 6"], "all_classes": null, "length": 8991}
24 | {"pred": "Paragraph 17]", "answers": ["Paragraph 12"], "all_classes": null, "length": 9211}
25 | {"pred": "Paragraph 24", "answers": ["Paragraph 24"], "all_classes": null, "length": 8267}
26 | {"pred": "\nParagraph 1", "answers": ["Paragraph 30"], "all_classes": null, "length": 8180}
27 | {"pred": "Paragraph 1]", "answers": ["Paragraph 1"], "all_classes": null, "length": 10576}
28 | {"pred": "Paragraph 1", "answers": ["Paragraph 6"], "all_classes": null, "length": 9513}
29 | {"pred": "Paragraph 1", "answers": ["Paragraph 8"], "all_classes": null, "length": 9747}
30 | {"pred": "Paragraph 29", "answers": ["Paragraph 3"], "all_classes": null, "length": 7534}
31 | {"pred": "Paragraph 1", "answers": ["Paragraph 1"], "all_classes": null, "length": 9908}
32 | {"pred": "Paragraph 27", "answers": ["Paragraph 25"], "all_classes": null, "length": 8838}
33 | {"pred": "Paragraph 17]", "answers": ["Paragraph 19"], "all_classes": null, "length": 10016}
34 | {"pred": "\nParagraph 20", "answers": ["Paragraph 22"], "all_classes": null, "length": 7992}
35 | {"pred": "Paragraph 17]", "answers": ["Paragraph 5"], "all_classes": null, "length": 9899}
36 | {"pred": "Paragraph 19", "answers": ["Paragraph 15"], "all_classes": null, "length": 9483}
37 | {"pred": "Paragraph 28]", "answers": ["Paragraph 28"], "all_classes": null, "length": 10118}
38 | {"pred": "Paragraph 1", "answers": ["Paragraph 5"], "all_classes": null, "length": 8975}
39 | {"pred": "Paragraph 1", "answers": ["Paragraph 30"], "all_classes": null, "length": 9934}
40 | {"pred": "\nParagraph 10", "answers": ["Paragraph 6"], "all_classes": null, "length": 7736}
41 | {"pred": "Paragraph 1", "answers": ["Paragraph 12"], "all_classes": null, "length": 8771}
42 | {"pred": "Paragraph 27]", "answers": ["Paragraph 26"], "all_classes": null, "length": 10004}
43 | {"pred": "\nParagraph 20", "answers": ["Paragraph 30"], "all_classes": null, "length": 9000}
44 | {"pred": "Paragraph 1]", "answers": ["Paragraph 6"], "all_classes": null, "length": 10255}
45 | {"pred": "\nParagraph 14", "answers": ["Paragraph 12"], "all_classes": null, "length": 8776}
46 | {"pred": "Paragraph 23]", "answers": ["Paragraph 23"], "all_classes": null, "length": 10742}
47 | {"pred": "Paragraph 1]", "answers": ["Paragraph 17"], "all_classes": null, "length": 10320}
48 | {"pred": "Paragraph 1", "answers": ["Paragraph 1"], "all_classes": null, "length": 8971}
49 | {"pred": " Paragraph 27", "answers": ["Paragraph 27"], "all_classes": null, "length": 10277}
50 | {"pred": "Paragraph 21]", "answers": ["Paragraph 21"], "all_classes": null, "length": 9525}
51 | {"pred": "\nParagraph 27", "answers": ["Paragraph 27"], "all_classes": null, "length": 8674}
52 | {"pred": "Paragraph 29", "answers": ["Paragraph 29"], "all_classes": null, "length": 8629}
53 | {"pred": "Paragraph 29", "answers": ["Paragraph 28"], "all_classes": null, "length": 9970}
54 | {"pred": "Paragraph 1", "answers": ["Paragraph 7"], "all_classes": null, "length": 8791}
55 | {"pred": "Paragraph 17]", "answers": ["Paragraph 15"], "all_classes": null, "length": 9949}
56 | {"pred": "Paragraph 20]", "answers": ["Paragraph 22"], "all_classes": null, "length": 9832}
57 | {"pred": "Paragraph 16", "answers": ["Paragraph 16"], "all_classes": null, "length": 8952}
58 | {"pred": "\nParagraph 14", "answers": ["Paragraph 14"], "all_classes": null, "length": 8253}
59 | {"pred": "Paragraph 29", "answers": ["Paragraph 29"], "all_classes": null, "length": 10431}
60 | {"pred": "\nParagraph 19", "answers": ["Paragraph 19"], "all_classes": null, "length": 10196}
61 | {"pred": "Paragraph 20]", "answers": ["Paragraph 28"], "all_classes": null, "length": 10633}
62 | {"pred": "\nParagraph 25", "answers": ["Paragraph 26"], "all_classes": null, "length": 9824}
63 | {"pred": "Paragraph 1", "answers": ["Paragraph 1"], "all_classes": null, "length": 9324}
64 | {"pred": "\nParagraph 1", "answers": ["Paragraph 10"], "all_classes": null, "length": 10299}
65 | {"pred": " Paragraph 17", "answers": ["Paragraph 8"], "all_classes": null, "length": 9489}
66 | {"pred": "\nParagraph 1", "answers": ["Paragraph 3"], "all_classes": null, "length": 8281}
67 | {"pred": " Paragraph 27", "answers": ["Paragraph 27"], "all_classes": null, "length": 8804}
68 | {"pred": "Paragraph 16]", "answers": ["Paragraph 15"], "all_classes": null, "length": 8473}
69 | {"pred": "Paragraph 17]", "answers": ["Paragraph 21"], "all_classes": null, "length": 9216}
70 | {"pred": "Paragraph 28", "answers": ["Paragraph 28"], "all_classes": null, "length": 10941}
71 | {"pred": "Paragraph 1]", "answers": ["Paragraph 5"], "all_classes": null, "length": 10198}
72 | {"pred": "Paragraph 1", "answers": ["Paragraph 9"], "all_classes": null, "length": 8526}
73 | {"pred": "Paragraph 2", "answers": ["Paragraph 20"], "all_classes": null, "length": 8039}
74 | {"pred": "\nParagraph 17", "answers": ["Paragraph 15"], "all_classes": null, "length": 9207}
75 | {"pred": "Paragraph 1", "answers": ["Paragraph 6"], "all_classes": null, "length": 9357}
76 | {"pred": "Paragraph 20]", "answers": ["Paragraph 23"], "all_classes": null, "length": 9713}
77 | {"pred": "\nParagraph 1", "answers": ["Paragraph 26"], "all_classes": null, "length": 8114}
78 | {"pred": "Paragraph 1", "answers": ["Paragraph 18"], "all_classes": null, "length": 9148}
79 | {"pred": "\nParagraph 1", "answers": ["Paragraph 11"], "all_classes": null, "length": 8795}
80 | {"pred": "Paragraph 19", "answers": ["Paragraph 18"], "all_classes": null, "length": 8491}
81 | {"pred": "Paragraph 17", "answers": ["Paragraph 7"], "all_classes": null, "length": 8642}
82 | {"pred": "Paragraph 24", "answers": ["Paragraph 24"], "all_classes": null, "length": 11270}
83 | {"pred": "Paragraph 2", "answers": ["Paragraph 23"], "all_classes": null, "length": 9267}
84 | {"pred": "Paragraph 27", "answers": ["Paragraph 27"], "all_classes": null, "length": 8747}
85 | {"pred": "Paragraph 17]", "answers": ["Paragraph 18"], "all_classes": null, "length": 10176}
86 | {"pred": "\nParagraph 29", "answers": ["Paragraph 28"], "all_classes": null, "length": 8712}
87 | {"pred": "Paragraph 10", "answers": ["Paragraph 3"], "all_classes": null, "length": 8315}
88 | {"pred": "Paragraph 17]", "answers": ["Paragraph 5"], "all_classes": null, "length": 9403}
89 | {"pred": "Paragraph 1", "answers": ["Paragraph 27"], "all_classes": null, "length": 10242}
90 | {"pred": "Paragraph 24", "answers": ["Paragraph 24"], "all_classes": null, "length": 8143}
91 | {"pred": "Paragraph 1]", "answers": ["Paragraph 29"], "all_classes": null, "length": 9508}
92 | {"pred": "\nParagraph 23", "answers": ["Paragraph 23"], "all_classes": null, "length": 7817}
93 | {"pred": " Paragraph 17", "answers": ["Paragraph 17"], "all_classes": null, "length": 7533}
94 | {"pred": "Paragraph 17", "answers": ["Paragraph 16"], "all_classes": null, "length": 9056}
95 | {"pred": "Paragraph 19]", "answers": ["Paragraph 18"], "all_classes": null, "length": 9565}
96 | {"pred": "Paragraph 17]", "answers": ["Paragraph 13"], "all_classes": null, "length": 10347}
97 | {"pred": "Paragraph 20]", "answers": ["Paragraph 20"], "all_classes": null, "length": 9260}
98 | {"pred": "Paragraph 17]", "answers": ["Paragraph 18"], "all_classes": null, "length": 11516}
99 | {"pred": "Paragraph 23]", "answers": ["Paragraph 23"], "all_classes": null, "length": 9334}
100 | {"pred": "Paragraph 1]", "answers": ["Paragraph 19"], "all_classes": null, "length": 10614}
101 | {"pred": "Paragraph 19]", "answers": ["Paragraph 19"], "all_classes": null, "length": 9251}
102 | {"pred": "Paragraph 1", "answers": ["Paragraph 30"], "all_classes": null, "length": 9922}
103 | {"pred": "Paragraph 17]", "answers": ["Paragraph 13"], "all_classes": null, "length": 9459}
104 | {"pred": " Paragraph 1", "answers": ["Paragraph 9"], "all_classes": null, "length": 9025}
105 | {"pred": "Paragraph 2", "answers": ["Paragraph 22"], "all_classes": null, "length": 9717}
106 | {"pred": "Paragraph 1", "answers": ["Paragraph 11"], "all_classes": null, "length": 9446}
107 | {"pred": "Paragraph 1", "answers": ["Paragraph 3"], "all_classes": null, "length": 10308}
108 | {"pred": "\nParagraph 10\nThe airport is northeast of downtown in east Omaha. Although the airport is in Nebraska", "answers": ["Paragraph 12"], "all_classes": null, "length": 8558}
109 | {"pred": "Paragraph 17]", "answers": ["Paragraph 10"], "all_classes": null, "length": 9485}
110 | {"pred": "Paragraph 20", "answers": ["Paragraph 3"], "all_classes": null, "length": 8673}
111 | {"pred": "Paragraph 1", "answers": ["Paragraph 10"], "all_classes": null, "length": 9147}
112 | {"pred": "Paragraph 1", "answers": ["Paragraph 2"], "all_classes": null, "length": 10298}
113 | {"pred": "\nParagraph 27", "answers": ["Paragraph 27"], "all_classes": null, "length": 8716}
114 | {"pred": "\nParagraph 2", "answers": ["Paragraph 7"], "all_classes": null, "length": 9171}
115 | {"pred": "Paragraph 19", "answers": ["Paragraph 20"], "all_classes": null, "length": 9357}
116 | {"pred": "Paragraph 25", "answers": ["Paragraph 25"], "all_classes": null, "length": 10597}
117 | {"pred": "\nParagraph 1", "answers": ["Paragraph 30"], "all_classes": null, "length": 8070}
118 | {"pred": "\nParagraph 1", "answers": ["Paragraph 9"], "all_classes": null, "length": 9095}
119 | {"pred": "Paragraph 17", "answers": ["Paragraph 14"], "all_classes": null, "length": 8867}
120 | {"pred": "Paragraph 1", "answers": ["Paragraph 8"], "all_classes": null, "length": 9895}
121 | {"pred": " Paragraph 1", "answers": ["Paragraph 8"], "all_classes": null, "length": 9072}
122 | {"pred": "Paragraph 1", "answers": ["Paragraph 19"], "all_classes": null, "length": 10852}
123 | {"pred": "Paragraph 2", "answers": ["Paragraph 22"], "all_classes": null, "length": 9767}
124 | {"pred": "\nParagraph 1", "answers": ["Paragraph 1"], "all_classes": null, "length": 9547}
125 | {"pred": "Paragraph 1", "answers": ["Paragraph 7"], "all_classes": null, "length": 8081}
126 | {"pred": "\nParagraph 1", "answers": ["Paragraph 7"], "all_classes": null, "length": 7860}
127 | {"pred": "Paragraph 28]", "answers": ["Paragraph 28"], "all_classes": null, "length": 9423}
128 | {"pred": "\nParagraph 1", "answers": ["Paragraph 11"], "all_classes": null, "length": 7818}
129 | {"pred": "Paragraph 17]", "answers": ["Paragraph 18"], "all_classes": null, "length": 10417}
130 | {"pred": "Paragraph 1", "answers": ["Paragraph 9"], "all_classes": null, "length": 8644}
131 | {"pred": "Paragraph 14", "answers": ["Paragraph 16"], "all_classes": null, "length": 8429}
132 | {"pred": "\nParagraph 1", "answers": ["Paragraph 8"], "all_classes": null, "length": 8931}
133 | {"pred": " Paragraph 1", "answers": ["Paragraph 27"], "all_classes": null, "length": 9098}
134 | {"pred": " Paragraph 1", "answers": ["Paragraph 7"], "all_classes": null, "length": 8751}
135 | {"pred": "Paragraph 26]", "answers": ["Paragraph 26"], "all_classes": null, "length": 9296}
136 | {"pred": "Paragraph 24]", "answers": ["Paragraph 24"], "all_classes": null, "length": 9412}
137 | {"pred": "Paragraph 20]", "answers": ["Paragraph 20"], "all_classes": null, "length": 9968}
138 | {"pred": "Paragraph 19]", "answers": ["Paragraph 29"], "all_classes": null, "length": 9105}
139 | {"pred": "\nParagraph 28", "answers": ["Paragraph 28"], "all_classes": null, "length": 8924}
140 | {"pred": "Paragraph 17]", "answers": ["Paragraph 6"], "all_classes": null, "length": 9991}
141 | {"pred": "Paragraph 1", "answers": ["Paragraph 1"], "all_classes": null, "length": 10601}
142 | {"pred": "Paragraph 1", "answers": ["Paragraph 4"], "all_classes": null, "length": 8559}
143 | {"pred": "Paragraph 10]", "answers": ["Paragraph 29"], "all_classes": null, "length": 10041}
144 | {"pred": " Paragraph 1", "answers": ["Paragraph 11"], "all_classes": null, "length": 8851}
145 | {"pred": "Paragraph 7]", "answers": ["Paragraph 7"], "all_classes": null, "length": 10303}
146 | {"pred": "Paragraph 2", "answers": ["Paragraph 22"], "all_classes": null, "length": 9065}
147 | {"pred": "\nParagraph 14", "answers": ["Paragraph 8"], "all_classes": null, "length": 8976}
148 | {"pred": "\nParagraph 28", "answers": ["Paragraph 28"], "all_classes": null, "length": 8863}
149 | {"pred": " Paragraph 1", "answers": ["Paragraph 23"], "all_classes": null, "length": 9616}
150 | {"pred": "Paragraph 1", "answers": ["Paragraph 8"], "all_classes": null, "length": 9063}
151 | {"pred": "Paragraph 1", "answers": ["Paragraph 11"], "all_classes": null, "length": 8464}
152 | {"pred": " Paragraph 29", "answers": ["Paragraph 28"], "all_classes": null, "length": 9939}
153 | {"pred": " Paragraph 14", "answers": ["Paragraph 16"], "all_classes": null, "length": 7539}
154 | {"pred": "Paragraph 1", "answers": ["Paragraph 2"], "all_classes": null, "length": 9321}
155 | {"pred": "Paragraph 1", "answers": ["Paragraph 1"], "all_classes": null, "length": 9976}
156 | {"pred": "Paragraph 1", "answers": ["Paragraph 2"], "all_classes": null, "length": 7814}
157 | {"pred": "Paragraph 28", "answers": ["Paragraph 28"], "all_classes": null, "length": 10274}
158 | {"pred": "\nParagraph 20", "answers": ["Paragraph 7"], "all_classes": null, "length": 9091}
159 | {"pred": "Paragraph 19]", "answers": ["Paragraph 27"], "all_classes": null, "length": 10887}
160 | {"pred": "Paragraph 1", "answers": ["Paragraph 16"], "all_classes": null, "length": 9628}
161 | {"pred": "\nParagraph 1", "answers": ["Paragraph 1"], "all_classes": null, "length": 8060}
162 | {"pred": "\nParagraph 27", "answers": ["Paragraph 26"], "all_classes": null, "length": 9395}
163 | {"pred": "\nParagraph 10", "answers": ["Paragraph 6"], "all_classes": null, "length": 8232}
164 | {"pred": "Paragraph 19]", "answers": ["Paragraph 15"], "all_classes": null, "length": 9132}
165 | {"pred": "\nParagraph 17", "answers": ["Paragraph 14"], "all_classes": null, "length": 8557}
166 | {"pred": " Paragraph 25", "answers": ["Paragraph 25"], "all_classes": null, "length": 9374}
167 | {"pred": "Paragraph 1", "answers": ["Paragraph 6"], "all_classes": null, "length": 9024}
168 | {"pred": "\nParagraph 2", "answers": ["Paragraph 2"], "all_classes": null, "length": 8492}
169 | {"pred": "Paragraph 1", "answers": ["Paragraph 5"], "all_classes": null, "length": 8944}
170 | {"pred": "\nParagraph 10", "answers": ["Paragraph 12"], "all_classes": null, "length": 8243}
171 | {"pred": "Paragraph 17", "answers": ["Paragraph 12"], "all_classes": null, "length": 8452}
172 | {"pred": "Paragraph 10", "answers": ["Paragraph 8"], "all_classes": null, "length": 8418}
173 | {"pred": "Paragraph 1", "answers": ["Paragraph 20"], "all_classes": null, "length": 9471}
174 | {"pred": "\nParagraph 25", "answers": ["Paragraph 25"], "all_classes": null, "length": 8888}
175 | {"pred": "\nParagraph 2", "answers": ["Paragraph 23"], "all_classes": null, "length": 8088}
176 | {"pred": "Paragraph 30", "answers": ["Paragraph 30"], "all_classes": null, "length": 10318}
177 | {"pred": "Paragraph 1", "answers": ["Paragraph 7"], "all_classes": null, "length": 10754}
178 | {"pred": " Paragraph 1", "answers": ["Paragraph 10"], "all_classes": null, "length": 8055}
179 | {"pred": "\nParagraph 17", "answers": ["Paragraph 30"], "all_classes": null, "length": 9471}
180 | {"pred": "Paragraph 10", "answers": ["Paragraph 9"], "all_classes": null, "length": 9415}
181 | {"pred": "Paragraph 17]", "answers": ["Paragraph 17"], "all_classes": null, "length": 10382}
182 | {"pred": "Paragraph 1", "answers": ["Paragraph 30"], "all_classes": null, "length": 10098}
183 | {"pred": "Paragraph 18]", "answers": ["Paragraph 10"], "all_classes": null, "length": 10972}
184 | {"pred": "Paragraph 27", "answers": ["Paragraph 26"], "all_classes": null, "length": 8937}
185 | {"pred": "\nParagraph 27", "answers": ["Paragraph 27"], "all_classes": null, "length": 8275}
186 | {"pred": "Paragraph 1", "answers": ["Paragraph 26"], "all_classes": null, "length": 9032}
187 | {"pred": "Paragraph 17", "answers": ["Paragraph 8"], "all_classes": null, "length": 7990}
188 | {"pred": "Paragraph 1", "answers": ["Paragraph 30"], "all_classes": null, "length": 9969}
189 | {"pred": "\nParagraph 17", "answers": ["Paragraph 13"], "all_classes": null, "length": 9356}
190 | {"pred": "Paragraph 1", "answers": ["Paragraph 12"], "all_classes": null, "length": 9827}
191 | {"pred": "Paragraph 24]", "answers": ["Paragraph 24"], "all_classes": null, "length": 8886}
192 | {"pred": "Paragraph 1", "answers": ["Paragraph 6"], "all_classes": null, "length": 7953}
193 | {"pred": "Paragraph 1", "answers": ["Paragraph 24"], "all_classes": null, "length": 9019}
194 | {"pred": "Paragraph 24]", "answers": ["Paragraph 25"], "all_classes": null, "length": 9668}
195 | {"pred": "Paragraph 24]", "answers": ["Paragraph 24"], "all_classes": null, "length": 9446}
196 | {"pred": "Paragraph 1]", "answers": ["Paragraph 1"], "all_classes": null, "length": 9333}
197 | {"pred": "Paragraph 25]", "answers": ["Paragraph 25"], "all_classes": null, "length": 9006}
198 | {"pred": "Paragraph 2", "answers": ["Paragraph 2"], "all_classes": null, "length": 10134}
199 | {"pred": "\nParagraph 27", "answers": ["Paragraph 27"], "all_classes": null, "length": 8214}
200 | {"pred": "Paragraph 1]", "answers": ["Paragraph 4"], "all_classes": null, "length": 10149}
201 |
--------------------------------------------------------------------------------
/benchmarks/Pred_LongBench/result.json:
--------------------------------------------------------------------------------
1 | {
2 | "2wikimqa": 30.26,
3 | "passage_retrieval_en": 29.75,
4 | "passage_retrieval_zh": 3.96,
5 | "qasper": 29.1,
6 | "passage_count": 3.61,
7 | "gov_report": 31.53,
8 | "multifieldqa_zh": 8.48,
9 | "trec": 63.5,
10 | "multifieldqa_en": 37.15,
11 | "lsht": 26.0,
12 | "dureader": 15.25,
13 | "narrativeqa": 19.8,
14 | "lcc": 57.61,
15 | "musique": 17.14,
16 | "multi_news": 27.74,
17 | "qmsum": 24.13,
18 | "vcsum": 0.46,
19 | "samsum": 41.88,
20 | "repobench-p": 54.45,
21 | "triviaqa": 85.69,
22 | "hotpotqa": 37.01
23 | }
--------------------------------------------------------------------------------
/benchmarks/README.md:
--------------------------------------------------------------------------------
1 | # Evaluation on LongBench and L-Eval Benchmarks
2 |
3 | We evaluate our supervised fine-tuned model, [LongAlpaca-7B-16k](https://huggingface.co/Yukang/LongAlpaca-7B-16k), on LongBench and L-Eval benchmarks.
4 |
5 | Table - Evaluation on LongBench English tasks
6 | | Model | Avg | Single-Doc QA | Multi-Doc QA | Summarization | Few-shot Learning | Code | Synthetic |
7 | | --- | --- | --- | --- | --- | --- | --- | --- |
8 | | GPT-3.5-Turbo | 44.0 | 39.8 | 38.7 | 26.5 | 67.1 | 54.1 | 37.8 |
9 | | Llama2-7B-chat | 31.0 | 24.9 | 22.6 | 24.7 | 60.0 | 48.1 | 5.9 |
10 | | Ours | 36.8 | 28.7 | 28.1 | 27.8 | 63.7 | 56.0 | 16.7 |
11 |
12 | The predictions can be found [here](https://github.com/dvlab-research/LongLoRA/tree/main/benchmarks/Pred_LongBench).
13 |
14 |
15 | Table 2 - Evaluation on L-Eval open-ended tasks, comparing to GPT-3.5-Turbo and judging win rates via GPT-4.
16 | | Model | Win-rate | Wins | Ties |
17 | | --- | --- | --- | --- |
18 | | Ours | 39.06 | 45 | 60 |
19 |
20 | The predictions can be found [here](https://github.com/dvlab-research/LongLoRA/tree/main/benchmarks/Pred_L-Eval).
21 |
--------------------------------------------------------------------------------
/demo.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import math
4 | import torch
5 | import argparse
6 | import textwrap
7 | import transformers
8 | from peft import PeftModel
9 | from transformers import GenerationConfig, TextIteratorStreamer
10 | from llama_attn_replace import replace_llama_attn
11 | from threading import Thread
12 | import gradio as gr
13 |
14 |
15 | def parse_config():
16 | parser = argparse.ArgumentParser(description='arg parser')
17 | parser.add_argument('--base_model', type=str, default="/data1/pretrained-models/llama-7b-hf")
18 | parser.add_argument('--cache_dir', type=str, default="./cache")
19 | parser.add_argument('--context_size', type=int, default=-1, help='context size during fine-tuning')
20 | parser.add_argument('--flash_attn', type=bool, default=True, help='')
21 | parser.add_argument('--temperature', type=float, default=0.6, help='')
22 | parser.add_argument('--top_p', type=float, default=0.9, help='')
23 | parser.add_argument('--max_gen_len', type=int, default=512, help='')
24 | parser.add_argument("--host", type=str, default="localhost")
25 | parser.add_argument("--port", type=int, default=8898)
26 | args = parser.parse_args()
27 | return args
28 |
29 | title = "LongLoRA and LongAlpaca for Long-context LLMs"
30 |
31 | description = """
32 |
33 | This is the online demo of LongLoRA. \n
34 | If multiple users are using it at the same time, they will enter a queue, which may delay some time. \n
35 | **Inputs**:
36 | - **Input material txt** and **Question** are required.
37 | **Note**:
38 | - The demo model is **LongAlpaca-7B**. We use 4-bit quantization for low GPU memory inference, which may impair text-generation quality.
39 | - There are 10 book-related examples and 5 paper-related examples, 15 in total.
40 | - Note that only txt file is currently support.\n
41 | **Example questions**:
42 | Please summarize the book in one paragraph.
43 | Please tell me that what high-level idea the author want to indicate in this book.
44 | Please describe the relationship among the roles in the book.
45 | Please summarize the paper in one paragraph.
46 | What is the main contribution of this paper?
47 | Hope you can enjoy our work!
48 |
49 | """
50 |
51 | # Gradio
52 | article = """
53 |
54 |
55 | Preprint Paper
56 |
57 | \n
58 |
59 | Github Repo
60 | """
61 |
62 | PROMPT_DICT = {
63 | "prompt_no_input": (
64 | "Below is an instruction that describes a task. "
65 | "Write a response that appropriately completes the request.\n\n"
66 | "### Instruction:\n{instruction}\n\n### Response:"
67 | ),
68 | "prompt_no_input_llama2":(
69 | "[INST] <>\n"
70 | "You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\n"
71 | "If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\n"
72 | "<> \n\n {instruction} [/INST]"
73 | ),
74 | }
75 |
76 |
77 | def read_txt_file(material_txt):
78 | content = ""
79 | with open(material_txt) as f:
80 | for line in f.readlines():
81 | content += line
82 | return content
83 |
84 | def build_generator(
85 | model, tokenizer, temperature=0.6, top_p=0.9, max_gen_len=4096, use_cache=True
86 | ):
87 | def response(material, question):
88 | if material is None:
89 | return "Only support txt file."
90 |
91 | if not material.name.split(".")[-1]=='txt':
92 | return "Only support txt file."
93 |
94 | material = read_txt_file(material.name)
95 | prompt_no_input = PROMPT_DICT["prompt_no_input_llama2"]
96 | prompt = prompt_no_input.format_map({"instruction": material + "\n%s" % question})
97 |
98 | inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
99 |
100 | if len(inputs['input_ids'][0]) > 32768:
101 | return "This demo supports tokens less than 32768, while the current is %d. Please use material with less tokens."%len(inputs['input_ids'][0])
102 | torch.cuda.empty_cache()
103 |
104 | streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
105 | generate_kwargs = dict(**inputs,
106 | max_new_tokens=max_gen_len,
107 | temperature=temperature,
108 | top_p=top_p,
109 | use_cache=use_cache,
110 | streamer=streamer,
111 | )
112 |
113 | t = Thread(target=model.generate, kwargs=generate_kwargs)
114 | t.start()
115 |
116 | generated_text = ""
117 | for new_text in streamer:
118 | generated_text += new_text
119 | yield generated_text
120 | return generated_text
121 |
122 | return response
123 |
124 | def main(args):
125 | if args.flash_attn:
126 | replace_llama_attn(inference=True)
127 |
128 | # Set RoPE scaling factor
129 | config = transformers.AutoConfig.from_pretrained(
130 | args.base_model,
131 | cache_dir=args.cache_dir,
132 | )
133 |
134 | orig_ctx_len = getattr(config, "max_position_embeddings", None)
135 | if orig_ctx_len and args.context_size > orig_ctx_len:
136 | scaling_factor = float(math.ceil(args.context_size / orig_ctx_len))
137 | config.rope_scaling = {"type": "linear", "factor": scaling_factor}
138 |
139 | # Load model and tokenizer
140 | model = transformers.AutoModelForCausalLM.from_pretrained(
141 | args.base_model,
142 | config=config,
143 | cache_dir=args.cache_dir,
144 | torch_dtype=torch.float16,
145 | load_in_4bit=True,
146 | device_map="auto",
147 | )
148 | model.resize_token_embeddings(32001)
149 |
150 | tokenizer = transformers.AutoTokenizer.from_pretrained(
151 | args.base_model,
152 | cache_dir=args.cache_dir,
153 | model_max_length=args.context_size if args.context_size > orig_ctx_len else orig_ctx_len,
154 | padding_side="right",
155 | use_fast=False,
156 | )
157 |
158 | model.eval()
159 | if torch.__version__ >= "2" and sys.platform != "win32":
160 | model = torch.compile(model)
161 | # import pdb; pdb.set_trace()
162 | respond = build_generator(model, tokenizer, temperature=args.temperature, top_p=args.top_p,
163 | max_gen_len=args.max_gen_len, use_cache=True)
164 |
165 | demo = gr.Interface(
166 | respond,
167 | inputs=[
168 | gr.File(type="file", label="Input material txt"),
169 | gr.Textbox(lines=1, placeholder=None, label="Question"),
170 | ],
171 | outputs=[
172 | gr.Textbox(lines=1, placeholder=None, label="Text Output"),
173 | ],
174 | title=title,
175 | description=description,
176 | article=article,
177 | allow_flagging="auto",
178 | )
179 |
180 | demo.queue()
181 | demo.launch(server_name=args.host, server_port=args.port, show_error=True, share=True)
182 |
183 | if __name__ == "__main__":
184 | args = parse_config()
185 | main(args)
186 |
--------------------------------------------------------------------------------
/ds_configs/stage2.json:
--------------------------------------------------------------------------------
1 | {
2 | "train_micro_batch_size_per_gpu": "auto",
3 | "gradient_accumulation_steps": "auto",
4 | "gradient_clipping": "auto",
5 | "zero_allow_untested_optimizer": true,
6 | "bf16": {
7 | "enabled": "auto",
8 | "loss_scale": 0,
9 | "initial_scale_power": 16,
10 | "loss_scale_window": 1000,
11 | "hysteresis": 2,
12 | "min_loss_scale": 1
13 | },
14 | "zero_optimization": {
15 | "stage": 2,
16 | "allgather_partitions": true,
17 | "allgather_bucket_size": 1e9,
18 | "reduce_scatter": true,
19 | "reduce_bucket_size": 1e9,
20 | "overlap_comm": true,
21 | "contiguous_gradients": true
22 | }
23 | }
24 |
--------------------------------------------------------------------------------
/ds_configs/stage3.json:
--------------------------------------------------------------------------------
1 | {
2 | "bf16": {
3 | "enabled": "auto"
4 | },
5 | "optimizer": {
6 | "type": "AdamW",
7 | "params": {
8 | "lr": "auto",
9 | "betas": "auto",
10 | "eps": "auto",
11 | "weight_decay": "auto"
12 | }
13 | },
14 | "scheduler": {
15 | "type": "WarmupDecayLR",
16 | "params": {
17 | "total_num_steps": "auto",
18 | "warmup_min_lr": "auto",
19 | "warmup_max_lr": "auto",
20 | "warmup_num_steps": "auto"
21 | }
22 | },
23 | "zero_optimization": {
24 | "stage": 3,
25 | "offload_optimizer": {
26 | "device": "cpu",
27 | "pin_memory": true
28 | },
29 | "offload_param": {
30 | "device": "cpu",
31 | "pin_memory": true
32 | },
33 | "overlap_comm": true,
34 | "contiguous_gradients": true,
35 | "sub_group_size": 1e9,
36 | "reduce_bucket_size": "auto",
37 | "stage3_prefetch_bucket_size": "auto",
38 | "stage3_param_persistence_threshold": "auto",
39 | "stage3_max_live_parameters": 1e9,
40 | "stage3_max_reuse_distance": 1e9,
41 | "stage3_gather_16bit_weights_on_model_save": false
42 | },
43 | "gradient_accumulation_steps": "auto",
44 | "gradient_clipping": "auto",
45 | "steps_per_print": 5,
46 | "train_batch_size": "auto",
47 | "train_micro_batch_size_per_gpu": "auto",
48 | "wall_clock_breakdown": false
49 | }
50 |
--------------------------------------------------------------------------------
/eval.py:
--------------------------------------------------------------------------------
1 | # Written by Yukang Chen
2 | # Some code based on https://github.com/epfml/landmark-attention
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import os
17 | import math
18 | import torch
19 | import argparse
20 | import random
21 | import numpy as np
22 | from tqdm import tqdm
23 | import transformers
24 | from peft import PeftModel
25 | from llama_attn_replace import replace_llama_attn
26 |
27 | def parse_config():
28 | parser = argparse.ArgumentParser(description='arg parser')
29 | parser.add_argument('--batch_size', type=int, default=32, help='batch size during inference')
30 | parser.add_argument('--base_model', type=str, default="/data1/pretrained-models/llama-7b-hf")
31 | parser.add_argument('--cache_dir', type=str, default="./cache")
32 | parser.add_argument('--seq_len', type=int, default=2048, help='context length during evaluation')
33 | parser.add_argument('--context_size', type=int, default=-1, help='context size during fine-tuning')
34 | parser.add_argument('--peft_model', type=str, default=None, help='')
35 | parser.add_argument('--flash_attn', type=bool, default=True, help='')
36 | parser.add_argument('--data_path', type=str, default="./test.bin", help='')
37 | args = parser.parse_args()
38 | return args
39 |
40 | def get_as_batch(data, seq_length, batch_size, device='cpu', sliding_window=256):
41 | all_ix = list(range(0, len(data) - seq_length, sliding_window))
42 | all_ix.pop()
43 |
44 | for idx in range(0, len(all_ix), batch_size):
45 | ix = all_ix[idx:idx+batch_size]
46 | assert all([idx + seq_length + 1 <= len(data) for idx in ix])
47 | x = torch.stack([torch.from_numpy((data[i:i+seq_length]).astype(np.int64)) for i in ix])
48 | y = torch.stack([torch.from_numpy((data[i+1:i+1+seq_length]).astype(np.int64)) for i in ix])
49 | if device != 'cpu':
50 | x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True)
51 | yield x, y
52 |
53 | def iceildiv(x, y):
54 | return (x + y - 1) // y
55 |
56 | def evaluate(model, data, batch_size, device, seq_length, sliding_window=256, use_cache=False):
57 | stats = {}
58 |
59 | model.eval()
60 |
61 | loss_list_val, acc_list = [], []
62 | loss_step_list_val = []
63 |
64 | with torch.no_grad():
65 | print(f"Using seq length {seq_length}")
66 | torch.set_printoptions(sci_mode=False)
67 | for idx, (x, y) in tqdm(
68 | enumerate(
69 | get_as_batch(
70 | data['val'],
71 | seq_length,
72 | batch_size,
73 | device=device,
74 | sliding_window=sliding_window
75 | )
76 | ),
77 | total=iceildiv(
78 | iceildiv(len(data['val']), sliding_window),
79 | batch_size
80 | )
81 | ):
82 | val_loss = 0.
83 | acc = 0.
84 | cnt = 0
85 |
86 | for part_idx, i in enumerate(range(0, x.shape[1], seq_length)):
87 | part_len = x[:, i:i + seq_length].shape[1]
88 |
89 | outputs = model(
90 | input_ids=x[:, i:i + seq_length],
91 | labels=x[:, i:i+seq_length].contiguous(),
92 | use_cache=use_cache)
93 |
94 | val_loss = outputs.loss * part_len + val_loss
95 | acc = ((outputs.logits.argmax(-1) == y[:, i:i+seq_length]).float().sum()) + acc
96 | cnt += part_len
97 | while len(loss_step_list_val) <= part_idx:
98 | loss_step_list_val.append([])
99 | loss_step_list_val[part_idx].append(outputs.loss.item())
100 | val_loss /= cnt
101 | acc /= cnt
102 |
103 | loss_list_val.append(val_loss.item())
104 | acc_list.append(acc.item())
105 |
106 | stats['val_acc'] = torch.as_tensor(acc_list).mean().item()
107 | stats['val_loss'] = torch.as_tensor(loss_list_val).mean().item()
108 | stats['val_perplexity'] = 2.71828 ** stats['val_loss']
109 | stats['val_perplexity_per_chunk'] = torch.exp(torch.as_tensor(loss_step_list_val).mean(dim=1))
110 |
111 | return stats
112 |
113 | def main(args):
114 |
115 | device = "cuda:0"
116 | seed = 2
117 | torch.cuda.set_device(device)
118 |
119 | torch.manual_seed(seed)
120 | random.seed(seed)
121 | np.random.seed(seed)
122 |
123 | data = {'val': np.memmap(args.data_path, dtype=np.uint16, mode='r')}
124 |
125 | print(f"Num validation tokens: {len(data['val'])}")
126 | print("data path", args.data_path)
127 | print("base model", args.base_model)
128 | print("peft model", args.peft_model)
129 |
130 | if args.flash_attn:
131 | replace_llama_attn(use_flash_attn=True, use_full=True)
132 |
133 | # Set RoPE scaling factor
134 | config = transformers.AutoConfig.from_pretrained(
135 | args.base_model,
136 | cache_dir=args.cache_dir,
137 | )
138 |
139 | context_size = args.context_size if args.context_size > 0 else args.seq_len
140 | orig_ctx_len = getattr(config, "max_position_embeddings", None) # this value should be 4096 for LLaMA2 models
141 | if orig_ctx_len and context_size > orig_ctx_len:
142 | scaling_factor = float(math.ceil(context_size / orig_ctx_len))
143 | config.rope_scaling = {"type": "linear", "factor": scaling_factor}
144 |
145 | # Load model and tokenizer
146 | model = transformers.AutoModelForCausalLM.from_pretrained(
147 | args.base_model,
148 | config=config,
149 | cache_dir=args.cache_dir,
150 | torch_dtype=torch.float16,
151 | device_map="auto",
152 | )
153 | model.resize_token_embeddings(32001)
154 |
155 | if args.peft_model:
156 | trainable_params = os.path.join(args.peft_model, "trainable_params.bin")
157 | if os.path.isfile(trainable_params):
158 | model.load_state_dict(torch.load(trainable_params, map_location=model.device), strict=False)
159 | else:
160 | raise ValueError("Trainable input embedding and normalization are required.")
161 | model = PeftModel.from_pretrained(
162 | model,
163 | args.peft_model,
164 | device_map="auto",
165 | torch_dtype=torch.float16,
166 | )
167 |
168 | stats = evaluate(model, data, args.batch_size, device, args.seq_len, sliding_window=256)
169 |
170 | print(stats)
171 |
172 |
173 | if __name__ == "__main__":
174 | args = parse_config()
175 | main(args)
176 |
--------------------------------------------------------------------------------
/eval_distributed.py:
--------------------------------------------------------------------------------
1 | # Written by Yukang Chen
2 | # Some code based on https://github.com/epfml/landmark-attention
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import os
17 | from dataclasses import dataclass, field
18 | from typing import Optional
19 |
20 | import math
21 | import random
22 | import transformers
23 | from peft import PeftModel
24 |
25 | from llama_attn_replace import replace_llama_attn
26 | from torch.distributed import init_process_group, destroy_process_group
27 | from torchmetrics import Accuracy
28 | from torchmetrics.text import Perplexity
29 | from torch.nn import CrossEntropyLoss
30 |
31 | import inspect
32 | from abc import ABC, abstractmethod
33 | from typing import Union
34 |
35 | from torch.utils.data import Dataset, DataLoader, DistributedSampler
36 | from transformers.modeling_utils import PreTrainedModel
37 | from torch import nn
38 | from torch.nn.parallel import DistributedDataParallel as DDP
39 | from tqdm import tqdm
40 |
41 |
42 | import numpy as np
43 | import torch
44 |
45 |
46 | class Pg19Dataset(Dataset):
47 | def __init__(self, data_path: str, seq_length: int, sliding_window: int = 256):
48 | assert seq_length >= sliding_window, f"Sliding window '{sliding_window}' must be smaller than sequence length '{seq_length}'"
49 |
50 | self.seq_length = seq_length
51 | self.data = np.memmap(data_path, dtype=np.uint16, mode='r')
52 | self.start_indices = list(range(0, len(self.data) - seq_length, sliding_window))
53 |
54 | assert len(self) > 0, "Dataset is empty"
55 |
56 | def __len__(self):
57 | return len(self.start_indices)
58 | # return 1000
59 |
60 | def __getitem__(self, index) -> dict[str, torch.Tensor]:
61 | start = self.start_indices[index]
62 | end = start + self.seq_length
63 |
64 | input_id = torch.from_numpy(self.data[start: end].astype(np.int64))
65 | y = torch.from_numpy(self.data[start+1: end+1].astype(np.int64))
66 | return {
67 | "input_ids": input_id,
68 | "labels": input_id,
69 | "ys": y
70 | }
71 |
72 | def num_tokens(self):
73 | return len(self.data)
74 |
75 |
76 | class EvalMetric(ABC):
77 | @abstractmethod
78 | def add(self, logits: torch.FloatTensor, labels: torch.LongTensor, model_output: object) -> dict[str, object]:
79 | pass
80 |
81 | @abstractmethod
82 | def compute(self) -> dict[str, object]:
83 | pass
84 |
85 |
86 | class DistributedEvaluator:
87 | def __init__(self,
88 | model: Union[PreTrainedModel, nn.Module],
89 | batch_size: int,
90 | refresh_rate: int,
91 | gpu_id: int):
92 | self.gpu_id = gpu_id
93 | self.batch_size = batch_size
94 | self.refresh_rate = refresh_rate
95 |
96 | self.model = DDP(model, device_ids=[self.gpu_id])
97 |
98 | def evaluate(self, dataset: Dataset, metric: EvalMetric) -> dict[str, object]:
99 | data_loader = self._prepare_dataloader(dataset)
100 | self.model.eval()
101 | with torch.no_grad():
102 | if self.is_first_device():
103 | data_loader = tqdm(data_loader)
104 | for i, example_dict in enumerate(data_loader):
105 | sig = inspect.signature(self.model.forward)
106 | used = set(list(sig.parameters.keys()) + ["input_ids", "labels"])
107 | inputs = {key: example_dict[key].to(self.gpu_id) for key in used if key in example_dict}
108 | outputs = self.model(**inputs)
109 | metric_result = metric.add(logits=outputs["logits"], labels=inputs["labels"], model_output=outputs)
110 |
111 | if self.is_first_device() and (i % self.refresh_rate == 0):
112 | data_loader.set_postfix(metric_result)
113 | return metric.compute()
114 |
115 | def is_first_device(self):
116 | return self.gpu_id == 0
117 |
118 | def _prepare_dataloader(self, dataset: Dataset):
119 | return DataLoader(
120 | dataset,
121 | batch_size=self.batch_size,
122 | pin_memory=True,
123 | shuffle=False,
124 | sampler=DistributedSampler(dataset)
125 | )
126 |
127 |
128 | class EvalMetricImpl(EvalMetric):
129 | def __init__(self, vocab_size: int, gpu_id: int):
130 | self.accuracy = Accuracy(task="multiclass", num_classes=vocab_size).to(gpu_id)
131 | self.perplexity = Perplexity(ignore_index=CrossEntropyLoss().ignore_index).to(gpu_id)
132 | self.last_loss = 0.0
133 |
134 | def add(self, logits: torch.FloatTensor, labels: torch.LongTensor, model_output: object) -> dict[str, object]:
135 | shift_predictions = logits.argmax(dim=-1)[..., :-1]
136 | shift_labels = labels[..., 1:]
137 |
138 | current_accuracy = self.accuracy.forward(preds=shift_predictions, target=shift_labels)
139 |
140 | shift_logits = logits[..., :-1, :]
141 | current_perplexity = self.perplexity.forward(preds=shift_logits, target=shift_labels)
142 |
143 | self.last_loss = model_output["loss"].item()
144 | return {
145 | "accuracy": current_accuracy.item(),
146 | "perplexity": current_perplexity.item(),
147 | "loss": self.last_loss
148 | }
149 |
150 | def compute(self) -> dict[str, object]:
151 | current_accuracy = self.accuracy.compute()
152 | current_perplexity = self.perplexity.compute()
153 | return {
154 | "accuracy": current_accuracy.item(),
155 | "perplexity": current_perplexity.item(),
156 | "loss": self.last_loss
157 | }
158 |
159 |
160 | @dataclass
161 | class EvalArguments:
162 | batch_size: int = field(
163 | default=1,
164 | metadata={"help": "batch size."},
165 | )
166 | base_model: Optional[str] = field(default="meta-llama/Llama-2-7b-hf")
167 | seq_len: int = field(
168 | default=2048,
169 | metadata={"help": "context length during evaluation."},
170 | )
171 | context_size: int = field(
172 | default=-1,
173 | metadata={"help": "context size during fine-tuning."},
174 | )
175 | peft_model: Optional[str] = field(default=None)
176 | flash_attn: bool = field(
177 | default=True,
178 | metadata={"help": "Whether use flash attention."},
179 | )
180 | data_path: str = field(
181 | default="./test.bin",
182 | metadata={"help": "test data path"},
183 | )
184 | cache_dir: Optional[str] = field(default="./.cache")
185 | progress_bar_fresh_rate: int = field(
186 | default=10,
187 | metadata={"help": "progress bar metrics fresh rate."},
188 | )
189 |
190 |
191 | def run_eval(args: EvalArguments):
192 | torch_dtype = torch.float16
193 |
194 | seed = 2
195 | torch.manual_seed(seed)
196 | random.seed(seed)
197 | np.random.seed(seed)
198 |
199 | dataset = Pg19Dataset(args.data_path, seq_length=args.seq_len, sliding_window=256)
200 | if args.flash_attn:
201 | replace_llama_attn(use_flash_attn=True, use_full=True)
202 |
203 | # Set RoPE scaling factor
204 | config = transformers.AutoConfig.from_pretrained(
205 | args.base_model,
206 | cache_dir=args.cache_dir,
207 | use_cache=False
208 | )
209 |
210 | context_size = args.context_size if args.context_size > 0 else args.seq_len
211 | orig_ctx_len = getattr(config, "max_position_embeddings", None) # this value should be 4096 for LLaMA2 models
212 | if orig_ctx_len and context_size > orig_ctx_len:
213 | scaling_factor = float(math.ceil(context_size / orig_ctx_len))
214 | config.rope_scaling = {"type": "linear", "factor": scaling_factor}
215 |
216 | # Load model and tokenizer
217 | model = transformers.AutoModelForCausalLM.from_pretrained(
218 | args.base_model,
219 | config=config,
220 | cache_dir=args.cache_dir,
221 | torch_dtype=torch_dtype)
222 | model.resize_token_embeddings(32001)
223 |
224 | if args.peft_model:
225 | trainable_params = os.path.join(args.peft_model, "trainable_params.bin")
226 | if os.path.isfile(trainable_params):
227 | model.load_state_dict(torch.load(trainable_params, map_location=model.device), strict=False)
228 | else:
229 | raise ValueError("Trainable input embedding and normalization are required.")
230 | model = PeftModel.from_pretrained(
231 | model,
232 | args.peft_model,
233 | torch_dtype=torch_dtype,
234 | offload_folder=args.cache_dir,
235 | )
236 |
237 | # This is a hacky way to enable distributed evaluation. Otherwise, without any trainable parameters, we will not
238 | # be able to use DistributedDataParallel, although we don't update any parameters during evaluation.
239 | [p.requires_grad_() for n, p in model.named_parameters() if any([k in n for k in ["lm_head"]])]
240 |
241 | gpu_id = int(os.environ["LOCAL_RANK"])
242 | model.to(gpu_id)
243 |
244 | evaluator = DistributedEvaluator(
245 | model=model,
246 | batch_size=args.batch_size,
247 | refresh_rate=args.progress_bar_fresh_rate,
248 | gpu_id=gpu_id)
249 |
250 | if evaluator.is_first_device():
251 | print("data path", args.data_path)
252 | print("base model", args.base_model)
253 | print("peft model", args.peft_model)
254 | print(f"Num validation tokens: {dataset.num_tokens()}, Num validation examples: {len(dataset)}")
255 |
256 | eval_metric = EvalMetricImpl(vocab_size=config.vocab_size, gpu_id=gpu_id)
257 | result = evaluator.evaluate(dataset, eval_metric)
258 | if evaluator.is_first_device():
259 | print(result)
260 |
261 |
262 | def ddp_setup():
263 | init_process_group(backend="nccl")
264 |
265 |
266 | def main(cmd_args: list[str] = None):
267 | ddp_setup()
268 | parser = transformers.HfArgumentParser((EvalArguments, ))
269 | args: EvalArguments = parser.parse_args_into_dataclasses(cmd_args)[0]
270 | try:
271 | run_eval(args)
272 | finally:
273 | destroy_process_group()
274 |
275 |
276 | if __name__ == "__main__":
277 | main()
278 |
--------------------------------------------------------------------------------
/fine-tune.py:
--------------------------------------------------------------------------------
1 | # Written by Yukang Chen
2 | # Some code based on https://github.com/epfml/landmark-attention
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import os
17 | import math
18 | from dataclasses import dataclass, field
19 | from functools import partial
20 | from typing import Dict, Optional, Sequence
21 |
22 | import torch
23 | import transformers
24 | from torch.utils.data import Dataset
25 | from transformers import Trainer, DataCollatorForLanguageModeling
26 | from llama_attn_replace import replace_llama_attn
27 | from gptneox_attn_replace import replace_gpt_neox_attn
28 | from peft import LoraConfig, get_peft_model
29 | from torch.distributed import barrier
30 |
31 |
32 | from datasets import load_dataset
33 |
34 | IGNORE_INDEX = -100
35 | DEFAULT_PAD_TOKEN = "[PAD]"
36 | DEFAULT_EOS_TOKEN = ""
37 | DEFAULT_BOS_TOKEN = ""
38 | DEFAULT_UNK_TOKEN = ""
39 |
40 |
41 | @dataclass
42 | class ModelArguments:
43 | model_name_or_path: Optional[str] = field(default="EleutherAI/pythia-1.4b-deduped")
44 | model_type: Optional[str] = field(default="llama")
45 |
46 | @dataclass
47 | class TrainingArguments(transformers.TrainingArguments):
48 | cache_dir: Optional[str] = field(default=None)
49 | optim: str = field(default="adamw_torch")
50 | model_max_length: int = field(
51 | default=8192 * 4,
52 | metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."},
53 | )
54 | use_flash_attn: bool = field(
55 | default=True,
56 | metadata={"help": "Whether use flash attention for training."},
57 | )
58 | use_full_attn: bool = field(
59 | default=False,
60 | metadata={"help": "Whether to use plain, full-attention for training."},
61 | )
62 | low_rank_training: bool = field(
63 | default=True,
64 | metadata={"help": "Whether use low rank adaptation for training."},
65 | )
66 | trainable_params: str = field(
67 | default="embed,norm",
68 | metadata={"help": "Additional trainable parameters except LoRA weights, if low rank training."},
69 | )
70 |
71 | def smart_tokenizer_and_embedding_resize(
72 | special_tokens_dict: Dict,
73 | tokenizer: transformers.PreTrainedTokenizer,
74 | model: transformers.PreTrainedModel,
75 | ):
76 | """Resize tokenizer and embedding.
77 |
78 | Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
79 | """
80 | num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
81 | model.resize_token_embeddings(len(tokenizer))
82 |
83 | if num_new_tokens > 0:
84 | input_embeddings = model.get_input_embeddings().weight.data
85 | output_embeddings = model.get_output_embeddings().weight.data
86 |
87 | input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
88 | output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
89 |
90 | input_embeddings[-num_new_tokens:] = input_embeddings_avg
91 | output_embeddings[-num_new_tokens:] = output_embeddings_avg
92 |
93 | def tokenize_fn(tokenizer, example):
94 | context_length = tokenizer.model_max_length
95 | outputs = tokenizer(
96 | tokenizer.eos_token.join(example["text"]),
97 | truncation=False,
98 | return_tensors="pt",
99 | pad_to_multiple_of=context_length,
100 | padding=True,
101 | )
102 | return {"input_ids": outputs["input_ids"].view(-1, context_length)}
103 |
104 | def train():
105 | parser = transformers.HfArgumentParser((ModelArguments, TrainingArguments))
106 | model_args, training_args = parser.parse_args_into_dataclasses()
107 |
108 | # NOTE: May expand supported model types in the future
109 | if model_args.model_type == "gpt-neox":
110 | replace_gpt_neox_attn(training_args.use_flash_attn, training_args.use_full_attn)
111 | else:
112 | assert model_args.model_type == "llama", "Only support llama and gpt-neox for now"
113 | replace_llama_attn(training_args.use_flash_attn, training_args.use_full_attn)
114 |
115 | # Set RoPE scaling factor
116 | config = transformers.AutoConfig.from_pretrained(
117 | model_args.model_name_or_path,
118 | cache_dir=training_args.cache_dir,
119 | )
120 |
121 | orig_rope_scaling = getattr(config, "rope_scaling", None)
122 | if orig_rope_scaling is None:
123 | orig_rope_scaling = {"factor": 1}
124 |
125 | orig_rope_scaling_factor = orig_rope_scaling["factor"] if "factor" in orig_rope_scaling.keys() else 1
126 | orig_ctx_len = getattr(config, "max_position_embeddings", None)
127 | if orig_ctx_len:
128 | orig_ctx_len *= orig_rope_scaling_factor
129 | if training_args.model_max_length > orig_ctx_len:
130 | scaling_factor = float(math.ceil(training_args.model_max_length / orig_ctx_len))
131 | config.rope_scaling = {"type": "linear", "factor": scaling_factor}
132 |
133 | # Load model and tokenizer
134 | model = transformers.AutoModelForCausalLM.from_pretrained(
135 | model_args.model_name_or_path,
136 | config=config,
137 | cache_dir=training_args.cache_dir,
138 | torch_dtype=torch.bfloat16,
139 | )
140 |
141 | tokenizer = transformers.AutoTokenizer.from_pretrained(
142 | model_args.model_name_or_path,
143 | cache_dir=training_args.cache_dir,
144 | model_max_length=training_args.model_max_length,
145 | padding_side="right",
146 | use_fast=True,
147 | )
148 |
149 | special_tokens_dict = dict()
150 | if tokenizer.pad_token is None:
151 | special_tokens_dict["pad_token"] = DEFAULT_PAD_TOKEN
152 | if tokenizer.eos_token is None:
153 | special_tokens_dict["eos_token"] = DEFAULT_EOS_TOKEN
154 | if tokenizer.bos_token is None:
155 | special_tokens_dict["bos_token"] = DEFAULT_BOS_TOKEN
156 | if tokenizer.unk_token is None:
157 | special_tokens_dict["unk_token"] = DEFAULT_UNK_TOKEN
158 |
159 | smart_tokenizer_and_embedding_resize(
160 | special_tokens_dict=special_tokens_dict,
161 | tokenizer=tokenizer,
162 | model=model,
163 | )
164 |
165 | rank = int(os.environ.get('RANK', -1))
166 | if rank > 0:
167 | barrier()
168 | dataset = load_dataset("togethercomputer/RedPajama-Data-1T-Sample", cache_dir=training_args.cache_dir)
169 | dataset = dataset.map(partial(tokenize_fn,tokenizer),batched=True, num_proc=128, remove_columns=["text", "meta"])
170 |
171 | if rank == 0:
172 | barrier()
173 |
174 | print(dataset)
175 |
176 | data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
177 |
178 | if training_args.low_rank_training:
179 | if model_args.model_type == "gpt-neox":
180 | # added `dense` to match with llama as the basic LoRA would only target 'query_key_value'
181 | targets = ["query_key_value", "dense"]
182 | else:
183 | targets=["q_proj", "k_proj", "v_proj", "o_proj"]
184 |
185 | config = LoraConfig(
186 | r=8,
187 | lora_alpha=16,
188 | target_modules=targets,
189 | lora_dropout=0,
190 | bias="none",
191 | task_type="CAUSAL_LM",
192 | )
193 | model = get_peft_model(model, config)
194 | # enable trainable params
195 | [p.requires_grad_() for n, p in model.named_parameters() if any([k in n for k in training_args.trainable_params.split(",")])]
196 |
197 | model.config.use_cache = False # required for gradient checkpointing
198 | model.enable_input_require_grads() # required for gradient checkpointing
199 | model.gradient_checkpointing_enable() # enable gradient checkpointing
200 | trainer = Trainer(
201 | model=model, tokenizer=tokenizer, args=training_args,
202 | train_dataset=dataset["train"],
203 | eval_dataset=None,
204 | data_collator=data_collator)
205 | trainer.train()
206 | trainer.save_state()
207 | trainer.save_model(output_dir=training_args.output_dir)
208 |
209 |
210 | if __name__ == "__main__":
211 | train()
212 |
--------------------------------------------------------------------------------
/get_trainable_weights.py:
--------------------------------------------------------------------------------
1 | # Written by Yukang Chen
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import os
16 | import torch
17 | import argparse
18 |
19 | def parse_config():
20 | parser = argparse.ArgumentParser(description='arg parser')
21 | parser.add_argument('--checkpoint_path', type=str, default="/dataset/models/checkpoint-1000")
22 | parser.add_argument('--trainable_params', type=str, default="embed,norm")
23 | args = parser.parse_args()
24 | return args
25 |
26 |
27 | def main(args):
28 | path = args.checkpoint_path
29 | trainable_params = args.trainable_params.split(",")
30 |
31 | weights_all = torch.load(os.path.join(path, "pytorch_model.bin"))
32 |
33 | weights_trainable = {}
34 | weights_lora = {}
35 | for k in weights_all:
36 | if "lora" in k:
37 | k_new = k.replace("default.", "") if "default." in k else k
38 | weights_lora[k_new] = weights_all[k]
39 | else:
40 | if any([n in k for n in trainable_params]):
41 | weights_trainable[k[17:]] = weights_all[k]
42 |
43 | adapter_model = os.path.join(path, "adapter_model.bin")
44 | trainable_params = os.path.join(path, "trainable_params.bin")
45 | if not os.path.isfile(adapter_model):
46 | torch.save(weights_lora, adapter_model)
47 | torch.save(weights_trainable, trainable_params)
48 |
49 | if __name__ == "__main__":
50 | args = parse_config()
51 | main(args)
52 |
--------------------------------------------------------------------------------
/gptneox_attn_replace.py:
--------------------------------------------------------------------------------
1 | # Modified based on https://github.com/dvlab-research/LongLoRA
2 |
3 | from typing import Optional, Tuple
4 | import warnings
5 | import torch
6 | import transformers
7 |
8 | from einops import rearrange
9 | from flash_attn import flash_attn_varlen_qkvpacked_func, flash_attn_varlen_func
10 | from flash_attn.bert_padding import unpad_input, pad_input
11 |
12 |
13 | group_size_ratio = 1/4
14 |
15 | def rotate_half(x):
16 | """Rotates half the hidden dims of the input."""
17 | x1 = x[..., : x.shape[-1] // 2]
18 | x2 = x[..., x.shape[-1] // 2 :]
19 | return torch.cat((-x2, x1), dim=-1)
20 |
21 | def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
22 | gather_indices = position_ids[:, None, :, None] # [bs, 1, seq_len, 1]
23 | gather_indices = gather_indices.repeat(1, cos.shape[1], 1, cos.shape[3])
24 | cos = torch.gather(cos.repeat(gather_indices.shape[0], 1, 1, 1).to(q.dtype), 2, gather_indices)
25 | sin = torch.gather(sin.repeat(gather_indices.shape[0], 1, 1, 1).to(k.dtype), 2, gather_indices)
26 | q_embed = (q * cos) + (rotate_half(q) * sin)
27 | k_embed = (k * cos) + (rotate_half(k) * sin)
28 | return q_embed, k_embed
29 |
30 |
31 | def _flash_attn_ssa(query, key, value, attention_mask=None, head_mask=None):
32 | # transform the data into the qkv packed form
33 | qkv = torch.stack(
34 | [query, key, value], dim=2
35 | ) # [bsz, nh, 3, q_len, hd]
36 | qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd]
37 | bsz, q_len = qkv.shape[:2]
38 |
39 | qkv = rearrange(qkv, "b s ... -> (b s) ...")
40 | cu_q_lens = torch.arange(0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device)
41 | output = flash_attn_varlen_qkvpacked_func(qkv, cu_q_lens, q_len, 0.0, softmax_scale=None, causal=True)
42 | output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
43 |
44 | # disable attn weights by returning None when using flash attention
45 | return output, None
46 |
47 | def _flash_attn_full(query, key, value, attention_mask=None, head_mask=None):
48 | # q, k, v: [bs, nh, seq_len, hd]
49 | batch_size, num_attention_heads, query_length, attn_head_size = query.size()
50 | key_length = key.size(-2)
51 | value_length = value.size(-2)
52 |
53 | # q, k, v: [bs, nh, seq_len, hd] -> [bs, seq_len, nh, hd] -> [bs * seq_len, nh, hd]
54 | query = query.transpose(1, 2).reshape(batch_size * query_length , num_attention_heads, attn_head_size)
55 | key = key.transpose(1, 2).reshape(batch_size * key_length, num_attention_heads, attn_head_size)
56 | value = value.transpose(1, 2).reshape(batch_size * value_length, num_attention_heads, attn_head_size)
57 |
58 | cu_seqlens_q = torch.arange(
59 | 0,
60 | (batch_size + 1) * query_length,
61 | step=query_length,
62 | dtype=torch.int32,
63 | device=query.device,
64 | )
65 |
66 | cu_seqlens_k = torch.arange(
67 | 0,
68 | (batch_size + 1) * key_length,
69 | step=key_length,
70 | dtype=torch.int32,
71 | device=key.device,
72 | )
73 |
74 | attn_output, attn_weights, _ = flash_attn_varlen_func(
75 | query, key, value, cu_seqlens_q, cu_seqlens_k, query_length, value_length, dropout_p=0.0,
76 | softmax_scale=None, causal=True, return_attn_probs=True
77 | )
78 |
79 | attn_output = attn_output.view(batch_size, query_length, num_attention_heads, attn_head_size).transpose(1, 2)
80 | return attn_output, attn_weights
81 |
82 |
83 | def get_forward_function(use_flash_attn=True, use_full=False):
84 |
85 | def forward_attention(
86 | self,
87 | hidden_states: torch.FloatTensor,
88 | attention_mask: torch.FloatTensor,
89 | position_ids: torch.LongTensor,
90 | head_mask: Optional[torch.FloatTensor] = None,
91 | layer_past: Optional[Tuple[torch.Tensor]] = None,
92 | use_cache: Optional[bool] = False,
93 | output_attentions: Optional[bool] = False,
94 | ):
95 | # NOTE: compute SS group size
96 | bsz, q_len, _ = hidden_states.size()
97 | has_layer_past = layer_past is not None
98 |
99 | # Compute QKV
100 | # Attention heads [batch, seq_len, hidden_size]
101 | # --> [batch, seq_len, (np * 3 * head_size)]
102 | qkv = self.query_key_value(hidden_states)
103 |
104 | # [batch, seq_len, (num_heads * 3 * head_size)]
105 | # --> [batch, seq_len, num_heads, 3 * head_size]
106 | new_qkv_shape = qkv.size()[:-1] + (self.num_attention_heads, 3 * self.head_size)
107 | qkv = qkv.view(*new_qkv_shape)
108 |
109 | # [batch, seq_len, num_attention_heads, 3 * head_size]
110 | # --> 3 [batch, num_attention_heads, seq_len, head_size]
111 | query = qkv[..., : self.head_size].permute(0, 2, 1, 3)
112 | key = qkv[..., self.head_size : 2 * self.head_size].permute(0, 2, 1, 3)
113 | value = qkv[..., 2 * self.head_size :].permute(0, 2, 1, 3)
114 | # [bsz, nh, q_len, hd]
115 |
116 | # Compute rotary embeddings on rotary_ndims
117 | query_rot = query[..., : self.rotary_ndims]
118 | query_pass = query[..., self.rotary_ndims :]
119 | key_rot = key[..., : self.rotary_ndims]
120 | key_pass = key[..., self.rotary_ndims :]
121 |
122 | # Compute token offset for rotary embeddings (when decoding)
123 | seq_len = key.shape[-2]
124 | if has_layer_past:
125 | seq_len += layer_past[0].shape[-2]
126 | cos, sin = self.rotary_emb(value, seq_len=seq_len)
127 | query, key = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, position_ids)
128 | query = torch.cat((query, query_pass), dim=-1)
129 | key = torch.cat((key, key_pass), dim=-1)
130 |
131 | # Cache QKV values
132 | if has_layer_past:
133 | past_key = layer_past[0]
134 | past_value = layer_past[1]
135 | key = torch.cat((past_key, key), dim=-2)
136 | value = torch.cat((past_value, value), dim=-2)
137 | present = (key, value) if use_cache else None
138 |
139 | # NOTE: apply shift
140 | group_size = int(q_len * group_size_ratio)
141 | if q_len % group_size > 0:
142 | raise ValueError("q_len %d should be divisible by group size %d." % (q_len, group_size))
143 | num_group = q_len // group_size
144 | if self.training and not use_full:
145 | def shift(qkv, num_heads, head_dim):
146 | # qkv = [bsz, nh, q_len, d]
147 | qkv = qkv.transpose(1, 2)
148 | # qkv = [bsz, q_len, nh, d]
149 | qkv[:, :, num_heads//2:] = qkv[:, :, num_heads//2:].roll(-group_size//2, dims=1)
150 |
151 | # -> [bsz * n_group, group_s, nh, d)
152 | # -> [bsz * n_group, nh, group_s, d)
153 | qkv = qkv.reshape(bsz * num_group, group_size, num_heads, head_dim).transpose(1, 2)
154 | return qkv
155 |
156 | # contiguous is required as self._attn() will attempt to apply .view() on them
157 | query = shift(query, self.num_attention_heads, self.head_size).contiguous()
158 | key = shift(key, self.num_attention_heads, self.head_size).contiguous()
159 | value = shift(value, self.num_attention_heads, self.head_size).contiguous()
160 |
161 | attention_mask = attention_mask[:, :, :group_size, :group_size].repeat(num_group, 1, 1, 1)
162 |
163 | # Compute attention
164 | if use_flash_attn:
165 | _flash_attn = _flash_attn_full if use_full else _flash_attn_ssa
166 | attn_output, attn_weights = _flash_attn(query, key, value, attention_mask, head_mask)
167 | else:
168 | attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
169 |
170 | # NOTE: shift back
171 | if self.training and not use_full:
172 | attn_output = attn_output.transpose(1, 2).contiguous()
173 | attn_output = attn_output.reshape(bsz, q_len, self.num_attention_heads, self.head_size)
174 | # [bsz, q_len, nh, hd]
175 | attn_output[:, :, self.num_attention_heads//2:] = attn_output[:, :, self.num_attention_heads//2:].roll(group_size//2, dims=1)
176 | attn_output = attn_output.transpose(1, 2)
177 |
178 | # Reshape outputs
179 | attn_output = self._merge_heads(attn_output, self.num_attention_heads, self.head_size)
180 | attn_output = self.dense(attn_output)
181 |
182 | outputs = (attn_output, present)
183 | if output_attentions:
184 | outputs += (attn_weights,)
185 |
186 | return outputs
187 |
188 | return forward_attention
189 |
190 |
191 | def replace_gpt_neox_attn(use_flash_attn=True, use_full=False):
192 | cuda_major, cuda_minor = torch.cuda.get_device_capability()
193 | if use_flash_attn and cuda_major < 8:
194 | warnings.warn(
195 | "Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward."
196 | "ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593"
197 | "Resorting to plain attention..."
198 | )
199 | use_flash_attn = False
200 |
201 | forward_fn = get_forward_function(use_flash_attn, use_full)
202 | transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXAttention.forward = forward_fn
203 |
--------------------------------------------------------------------------------
/imgs/LongAlpaca.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dvlab-research/LongLoRA/d4eb344c5ccc9e91c0812a2b2aeea69070df8c33/imgs/LongAlpaca.png
--------------------------------------------------------------------------------
/imgs/Shift-short-attention2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dvlab-research/LongLoRA/d4eb344c5ccc9e91c0812a2b2aeea69070df8c33/imgs/Shift-short-attention2.png
--------------------------------------------------------------------------------
/imgs/data-distribution-in-longalpaca12k.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dvlab-research/LongLoRA/d4eb344c5ccc9e91c0812a2b2aeea69070df8c33/imgs/data-distribution-in-longalpaca12k.png
--------------------------------------------------------------------------------
/imgs/demo-compare-harrypotter.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dvlab-research/LongLoRA/d4eb344c5ccc9e91c0812a2b2aeea69070df8c33/imgs/demo-compare-harrypotter.png
--------------------------------------------------------------------------------
/imgs/demo-compare-journeytothewest.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dvlab-research/LongLoRA/d4eb344c5ccc9e91c0812a2b2aeea69070df8c33/imgs/demo-compare-journeytothewest.png
--------------------------------------------------------------------------------
/imgs/demo-compare-threebody.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dvlab-research/LongLoRA/d4eb344c5ccc9e91c0812a2b2aeea69070df8c33/imgs/demo-compare-threebody.png
--------------------------------------------------------------------------------
/imgs/economy-comparison.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dvlab-research/LongLoRA/d4eb344c5ccc9e91c0812a2b2aeea69070df8c33/imgs/economy-comparison.png
--------------------------------------------------------------------------------
/imgs/economy-prediction.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dvlab-research/LongLoRA/d4eb344c5ccc9e91c0812a2b2aeea69070df8c33/imgs/economy-prediction.png
--------------------------------------------------------------------------------
/imgs/paper-improvements.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dvlab-research/LongLoRA/d4eb344c5ccc9e91c0812a2b2aeea69070df8c33/imgs/paper-improvements.png
--------------------------------------------------------------------------------
/imgs/paper-review.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dvlab-research/LongLoRA/d4eb344c5ccc9e91c0812a2b2aeea69070df8c33/imgs/paper-review.png
--------------------------------------------------------------------------------
/imgs/paper-style-compare-cvpr-iclr.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dvlab-research/LongLoRA/d4eb344c5ccc9e91c0812a2b2aeea69070df8c33/imgs/paper-style-compare-cvpr-iclr.png
--------------------------------------------------------------------------------
/inference-qlora.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import math
4 | import torch
5 | import argparse
6 | import textwrap
7 | import transformers
8 | from peft import PeftModel
9 | from transformers import GenerationConfig, TextStreamer, BitsAndBytesConfig
10 | from llama_attn_replace import replace_llama_attn
11 |
12 | PROMPT_DICT = {
13 | "prompt_no_input": (
14 | "Below is an instruction that describes a task. "
15 | "Write a response that appropriately completes the request.\n\n"
16 | "### Instruction:\n{instruction}\n\n### Response:"
17 | ),
18 | "prompt_no_input_llama2": (
19 | "[INST] <>\n"
20 | "You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\n"
21 | "If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\n"
22 | "<> \n\n {instruction} [/INST]"
23 | ),
24 | "prompt_llama2": "[INST]{instruction}[/INST]"
25 | }
26 |
27 | def parse_config():
28 | parser = argparse.ArgumentParser(description='arg parser')
29 | parser.add_argument('--material', type=str, default="")
30 | parser.add_argument('--question', type=str, default="")
31 | parser.add_argument('--base_model', type=str, default="/data1/pretrained-models/llama-7b-hf")
32 | parser.add_argument('--cache_dir', type=str, default="./cache")
33 | parser.add_argument('--context_size', type=int, default=-1, help='context size during fine-tuning')
34 | parser.add_argument('--flash_attn', type=bool, default=False, help='')
35 | parser.add_argument('--temperature', type=float, default=0.6, help='')
36 | parser.add_argument('--top_p', type=float, default=0.9, help='')
37 | parser.add_argument('--max_gen_len', type=int, default=512, help='')
38 | args = parser.parse_args()
39 | return args
40 |
41 | def read_txt_file(material_txt):
42 | if not material_txt.split(".")[-1]=='txt':
43 | raise ValueError("Only support txt or pdf file.")
44 | content = ""
45 | with open(material_txt) as f:
46 | for line in f.readlines():
47 | content += line
48 | return content
49 |
50 | def build_generator(
51 | model, tokenizer, temperature=0.6, top_p=0.9, max_gen_len=4096, use_cache=True
52 | ):
53 | def response(prompt):
54 | inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
55 |
56 | streamer = TextStreamer(tokenizer)
57 |
58 | output = model.generate(
59 | **inputs,
60 | max_new_tokens=max_gen_len,
61 | temperature=temperature,
62 | top_p=top_p,
63 | use_cache=use_cache,
64 | streamer=streamer,
65 | )
66 |
67 | out = tokenizer.decode(output[0], skip_special_tokens=True)
68 |
69 | out = out.split(prompt.lstrip(""))[1].strip()
70 | return out
71 |
72 | return response
73 |
74 | def main(args):
75 | if args.flash_attn:
76 | replace_llama_attn(inference=True)
77 |
78 | # Set RoPE scaling factor
79 | config = transformers.AutoConfig.from_pretrained(
80 | args.base_model,
81 | cache_dir=args.cache_dir,
82 | )
83 |
84 | orig_ctx_len = getattr(config, "max_position_embeddings", None)
85 | if orig_ctx_len and args.context_size > orig_ctx_len:
86 | scaling_factor = float(math.ceil(args.context_size / orig_ctx_len))
87 | config.rope_scaling = {"type": "linear", "factor": scaling_factor}
88 |
89 | # Load model and tokenizer
90 | model = transformers.AutoModelForCausalLM.from_pretrained(
91 | args.base_model,
92 | config=config,
93 | cache_dir=args.cache_dir,
94 | torch_dtype=torch.float16,
95 | device_map="auto",
96 | quantization_config = BitsAndBytesConfig(
97 | load_in_4bit=True,
98 | bnb_4bit_use_double_quant=True,
99 | bnb_4bit_quant_type="nf4",
100 | bnb_4bit_compute_dtype=torch.bfloat16
101 | )
102 | )
103 | model.resize_token_embeddings(32001)
104 |
105 | tokenizer = transformers.AutoTokenizer.from_pretrained(
106 | args.base_model,
107 | cache_dir=args.cache_dir,
108 | model_max_length=args.context_size if args.context_size > orig_ctx_len else orig_ctx_len,
109 | padding_side="right",
110 | use_fast=False,
111 | )
112 |
113 | model.eval()
114 | if torch.__version__ >= "2" and sys.platform != "win32":
115 | model = torch.compile(model)
116 | respond = build_generator(model, tokenizer, temperature=args.temperature, top_p=args.top_p,
117 | max_gen_len=args.max_gen_len, use_cache=True)
118 |
119 | material = read_txt_file(args.material)
120 | prompt_no_input = PROMPT_DICT["prompt_llama2"]
121 | prompt = prompt_no_input.format_map({"instruction": material + "\n%s"%args.question})
122 |
123 | output = respond(prompt=prompt)
124 |
125 | if __name__ == "__main__":
126 | args = parse_config()
127 | main(args)
128 |
--------------------------------------------------------------------------------
/inference.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import math
4 | import torch
5 | import argparse
6 | import textwrap
7 | import transformers
8 | from peft import PeftModel
9 | from transformers import GenerationConfig, TextStreamer
10 | from llama_attn_replace import replace_llama_attn
11 |
12 | PROMPT_DICT = {
13 | "prompt_no_input": (
14 | "Below is an instruction that describes a task. "
15 | "Write a response that appropriately completes the request.\n\n"
16 | "### Instruction:\n{instruction}\n\n### Response:"
17 | ),
18 | "prompt_no_input_llama2": (
19 | "[INST] <>\n"
20 | "You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\n"
21 | "If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\n"
22 | "<> \n\n {instruction} [/INST]"
23 | ),
24 | "prompt_llama2": "[INST]{instruction}[/INST]"
25 | }
26 |
27 | def parse_config():
28 | parser = argparse.ArgumentParser(description='arg parser')
29 | parser.add_argument('--material', type=str, default="")
30 | parser.add_argument('--question', type=str, default="")
31 | parser.add_argument('--base_model', type=str, default="/data1/pretrained-models/llama-7b-hf")
32 | parser.add_argument('--cache_dir', type=str, default="./cache")
33 | parser.add_argument('--context_size', type=int, default=-1, help='context size during fine-tuning')
34 | parser.add_argument('--flash_attn', type=bool, default=False, help='')
35 | parser.add_argument('--temperature', type=float, default=0.6, help='')
36 | parser.add_argument('--top_p', type=float, default=0.9, help='')
37 | parser.add_argument('--max_gen_len', type=int, default=512, help='')
38 | args = parser.parse_args()
39 | return args
40 |
41 | def read_txt_file(material_txt):
42 | if not material_txt.split(".")[-1]=='txt':
43 | raise ValueError("Only support txt or pdf file.")
44 | content = ""
45 | with open(material_txt) as f:
46 | for line in f.readlines():
47 | content += line
48 | return content
49 |
50 | def build_generator(
51 | model, tokenizer, temperature=0.6, top_p=0.9, max_gen_len=4096, use_cache=True
52 | ):
53 | def response(prompt):
54 | inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
55 |
56 | streamer = TextStreamer(tokenizer)
57 |
58 | output = model.generate(
59 | **inputs,
60 | max_new_tokens=max_gen_len,
61 | temperature=temperature,
62 | top_p=top_p,
63 | use_cache=use_cache,
64 | streamer=streamer,
65 | )
66 |
67 | out = tokenizer.decode(output[0], skip_special_tokens=True)
68 |
69 | out = out.split(prompt.lstrip(""))[1].strip()
70 | return out
71 |
72 | return response
73 |
74 | def main(args):
75 | if args.flash_attn:
76 | replace_llama_attn(inference=True)
77 |
78 | # Set RoPE scaling factor
79 | config = transformers.AutoConfig.from_pretrained(
80 | args.base_model,
81 | cache_dir=args.cache_dir,
82 | )
83 |
84 | orig_ctx_len = getattr(config, "max_position_embeddings", None)
85 | if orig_ctx_len and args.context_size > orig_ctx_len:
86 | scaling_factor = float(math.ceil(args.context_size / orig_ctx_len))
87 | config.rope_scaling = {"type": "linear", "factor": scaling_factor}
88 |
89 | # Load model and tokenizer
90 | model = transformers.AutoModelForCausalLM.from_pretrained(
91 | args.base_model,
92 | config=config,
93 | cache_dir=args.cache_dir,
94 | torch_dtype=torch.float16,
95 | device_map="auto",
96 | )
97 | model.resize_token_embeddings(32001)
98 |
99 | tokenizer = transformers.AutoTokenizer.from_pretrained(
100 | args.base_model,
101 | cache_dir=args.cache_dir,
102 | model_max_length=args.context_size if args.context_size > orig_ctx_len else orig_ctx_len,
103 | padding_side="right",
104 | use_fast=False,
105 | )
106 |
107 | if torch.__version__ >= "2" and sys.platform != "win32":
108 | model = torch.compile(model)
109 | model.eval()
110 |
111 | respond = build_generator(model, tokenizer, temperature=args.temperature, top_p=args.top_p,
112 | max_gen_len=args.max_gen_len, use_cache=True)
113 |
114 | material = read_txt_file(args.material)
115 | prompt_no_input = PROMPT_DICT["prompt_llama2"]
116 | prompt = prompt_no_input.format_map({"instruction": material + "\n%s"%args.question})
117 |
118 | output = respond(prompt=prompt)
119 |
120 | if __name__ == "__main__":
121 | args = parse_config()
122 | main(args)
123 |
--------------------------------------------------------------------------------
/merge_lora_weights_and_save_hf_model.py:
--------------------------------------------------------------------------------
1 | # Written by Yukang Chen
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import os
16 | import torch
17 | import argparse
18 | import transformers
19 | from peft import PeftModel
20 | from typing import Dict
21 |
22 | IGNORE_INDEX = -100
23 | DEFAULT_PAD_TOKEN = "[PAD]"
24 | DEFAULT_EOS_TOKEN = ""
25 | DEFAULT_BOS_TOKEN = ""
26 | DEFAULT_UNK_TOKEN = ""
27 |
28 | def parse_config():
29 | parser = argparse.ArgumentParser(description='arg parser')
30 | parser.add_argument('--base_model', type=str, default="/data/pretrained-models/llama-7b-hf")
31 | parser.add_argument('--peft_model', type=str, default=None, help='')
32 | parser.add_argument('--context_size', type=int, default=-1, help='context size during fine-tuning')
33 | parser.add_argument('--save_path', type=str, default=None, help='')
34 | parser.add_argument('--cache_dir', type=str, default=None, help='./cache_dir')
35 | args = parser.parse_args()
36 | return args
37 |
38 | def smart_tokenizer_and_embedding_resize(
39 | special_tokens_dict: Dict,
40 | tokenizer: transformers.PreTrainedTokenizer,
41 | model: transformers.PreTrainedModel,
42 | ):
43 | """Resize tokenizer and embedding.
44 |
45 | Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
46 | """
47 | num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
48 | model.resize_token_embeddings(len(tokenizer))
49 |
50 | if num_new_tokens > 0:
51 | input_embeddings = model.get_input_embeddings().weight.data
52 | output_embeddings = model.get_output_embeddings().weight.data
53 |
54 | input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
55 | output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
56 |
57 | input_embeddings[-num_new_tokens:] = input_embeddings_avg
58 | output_embeddings[-num_new_tokens:] = output_embeddings_avg
59 |
60 | def main(args):
61 | device = "cuda:0"
62 | torch.cuda.set_device(device)
63 |
64 | print("base model", args.base_model)
65 | print("peft model", args.peft_model)
66 |
67 | # Load model and tokenizer
68 | model = transformers.AutoModelForCausalLM.from_pretrained(
69 | args.base_model,
70 | cache_dir=args.cache_dir,
71 | torch_dtype=torch.float16,
72 | device_map="auto",
73 | )
74 |
75 | tokenizer = transformers.AutoTokenizer.from_pretrained(
76 | args.base_model,
77 | cache_dir=args.cache_dir,
78 | model_max_length=args.context_size,
79 | padding_side="right",
80 | use_fast=False,
81 | )
82 | special_tokens_dict = dict()
83 | if tokenizer.pad_token is None:
84 | special_tokens_dict["pad_token"] = DEFAULT_PAD_TOKEN
85 | if tokenizer.eos_token is None:
86 | special_tokens_dict["eos_token"] = DEFAULT_EOS_TOKEN
87 | if tokenizer.bos_token is None:
88 | special_tokens_dict["bos_token"] = DEFAULT_BOS_TOKEN
89 | if tokenizer.unk_token is None:
90 | special_tokens_dict["unk_token"] = DEFAULT_UNK_TOKEN
91 |
92 | smart_tokenizer_and_embedding_resize(
93 | special_tokens_dict=special_tokens_dict,
94 | tokenizer=tokenizer,
95 | model=model,
96 | )
97 |
98 | trainable_params = os.path.join(args.peft_model, "trainable_params.bin")
99 | if os.path.isfile(trainable_params):
100 | model.load_state_dict(torch.load(trainable_params, map_location=model.device), strict=False)
101 | model = PeftModel.from_pretrained(
102 | model,
103 | args.peft_model,
104 | device_map="auto",
105 | torch_dtype=torch.float16,
106 | )
107 | model = model.merge_and_unload()
108 | model.save_pretrained(args.save_path)
109 | tokenizer.save_pretrained(args.save_path)
110 |
111 | if __name__ == "__main__":
112 | args = parse_config()
113 | main(args)
114 |
--------------------------------------------------------------------------------
/passkey_retrivial.py:
--------------------------------------------------------------------------------
1 | # Written by Yukang Chen
2 | # Core code based on https://github.com/CStanKonrad/long_llama
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import os
17 | import math
18 | import torch
19 | import argparse
20 | import random
21 | import numpy as np
22 | from numpy import random
23 | from tqdm import tqdm
24 | import transformers
25 | from peft import PeftModel
26 | from llama_attn_replace import replace_llama_attn
27 |
28 |
29 | def parse_config():
30 | parser = argparse.ArgumentParser(description='arg parser')
31 | parser.add_argument('--base_model', type=str, default="/data1/pretrained-models/llama-7b-hf")
32 | parser.add_argument('--cache_dir', type=str, default="./cache")
33 | parser.add_argument('--context_size', type=int, default=-1, help='context size during fine-tuning')
34 | parser.add_argument('--flash_attn', type=bool, default=True, help='whether to use flash attention 2')
35 | parser.add_argument('--max_tokens', type=int, default=32000, help='maximum token length for evaluation')
36 | parser.add_argument('--interval', type=int, default=1000, help='interval for evaluation')
37 | parser.add_argument('--num_tests', type=int, default=10, help='number of repeat testing for each length')
38 |
39 | args = parser.parse_args()
40 | return args
41 |
42 |
43 | def generate_prompt_landmark(n_garbage, seed):
44 | """Generates a text file and inserts an passkey at a random position."""
45 | rnd_state = random.get_state()
46 | random.seed(seed)
47 | n_garbage_prefix = random.randint(0, n_garbage)
48 | n_garbage_suffix = n_garbage - n_garbage_prefix
49 |
50 | task_description = "There is an important info hidden inside a lot of irrelevant text. Find it and memorize them. I will quiz you about the important information there."
51 | garbage = "The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again."
52 | garbage_inf = " ".join([garbage] * 5000)
53 | assert len(garbage_inf) >= n_garbage
54 | garbage_prefix = garbage_inf[:n_garbage_prefix]
55 | garbage_suffix = garbage_inf[:n_garbage_suffix]
56 | pass_key = random.randint(1, 50000)
57 | information_line = f"The pass key is {pass_key}. Remember it. {pass_key} is the pass key."
58 | final_question = "What is the pass key? The pass key is"
59 | lines = [
60 | task_description,
61 | garbage_prefix,
62 | information_line,
63 | garbage_suffix,
64 | final_question,
65 | ]
66 | random.set_state(rnd_state)
67 | return "\n".join(lines), str(pass_key)
68 |
69 |
70 | def passkey_retrieval_test(model, tokenizer, device, use_cache=False, n_garbage=60000, seed=666):
71 | prompt, answer = generate_prompt_landmark(n_garbage, seed)
72 | input_ids = tokenizer(prompt, return_tensors="pt").input_ids
73 | input_ids = input_ids.to(device)
74 | len_token = input_ids.shape[-1]
75 |
76 | answer_ids = tokenizer(answer, return_tensors="pt").input_ids[:, 1:] # drop BOS
77 | generation_output = model.generate(
78 | input_ids=input_ids, max_new_tokens=answer_ids.shape[-1], num_beams=1, use_cache=use_cache
79 | )
80 |
81 | model_answer = generation_output[0, -answer_ids.shape[-1]:].cpu()
82 |
83 | is_correct = (model_answer == answer_ids[0]).all().item()
84 | #print(f"The correct answer is {tokenizer.decode(answer_ids[0].cpu())}")
85 | #print(f"The model answer is {tokenizer.decode(model_answer.cpu())}, is_correct : {is_correct}")
86 | return is_correct, len_token
87 |
88 |
89 | def main(args):
90 | device = "cuda:0"
91 | torch.cuda.set_device(device)
92 |
93 | print("base model", args.base_model)
94 |
95 | if args.flash_attn:
96 | replace_llama_attn(use_full=True)
97 |
98 | # Set RoPE scaling factor
99 | config = transformers.AutoConfig.from_pretrained(
100 | args.base_model,
101 | cache_dir=args.cache_dir,
102 | )
103 |
104 | context_size = args.context_size
105 | orig_ctx_len = getattr(config, "max_position_embeddings", None) # this value should be 4096 for LLaMA2 models
106 | if orig_ctx_len and context_size > orig_ctx_len:
107 | scaling_factor = float(math.ceil(context_size / orig_ctx_len))
108 | config.rope_scaling = {"type": "linear", "factor": scaling_factor}
109 |
110 | # Load model and tokenizer
111 | model = transformers.AutoModelForCausalLM.from_pretrained(
112 | args.base_model,
113 | config=config,
114 | cache_dir=args.cache_dir,
115 | torch_dtype=torch.float16,
116 | device_map="auto",
117 | )
118 | model.resize_token_embeddings(32001)
119 |
120 | tokenizer = transformers.AutoTokenizer.from_pretrained(
121 | args.base_model,
122 | cache_dir=args.cache_dir,
123 | model_max_length=args.context_size if args.context_size > orig_ctx_len else orig_ctx_len,
124 | padding_side="right",
125 | use_fast=False,
126 | )
127 |
128 | total_test_points = args.max_tokens // args.interval
129 | all_accuries = {}
130 | for i in range(total_test_points):
131 | # This is a rough ratio to control the number of texts and tokens
132 | n_garbage = int(3.75 * (i + 1) * args.interval // 1024 * 1024)
133 | passed_tests = 0
134 | total_tokens = 0
135 | for i in range(args.num_tests):
136 | is_correct, len_tokens = passkey_retrieval_test(model, tokenizer, device, use_cache=not args.flash_attn, n_garbage=n_garbage, seed=i)
137 | passed_tests += is_correct
138 | total_tokens += len_tokens
139 | avg_tokens = total_tokens//args.num_tests
140 | accuracy = float(passed_tests)/args.num_tests
141 | print("accuracy on the token length %d is %f"%(avg_tokens, accuracy))
142 | all_accuries[str(avg_tokens)] = accuracy
143 | print("accuries over tokens", all_accuries)
144 |
145 |
146 | if __name__ == "__main__":
147 | args = parse_config()
148 | main(args)
149 |
--------------------------------------------------------------------------------
/pdf2txt/README.md:
--------------------------------------------------------------------------------
1 | # Extract text from pdf by dit detection and ocr
2 |
3 | The script uses various libraries such as `pdf2image`, `easyocr`, `ditod` and `detectron2` for processing.
4 |
5 | Detected objects are categorized into "text", "title", "list", "table", and "figure".
6 |
7 | The script provides detailed timing information for various processing steps, which can be useful for performance analysis.
8 |
9 | Text extraction uses `easyocr` and the results are further processed using SymSpell for word segmentation and a regular expression for filtering.
10 |
11 | ### 1. Installation
12 | ```
13 | pip install -r requirements.txt
14 | python -m pip install 'git+https://github.com/facebookresearch/detectron2.git'
15 | apt-get install poppler-utils
16 | ```
17 |
18 | ### 2. Download OCR model
19 | - Please download the weight [trained_ocr_cascade_large.pth](https://drive.google.com/file/d/1DtHtR3hhj8Df_Lkgdm9P79Eljot5MR_i/view?usp=share_link) first.
20 | - Please set the weight path in `configs/cascade_dit_large.yaml`.
21 |
22 | ### 3. Basic usage
23 | ```
24 | python pdf2txt.py --pdf_path path_to_pdf_file --outputs_dir path_to_output_dir
25 | ```
26 |
27 | The output txt file will be stored in `path_to_output_dir/txt`
28 |
--------------------------------------------------------------------------------
/pdf2txt/backbone.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------------------------------
2 | # VIT: Multi-Path Vision Transformer for Dense Prediction
3 | # Copyright (c) 2022 Electronics and Telecommunications Research Institute (ETRI).
4 | # All Rights Reserved.
5 | # Written by Youngwan Lee
6 | # This source code is licensed(Dual License(GPL3.0 & Commercial)) under the license found in the
7 | # LICENSE file in the root directory of this source tree.
8 | # --------------------------------------------------------------------------------
9 | # References:
10 | # timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm
11 | # CoaT: https://github.com/mlpc-ucsd/CoaT
12 | # --------------------------------------------------------------------------------
13 |
14 |
15 | import torch
16 |
17 | from detectron2.layers import (
18 | ShapeSpec,
19 | )
20 | from detectron2.modeling import Backbone, BACKBONE_REGISTRY, FPN
21 | from detectron2.modeling.backbone.fpn import LastLevelP6P7, LastLevelMaxPool
22 |
23 | from beit import beit_base_patch16, dit_base_patch16, dit_large_patch16, beit_large_patch16
24 |
25 | __all__ = [
26 | "build_vit_fpn_backbone",
27 | ]
28 |
29 |
30 | class VIT_Backbone(Backbone):
31 | """
32 | Implement VIT backbone.
33 | """
34 |
35 | def __init__(self, name, out_features, drop_path, img_size, pos_type, model_kwargs):
36 | super().__init__()
37 | self._out_features = out_features
38 | if 'base' in name:
39 | self._out_feature_strides = {"layer3": 4, "layer5": 8, "layer7": 16, "layer11": 32}
40 | else:
41 | self._out_feature_strides = {"layer7": 4, "layer11": 8, "layer15": 16, "layer23": 32}
42 |
43 | if name == 'beit_base_patch16':
44 | model_func = beit_base_patch16
45 | self._out_feature_channels = {"layer3": 768, "layer5": 768, "layer7": 768, "layer11": 768}
46 | elif name == 'dit_base_patch16':
47 | model_func = dit_base_patch16
48 | self._out_feature_channels = {"layer3": 768, "layer5": 768, "layer7": 768, "layer11": 768}
49 | elif name == "deit_base_patch16":
50 | model_func = deit_base_patch16
51 | self._out_feature_channels = {"layer3": 768, "layer5": 768, "layer7": 768, "layer11": 768}
52 | elif name == "mae_base_patch16":
53 | model_func = mae_base_patch16
54 | self._out_feature_channels = {"layer3": 768, "layer5": 768, "layer7": 768, "layer11": 768}
55 | elif name == "dit_large_patch16":
56 | model_func = dit_large_patch16
57 | self._out_feature_channels = {"layer7": 1024, "layer11": 1024, "layer15": 1024, "layer23": 1024}
58 | elif name == "beit_large_patch16":
59 | model_func = beit_large_patch16
60 | self._out_feature_channels = {"layer7": 1024, "layer11": 1024, "layer15": 1024, "layer23": 1024}
61 | else:
62 | raise ValueError("Unsupported VIT name yet.")
63 |
64 | if 'beit' in name or 'dit' in name:
65 | if pos_type == "abs":
66 | self.backbone = model_func(img_size=img_size,
67 | out_features=out_features,
68 | drop_path_rate=drop_path,
69 | use_abs_pos_emb=True,
70 | **model_kwargs)
71 | elif pos_type == "shared_rel":
72 | self.backbone = model_func(img_size=img_size,
73 | out_features=out_features,
74 | drop_path_rate=drop_path,
75 | use_shared_rel_pos_bias=True,
76 | **model_kwargs)
77 | elif pos_type == "rel":
78 | self.backbone = model_func(img_size=img_size,
79 | out_features=out_features,
80 | drop_path_rate=drop_path,
81 | use_rel_pos_bias=True,
82 | **model_kwargs)
83 | else:
84 | raise ValueError()
85 | else:
86 | self.backbone = model_func(img_size=img_size,
87 | out_features=out_features,
88 | drop_path_rate=drop_path,
89 | **model_kwargs)
90 |
91 | def forward(self, x):
92 | """
93 | Args:
94 | x: Tensor of shape (N,C,H,W). H, W must be a multiple of ``self.size_divisibility``.
95 |
96 | Returns:
97 | dict[str->Tensor]: names and the corresponding features
98 | """
99 | assert x.dim() == 4, f"VIT takes an input of shape (N, C, H, W). Got {x.shape} instead!"
100 | return self.backbone.forward_features(x)
101 |
102 | def output_shape(self):
103 | return {
104 | name: ShapeSpec(
105 | channels=self._out_feature_channels[name], stride=self._out_feature_strides[name]
106 | )
107 | for name in self._out_features
108 | }
109 |
110 |
111 | def build_VIT_backbone(cfg):
112 | """
113 | Create a VIT instance from config.
114 |
115 | Args:
116 | cfg: a detectron2 CfgNode
117 |
118 | Returns:
119 | A VIT backbone instance.
120 | """
121 | # fmt: off
122 | name = cfg.MODEL.VIT.NAME
123 | out_features = cfg.MODEL.VIT.OUT_FEATURES
124 | drop_path = cfg.MODEL.VIT.DROP_PATH
125 | img_size = cfg.MODEL.VIT.IMG_SIZE
126 | pos_type = cfg.MODEL.VIT.POS_TYPE
127 |
128 | model_kwargs = eval(str(cfg.MODEL.VIT.MODEL_KWARGS).replace("`", ""))
129 |
130 | return VIT_Backbone(name, out_features, drop_path, img_size, pos_type, model_kwargs)
131 |
132 |
133 | @BACKBONE_REGISTRY.register()
134 | def build_vit_fpn_backbone(cfg, input_shape: ShapeSpec):
135 | """
136 | Create a VIT w/ FPN backbone.
137 |
138 | Args:
139 | cfg: a detectron2 CfgNode
140 |
141 | Returns:
142 | backbone (Backbone): backbone module, must be a subclass of :class:`Backbone`.
143 | """
144 | bottom_up = build_VIT_backbone(cfg)
145 | in_features = cfg.MODEL.FPN.IN_FEATURES
146 | out_channels = cfg.MODEL.FPN.OUT_CHANNELS
147 | backbone = FPN(
148 | bottom_up=bottom_up,
149 | in_features=in_features,
150 | out_channels=out_channels,
151 | norm=cfg.MODEL.FPN.NORM,
152 | top_block=LastLevelMaxPool(),
153 | fuse_type=cfg.MODEL.FPN.FUSE_TYPE,
154 | )
155 | return backbone
156 |
--------------------------------------------------------------------------------
/pdf2txt/config.py:
--------------------------------------------------------------------------------
1 | from detectron2.config import CfgNode as CN
2 |
3 |
4 | def add_vit_config(cfg):
5 | """
6 | Add config for VIT.
7 | """
8 | _C = cfg
9 |
10 | _C.MODEL.VIT = CN()
11 |
12 | # CoaT model name.
13 | _C.MODEL.VIT.NAME = ""
14 |
15 | # Output features from CoaT backbone.
16 | _C.MODEL.VIT.OUT_FEATURES = ["layer3", "layer5", "layer7", "layer11"]
17 |
18 | _C.MODEL.VIT.IMG_SIZE = [224, 224]
19 |
20 | _C.MODEL.VIT.POS_TYPE = "shared_rel"
21 |
22 | _C.MODEL.VIT.DROP_PATH = 0.
23 |
24 | _C.MODEL.VIT.MODEL_KWARGS = "{}"
25 |
26 | _C.SOLVER.OPTIMIZER = "ADAMW"
27 |
28 | _C.SOLVER.BACKBONE_MULTIPLIER = 1.0
29 |
30 | _C.AUG = CN()
31 |
32 | _C.AUG.DETR = False
33 |
--------------------------------------------------------------------------------
/pdf2txt/configs/Base-RCNN-FPN.yaml:
--------------------------------------------------------------------------------
1 | MODEL:
2 | MASK_ON: True
3 | META_ARCHITECTURE: "GeneralizedRCNN"
4 | PIXEL_MEAN: [123.675, 116.280, 103.530]
5 | PIXEL_STD: [58.395, 57.120, 57.375]
6 | BACKBONE:
7 | NAME: "build_vit_fpn_backbone"
8 | VIT:
9 | OUT_FEATURES: ["layer3", "layer5", "layer7", "layer11"]
10 | DROP_PATH: 0.1
11 | IMG_SIZE: [224,224]
12 | POS_TYPE: "abs"
13 | FPN:
14 | IN_FEATURES: ["layer3", "layer5", "layer7", "layer11"]
15 | ANCHOR_GENERATOR:
16 | SIZES: [[32], [64], [128], [256], [512]] # One size for each in feature map
17 | ASPECT_RATIOS: [[0.5, 1.0, 2.0]] # Three aspect ratios (same for all in feature maps)
18 | RPN:
19 | IN_FEATURES: ["p2", "p3", "p4", "p5", "p6"]
20 | PRE_NMS_TOPK_TRAIN: 2000 # Per FPN level
21 | PRE_NMS_TOPK_TEST: 1000 # Per FPN level
22 | # Detectron1 uses 2000 proposals per-batch,
23 | # (See "modeling/rpn/rpn_outputs.py" for details of this legacy issue)
24 | # which is approximately 1000 proposals per-image since the default batch size for FPN is 2.
25 | POST_NMS_TOPK_TRAIN: 1000
26 | POST_NMS_TOPK_TEST: 1000
27 | ROI_HEADS:
28 | NAME: "StandardROIHeads"
29 | IN_FEATURES: ["p2", "p3", "p4", "p5"]
30 | NUM_CLASSES: 5
31 | ROI_BOX_HEAD:
32 | NAME: "FastRCNNConvFCHead"
33 | NUM_FC: 2
34 | POOLER_RESOLUTION: 7
35 | ROI_MASK_HEAD:
36 | NAME: "MaskRCNNConvUpsampleHead"
37 | NUM_CONV: 4
38 | POOLER_RESOLUTION: 14
39 | DATASETS:
40 | TRAIN: ("publaynet_train",)
41 | TEST: ("publaynet_val",)
42 | SOLVER:
43 | LR_SCHEDULER_NAME: "WarmupCosineLR"
44 | AMP:
45 | ENABLED: True
46 | OPTIMIZER: "ADAMW"
47 | BACKBONE_MULTIPLIER: 1.0
48 | CLIP_GRADIENTS:
49 | ENABLED: True
50 | CLIP_TYPE: "full_model"
51 | CLIP_VALUE: 1.0
52 | NORM_TYPE: 2.0
53 | WARMUP_FACTOR: 0.01
54 | BASE_LR: 0.0004
55 | WEIGHT_DECAY: 0.05
56 | IMS_PER_BATCH: 32
57 | INPUT:
58 | CROP:
59 | ENABLED: True
60 | TYPE: "absolute_range"
61 | SIZE: (384, 600)
62 | MIN_SIZE_TRAIN: (480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800)
63 | FORMAT: "RGB"
64 | DATALOADER:
65 | FILTER_EMPTY_ANNOTATIONS: False
66 | VERSION: 2
67 | AUG:
68 | DETR: True
69 | SEED: 42
--------------------------------------------------------------------------------
/pdf2txt/configs/cascade_dit_large.yaml:
--------------------------------------------------------------------------------
1 | _BASE_: "./Base-RCNN-FPN.yaml"
2 | MODEL:
3 | PIXEL_MEAN: [ 127.5, 127.5, 127.5 ]
4 | PIXEL_STD: [ 127.5, 127.5, 127.5 ]
5 | WEIGHTS: "./trained_ocr_cascade_large.pth"
6 | VIT:
7 | NAME: "dit_large_patch16"
8 | OUT_FEATURES: [ "layer7", "layer11", "layer15", "layer23" ]
9 | DROP_PATH: 0.2
10 | FPN:
11 | IN_FEATURES: [ "layer7", "layer11", "layer15", "layer23" ]
12 | ROI_HEADS:
13 | NAME: CascadeROIHeads
14 | ROI_BOX_HEAD:
15 | CLS_AGNOSTIC_BBOX_REG: True
16 | RPN:
17 | POST_NMS_TOPK_TRAIN: 2000
18 | SOLVER:
19 | WARMUP_ITERS: 1000
20 | IMS_PER_BATCH: 16
21 | MAX_ITER: 60000
22 | CHECKPOINT_PERIOD: 2000
23 | BASE_LR: 0.0001
24 | STEPS: (40000, 53333)
25 | AMP:
26 | ENABLED: False
27 | TEST:
28 | EVAL_PERIOD: 2000
29 |
--------------------------------------------------------------------------------
/pdf2txt/pdf2txt.py:
--------------------------------------------------------------------------------
1 | # Written by Shaozuo Yu
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import os
16 | import argparse
17 | import pdf2image
18 | import easyocr
19 | import cv2
20 | from config import add_vit_config
21 | from backbone import build_vit_fpn_backbone
22 | import torch
23 | from detectron2.config import get_cfg
24 | from detectron2.utils.visualizer import ColorMode, Visualizer
25 | from detectron2.data import MetadataCatalog
26 | from detectron2.engine import DefaultPredictor
27 | from detectron2.layers import nms
28 | import pickle
29 | import numpy as np
30 | import shutil
31 | from tqdm import tqdm
32 | from PIL import Image
33 | import time
34 |
35 | from symspellpy.symspellpy import SymSpell
36 | import pkg_resources
37 | import re
38 |
39 |
40 | prefix_length = 7
41 | sym_spell = SymSpell(max_dictionary_edit_distance=0, prefix_length=prefix_length)
42 | dictionary_path = pkg_resources.resource_filename("symspellpy", "frequency_dictionary_en_82_765.txt")
43 | sym_spell.load_dictionary(dictionary_path, term_index=0, count_index=1)
44 | # filter
45 | regex = "[A-Za-z0-9=:/\*]*[=:+-][A-Za-z0-9=:/\*]"
46 |
47 |
48 | def detect_objects(image_path, predictor, cfg):
49 | # Step 5: run inference
50 | img = cv2.imread(image_path)
51 |
52 | md = MetadataCatalog.get(cfg.DATASETS.TEST[0])
53 | md.set(thing_classes=["text", "title", "list", "table", "figure"])
54 |
55 | start_time = time.time()
56 |
57 | detections = predictor(img)["instances"]
58 |
59 | end_time = time.time()
60 |
61 | print(f"detection model部分执行时间: {end_time - start_time} 秒")
62 | # get boxes and scores
63 | boxes = detections.pred_boxes.tensor
64 | scores = detections.scores
65 |
66 | # NMS
67 | keep = nms(boxes, scores, 0.1)
68 | detections = detections[keep]
69 | scores = detections.scores
70 |
71 | threshold = 0.8 # you can adjust this value
72 | keep2 = torch.nonzero(scores > threshold).squeeze(1)
73 | detections = detections[keep2]
74 |
75 | return detections
76 |
77 | def process_pdf(pdf_file, outputs_dir, config_file):
78 |
79 | results = {}
80 |
81 | tmp_dir = os.path.join(outputs_dir, 'tmp')
82 | txt_dir = os.path.join(outputs_dir, 'txt')
83 | os.makedirs(tmp_dir, exist_ok=True)
84 | os.makedirs(txt_dir, exist_ok=True)
85 |
86 | #load detection model
87 | cfg = get_cfg()
88 | add_vit_config(cfg)
89 | cfg.merge_from_file(config_file)
90 | device = "cuda" if torch.cuda.is_available() else "cpu"
91 | #device = "cpu"
92 | cfg.MODEL.DEVICE = device
93 | predictor = DefaultPredictor(cfg)
94 | reader = easyocr.Reader(['en'], gpu=True)
95 |
96 | book_name = os.path.splitext(pdf_file)[0]
97 | book_base_name = os.path.basename(pdf_file)
98 | txt_file_path = os.path.join(txt_dir, f"{book_base_name}.txt")
99 | if os.path.exists(txt_file_path):
100 | raise ValueError(f"Skipping {book_name} as it already exists in the output directory.")
101 |
102 | start_time = time.time()
103 |
104 | book_name = os.path.splitext(pdf_file)[0]
105 | images = pdf2image.convert_from_path(pdf_file)
106 |
107 | end_time = time.time()
108 |
109 | print(f"pdf2image time: {end_time - start_time} s")
110 |
111 | book_results = []
112 | for page_num, image in tqdm(enumerate(images, start=1), desc=f"Processing {book_name}", leave=False):
113 | image_path = os.path.join(tmp_dir, f"{book_base_name}-{page_num}.png")
114 | image.save(image_path)
115 |
116 | start_time = time.time()
117 |
118 | detections = detect_objects(image_path, predictor, cfg)
119 |
120 | end_time = time.time()
121 |
122 | print(f"detection time: {end_time - start_time} s")
123 |
124 | boxes = detections.pred_boxes.tensor.tolist()
125 | labels = detections.pred_classes.tolist()
126 |
127 | # get boxes
128 | all_detections = [(bbox, label_id) for bbox, label_id in zip(boxes, labels)]
129 |
130 | # sort
131 | all_detections.sort(key=lambda x: (x[0][1], x[0][0]))
132 |
133 | start_time = time.time()
134 |
135 | label_counter = {"figure": 0, "table": 0, 'text': 0, 'list': 0, 'title': 0}
136 | for bbox, label_id in all_detections:
137 | #print("Number of classes:", len(MetadataCatalog.get(cfg.DATASETS.TEST[0]).thing_classes))
138 | label = MetadataCatalog.get(cfg.DATASETS.TEST[0]).thing_classes[label_id]
139 | cropped_image_np = np.array(image.crop(bbox))
140 |
141 | if label in ['text', 'list', 'title']:
142 | #reader = easyocr.Reader(['en'], cudnn_benchmark=True)
143 | ocr_result = reader.readtext(cropped_image_np, batch_size=10)
144 | extracted_text = ' '.join([item[1] for item in ocr_result])
145 |
146 | # SymSpell for word segmentation
147 | suggestions = sym_spell.word_segmentation(extracted_text)
148 | segmented_text = suggestions.corrected_string
149 |
150 | # filter
151 | filtered_text = re.sub(regex, "", segmented_text)
152 |
153 | book_results.append(extracted_text)
154 |
155 | end_time = time.time()
156 |
157 | print(f"ocr time: {end_time - start_time} s")
158 |
159 | results[book_name] = book_results
160 | with open(os.path.join(txt_dir, f"{book_base_name}.txt"), 'w') as f:
161 | f.write('\n'.join(book_results))
162 |
163 | # delete tmp dir
164 | shutil.rmtree(tmp_dir)
165 |
166 |
167 | if __name__ == '__main__':
168 | parser = argparse.ArgumentParser(description="PDF processing script")
169 | parser.add_argument("--pdf_path", help="Path to PDF file", type=str, required=True)
170 | parser.add_argument("--outputs_dir", help="Directory to save outputs", type=str, required=True)
171 | parser.add_argument("--config_file", default="configs/cascade_dit_large.yaml", metavar="FILE", help="path to config file")
172 |
173 | args = parser.parse_args()
174 |
175 | process_pdf(args.pdf_path, args.outputs_dir, args.config_file)
176 |
--------------------------------------------------------------------------------
/pdf2txt/requirements.txt:
--------------------------------------------------------------------------------
1 | torch
2 | torchvision
3 | timm==0.5.4
4 | Pillow
5 | blobfile
6 | mypy
7 | numpy
8 | pytest
9 | requests
10 | einops
11 | tensorboardX
12 | scipy
13 | opencv-python
14 | pdf2image
15 | easyocr
16 | argparse
17 | regex
18 | symspellpy
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy>=1.26.0
2 | rouge_score>=0.1.2
3 | fire>=0.5.0
4 | # openai
5 | transformers==4.34.0
6 | torch>=2.0.0
7 | sentencepiece>=0.1.99
8 | tokenizers>=0.14.0
9 | # wandb
10 | accelerate>=0.23.0
11 | datasets>=2.14.5
12 | deepspeed>=0.10.3
13 | peft>=0.5.0
14 | # partial
15 | # gradio
16 | einops>=0.7.0
17 | bitsandbytes==0.41.1
18 | scipy>=1.11.3
19 | protobuf>=4.24.4
20 | torchmetrics>=1.2.0
21 |
--------------------------------------------------------------------------------
/run_streaming_llama_longalpaca.py:
--------------------------------------------------------------------------------
1 | import warnings
2 |
3 | warnings.filterwarnings("ignore")
4 |
5 | import torch
6 | import argparse
7 | import json
8 | import os
9 | import time
10 | import re
11 | import sys
12 |
13 | from tqdm import tqdm
14 | from streaming_llm.utils import load, download_url, load_jsonl
15 | from streaming_llm.enable_streaming_llm import enable_streaming_llm
16 |
17 |
18 | @torch.no_grad()
19 | def greedy_generate(model, tokenizer, input_ids, past_key_values, max_gen_len):
20 | outputs = model(
21 | input_ids=input_ids,
22 | past_key_values=past_key_values,
23 | use_cache=True,
24 | )
25 | past_key_values = outputs.past_key_values
26 | pred_token_idx = outputs.logits[:, -1, :].argmax(dim=-1).unsqueeze(1)
27 | generated_ids = [pred_token_idx.item()]
28 | pos = 0
29 | for _ in range(max_gen_len - 1):
30 | outputs = model(
31 | input_ids=pred_token_idx,
32 | past_key_values=past_key_values,
33 | use_cache=True,
34 | )
35 | past_key_values = outputs.past_key_values
36 | pred_token_idx = outputs.logits[:, -1, :].argmax(dim=-1).unsqueeze(1)
37 | generated_ids.append(pred_token_idx.item())
38 | generated_text = (
39 | tokenizer.decode(
40 | generated_ids,
41 | skip_special_tokens=True,
42 | clean_up_tokenization_spaces=True,
43 | spaces_between_special_tokens=False,
44 | )
45 | .strip()
46 | .split(" ")
47 | )
48 |
49 | now = len(generated_text) - 1
50 | if now > pos:
51 | print(" ".join(generated_text[pos:now]), end=" ", flush=True)
52 | pos = now
53 |
54 | if pred_token_idx == tokenizer.eos_token_id:
55 | break
56 | print(" ".join(generated_text[pos:]), flush=True)
57 | return past_key_values
58 |
59 |
60 | @torch.no_grad()
61 | def streaming_inference(model, tokenizer, prompts, kv_cache=None, max_gen_len=1000):
62 | past_key_values = None
63 | for idx, prompt in enumerate(prompts):
64 | prompt = "USER: " + prompt + "\n\nASSISTANT: "
65 | print("\n" + prompt, end="")
66 | input_ids = tokenizer(prompt, return_tensors="pt").input_ids
67 | input_ids = input_ids.to(model.device)
68 | seq_len = input_ids.shape[1]
69 | if kv_cache is not None:
70 | space_needed = seq_len + max_gen_len
71 | past_key_values = kv_cache.evict_for_space(past_key_values, space_needed)
72 |
73 | past_key_values = greedy_generate(
74 | model, tokenizer, input_ids, past_key_values, max_gen_len=max_gen_len
75 | )
76 |
77 |
78 | def main(args):
79 | model_name_or_path = args.model_name_or_path
80 | model, tokenizer = load(model_name_or_path)
81 | print(f"Loading data from {args.test_filepath} ...")
82 |
83 | list_data = json.load(open(args.test_filepath))
84 | prompts = []
85 | for sample in list_data:
86 | prompts += [sample["instruction"]]
87 |
88 | if args.enable_streaming:
89 | kv_cache = enable_streaming_llm(
90 | model, start_size=args.start_size, recent_size=args.recent_size, use_flash_attn=args.use_flash_attn
91 | )
92 | else:
93 | kv_cache = None
94 |
95 | streaming_inference(
96 | model,
97 | tokenizer,
98 | prompts,
99 | kv_cache,
100 | )
101 |
102 |
103 | if __name__ == "__main__":
104 | parser = argparse.ArgumentParser()
105 | parser.add_argument(
106 | "--model_name_or_path", type=str, default="Yukang/LongAlpaca-7B"
107 | )
108 | parser.add_argument("--test_filepath", type=str, default="outputs_stream.json")
109 | parser.add_argument("--enable_streaming", action="store_true")
110 | parser.add_argument("--start_size", type=int, default=4)
111 | parser.add_argument("--recent_size", type=int, default=8192)
112 | parser.add_argument("--use_flash_attn", type=bool, default=True)
113 | args = parser.parse_args()
114 |
115 | main(args)
116 |
--------------------------------------------------------------------------------
/streaming_llm/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dvlab-research/LongLoRA/d4eb344c5ccc9e91c0812a2b2aeea69070df8c33/streaming_llm/__init__.py
--------------------------------------------------------------------------------
/streaming_llm/enable_streaming_llm.py:
--------------------------------------------------------------------------------
1 | from streaming_llm.kv_cache import StartRecentKVCache
2 |
3 |
4 | def enable_streaming_llm(model, start_size, recent_size, use_flash_attn=True):
5 | if "llama" in model.config.model_type:
6 | k_seq_dim = v_seq_dim = 2
7 | from streaming_llm.pos_shift.modify_llama import (
8 | enable_llama_pos_shift_attention,
9 | )
10 |
11 | enable_llama_pos_shift_attention(model, use_flash_attn)
12 | elif "mpt" in model.config.model_type:
13 | v_seq_dim = 2
14 | k_seq_dim = 3
15 | elif "gpt_neox" in model.config.model_type:
16 | k_seq_dim = v_seq_dim = 2
17 | from streaming_llm.pos_shift.modify_gpt_neox import (
18 | enable_gpt_neox_pos_shift_attention,
19 | )
20 |
21 | enable_gpt_neox_pos_shift_attention(model)
22 | elif "falcon" in model.config.model_type:
23 | v_seq_dim = 1
24 | k_seq_dim = 1
25 | from streaming_llm.pos_shift.modify_falcon import (
26 | enable_falcon_pos_shift_attention,
27 | )
28 |
29 | enable_falcon_pos_shift_attention(model)
30 | else:
31 | raise ValueError(f"got {model.config.model_type}")
32 | kv_cache = StartRecentKVCache(
33 | start_size=start_size,
34 | recent_size=recent_size,
35 | k_seq_dim=k_seq_dim,
36 | v_seq_dim=v_seq_dim,
37 | )
38 | return kv_cache
39 |
--------------------------------------------------------------------------------
/streaming_llm/kv_cache.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def slice2d(x, start, end):
5 | return x[:, :, start:end, ...]
6 |
7 |
8 | def slice3d(x, start, end):
9 | return x[:, :, :, start:end, ...]
10 |
11 |
12 | def slice1d(x, start, end):
13 | return x[:, start:end, ...]
14 |
15 |
16 | DIM_TO_SLICE = {
17 | 1: slice1d,
18 | 2: slice2d,
19 | 3: slice3d,
20 | }
21 |
22 |
23 | class StartRecentKVCache:
24 | def __init__(
25 | self,
26 | start_size=4,
27 | recent_size=512,
28 | k_seq_dim=2,
29 | v_seq_dim=2,
30 | ):
31 | print(f"StartRecentKVCache: {start_size}, {recent_size}")
32 | self.start_size = start_size
33 | self.recent_size = recent_size
34 | self.cache_size = start_size + recent_size
35 | self.k_seq_dim = k_seq_dim
36 | self.v_seq_dim = v_seq_dim
37 | self.k_slice = DIM_TO_SLICE[k_seq_dim]
38 | self.v_slice = DIM_TO_SLICE[v_seq_dim]
39 |
40 | def __call__(self, past_key_values):
41 | if past_key_values is None:
42 | return None
43 | seq_len = past_key_values[0][0].size(self.k_seq_dim)
44 | if seq_len <= self.cache_size:
45 | return past_key_values
46 | return [
47 | [
48 | torch.cat(
49 | [
50 | self.k_slice(k, 0, self.start_size),
51 | self.k_slice(k, seq_len - self.recent_size, seq_len),
52 | ],
53 | dim=self.k_seq_dim,
54 | ),
55 | torch.cat(
56 | [
57 | self.v_slice(v, 0, self.start_size),
58 | self.v_slice(v, seq_len - self.recent_size, seq_len),
59 | ],
60 | dim=self.v_seq_dim,
61 | ),
62 | ]
63 | for k, v in past_key_values
64 | ]
65 |
66 | def evict_for_space(self, past_key_values, num_coming):
67 | if past_key_values is None:
68 | return None
69 | seq_len = past_key_values[0][0].size(self.k_seq_dim)
70 | if seq_len + num_coming <= self.cache_size:
71 | return past_key_values
72 | return [
73 | [
74 | torch.cat(
75 | [
76 | self.k_slice(k, 0, self.start_size),
77 | self.k_slice(
78 | k, seq_len - self.recent_size + num_coming, seq_len
79 | ),
80 | ],
81 | dim=self.k_seq_dim,
82 | ),
83 | torch.cat(
84 | [
85 | self.v_slice(v, 0, self.start_size),
86 | self.v_slice(
87 | v, seq_len - self.recent_size + num_coming, seq_len
88 | ),
89 | ],
90 | dim=self.v_seq_dim,
91 | ),
92 | ]
93 | for k, v in past_key_values
94 | ]
95 |
96 | def evict_range(self, past_key_values, start, end):
97 | if past_key_values is None:
98 | return None
99 | seq_len = past_key_values[0][0].size(self.k_seq_dim)
100 | assert start <= end and end <= seq_len
101 | return [
102 | [
103 | torch.cat(
104 | [
105 | self.k_slice(k, 0, start),
106 | self.k_slice(k, end, seq_len),
107 | ],
108 | dim=self.k_seq_dim,
109 | ),
110 | torch.cat(
111 | [
112 | self.v_slice(v, 0, start),
113 | self.v_slice(v, end, seq_len),
114 | ],
115 | dim=self.v_seq_dim,
116 | ),
117 | ]
118 | for k, v in past_key_values
119 | ]
120 |
--------------------------------------------------------------------------------
/streaming_llm/pos_shift/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/dvlab-research/LongLoRA/d4eb344c5ccc9e91c0812a2b2aeea69070df8c33/streaming_llm/pos_shift/__init__.py
--------------------------------------------------------------------------------
/streaming_llm/pos_shift/modify_falcon.py:
--------------------------------------------------------------------------------
1 | import math
2 | from typing import Optional, Tuple
3 |
4 | import torch
5 | from torch import nn
6 | import torch.utils.checkpoint
7 |
8 | import torch.nn.functional as F
9 |
10 | from transformers.models.falcon.modeling_falcon import (
11 | FalconAttention,
12 | rotate_half,
13 | )
14 | import types
15 |
16 | __all__ = ["enable_falcon_pos_shift_attention"]
17 |
18 |
19 | def falcon_pos_shift_attention_forward(
20 | self,
21 | hidden_states: torch.Tensor,
22 | alibi: torch.Tensor,
23 | attention_mask: torch.Tensor,
24 | layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
25 | head_mask: Optional[torch.Tensor] = None,
26 | use_cache: bool = False,
27 | output_attentions: bool = False,
28 | ):
29 | fused_qkv = self.query_key_value(
30 | hidden_states
31 | ) # [batch_size, seq_length, 3 x hidden_size]
32 |
33 | # 3 x [batch_size, seq_length, num_heads, head_dim]
34 | (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv)
35 |
36 | batch_size, q_length, _, _ = query_layer.shape
37 |
38 | query_layer = query_layer.transpose(1, 2).reshape(
39 | batch_size * self.num_heads, q_length, self.head_dim
40 | )
41 |
42 | # dirty hack to fix the inconsistency between falcon-40b and falcon-7b
43 | num_kv = self.num_heads if self.num_heads == 128 else self.num_kv
44 | key_layer = key_layer.transpose(1, 2).reshape(
45 | batch_size * num_kv,
46 | q_length,
47 | self.head_dim,
48 | )
49 | value_layer = value_layer.transpose(1, 2).reshape(
50 | batch_size * num_kv, q_length, self.head_dim
51 | )
52 |
53 | past_len = 0
54 | if layer_past is not None:
55 | past_len = layer_past[0].shape[1]
56 |
57 | query_layer_copy = query_layer.clone()
58 | query_layer, _ = self.maybe_rotary(query_layer, query_layer_copy, past_len)
59 | if layer_past is not None:
60 | past_key, past_value = layer_past
61 | # concatenate along seq_length dimension:
62 | # - key: [batch_size * self.num_heads, head_dim, kv_length]
63 | # - value: [batch_size * self.num_heads, kv_length, head_dim]
64 | key_layer = torch.cat((past_key, key_layer), dim=1)
65 | value_layer = torch.cat((past_value, value_layer), dim=1)
66 |
67 | if use_cache is True:
68 | present = (key_layer, value_layer)
69 | else:
70 | present = None
71 |
72 | key_layer_copy = key_layer.clone()
73 | _, key_layer = self.maybe_rotary(key_layer_copy, key_layer, 0)
74 |
75 | _, kv_length, _ = key_layer.shape
76 |
77 | if alibi is None:
78 | query_layer_ = query_layer.reshape(
79 | batch_size, self.num_heads, -1, self.head_dim
80 | )
81 | key_layer_ = key_layer.reshape(batch_size, num_kv, -1, self.head_dim)
82 | value_layer_ = value_layer.reshape(batch_size, num_kv, -1, self.head_dim)
83 |
84 | if layer_past is not None:
85 | attn_output = F.scaled_dot_product_attention(
86 | query_layer_, key_layer_, value_layer_, None, 0.0, is_causal=False
87 | )
88 | else:
89 | attn_output = F.scaled_dot_product_attention(
90 | query_layer_, key_layer_, value_layer_, None, 0.0, is_causal=True
91 | )
92 |
93 | x = attn_output.view(batch_size, self.num_heads, q_length, self.head_dim)
94 | x = x.permute(0, 2, 1, 3)
95 | attn_output = x.reshape(batch_size, q_length, self.num_heads * self.head_dim)
96 |
97 | output_tensor = self.dense(attn_output)
98 |
99 | outputs = (output_tensor, present)
100 | assert not output_attentions # not supported.
101 | return outputs
102 | else:
103 | attention_mask_float = (
104 | (attention_mask * 1.0).masked_fill(attention_mask, -1e9).to(torch.bfloat16)
105 | )
106 | matmul_result = query_layer @ key_layer.transpose(-1, -2)
107 |
108 | # change view to [batch_size, num_heads, q_length, kv_length]
109 | attention_scores = matmul_result.view(
110 | batch_size, self.num_heads, q_length, kv_length
111 | )
112 |
113 | # cast attention scores to fp32, compute scaled softmax and cast back to initial dtype - [batch_size, num_heads, q_length, kv_length]
114 | input_dtype = attention_scores.dtype
115 | # `float16` has a minimum value of -65504.0, whereas `bfloat16` and `float32` have a minimum value of `-3.4e+38`
116 | if input_dtype == torch.float16 or input_dtype == torch.bfloat16:
117 | attention_scores = attention_scores.to(torch.float32)
118 | # attn_weights = torch.masked_fill(attention_scores, attention_mask, torch.finfo(attention_scores.dtype).min)
119 | attention_probs = F.softmax(
120 | (attention_scores + alibi.view(batch_size, self.num_heads, 1, -1))
121 | * self.inv_norm_factor
122 | + attention_mask_float,
123 | dim=-1,
124 | dtype=hidden_states.dtype,
125 | )
126 | # [batch_size, num_heads, q_length, kv_length]
127 | attention_probs = self.attention_dropout(attention_probs)
128 |
129 | if head_mask is not None:
130 | attention_probs = attention_probs * head_mask
131 |
132 | # change view [batch_size x num_heads, q_length, kv_length]
133 | attention_probs_reshaped = attention_probs.view(
134 | batch_size * self.num_heads, q_length, kv_length
135 | )
136 |
137 | # matmul: [batch_size * num_heads, q_length, head_dim]
138 | context_layer = attention_probs_reshaped @ value_layer
139 |
140 | # change view [batch_size, num_heads, q_length, head_dim]
141 | context_layer = self._merge_heads(context_layer)
142 |
143 | output_tensor = self.dense(context_layer)
144 |
145 | outputs = (output_tensor, present)
146 | if output_attentions:
147 | outputs += (attention_probs,)
148 |
149 | return outputs
150 |
151 |
152 | def enable_falcon_pos_shift_attention(model):
153 | for name, module in reversed(model._modules.items()):
154 | if len(list(module.children())) > 0:
155 | enable_falcon_pos_shift_attention(
156 | module,
157 | )
158 |
159 | if "self_attention" == name[-14:]:
160 | model._modules[name].forward = types.MethodType(
161 | falcon_pos_shift_attention_forward, model._modules[name]
162 | )
163 |
--------------------------------------------------------------------------------
/streaming_llm/pos_shift/modify_gpt_neox.py:
--------------------------------------------------------------------------------
1 | import math
2 | from typing import Optional, Tuple
3 |
4 | import torch
5 | from torch import nn
6 | import torch.utils.checkpoint
7 |
8 | import torch.nn.functional as F
9 |
10 | from transformers.models.gpt_neox.modeling_gpt_neox import (
11 | apply_rotary_pos_emb,
12 | rotate_half,
13 | GPTNeoXAttention,
14 | )
15 | import types
16 |
17 | __all__ = ["enable_gpt_neox_pos_shift_attention"]
18 |
19 |
20 | def apply_rotary_pos_emb_single(x, cos, sin, position_ids):
21 | gather_indices = position_ids[:, None, :, None] # [bs, 1, seq_len, 1]
22 | gather_indices = gather_indices.repeat(1, cos.shape[1], 1, cos.shape[3])
23 | cos = torch.gather(cos.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices)
24 | sin = torch.gather(sin.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices)
25 | x_embed = (x * cos) + (rotate_half(x) * sin)
26 | return x_embed
27 |
28 |
29 | def gpt_neox_pos_shift_attention_forward(
30 | self,
31 | hidden_states: torch.FloatTensor,
32 | attention_mask: torch.FloatTensor,
33 | position_ids: torch.LongTensor,
34 | head_mask: Optional[torch.FloatTensor] = None,
35 | layer_past: Optional[Tuple[torch.Tensor]] = None,
36 | use_cache: Optional[bool] = False,
37 | output_attentions: Optional[bool] = False,
38 | ):
39 | has_layer_past = layer_past is not None
40 |
41 | # Compute QKV
42 | # Attention heads [batch, seq_len, hidden_size]
43 | # --> [batch, seq_len, (np * 3 * head_size)]
44 | qkv = self.query_key_value(hidden_states)
45 |
46 | # [batch, seq_len, (num_heads * 3 * head_size)]
47 | # --> [batch, seq_len, num_heads, 3 * head_size]
48 | new_qkv_shape = qkv.size()[:-1] + (self.num_attention_heads, 3 * self.head_size)
49 | qkv = qkv.view(*new_qkv_shape)
50 |
51 | # [batch, seq_len, num_attention_heads, 3 * head_size] --> 3 [batch, num_attention_heads, seq_len, head_size]
52 | query = qkv[..., : self.head_size].permute(0, 2, 1, 3)
53 | key = qkv[..., self.head_size : 2 * self.head_size].permute(0, 2, 1, 3)
54 | value = qkv[..., 2 * self.head_size :].permute(0, 2, 1, 3)
55 |
56 | # Compute rotary embeddings on rotary_ndims
57 | query_rot = query[..., : self.rotary_ndims]
58 | query_pass = query[..., self.rotary_ndims :]
59 |
60 | # Compute token offset for rotary embeddings (when decoding)
61 | seq_len = key.shape[-2]
62 | if has_layer_past:
63 | seq_len += layer_past[0].shape[-2]
64 | cos, sin = self.rotary_emb(value, seq_len=seq_len)
65 | query = apply_rotary_pos_emb_single(query_rot, cos, sin, position_ids)
66 | query = torch.cat((query, query_pass), dim=-1)
67 |
68 | # Cache QKV values
69 | if has_layer_past:
70 | past_key = layer_past[0]
71 | past_value = layer_past[1]
72 | key = torch.cat((past_key, key), dim=-2)
73 | value = torch.cat((past_value, value), dim=-2)
74 |
75 | present = (key, value) if use_cache else None
76 |
77 | key_rot = key[..., : self.rotary_ndims]
78 | key_pass = key[..., self.rotary_ndims :]
79 | key_position_ids = torch.arange(seq_len, device=position_ids.device).unsqueeze(0)
80 | key = apply_rotary_pos_emb_single(key_rot, cos, sin, key_position_ids)
81 | key = torch.cat((key, key_pass), dim=-1)
82 |
83 | # Compute attention
84 | attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
85 |
86 | # Reshape outputs
87 | attn_output = self._merge_heads(
88 | attn_output, self.num_attention_heads, self.head_size
89 | )
90 | attn_output = self.dense(attn_output)
91 |
92 | outputs = (attn_output, present)
93 | if output_attentions:
94 | outputs += (attn_weights,)
95 |
96 | return outputs
97 |
98 |
99 | def enable_gpt_neox_pos_shift_attention(model):
100 | for name, module in reversed(model._modules.items()):
101 | if len(list(module.children())) > 0:
102 | enable_gpt_neox_pos_shift_attention(
103 | module,
104 | )
105 |
106 | if isinstance(module, GPTNeoXAttention):
107 | module.forward = types.MethodType(
108 | gpt_neox_pos_shift_attention_forward, module
109 | )
110 |
--------------------------------------------------------------------------------
/streaming_llm/pos_shift/modify_llama.py:
--------------------------------------------------------------------------------
1 | import math
2 | from typing import Optional, Tuple
3 |
4 | import torch
5 | from torch import nn
6 | import torch.utils.checkpoint
7 |
8 | import torch.nn.functional as F
9 |
10 | from transformers.models.llama.modeling_llama import (
11 | LlamaAttention,
12 | rotate_half,
13 | apply_rotary_pos_emb,
14 | repeat_kv,
15 | )
16 | import types
17 | import transformers
18 | from einops import rearrange
19 | from flash_attn import __version__ as flash_attn_version
20 | from flash_attn.bert_padding import pad_input, unpad_input
21 | from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func
22 |
23 | __all__ = ["enable_llama_pos_shift_attention"]
24 |
25 |
26 | def apply_rotary_pos_emb_single(x, cos, sin, position_ids):
27 | # The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
28 | cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
29 | sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
30 | cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
31 | sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
32 | x_embed = (x * cos) + (rotate_half(x) * sin)
33 | return x_embed
34 |
35 |
36 | def llama_pos_shift_attention_forward(
37 | self,
38 | hidden_states: torch.Tensor,
39 | attention_mask: Optional[torch.Tensor] = None,
40 | position_ids: Optional[torch.LongTensor] = None,
41 | past_key_value: Optional[Tuple[torch.Tensor]] = None,
42 | output_attentions: bool = False,
43 | use_cache: bool = False,
44 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
45 | bsz, q_len, _ = hidden_states.size()
46 |
47 | if self.config.pretraining_tp > 1:
48 | key_value_slicing = (
49 | self.num_key_value_heads * self.head_dim
50 | ) // self.config.pretraining_tp
51 | query_slices = self.q_proj.weight.split(
52 | (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
53 | )
54 | key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
55 | value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
56 |
57 | query_states = [
58 | F.linear(hidden_states, query_slices[i])
59 | for i in range(self.config.pretraining_tp)
60 | ]
61 | query_states = torch.cat(query_states, dim=-1)
62 |
63 | key_states = [
64 | F.linear(hidden_states, key_slices[i])
65 | for i in range(self.config.pretraining_tp)
66 | ]
67 | key_states = torch.cat(key_states, dim=-1)
68 |
69 | value_states = [
70 | F.linear(hidden_states, value_slices[i])
71 | for i in range(self.config.pretraining_tp)
72 | ]
73 | value_states = torch.cat(value_states, dim=-1)
74 |
75 | else:
76 | query_states = self.q_proj(hidden_states)
77 | key_states = self.k_proj(hidden_states)
78 | value_states = self.v_proj(hidden_states)
79 |
80 | query_states = query_states.view(
81 | bsz, q_len, self.num_heads, self.head_dim
82 | ).transpose(1, 2)
83 | key_states = key_states.view(
84 | bsz, q_len, self.num_key_value_heads, self.head_dim
85 | ).transpose(1, 2)
86 | value_states = value_states.view(
87 | bsz, q_len, self.num_key_value_heads, self.head_dim
88 | ).transpose(1, 2)
89 |
90 | kv_seq_len = key_states.shape[-2]
91 | if past_key_value is not None:
92 | kv_seq_len += past_key_value[0].shape[-2]
93 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
94 | ### Shift Pos: query pos is min(cache_size, idx)
95 | # query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
96 | query_states = apply_rotary_pos_emb_single(query_states, cos, sin, position_ids)
97 | ###
98 |
99 | if past_key_value is not None:
100 | # reuse k, v, self_attention
101 | key_states = torch.cat([past_key_value[0], key_states], dim=2)
102 | value_states = torch.cat([past_key_value[1], value_states], dim=2)
103 |
104 | past_key_value = (key_states, value_states) if use_cache else None
105 |
106 | ### Shift Pos: key pos is the pos in cache
107 | key_position_ids = torch.arange(kv_seq_len, device=position_ids.device).unsqueeze(0)
108 | key_states = apply_rotary_pos_emb_single(key_states, cos, sin, key_position_ids)
109 | ###
110 |
111 | # repeat k/v heads if n_kv_heads < n_heads
112 | key_states = repeat_kv(key_states, self.num_key_value_groups)
113 | value_states = repeat_kv(value_states, self.num_key_value_groups)
114 |
115 | attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(
116 | self.head_dim
117 | )
118 |
119 | if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
120 | raise ValueError(
121 | f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
122 | f" {attn_weights.size()}"
123 | )
124 |
125 | if attention_mask is not None:
126 | if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
127 | raise ValueError(
128 | f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
129 | )
130 | attn_weights = attn_weights + attention_mask
131 |
132 | # upcast attention to fp32
133 | attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(
134 | query_states.dtype
135 | )
136 | attn_output = torch.matmul(attn_weights, value_states)
137 |
138 | if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
139 | raise ValueError(
140 | f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
141 | f" {attn_output.size()}"
142 | )
143 |
144 | attn_output = attn_output.transpose(1, 2).contiguous()
145 | attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
146 |
147 | if self.config.pretraining_tp > 1:
148 | attn_output = attn_output.split(
149 | self.hidden_size // self.config.pretraining_tp, dim=2
150 | )
151 | o_proj_slices = self.o_proj.weight.split(
152 | self.hidden_size // self.config.pretraining_tp, dim=1
153 | )
154 | attn_output = sum(
155 | [
156 | F.linear(attn_output[i], o_proj_slices[i])
157 | for i in range(self.config.pretraining_tp)
158 | ]
159 | )
160 | else:
161 | attn_output = self.o_proj(attn_output)
162 |
163 | if not output_attentions:
164 | attn_weights = None
165 |
166 | return attn_output, attn_weights, past_key_value
167 |
168 |
169 | def llama_pos_shift_attention_forward_flashattn(
170 | self,
171 | hidden_states: torch.Tensor,
172 | attention_mask: Optional[torch.Tensor] = None,
173 | position_ids: Optional[torch.LongTensor] = None,
174 | past_key_value: Optional[Tuple[torch.Tensor]] = None,
175 | output_attentions: bool = False,
176 | use_cache: bool = False,
177 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
178 | bsz, q_len, _ = hidden_states.size()
179 |
180 | query_states = self.q_proj(hidden_states)
181 | key_states = self.k_proj(hidden_states)
182 | value_states = self.v_proj(hidden_states)
183 |
184 | query_states = query_states.view(
185 | bsz, q_len, self.num_heads, self.head_dim
186 | ).transpose(1, 2)
187 | key_states = key_states.view(
188 | bsz, q_len, self.num_key_value_heads, self.head_dim
189 | ).transpose(1, 2)
190 | value_states = value_states.view(
191 | bsz, q_len, self.num_key_value_heads, self.head_dim
192 | ).transpose(1, 2)
193 |
194 | kv_seq_len = key_states.shape[-2]
195 | if past_key_value is not None:
196 | kv_seq_len += past_key_value[0].shape[-2]
197 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
198 | ### Shift Pos: query pos is min(cache_size, idx)
199 | # query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
200 | query_states = apply_rotary_pos_emb_single(query_states, cos, sin, position_ids)
201 | ###
202 |
203 | if past_key_value is not None:
204 | # reuse k, v, self_attention
205 | key_states = torch.cat([past_key_value[0], key_states], dim=2)
206 | value_states = torch.cat([past_key_value[1], value_states], dim=2)
207 |
208 | past_key_value = (key_states, value_states) if use_cache else None
209 |
210 | ### Shift Pos: key pos is the pos in cache
211 | key_position_ids = torch.arange(kv_seq_len, device=position_ids.device).unsqueeze(0)
212 | key_states = apply_rotary_pos_emb_single(key_states, cos, sin, key_position_ids)
213 | ###
214 |
215 | # repeat k/v heads if n_kv_heads < n_heads
216 | key_states = repeat_kv(key_states, self.num_key_value_groups)
217 | value_states = repeat_kv(value_states, self.num_key_value_groups)
218 |
219 | if past_key_value is None:
220 | qkv = torch.stack(
221 | [query_states, key_states, value_states], dim=2
222 | ) # [bsz, nh, 3, q_len, hd]
223 | qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd]
224 |
225 | key_padding_mask = torch.full((bsz, q_len), True, dtype=torch.bool, device=attention_mask.device)
226 | nheads = qkv.shape[-2]
227 | x = rearrange(qkv, "b s three h d -> b s (three h d)")
228 | x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask)
229 | x_unpad = rearrange(
230 | x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads
231 | )
232 | output_unpad = flash_attn_varlen_qkvpacked_func(
233 | x_unpad, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
234 | )
235 | output = rearrange(
236 | pad_input(
237 | rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices, bsz, q_len
238 | ),
239 | "b s (h d) -> b s h d",
240 | h=nheads,
241 | )
242 | output = output.reshape(bsz, q_len, self.num_heads, self.head_dim)
243 |
244 | attn_output = self.o_proj(rearrange(output, "b s h d -> b s (h d)"))
245 | attn_weights = None
246 | else:
247 | attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(
248 | self.head_dim
249 | )
250 |
251 | if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
252 | raise ValueError(
253 | f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
254 | f" {attn_weights.size()}"
255 | )
256 |
257 | if attention_mask is not None:
258 | if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
259 | raise ValueError(
260 | f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
261 | )
262 | attn_weights = attn_weights + attention_mask
263 |
264 | # upcast attention to fp32
265 | attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(
266 | query_states.dtype
267 | )
268 | attn_output = torch.matmul(attn_weights, value_states)
269 |
270 | if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
271 | raise ValueError(
272 | f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
273 | f" {attn_output.size()}"
274 | )
275 |
276 | attn_output = attn_output.transpose(1, 2).contiguous()
277 | attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
278 |
279 | if self.config.pretraining_tp > 1:
280 | attn_output = attn_output.split(
281 | self.hidden_size // self.config.pretraining_tp, dim=2
282 | )
283 | o_proj_slices = self.o_proj.weight.split(
284 | self.hidden_size // self.config.pretraining_tp, dim=1
285 | )
286 | attn_output = sum(
287 | [
288 | F.linear(attn_output[i], o_proj_slices[i])
289 | for i in range(self.config.pretraining_tp)
290 | ]
291 | )
292 | else:
293 | attn_output = self.o_proj(attn_output)
294 |
295 | if not output_attentions:
296 | attn_weights = None
297 |
298 | return attn_output, attn_weights, past_key_value
299 |
300 |
301 | def enable_llama_pos_shift_attention(model, use_flash_attn=True):
302 | for name, module in reversed(model._modules.items()):
303 | if len(list(module.children())) > 0:
304 | enable_llama_pos_shift_attention(
305 | module,
306 | )
307 |
308 | if isinstance(module, LlamaAttention):
309 | model._modules[name].forward = types.MethodType(
310 | llama_pos_shift_attention_forward_flashattn if use_flash_attn else llama_pos_shift_attention_forward, model._modules[name]
311 | )
312 |
--------------------------------------------------------------------------------
/streaming_llm/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import argparse
3 | from transformers import (
4 | AutoTokenizer,
5 | AutoModelForCausalLM,
6 | )
7 | import os.path as osp
8 | import ssl
9 | import urllib.request
10 | import os
11 | import json
12 |
13 |
14 | def parse_args():
15 | parser = argparse.ArgumentParser()
16 | parser.add_argument(
17 | "--model_name_or_path", type=str, default="models/llama/llama-7b"
18 | )
19 | parser.add_argument("--revision", type=str, default="main")
20 | parser.add_argument("--tokenizer_name_or_path", type=str, default=None)
21 | parser.add_argument("--dataset_name", type=str, default="wikitext")
22 |
23 | parser.add_argument("--task", type=str, default="wikitext-2-raw-v1")
24 | parser.add_argument(
25 | "--split", type=str, default="test", choices=["validation", "test"]
26 | )
27 |
28 | parser.add_argument(
29 | "--num_samples",
30 | type=int,
31 | default=1,
32 | )
33 |
34 | parser.add_argument(
35 | "--output_dir",
36 | type=str,
37 | default="outputs/debug",
38 | )
39 |
40 | parser.add_argument("--enable_start_recent_kv_cache", action="store_true")
41 | parser.add_argument("--start_size", type=int, default=1)
42 | parser.add_argument("--recent_size", type=int, default=255)
43 | parser.add_argument("--enable_pos_shift", action="store_true")
44 |
45 | parser.add_argument("--num_eval_tokens", type=int, default=None)
46 |
47 | args = parser.parse_args()
48 | return args
49 |
50 |
51 | def load(model_name_or_path):
52 | print(f"Loading model from {model_name_or_path} ...")
53 | # however, tensor parallel for running falcon will occur bugs
54 | tokenizer = AutoTokenizer.from_pretrained(
55 | model_name_or_path,
56 | trust_remote_code=True,
57 | )
58 | model = AutoModelForCausalLM.from_pretrained(
59 | model_name_or_path,
60 | device_map="auto",
61 | torch_dtype=torch.float16,
62 | trust_remote_code=True,
63 | )
64 | if tokenizer.pad_token_id is None:
65 | if tokenizer.eos_token_id is not None:
66 | tokenizer.pad_token_id = tokenizer.eos_token_id
67 | else:
68 | tokenizer.pad_token_id = 0
69 |
70 | model.eval()
71 |
72 | return model, tokenizer
73 |
74 |
75 | def download_url(url: str, folder="folder"):
76 | """
77 | Downloads the content of an url to a folder. Modified from \
78 | https://github.com/pyg-team/pytorch_geometric/tree/master/torch_geometric
79 |
80 | Args:
81 | url (string): The url of target file.
82 | folder (string): The target folder.
83 |
84 | Returns:
85 | string: File path of downloaded files.
86 | """
87 |
88 | file = url.rpartition("/")[2]
89 | file = file if file[0] == "?" else file.split("?")[0]
90 | path = osp.join(folder, file)
91 | if osp.exists(path):
92 | print(f"File {file} exists, use existing file.")
93 | return path
94 |
95 | print(f"Downloading {url}")
96 | os.makedirs(folder, exist_ok=True)
97 | ctx = ssl._create_unverified_context()
98 | data = urllib.request.urlopen(url, context=ctx)
99 | with open(path, "wb") as f:
100 | f.write(data.read())
101 |
102 | return path
103 |
104 |
105 | def load_jsonl(
106 | file_path,
107 | ):
108 | list_data_dict = []
109 | with open(file_path, "r") as f:
110 | for line in f:
111 | list_data_dict.append(json.loads(line))
112 | return list_data_dict
113 |
--------------------------------------------------------------------------------
/supervised-fine-tune-qlora.py:
--------------------------------------------------------------------------------
1 | # Written by Yukang Chen
2 | # Some code based on https://github.com/huggingface/peft/blob/main/examples/fp4_finetuning/finetune_fp4_opt_bnb_peft.py
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import io
17 | import os
18 | import copy
19 | import json
20 | import math
21 | import logging
22 | from dataclasses import dataclass, field
23 | from typing import Dict, Optional, Sequence
24 |
25 | import torch
26 | import torch.nn as nn
27 | import transformers
28 | from torch.utils.data import Dataset
29 | from transformers import Trainer, DataCollatorForLanguageModeling, BitsAndBytesConfig
30 | from llama_attn_replace_sft import replace_llama_attn
31 | from gptneox_attn_replace import replace_gpt_neox_attn
32 | from peft import LoraConfig, get_peft_model
33 | from torch.distributed import barrier
34 |
35 | IGNORE_INDEX = -100
36 | DEFAULT_PAD_TOKEN = "[PAD]"
37 | DEFAULT_EOS_TOKEN = ""
38 | DEFAULT_BOS_TOKEN = ""
39 | DEFAULT_UNK_TOKEN = ""
40 |
41 | def _make_r_io_base(f, mode: str):
42 | if not isinstance(f, io.IOBase):
43 | f = open(f, mode=mode)
44 | return f
45 |
46 | def jload(f, mode="r"):
47 | """Load a .json file into a dictionary."""
48 | f = _make_r_io_base(f, mode)
49 | jdict = json.load(f)
50 | f.close()
51 | return jdict
52 |
53 | PROMPT_DICT = {
54 | "prompt_input": (
55 | "Below is an instruction that describes a task, paired with an input that provides further context. "
56 | "Write a response that appropriately completes the request.\n\n"
57 | "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"
58 | ),
59 | "prompt_no_input": (
60 | "Below is an instruction that describes a task. "
61 | "Write a response that appropriately completes the request.\n\n"
62 | "### Instruction:\n{instruction}\n\n### Response:"
63 | ),
64 | "prompt_no_input_llama2":(
65 | "[INST] <>\n"
66 | "You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\n"
67 | "If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\n"
68 | "<> \n\n {instruction} [/INST]"
69 | ),
70 | "prompt_input_llama2": (
71 | "[INST] <>\n"
72 | "You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\n"
73 | "If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\n"
74 | "<> \n\n {instruction} \n{input} [/INST]"
75 | ),
76 | "prompt_llama2": "[INST]{instruction}[/INST]"
77 | }
78 |
79 |
80 | @dataclass
81 | class ModelArguments:
82 | model_name_or_path: Optional[str] = field(default="EleutherAI/pythia-1.4b-deduped")
83 | model_type: Optional[str] = field(default="llama")
84 |
85 |
86 | @dataclass
87 | class DataArguments:
88 | data_path: str = field(default=None, metadata={"help": "Path to the training data."})
89 |
90 |
91 | @dataclass
92 | class TrainingArguments(transformers.TrainingArguments):
93 | cache_dir: Optional[str] = field(default=None)
94 | optim: str = field(default="adamw_torch")
95 | model_max_length: int = field(
96 | default=8192 * 4,
97 | metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."},
98 | )
99 | use_flash_attn: bool = field(
100 | default=True,
101 | metadata={"help": "Whether use flash attention for training."},
102 | )
103 | use_full_attn: bool = field(
104 | default=False,
105 | metadata={"help": "Whether to use plain, full-attention for training."},
106 | )
107 | low_rank_training: bool = field(
108 | default=True,
109 | metadata={"help": "Whether use low rank adaptation for training."},
110 | )
111 | trainable_params: str = field(
112 | default="embed,norm",
113 | metadata={"help": "Additional trainable parameters except LoRA weights, if low rank training."},
114 | )
115 |
116 | def smart_tokenizer_and_embedding_resize(
117 | special_tokens_dict: Dict,
118 | tokenizer: transformers.PreTrainedTokenizer,
119 | model: transformers.PreTrainedModel,
120 | ):
121 | """Resize tokenizer and embedding.
122 |
123 | Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
124 | """
125 | num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
126 | model.resize_token_embeddings(len(tokenizer))
127 |
128 | if num_new_tokens > 0:
129 | input_embeddings = model.get_input_embeddings().weight.data
130 | output_embeddings = model.get_output_embeddings().weight.data
131 |
132 | input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
133 | output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
134 |
135 | input_embeddings[-num_new_tokens:] = input_embeddings_avg
136 | output_embeddings[-num_new_tokens:] = output_embeddings_avg
137 |
138 |
139 | def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer) -> Dict:
140 | """Tokenize a list of strings."""
141 | tokenized_list = [
142 | tokenizer(
143 | text,
144 | return_tensors="pt",
145 | padding="longest",
146 | max_length=tokenizer.model_max_length,
147 | truncation=True,
148 | )
149 | for text in strings
150 | ]
151 | input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list]
152 | input_ids_lens = labels_lens = [
153 | tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list
154 | ]
155 | return dict(
156 | input_ids=input_ids,
157 | labels=labels,
158 | input_ids_lens=input_ids_lens,
159 | labels_lens=labels_lens,
160 | )
161 |
162 |
163 | def preprocess(
164 | sources: Sequence[str],
165 | targets: Sequence[str],
166 | tokenizer: transformers.PreTrainedTokenizer,
167 | ) -> Dict:
168 | """Preprocess the data by tokenizing."""
169 | examples = [s + t for s, t in zip(sources, targets)]
170 | examples_tokenized, sources_tokenized = [_tokenize_fn(strings, tokenizer) for strings in (examples, sources)]
171 | input_ids = examples_tokenized["input_ids"]
172 | labels = copy.deepcopy(input_ids)
173 | for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]):
174 | label[:source_len] = IGNORE_INDEX
175 | return dict(input_ids=input_ids, labels=labels)
176 |
177 |
178 | class SupervisedDataset(Dataset):
179 | """Dataset for supervised fine-tuning."""
180 |
181 | def __init__(self, data_path: str, tokenizer: transformers.PreTrainedTokenizer):
182 | super(SupervisedDataset, self).__init__()
183 | logging.warning("Loading data...")
184 | list_data_dict = jload(data_path)
185 |
186 | logging.warning("Formatting inputs...")
187 |
188 | prompt_input, prompt_no_input = PROMPT_DICT["prompt_input_llama2"], PROMPT_DICT["prompt_llama2"]
189 | sources = [
190 | prompt_input.format_map(example) if example.get("input", "") != "" else prompt_no_input.format_map(example)
191 | for example in list_data_dict
192 | ]
193 |
194 | targets = [f"{example['output']}{tokenizer.eos_token}" for example in list_data_dict]
195 |
196 | logging.warning("Tokenizing inputs... This may take some time...")
197 | data_dict = preprocess(sources, targets, tokenizer)
198 |
199 | self.input_ids = data_dict["input_ids"]
200 | self.labels = data_dict["labels"]
201 |
202 | def __len__(self):
203 | return len(self.input_ids)
204 |
205 | def __getitem__(self, i) -> Dict[str, torch.Tensor]:
206 | return dict(input_ids=self.input_ids[i], labels=self.labels[i])
207 |
208 |
209 | @dataclass
210 | class DataCollatorForSupervisedDataset(object):
211 | """Collate examples for supervised fine-tuning."""
212 |
213 | tokenizer: transformers.PreTrainedTokenizer
214 |
215 | def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
216 | input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels"))
217 | input_ids = torch.nn.utils.rnn.pad_sequence(
218 | input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
219 | )
220 | labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)
221 | return dict(
222 | input_ids=input_ids,
223 | labels=labels,
224 | attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
225 | )
226 |
227 |
228 | def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args) -> Dict:
229 | """Make dataset and collator for supervised fine-tuning."""
230 | train_dataset = SupervisedDataset(tokenizer=tokenizer, data_path=data_args.data_path)
231 | data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
232 | return dict(train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator)
233 |
234 |
235 | def train():
236 | parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
237 | model_args, data_args, training_args = parser.parse_args_into_dataclasses()
238 |
239 | # NOTE: May expand supported model types in the future
240 | if model_args.model_type == "gpt-neox":
241 | replace_gpt_neox_attn(training_args.use_flash_attn, training_args.use_full_attn)
242 | else:
243 | replace_llama_attn(training_args.use_flash_attn, training_args.use_full_attn)
244 |
245 | # Set RoPE scaling factor
246 | config = transformers.AutoConfig.from_pretrained(
247 | model_args.model_name_or_path,
248 | cache_dir=training_args.cache_dir,
249 | )
250 |
251 | orig_rope_scaling = getattr(config, "rope_scaling", None)
252 | if orig_rope_scaling is None:
253 | orig_rope_scaling = {"factor": 1}
254 | orig_rope_scaling_factor = orig_rope_scaling["factor"] if "factor" in orig_rope_scaling.keys() else 1
255 | orig_ctx_len = getattr(config, "max_position_embeddings", None)
256 | if orig_ctx_len:
257 | orig_ctx_len *= orig_rope_scaling_factor
258 | if training_args.model_max_length > orig_ctx_len:
259 | scaling_factor = float(math.ceil(training_args.model_max_length / orig_ctx_len))
260 | config.rope_scaling = {"type": "linear", "factor": scaling_factor}
261 |
262 | # Load model and tokenizer
263 | model = transformers.AutoModelForCausalLM.from_pretrained(
264 | model_args.model_name_or_path,
265 | config=config,
266 | cache_dir=training_args.cache_dir,
267 | torch_dtype=torch.bfloat16,
268 | quantization_config=BitsAndBytesConfig(
269 | load_in_4bit=True,
270 | llm_int8_threshold=6.0,
271 | llm_int8_has_fp16_weight=False,
272 | bnb_4bit_compute_dtype=torch.bfloat16,
273 | bnb_4bit_use_double_quant=True,
274 | bnb_4bit_quant_type="nf4",
275 | ),
276 | )
277 |
278 | for param in model.parameters():
279 | param.requires_grad = False # freeze the model - train adapters later
280 | if param.ndim == 1:
281 | # cast the small parameters (e.g. layernorm) to fp32 for stability
282 | param.data = param.data.to(torch.float32)
283 |
284 | tokenizer = transformers.AutoTokenizer.from_pretrained(
285 | model_args.model_name_or_path,
286 | cache_dir=training_args.cache_dir,
287 | model_max_length=training_args.model_max_length,
288 | padding_side="right",
289 | use_fast=True,
290 | )
291 |
292 | special_tokens_dict = dict()
293 | if tokenizer.pad_token is None:
294 | special_tokens_dict["pad_token"] = DEFAULT_PAD_TOKEN
295 | if tokenizer.eos_token is None:
296 | special_tokens_dict["eos_token"] = DEFAULT_EOS_TOKEN
297 | if tokenizer.bos_token is None:
298 | special_tokens_dict["bos_token"] = DEFAULT_BOS_TOKEN
299 | if tokenizer.unk_token is None:
300 | special_tokens_dict["unk_token"] = DEFAULT_UNK_TOKEN
301 |
302 | smart_tokenizer_and_embedding_resize(
303 | special_tokens_dict=special_tokens_dict,
304 | tokenizer=tokenizer,
305 | model=model,
306 | )
307 |
308 | data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args)
309 |
310 | if training_args.low_rank_training:
311 | if model_args.model_type == "gpt-neox":
312 | # added `dense` to match with llama as the basic LoRA would only target 'query_key_value'
313 | targets = ["query_key_value", "dense"]
314 | else:
315 | targets=["q_proj", "k_proj", "v_proj", "o_proj"]
316 |
317 | config = LoraConfig(
318 | r=8,
319 | lora_alpha=16,
320 | target_modules=targets,
321 | lora_dropout=0,
322 | bias="none",
323 | task_type="CAUSAL_LM",
324 | )
325 | model = get_peft_model(model, config)
326 | # enable trainable params
327 | [p.requires_grad_() for n, p in model.named_parameters() if any([k in n for k in training_args.trainable_params.split(",")])]
328 |
329 | class CastOutputToFloat(nn.Sequential):
330 | def forward(self, x):
331 | return super().forward(x).to(torch.float32)
332 |
333 | model.lm_head = CastOutputToFloat(model.lm_head)
334 |
335 | # Verifying the datatypes.
336 | dtypes = {}
337 | for _, p in model.named_parameters():
338 | dtype = p.dtype
339 | if dtype not in dtypes:
340 | dtypes[dtype] = 0
341 | dtypes[dtype] += p.numel()
342 | total = 0
343 | for k, v in dtypes.items():
344 | total += v
345 | for k, v in dtypes.items():
346 | print(k, v, v / total)
347 |
348 | model.config.use_cache = False # required for gradient checkpointing
349 | model.enable_input_require_grads() # required for gradient checkpointing
350 | model.gradient_checkpointing_enable() # enable gradient checkpointing
351 |
352 | trainer = Trainer(model=model, tokenizer=tokenizer, args=training_args, **data_module)
353 | trainer.train()
354 | trainer.save_state()
355 | trainer.save_model(output_dir=training_args.output_dir)
356 |
357 |
358 | if __name__ == "__main__":
359 | train()
360 |
--------------------------------------------------------------------------------
/supervised-fine-tune.py:
--------------------------------------------------------------------------------
1 | # Written by Yukang Chen
2 | # Some code based on https://github.com/epfml/landmark-attention
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import io
17 | import os
18 | import copy
19 | import json
20 | import math
21 | import logging
22 | from dataclasses import dataclass, field
23 | from typing import Dict, Optional, Sequence
24 |
25 | import torch
26 | import transformers
27 | from torch.utils.data import Dataset
28 | from transformers import Trainer, DataCollatorForLanguageModeling
29 | from llama_attn_replace_sft import replace_llama_attn
30 | from gptneox_attn_replace import replace_gpt_neox_attn
31 | from peft import LoraConfig, get_peft_model
32 | from torch.distributed import barrier
33 |
34 | IGNORE_INDEX = -100
35 | DEFAULT_PAD_TOKEN = "[PAD]"
36 | DEFAULT_EOS_TOKEN = ""
37 | DEFAULT_BOS_TOKEN = ""
38 | DEFAULT_UNK_TOKEN = ""
39 |
40 | def _make_r_io_base(f, mode: str):
41 | if not isinstance(f, io.IOBase):
42 | f = open(f, mode=mode)
43 | return f
44 |
45 | def jload(f, mode="r"):
46 | """Load a .json file into a dictionary."""
47 | f = _make_r_io_base(f, mode)
48 | jdict = json.load(f)
49 | f.close()
50 | return jdict
51 |
52 | PROMPT_DICT = {
53 | "prompt_input": (
54 | "Below is an instruction that describes a task, paired with an input that provides further context. "
55 | "Write a response that appropriately completes the request.\n\n"
56 | "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"
57 | ),
58 | "prompt_no_input": (
59 | "Below is an instruction that describes a task. "
60 | "Write a response that appropriately completes the request.\n\n"
61 | "### Instruction:\n{instruction}\n\n### Response:"
62 | ),
63 | "prompt_no_input_llama2":(
64 | "[INST] <>\n"
65 | "You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\n"
66 | "If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\n"
67 | "<> \n\n {instruction} [/INST]"
68 | ),
69 | "prompt_input_llama2": (
70 | "[INST] <>\n"
71 | "You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\n"
72 | "If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\n"
73 | "<> \n\n {instruction} \n{input} [/INST]"
74 | ),
75 | "prompt_llama2": "[INST]{instruction}[/INST]"
76 | }
77 |
78 |
79 | @dataclass
80 | class ModelArguments:
81 | model_name_or_path: Optional[str] = field(default="EleutherAI/pythia-1.4b-deduped")
82 | model_type: Optional[str] = field(default="llama")
83 |
84 |
85 | @dataclass
86 | class DataArguments:
87 | data_path: str = field(default=None, metadata={"help": "Path to the training data."})
88 |
89 |
90 | @dataclass
91 | class TrainingArguments(transformers.TrainingArguments):
92 | cache_dir: Optional[str] = field(default=None)
93 | optim: str = field(default="adamw_torch")
94 | model_max_length: int = field(
95 | default=8192 * 4,
96 | metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."},
97 | )
98 | use_flash_attn: bool = field(
99 | default=True,
100 | metadata={"help": "Whether use flash attention for training."},
101 | )
102 | use_full_attn: bool = field(
103 | default=False,
104 | metadata={"help": "Whether to use plain, full-attention for training."},
105 | )
106 | low_rank_training: bool = field(
107 | default=True,
108 | metadata={"help": "Whether use low rank adaptation for training."},
109 | )
110 | trainable_params: str = field(
111 | default="embed,norm",
112 | metadata={"help": "Additional trainable parameters except LoRA weights, if low rank training."},
113 | )
114 |
115 | def smart_tokenizer_and_embedding_resize(
116 | special_tokens_dict: Dict,
117 | tokenizer: transformers.PreTrainedTokenizer,
118 | model: transformers.PreTrainedModel,
119 | ):
120 | """Resize tokenizer and embedding.
121 |
122 | Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
123 | """
124 | num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
125 | model.resize_token_embeddings(len(tokenizer))
126 |
127 | if num_new_tokens > 0:
128 | input_embeddings = model.get_input_embeddings().weight.data
129 | output_embeddings = model.get_output_embeddings().weight.data
130 |
131 | input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
132 | output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
133 |
134 | input_embeddings[-num_new_tokens:] = input_embeddings_avg
135 | output_embeddings[-num_new_tokens:] = output_embeddings_avg
136 |
137 |
138 | def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer) -> Dict:
139 | """Tokenize a list of strings."""
140 | tokenized_list = [
141 | tokenizer(
142 | text,
143 | return_tensors="pt",
144 | padding="longest",
145 | max_length=tokenizer.model_max_length,
146 | truncation=True,
147 | )
148 | for text in strings
149 | ]
150 | input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list]
151 | input_ids_lens = labels_lens = [
152 | tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list
153 | ]
154 | return dict(
155 | input_ids=input_ids,
156 | labels=labels,
157 | input_ids_lens=input_ids_lens,
158 | labels_lens=labels_lens,
159 | )
160 |
161 |
162 | def preprocess(
163 | sources: Sequence[str],
164 | targets: Sequence[str],
165 | tokenizer: transformers.PreTrainedTokenizer,
166 | ) -> Dict:
167 | """Preprocess the data by tokenizing."""
168 | examples = [s + t for s, t in zip(sources, targets)]
169 | examples_tokenized, sources_tokenized = [_tokenize_fn(strings, tokenizer) for strings in (examples, sources)]
170 | input_ids = examples_tokenized["input_ids"]
171 | labels = copy.deepcopy(input_ids)
172 | for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]):
173 | label[:source_len] = IGNORE_INDEX
174 | return dict(input_ids=input_ids, labels=labels)
175 |
176 |
177 | class SupervisedDataset(Dataset):
178 | """Dataset for supervised fine-tuning."""
179 |
180 | def __init__(self, data_path: str, tokenizer: transformers.PreTrainedTokenizer):
181 | super(SupervisedDataset, self).__init__()
182 | logging.warning("Loading data...")
183 | list_data_dict = jload(data_path)
184 |
185 | logging.warning("Formatting inputs...")
186 |
187 | prompt_input, prompt_no_input = PROMPT_DICT["prompt_input_llama2"], PROMPT_DICT["prompt_llama2"]
188 | sources = [
189 | prompt_input.format_map(example) if example.get("input", "") != "" else prompt_no_input.format_map(example)
190 | for example in list_data_dict
191 | ]
192 |
193 | targets = [f"{example['output']}{tokenizer.eos_token}" for example in list_data_dict]
194 |
195 | logging.warning("Tokenizing inputs... This may take some time...")
196 | data_dict = preprocess(sources, targets, tokenizer)
197 |
198 | self.input_ids = data_dict["input_ids"]
199 | self.labels = data_dict["labels"]
200 |
201 | def __len__(self):
202 | return len(self.input_ids)
203 |
204 | def __getitem__(self, i) -> Dict[str, torch.Tensor]:
205 | return dict(input_ids=self.input_ids[i], labels=self.labels[i])
206 |
207 |
208 | @dataclass
209 | class DataCollatorForSupervisedDataset(object):
210 | """Collate examples for supervised fine-tuning."""
211 |
212 | tokenizer: transformers.PreTrainedTokenizer
213 |
214 | def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
215 | input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels"))
216 | input_ids = torch.nn.utils.rnn.pad_sequence(
217 | input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id
218 | )
219 | labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)
220 | return dict(
221 | input_ids=input_ids,
222 | labels=labels,
223 | attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
224 | )
225 |
226 |
227 | def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args) -> Dict:
228 | """Make dataset and collator for supervised fine-tuning."""
229 | train_dataset = SupervisedDataset(tokenizer=tokenizer, data_path=data_args.data_path)
230 | data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
231 | return dict(train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator)
232 |
233 |
234 | def train():
235 | parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
236 | model_args, data_args, training_args = parser.parse_args_into_dataclasses()
237 |
238 | # NOTE: May expand supported model types in the future
239 | if model_args.model_type == "gpt-neox":
240 | replace_gpt_neox_attn(training_args.use_flash_attn, training_args.use_full_attn)
241 | else:
242 | replace_llama_attn(training_args.use_flash_attn, training_args.use_full_attn)
243 |
244 | # Set RoPE scaling factor
245 | config = transformers.AutoConfig.from_pretrained(
246 | model_args.model_name_or_path,
247 | cache_dir=training_args.cache_dir,
248 | )
249 |
250 | orig_rope_scaling = getattr(config, "rope_scaling", None)
251 | if orig_rope_scaling is None:
252 | orig_rope_scaling = {"factor": 1}
253 | orig_rope_scaling_factor = orig_rope_scaling["factor"] if "factor" in orig_rope_scaling.keys() else 1
254 | orig_ctx_len = getattr(config, "max_position_embeddings", None)
255 | if orig_ctx_len:
256 | orig_ctx_len *= orig_rope_scaling_factor
257 | if training_args.model_max_length > orig_ctx_len:
258 | scaling_factor = float(math.ceil(training_args.model_max_length / orig_ctx_len))
259 | config.rope_scaling = {"type": "linear", "factor": scaling_factor}
260 |
261 | # Load model and tokenizer
262 | model = transformers.AutoModelForCausalLM.from_pretrained(
263 | model_args.model_name_or_path,
264 | config=config,
265 | cache_dir=training_args.cache_dir,
266 | torch_dtype=torch.bfloat16,
267 | )
268 |
269 | tokenizer = transformers.AutoTokenizer.from_pretrained(
270 | model_args.model_name_or_path,
271 | cache_dir=training_args.cache_dir,
272 | model_max_length=training_args.model_max_length,
273 | padding_side="right",
274 | use_fast=True,
275 | )
276 |
277 | special_tokens_dict = dict()
278 | if tokenizer.pad_token is None:
279 | special_tokens_dict["pad_token"] = DEFAULT_PAD_TOKEN
280 | if tokenizer.eos_token is None:
281 | special_tokens_dict["eos_token"] = DEFAULT_EOS_TOKEN
282 | if tokenizer.bos_token is None:
283 | special_tokens_dict["bos_token"] = DEFAULT_BOS_TOKEN
284 | if tokenizer.unk_token is None:
285 | special_tokens_dict["unk_token"] = DEFAULT_UNK_TOKEN
286 |
287 | smart_tokenizer_and_embedding_resize(
288 | special_tokens_dict=special_tokens_dict,
289 | tokenizer=tokenizer,
290 | model=model,
291 | )
292 |
293 | data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args)
294 |
295 | if training_args.low_rank_training:
296 | if model_args.model_type == "gpt-neox":
297 | # added `dense` to match with llama as the basic LoRA would only target 'query_key_value'
298 | targets = ["query_key_value", "dense"]
299 | else:
300 | targets=["q_proj", "k_proj", "v_proj", "o_proj"]
301 |
302 | config = LoraConfig(
303 | r=8,
304 | lora_alpha=16,
305 | target_modules=targets,
306 | lora_dropout=0,
307 | bias="none",
308 | task_type="CAUSAL_LM",
309 | )
310 | model = get_peft_model(model, config)
311 | # enable trainable params
312 | [p.requires_grad_() for n, p in model.named_parameters() if any([k in n for k in training_args.trainable_params.split(",")])]
313 |
314 | model.config.use_cache = False # required for gradient checkpointing
315 | model.enable_input_require_grads() # required for gradient checkpointing
316 | model.gradient_checkpointing_enable() # enable gradient checkpointing
317 |
318 | trainer = Trainer(model=model, tokenizer=tokenizer, args=training_args, **data_module)
319 | trainer.train()
320 | trainer.save_state()
321 | trainer.save_model(output_dir=training_args.output_dir)
322 |
323 |
324 | if __name__ == "__main__":
325 | train()
326 |
--------------------------------------------------------------------------------