├── DATA_LICENSE ├── LICENSE ├── README.md ├── WEIGHT_LICENSE ├── benchmarks ├── Pred_L-Eval │ └── llm_gpt4_eval.pred.jsonl ├── Pred_LongBench │ ├── 2wikimqa.jsonl │ ├── gov_report.jsonl │ ├── hotpotqa.jsonl │ ├── lcc.jsonl │ ├── multi_news.jsonl │ ├── multifieldqa_en.jsonl │ ├── musique.jsonl │ ├── narrativeqa.jsonl │ ├── passage_count.jsonl │ ├── passage_retrieval_en.jsonl │ ├── qasper.jsonl │ ├── qmsum.jsonl │ ├── repobench-p.jsonl │ ├── result.json │ ├── samsum.jsonl │ ├── trec.jsonl │ └── triviaqa.jsonl └── README.md ├── demo.py ├── ds_configs ├── stage2.json └── stage3.json ├── eval.py ├── eval_distributed.py ├── fine-tune.py ├── get_trainable_weights.py ├── gptneox_attn_replace.py ├── imgs ├── LongAlpaca.png ├── Shift-short-attention2.png ├── data-distribution-in-longalpaca12k.png ├── demo-compare-harrypotter.png ├── demo-compare-journeytothewest.png ├── demo-compare-threebody.png ├── economy-comparison.png ├── economy-prediction.png ├── paper-improvements.png ├── paper-review.png └── paper-style-compare-cvpr-iclr.png ├── inference-qlora.py ├── inference.py ├── llama_attn_replace.py ├── llama_attn_replace_sft.py ├── merge_lora_weights_and_save_hf_model.py ├── passkey_retrivial.py ├── pdf2txt ├── README.md ├── backbone.py ├── beit.py ├── config.py ├── configs │ ├── Base-RCNN-FPN.yaml │ └── cascade_dit_large.yaml ├── pdf2txt.py └── requirements.txt ├── requirements.txt ├── run_streaming_llama_longalpaca.py ├── streaming_llm ├── __init__.py ├── enable_streaming_llm.py ├── kv_cache.py ├── pos_shift │ ├── __init__.py │ ├── modify_falcon.py │ ├── modify_gpt_neox.py │ └── modify_llama.py └── utils.py ├── supervised-fine-tune-qlora.py └── supervised-fine-tune.py /DATA_LICENSE: -------------------------------------------------------------------------------- 1 | Attribution-NonCommercial 4.0 International 2 | 3 | ======================================================================= 4 | 5 | Creative Commons Corporation ("Creative Commons") is not a law firm and 6 | does not provide legal services or legal advice. Distribution of 7 | Creative Commons public licenses does not create a lawyer-client or 8 | other relationship. Creative Commons makes its licenses and related 9 | information available on an "as-is" basis. Creative Commons gives no 10 | warranties regarding its licenses, any material licensed under their 11 | terms and conditions, or any related information. Creative Commons 12 | disclaims all liability for damages resulting from their use to the 13 | fullest extent possible. 14 | 15 | Using Creative Commons Public Licenses 16 | 17 | Creative Commons public licenses provide a standard set of terms and 18 | conditions that creators and other rights holders may use to share 19 | original works of authorship and other material subject to copyright 20 | and certain other rights specified in the public license below. The 21 | following considerations are for informational purposes only, are not 22 | exhaustive, and do not form part of our licenses. 23 | 24 | Considerations for licensors: Our public licenses are 25 | intended for use by those authorized to give the public 26 | permission to use material in ways otherwise restricted by 27 | copyright and certain other rights. Our licenses are 28 | irrevocable. Licensors should read and understand the terms 29 | and conditions of the license they choose before applying it. 30 | Licensors should also secure all rights necessary before 31 | applying our licenses so that the public can reuse the 32 | material as expected. Licensors should clearly mark any 33 | material not subject to the license. This includes other CC- 34 | licensed material, or material used under an exception or 35 | limitation to copyright. More considerations for licensors: 36 | wiki.creativecommons.org/Considerations_for_licensors 37 | 38 | Considerations for the public: By using one of our public 39 | licenses, a licensor grants the public permission to use the 40 | licensed material under specified terms and conditions. If 41 | the licensor's permission is not necessary for any reason--for 42 | example, because of any applicable exception or limitation to 43 | copyright--then that use is not regulated by the license. Our 44 | licenses grant only permissions under copyright and certain 45 | other rights that a licensor has authority to grant. Use of 46 | the licensed material may still be restricted for other 47 | reasons, including because others have copyright or other 48 | rights in the material. A licensor may make special requests, 49 | such as asking that all changes be marked or described. 50 | Although not required by our licenses, you are encouraged to 51 | respect those requests where reasonable. More considerations 52 | for the public: 53 | wiki.creativecommons.org/Considerations_for_licensees 54 | 55 | ======================================================================= 56 | 57 | Creative Commons Attribution-NonCommercial 4.0 International Public 58 | License 59 | 60 | By exercising the Licensed Rights (defined below), You accept and agree 61 | to be bound by the terms and conditions of this Creative Commons 62 | Attribution-NonCommercial 4.0 International Public License ("Public 63 | License"). To the extent this Public License may be interpreted as a 64 | contract, You are granted the Licensed Rights in consideration of Your 65 | acceptance of these terms and conditions, and the Licensor grants You 66 | such rights in consideration of benefits the Licensor receives from 67 | making the Licensed Material available under these terms and 68 | conditions. 69 | 70 | 71 | Section 1 -- Definitions. 72 | 73 | a. Adapted Material means material subject to Copyright and Similar 74 | Rights that is derived from or based upon the Licensed Material 75 | and in which the Licensed Material is translated, altered, 76 | arranged, transformed, or otherwise modified in a manner requiring 77 | permission under the Copyright and Similar Rights held by the 78 | Licensor. For purposes of this Public License, where the Licensed 79 | Material is a musical work, performance, or sound recording, 80 | Adapted Material is always produced where the Licensed Material is 81 | synched in timed relation with a moving image. 82 | 83 | b. Adapter's License means the license You apply to Your Copyright 84 | and Similar Rights in Your contributions to Adapted Material in 85 | accordance with the terms and conditions of this Public License. 86 | 87 | c. Copyright and Similar Rights means copyright and/or similar rights 88 | closely related to copyright including, without limitation, 89 | performance, broadcast, sound recording, and Sui Generis Database 90 | Rights, without regard to how the rights are labeled or 91 | categorized. For purposes of this Public License, the rights 92 | specified in Section 2(b)(1)-(2) are not Copyright and Similar 93 | Rights. 94 | d. Effective Technological Measures means those measures that, in the 95 | absence of proper authority, may not be circumvented under laws 96 | fulfilling obligations under Article 11 of the WIPO Copyright 97 | Treaty adopted on December 20, 1996, and/or similar international 98 | agreements. 99 | 100 | e. Exceptions and Limitations means fair use, fair dealing, and/or 101 | any other exception or limitation to Copyright and Similar Rights 102 | that applies to Your use of the Licensed Material. 103 | 104 | f. Licensed Material means the artistic or literary work, database, 105 | or other material to which the Licensor applied this Public 106 | License. 107 | 108 | g. Licensed Rights means the rights granted to You subject to the 109 | terms and conditions of this Public License, which are limited to 110 | all Copyright and Similar Rights that apply to Your use of the 111 | Licensed Material and that the Licensor has authority to license. 112 | 113 | h. Licensor means the individual(s) or entity(ies) granting rights 114 | under this Public License. 115 | 116 | i. NonCommercial means not primarily intended for or directed towards 117 | commercial advantage or monetary compensation. For purposes of 118 | this Public License, the exchange of the Licensed Material for 119 | other material subject to Copyright and Similar Rights by digital 120 | file-sharing or similar means is NonCommercial provided there is 121 | no payment of monetary compensation in connection with the 122 | exchange. 123 | 124 | j. Share means to provide material to the public by any means or 125 | process that requires permission under the Licensed Rights, such 126 | as reproduction, public display, public performance, distribution, 127 | dissemination, communication, or importation, and to make material 128 | available to the public including in ways that members of the 129 | public may access the material from a place and at a time 130 | individually chosen by them. 131 | 132 | k. Sui Generis Database Rights means rights other than copyright 133 | resulting from Directive 96/9/EC of the European Parliament and of 134 | the Council of 11 March 1996 on the legal protection of databases, 135 | as amended and/or succeeded, as well as other essentially 136 | equivalent rights anywhere in the world. 137 | 138 | l. You means the individual or entity exercising the Licensed Rights 139 | under this Public License. Your has a corresponding meaning. 140 | 141 | 142 | Section 2 -- Scope. 143 | 144 | a. License grant. 145 | 146 | 1. Subject to the terms and conditions of this Public License, 147 | the Licensor hereby grants You a worldwide, royalty-free, 148 | non-sublicensable, non-exclusive, irrevocable license to 149 | exercise the Licensed Rights in the Licensed Material to: 150 | 151 | a. reproduce and Share the Licensed Material, in whole or 152 | in part, for NonCommercial purposes only; and 153 | 154 | b. produce, reproduce, and Share Adapted Material for 155 | NonCommercial purposes only. 156 | 157 | 2. Exceptions and Limitations. For the avoidance of doubt, where 158 | Exceptions and Limitations apply to Your use, this Public 159 | License does not apply, and You do not need to comply with 160 | its terms and conditions. 161 | 162 | 3. Term. The term of this Public License is specified in Section 163 | 6(a). 164 | 165 | 4. Media and formats; technical modifications allowed. The 166 | Licensor authorizes You to exercise the Licensed Rights in 167 | all media and formats whether now known or hereafter created, 168 | and to make technical modifications necessary to do so. The 169 | Licensor waives and/or agrees not to assert any right or 170 | authority to forbid You from making technical modifications 171 | necessary to exercise the Licensed Rights, including 172 | technical modifications necessary to circumvent Effective 173 | Technological Measures. For purposes of this Public License, 174 | simply making modifications authorized by this Section 2(a) 175 | (4) never produces Adapted Material. 176 | 177 | 5. Downstream recipients. 178 | 179 | a. Offer from the Licensor -- Licensed Material. Every 180 | recipient of the Licensed Material automatically 181 | receives an offer from the Licensor to exercise the 182 | Licensed Rights under the terms and conditions of this 183 | Public License. 184 | 185 | b. No downstream restrictions. You may not offer or impose 186 | any additional or different terms or conditions on, or 187 | apply any Effective Technological Measures to, the 188 | Licensed Material if doing so restricts exercise of the 189 | Licensed Rights by any recipient of the Licensed 190 | Material. 191 | 192 | 6. No endorsement. Nothing in this Public License constitutes or 193 | may be construed as permission to assert or imply that You 194 | are, or that Your use of the Licensed Material is, connected 195 | with, or sponsored, endorsed, or granted official status by, 196 | the Licensor or others designated to receive attribution as 197 | provided in Section 3(a)(1)(A)(i). 198 | 199 | b. Other rights. 200 | 201 | 1. Moral rights, such as the right of integrity, are not 202 | licensed under this Public License, nor are publicity, 203 | privacy, and/or other similar personality rights; however, to 204 | the extent possible, the Licensor waives and/or agrees not to 205 | assert any such rights held by the Licensor to the limited 206 | extent necessary to allow You to exercise the Licensed 207 | Rights, but not otherwise. 208 | 209 | 2. Patent and trademark rights are not licensed under this 210 | Public License. 211 | 212 | 3. To the extent possible, the Licensor waives any right to 213 | collect royalties from You for the exercise of the Licensed 214 | Rights, whether directly or through a collecting society 215 | under any voluntary or waivable statutory or compulsory 216 | licensing scheme. In all other cases the Licensor expressly 217 | reserves any right to collect such royalties, including when 218 | the Licensed Material is used other than for NonCommercial 219 | purposes. 220 | 221 | 222 | Section 3 -- License Conditions. 223 | 224 | Your exercise of the Licensed Rights is expressly made subject to the 225 | following conditions. 226 | 227 | a. Attribution. 228 | 229 | 1. If You Share the Licensed Material (including in modified 230 | form), You must: 231 | 232 | a. retain the following if it is supplied by the Licensor 233 | with the Licensed Material: 234 | 235 | i. identification of the creator(s) of the Licensed 236 | Material and any others designated to receive 237 | attribution, in any reasonable manner requested by 238 | the Licensor (including by pseudonym if 239 | designated); 240 | 241 | ii. a copyright notice; 242 | 243 | iii. a notice that refers to this Public License; 244 | 245 | iv. a notice that refers to the disclaimer of 246 | warranties; 247 | 248 | v. a URI or hyperlink to the Licensed Material to the 249 | extent reasonably practicable; 250 | 251 | b. indicate if You modified the Licensed Material and 252 | retain an indication of any previous modifications; and 253 | 254 | c. indicate the Licensed Material is licensed under this 255 | Public License, and include the text of, or the URI or 256 | hyperlink to, this Public License. 257 | 258 | 2. You may satisfy the conditions in Section 3(a)(1) in any 259 | reasonable manner based on the medium, means, and context in 260 | which You Share the Licensed Material. For example, it may be 261 | reasonable to satisfy the conditions by providing a URI or 262 | hyperlink to a resource that includes the required 263 | information. 264 | 265 | 3. If requested by the Licensor, You must remove any of the 266 | information required by Section 3(a)(1)(A) to the extent 267 | reasonably practicable. 268 | 269 | 4. If You Share Adapted Material You produce, the Adapter's 270 | License You apply must not prevent recipients of the Adapted 271 | Material from complying with this Public License. 272 | 273 | 274 | Section 4 -- Sui Generis Database Rights. 275 | 276 | Where the Licensed Rights include Sui Generis Database Rights that 277 | apply to Your use of the Licensed Material: 278 | 279 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right 280 | to extract, reuse, reproduce, and Share all or a substantial 281 | portion of the contents of the database for NonCommercial purposes 282 | only; 283 | 284 | b. if You include all or a substantial portion of the database 285 | contents in a database in which You have Sui Generis Database 286 | Rights, then the database in which You have Sui Generis Database 287 | Rights (but not its individual contents) is Adapted Material; and 288 | 289 | c. You must comply with the conditions in Section 3(a) if You Share 290 | all or a substantial portion of the contents of the database. 291 | 292 | For the avoidance of doubt, this Section 4 supplements and does not 293 | replace Your obligations under this Public License where the Licensed 294 | Rights include other Copyright and Similar Rights. 295 | 296 | 297 | Section 5 -- Disclaimer of Warranties and Limitation of Liability. 298 | 299 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE 300 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS 301 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF 302 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, 303 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, 304 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR 305 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, 306 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT 307 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT 308 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. 309 | 310 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE 311 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, 312 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, 313 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, 314 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR 315 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN 316 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR 317 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR 318 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. 319 | 320 | c. The disclaimer of warranties and limitation of liability provided 321 | above shall be interpreted in a manner that, to the extent 322 | possible, most closely approximates an absolute disclaimer and 323 | waiver of all liability. 324 | 325 | 326 | Section 6 -- Term and Termination. 327 | 328 | a. This Public License applies for the term of the Copyright and 329 | Similar Rights licensed here. However, if You fail to comply with 330 | this Public License, then Your rights under this Public License 331 | terminate automatically. 332 | 333 | b. Where Your right to use the Licensed Material has terminated under 334 | Section 6(a), it reinstates: 335 | 336 | 1. automatically as of the date the violation is cured, provided 337 | it is cured within 30 days of Your discovery of the 338 | violation; or 339 | 340 | 2. upon express reinstatement by the Licensor. 341 | 342 | For the avoidance of doubt, this Section 6(b) does not affect any 343 | right the Licensor may have to seek remedies for Your violations 344 | of this Public License. 345 | 346 | c. For the avoidance of doubt, the Licensor may also offer the 347 | Licensed Material under separate terms or conditions or stop 348 | distributing the Licensed Material at any time; however, doing so 349 | will not terminate this Public License. 350 | 351 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public 352 | License. 353 | 354 | 355 | Section 7 -- Other Terms and Conditions. 356 | 357 | a. The Licensor shall not be bound by any additional or different 358 | terms or conditions communicated by You unless expressly agreed. 359 | 360 | b. Any arrangements, understandings, or agreements regarding the 361 | Licensed Material not stated herein are separate from and 362 | independent of the terms and conditions of this Public License. 363 | 364 | 365 | Section 8 -- Interpretation. 366 | 367 | a. For the avoidance of doubt, this Public License does not, and 368 | shall not be interpreted to, reduce, limit, restrict, or impose 369 | conditions on any use of the Licensed Material that could lawfully 370 | be made without permission under this Public License. 371 | 372 | b. To the extent possible, if any provision of this Public License is 373 | deemed unenforceable, it shall be automatically reformed to the 374 | minimum extent necessary to make it enforceable. If the provision 375 | cannot be reformed, it shall be severed from this Public License 376 | without affecting the enforceability of the remaining terms and 377 | conditions. 378 | 379 | c. No term or condition of this Public License will be waived and no 380 | failure to comply consented to unless expressly agreed to by the 381 | Licensor. 382 | 383 | d. Nothing in this Public License constitutes or may be interpreted 384 | as a limitation upon, or waiver of, any privileges and immunities 385 | that apply to the Licensor or You, including from the legal 386 | processes of any jurisdiction or authority. 387 | 388 | ======================================================================= 389 | 390 | Creative Commons is not a party to its public 391 | licenses. Notwithstanding, Creative Commons may elect to apply one of 392 | its public licenses to material it publishes and in those instances 393 | will be considered the “Licensor.” The text of the Creative Commons 394 | public licenses is dedicated to the public domain under the CC0 Public 395 | Domain Dedication. Except for the limited purpose of indicating that 396 | material is shared under a Creative Commons public license or as 397 | otherwise permitted by the Creative Commons policies published at 398 | creativecommons.org/policies, Creative Commons does not authorize the 399 | use of the trademark "Creative Commons" or any other trademark or logo 400 | of Creative Commons without its prior written consent including, 401 | without limitation, in connection with any unauthorized modifications 402 | to any of its public licenses or any other arrangements, 403 | understandings, or agreements concerning use of licensed material. For 404 | the avoidance of doubt, this paragraph does not form part of the 405 | public licenses. 406 | 407 | Creative Commons may be contacted at creativecommons.org. 408 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /benchmarks/Pred_LongBench/2wikimqa.jsonl: -------------------------------------------------------------------------------- 1 | {"pred": "Izmit", "answers": ["Ozalj"], "all_classes": null, "length": 4696} 2 | {"pred": "Elizabeth", "answers": ["John the Baptist"], "all_classes": null, "length": 4776} 3 | {"pred": "Sam Spiegel Film and Television School", "answers": ["Cahiers du cinéma"], "all_classes": null, "length": 4274} 4 | {"pred": "England", "answers": ["no"], "all_classes": null, "length": 8125} 5 | {"pred": "1483", "answers": ["1510"], "all_classes": null, "length": 4621} 6 | {"pred": "Abd al-Muttalib", "answers": ["Edward Watson"], "all_classes": null, "length": 4625} 7 | {"pred": "1275", "answers": ["16 September 1360"], "all_classes": null, "length": 5001} 8 | {"pred": "Helmichis's father-in-law is Alboin.", "answers": ["Cunimund"], "all_classes": null, "length": 7639} 9 | {"pred": " Dublin", "answers": ["St Patrick's College"], "all_classes": null, "length": 3964} 10 | {"pred": "Wine of Morning director Katherine Stenholm worked at Unusual Films.", "answers": ["Bob Jones University"], "all_classes": null, "length": 5162} 11 | {"pred": "The House Of The Seven Hawks", "answers": ["The House Of The Seven Hawks"], "all_classes": null, "length": 10338} 12 | {"pred": "Marie Of Brabant's paternal grandmother is Marie Of Brabant.", "answers": ["Marie of Hohenstaufen"], "all_classes": null, "length": 3596} 13 | {"pred": "Adelaide, Countess of Soissons", "answers": ["Guy II, Count of Soissons"], "all_classes": null, "length": 1280} 14 | {"pred": "1839", "answers": ["26 April 1872"], "all_classes": null, "length": 3432} 15 | {"pred": "United States", "answers": ["America"], "all_classes": null, "length": 4442} 16 | {"pred": "Larry Parks", "answers": ["Ona Munson"], "all_classes": null, "length": 10444} 17 | {"pred": "The Death of Black King (1932)", "answers": ["The Death Of Black King"], "all_classes": null, "length": 1696} 18 | {"pred": "Yes", "answers": ["no"], "all_classes": null, "length": 5932} 19 | {"pred": " No", "answers": ["no"], "all_classes": null, "length": 535} 20 | {"pred": "Kanneshwara Rama", "answers": ["Mysore"], "all_classes": null, "length": 3532} 21 | {"pred": "Hamar", "answers": ["Kristiania"], "all_classes": null, "length": 2782} 22 | {"pred": "Kathy Griffin graduated from Oak Park High School.", "answers": ["Lee Strasberg Theatre and Film Institute"], "all_classes": null, "length": 7578} 23 | {"pred": "Hell Up In Harlem", "answers": ["Hell Up In Harlem"], "all_classes": null, "length": 9013} 24 | {"pred": "Cipriano Castro", "answers": ["Cipriano Castro"], "all_classes": null, "length": 6026} 25 | {"pred": "The Year Of The Rabbit", "answers": ["Monster On The Campus"], "all_classes": null, "length": 6192} 26 | {"pred": "Younger", "answers": ["Altuğ Çelikbilek"], "all_classes": null, "length": 889} 27 | {"pred": "Joel The Lump Of Coal", "answers": ["Jugband Blues"], "all_classes": null, "length": 3517} 28 | {"pred": "Yes", "answers": ["yes"], "all_classes": null, "length": 4141} 29 | {"pred": "Baldwin I Rátót", "answers": ["Leustach Rátót"], "all_classes": null, "length": 3948} 30 | {"pred": "Kantipur", "answers": ["Nepal"], "all_classes": null, "length": 4625} 31 | {"pred": "Archibald Acheson, 1st Earl of Gosford", "answers": ["Archibald Acheson, 2nd Earl of Gosford"], "all_classes": null, "length": 4383} 32 | {"pred": "Abd al-Muttalib", "answers": ["Jean Paul Getty"], "all_classes": null, "length": 5464} 33 | {"pred": "Goring-on-Thames", "answers": ["Goring-on-Thames, Oxfordshire"], "all_classes": null, "length": 11686} 34 | {"pred": "British", "answers": ["yes"], "all_classes": null, "length": 3122} 35 | {"pred": "The Abduction Club", "answers": ["Wooden Crosses"], "all_classes": null, "length": 4154} 36 | {"pred": "The Magic Aster", "answers": ["Above Rubies"], "all_classes": null, "length": 3299} 37 | {"pred": "Blue Blood And Red", "answers": ["Blue Blood And Red"], "all_classes": null, "length": 4436} 38 | {"pred": "Paris, France", "answers": ["Castlemaine, Victoria, Australia"], "all_classes": null, "length": 2896} 39 | {"pred": "Paul De Scherff", "answers": ["Lyudmyla Olyanovska"], "all_classes": null, "length": 6416} 40 | {"pred": "Eyüp Cemetery, Istanbul", "answers": ["Constantinople"], "all_classes": null, "length": 4769} 41 | {"pred": "Melun", "answers": ["East Francia"], "all_classes": null, "length": 4957} 42 | {"pred": "Tex And The Lord Of The Deep", "answers": ["Henry Goes Arizona"], "all_classes": null, "length": 2540} 43 | {"pred": "Wales", "answers": ["United Kingdom"], "all_classes": null, "length": 8759} 44 | {"pred": "All-American Co-Ed", "answers": ["All-American Co-Ed"], "all_classes": null, "length": 5527} 45 | {"pred": "Buenos Aires", "answers": ["Buenos Aires"], "all_classes": null, "length": 3859} 46 | {"pred": "Louise's mother-in-law is Duchess Magdalene Sibylle of Holstein-Gottorp.", "answers": ["Charlotte Amalie of Hesse-Kassel"], "all_classes": null, "length": 5695} 47 | {"pred": " Stahleck Castle", "answers": ["Brunswick"], "all_classes": null, "length": 3954} 48 | {"pred": "Marcus Annius Libo's aunt is Rupilia Faustina.", "answers": ["Vibia Sabina"], "all_classes": null, "length": 3690} 49 | {"pred": "Riding the California Trail", "answers": ["Bajo Otro Sol"], "all_classes": null, "length": 1120} 50 | {"pred": "Lisa Azuelos", "answers": ["Marie Laforêt"], "all_classes": null, "length": 3219} 51 | {"pred": "Val Kilmer", "answers": ["Sandra Nelson"], "all_classes": null, "length": 10498} 52 | {"pred": "2 March 1702", "answers": ["May 19, 1669"], "all_classes": null, "length": 4570} 53 | {"pred": "Pamplona", "answers": ["Palencia"], "all_classes": null, "length": 3496} 54 | {"pred": " Dance With A Stranger", "answers": ["Miley Naa Miley Hum"], "all_classes": null, "length": 3934} 55 | {"pred": "Space Probe Taurus", "answers": ["Tom Mix In Arabia"], "all_classes": null, "length": 3324} 56 | {"pred": "Daughter of the Jungle", "answers": ["Seven In The Sun"], "all_classes": null, "length": 2708} 57 | {"pred": "Dr. Socrates", "answers": ["Dr. Socrates"], "all_classes": null, "length": 4218} 58 | {"pred": "Changeland", "answers": ["Changeland"], "all_classes": null, "length": 6736} 59 | {"pred": "Peter Rosegger", "answers": ["Ruel Redinger"], "all_classes": null, "length": 2046} 60 | {"pred": "Pyotr Karatygin's sibling-in-law is Vasily Karatygin.", "answers": ["Alexandra Kolosova"], "all_classes": null, "length": 2898} 61 | {"pred": " Maxine Caroll Lawrence", "answers": ["Jessi Colter"], "all_classes": null, "length": 7833} 62 | {"pred": " No", "answers": ["yes"], "all_classes": null, "length": 2351} 63 | {"pred": "William Pooley", "answers": ["William Pooley"], "all_classes": null, "length": 2189} 64 | {"pred": " French", "answers": ["Sweden"], "all_classes": null, "length": 4845} 65 | {"pred": " American", "answers": ["yes"], "all_classes": null, "length": 1412} 66 | {"pred": "Melody Of The World", "answers": ["Melody Of The World"], "all_classes": null, "length": 4784} 67 | {"pred": "Madrid", "answers": ["Madrid"], "all_classes": null, "length": 2868} 68 | {"pred": "Oskar Roehler", "answers": ["Gisela Elsner"], "all_classes": null, "length": 3211} 69 | {"pred": "The Great Man's Lady", "answers": ["La Belle Américaine"], "all_classes": null, "length": 4142} 70 | {"pred": "Oklahoma City, Oklahoma", "answers": ["Oklahoma City, Oklahoma"], "all_classes": null, "length": 3277} 71 | {"pred": "Dubai", "answers": ["Dubai"], "all_classes": null, "length": 4240} 72 | {"pred": "France", "answers": ["La Trinité"], "all_classes": null, "length": 4966} 73 | {"pred": "Duke Paul Frederick of Mecklenburg-Schwerin", "answers": ["Prince Albert of Prussia"], "all_classes": null, "length": 4350} 74 | {"pred": "Yes", "answers": ["yes"], "all_classes": null, "length": 562} 75 | {"pred": "Pembroke Castle", "answers": ["Banbury"], "all_classes": null, "length": 3847} 76 | {"pred": "Sir Paul Gore, 1st Baronet", "answers": ["Sir Paul Gore, 1st Baronet"], "all_classes": null, "length": 2779} 77 | {"pred": "Ireland", "answers": ["Fatimid caliphate"], "all_classes": null, "length": 8314} 78 | {"pred": "Santa Monica", "answers": ["Dayton, Ohio"], "all_classes": null, "length": 3761} 79 | {"pred": "Bill Graham", "answers": ["Ona Munson"], "all_classes": null, "length": 11373} 80 | {"pred": "Pigsty (Film)", "answers": ["The Drover'S Sweetheart"], "all_classes": null, "length": 3866} 81 | {"pred": "Fayetteville, Arkansas", "answers": ["Tumkur"], "all_classes": null, "length": 3459} 82 | {"pred": "Closely Watched Trains", "answers": ["Det Sande Ansigt"], "all_classes": null, "length": 4713} 83 | {"pred": "Il Gaucho", "answers": ["Bomgay"], "all_classes": null, "length": 4261} 84 | {"pred": "Pliska", "answers": ["Preslav"], "all_classes": null, "length": 5869} 85 | {"pred": "The Pyramid", "answers": ["Revolt Of The Praetorians"], "all_classes": null, "length": 3504} 86 | {"pred": "France", "answers": ["yes"], "all_classes": null, "length": 10582} 87 | {"pred": "Eric XIV of Sweden died of arsenic poisoning.", "answers": ["poisoning"], "all_classes": null, "length": 5584} 88 | {"pred": "American", "answers": ["America"], "all_classes": null, "length": 5631} 89 | {"pred": "Kamehameha I", "answers": ["Kingdom of Hawaii"], "all_classes": null, "length": 8829} 90 | {"pred": "Mi Novia Está De Madre", "answers": ["X-Paroni"], "all_classes": null, "length": 2674} 91 | {"pred": "Vasantha Raagam", "answers": ["Vasantha Raagam"], "all_classes": null, "length": 6052} 92 | {"pred": " Lee Kun-Hee", "answers": ["Lee Byung-chul"], "all_classes": null, "length": 6337} 93 | {"pred": "Charles I, Duke of Bourbon", "answers": ["John I, Duke of Bourbon"], "all_classes": null, "length": 4501} 94 | {"pred": "1839", "answers": ["27 June 1839"], "all_classes": null, "length": 10742} 95 | {"pred": "Nathan Juran", "answers": ["Gura Humorului"], "all_classes": null, "length": 3888} 96 | {"pred": "Amandine Bourgeois", "answers": ["Françoise Hardy"], "all_classes": null, "length": 5091} 97 | {"pred": "Charles Wheatstone", "answers": ["Charles Wheatstone"], "all_classes": null, "length": 8386} 98 | {"pred": "Perryville, Missouri", "answers": ["Washington"], "all_classes": null, "length": 1917} 99 | {"pred": "London Melody", "answers": ["London Melody"], "all_classes": null, "length": 9611} 100 | {"pred": "Diego Fernández de Oviedo", "answers": ["Flaín Muñoz"], "all_classes": null, "length": 3208} 101 | {"pred": "Michael Sorvino", "answers": ["Mira Sorvino"], "all_classes": null, "length": 4398} 102 | {"pred": "Yes", "answers": ["yes"], "all_classes": null, "length": 1863} 103 | {"pred": "Denmark", "answers": ["Norway"], "all_classes": null, "length": 3211} 104 | {"pred": "Death by illness", "answers": ["illness"], "all_classes": null, "length": 4438} 105 | {"pred": "She died in Amesbury Priory.", "answers": ["Harby"], "all_classes": null, "length": 10562} 106 | {"pred": "The director of film Tiramisu (2002) earned the Hundred Flowers Award for Best Director.", "answers": ["Hong Kong Film Award for Best Director"], "all_classes": null, "length": 3401} 107 | {"pred": "192", "answers": ["3 September 1992"], "all_classes": null, "length": 1680} 108 | {"pred": "Nathaniel McLenaghan", "answers": ["Nathaniel Mclenaghan"], "all_classes": null, "length": 3026} 109 | {"pred": "Tisch School of the Arts", "answers": ["Tisch"], "all_classes": null, "length": 3749} 110 | {"pred": "Cuchillos De Fuego", "answers": ["Cuchillos De Fuego"], "all_classes": null, "length": 2697} 111 | {"pred": "The Ballad Of Josie", "answers": ["Moment Of Danger"], "all_classes": null, "length": 4003} 112 | {"pred": "De AS", "answers": ["De As"], "all_classes": null, "length": 2425} 113 | {"pred": "The Piper's Price", "answers": ["The Piper'S Price"], "all_classes": null, "length": 4134} 114 | {"pred": " American", "answers": ["yes"], "all_classes": null, "length": 5918} 115 | {"pred": "1753", "answers": ["13 March 1753"], "all_classes": null, "length": 4323} 116 | {"pred": "True To The Navy", "answers": ["No Trees In The Street"], "all_classes": null, "length": 7162} 117 | {"pred": "Malayalam", "answers": ["Methala"], "all_classes": null, "length": 3129} 118 | {"pred": "House of Dark Shadows", "answers": ["Alkohol"], "all_classes": null, "length": 5610} 119 | {"pred": "Do Musafir", "answers": ["Do Musafir"], "all_classes": null, "length": 1138} 120 | {"pred": " Yes", "answers": ["no"], "all_classes": null, "length": 1241} 121 | {"pred": "New York City", "answers": ["New York"], "all_classes": null, "length": 2516} 122 | {"pred": "Tiger In The Smoke", "answers": ["Contragolpe"], "all_classes": null, "length": 3675} 123 | {"pred": "Mumbai", "answers": ["Mumbai"], "all_classes": null, "length": 3052} 124 | {"pred": "The Comedians of Comedy", "answers": ["The Comedians Of Comedy"], "all_classes": null, "length": 4756} 125 | {"pred": "Tombstone Rashomon", "answers": ["Tombstone Rashomon"], "all_classes": null, "length": 5772} 126 | {"pred": "Dhuen Ki Lakeer", "answers": ["Dhuen Ki Lakeer"], "all_classes": null, "length": 4828} 127 | {"pred": "Perdón, viejita", "answers": ["Perdón, Viejita"], "all_classes": null, "length": 10456} 128 | {"pred": "University of Wisconsin-Madison", "answers": ["University of Wisconsin"], "all_classes": null, "length": 2748} 129 | {"pred": "Dudley Russell", "answers": ["Dudley Russell"], "all_classes": null, "length": 4526} 130 | {"pred": "Vytautas Straižys", "answers": ["Mirjam Polkunen"], "all_classes": null, "length": 3620} 131 | {"pred": "Russia", "answers": ["Saint Petersburg"], "all_classes": null, "length": 9479} 132 | {"pred": "Menno Meyjes", "answers": ["Eindhoven"], "all_classes": null, "length": 3592} 133 | {"pred": "Women's Suffrage Journal", "answers": ["Women'S Suffrage Journal"], "all_classes": null, "length": 3828} 134 | {"pred": "Fairmont, West Virginia", "answers": ["Fairmont, West Virginia"], "all_classes": null, "length": 3228} 135 | {"pred": "The Market Of Souls", "answers": ["The Market Of Souls"], "all_classes": null, "length": 2305} 136 | {"pred": "Yes", "answers": ["yes"], "all_classes": null, "length": 4201} 137 | {"pred": "Marrakech", "answers": ["Morocco"], "all_classes": null, "length": 4417} 138 | {"pred": "Beaulieu-sur-Loire]", "answers": ["Akkadian empire"], "all_classes": null, "length": 8701} 139 | {"pred": "Tarzan The Magnificent", "answers": ["Tarzan The Magnificent"], "all_classes": null, "length": 4580} 140 | {"pred": "Wolf Warrior", "answers": ["Wolf Warrior"], "all_classes": null, "length": 3774} 141 | {"pred": "Trịnh Cương", "answers": ["Trịnh Bính"], "all_classes": null, "length": 3718} 142 | {"pred": "Nathan Juran", "answers": ["Gura Humorului"], "all_classes": null, "length": 4854} 143 | {"pred": "Bernie Bonvois", "answers": ["Crenshaw"], "all_classes": null, "length": 7562} 144 | {"pred": "Winter Sleepers", "answers": ["Winter Sleepers"], "all_classes": null, "length": 3065} 145 | {"pred": "Bruno II of Isenburg-Braunsberg", "answers": ["Adolf II of Berg"], "all_classes": null, "length": 4040} 146 | {"pred": " No", "answers": ["no"], "all_classes": null, "length": 6771} 147 | {"pred": "Helge Schneider", "answers": ["Mülheim an der Ruhr"], "all_classes": null, "length": 3140} 148 | {"pred": "Una Prostituta Al Servizio Del Pubblico E In Regola Con Le Leggi Dello Stato", "answers": ["Una Prostituta Al Servizio Del Pubblico E In Regola Con Le Leggi Dello Stato"], "all_classes": null, "length": 2809} 149 | {"pred": " Yes", "answers": ["yes"], "all_classes": null, "length": 10991} 150 | {"pred": "Prince Bhanurangsi Savangwongse", "answers": ["Mongkut"], "all_classes": null, "length": 3759} 151 | {"pred": "Elizabeth Blount", "answers": ["Lady Frances Manners"], "all_classes": null, "length": 3293} 152 | {"pred": "L'Arbre, le maire et la médiathèque]", "answers": ["Cahiers du cinéma"], "all_classes": null, "length": 9453} 153 | {"pred": "Mexico", "answers": ["Newport, Wales"], "all_classes": null, "length": 3945} 154 | {"pred": "Orange County, Virginia", "answers": ["Gordonsville, Virginia"], "all_classes": null, "length": 9096} 155 | {"pred": "Romania", "answers": ["Mangalia"], "all_classes": null, "length": 8046} 156 | {"pred": "Pier-Luc Funk", "answers": ["Félix Leclerc"], "all_classes": null, "length": 4871} 157 | {"pred": "Henri I, Duke of Nemours", "answers": ["Jacques"], "all_classes": null, "length": 5577} 158 | {"pred": "Gisela Of Hungary's father-in-law is Stephen I Of Hungary.", "answers": ["Géza"], "all_classes": null, "length": 11622} 159 | {"pred": "Thuya", "answers": ["Yi Jiang"], "all_classes": null, "length": 4142} 160 | {"pred": "5 September 1896", "answers": ["November 3, 1867"], "all_classes": null, "length": 6374} 161 | {"pred": "Guy Newall", "answers": ["Ivy Duke"], "all_classes": null, "length": 10243} 162 | {"pred": "181", "answers": ["1 August 1813"], "all_classes": null, "length": 2765} 163 | {"pred": "The Yellow Teddy Bears", "answers": ["Season Of Strangers"], "all_classes": null, "length": 6966} 164 | {"pred": " American", "answers": ["yes"], "all_classes": null, "length": 4450} 165 | {"pred": "Abdul Ali Lalu", "answers": ["Antoine Casavant"], "all_classes": null, "length": 3131} 166 | {"pred": "Allahabad", "answers": ["Oxford"], "all_classes": null, "length": 4103} 167 | {"pred": "Sir John Gage", "answers": ["Sir Hew Dalrymple, 2nd Baronet"], "all_classes": null, "length": 2932} 168 | {"pred": "Khud-Daar", "answers": ["Murderer In The Fog"], "all_classes": null, "length": 2692} 169 | {"pred": "Paul W. S. Anderson", "answers": ["Milla"], "all_classes": null, "length": 7443} 170 | {"pred": "Naples", "answers": ["Lecce"], "all_classes": null, "length": 3523} 171 | {"pred": "Lady Magdalene's director, J. Neil Schulman, won the \"Special Jury Prize for Libertarian Ideals\" at the 20", "answers": ["Prometheus Award"], "all_classes": null, "length": 3529} 172 | {"pred": "German", "answers": ["Duchy of Saxe-Gotha-Altenburg"], "all_classes": null, "length": 4847} 173 | {"pred": "The Longshots", "answers": ["Station For Two"], "all_classes": null, "length": 9998} 174 | {"pred": "Inverkeithing, Fife, Scotland", "answers": ["Inverkeithing"], "all_classes": null, "length": 4097} 175 | {"pred": "1234", "answers": ["1234"], "all_classes": null, "length": 9182} 176 | {"pred": "1642", "answers": ["25 July 1666"], "all_classes": null, "length": 6136} 177 | {"pred": "Special Delivery (1927 Film)", "answers": ["The Wind'S Fierce"], "all_classes": null, "length": 9043} 178 | {"pred": "He died by suicide by drowning himself in his swimming pool.", "answers": ["drowning"], "all_classes": null, "length": 8169} 179 | {"pred": "Abu Dhabi", "answers": ["Abu Dhabi"], "all_classes": null, "length": 3593} 180 | {"pred": "Ottensen", "answers": ["Prenzlau"], "all_classes": null, "length": 3584} 181 | {"pred": "Poland", "answers": ["Poznań"], "all_classes": null, "length": 3814} 182 | {"pred": "Picton", "answers": ["Italy"], "all_classes": null, "length": 4597} 183 | {"pred": "New Zealand", "answers": ["Toronto, Ontario"], "all_classes": null, "length": 3366} 184 | {"pred": "The Third Kiss", "answers": ["Forbidden Daughters"], "all_classes": null, "length": 4389} 185 | {"pred": "Yes", "answers": ["yes"], "all_classes": null, "length": 4958} 186 | {"pred": "Bogdan Țăruș", "answers": ["Bogdan Țăruș"], "all_classes": null, "length": 3425} 187 | {"pred": "Folgore Division", "answers": ["Sandflow"], "all_classes": null, "length": 1945} 188 | {"pred": "Durango Valley Raiders", "answers": ["From Corleone To Brooklyn"], "all_classes": null, "length": 3894} 189 | {"pred": " No", "answers": ["yes"], "all_classes": null, "length": 667} 190 | {"pred": "Pacific Palisades", "answers": ["Vienna"], "all_classes": null, "length": 10974} 191 | {"pred": "Johnny Ekström", "answers": ["Johnny Ekström"], "all_classes": null, "length": 6044} 192 | {"pred": " No", "answers": ["yes"], "all_classes": null, "length": 592} 193 | {"pred": "English", "answers": ["England"], "all_classes": null, "length": 11950} 194 | {"pred": "Norman Panama died on January 13, 2003 due to complications of Parkinson's disease.", "answers": ["Parkinson"], "all_classes": null, "length": 4337} 195 | {"pred": "Thomas De Berkeley, 5th Baron Berkeley", "answers": ["Roger Mortimer, 1st Earl of March"], "all_classes": null, "length": 3947} 196 | {"pred": "Ali Dinar", "answers": ["Albertus Seba"], "all_classes": null, "length": 4441} 197 | {"pred": "Henry Krause", "answers": ["Ferdinand I of the Two Sicilies"], "all_classes": null, "length": 8935} 198 | {"pred": "Prince Of Arcadia", "answers": ["Prince Of Arcadia"], "all_classes": null, "length": 3155} 199 | {"pred": "Italy", "answers": ["Milan"], "all_classes": null, "length": 3195} 200 | {"pred": "Germany", "answers": ["Köln"], "all_classes": null, "length": 3719} 201 | -------------------------------------------------------------------------------- /benchmarks/Pred_LongBench/passage_retrieval_en.jsonl: -------------------------------------------------------------------------------- 1 | {"pred": "\nParagraph 15", "answers": ["Paragraph 15"], "all_classes": null, "length": 8525} 2 | {"pred": "\nParagraph 28", "answers": ["Paragraph 8"], "all_classes": null, "length": 8679} 3 | {"pred": "Paragraph 17]", "answers": ["Paragraph 18"], "all_classes": null, "length": 9562} 4 | {"pred": "\nParagraph 17", "answers": ["Paragraph 17"], "all_classes": null, "length": 9273} 5 | {"pred": " Paragraph 14", "answers": ["Paragraph 14"], "all_classes": null, "length": 9175} 6 | {"pred": "\nParagraph 1", "answers": ["Paragraph 3"], "all_classes": null, "length": 9139} 7 | {"pred": "\nParagraph 24", "answers": ["Paragraph 24"], "all_classes": null, "length": 8832} 8 | {"pred": "Paragraph 28", "answers": ["Paragraph 27"], "all_classes": null, "length": 8995} 9 | {"pred": "\nParagraph 30\nThe embodiment of the Gopher mascot came to life in 1952 when University of Minnesota assistant band", "answers": ["Paragraph 30"], "all_classes": null, "length": 7789} 10 | {"pred": "Paragraph 20]", "answers": ["Paragraph 12"], "all_classes": null, "length": 10027} 11 | {"pred": "Paragraph 2", "answers": ["Paragraph 22"], "all_classes": null, "length": 9273} 12 | {"pred": "Paragraph 20]", "answers": ["Paragraph 30"], "all_classes": null, "length": 10178} 13 | {"pred": "Paragraph 24", "answers": ["Paragraph 24"], "all_classes": null, "length": 9221} 14 | {"pred": "Paragraph 17", "answers": ["Paragraph 19"], "all_classes": null, "length": 9718} 15 | {"pred": "Paragraph 20", "answers": ["Paragraph 3"], "all_classes": null, "length": 11340} 16 | {"pred": "Paragraph 1]", "answers": ["Paragraph 8"], "all_classes": null, "length": 9750} 17 | {"pred": "Paragraph 1", "answers": ["Paragraph 3"], "all_classes": null, "length": 8965} 18 | {"pred": "Paragraph 17]", "answers": ["Paragraph 17"], "all_classes": null, "length": 8965} 19 | {"pred": "\nParagraph 1", "answers": ["Paragraph 4"], "all_classes": null, "length": 9275} 20 | {"pred": "Paragraph 1]", "answers": ["Paragraph 3"], "all_classes": null, "length": 10934} 21 | {"pred": "Paragraph 24]", "answers": ["Paragraph 24"], "all_classes": null, "length": 9713} 22 | {"pred": "Paragraph 2", "answers": ["Paragraph 2"], "all_classes": null, "length": 9947} 23 | {"pred": "Paragraph 1", "answers": ["Paragraph 6"], "all_classes": null, "length": 8991} 24 | {"pred": "Paragraph 17]", "answers": ["Paragraph 12"], "all_classes": null, "length": 9211} 25 | {"pred": "Paragraph 24", "answers": ["Paragraph 24"], "all_classes": null, "length": 8267} 26 | {"pred": "\nParagraph 1", "answers": ["Paragraph 30"], "all_classes": null, "length": 8180} 27 | {"pred": "Paragraph 1]", "answers": ["Paragraph 1"], "all_classes": null, "length": 10576} 28 | {"pred": "Paragraph 1", "answers": ["Paragraph 6"], "all_classes": null, "length": 9513} 29 | {"pred": "Paragraph 1", "answers": ["Paragraph 8"], "all_classes": null, "length": 9747} 30 | {"pred": "Paragraph 29", "answers": ["Paragraph 3"], "all_classes": null, "length": 7534} 31 | {"pred": "Paragraph 1", "answers": ["Paragraph 1"], "all_classes": null, "length": 9908} 32 | {"pred": "Paragraph 27", "answers": ["Paragraph 25"], "all_classes": null, "length": 8838} 33 | {"pred": "Paragraph 17]", "answers": ["Paragraph 19"], "all_classes": null, "length": 10016} 34 | {"pred": "\nParagraph 20", "answers": ["Paragraph 22"], "all_classes": null, "length": 7992} 35 | {"pred": "Paragraph 17]", "answers": ["Paragraph 5"], "all_classes": null, "length": 9899} 36 | {"pred": "Paragraph 19", "answers": ["Paragraph 15"], "all_classes": null, "length": 9483} 37 | {"pred": "Paragraph 28]", "answers": ["Paragraph 28"], "all_classes": null, "length": 10118} 38 | {"pred": "Paragraph 1", "answers": ["Paragraph 5"], "all_classes": null, "length": 8975} 39 | {"pred": "Paragraph 1", "answers": ["Paragraph 30"], "all_classes": null, "length": 9934} 40 | {"pred": "\nParagraph 10", "answers": ["Paragraph 6"], "all_classes": null, "length": 7736} 41 | {"pred": "Paragraph 1", "answers": ["Paragraph 12"], "all_classes": null, "length": 8771} 42 | {"pred": "Paragraph 27]", "answers": ["Paragraph 26"], "all_classes": null, "length": 10004} 43 | {"pred": "\nParagraph 20", "answers": ["Paragraph 30"], "all_classes": null, "length": 9000} 44 | {"pred": "Paragraph 1]", "answers": ["Paragraph 6"], "all_classes": null, "length": 10255} 45 | {"pred": "\nParagraph 14", "answers": ["Paragraph 12"], "all_classes": null, "length": 8776} 46 | {"pred": "Paragraph 23]", "answers": ["Paragraph 23"], "all_classes": null, "length": 10742} 47 | {"pred": "Paragraph 1]", "answers": ["Paragraph 17"], "all_classes": null, "length": 10320} 48 | {"pred": "Paragraph 1", "answers": ["Paragraph 1"], "all_classes": null, "length": 8971} 49 | {"pred": " Paragraph 27", "answers": ["Paragraph 27"], "all_classes": null, "length": 10277} 50 | {"pred": "Paragraph 21]", "answers": ["Paragraph 21"], "all_classes": null, "length": 9525} 51 | {"pred": "\nParagraph 27", "answers": ["Paragraph 27"], "all_classes": null, "length": 8674} 52 | {"pred": "Paragraph 29", "answers": ["Paragraph 29"], "all_classes": null, "length": 8629} 53 | {"pred": "Paragraph 29", "answers": ["Paragraph 28"], "all_classes": null, "length": 9970} 54 | {"pred": "Paragraph 1", "answers": ["Paragraph 7"], "all_classes": null, "length": 8791} 55 | {"pred": "Paragraph 17]", "answers": ["Paragraph 15"], "all_classes": null, "length": 9949} 56 | {"pred": "Paragraph 20]", "answers": ["Paragraph 22"], "all_classes": null, "length": 9832} 57 | {"pred": "Paragraph 16", "answers": ["Paragraph 16"], "all_classes": null, "length": 8952} 58 | {"pred": "\nParagraph 14", "answers": ["Paragraph 14"], "all_classes": null, "length": 8253} 59 | {"pred": "Paragraph 29", "answers": ["Paragraph 29"], "all_classes": null, "length": 10431} 60 | {"pred": "\nParagraph 19", "answers": ["Paragraph 19"], "all_classes": null, "length": 10196} 61 | {"pred": "Paragraph 20]", "answers": ["Paragraph 28"], "all_classes": null, "length": 10633} 62 | {"pred": "\nParagraph 25", "answers": ["Paragraph 26"], "all_classes": null, "length": 9824} 63 | {"pred": "Paragraph 1", "answers": ["Paragraph 1"], "all_classes": null, "length": 9324} 64 | {"pred": "\nParagraph 1", "answers": ["Paragraph 10"], "all_classes": null, "length": 10299} 65 | {"pred": " Paragraph 17", "answers": ["Paragraph 8"], "all_classes": null, "length": 9489} 66 | {"pred": "\nParagraph 1", "answers": ["Paragraph 3"], "all_classes": null, "length": 8281} 67 | {"pred": " Paragraph 27", "answers": ["Paragraph 27"], "all_classes": null, "length": 8804} 68 | {"pred": "Paragraph 16]", "answers": ["Paragraph 15"], "all_classes": null, "length": 8473} 69 | {"pred": "Paragraph 17]", "answers": ["Paragraph 21"], "all_classes": null, "length": 9216} 70 | {"pred": "Paragraph 28", "answers": ["Paragraph 28"], "all_classes": null, "length": 10941} 71 | {"pred": "Paragraph 1]", "answers": ["Paragraph 5"], "all_classes": null, "length": 10198} 72 | {"pred": "Paragraph 1", "answers": ["Paragraph 9"], "all_classes": null, "length": 8526} 73 | {"pred": "Paragraph 2", "answers": ["Paragraph 20"], "all_classes": null, "length": 8039} 74 | {"pred": "\nParagraph 17", "answers": ["Paragraph 15"], "all_classes": null, "length": 9207} 75 | {"pred": "Paragraph 1", "answers": ["Paragraph 6"], "all_classes": null, "length": 9357} 76 | {"pred": "Paragraph 20]", "answers": ["Paragraph 23"], "all_classes": null, "length": 9713} 77 | {"pred": "\nParagraph 1", "answers": ["Paragraph 26"], "all_classes": null, "length": 8114} 78 | {"pred": "Paragraph 1", "answers": ["Paragraph 18"], "all_classes": null, "length": 9148} 79 | {"pred": "\nParagraph 1", "answers": ["Paragraph 11"], "all_classes": null, "length": 8795} 80 | {"pred": "Paragraph 19", "answers": ["Paragraph 18"], "all_classes": null, "length": 8491} 81 | {"pred": "Paragraph 17", "answers": ["Paragraph 7"], "all_classes": null, "length": 8642} 82 | {"pred": "Paragraph 24", "answers": ["Paragraph 24"], "all_classes": null, "length": 11270} 83 | {"pred": "Paragraph 2", "answers": ["Paragraph 23"], "all_classes": null, "length": 9267} 84 | {"pred": "Paragraph 27", "answers": ["Paragraph 27"], "all_classes": null, "length": 8747} 85 | {"pred": "Paragraph 17]", "answers": ["Paragraph 18"], "all_classes": null, "length": 10176} 86 | {"pred": "\nParagraph 29", "answers": ["Paragraph 28"], "all_classes": null, "length": 8712} 87 | {"pred": "Paragraph 10", "answers": ["Paragraph 3"], "all_classes": null, "length": 8315} 88 | {"pred": "Paragraph 17]", "answers": ["Paragraph 5"], "all_classes": null, "length": 9403} 89 | {"pred": "Paragraph 1", "answers": ["Paragraph 27"], "all_classes": null, "length": 10242} 90 | {"pred": "Paragraph 24", "answers": ["Paragraph 24"], "all_classes": null, "length": 8143} 91 | {"pred": "Paragraph 1]", "answers": ["Paragraph 29"], "all_classes": null, "length": 9508} 92 | {"pred": "\nParagraph 23", "answers": ["Paragraph 23"], "all_classes": null, "length": 7817} 93 | {"pred": " Paragraph 17", "answers": ["Paragraph 17"], "all_classes": null, "length": 7533} 94 | {"pred": "Paragraph 17", "answers": ["Paragraph 16"], "all_classes": null, "length": 9056} 95 | {"pred": "Paragraph 19]", "answers": ["Paragraph 18"], "all_classes": null, "length": 9565} 96 | {"pred": "Paragraph 17]", "answers": ["Paragraph 13"], "all_classes": null, "length": 10347} 97 | {"pred": "Paragraph 20]", "answers": ["Paragraph 20"], "all_classes": null, "length": 9260} 98 | {"pred": "Paragraph 17]", "answers": ["Paragraph 18"], "all_classes": null, "length": 11516} 99 | {"pred": "Paragraph 23]", "answers": ["Paragraph 23"], "all_classes": null, "length": 9334} 100 | {"pred": "Paragraph 1]", "answers": ["Paragraph 19"], "all_classes": null, "length": 10614} 101 | {"pred": "Paragraph 19]", "answers": ["Paragraph 19"], "all_classes": null, "length": 9251} 102 | {"pred": "Paragraph 1", "answers": ["Paragraph 30"], "all_classes": null, "length": 9922} 103 | {"pred": "Paragraph 17]", "answers": ["Paragraph 13"], "all_classes": null, "length": 9459} 104 | {"pred": " Paragraph 1", "answers": ["Paragraph 9"], "all_classes": null, "length": 9025} 105 | {"pred": "Paragraph 2", "answers": ["Paragraph 22"], "all_classes": null, "length": 9717} 106 | {"pred": "Paragraph 1", "answers": ["Paragraph 11"], "all_classes": null, "length": 9446} 107 | {"pred": "Paragraph 1", "answers": ["Paragraph 3"], "all_classes": null, "length": 10308} 108 | {"pred": "\nParagraph 10\nThe airport is northeast of downtown in east Omaha. Although the airport is in Nebraska", "answers": ["Paragraph 12"], "all_classes": null, "length": 8558} 109 | {"pred": "Paragraph 17]", "answers": ["Paragraph 10"], "all_classes": null, "length": 9485} 110 | {"pred": "Paragraph 20", "answers": ["Paragraph 3"], "all_classes": null, "length": 8673} 111 | {"pred": "Paragraph 1", "answers": ["Paragraph 10"], "all_classes": null, "length": 9147} 112 | {"pred": "Paragraph 1", "answers": ["Paragraph 2"], "all_classes": null, "length": 10298} 113 | {"pred": "\nParagraph 27", "answers": ["Paragraph 27"], "all_classes": null, "length": 8716} 114 | {"pred": "\nParagraph 2", "answers": ["Paragraph 7"], "all_classes": null, "length": 9171} 115 | {"pred": "Paragraph 19", "answers": ["Paragraph 20"], "all_classes": null, "length": 9357} 116 | {"pred": "Paragraph 25", "answers": ["Paragraph 25"], "all_classes": null, "length": 10597} 117 | {"pred": "\nParagraph 1", "answers": ["Paragraph 30"], "all_classes": null, "length": 8070} 118 | {"pred": "\nParagraph 1", "answers": ["Paragraph 9"], "all_classes": null, "length": 9095} 119 | {"pred": "Paragraph 17", "answers": ["Paragraph 14"], "all_classes": null, "length": 8867} 120 | {"pred": "Paragraph 1", "answers": ["Paragraph 8"], "all_classes": null, "length": 9895} 121 | {"pred": " Paragraph 1", "answers": ["Paragraph 8"], "all_classes": null, "length": 9072} 122 | {"pred": "Paragraph 1", "answers": ["Paragraph 19"], "all_classes": null, "length": 10852} 123 | {"pred": "Paragraph 2", "answers": ["Paragraph 22"], "all_classes": null, "length": 9767} 124 | {"pred": "\nParagraph 1", "answers": ["Paragraph 1"], "all_classes": null, "length": 9547} 125 | {"pred": "Paragraph 1", "answers": ["Paragraph 7"], "all_classes": null, "length": 8081} 126 | {"pred": "\nParagraph 1", "answers": ["Paragraph 7"], "all_classes": null, "length": 7860} 127 | {"pred": "Paragraph 28]", "answers": ["Paragraph 28"], "all_classes": null, "length": 9423} 128 | {"pred": "\nParagraph 1", "answers": ["Paragraph 11"], "all_classes": null, "length": 7818} 129 | {"pred": "Paragraph 17]", "answers": ["Paragraph 18"], "all_classes": null, "length": 10417} 130 | {"pred": "Paragraph 1", "answers": ["Paragraph 9"], "all_classes": null, "length": 8644} 131 | {"pred": "Paragraph 14", "answers": ["Paragraph 16"], "all_classes": null, "length": 8429} 132 | {"pred": "\nParagraph 1", "answers": ["Paragraph 8"], "all_classes": null, "length": 8931} 133 | {"pred": " Paragraph 1", "answers": ["Paragraph 27"], "all_classes": null, "length": 9098} 134 | {"pred": " Paragraph 1", "answers": ["Paragraph 7"], "all_classes": null, "length": 8751} 135 | {"pred": "Paragraph 26]", "answers": ["Paragraph 26"], "all_classes": null, "length": 9296} 136 | {"pred": "Paragraph 24]", "answers": ["Paragraph 24"], "all_classes": null, "length": 9412} 137 | {"pred": "Paragraph 20]", "answers": ["Paragraph 20"], "all_classes": null, "length": 9968} 138 | {"pred": "Paragraph 19]", "answers": ["Paragraph 29"], "all_classes": null, "length": 9105} 139 | {"pred": "\nParagraph 28", "answers": ["Paragraph 28"], "all_classes": null, "length": 8924} 140 | {"pred": "Paragraph 17]", "answers": ["Paragraph 6"], "all_classes": null, "length": 9991} 141 | {"pred": "Paragraph 1", "answers": ["Paragraph 1"], "all_classes": null, "length": 10601} 142 | {"pred": "Paragraph 1", "answers": ["Paragraph 4"], "all_classes": null, "length": 8559} 143 | {"pred": "Paragraph 10]", "answers": ["Paragraph 29"], "all_classes": null, "length": 10041} 144 | {"pred": " Paragraph 1", "answers": ["Paragraph 11"], "all_classes": null, "length": 8851} 145 | {"pred": "Paragraph 7]", "answers": ["Paragraph 7"], "all_classes": null, "length": 10303} 146 | {"pred": "Paragraph 2", "answers": ["Paragraph 22"], "all_classes": null, "length": 9065} 147 | {"pred": "\nParagraph 14", "answers": ["Paragraph 8"], "all_classes": null, "length": 8976} 148 | {"pred": "\nParagraph 28", "answers": ["Paragraph 28"], "all_classes": null, "length": 8863} 149 | {"pred": " Paragraph 1", "answers": ["Paragraph 23"], "all_classes": null, "length": 9616} 150 | {"pred": "Paragraph 1", "answers": ["Paragraph 8"], "all_classes": null, "length": 9063} 151 | {"pred": "Paragraph 1", "answers": ["Paragraph 11"], "all_classes": null, "length": 8464} 152 | {"pred": " Paragraph 29", "answers": ["Paragraph 28"], "all_classes": null, "length": 9939} 153 | {"pred": " Paragraph 14", "answers": ["Paragraph 16"], "all_classes": null, "length": 7539} 154 | {"pred": "Paragraph 1", "answers": ["Paragraph 2"], "all_classes": null, "length": 9321} 155 | {"pred": "Paragraph 1", "answers": ["Paragraph 1"], "all_classes": null, "length": 9976} 156 | {"pred": "Paragraph 1", "answers": ["Paragraph 2"], "all_classes": null, "length": 7814} 157 | {"pred": "Paragraph 28", "answers": ["Paragraph 28"], "all_classes": null, "length": 10274} 158 | {"pred": "\nParagraph 20", "answers": ["Paragraph 7"], "all_classes": null, "length": 9091} 159 | {"pred": "Paragraph 19]", "answers": ["Paragraph 27"], "all_classes": null, "length": 10887} 160 | {"pred": "Paragraph 1", "answers": ["Paragraph 16"], "all_classes": null, "length": 9628} 161 | {"pred": "\nParagraph 1", "answers": ["Paragraph 1"], "all_classes": null, "length": 8060} 162 | {"pred": "\nParagraph 27", "answers": ["Paragraph 26"], "all_classes": null, "length": 9395} 163 | {"pred": "\nParagraph 10", "answers": ["Paragraph 6"], "all_classes": null, "length": 8232} 164 | {"pred": "Paragraph 19]", "answers": ["Paragraph 15"], "all_classes": null, "length": 9132} 165 | {"pred": "\nParagraph 17", "answers": ["Paragraph 14"], "all_classes": null, "length": 8557} 166 | {"pred": " Paragraph 25", "answers": ["Paragraph 25"], "all_classes": null, "length": 9374} 167 | {"pred": "Paragraph 1", "answers": ["Paragraph 6"], "all_classes": null, "length": 9024} 168 | {"pred": "\nParagraph 2", "answers": ["Paragraph 2"], "all_classes": null, "length": 8492} 169 | {"pred": "Paragraph 1", "answers": ["Paragraph 5"], "all_classes": null, "length": 8944} 170 | {"pred": "\nParagraph 10", "answers": ["Paragraph 12"], "all_classes": null, "length": 8243} 171 | {"pred": "Paragraph 17", "answers": ["Paragraph 12"], "all_classes": null, "length": 8452} 172 | {"pred": "Paragraph 10", "answers": ["Paragraph 8"], "all_classes": null, "length": 8418} 173 | {"pred": "Paragraph 1", "answers": ["Paragraph 20"], "all_classes": null, "length": 9471} 174 | {"pred": "\nParagraph 25", "answers": ["Paragraph 25"], "all_classes": null, "length": 8888} 175 | {"pred": "\nParagraph 2", "answers": ["Paragraph 23"], "all_classes": null, "length": 8088} 176 | {"pred": "Paragraph 30", "answers": ["Paragraph 30"], "all_classes": null, "length": 10318} 177 | {"pred": "Paragraph 1", "answers": ["Paragraph 7"], "all_classes": null, "length": 10754} 178 | {"pred": " Paragraph 1", "answers": ["Paragraph 10"], "all_classes": null, "length": 8055} 179 | {"pred": "\nParagraph 17", "answers": ["Paragraph 30"], "all_classes": null, "length": 9471} 180 | {"pred": "Paragraph 10", "answers": ["Paragraph 9"], "all_classes": null, "length": 9415} 181 | {"pred": "Paragraph 17]", "answers": ["Paragraph 17"], "all_classes": null, "length": 10382} 182 | {"pred": "Paragraph 1", "answers": ["Paragraph 30"], "all_classes": null, "length": 10098} 183 | {"pred": "Paragraph 18]", "answers": ["Paragraph 10"], "all_classes": null, "length": 10972} 184 | {"pred": "Paragraph 27", "answers": ["Paragraph 26"], "all_classes": null, "length": 8937} 185 | {"pred": "\nParagraph 27", "answers": ["Paragraph 27"], "all_classes": null, "length": 8275} 186 | {"pred": "Paragraph 1", "answers": ["Paragraph 26"], "all_classes": null, "length": 9032} 187 | {"pred": "Paragraph 17", "answers": ["Paragraph 8"], "all_classes": null, "length": 7990} 188 | {"pred": "Paragraph 1", "answers": ["Paragraph 30"], "all_classes": null, "length": 9969} 189 | {"pred": "\nParagraph 17", "answers": ["Paragraph 13"], "all_classes": null, "length": 9356} 190 | {"pred": "Paragraph 1", "answers": ["Paragraph 12"], "all_classes": null, "length": 9827} 191 | {"pred": "Paragraph 24]", "answers": ["Paragraph 24"], "all_classes": null, "length": 8886} 192 | {"pred": "Paragraph 1", "answers": ["Paragraph 6"], "all_classes": null, "length": 7953} 193 | {"pred": "Paragraph 1", "answers": ["Paragraph 24"], "all_classes": null, "length": 9019} 194 | {"pred": "Paragraph 24]", "answers": ["Paragraph 25"], "all_classes": null, "length": 9668} 195 | {"pred": "Paragraph 24]", "answers": ["Paragraph 24"], "all_classes": null, "length": 9446} 196 | {"pred": "Paragraph 1]", "answers": ["Paragraph 1"], "all_classes": null, "length": 9333} 197 | {"pred": "Paragraph 25]", "answers": ["Paragraph 25"], "all_classes": null, "length": 9006} 198 | {"pred": "Paragraph 2", "answers": ["Paragraph 2"], "all_classes": null, "length": 10134} 199 | {"pred": "\nParagraph 27", "answers": ["Paragraph 27"], "all_classes": null, "length": 8214} 200 | {"pred": "Paragraph 1]", "answers": ["Paragraph 4"], "all_classes": null, "length": 10149} 201 | -------------------------------------------------------------------------------- /benchmarks/Pred_LongBench/result.json: -------------------------------------------------------------------------------- 1 | { 2 | "2wikimqa": 30.26, 3 | "passage_retrieval_en": 29.75, 4 | "passage_retrieval_zh": 3.96, 5 | "qasper": 29.1, 6 | "passage_count": 3.61, 7 | "gov_report": 31.53, 8 | "multifieldqa_zh": 8.48, 9 | "trec": 63.5, 10 | "multifieldqa_en": 37.15, 11 | "lsht": 26.0, 12 | "dureader": 15.25, 13 | "narrativeqa": 19.8, 14 | "lcc": 57.61, 15 | "musique": 17.14, 16 | "multi_news": 27.74, 17 | "qmsum": 24.13, 18 | "vcsum": 0.46, 19 | "samsum": 41.88, 20 | "repobench-p": 54.45, 21 | "triviaqa": 85.69, 22 | "hotpotqa": 37.01 23 | } -------------------------------------------------------------------------------- /benchmarks/README.md: -------------------------------------------------------------------------------- 1 | # Evaluation on LongBench and L-Eval Benchmarks 2 | 3 | We evaluate our supervised fine-tuned model, [LongAlpaca-7B-16k](https://huggingface.co/Yukang/LongAlpaca-7B-16k), on LongBench and L-Eval benchmarks. 4 | 5 | Table - Evaluation on LongBench English tasks 6 | | Model | Avg | Single-Doc QA | Multi-Doc QA | Summarization | Few-shot Learning | Code | Synthetic | 7 | | --- | --- | --- | --- | --- | --- | --- | --- | 8 | | GPT-3.5-Turbo | 44.0 | 39.8 | 38.7 | 26.5 | 67.1 | 54.1 | 37.8 | 9 | | Llama2-7B-chat | 31.0 | 24.9 | 22.6 | 24.7 | 60.0 | 48.1 | 5.9 | 10 | | Ours | 36.8 | 28.7 | 28.1 | 27.8 | 63.7 | 56.0 | 16.7 | 11 | 12 | The predictions can be found [here](https://github.com/dvlab-research/LongLoRA/tree/main/benchmarks/Pred_LongBench). 13 | 14 | 15 | Table 2 - Evaluation on L-Eval open-ended tasks, comparing to GPT-3.5-Turbo and judging win rates via GPT-4. 16 | | Model | Win-rate | Wins | Ties | 17 | | --- | --- | --- | --- | 18 | | Ours | 39.06 | 45 | 60 | 19 | 20 | The predictions can be found [here](https://github.com/dvlab-research/LongLoRA/tree/main/benchmarks/Pred_L-Eval). 21 | -------------------------------------------------------------------------------- /demo.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import math 4 | import torch 5 | import argparse 6 | import textwrap 7 | import transformers 8 | from peft import PeftModel 9 | from transformers import GenerationConfig, TextIteratorStreamer 10 | from llama_attn_replace import replace_llama_attn 11 | from threading import Thread 12 | import gradio as gr 13 | 14 | 15 | def parse_config(): 16 | parser = argparse.ArgumentParser(description='arg parser') 17 | parser.add_argument('--base_model', type=str, default="/data1/pretrained-models/llama-7b-hf") 18 | parser.add_argument('--cache_dir', type=str, default="./cache") 19 | parser.add_argument('--context_size', type=int, default=-1, help='context size during fine-tuning') 20 | parser.add_argument('--flash_attn', type=bool, default=True, help='') 21 | parser.add_argument('--temperature', type=float, default=0.6, help='') 22 | parser.add_argument('--top_p', type=float, default=0.9, help='') 23 | parser.add_argument('--max_gen_len', type=int, default=512, help='') 24 | parser.add_argument("--host", type=str, default="localhost") 25 | parser.add_argument("--port", type=int, default=8898) 26 | args = parser.parse_args() 27 | return args 28 | 29 | title = "LongLoRA and LongAlpaca for Long-context LLMs" 30 | 31 | description = """ 32 | 33 | This is the online demo of LongLoRA. \n 34 | If multiple users are using it at the same time, they will enter a queue, which may delay some time. \n 35 | **Inputs**:
36 | - **Input material txt** and **Question** are required.
37 | **Note**:
38 | - The demo model is **LongAlpaca-7B**. We use 4-bit quantization for low GPU memory inference, which may impair text-generation quality.
39 | - There are 10 book-related examples and 5 paper-related examples, 15 in total.
40 | - Note that only txt file is currently support.\n 41 | **Example questions**:
42 |   Please summarize the book in one paragraph.
43 |   Please tell me that what high-level idea the author want to indicate in this book.
44 |   Please describe the relationship among the roles in the book.
45 |   Please summarize the paper in one paragraph.
46 |   What is the main contribution of this paper?
47 | Hope you can enjoy our work! 48 |
49 | """ 50 | 51 | # Gradio 52 | article = """ 53 |

54 | 55 | Preprint Paper 56 | 57 | \n 58 |

59 | Github Repo

60 | """ 61 | 62 | PROMPT_DICT = { 63 | "prompt_no_input": ( 64 | "Below is an instruction that describes a task. " 65 | "Write a response that appropriately completes the request.\n\n" 66 | "### Instruction:\n{instruction}\n\n### Response:" 67 | ), 68 | "prompt_no_input_llama2":( 69 | "[INST] <>\n" 70 | "You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\n" 71 | "If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\n" 72 | "<> \n\n {instruction} [/INST]" 73 | ), 74 | } 75 | 76 | 77 | def read_txt_file(material_txt): 78 | content = "" 79 | with open(material_txt) as f: 80 | for line in f.readlines(): 81 | content += line 82 | return content 83 | 84 | def build_generator( 85 | model, tokenizer, temperature=0.6, top_p=0.9, max_gen_len=4096, use_cache=True 86 | ): 87 | def response(material, question): 88 | if material is None: 89 | return "Only support txt file." 90 | 91 | if not material.name.split(".")[-1]=='txt': 92 | return "Only support txt file." 93 | 94 | material = read_txt_file(material.name) 95 | prompt_no_input = PROMPT_DICT["prompt_no_input_llama2"] 96 | prompt = prompt_no_input.format_map({"instruction": material + "\n%s" % question}) 97 | 98 | inputs = tokenizer(prompt, return_tensors="pt").to(model.device) 99 | 100 | if len(inputs['input_ids'][0]) > 32768: 101 | return "This demo supports tokens less than 32768, while the current is %d. Please use material with less tokens."%len(inputs['input_ids'][0]) 102 | torch.cuda.empty_cache() 103 | 104 | streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) 105 | generate_kwargs = dict(**inputs, 106 | max_new_tokens=max_gen_len, 107 | temperature=temperature, 108 | top_p=top_p, 109 | use_cache=use_cache, 110 | streamer=streamer, 111 | ) 112 | 113 | t = Thread(target=model.generate, kwargs=generate_kwargs) 114 | t.start() 115 | 116 | generated_text = "" 117 | for new_text in streamer: 118 | generated_text += new_text 119 | yield generated_text 120 | return generated_text 121 | 122 | return response 123 | 124 | def main(args): 125 | if args.flash_attn: 126 | replace_llama_attn(inference=True) 127 | 128 | # Set RoPE scaling factor 129 | config = transformers.AutoConfig.from_pretrained( 130 | args.base_model, 131 | cache_dir=args.cache_dir, 132 | ) 133 | 134 | orig_ctx_len = getattr(config, "max_position_embeddings", None) 135 | if orig_ctx_len and args.context_size > orig_ctx_len: 136 | scaling_factor = float(math.ceil(args.context_size / orig_ctx_len)) 137 | config.rope_scaling = {"type": "linear", "factor": scaling_factor} 138 | 139 | # Load model and tokenizer 140 | model = transformers.AutoModelForCausalLM.from_pretrained( 141 | args.base_model, 142 | config=config, 143 | cache_dir=args.cache_dir, 144 | torch_dtype=torch.float16, 145 | load_in_4bit=True, 146 | device_map="auto", 147 | ) 148 | model.resize_token_embeddings(32001) 149 | 150 | tokenizer = transformers.AutoTokenizer.from_pretrained( 151 | args.base_model, 152 | cache_dir=args.cache_dir, 153 | model_max_length=args.context_size if args.context_size > orig_ctx_len else orig_ctx_len, 154 | padding_side="right", 155 | use_fast=False, 156 | ) 157 | 158 | model.eval() 159 | if torch.__version__ >= "2" and sys.platform != "win32": 160 | model = torch.compile(model) 161 | # import pdb; pdb.set_trace() 162 | respond = build_generator(model, tokenizer, temperature=args.temperature, top_p=args.top_p, 163 | max_gen_len=args.max_gen_len, use_cache=True) 164 | 165 | demo = gr.Interface( 166 | respond, 167 | inputs=[ 168 | gr.File(type="file", label="Input material txt"), 169 | gr.Textbox(lines=1, placeholder=None, label="Question"), 170 | ], 171 | outputs=[ 172 | gr.Textbox(lines=1, placeholder=None, label="Text Output"), 173 | ], 174 | title=title, 175 | description=description, 176 | article=article, 177 | allow_flagging="auto", 178 | ) 179 | 180 | demo.queue() 181 | demo.launch(server_name=args.host, server_port=args.port, show_error=True, share=True) 182 | 183 | if __name__ == "__main__": 184 | args = parse_config() 185 | main(args) 186 | -------------------------------------------------------------------------------- /ds_configs/stage2.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_micro_batch_size_per_gpu": "auto", 3 | "gradient_accumulation_steps": "auto", 4 | "gradient_clipping": "auto", 5 | "zero_allow_untested_optimizer": true, 6 | "bf16": { 7 | "enabled": "auto", 8 | "loss_scale": 0, 9 | "initial_scale_power": 16, 10 | "loss_scale_window": 1000, 11 | "hysteresis": 2, 12 | "min_loss_scale": 1 13 | }, 14 | "zero_optimization": { 15 | "stage": 2, 16 | "allgather_partitions": true, 17 | "allgather_bucket_size": 1e9, 18 | "reduce_scatter": true, 19 | "reduce_bucket_size": 1e9, 20 | "overlap_comm": true, 21 | "contiguous_gradients": true 22 | } 23 | } 24 | -------------------------------------------------------------------------------- /ds_configs/stage3.json: -------------------------------------------------------------------------------- 1 | { 2 | "bf16": { 3 | "enabled": "auto" 4 | }, 5 | "optimizer": { 6 | "type": "AdamW", 7 | "params": { 8 | "lr": "auto", 9 | "betas": "auto", 10 | "eps": "auto", 11 | "weight_decay": "auto" 12 | } 13 | }, 14 | "scheduler": { 15 | "type": "WarmupDecayLR", 16 | "params": { 17 | "total_num_steps": "auto", 18 | "warmup_min_lr": "auto", 19 | "warmup_max_lr": "auto", 20 | "warmup_num_steps": "auto" 21 | } 22 | }, 23 | "zero_optimization": { 24 | "stage": 3, 25 | "offload_optimizer": { 26 | "device": "cpu", 27 | "pin_memory": true 28 | }, 29 | "offload_param": { 30 | "device": "cpu", 31 | "pin_memory": true 32 | }, 33 | "overlap_comm": true, 34 | "contiguous_gradients": true, 35 | "sub_group_size": 1e9, 36 | "reduce_bucket_size": "auto", 37 | "stage3_prefetch_bucket_size": "auto", 38 | "stage3_param_persistence_threshold": "auto", 39 | "stage3_max_live_parameters": 1e9, 40 | "stage3_max_reuse_distance": 1e9, 41 | "stage3_gather_16bit_weights_on_model_save": false 42 | }, 43 | "gradient_accumulation_steps": "auto", 44 | "gradient_clipping": "auto", 45 | "steps_per_print": 5, 46 | "train_batch_size": "auto", 47 | "train_micro_batch_size_per_gpu": "auto", 48 | "wall_clock_breakdown": false 49 | } 50 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | # Written by Yukang Chen 2 | # Some code based on https://github.com/epfml/landmark-attention 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import os 17 | import math 18 | import torch 19 | import argparse 20 | import random 21 | import numpy as np 22 | from tqdm import tqdm 23 | import transformers 24 | from peft import PeftModel 25 | from llama_attn_replace import replace_llama_attn 26 | 27 | def parse_config(): 28 | parser = argparse.ArgumentParser(description='arg parser') 29 | parser.add_argument('--batch_size', type=int, default=32, help='batch size during inference') 30 | parser.add_argument('--base_model', type=str, default="/data1/pretrained-models/llama-7b-hf") 31 | parser.add_argument('--cache_dir', type=str, default="./cache") 32 | parser.add_argument('--seq_len', type=int, default=2048, help='context length during evaluation') 33 | parser.add_argument('--context_size', type=int, default=-1, help='context size during fine-tuning') 34 | parser.add_argument('--peft_model', type=str, default=None, help='') 35 | parser.add_argument('--flash_attn', type=bool, default=True, help='') 36 | parser.add_argument('--data_path', type=str, default="./test.bin", help='') 37 | args = parser.parse_args() 38 | return args 39 | 40 | def get_as_batch(data, seq_length, batch_size, device='cpu', sliding_window=256): 41 | all_ix = list(range(0, len(data) - seq_length, sliding_window)) 42 | all_ix.pop() 43 | 44 | for idx in range(0, len(all_ix), batch_size): 45 | ix = all_ix[idx:idx+batch_size] 46 | assert all([idx + seq_length + 1 <= len(data) for idx in ix]) 47 | x = torch.stack([torch.from_numpy((data[i:i+seq_length]).astype(np.int64)) for i in ix]) 48 | y = torch.stack([torch.from_numpy((data[i+1:i+1+seq_length]).astype(np.int64)) for i in ix]) 49 | if device != 'cpu': 50 | x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True) 51 | yield x, y 52 | 53 | def iceildiv(x, y): 54 | return (x + y - 1) // y 55 | 56 | def evaluate(model, data, batch_size, device, seq_length, sliding_window=256, use_cache=False): 57 | stats = {} 58 | 59 | model.eval() 60 | 61 | loss_list_val, acc_list = [], [] 62 | loss_step_list_val = [] 63 | 64 | with torch.no_grad(): 65 | print(f"Using seq length {seq_length}") 66 | torch.set_printoptions(sci_mode=False) 67 | for idx, (x, y) in tqdm( 68 | enumerate( 69 | get_as_batch( 70 | data['val'], 71 | seq_length, 72 | batch_size, 73 | device=device, 74 | sliding_window=sliding_window 75 | ) 76 | ), 77 | total=iceildiv( 78 | iceildiv(len(data['val']), sliding_window), 79 | batch_size 80 | ) 81 | ): 82 | val_loss = 0. 83 | acc = 0. 84 | cnt = 0 85 | 86 | for part_idx, i in enumerate(range(0, x.shape[1], seq_length)): 87 | part_len = x[:, i:i + seq_length].shape[1] 88 | 89 | outputs = model( 90 | input_ids=x[:, i:i + seq_length], 91 | labels=x[:, i:i+seq_length].contiguous(), 92 | use_cache=use_cache) 93 | 94 | val_loss = outputs.loss * part_len + val_loss 95 | acc = ((outputs.logits.argmax(-1) == y[:, i:i+seq_length]).float().sum()) + acc 96 | cnt += part_len 97 | while len(loss_step_list_val) <= part_idx: 98 | loss_step_list_val.append([]) 99 | loss_step_list_val[part_idx].append(outputs.loss.item()) 100 | val_loss /= cnt 101 | acc /= cnt 102 | 103 | loss_list_val.append(val_loss.item()) 104 | acc_list.append(acc.item()) 105 | 106 | stats['val_acc'] = torch.as_tensor(acc_list).mean().item() 107 | stats['val_loss'] = torch.as_tensor(loss_list_val).mean().item() 108 | stats['val_perplexity'] = 2.71828 ** stats['val_loss'] 109 | stats['val_perplexity_per_chunk'] = torch.exp(torch.as_tensor(loss_step_list_val).mean(dim=1)) 110 | 111 | return stats 112 | 113 | def main(args): 114 | 115 | device = "cuda:0" 116 | seed = 2 117 | torch.cuda.set_device(device) 118 | 119 | torch.manual_seed(seed) 120 | random.seed(seed) 121 | np.random.seed(seed) 122 | 123 | data = {'val': np.memmap(args.data_path, dtype=np.uint16, mode='r')} 124 | 125 | print(f"Num validation tokens: {len(data['val'])}") 126 | print("data path", args.data_path) 127 | print("base model", args.base_model) 128 | print("peft model", args.peft_model) 129 | 130 | if args.flash_attn: 131 | replace_llama_attn(use_flash_attn=True, use_full=True) 132 | 133 | # Set RoPE scaling factor 134 | config = transformers.AutoConfig.from_pretrained( 135 | args.base_model, 136 | cache_dir=args.cache_dir, 137 | ) 138 | 139 | context_size = args.context_size if args.context_size > 0 else args.seq_len 140 | orig_ctx_len = getattr(config, "max_position_embeddings", None) # this value should be 4096 for LLaMA2 models 141 | if orig_ctx_len and context_size > orig_ctx_len: 142 | scaling_factor = float(math.ceil(context_size / orig_ctx_len)) 143 | config.rope_scaling = {"type": "linear", "factor": scaling_factor} 144 | 145 | # Load model and tokenizer 146 | model = transformers.AutoModelForCausalLM.from_pretrained( 147 | args.base_model, 148 | config=config, 149 | cache_dir=args.cache_dir, 150 | torch_dtype=torch.float16, 151 | device_map="auto", 152 | ) 153 | model.resize_token_embeddings(32001) 154 | 155 | if args.peft_model: 156 | trainable_params = os.path.join(args.peft_model, "trainable_params.bin") 157 | if os.path.isfile(trainable_params): 158 | model.load_state_dict(torch.load(trainable_params, map_location=model.device), strict=False) 159 | else: 160 | raise ValueError("Trainable input embedding and normalization are required.") 161 | model = PeftModel.from_pretrained( 162 | model, 163 | args.peft_model, 164 | device_map="auto", 165 | torch_dtype=torch.float16, 166 | ) 167 | 168 | stats = evaluate(model, data, args.batch_size, device, args.seq_len, sliding_window=256) 169 | 170 | print(stats) 171 | 172 | 173 | if __name__ == "__main__": 174 | args = parse_config() 175 | main(args) 176 | -------------------------------------------------------------------------------- /eval_distributed.py: -------------------------------------------------------------------------------- 1 | # Written by Yukang Chen 2 | # Some code based on https://github.com/epfml/landmark-attention 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import os 17 | from dataclasses import dataclass, field 18 | from typing import Optional 19 | 20 | import math 21 | import random 22 | import transformers 23 | from peft import PeftModel 24 | 25 | from llama_attn_replace import replace_llama_attn 26 | from torch.distributed import init_process_group, destroy_process_group 27 | from torchmetrics import Accuracy 28 | from torchmetrics.text import Perplexity 29 | from torch.nn import CrossEntropyLoss 30 | 31 | import inspect 32 | from abc import ABC, abstractmethod 33 | from typing import Union 34 | 35 | from torch.utils.data import Dataset, DataLoader, DistributedSampler 36 | from transformers.modeling_utils import PreTrainedModel 37 | from torch import nn 38 | from torch.nn.parallel import DistributedDataParallel as DDP 39 | from tqdm import tqdm 40 | 41 | 42 | import numpy as np 43 | import torch 44 | 45 | 46 | class Pg19Dataset(Dataset): 47 | def __init__(self, data_path: str, seq_length: int, sliding_window: int = 256): 48 | assert seq_length >= sliding_window, f"Sliding window '{sliding_window}' must be smaller than sequence length '{seq_length}'" 49 | 50 | self.seq_length = seq_length 51 | self.data = np.memmap(data_path, dtype=np.uint16, mode='r') 52 | self.start_indices = list(range(0, len(self.data) - seq_length, sliding_window)) 53 | 54 | assert len(self) > 0, "Dataset is empty" 55 | 56 | def __len__(self): 57 | return len(self.start_indices) 58 | # return 1000 59 | 60 | def __getitem__(self, index) -> dict[str, torch.Tensor]: 61 | start = self.start_indices[index] 62 | end = start + self.seq_length 63 | 64 | input_id = torch.from_numpy(self.data[start: end].astype(np.int64)) 65 | y = torch.from_numpy(self.data[start+1: end+1].astype(np.int64)) 66 | return { 67 | "input_ids": input_id, 68 | "labels": input_id, 69 | "ys": y 70 | } 71 | 72 | def num_tokens(self): 73 | return len(self.data) 74 | 75 | 76 | class EvalMetric(ABC): 77 | @abstractmethod 78 | def add(self, logits: torch.FloatTensor, labels: torch.LongTensor, model_output: object) -> dict[str, object]: 79 | pass 80 | 81 | @abstractmethod 82 | def compute(self) -> dict[str, object]: 83 | pass 84 | 85 | 86 | class DistributedEvaluator: 87 | def __init__(self, 88 | model: Union[PreTrainedModel, nn.Module], 89 | batch_size: int, 90 | refresh_rate: int, 91 | gpu_id: int): 92 | self.gpu_id = gpu_id 93 | self.batch_size = batch_size 94 | self.refresh_rate = refresh_rate 95 | 96 | self.model = DDP(model, device_ids=[self.gpu_id]) 97 | 98 | def evaluate(self, dataset: Dataset, metric: EvalMetric) -> dict[str, object]: 99 | data_loader = self._prepare_dataloader(dataset) 100 | self.model.eval() 101 | with torch.no_grad(): 102 | if self.is_first_device(): 103 | data_loader = tqdm(data_loader) 104 | for i, example_dict in enumerate(data_loader): 105 | sig = inspect.signature(self.model.forward) 106 | used = set(list(sig.parameters.keys()) + ["input_ids", "labels"]) 107 | inputs = {key: example_dict[key].to(self.gpu_id) for key in used if key in example_dict} 108 | outputs = self.model(**inputs) 109 | metric_result = metric.add(logits=outputs["logits"], labels=inputs["labels"], model_output=outputs) 110 | 111 | if self.is_first_device() and (i % self.refresh_rate == 0): 112 | data_loader.set_postfix(metric_result) 113 | return metric.compute() 114 | 115 | def is_first_device(self): 116 | return self.gpu_id == 0 117 | 118 | def _prepare_dataloader(self, dataset: Dataset): 119 | return DataLoader( 120 | dataset, 121 | batch_size=self.batch_size, 122 | pin_memory=True, 123 | shuffle=False, 124 | sampler=DistributedSampler(dataset) 125 | ) 126 | 127 | 128 | class EvalMetricImpl(EvalMetric): 129 | def __init__(self, vocab_size: int, gpu_id: int): 130 | self.accuracy = Accuracy(task="multiclass", num_classes=vocab_size).to(gpu_id) 131 | self.perplexity = Perplexity(ignore_index=CrossEntropyLoss().ignore_index).to(gpu_id) 132 | self.last_loss = 0.0 133 | 134 | def add(self, logits: torch.FloatTensor, labels: torch.LongTensor, model_output: object) -> dict[str, object]: 135 | shift_predictions = logits.argmax(dim=-1)[..., :-1] 136 | shift_labels = labels[..., 1:] 137 | 138 | current_accuracy = self.accuracy.forward(preds=shift_predictions, target=shift_labels) 139 | 140 | shift_logits = logits[..., :-1, :] 141 | current_perplexity = self.perplexity.forward(preds=shift_logits, target=shift_labels) 142 | 143 | self.last_loss = model_output["loss"].item() 144 | return { 145 | "accuracy": current_accuracy.item(), 146 | "perplexity": current_perplexity.item(), 147 | "loss": self.last_loss 148 | } 149 | 150 | def compute(self) -> dict[str, object]: 151 | current_accuracy = self.accuracy.compute() 152 | current_perplexity = self.perplexity.compute() 153 | return { 154 | "accuracy": current_accuracy.item(), 155 | "perplexity": current_perplexity.item(), 156 | "loss": self.last_loss 157 | } 158 | 159 | 160 | @dataclass 161 | class EvalArguments: 162 | batch_size: int = field( 163 | default=1, 164 | metadata={"help": "batch size."}, 165 | ) 166 | base_model: Optional[str] = field(default="meta-llama/Llama-2-7b-hf") 167 | seq_len: int = field( 168 | default=2048, 169 | metadata={"help": "context length during evaluation."}, 170 | ) 171 | context_size: int = field( 172 | default=-1, 173 | metadata={"help": "context size during fine-tuning."}, 174 | ) 175 | peft_model: Optional[str] = field(default=None) 176 | flash_attn: bool = field( 177 | default=True, 178 | metadata={"help": "Whether use flash attention."}, 179 | ) 180 | data_path: str = field( 181 | default="./test.bin", 182 | metadata={"help": "test data path"}, 183 | ) 184 | cache_dir: Optional[str] = field(default="./.cache") 185 | progress_bar_fresh_rate: int = field( 186 | default=10, 187 | metadata={"help": "progress bar metrics fresh rate."}, 188 | ) 189 | 190 | 191 | def run_eval(args: EvalArguments): 192 | torch_dtype = torch.float16 193 | 194 | seed = 2 195 | torch.manual_seed(seed) 196 | random.seed(seed) 197 | np.random.seed(seed) 198 | 199 | dataset = Pg19Dataset(args.data_path, seq_length=args.seq_len, sliding_window=256) 200 | if args.flash_attn: 201 | replace_llama_attn(use_flash_attn=True, use_full=True) 202 | 203 | # Set RoPE scaling factor 204 | config = transformers.AutoConfig.from_pretrained( 205 | args.base_model, 206 | cache_dir=args.cache_dir, 207 | use_cache=False 208 | ) 209 | 210 | context_size = args.context_size if args.context_size > 0 else args.seq_len 211 | orig_ctx_len = getattr(config, "max_position_embeddings", None) # this value should be 4096 for LLaMA2 models 212 | if orig_ctx_len and context_size > orig_ctx_len: 213 | scaling_factor = float(math.ceil(context_size / orig_ctx_len)) 214 | config.rope_scaling = {"type": "linear", "factor": scaling_factor} 215 | 216 | # Load model and tokenizer 217 | model = transformers.AutoModelForCausalLM.from_pretrained( 218 | args.base_model, 219 | config=config, 220 | cache_dir=args.cache_dir, 221 | torch_dtype=torch_dtype) 222 | model.resize_token_embeddings(32001) 223 | 224 | if args.peft_model: 225 | trainable_params = os.path.join(args.peft_model, "trainable_params.bin") 226 | if os.path.isfile(trainable_params): 227 | model.load_state_dict(torch.load(trainable_params, map_location=model.device), strict=False) 228 | else: 229 | raise ValueError("Trainable input embedding and normalization are required.") 230 | model = PeftModel.from_pretrained( 231 | model, 232 | args.peft_model, 233 | torch_dtype=torch_dtype, 234 | offload_folder=args.cache_dir, 235 | ) 236 | 237 | # This is a hacky way to enable distributed evaluation. Otherwise, without any trainable parameters, we will not 238 | # be able to use DistributedDataParallel, although we don't update any parameters during evaluation. 239 | [p.requires_grad_() for n, p in model.named_parameters() if any([k in n for k in ["lm_head"]])] 240 | 241 | gpu_id = int(os.environ["LOCAL_RANK"]) 242 | model.to(gpu_id) 243 | 244 | evaluator = DistributedEvaluator( 245 | model=model, 246 | batch_size=args.batch_size, 247 | refresh_rate=args.progress_bar_fresh_rate, 248 | gpu_id=gpu_id) 249 | 250 | if evaluator.is_first_device(): 251 | print("data path", args.data_path) 252 | print("base model", args.base_model) 253 | print("peft model", args.peft_model) 254 | print(f"Num validation tokens: {dataset.num_tokens()}, Num validation examples: {len(dataset)}") 255 | 256 | eval_metric = EvalMetricImpl(vocab_size=config.vocab_size, gpu_id=gpu_id) 257 | result = evaluator.evaluate(dataset, eval_metric) 258 | if evaluator.is_first_device(): 259 | print(result) 260 | 261 | 262 | def ddp_setup(): 263 | init_process_group(backend="nccl") 264 | 265 | 266 | def main(cmd_args: list[str] = None): 267 | ddp_setup() 268 | parser = transformers.HfArgumentParser((EvalArguments, )) 269 | args: EvalArguments = parser.parse_args_into_dataclasses(cmd_args)[0] 270 | try: 271 | run_eval(args) 272 | finally: 273 | destroy_process_group() 274 | 275 | 276 | if __name__ == "__main__": 277 | main() 278 | -------------------------------------------------------------------------------- /fine-tune.py: -------------------------------------------------------------------------------- 1 | # Written by Yukang Chen 2 | # Some code based on https://github.com/epfml/landmark-attention 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import os 17 | import math 18 | from dataclasses import dataclass, field 19 | from functools import partial 20 | from typing import Dict, Optional, Sequence 21 | 22 | import torch 23 | import transformers 24 | from torch.utils.data import Dataset 25 | from transformers import Trainer, DataCollatorForLanguageModeling 26 | from llama_attn_replace import replace_llama_attn 27 | from gptneox_attn_replace import replace_gpt_neox_attn 28 | from peft import LoraConfig, get_peft_model 29 | from torch.distributed import barrier 30 | 31 | 32 | from datasets import load_dataset 33 | 34 | IGNORE_INDEX = -100 35 | DEFAULT_PAD_TOKEN = "[PAD]" 36 | DEFAULT_EOS_TOKEN = "" 37 | DEFAULT_BOS_TOKEN = "" 38 | DEFAULT_UNK_TOKEN = "" 39 | 40 | 41 | @dataclass 42 | class ModelArguments: 43 | model_name_or_path: Optional[str] = field(default="EleutherAI/pythia-1.4b-deduped") 44 | model_type: Optional[str] = field(default="llama") 45 | 46 | @dataclass 47 | class TrainingArguments(transformers.TrainingArguments): 48 | cache_dir: Optional[str] = field(default=None) 49 | optim: str = field(default="adamw_torch") 50 | model_max_length: int = field( 51 | default=8192 * 4, 52 | metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."}, 53 | ) 54 | use_flash_attn: bool = field( 55 | default=True, 56 | metadata={"help": "Whether use flash attention for training."}, 57 | ) 58 | use_full_attn: bool = field( 59 | default=False, 60 | metadata={"help": "Whether to use plain, full-attention for training."}, 61 | ) 62 | low_rank_training: bool = field( 63 | default=True, 64 | metadata={"help": "Whether use low rank adaptation for training."}, 65 | ) 66 | trainable_params: str = field( 67 | default="embed,norm", 68 | metadata={"help": "Additional trainable parameters except LoRA weights, if low rank training."}, 69 | ) 70 | 71 | def smart_tokenizer_and_embedding_resize( 72 | special_tokens_dict: Dict, 73 | tokenizer: transformers.PreTrainedTokenizer, 74 | model: transformers.PreTrainedModel, 75 | ): 76 | """Resize tokenizer and embedding. 77 | 78 | Note: This is the unoptimized version that may make your embedding size not be divisible by 64. 79 | """ 80 | num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict) 81 | model.resize_token_embeddings(len(tokenizer)) 82 | 83 | if num_new_tokens > 0: 84 | input_embeddings = model.get_input_embeddings().weight.data 85 | output_embeddings = model.get_output_embeddings().weight.data 86 | 87 | input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True) 88 | output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True) 89 | 90 | input_embeddings[-num_new_tokens:] = input_embeddings_avg 91 | output_embeddings[-num_new_tokens:] = output_embeddings_avg 92 | 93 | def tokenize_fn(tokenizer, example): 94 | context_length = tokenizer.model_max_length 95 | outputs = tokenizer( 96 | tokenizer.eos_token.join(example["text"]), 97 | truncation=False, 98 | return_tensors="pt", 99 | pad_to_multiple_of=context_length, 100 | padding=True, 101 | ) 102 | return {"input_ids": outputs["input_ids"].view(-1, context_length)} 103 | 104 | def train(): 105 | parser = transformers.HfArgumentParser((ModelArguments, TrainingArguments)) 106 | model_args, training_args = parser.parse_args_into_dataclasses() 107 | 108 | # NOTE: May expand supported model types in the future 109 | if model_args.model_type == "gpt-neox": 110 | replace_gpt_neox_attn(training_args.use_flash_attn, training_args.use_full_attn) 111 | else: 112 | assert model_args.model_type == "llama", "Only support llama and gpt-neox for now" 113 | replace_llama_attn(training_args.use_flash_attn, training_args.use_full_attn) 114 | 115 | # Set RoPE scaling factor 116 | config = transformers.AutoConfig.from_pretrained( 117 | model_args.model_name_or_path, 118 | cache_dir=training_args.cache_dir, 119 | ) 120 | 121 | orig_rope_scaling = getattr(config, "rope_scaling", None) 122 | if orig_rope_scaling is None: 123 | orig_rope_scaling = {"factor": 1} 124 | 125 | orig_rope_scaling_factor = orig_rope_scaling["factor"] if "factor" in orig_rope_scaling.keys() else 1 126 | orig_ctx_len = getattr(config, "max_position_embeddings", None) 127 | if orig_ctx_len: 128 | orig_ctx_len *= orig_rope_scaling_factor 129 | if training_args.model_max_length > orig_ctx_len: 130 | scaling_factor = float(math.ceil(training_args.model_max_length / orig_ctx_len)) 131 | config.rope_scaling = {"type": "linear", "factor": scaling_factor} 132 | 133 | # Load model and tokenizer 134 | model = transformers.AutoModelForCausalLM.from_pretrained( 135 | model_args.model_name_or_path, 136 | config=config, 137 | cache_dir=training_args.cache_dir, 138 | torch_dtype=torch.bfloat16, 139 | ) 140 | 141 | tokenizer = transformers.AutoTokenizer.from_pretrained( 142 | model_args.model_name_or_path, 143 | cache_dir=training_args.cache_dir, 144 | model_max_length=training_args.model_max_length, 145 | padding_side="right", 146 | use_fast=True, 147 | ) 148 | 149 | special_tokens_dict = dict() 150 | if tokenizer.pad_token is None: 151 | special_tokens_dict["pad_token"] = DEFAULT_PAD_TOKEN 152 | if tokenizer.eos_token is None: 153 | special_tokens_dict["eos_token"] = DEFAULT_EOS_TOKEN 154 | if tokenizer.bos_token is None: 155 | special_tokens_dict["bos_token"] = DEFAULT_BOS_TOKEN 156 | if tokenizer.unk_token is None: 157 | special_tokens_dict["unk_token"] = DEFAULT_UNK_TOKEN 158 | 159 | smart_tokenizer_and_embedding_resize( 160 | special_tokens_dict=special_tokens_dict, 161 | tokenizer=tokenizer, 162 | model=model, 163 | ) 164 | 165 | rank = int(os.environ.get('RANK', -1)) 166 | if rank > 0: 167 | barrier() 168 | dataset = load_dataset("togethercomputer/RedPajama-Data-1T-Sample", cache_dir=training_args.cache_dir) 169 | dataset = dataset.map(partial(tokenize_fn,tokenizer),batched=True, num_proc=128, remove_columns=["text", "meta"]) 170 | 171 | if rank == 0: 172 | barrier() 173 | 174 | print(dataset) 175 | 176 | data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) 177 | 178 | if training_args.low_rank_training: 179 | if model_args.model_type == "gpt-neox": 180 | # added `dense` to match with llama as the basic LoRA would only target 'query_key_value' 181 | targets = ["query_key_value", "dense"] 182 | else: 183 | targets=["q_proj", "k_proj", "v_proj", "o_proj"] 184 | 185 | config = LoraConfig( 186 | r=8, 187 | lora_alpha=16, 188 | target_modules=targets, 189 | lora_dropout=0, 190 | bias="none", 191 | task_type="CAUSAL_LM", 192 | ) 193 | model = get_peft_model(model, config) 194 | # enable trainable params 195 | [p.requires_grad_() for n, p in model.named_parameters() if any([k in n for k in training_args.trainable_params.split(",")])] 196 | 197 | model.config.use_cache = False # required for gradient checkpointing 198 | model.enable_input_require_grads() # required for gradient checkpointing 199 | model.gradient_checkpointing_enable() # enable gradient checkpointing 200 | trainer = Trainer( 201 | model=model, tokenizer=tokenizer, args=training_args, 202 | train_dataset=dataset["train"], 203 | eval_dataset=None, 204 | data_collator=data_collator) 205 | trainer.train() 206 | trainer.save_state() 207 | trainer.save_model(output_dir=training_args.output_dir) 208 | 209 | 210 | if __name__ == "__main__": 211 | train() 212 | -------------------------------------------------------------------------------- /get_trainable_weights.py: -------------------------------------------------------------------------------- 1 | # Written by Yukang Chen 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | import torch 17 | import argparse 18 | 19 | def parse_config(): 20 | parser = argparse.ArgumentParser(description='arg parser') 21 | parser.add_argument('--checkpoint_path', type=str, default="/dataset/models/checkpoint-1000") 22 | parser.add_argument('--trainable_params', type=str, default="embed,norm") 23 | args = parser.parse_args() 24 | return args 25 | 26 | 27 | def main(args): 28 | path = args.checkpoint_path 29 | trainable_params = args.trainable_params.split(",") 30 | 31 | weights_all = torch.load(os.path.join(path, "pytorch_model.bin")) 32 | 33 | weights_trainable = {} 34 | weights_lora = {} 35 | for k in weights_all: 36 | if "lora" in k: 37 | k_new = k.replace("default.", "") if "default." in k else k 38 | weights_lora[k_new] = weights_all[k] 39 | else: 40 | if any([n in k for n in trainable_params]): 41 | weights_trainable[k[17:]] = weights_all[k] 42 | 43 | adapter_model = os.path.join(path, "adapter_model.bin") 44 | trainable_params = os.path.join(path, "trainable_params.bin") 45 | if not os.path.isfile(adapter_model): 46 | torch.save(weights_lora, adapter_model) 47 | torch.save(weights_trainable, trainable_params) 48 | 49 | if __name__ == "__main__": 50 | args = parse_config() 51 | main(args) 52 | -------------------------------------------------------------------------------- /gptneox_attn_replace.py: -------------------------------------------------------------------------------- 1 | # Modified based on https://github.com/dvlab-research/LongLoRA 2 | 3 | from typing import Optional, Tuple 4 | import warnings 5 | import torch 6 | import transformers 7 | 8 | from einops import rearrange 9 | from flash_attn import flash_attn_varlen_qkvpacked_func, flash_attn_varlen_func 10 | from flash_attn.bert_padding import unpad_input, pad_input 11 | 12 | 13 | group_size_ratio = 1/4 14 | 15 | def rotate_half(x): 16 | """Rotates half the hidden dims of the input.""" 17 | x1 = x[..., : x.shape[-1] // 2] 18 | x2 = x[..., x.shape[-1] // 2 :] 19 | return torch.cat((-x2, x1), dim=-1) 20 | 21 | def apply_rotary_pos_emb(q, k, cos, sin, position_ids): 22 | gather_indices = position_ids[:, None, :, None] # [bs, 1, seq_len, 1] 23 | gather_indices = gather_indices.repeat(1, cos.shape[1], 1, cos.shape[3]) 24 | cos = torch.gather(cos.repeat(gather_indices.shape[0], 1, 1, 1).to(q.dtype), 2, gather_indices) 25 | sin = torch.gather(sin.repeat(gather_indices.shape[0], 1, 1, 1).to(k.dtype), 2, gather_indices) 26 | q_embed = (q * cos) + (rotate_half(q) * sin) 27 | k_embed = (k * cos) + (rotate_half(k) * sin) 28 | return q_embed, k_embed 29 | 30 | 31 | def _flash_attn_ssa(query, key, value, attention_mask=None, head_mask=None): 32 | # transform the data into the qkv packed form 33 | qkv = torch.stack( 34 | [query, key, value], dim=2 35 | ) # [bsz, nh, 3, q_len, hd] 36 | qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd] 37 | bsz, q_len = qkv.shape[:2] 38 | 39 | qkv = rearrange(qkv, "b s ... -> (b s) ...") 40 | cu_q_lens = torch.arange(0, (bsz + 1) * q_len, step=q_len, dtype=torch.int32, device=qkv.device) 41 | output = flash_attn_varlen_qkvpacked_func(qkv, cu_q_lens, q_len, 0.0, softmax_scale=None, causal=True) 42 | output = rearrange(output, "(b s) ... -> b s ...", b=bsz) 43 | 44 | # disable attn weights by returning None when using flash attention 45 | return output, None 46 | 47 | def _flash_attn_full(query, key, value, attention_mask=None, head_mask=None): 48 | # q, k, v: [bs, nh, seq_len, hd] 49 | batch_size, num_attention_heads, query_length, attn_head_size = query.size() 50 | key_length = key.size(-2) 51 | value_length = value.size(-2) 52 | 53 | # q, k, v: [bs, nh, seq_len, hd] -> [bs, seq_len, nh, hd] -> [bs * seq_len, nh, hd] 54 | query = query.transpose(1, 2).reshape(batch_size * query_length , num_attention_heads, attn_head_size) 55 | key = key.transpose(1, 2).reshape(batch_size * key_length, num_attention_heads, attn_head_size) 56 | value = value.transpose(1, 2).reshape(batch_size * value_length, num_attention_heads, attn_head_size) 57 | 58 | cu_seqlens_q = torch.arange( 59 | 0, 60 | (batch_size + 1) * query_length, 61 | step=query_length, 62 | dtype=torch.int32, 63 | device=query.device, 64 | ) 65 | 66 | cu_seqlens_k = torch.arange( 67 | 0, 68 | (batch_size + 1) * key_length, 69 | step=key_length, 70 | dtype=torch.int32, 71 | device=key.device, 72 | ) 73 | 74 | attn_output, attn_weights, _ = flash_attn_varlen_func( 75 | query, key, value, cu_seqlens_q, cu_seqlens_k, query_length, value_length, dropout_p=0.0, 76 | softmax_scale=None, causal=True, return_attn_probs=True 77 | ) 78 | 79 | attn_output = attn_output.view(batch_size, query_length, num_attention_heads, attn_head_size).transpose(1, 2) 80 | return attn_output, attn_weights 81 | 82 | 83 | def get_forward_function(use_flash_attn=True, use_full=False): 84 | 85 | def forward_attention( 86 | self, 87 | hidden_states: torch.FloatTensor, 88 | attention_mask: torch.FloatTensor, 89 | position_ids: torch.LongTensor, 90 | head_mask: Optional[torch.FloatTensor] = None, 91 | layer_past: Optional[Tuple[torch.Tensor]] = None, 92 | use_cache: Optional[bool] = False, 93 | output_attentions: Optional[bool] = False, 94 | ): 95 | # NOTE: compute SS group size 96 | bsz, q_len, _ = hidden_states.size() 97 | has_layer_past = layer_past is not None 98 | 99 | # Compute QKV 100 | # Attention heads [batch, seq_len, hidden_size] 101 | # --> [batch, seq_len, (np * 3 * head_size)] 102 | qkv = self.query_key_value(hidden_states) 103 | 104 | # [batch, seq_len, (num_heads * 3 * head_size)] 105 | # --> [batch, seq_len, num_heads, 3 * head_size] 106 | new_qkv_shape = qkv.size()[:-1] + (self.num_attention_heads, 3 * self.head_size) 107 | qkv = qkv.view(*new_qkv_shape) 108 | 109 | # [batch, seq_len, num_attention_heads, 3 * head_size] 110 | # --> 3 [batch, num_attention_heads, seq_len, head_size] 111 | query = qkv[..., : self.head_size].permute(0, 2, 1, 3) 112 | key = qkv[..., self.head_size : 2 * self.head_size].permute(0, 2, 1, 3) 113 | value = qkv[..., 2 * self.head_size :].permute(0, 2, 1, 3) 114 | # [bsz, nh, q_len, hd] 115 | 116 | # Compute rotary embeddings on rotary_ndims 117 | query_rot = query[..., : self.rotary_ndims] 118 | query_pass = query[..., self.rotary_ndims :] 119 | key_rot = key[..., : self.rotary_ndims] 120 | key_pass = key[..., self.rotary_ndims :] 121 | 122 | # Compute token offset for rotary embeddings (when decoding) 123 | seq_len = key.shape[-2] 124 | if has_layer_past: 125 | seq_len += layer_past[0].shape[-2] 126 | cos, sin = self.rotary_emb(value, seq_len=seq_len) 127 | query, key = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, position_ids) 128 | query = torch.cat((query, query_pass), dim=-1) 129 | key = torch.cat((key, key_pass), dim=-1) 130 | 131 | # Cache QKV values 132 | if has_layer_past: 133 | past_key = layer_past[0] 134 | past_value = layer_past[1] 135 | key = torch.cat((past_key, key), dim=-2) 136 | value = torch.cat((past_value, value), dim=-2) 137 | present = (key, value) if use_cache else None 138 | 139 | # NOTE: apply shift 140 | group_size = int(q_len * group_size_ratio) 141 | if q_len % group_size > 0: 142 | raise ValueError("q_len %d should be divisible by group size %d." % (q_len, group_size)) 143 | num_group = q_len // group_size 144 | if self.training and not use_full: 145 | def shift(qkv, num_heads, head_dim): 146 | # qkv = [bsz, nh, q_len, d] 147 | qkv = qkv.transpose(1, 2) 148 | # qkv = [bsz, q_len, nh, d] 149 | qkv[:, :, num_heads//2:] = qkv[:, :, num_heads//2:].roll(-group_size//2, dims=1) 150 | 151 | # -> [bsz * n_group, group_s, nh, d) 152 | # -> [bsz * n_group, nh, group_s, d) 153 | qkv = qkv.reshape(bsz * num_group, group_size, num_heads, head_dim).transpose(1, 2) 154 | return qkv 155 | 156 | # contiguous is required as self._attn() will attempt to apply .view() on them 157 | query = shift(query, self.num_attention_heads, self.head_size).contiguous() 158 | key = shift(key, self.num_attention_heads, self.head_size).contiguous() 159 | value = shift(value, self.num_attention_heads, self.head_size).contiguous() 160 | 161 | attention_mask = attention_mask[:, :, :group_size, :group_size].repeat(num_group, 1, 1, 1) 162 | 163 | # Compute attention 164 | if use_flash_attn: 165 | _flash_attn = _flash_attn_full if use_full else _flash_attn_ssa 166 | attn_output, attn_weights = _flash_attn(query, key, value, attention_mask, head_mask) 167 | else: 168 | attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask) 169 | 170 | # NOTE: shift back 171 | if self.training and not use_full: 172 | attn_output = attn_output.transpose(1, 2).contiguous() 173 | attn_output = attn_output.reshape(bsz, q_len, self.num_attention_heads, self.head_size) 174 | # [bsz, q_len, nh, hd] 175 | attn_output[:, :, self.num_attention_heads//2:] = attn_output[:, :, self.num_attention_heads//2:].roll(group_size//2, dims=1) 176 | attn_output = attn_output.transpose(1, 2) 177 | 178 | # Reshape outputs 179 | attn_output = self._merge_heads(attn_output, self.num_attention_heads, self.head_size) 180 | attn_output = self.dense(attn_output) 181 | 182 | outputs = (attn_output, present) 183 | if output_attentions: 184 | outputs += (attn_weights,) 185 | 186 | return outputs 187 | 188 | return forward_attention 189 | 190 | 191 | def replace_gpt_neox_attn(use_flash_attn=True, use_full=False): 192 | cuda_major, cuda_minor = torch.cuda.get_device_capability() 193 | if use_flash_attn and cuda_major < 8: 194 | warnings.warn( 195 | "Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward." 196 | "ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593" 197 | "Resorting to plain attention..." 198 | ) 199 | use_flash_attn = False 200 | 201 | forward_fn = get_forward_function(use_flash_attn, use_full) 202 | transformers.models.gpt_neox.modeling_gpt_neox.GPTNeoXAttention.forward = forward_fn 203 | -------------------------------------------------------------------------------- /imgs/LongAlpaca.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/LongLoRA/d4eb344c5ccc9e91c0812a2b2aeea69070df8c33/imgs/LongAlpaca.png -------------------------------------------------------------------------------- /imgs/Shift-short-attention2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/LongLoRA/d4eb344c5ccc9e91c0812a2b2aeea69070df8c33/imgs/Shift-short-attention2.png -------------------------------------------------------------------------------- /imgs/data-distribution-in-longalpaca12k.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/LongLoRA/d4eb344c5ccc9e91c0812a2b2aeea69070df8c33/imgs/data-distribution-in-longalpaca12k.png -------------------------------------------------------------------------------- /imgs/demo-compare-harrypotter.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/LongLoRA/d4eb344c5ccc9e91c0812a2b2aeea69070df8c33/imgs/demo-compare-harrypotter.png -------------------------------------------------------------------------------- /imgs/demo-compare-journeytothewest.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/LongLoRA/d4eb344c5ccc9e91c0812a2b2aeea69070df8c33/imgs/demo-compare-journeytothewest.png -------------------------------------------------------------------------------- /imgs/demo-compare-threebody.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/LongLoRA/d4eb344c5ccc9e91c0812a2b2aeea69070df8c33/imgs/demo-compare-threebody.png -------------------------------------------------------------------------------- /imgs/economy-comparison.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/LongLoRA/d4eb344c5ccc9e91c0812a2b2aeea69070df8c33/imgs/economy-comparison.png -------------------------------------------------------------------------------- /imgs/economy-prediction.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/LongLoRA/d4eb344c5ccc9e91c0812a2b2aeea69070df8c33/imgs/economy-prediction.png -------------------------------------------------------------------------------- /imgs/paper-improvements.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/LongLoRA/d4eb344c5ccc9e91c0812a2b2aeea69070df8c33/imgs/paper-improvements.png -------------------------------------------------------------------------------- /imgs/paper-review.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/LongLoRA/d4eb344c5ccc9e91c0812a2b2aeea69070df8c33/imgs/paper-review.png -------------------------------------------------------------------------------- /imgs/paper-style-compare-cvpr-iclr.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/LongLoRA/d4eb344c5ccc9e91c0812a2b2aeea69070df8c33/imgs/paper-style-compare-cvpr-iclr.png -------------------------------------------------------------------------------- /inference-qlora.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import math 4 | import torch 5 | import argparse 6 | import textwrap 7 | import transformers 8 | from peft import PeftModel 9 | from transformers import GenerationConfig, TextStreamer, BitsAndBytesConfig 10 | from llama_attn_replace import replace_llama_attn 11 | 12 | PROMPT_DICT = { 13 | "prompt_no_input": ( 14 | "Below is an instruction that describes a task. " 15 | "Write a response that appropriately completes the request.\n\n" 16 | "### Instruction:\n{instruction}\n\n### Response:" 17 | ), 18 | "prompt_no_input_llama2": ( 19 | "[INST] <>\n" 20 | "You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\n" 21 | "If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\n" 22 | "<> \n\n {instruction} [/INST]" 23 | ), 24 | "prompt_llama2": "[INST]{instruction}[/INST]" 25 | } 26 | 27 | def parse_config(): 28 | parser = argparse.ArgumentParser(description='arg parser') 29 | parser.add_argument('--material', type=str, default="") 30 | parser.add_argument('--question', type=str, default="") 31 | parser.add_argument('--base_model', type=str, default="/data1/pretrained-models/llama-7b-hf") 32 | parser.add_argument('--cache_dir', type=str, default="./cache") 33 | parser.add_argument('--context_size', type=int, default=-1, help='context size during fine-tuning') 34 | parser.add_argument('--flash_attn', type=bool, default=False, help='') 35 | parser.add_argument('--temperature', type=float, default=0.6, help='') 36 | parser.add_argument('--top_p', type=float, default=0.9, help='') 37 | parser.add_argument('--max_gen_len', type=int, default=512, help='') 38 | args = parser.parse_args() 39 | return args 40 | 41 | def read_txt_file(material_txt): 42 | if not material_txt.split(".")[-1]=='txt': 43 | raise ValueError("Only support txt or pdf file.") 44 | content = "" 45 | with open(material_txt) as f: 46 | for line in f.readlines(): 47 | content += line 48 | return content 49 | 50 | def build_generator( 51 | model, tokenizer, temperature=0.6, top_p=0.9, max_gen_len=4096, use_cache=True 52 | ): 53 | def response(prompt): 54 | inputs = tokenizer(prompt, return_tensors="pt").to(model.device) 55 | 56 | streamer = TextStreamer(tokenizer) 57 | 58 | output = model.generate( 59 | **inputs, 60 | max_new_tokens=max_gen_len, 61 | temperature=temperature, 62 | top_p=top_p, 63 | use_cache=use_cache, 64 | streamer=streamer, 65 | ) 66 | 67 | out = tokenizer.decode(output[0], skip_special_tokens=True) 68 | 69 | out = out.split(prompt.lstrip(""))[1].strip() 70 | return out 71 | 72 | return response 73 | 74 | def main(args): 75 | if args.flash_attn: 76 | replace_llama_attn(inference=True) 77 | 78 | # Set RoPE scaling factor 79 | config = transformers.AutoConfig.from_pretrained( 80 | args.base_model, 81 | cache_dir=args.cache_dir, 82 | ) 83 | 84 | orig_ctx_len = getattr(config, "max_position_embeddings", None) 85 | if orig_ctx_len and args.context_size > orig_ctx_len: 86 | scaling_factor = float(math.ceil(args.context_size / orig_ctx_len)) 87 | config.rope_scaling = {"type": "linear", "factor": scaling_factor} 88 | 89 | # Load model and tokenizer 90 | model = transformers.AutoModelForCausalLM.from_pretrained( 91 | args.base_model, 92 | config=config, 93 | cache_dir=args.cache_dir, 94 | torch_dtype=torch.float16, 95 | device_map="auto", 96 | quantization_config = BitsAndBytesConfig( 97 | load_in_4bit=True, 98 | bnb_4bit_use_double_quant=True, 99 | bnb_4bit_quant_type="nf4", 100 | bnb_4bit_compute_dtype=torch.bfloat16 101 | ) 102 | ) 103 | model.resize_token_embeddings(32001) 104 | 105 | tokenizer = transformers.AutoTokenizer.from_pretrained( 106 | args.base_model, 107 | cache_dir=args.cache_dir, 108 | model_max_length=args.context_size if args.context_size > orig_ctx_len else orig_ctx_len, 109 | padding_side="right", 110 | use_fast=False, 111 | ) 112 | 113 | model.eval() 114 | if torch.__version__ >= "2" and sys.platform != "win32": 115 | model = torch.compile(model) 116 | respond = build_generator(model, tokenizer, temperature=args.temperature, top_p=args.top_p, 117 | max_gen_len=args.max_gen_len, use_cache=True) 118 | 119 | material = read_txt_file(args.material) 120 | prompt_no_input = PROMPT_DICT["prompt_llama2"] 121 | prompt = prompt_no_input.format_map({"instruction": material + "\n%s"%args.question}) 122 | 123 | output = respond(prompt=prompt) 124 | 125 | if __name__ == "__main__": 126 | args = parse_config() 127 | main(args) 128 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import math 4 | import torch 5 | import argparse 6 | import textwrap 7 | import transformers 8 | from peft import PeftModel 9 | from transformers import GenerationConfig, TextStreamer 10 | from llama_attn_replace import replace_llama_attn 11 | 12 | PROMPT_DICT = { 13 | "prompt_no_input": ( 14 | "Below is an instruction that describes a task. " 15 | "Write a response that appropriately completes the request.\n\n" 16 | "### Instruction:\n{instruction}\n\n### Response:" 17 | ), 18 | "prompt_no_input_llama2": ( 19 | "[INST] <>\n" 20 | "You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\n" 21 | "If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\n" 22 | "<> \n\n {instruction} [/INST]" 23 | ), 24 | "prompt_llama2": "[INST]{instruction}[/INST]" 25 | } 26 | 27 | def parse_config(): 28 | parser = argparse.ArgumentParser(description='arg parser') 29 | parser.add_argument('--material', type=str, default="") 30 | parser.add_argument('--question', type=str, default="") 31 | parser.add_argument('--base_model', type=str, default="/data1/pretrained-models/llama-7b-hf") 32 | parser.add_argument('--cache_dir', type=str, default="./cache") 33 | parser.add_argument('--context_size', type=int, default=-1, help='context size during fine-tuning') 34 | parser.add_argument('--flash_attn', type=bool, default=False, help='') 35 | parser.add_argument('--temperature', type=float, default=0.6, help='') 36 | parser.add_argument('--top_p', type=float, default=0.9, help='') 37 | parser.add_argument('--max_gen_len', type=int, default=512, help='') 38 | args = parser.parse_args() 39 | return args 40 | 41 | def read_txt_file(material_txt): 42 | if not material_txt.split(".")[-1]=='txt': 43 | raise ValueError("Only support txt or pdf file.") 44 | content = "" 45 | with open(material_txt) as f: 46 | for line in f.readlines(): 47 | content += line 48 | return content 49 | 50 | def build_generator( 51 | model, tokenizer, temperature=0.6, top_p=0.9, max_gen_len=4096, use_cache=True 52 | ): 53 | def response(prompt): 54 | inputs = tokenizer(prompt, return_tensors="pt").to(model.device) 55 | 56 | streamer = TextStreamer(tokenizer) 57 | 58 | output = model.generate( 59 | **inputs, 60 | max_new_tokens=max_gen_len, 61 | temperature=temperature, 62 | top_p=top_p, 63 | use_cache=use_cache, 64 | streamer=streamer, 65 | ) 66 | 67 | out = tokenizer.decode(output[0], skip_special_tokens=True) 68 | 69 | out = out.split(prompt.lstrip(""))[1].strip() 70 | return out 71 | 72 | return response 73 | 74 | def main(args): 75 | if args.flash_attn: 76 | replace_llama_attn(inference=True) 77 | 78 | # Set RoPE scaling factor 79 | config = transformers.AutoConfig.from_pretrained( 80 | args.base_model, 81 | cache_dir=args.cache_dir, 82 | ) 83 | 84 | orig_ctx_len = getattr(config, "max_position_embeddings", None) 85 | if orig_ctx_len and args.context_size > orig_ctx_len: 86 | scaling_factor = float(math.ceil(args.context_size / orig_ctx_len)) 87 | config.rope_scaling = {"type": "linear", "factor": scaling_factor} 88 | 89 | # Load model and tokenizer 90 | model = transformers.AutoModelForCausalLM.from_pretrained( 91 | args.base_model, 92 | config=config, 93 | cache_dir=args.cache_dir, 94 | torch_dtype=torch.float16, 95 | device_map="auto", 96 | ) 97 | model.resize_token_embeddings(32001) 98 | 99 | tokenizer = transformers.AutoTokenizer.from_pretrained( 100 | args.base_model, 101 | cache_dir=args.cache_dir, 102 | model_max_length=args.context_size if args.context_size > orig_ctx_len else orig_ctx_len, 103 | padding_side="right", 104 | use_fast=False, 105 | ) 106 | 107 | if torch.__version__ >= "2" and sys.platform != "win32": 108 | model = torch.compile(model) 109 | model.eval() 110 | 111 | respond = build_generator(model, tokenizer, temperature=args.temperature, top_p=args.top_p, 112 | max_gen_len=args.max_gen_len, use_cache=True) 113 | 114 | material = read_txt_file(args.material) 115 | prompt_no_input = PROMPT_DICT["prompt_llama2"] 116 | prompt = prompt_no_input.format_map({"instruction": material + "\n%s"%args.question}) 117 | 118 | output = respond(prompt=prompt) 119 | 120 | if __name__ == "__main__": 121 | args = parse_config() 122 | main(args) 123 | -------------------------------------------------------------------------------- /merge_lora_weights_and_save_hf_model.py: -------------------------------------------------------------------------------- 1 | # Written by Yukang Chen 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | import torch 17 | import argparse 18 | import transformers 19 | from peft import PeftModel 20 | from typing import Dict 21 | 22 | IGNORE_INDEX = -100 23 | DEFAULT_PAD_TOKEN = "[PAD]" 24 | DEFAULT_EOS_TOKEN = "" 25 | DEFAULT_BOS_TOKEN = "" 26 | DEFAULT_UNK_TOKEN = "" 27 | 28 | def parse_config(): 29 | parser = argparse.ArgumentParser(description='arg parser') 30 | parser.add_argument('--base_model', type=str, default="/data/pretrained-models/llama-7b-hf") 31 | parser.add_argument('--peft_model', type=str, default=None, help='') 32 | parser.add_argument('--context_size', type=int, default=-1, help='context size during fine-tuning') 33 | parser.add_argument('--save_path', type=str, default=None, help='') 34 | parser.add_argument('--cache_dir', type=str, default=None, help='./cache_dir') 35 | args = parser.parse_args() 36 | return args 37 | 38 | def smart_tokenizer_and_embedding_resize( 39 | special_tokens_dict: Dict, 40 | tokenizer: transformers.PreTrainedTokenizer, 41 | model: transformers.PreTrainedModel, 42 | ): 43 | """Resize tokenizer and embedding. 44 | 45 | Note: This is the unoptimized version that may make your embedding size not be divisible by 64. 46 | """ 47 | num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict) 48 | model.resize_token_embeddings(len(tokenizer)) 49 | 50 | if num_new_tokens > 0: 51 | input_embeddings = model.get_input_embeddings().weight.data 52 | output_embeddings = model.get_output_embeddings().weight.data 53 | 54 | input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True) 55 | output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True) 56 | 57 | input_embeddings[-num_new_tokens:] = input_embeddings_avg 58 | output_embeddings[-num_new_tokens:] = output_embeddings_avg 59 | 60 | def main(args): 61 | device = "cuda:0" 62 | torch.cuda.set_device(device) 63 | 64 | print("base model", args.base_model) 65 | print("peft model", args.peft_model) 66 | 67 | # Load model and tokenizer 68 | model = transformers.AutoModelForCausalLM.from_pretrained( 69 | args.base_model, 70 | cache_dir=args.cache_dir, 71 | torch_dtype=torch.float16, 72 | device_map="auto", 73 | ) 74 | 75 | tokenizer = transformers.AutoTokenizer.from_pretrained( 76 | args.base_model, 77 | cache_dir=args.cache_dir, 78 | model_max_length=args.context_size, 79 | padding_side="right", 80 | use_fast=False, 81 | ) 82 | special_tokens_dict = dict() 83 | if tokenizer.pad_token is None: 84 | special_tokens_dict["pad_token"] = DEFAULT_PAD_TOKEN 85 | if tokenizer.eos_token is None: 86 | special_tokens_dict["eos_token"] = DEFAULT_EOS_TOKEN 87 | if tokenizer.bos_token is None: 88 | special_tokens_dict["bos_token"] = DEFAULT_BOS_TOKEN 89 | if tokenizer.unk_token is None: 90 | special_tokens_dict["unk_token"] = DEFAULT_UNK_TOKEN 91 | 92 | smart_tokenizer_and_embedding_resize( 93 | special_tokens_dict=special_tokens_dict, 94 | tokenizer=tokenizer, 95 | model=model, 96 | ) 97 | 98 | trainable_params = os.path.join(args.peft_model, "trainable_params.bin") 99 | if os.path.isfile(trainable_params): 100 | model.load_state_dict(torch.load(trainable_params, map_location=model.device), strict=False) 101 | model = PeftModel.from_pretrained( 102 | model, 103 | args.peft_model, 104 | device_map="auto", 105 | torch_dtype=torch.float16, 106 | ) 107 | model = model.merge_and_unload() 108 | model.save_pretrained(args.save_path) 109 | tokenizer.save_pretrained(args.save_path) 110 | 111 | if __name__ == "__main__": 112 | args = parse_config() 113 | main(args) 114 | -------------------------------------------------------------------------------- /passkey_retrivial.py: -------------------------------------------------------------------------------- 1 | # Written by Yukang Chen 2 | # Core code based on https://github.com/CStanKonrad/long_llama 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import os 17 | import math 18 | import torch 19 | import argparse 20 | import random 21 | import numpy as np 22 | from numpy import random 23 | from tqdm import tqdm 24 | import transformers 25 | from peft import PeftModel 26 | from llama_attn_replace import replace_llama_attn 27 | 28 | 29 | def parse_config(): 30 | parser = argparse.ArgumentParser(description='arg parser') 31 | parser.add_argument('--base_model', type=str, default="/data1/pretrained-models/llama-7b-hf") 32 | parser.add_argument('--cache_dir', type=str, default="./cache") 33 | parser.add_argument('--context_size', type=int, default=-1, help='context size during fine-tuning') 34 | parser.add_argument('--flash_attn', type=bool, default=True, help='whether to use flash attention 2') 35 | parser.add_argument('--max_tokens', type=int, default=32000, help='maximum token length for evaluation') 36 | parser.add_argument('--interval', type=int, default=1000, help='interval for evaluation') 37 | parser.add_argument('--num_tests', type=int, default=10, help='number of repeat testing for each length') 38 | 39 | args = parser.parse_args() 40 | return args 41 | 42 | 43 | def generate_prompt_landmark(n_garbage, seed): 44 | """Generates a text file and inserts an passkey at a random position.""" 45 | rnd_state = random.get_state() 46 | random.seed(seed) 47 | n_garbage_prefix = random.randint(0, n_garbage) 48 | n_garbage_suffix = n_garbage - n_garbage_prefix 49 | 50 | task_description = "There is an important info hidden inside a lot of irrelevant text. Find it and memorize them. I will quiz you about the important information there." 51 | garbage = "The grass is green. The sky is blue. The sun is yellow. Here we go. There and back again." 52 | garbage_inf = " ".join([garbage] * 5000) 53 | assert len(garbage_inf) >= n_garbage 54 | garbage_prefix = garbage_inf[:n_garbage_prefix] 55 | garbage_suffix = garbage_inf[:n_garbage_suffix] 56 | pass_key = random.randint(1, 50000) 57 | information_line = f"The pass key is {pass_key}. Remember it. {pass_key} is the pass key." 58 | final_question = "What is the pass key? The pass key is" 59 | lines = [ 60 | task_description, 61 | garbage_prefix, 62 | information_line, 63 | garbage_suffix, 64 | final_question, 65 | ] 66 | random.set_state(rnd_state) 67 | return "\n".join(lines), str(pass_key) 68 | 69 | 70 | def passkey_retrieval_test(model, tokenizer, device, use_cache=False, n_garbage=60000, seed=666): 71 | prompt, answer = generate_prompt_landmark(n_garbage, seed) 72 | input_ids = tokenizer(prompt, return_tensors="pt").input_ids 73 | input_ids = input_ids.to(device) 74 | len_token = input_ids.shape[-1] 75 | 76 | answer_ids = tokenizer(answer, return_tensors="pt").input_ids[:, 1:] # drop BOS 77 | generation_output = model.generate( 78 | input_ids=input_ids, max_new_tokens=answer_ids.shape[-1], num_beams=1, use_cache=use_cache 79 | ) 80 | 81 | model_answer = generation_output[0, -answer_ids.shape[-1]:].cpu() 82 | 83 | is_correct = (model_answer == answer_ids[0]).all().item() 84 | #print(f"The correct answer is {tokenizer.decode(answer_ids[0].cpu())}") 85 | #print(f"The model answer is {tokenizer.decode(model_answer.cpu())}, is_correct : {is_correct}") 86 | return is_correct, len_token 87 | 88 | 89 | def main(args): 90 | device = "cuda:0" 91 | torch.cuda.set_device(device) 92 | 93 | print("base model", args.base_model) 94 | 95 | if args.flash_attn: 96 | replace_llama_attn(use_full=True) 97 | 98 | # Set RoPE scaling factor 99 | config = transformers.AutoConfig.from_pretrained( 100 | args.base_model, 101 | cache_dir=args.cache_dir, 102 | ) 103 | 104 | context_size = args.context_size 105 | orig_ctx_len = getattr(config, "max_position_embeddings", None) # this value should be 4096 for LLaMA2 models 106 | if orig_ctx_len and context_size > orig_ctx_len: 107 | scaling_factor = float(math.ceil(context_size / orig_ctx_len)) 108 | config.rope_scaling = {"type": "linear", "factor": scaling_factor} 109 | 110 | # Load model and tokenizer 111 | model = transformers.AutoModelForCausalLM.from_pretrained( 112 | args.base_model, 113 | config=config, 114 | cache_dir=args.cache_dir, 115 | torch_dtype=torch.float16, 116 | device_map="auto", 117 | ) 118 | model.resize_token_embeddings(32001) 119 | 120 | tokenizer = transformers.AutoTokenizer.from_pretrained( 121 | args.base_model, 122 | cache_dir=args.cache_dir, 123 | model_max_length=args.context_size if args.context_size > orig_ctx_len else orig_ctx_len, 124 | padding_side="right", 125 | use_fast=False, 126 | ) 127 | 128 | total_test_points = args.max_tokens // args.interval 129 | all_accuries = {} 130 | for i in range(total_test_points): 131 | # This is a rough ratio to control the number of texts and tokens 132 | n_garbage = int(3.75 * (i + 1) * args.interval // 1024 * 1024) 133 | passed_tests = 0 134 | total_tokens = 0 135 | for i in range(args.num_tests): 136 | is_correct, len_tokens = passkey_retrieval_test(model, tokenizer, device, use_cache=not args.flash_attn, n_garbage=n_garbage, seed=i) 137 | passed_tests += is_correct 138 | total_tokens += len_tokens 139 | avg_tokens = total_tokens//args.num_tests 140 | accuracy = float(passed_tests)/args.num_tests 141 | print("accuracy on the token length %d is %f"%(avg_tokens, accuracy)) 142 | all_accuries[str(avg_tokens)] = accuracy 143 | print("accuries over tokens", all_accuries) 144 | 145 | 146 | if __name__ == "__main__": 147 | args = parse_config() 148 | main(args) 149 | -------------------------------------------------------------------------------- /pdf2txt/README.md: -------------------------------------------------------------------------------- 1 | # Extract text from pdf by dit detection and ocr 2 | 3 | The script uses various libraries such as `pdf2image`, `easyocr`, `ditod` and `detectron2` for processing. 4 | 5 | Detected objects are categorized into "text", "title", "list", "table", and "figure". 6 | 7 | The script provides detailed timing information for various processing steps, which can be useful for performance analysis. 8 | 9 | Text extraction uses `easyocr` and the results are further processed using SymSpell for word segmentation and a regular expression for filtering. 10 | 11 | ### 1. Installation 12 | ``` 13 | pip install -r requirements.txt 14 | python -m pip install 'git+https://github.com/facebookresearch/detectron2.git' 15 | apt-get install poppler-utils 16 | ``` 17 | 18 | ### 2. Download OCR model 19 | - Please download the weight [trained_ocr_cascade_large.pth](https://drive.google.com/file/d/1DtHtR3hhj8Df_Lkgdm9P79Eljot5MR_i/view?usp=share_link) first. 20 | - Please set the weight path in `configs/cascade_dit_large.yaml`. 21 | 22 | ### 3. Basic usage 23 | ``` 24 | python pdf2txt.py --pdf_path path_to_pdf_file --outputs_dir path_to_output_dir 25 | ``` 26 | 27 | The output txt file will be stored in `path_to_output_dir/txt` 28 | -------------------------------------------------------------------------------- /pdf2txt/backbone.py: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------------------------------- 2 | # VIT: Multi-Path Vision Transformer for Dense Prediction 3 | # Copyright (c) 2022 Electronics and Telecommunications Research Institute (ETRI). 4 | # All Rights Reserved. 5 | # Written by Youngwan Lee 6 | # This source code is licensed(Dual License(GPL3.0 & Commercial)) under the license found in the 7 | # LICENSE file in the root directory of this source tree. 8 | # -------------------------------------------------------------------------------- 9 | # References: 10 | # timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm 11 | # CoaT: https://github.com/mlpc-ucsd/CoaT 12 | # -------------------------------------------------------------------------------- 13 | 14 | 15 | import torch 16 | 17 | from detectron2.layers import ( 18 | ShapeSpec, 19 | ) 20 | from detectron2.modeling import Backbone, BACKBONE_REGISTRY, FPN 21 | from detectron2.modeling.backbone.fpn import LastLevelP6P7, LastLevelMaxPool 22 | 23 | from beit import beit_base_patch16, dit_base_patch16, dit_large_patch16, beit_large_patch16 24 | 25 | __all__ = [ 26 | "build_vit_fpn_backbone", 27 | ] 28 | 29 | 30 | class VIT_Backbone(Backbone): 31 | """ 32 | Implement VIT backbone. 33 | """ 34 | 35 | def __init__(self, name, out_features, drop_path, img_size, pos_type, model_kwargs): 36 | super().__init__() 37 | self._out_features = out_features 38 | if 'base' in name: 39 | self._out_feature_strides = {"layer3": 4, "layer5": 8, "layer7": 16, "layer11": 32} 40 | else: 41 | self._out_feature_strides = {"layer7": 4, "layer11": 8, "layer15": 16, "layer23": 32} 42 | 43 | if name == 'beit_base_patch16': 44 | model_func = beit_base_patch16 45 | self._out_feature_channels = {"layer3": 768, "layer5": 768, "layer7": 768, "layer11": 768} 46 | elif name == 'dit_base_patch16': 47 | model_func = dit_base_patch16 48 | self._out_feature_channels = {"layer3": 768, "layer5": 768, "layer7": 768, "layer11": 768} 49 | elif name == "deit_base_patch16": 50 | model_func = deit_base_patch16 51 | self._out_feature_channels = {"layer3": 768, "layer5": 768, "layer7": 768, "layer11": 768} 52 | elif name == "mae_base_patch16": 53 | model_func = mae_base_patch16 54 | self._out_feature_channels = {"layer3": 768, "layer5": 768, "layer7": 768, "layer11": 768} 55 | elif name == "dit_large_patch16": 56 | model_func = dit_large_patch16 57 | self._out_feature_channels = {"layer7": 1024, "layer11": 1024, "layer15": 1024, "layer23": 1024} 58 | elif name == "beit_large_patch16": 59 | model_func = beit_large_patch16 60 | self._out_feature_channels = {"layer7": 1024, "layer11": 1024, "layer15": 1024, "layer23": 1024} 61 | else: 62 | raise ValueError("Unsupported VIT name yet.") 63 | 64 | if 'beit' in name or 'dit' in name: 65 | if pos_type == "abs": 66 | self.backbone = model_func(img_size=img_size, 67 | out_features=out_features, 68 | drop_path_rate=drop_path, 69 | use_abs_pos_emb=True, 70 | **model_kwargs) 71 | elif pos_type == "shared_rel": 72 | self.backbone = model_func(img_size=img_size, 73 | out_features=out_features, 74 | drop_path_rate=drop_path, 75 | use_shared_rel_pos_bias=True, 76 | **model_kwargs) 77 | elif pos_type == "rel": 78 | self.backbone = model_func(img_size=img_size, 79 | out_features=out_features, 80 | drop_path_rate=drop_path, 81 | use_rel_pos_bias=True, 82 | **model_kwargs) 83 | else: 84 | raise ValueError() 85 | else: 86 | self.backbone = model_func(img_size=img_size, 87 | out_features=out_features, 88 | drop_path_rate=drop_path, 89 | **model_kwargs) 90 | 91 | def forward(self, x): 92 | """ 93 | Args: 94 | x: Tensor of shape (N,C,H,W). H, W must be a multiple of ``self.size_divisibility``. 95 | 96 | Returns: 97 | dict[str->Tensor]: names and the corresponding features 98 | """ 99 | assert x.dim() == 4, f"VIT takes an input of shape (N, C, H, W). Got {x.shape} instead!" 100 | return self.backbone.forward_features(x) 101 | 102 | def output_shape(self): 103 | return { 104 | name: ShapeSpec( 105 | channels=self._out_feature_channels[name], stride=self._out_feature_strides[name] 106 | ) 107 | for name in self._out_features 108 | } 109 | 110 | 111 | def build_VIT_backbone(cfg): 112 | """ 113 | Create a VIT instance from config. 114 | 115 | Args: 116 | cfg: a detectron2 CfgNode 117 | 118 | Returns: 119 | A VIT backbone instance. 120 | """ 121 | # fmt: off 122 | name = cfg.MODEL.VIT.NAME 123 | out_features = cfg.MODEL.VIT.OUT_FEATURES 124 | drop_path = cfg.MODEL.VIT.DROP_PATH 125 | img_size = cfg.MODEL.VIT.IMG_SIZE 126 | pos_type = cfg.MODEL.VIT.POS_TYPE 127 | 128 | model_kwargs = eval(str(cfg.MODEL.VIT.MODEL_KWARGS).replace("`", "")) 129 | 130 | return VIT_Backbone(name, out_features, drop_path, img_size, pos_type, model_kwargs) 131 | 132 | 133 | @BACKBONE_REGISTRY.register() 134 | def build_vit_fpn_backbone(cfg, input_shape: ShapeSpec): 135 | """ 136 | Create a VIT w/ FPN backbone. 137 | 138 | Args: 139 | cfg: a detectron2 CfgNode 140 | 141 | Returns: 142 | backbone (Backbone): backbone module, must be a subclass of :class:`Backbone`. 143 | """ 144 | bottom_up = build_VIT_backbone(cfg) 145 | in_features = cfg.MODEL.FPN.IN_FEATURES 146 | out_channels = cfg.MODEL.FPN.OUT_CHANNELS 147 | backbone = FPN( 148 | bottom_up=bottom_up, 149 | in_features=in_features, 150 | out_channels=out_channels, 151 | norm=cfg.MODEL.FPN.NORM, 152 | top_block=LastLevelMaxPool(), 153 | fuse_type=cfg.MODEL.FPN.FUSE_TYPE, 154 | ) 155 | return backbone 156 | -------------------------------------------------------------------------------- /pdf2txt/config.py: -------------------------------------------------------------------------------- 1 | from detectron2.config import CfgNode as CN 2 | 3 | 4 | def add_vit_config(cfg): 5 | """ 6 | Add config for VIT. 7 | """ 8 | _C = cfg 9 | 10 | _C.MODEL.VIT = CN() 11 | 12 | # CoaT model name. 13 | _C.MODEL.VIT.NAME = "" 14 | 15 | # Output features from CoaT backbone. 16 | _C.MODEL.VIT.OUT_FEATURES = ["layer3", "layer5", "layer7", "layer11"] 17 | 18 | _C.MODEL.VIT.IMG_SIZE = [224, 224] 19 | 20 | _C.MODEL.VIT.POS_TYPE = "shared_rel" 21 | 22 | _C.MODEL.VIT.DROP_PATH = 0. 23 | 24 | _C.MODEL.VIT.MODEL_KWARGS = "{}" 25 | 26 | _C.SOLVER.OPTIMIZER = "ADAMW" 27 | 28 | _C.SOLVER.BACKBONE_MULTIPLIER = 1.0 29 | 30 | _C.AUG = CN() 31 | 32 | _C.AUG.DETR = False 33 | -------------------------------------------------------------------------------- /pdf2txt/configs/Base-RCNN-FPN.yaml: -------------------------------------------------------------------------------- 1 | MODEL: 2 | MASK_ON: True 3 | META_ARCHITECTURE: "GeneralizedRCNN" 4 | PIXEL_MEAN: [123.675, 116.280, 103.530] 5 | PIXEL_STD: [58.395, 57.120, 57.375] 6 | BACKBONE: 7 | NAME: "build_vit_fpn_backbone" 8 | VIT: 9 | OUT_FEATURES: ["layer3", "layer5", "layer7", "layer11"] 10 | DROP_PATH: 0.1 11 | IMG_SIZE: [224,224] 12 | POS_TYPE: "abs" 13 | FPN: 14 | IN_FEATURES: ["layer3", "layer5", "layer7", "layer11"] 15 | ANCHOR_GENERATOR: 16 | SIZES: [[32], [64], [128], [256], [512]] # One size for each in feature map 17 | ASPECT_RATIOS: [[0.5, 1.0, 2.0]] # Three aspect ratios (same for all in feature maps) 18 | RPN: 19 | IN_FEATURES: ["p2", "p3", "p4", "p5", "p6"] 20 | PRE_NMS_TOPK_TRAIN: 2000 # Per FPN level 21 | PRE_NMS_TOPK_TEST: 1000 # Per FPN level 22 | # Detectron1 uses 2000 proposals per-batch, 23 | # (See "modeling/rpn/rpn_outputs.py" for details of this legacy issue) 24 | # which is approximately 1000 proposals per-image since the default batch size for FPN is 2. 25 | POST_NMS_TOPK_TRAIN: 1000 26 | POST_NMS_TOPK_TEST: 1000 27 | ROI_HEADS: 28 | NAME: "StandardROIHeads" 29 | IN_FEATURES: ["p2", "p3", "p4", "p5"] 30 | NUM_CLASSES: 5 31 | ROI_BOX_HEAD: 32 | NAME: "FastRCNNConvFCHead" 33 | NUM_FC: 2 34 | POOLER_RESOLUTION: 7 35 | ROI_MASK_HEAD: 36 | NAME: "MaskRCNNConvUpsampleHead" 37 | NUM_CONV: 4 38 | POOLER_RESOLUTION: 14 39 | DATASETS: 40 | TRAIN: ("publaynet_train",) 41 | TEST: ("publaynet_val",) 42 | SOLVER: 43 | LR_SCHEDULER_NAME: "WarmupCosineLR" 44 | AMP: 45 | ENABLED: True 46 | OPTIMIZER: "ADAMW" 47 | BACKBONE_MULTIPLIER: 1.0 48 | CLIP_GRADIENTS: 49 | ENABLED: True 50 | CLIP_TYPE: "full_model" 51 | CLIP_VALUE: 1.0 52 | NORM_TYPE: 2.0 53 | WARMUP_FACTOR: 0.01 54 | BASE_LR: 0.0004 55 | WEIGHT_DECAY: 0.05 56 | IMS_PER_BATCH: 32 57 | INPUT: 58 | CROP: 59 | ENABLED: True 60 | TYPE: "absolute_range" 61 | SIZE: (384, 600) 62 | MIN_SIZE_TRAIN: (480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800) 63 | FORMAT: "RGB" 64 | DATALOADER: 65 | FILTER_EMPTY_ANNOTATIONS: False 66 | VERSION: 2 67 | AUG: 68 | DETR: True 69 | SEED: 42 -------------------------------------------------------------------------------- /pdf2txt/configs/cascade_dit_large.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "./Base-RCNN-FPN.yaml" 2 | MODEL: 3 | PIXEL_MEAN: [ 127.5, 127.5, 127.5 ] 4 | PIXEL_STD: [ 127.5, 127.5, 127.5 ] 5 | WEIGHTS: "./trained_ocr_cascade_large.pth" 6 | VIT: 7 | NAME: "dit_large_patch16" 8 | OUT_FEATURES: [ "layer7", "layer11", "layer15", "layer23" ] 9 | DROP_PATH: 0.2 10 | FPN: 11 | IN_FEATURES: [ "layer7", "layer11", "layer15", "layer23" ] 12 | ROI_HEADS: 13 | NAME: CascadeROIHeads 14 | ROI_BOX_HEAD: 15 | CLS_AGNOSTIC_BBOX_REG: True 16 | RPN: 17 | POST_NMS_TOPK_TRAIN: 2000 18 | SOLVER: 19 | WARMUP_ITERS: 1000 20 | IMS_PER_BATCH: 16 21 | MAX_ITER: 60000 22 | CHECKPOINT_PERIOD: 2000 23 | BASE_LR: 0.0001 24 | STEPS: (40000, 53333) 25 | AMP: 26 | ENABLED: False 27 | TEST: 28 | EVAL_PERIOD: 2000 29 | -------------------------------------------------------------------------------- /pdf2txt/pdf2txt.py: -------------------------------------------------------------------------------- 1 | # Written by Shaozuo Yu 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | import argparse 17 | import pdf2image 18 | import easyocr 19 | import cv2 20 | from config import add_vit_config 21 | from backbone import build_vit_fpn_backbone 22 | import torch 23 | from detectron2.config import get_cfg 24 | from detectron2.utils.visualizer import ColorMode, Visualizer 25 | from detectron2.data import MetadataCatalog 26 | from detectron2.engine import DefaultPredictor 27 | from detectron2.layers import nms 28 | import pickle 29 | import numpy as np 30 | import shutil 31 | from tqdm import tqdm 32 | from PIL import Image 33 | import time 34 | 35 | from symspellpy.symspellpy import SymSpell 36 | import pkg_resources 37 | import re 38 | 39 | 40 | prefix_length = 7 41 | sym_spell = SymSpell(max_dictionary_edit_distance=0, prefix_length=prefix_length) 42 | dictionary_path = pkg_resources.resource_filename("symspellpy", "frequency_dictionary_en_82_765.txt") 43 | sym_spell.load_dictionary(dictionary_path, term_index=0, count_index=1) 44 | # filter 45 | regex = "[A-Za-z0-9=:/\*]*[=:+-][A-Za-z0-9=:/\*]" 46 | 47 | 48 | def detect_objects(image_path, predictor, cfg): 49 | # Step 5: run inference 50 | img = cv2.imread(image_path) 51 | 52 | md = MetadataCatalog.get(cfg.DATASETS.TEST[0]) 53 | md.set(thing_classes=["text", "title", "list", "table", "figure"]) 54 | 55 | start_time = time.time() 56 | 57 | detections = predictor(img)["instances"] 58 | 59 | end_time = time.time() 60 | 61 | print(f"detection model部分执行时间: {end_time - start_time} 秒") 62 | # get boxes and scores 63 | boxes = detections.pred_boxes.tensor 64 | scores = detections.scores 65 | 66 | # NMS 67 | keep = nms(boxes, scores, 0.1) 68 | detections = detections[keep] 69 | scores = detections.scores 70 | 71 | threshold = 0.8 # you can adjust this value 72 | keep2 = torch.nonzero(scores > threshold).squeeze(1) 73 | detections = detections[keep2] 74 | 75 | return detections 76 | 77 | def process_pdf(pdf_file, outputs_dir, config_file): 78 | 79 | results = {} 80 | 81 | tmp_dir = os.path.join(outputs_dir, 'tmp') 82 | txt_dir = os.path.join(outputs_dir, 'txt') 83 | os.makedirs(tmp_dir, exist_ok=True) 84 | os.makedirs(txt_dir, exist_ok=True) 85 | 86 | #load detection model 87 | cfg = get_cfg() 88 | add_vit_config(cfg) 89 | cfg.merge_from_file(config_file) 90 | device = "cuda" if torch.cuda.is_available() else "cpu" 91 | #device = "cpu" 92 | cfg.MODEL.DEVICE = device 93 | predictor = DefaultPredictor(cfg) 94 | reader = easyocr.Reader(['en'], gpu=True) 95 | 96 | book_name = os.path.splitext(pdf_file)[0] 97 | book_base_name = os.path.basename(pdf_file) 98 | txt_file_path = os.path.join(txt_dir, f"{book_base_name}.txt") 99 | if os.path.exists(txt_file_path): 100 | raise ValueError(f"Skipping {book_name} as it already exists in the output directory.") 101 | 102 | start_time = time.time() 103 | 104 | book_name = os.path.splitext(pdf_file)[0] 105 | images = pdf2image.convert_from_path(pdf_file) 106 | 107 | end_time = time.time() 108 | 109 | print(f"pdf2image time: {end_time - start_time} s") 110 | 111 | book_results = [] 112 | for page_num, image in tqdm(enumerate(images, start=1), desc=f"Processing {book_name}", leave=False): 113 | image_path = os.path.join(tmp_dir, f"{book_base_name}-{page_num}.png") 114 | image.save(image_path) 115 | 116 | start_time = time.time() 117 | 118 | detections = detect_objects(image_path, predictor, cfg) 119 | 120 | end_time = time.time() 121 | 122 | print(f"detection time: {end_time - start_time} s") 123 | 124 | boxes = detections.pred_boxes.tensor.tolist() 125 | labels = detections.pred_classes.tolist() 126 | 127 | # get boxes 128 | all_detections = [(bbox, label_id) for bbox, label_id in zip(boxes, labels)] 129 | 130 | # sort 131 | all_detections.sort(key=lambda x: (x[0][1], x[0][0])) 132 | 133 | start_time = time.time() 134 | 135 | label_counter = {"figure": 0, "table": 0, 'text': 0, 'list': 0, 'title': 0} 136 | for bbox, label_id in all_detections: 137 | #print("Number of classes:", len(MetadataCatalog.get(cfg.DATASETS.TEST[0]).thing_classes)) 138 | label = MetadataCatalog.get(cfg.DATASETS.TEST[0]).thing_classes[label_id] 139 | cropped_image_np = np.array(image.crop(bbox)) 140 | 141 | if label in ['text', 'list', 'title']: 142 | #reader = easyocr.Reader(['en'], cudnn_benchmark=True) 143 | ocr_result = reader.readtext(cropped_image_np, batch_size=10) 144 | extracted_text = ' '.join([item[1] for item in ocr_result]) 145 | 146 | # SymSpell for word segmentation 147 | suggestions = sym_spell.word_segmentation(extracted_text) 148 | segmented_text = suggestions.corrected_string 149 | 150 | # filter 151 | filtered_text = re.sub(regex, "", segmented_text) 152 | 153 | book_results.append(extracted_text) 154 | 155 | end_time = time.time() 156 | 157 | print(f"ocr time: {end_time - start_time} s") 158 | 159 | results[book_name] = book_results 160 | with open(os.path.join(txt_dir, f"{book_base_name}.txt"), 'w') as f: 161 | f.write('\n'.join(book_results)) 162 | 163 | # delete tmp dir 164 | shutil.rmtree(tmp_dir) 165 | 166 | 167 | if __name__ == '__main__': 168 | parser = argparse.ArgumentParser(description="PDF processing script") 169 | parser.add_argument("--pdf_path", help="Path to PDF file", type=str, required=True) 170 | parser.add_argument("--outputs_dir", help="Directory to save outputs", type=str, required=True) 171 | parser.add_argument("--config_file", default="configs/cascade_dit_large.yaml", metavar="FILE", help="path to config file") 172 | 173 | args = parser.parse_args() 174 | 175 | process_pdf(args.pdf_path, args.outputs_dir, args.config_file) 176 | -------------------------------------------------------------------------------- /pdf2txt/requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | torchvision 3 | timm==0.5.4 4 | Pillow 5 | blobfile 6 | mypy 7 | numpy 8 | pytest 9 | requests 10 | einops 11 | tensorboardX 12 | scipy 13 | opencv-python 14 | pdf2image 15 | easyocr 16 | argparse 17 | regex 18 | symspellpy -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy>=1.26.0 2 | rouge_score>=0.1.2 3 | fire>=0.5.0 4 | # openai 5 | transformers==4.34.0 6 | torch>=2.0.0 7 | sentencepiece>=0.1.99 8 | tokenizers>=0.14.0 9 | # wandb 10 | accelerate>=0.23.0 11 | datasets>=2.14.5 12 | deepspeed>=0.10.3 13 | peft>=0.5.0 14 | # partial 15 | # gradio 16 | einops>=0.7.0 17 | bitsandbytes==0.41.1 18 | scipy>=1.11.3 19 | protobuf>=4.24.4 20 | torchmetrics>=1.2.0 21 | -------------------------------------------------------------------------------- /run_streaming_llama_longalpaca.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | warnings.filterwarnings("ignore") 4 | 5 | import torch 6 | import argparse 7 | import json 8 | import os 9 | import time 10 | import re 11 | import sys 12 | 13 | from tqdm import tqdm 14 | from streaming_llm.utils import load, download_url, load_jsonl 15 | from streaming_llm.enable_streaming_llm import enable_streaming_llm 16 | 17 | 18 | @torch.no_grad() 19 | def greedy_generate(model, tokenizer, input_ids, past_key_values, max_gen_len): 20 | outputs = model( 21 | input_ids=input_ids, 22 | past_key_values=past_key_values, 23 | use_cache=True, 24 | ) 25 | past_key_values = outputs.past_key_values 26 | pred_token_idx = outputs.logits[:, -1, :].argmax(dim=-1).unsqueeze(1) 27 | generated_ids = [pred_token_idx.item()] 28 | pos = 0 29 | for _ in range(max_gen_len - 1): 30 | outputs = model( 31 | input_ids=pred_token_idx, 32 | past_key_values=past_key_values, 33 | use_cache=True, 34 | ) 35 | past_key_values = outputs.past_key_values 36 | pred_token_idx = outputs.logits[:, -1, :].argmax(dim=-1).unsqueeze(1) 37 | generated_ids.append(pred_token_idx.item()) 38 | generated_text = ( 39 | tokenizer.decode( 40 | generated_ids, 41 | skip_special_tokens=True, 42 | clean_up_tokenization_spaces=True, 43 | spaces_between_special_tokens=False, 44 | ) 45 | .strip() 46 | .split(" ") 47 | ) 48 | 49 | now = len(generated_text) - 1 50 | if now > pos: 51 | print(" ".join(generated_text[pos:now]), end=" ", flush=True) 52 | pos = now 53 | 54 | if pred_token_idx == tokenizer.eos_token_id: 55 | break 56 | print(" ".join(generated_text[pos:]), flush=True) 57 | return past_key_values 58 | 59 | 60 | @torch.no_grad() 61 | def streaming_inference(model, tokenizer, prompts, kv_cache=None, max_gen_len=1000): 62 | past_key_values = None 63 | for idx, prompt in enumerate(prompts): 64 | prompt = "USER: " + prompt + "\n\nASSISTANT: " 65 | print("\n" + prompt, end="") 66 | input_ids = tokenizer(prompt, return_tensors="pt").input_ids 67 | input_ids = input_ids.to(model.device) 68 | seq_len = input_ids.shape[1] 69 | if kv_cache is not None: 70 | space_needed = seq_len + max_gen_len 71 | past_key_values = kv_cache.evict_for_space(past_key_values, space_needed) 72 | 73 | past_key_values = greedy_generate( 74 | model, tokenizer, input_ids, past_key_values, max_gen_len=max_gen_len 75 | ) 76 | 77 | 78 | def main(args): 79 | model_name_or_path = args.model_name_or_path 80 | model, tokenizer = load(model_name_or_path) 81 | print(f"Loading data from {args.test_filepath} ...") 82 | 83 | list_data = json.load(open(args.test_filepath)) 84 | prompts = [] 85 | for sample in list_data: 86 | prompts += [sample["instruction"]] 87 | 88 | if args.enable_streaming: 89 | kv_cache = enable_streaming_llm( 90 | model, start_size=args.start_size, recent_size=args.recent_size, use_flash_attn=args.use_flash_attn 91 | ) 92 | else: 93 | kv_cache = None 94 | 95 | streaming_inference( 96 | model, 97 | tokenizer, 98 | prompts, 99 | kv_cache, 100 | ) 101 | 102 | 103 | if __name__ == "__main__": 104 | parser = argparse.ArgumentParser() 105 | parser.add_argument( 106 | "--model_name_or_path", type=str, default="Yukang/LongAlpaca-7B" 107 | ) 108 | parser.add_argument("--test_filepath", type=str, default="outputs_stream.json") 109 | parser.add_argument("--enable_streaming", action="store_true") 110 | parser.add_argument("--start_size", type=int, default=4) 111 | parser.add_argument("--recent_size", type=int, default=8192) 112 | parser.add_argument("--use_flash_attn", type=bool, default=True) 113 | args = parser.parse_args() 114 | 115 | main(args) 116 | -------------------------------------------------------------------------------- /streaming_llm/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/LongLoRA/d4eb344c5ccc9e91c0812a2b2aeea69070df8c33/streaming_llm/__init__.py -------------------------------------------------------------------------------- /streaming_llm/enable_streaming_llm.py: -------------------------------------------------------------------------------- 1 | from streaming_llm.kv_cache import StartRecentKVCache 2 | 3 | 4 | def enable_streaming_llm(model, start_size, recent_size, use_flash_attn=True): 5 | if "llama" in model.config.model_type: 6 | k_seq_dim = v_seq_dim = 2 7 | from streaming_llm.pos_shift.modify_llama import ( 8 | enable_llama_pos_shift_attention, 9 | ) 10 | 11 | enable_llama_pos_shift_attention(model, use_flash_attn) 12 | elif "mpt" in model.config.model_type: 13 | v_seq_dim = 2 14 | k_seq_dim = 3 15 | elif "gpt_neox" in model.config.model_type: 16 | k_seq_dim = v_seq_dim = 2 17 | from streaming_llm.pos_shift.modify_gpt_neox import ( 18 | enable_gpt_neox_pos_shift_attention, 19 | ) 20 | 21 | enable_gpt_neox_pos_shift_attention(model) 22 | elif "falcon" in model.config.model_type: 23 | v_seq_dim = 1 24 | k_seq_dim = 1 25 | from streaming_llm.pos_shift.modify_falcon import ( 26 | enable_falcon_pos_shift_attention, 27 | ) 28 | 29 | enable_falcon_pos_shift_attention(model) 30 | else: 31 | raise ValueError(f"got {model.config.model_type}") 32 | kv_cache = StartRecentKVCache( 33 | start_size=start_size, 34 | recent_size=recent_size, 35 | k_seq_dim=k_seq_dim, 36 | v_seq_dim=v_seq_dim, 37 | ) 38 | return kv_cache 39 | -------------------------------------------------------------------------------- /streaming_llm/kv_cache.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def slice2d(x, start, end): 5 | return x[:, :, start:end, ...] 6 | 7 | 8 | def slice3d(x, start, end): 9 | return x[:, :, :, start:end, ...] 10 | 11 | 12 | def slice1d(x, start, end): 13 | return x[:, start:end, ...] 14 | 15 | 16 | DIM_TO_SLICE = { 17 | 1: slice1d, 18 | 2: slice2d, 19 | 3: slice3d, 20 | } 21 | 22 | 23 | class StartRecentKVCache: 24 | def __init__( 25 | self, 26 | start_size=4, 27 | recent_size=512, 28 | k_seq_dim=2, 29 | v_seq_dim=2, 30 | ): 31 | print(f"StartRecentKVCache: {start_size}, {recent_size}") 32 | self.start_size = start_size 33 | self.recent_size = recent_size 34 | self.cache_size = start_size + recent_size 35 | self.k_seq_dim = k_seq_dim 36 | self.v_seq_dim = v_seq_dim 37 | self.k_slice = DIM_TO_SLICE[k_seq_dim] 38 | self.v_slice = DIM_TO_SLICE[v_seq_dim] 39 | 40 | def __call__(self, past_key_values): 41 | if past_key_values is None: 42 | return None 43 | seq_len = past_key_values[0][0].size(self.k_seq_dim) 44 | if seq_len <= self.cache_size: 45 | return past_key_values 46 | return [ 47 | [ 48 | torch.cat( 49 | [ 50 | self.k_slice(k, 0, self.start_size), 51 | self.k_slice(k, seq_len - self.recent_size, seq_len), 52 | ], 53 | dim=self.k_seq_dim, 54 | ), 55 | torch.cat( 56 | [ 57 | self.v_slice(v, 0, self.start_size), 58 | self.v_slice(v, seq_len - self.recent_size, seq_len), 59 | ], 60 | dim=self.v_seq_dim, 61 | ), 62 | ] 63 | for k, v in past_key_values 64 | ] 65 | 66 | def evict_for_space(self, past_key_values, num_coming): 67 | if past_key_values is None: 68 | return None 69 | seq_len = past_key_values[0][0].size(self.k_seq_dim) 70 | if seq_len + num_coming <= self.cache_size: 71 | return past_key_values 72 | return [ 73 | [ 74 | torch.cat( 75 | [ 76 | self.k_slice(k, 0, self.start_size), 77 | self.k_slice( 78 | k, seq_len - self.recent_size + num_coming, seq_len 79 | ), 80 | ], 81 | dim=self.k_seq_dim, 82 | ), 83 | torch.cat( 84 | [ 85 | self.v_slice(v, 0, self.start_size), 86 | self.v_slice( 87 | v, seq_len - self.recent_size + num_coming, seq_len 88 | ), 89 | ], 90 | dim=self.v_seq_dim, 91 | ), 92 | ] 93 | for k, v in past_key_values 94 | ] 95 | 96 | def evict_range(self, past_key_values, start, end): 97 | if past_key_values is None: 98 | return None 99 | seq_len = past_key_values[0][0].size(self.k_seq_dim) 100 | assert start <= end and end <= seq_len 101 | return [ 102 | [ 103 | torch.cat( 104 | [ 105 | self.k_slice(k, 0, start), 106 | self.k_slice(k, end, seq_len), 107 | ], 108 | dim=self.k_seq_dim, 109 | ), 110 | torch.cat( 111 | [ 112 | self.v_slice(v, 0, start), 113 | self.v_slice(v, end, seq_len), 114 | ], 115 | dim=self.v_seq_dim, 116 | ), 117 | ] 118 | for k, v in past_key_values 119 | ] 120 | -------------------------------------------------------------------------------- /streaming_llm/pos_shift/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/dvlab-research/LongLoRA/d4eb344c5ccc9e91c0812a2b2aeea69070df8c33/streaming_llm/pos_shift/__init__.py -------------------------------------------------------------------------------- /streaming_llm/pos_shift/modify_falcon.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Optional, Tuple 3 | 4 | import torch 5 | from torch import nn 6 | import torch.utils.checkpoint 7 | 8 | import torch.nn.functional as F 9 | 10 | from transformers.models.falcon.modeling_falcon import ( 11 | FalconAttention, 12 | rotate_half, 13 | ) 14 | import types 15 | 16 | __all__ = ["enable_falcon_pos_shift_attention"] 17 | 18 | 19 | def falcon_pos_shift_attention_forward( 20 | self, 21 | hidden_states: torch.Tensor, 22 | alibi: torch.Tensor, 23 | attention_mask: torch.Tensor, 24 | layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, 25 | head_mask: Optional[torch.Tensor] = None, 26 | use_cache: bool = False, 27 | output_attentions: bool = False, 28 | ): 29 | fused_qkv = self.query_key_value( 30 | hidden_states 31 | ) # [batch_size, seq_length, 3 x hidden_size] 32 | 33 | # 3 x [batch_size, seq_length, num_heads, head_dim] 34 | (query_layer, key_layer, value_layer) = self._split_heads(fused_qkv) 35 | 36 | batch_size, q_length, _, _ = query_layer.shape 37 | 38 | query_layer = query_layer.transpose(1, 2).reshape( 39 | batch_size * self.num_heads, q_length, self.head_dim 40 | ) 41 | 42 | # dirty hack to fix the inconsistency between falcon-40b and falcon-7b 43 | num_kv = self.num_heads if self.num_heads == 128 else self.num_kv 44 | key_layer = key_layer.transpose(1, 2).reshape( 45 | batch_size * num_kv, 46 | q_length, 47 | self.head_dim, 48 | ) 49 | value_layer = value_layer.transpose(1, 2).reshape( 50 | batch_size * num_kv, q_length, self.head_dim 51 | ) 52 | 53 | past_len = 0 54 | if layer_past is not None: 55 | past_len = layer_past[0].shape[1] 56 | 57 | query_layer_copy = query_layer.clone() 58 | query_layer, _ = self.maybe_rotary(query_layer, query_layer_copy, past_len) 59 | if layer_past is not None: 60 | past_key, past_value = layer_past 61 | # concatenate along seq_length dimension: 62 | # - key: [batch_size * self.num_heads, head_dim, kv_length] 63 | # - value: [batch_size * self.num_heads, kv_length, head_dim] 64 | key_layer = torch.cat((past_key, key_layer), dim=1) 65 | value_layer = torch.cat((past_value, value_layer), dim=1) 66 | 67 | if use_cache is True: 68 | present = (key_layer, value_layer) 69 | else: 70 | present = None 71 | 72 | key_layer_copy = key_layer.clone() 73 | _, key_layer = self.maybe_rotary(key_layer_copy, key_layer, 0) 74 | 75 | _, kv_length, _ = key_layer.shape 76 | 77 | if alibi is None: 78 | query_layer_ = query_layer.reshape( 79 | batch_size, self.num_heads, -1, self.head_dim 80 | ) 81 | key_layer_ = key_layer.reshape(batch_size, num_kv, -1, self.head_dim) 82 | value_layer_ = value_layer.reshape(batch_size, num_kv, -1, self.head_dim) 83 | 84 | if layer_past is not None: 85 | attn_output = F.scaled_dot_product_attention( 86 | query_layer_, key_layer_, value_layer_, None, 0.0, is_causal=False 87 | ) 88 | else: 89 | attn_output = F.scaled_dot_product_attention( 90 | query_layer_, key_layer_, value_layer_, None, 0.0, is_causal=True 91 | ) 92 | 93 | x = attn_output.view(batch_size, self.num_heads, q_length, self.head_dim) 94 | x = x.permute(0, 2, 1, 3) 95 | attn_output = x.reshape(batch_size, q_length, self.num_heads * self.head_dim) 96 | 97 | output_tensor = self.dense(attn_output) 98 | 99 | outputs = (output_tensor, present) 100 | assert not output_attentions # not supported. 101 | return outputs 102 | else: 103 | attention_mask_float = ( 104 | (attention_mask * 1.0).masked_fill(attention_mask, -1e9).to(torch.bfloat16) 105 | ) 106 | matmul_result = query_layer @ key_layer.transpose(-1, -2) 107 | 108 | # change view to [batch_size, num_heads, q_length, kv_length] 109 | attention_scores = matmul_result.view( 110 | batch_size, self.num_heads, q_length, kv_length 111 | ) 112 | 113 | # cast attention scores to fp32, compute scaled softmax and cast back to initial dtype - [batch_size, num_heads, q_length, kv_length] 114 | input_dtype = attention_scores.dtype 115 | # `float16` has a minimum value of -65504.0, whereas `bfloat16` and `float32` have a minimum value of `-3.4e+38` 116 | if input_dtype == torch.float16 or input_dtype == torch.bfloat16: 117 | attention_scores = attention_scores.to(torch.float32) 118 | # attn_weights = torch.masked_fill(attention_scores, attention_mask, torch.finfo(attention_scores.dtype).min) 119 | attention_probs = F.softmax( 120 | (attention_scores + alibi.view(batch_size, self.num_heads, 1, -1)) 121 | * self.inv_norm_factor 122 | + attention_mask_float, 123 | dim=-1, 124 | dtype=hidden_states.dtype, 125 | ) 126 | # [batch_size, num_heads, q_length, kv_length] 127 | attention_probs = self.attention_dropout(attention_probs) 128 | 129 | if head_mask is not None: 130 | attention_probs = attention_probs * head_mask 131 | 132 | # change view [batch_size x num_heads, q_length, kv_length] 133 | attention_probs_reshaped = attention_probs.view( 134 | batch_size * self.num_heads, q_length, kv_length 135 | ) 136 | 137 | # matmul: [batch_size * num_heads, q_length, head_dim] 138 | context_layer = attention_probs_reshaped @ value_layer 139 | 140 | # change view [batch_size, num_heads, q_length, head_dim] 141 | context_layer = self._merge_heads(context_layer) 142 | 143 | output_tensor = self.dense(context_layer) 144 | 145 | outputs = (output_tensor, present) 146 | if output_attentions: 147 | outputs += (attention_probs,) 148 | 149 | return outputs 150 | 151 | 152 | def enable_falcon_pos_shift_attention(model): 153 | for name, module in reversed(model._modules.items()): 154 | if len(list(module.children())) > 0: 155 | enable_falcon_pos_shift_attention( 156 | module, 157 | ) 158 | 159 | if "self_attention" == name[-14:]: 160 | model._modules[name].forward = types.MethodType( 161 | falcon_pos_shift_attention_forward, model._modules[name] 162 | ) 163 | -------------------------------------------------------------------------------- /streaming_llm/pos_shift/modify_gpt_neox.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Optional, Tuple 3 | 4 | import torch 5 | from torch import nn 6 | import torch.utils.checkpoint 7 | 8 | import torch.nn.functional as F 9 | 10 | from transformers.models.gpt_neox.modeling_gpt_neox import ( 11 | apply_rotary_pos_emb, 12 | rotate_half, 13 | GPTNeoXAttention, 14 | ) 15 | import types 16 | 17 | __all__ = ["enable_gpt_neox_pos_shift_attention"] 18 | 19 | 20 | def apply_rotary_pos_emb_single(x, cos, sin, position_ids): 21 | gather_indices = position_ids[:, None, :, None] # [bs, 1, seq_len, 1] 22 | gather_indices = gather_indices.repeat(1, cos.shape[1], 1, cos.shape[3]) 23 | cos = torch.gather(cos.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices) 24 | sin = torch.gather(sin.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices) 25 | x_embed = (x * cos) + (rotate_half(x) * sin) 26 | return x_embed 27 | 28 | 29 | def gpt_neox_pos_shift_attention_forward( 30 | self, 31 | hidden_states: torch.FloatTensor, 32 | attention_mask: torch.FloatTensor, 33 | position_ids: torch.LongTensor, 34 | head_mask: Optional[torch.FloatTensor] = None, 35 | layer_past: Optional[Tuple[torch.Tensor]] = None, 36 | use_cache: Optional[bool] = False, 37 | output_attentions: Optional[bool] = False, 38 | ): 39 | has_layer_past = layer_past is not None 40 | 41 | # Compute QKV 42 | # Attention heads [batch, seq_len, hidden_size] 43 | # --> [batch, seq_len, (np * 3 * head_size)] 44 | qkv = self.query_key_value(hidden_states) 45 | 46 | # [batch, seq_len, (num_heads * 3 * head_size)] 47 | # --> [batch, seq_len, num_heads, 3 * head_size] 48 | new_qkv_shape = qkv.size()[:-1] + (self.num_attention_heads, 3 * self.head_size) 49 | qkv = qkv.view(*new_qkv_shape) 50 | 51 | # [batch, seq_len, num_attention_heads, 3 * head_size] --> 3 [batch, num_attention_heads, seq_len, head_size] 52 | query = qkv[..., : self.head_size].permute(0, 2, 1, 3) 53 | key = qkv[..., self.head_size : 2 * self.head_size].permute(0, 2, 1, 3) 54 | value = qkv[..., 2 * self.head_size :].permute(0, 2, 1, 3) 55 | 56 | # Compute rotary embeddings on rotary_ndims 57 | query_rot = query[..., : self.rotary_ndims] 58 | query_pass = query[..., self.rotary_ndims :] 59 | 60 | # Compute token offset for rotary embeddings (when decoding) 61 | seq_len = key.shape[-2] 62 | if has_layer_past: 63 | seq_len += layer_past[0].shape[-2] 64 | cos, sin = self.rotary_emb(value, seq_len=seq_len) 65 | query = apply_rotary_pos_emb_single(query_rot, cos, sin, position_ids) 66 | query = torch.cat((query, query_pass), dim=-1) 67 | 68 | # Cache QKV values 69 | if has_layer_past: 70 | past_key = layer_past[0] 71 | past_value = layer_past[1] 72 | key = torch.cat((past_key, key), dim=-2) 73 | value = torch.cat((past_value, value), dim=-2) 74 | 75 | present = (key, value) if use_cache else None 76 | 77 | key_rot = key[..., : self.rotary_ndims] 78 | key_pass = key[..., self.rotary_ndims :] 79 | key_position_ids = torch.arange(seq_len, device=position_ids.device).unsqueeze(0) 80 | key = apply_rotary_pos_emb_single(key_rot, cos, sin, key_position_ids) 81 | key = torch.cat((key, key_pass), dim=-1) 82 | 83 | # Compute attention 84 | attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask) 85 | 86 | # Reshape outputs 87 | attn_output = self._merge_heads( 88 | attn_output, self.num_attention_heads, self.head_size 89 | ) 90 | attn_output = self.dense(attn_output) 91 | 92 | outputs = (attn_output, present) 93 | if output_attentions: 94 | outputs += (attn_weights,) 95 | 96 | return outputs 97 | 98 | 99 | def enable_gpt_neox_pos_shift_attention(model): 100 | for name, module in reversed(model._modules.items()): 101 | if len(list(module.children())) > 0: 102 | enable_gpt_neox_pos_shift_attention( 103 | module, 104 | ) 105 | 106 | if isinstance(module, GPTNeoXAttention): 107 | module.forward = types.MethodType( 108 | gpt_neox_pos_shift_attention_forward, module 109 | ) 110 | -------------------------------------------------------------------------------- /streaming_llm/pos_shift/modify_llama.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Optional, Tuple 3 | 4 | import torch 5 | from torch import nn 6 | import torch.utils.checkpoint 7 | 8 | import torch.nn.functional as F 9 | 10 | from transformers.models.llama.modeling_llama import ( 11 | LlamaAttention, 12 | rotate_half, 13 | apply_rotary_pos_emb, 14 | repeat_kv, 15 | ) 16 | import types 17 | import transformers 18 | from einops import rearrange 19 | from flash_attn import __version__ as flash_attn_version 20 | from flash_attn.bert_padding import pad_input, unpad_input 21 | from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func 22 | 23 | __all__ = ["enable_llama_pos_shift_attention"] 24 | 25 | 26 | def apply_rotary_pos_emb_single(x, cos, sin, position_ids): 27 | # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. 28 | cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] 29 | sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] 30 | cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] 31 | sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] 32 | x_embed = (x * cos) + (rotate_half(x) * sin) 33 | return x_embed 34 | 35 | 36 | def llama_pos_shift_attention_forward( 37 | self, 38 | hidden_states: torch.Tensor, 39 | attention_mask: Optional[torch.Tensor] = None, 40 | position_ids: Optional[torch.LongTensor] = None, 41 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 42 | output_attentions: bool = False, 43 | use_cache: bool = False, 44 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 45 | bsz, q_len, _ = hidden_states.size() 46 | 47 | if self.config.pretraining_tp > 1: 48 | key_value_slicing = ( 49 | self.num_key_value_heads * self.head_dim 50 | ) // self.config.pretraining_tp 51 | query_slices = self.q_proj.weight.split( 52 | (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0 53 | ) 54 | key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) 55 | value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) 56 | 57 | query_states = [ 58 | F.linear(hidden_states, query_slices[i]) 59 | for i in range(self.config.pretraining_tp) 60 | ] 61 | query_states = torch.cat(query_states, dim=-1) 62 | 63 | key_states = [ 64 | F.linear(hidden_states, key_slices[i]) 65 | for i in range(self.config.pretraining_tp) 66 | ] 67 | key_states = torch.cat(key_states, dim=-1) 68 | 69 | value_states = [ 70 | F.linear(hidden_states, value_slices[i]) 71 | for i in range(self.config.pretraining_tp) 72 | ] 73 | value_states = torch.cat(value_states, dim=-1) 74 | 75 | else: 76 | query_states = self.q_proj(hidden_states) 77 | key_states = self.k_proj(hidden_states) 78 | value_states = self.v_proj(hidden_states) 79 | 80 | query_states = query_states.view( 81 | bsz, q_len, self.num_heads, self.head_dim 82 | ).transpose(1, 2) 83 | key_states = key_states.view( 84 | bsz, q_len, self.num_key_value_heads, self.head_dim 85 | ).transpose(1, 2) 86 | value_states = value_states.view( 87 | bsz, q_len, self.num_key_value_heads, self.head_dim 88 | ).transpose(1, 2) 89 | 90 | kv_seq_len = key_states.shape[-2] 91 | if past_key_value is not None: 92 | kv_seq_len += past_key_value[0].shape[-2] 93 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) 94 | ### Shift Pos: query pos is min(cache_size, idx) 95 | # query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) 96 | query_states = apply_rotary_pos_emb_single(query_states, cos, sin, position_ids) 97 | ### 98 | 99 | if past_key_value is not None: 100 | # reuse k, v, self_attention 101 | key_states = torch.cat([past_key_value[0], key_states], dim=2) 102 | value_states = torch.cat([past_key_value[1], value_states], dim=2) 103 | 104 | past_key_value = (key_states, value_states) if use_cache else None 105 | 106 | ### Shift Pos: key pos is the pos in cache 107 | key_position_ids = torch.arange(kv_seq_len, device=position_ids.device).unsqueeze(0) 108 | key_states = apply_rotary_pos_emb_single(key_states, cos, sin, key_position_ids) 109 | ### 110 | 111 | # repeat k/v heads if n_kv_heads < n_heads 112 | key_states = repeat_kv(key_states, self.num_key_value_groups) 113 | value_states = repeat_kv(value_states, self.num_key_value_groups) 114 | 115 | attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt( 116 | self.head_dim 117 | ) 118 | 119 | if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): 120 | raise ValueError( 121 | f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" 122 | f" {attn_weights.size()}" 123 | ) 124 | 125 | if attention_mask is not None: 126 | if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): 127 | raise ValueError( 128 | f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" 129 | ) 130 | attn_weights = attn_weights + attention_mask 131 | 132 | # upcast attention to fp32 133 | attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to( 134 | query_states.dtype 135 | ) 136 | attn_output = torch.matmul(attn_weights, value_states) 137 | 138 | if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): 139 | raise ValueError( 140 | f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" 141 | f" {attn_output.size()}" 142 | ) 143 | 144 | attn_output = attn_output.transpose(1, 2).contiguous() 145 | attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) 146 | 147 | if self.config.pretraining_tp > 1: 148 | attn_output = attn_output.split( 149 | self.hidden_size // self.config.pretraining_tp, dim=2 150 | ) 151 | o_proj_slices = self.o_proj.weight.split( 152 | self.hidden_size // self.config.pretraining_tp, dim=1 153 | ) 154 | attn_output = sum( 155 | [ 156 | F.linear(attn_output[i], o_proj_slices[i]) 157 | for i in range(self.config.pretraining_tp) 158 | ] 159 | ) 160 | else: 161 | attn_output = self.o_proj(attn_output) 162 | 163 | if not output_attentions: 164 | attn_weights = None 165 | 166 | return attn_output, attn_weights, past_key_value 167 | 168 | 169 | def llama_pos_shift_attention_forward_flashattn( 170 | self, 171 | hidden_states: torch.Tensor, 172 | attention_mask: Optional[torch.Tensor] = None, 173 | position_ids: Optional[torch.LongTensor] = None, 174 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 175 | output_attentions: bool = False, 176 | use_cache: bool = False, 177 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 178 | bsz, q_len, _ = hidden_states.size() 179 | 180 | query_states = self.q_proj(hidden_states) 181 | key_states = self.k_proj(hidden_states) 182 | value_states = self.v_proj(hidden_states) 183 | 184 | query_states = query_states.view( 185 | bsz, q_len, self.num_heads, self.head_dim 186 | ).transpose(1, 2) 187 | key_states = key_states.view( 188 | bsz, q_len, self.num_key_value_heads, self.head_dim 189 | ).transpose(1, 2) 190 | value_states = value_states.view( 191 | bsz, q_len, self.num_key_value_heads, self.head_dim 192 | ).transpose(1, 2) 193 | 194 | kv_seq_len = key_states.shape[-2] 195 | if past_key_value is not None: 196 | kv_seq_len += past_key_value[0].shape[-2] 197 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) 198 | ### Shift Pos: query pos is min(cache_size, idx) 199 | # query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) 200 | query_states = apply_rotary_pos_emb_single(query_states, cos, sin, position_ids) 201 | ### 202 | 203 | if past_key_value is not None: 204 | # reuse k, v, self_attention 205 | key_states = torch.cat([past_key_value[0], key_states], dim=2) 206 | value_states = torch.cat([past_key_value[1], value_states], dim=2) 207 | 208 | past_key_value = (key_states, value_states) if use_cache else None 209 | 210 | ### Shift Pos: key pos is the pos in cache 211 | key_position_ids = torch.arange(kv_seq_len, device=position_ids.device).unsqueeze(0) 212 | key_states = apply_rotary_pos_emb_single(key_states, cos, sin, key_position_ids) 213 | ### 214 | 215 | # repeat k/v heads if n_kv_heads < n_heads 216 | key_states = repeat_kv(key_states, self.num_key_value_groups) 217 | value_states = repeat_kv(value_states, self.num_key_value_groups) 218 | 219 | if past_key_value is None: 220 | qkv = torch.stack( 221 | [query_states, key_states, value_states], dim=2 222 | ) # [bsz, nh, 3, q_len, hd] 223 | qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd] 224 | 225 | key_padding_mask = torch.full((bsz, q_len), True, dtype=torch.bool, device=attention_mask.device) 226 | nheads = qkv.shape[-2] 227 | x = rearrange(qkv, "b s three h d -> b s (three h d)") 228 | x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask) 229 | x_unpad = rearrange( 230 | x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads 231 | ) 232 | output_unpad = flash_attn_varlen_qkvpacked_func( 233 | x_unpad, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True 234 | ) 235 | output = rearrange( 236 | pad_input( 237 | rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices, bsz, q_len 238 | ), 239 | "b s (h d) -> b s h d", 240 | h=nheads, 241 | ) 242 | output = output.reshape(bsz, q_len, self.num_heads, self.head_dim) 243 | 244 | attn_output = self.o_proj(rearrange(output, "b s h d -> b s (h d)")) 245 | attn_weights = None 246 | else: 247 | attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt( 248 | self.head_dim 249 | ) 250 | 251 | if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): 252 | raise ValueError( 253 | f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" 254 | f" {attn_weights.size()}" 255 | ) 256 | 257 | if attention_mask is not None: 258 | if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): 259 | raise ValueError( 260 | f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" 261 | ) 262 | attn_weights = attn_weights + attention_mask 263 | 264 | # upcast attention to fp32 265 | attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to( 266 | query_states.dtype 267 | ) 268 | attn_output = torch.matmul(attn_weights, value_states) 269 | 270 | if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): 271 | raise ValueError( 272 | f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" 273 | f" {attn_output.size()}" 274 | ) 275 | 276 | attn_output = attn_output.transpose(1, 2).contiguous() 277 | attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) 278 | 279 | if self.config.pretraining_tp > 1: 280 | attn_output = attn_output.split( 281 | self.hidden_size // self.config.pretraining_tp, dim=2 282 | ) 283 | o_proj_slices = self.o_proj.weight.split( 284 | self.hidden_size // self.config.pretraining_tp, dim=1 285 | ) 286 | attn_output = sum( 287 | [ 288 | F.linear(attn_output[i], o_proj_slices[i]) 289 | for i in range(self.config.pretraining_tp) 290 | ] 291 | ) 292 | else: 293 | attn_output = self.o_proj(attn_output) 294 | 295 | if not output_attentions: 296 | attn_weights = None 297 | 298 | return attn_output, attn_weights, past_key_value 299 | 300 | 301 | def enable_llama_pos_shift_attention(model, use_flash_attn=True): 302 | for name, module in reversed(model._modules.items()): 303 | if len(list(module.children())) > 0: 304 | enable_llama_pos_shift_attention( 305 | module, 306 | ) 307 | 308 | if isinstance(module, LlamaAttention): 309 | model._modules[name].forward = types.MethodType( 310 | llama_pos_shift_attention_forward_flashattn if use_flash_attn else llama_pos_shift_attention_forward, model._modules[name] 311 | ) 312 | -------------------------------------------------------------------------------- /streaming_llm/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import argparse 3 | from transformers import ( 4 | AutoTokenizer, 5 | AutoModelForCausalLM, 6 | ) 7 | import os.path as osp 8 | import ssl 9 | import urllib.request 10 | import os 11 | import json 12 | 13 | 14 | def parse_args(): 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument( 17 | "--model_name_or_path", type=str, default="models/llama/llama-7b" 18 | ) 19 | parser.add_argument("--revision", type=str, default="main") 20 | parser.add_argument("--tokenizer_name_or_path", type=str, default=None) 21 | parser.add_argument("--dataset_name", type=str, default="wikitext") 22 | 23 | parser.add_argument("--task", type=str, default="wikitext-2-raw-v1") 24 | parser.add_argument( 25 | "--split", type=str, default="test", choices=["validation", "test"] 26 | ) 27 | 28 | parser.add_argument( 29 | "--num_samples", 30 | type=int, 31 | default=1, 32 | ) 33 | 34 | parser.add_argument( 35 | "--output_dir", 36 | type=str, 37 | default="outputs/debug", 38 | ) 39 | 40 | parser.add_argument("--enable_start_recent_kv_cache", action="store_true") 41 | parser.add_argument("--start_size", type=int, default=1) 42 | parser.add_argument("--recent_size", type=int, default=255) 43 | parser.add_argument("--enable_pos_shift", action="store_true") 44 | 45 | parser.add_argument("--num_eval_tokens", type=int, default=None) 46 | 47 | args = parser.parse_args() 48 | return args 49 | 50 | 51 | def load(model_name_or_path): 52 | print(f"Loading model from {model_name_or_path} ...") 53 | # however, tensor parallel for running falcon will occur bugs 54 | tokenizer = AutoTokenizer.from_pretrained( 55 | model_name_or_path, 56 | trust_remote_code=True, 57 | ) 58 | model = AutoModelForCausalLM.from_pretrained( 59 | model_name_or_path, 60 | device_map="auto", 61 | torch_dtype=torch.float16, 62 | trust_remote_code=True, 63 | ) 64 | if tokenizer.pad_token_id is None: 65 | if tokenizer.eos_token_id is not None: 66 | tokenizer.pad_token_id = tokenizer.eos_token_id 67 | else: 68 | tokenizer.pad_token_id = 0 69 | 70 | model.eval() 71 | 72 | return model, tokenizer 73 | 74 | 75 | def download_url(url: str, folder="folder"): 76 | """ 77 | Downloads the content of an url to a folder. Modified from \ 78 | https://github.com/pyg-team/pytorch_geometric/tree/master/torch_geometric 79 | 80 | Args: 81 | url (string): The url of target file. 82 | folder (string): The target folder. 83 | 84 | Returns: 85 | string: File path of downloaded files. 86 | """ 87 | 88 | file = url.rpartition("/")[2] 89 | file = file if file[0] == "?" else file.split("?")[0] 90 | path = osp.join(folder, file) 91 | if osp.exists(path): 92 | print(f"File {file} exists, use existing file.") 93 | return path 94 | 95 | print(f"Downloading {url}") 96 | os.makedirs(folder, exist_ok=True) 97 | ctx = ssl._create_unverified_context() 98 | data = urllib.request.urlopen(url, context=ctx) 99 | with open(path, "wb") as f: 100 | f.write(data.read()) 101 | 102 | return path 103 | 104 | 105 | def load_jsonl( 106 | file_path, 107 | ): 108 | list_data_dict = [] 109 | with open(file_path, "r") as f: 110 | for line in f: 111 | list_data_dict.append(json.loads(line)) 112 | return list_data_dict 113 | -------------------------------------------------------------------------------- /supervised-fine-tune-qlora.py: -------------------------------------------------------------------------------- 1 | # Written by Yukang Chen 2 | # Some code based on https://github.com/huggingface/peft/blob/main/examples/fp4_finetuning/finetune_fp4_opt_bnb_peft.py 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import io 17 | import os 18 | import copy 19 | import json 20 | import math 21 | import logging 22 | from dataclasses import dataclass, field 23 | from typing import Dict, Optional, Sequence 24 | 25 | import torch 26 | import torch.nn as nn 27 | import transformers 28 | from torch.utils.data import Dataset 29 | from transformers import Trainer, DataCollatorForLanguageModeling, BitsAndBytesConfig 30 | from llama_attn_replace_sft import replace_llama_attn 31 | from gptneox_attn_replace import replace_gpt_neox_attn 32 | from peft import LoraConfig, get_peft_model 33 | from torch.distributed import barrier 34 | 35 | IGNORE_INDEX = -100 36 | DEFAULT_PAD_TOKEN = "[PAD]" 37 | DEFAULT_EOS_TOKEN = "" 38 | DEFAULT_BOS_TOKEN = "" 39 | DEFAULT_UNK_TOKEN = "" 40 | 41 | def _make_r_io_base(f, mode: str): 42 | if not isinstance(f, io.IOBase): 43 | f = open(f, mode=mode) 44 | return f 45 | 46 | def jload(f, mode="r"): 47 | """Load a .json file into a dictionary.""" 48 | f = _make_r_io_base(f, mode) 49 | jdict = json.load(f) 50 | f.close() 51 | return jdict 52 | 53 | PROMPT_DICT = { 54 | "prompt_input": ( 55 | "Below is an instruction that describes a task, paired with an input that provides further context. " 56 | "Write a response that appropriately completes the request.\n\n" 57 | "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:" 58 | ), 59 | "prompt_no_input": ( 60 | "Below is an instruction that describes a task. " 61 | "Write a response that appropriately completes the request.\n\n" 62 | "### Instruction:\n{instruction}\n\n### Response:" 63 | ), 64 | "prompt_no_input_llama2":( 65 | "[INST] <>\n" 66 | "You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\n" 67 | "If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\n" 68 | "<> \n\n {instruction} [/INST]" 69 | ), 70 | "prompt_input_llama2": ( 71 | "[INST] <>\n" 72 | "You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\n" 73 | "If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\n" 74 | "<> \n\n {instruction} \n{input} [/INST]" 75 | ), 76 | "prompt_llama2": "[INST]{instruction}[/INST]" 77 | } 78 | 79 | 80 | @dataclass 81 | class ModelArguments: 82 | model_name_or_path: Optional[str] = field(default="EleutherAI/pythia-1.4b-deduped") 83 | model_type: Optional[str] = field(default="llama") 84 | 85 | 86 | @dataclass 87 | class DataArguments: 88 | data_path: str = field(default=None, metadata={"help": "Path to the training data."}) 89 | 90 | 91 | @dataclass 92 | class TrainingArguments(transformers.TrainingArguments): 93 | cache_dir: Optional[str] = field(default=None) 94 | optim: str = field(default="adamw_torch") 95 | model_max_length: int = field( 96 | default=8192 * 4, 97 | metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."}, 98 | ) 99 | use_flash_attn: bool = field( 100 | default=True, 101 | metadata={"help": "Whether use flash attention for training."}, 102 | ) 103 | use_full_attn: bool = field( 104 | default=False, 105 | metadata={"help": "Whether to use plain, full-attention for training."}, 106 | ) 107 | low_rank_training: bool = field( 108 | default=True, 109 | metadata={"help": "Whether use low rank adaptation for training."}, 110 | ) 111 | trainable_params: str = field( 112 | default="embed,norm", 113 | metadata={"help": "Additional trainable parameters except LoRA weights, if low rank training."}, 114 | ) 115 | 116 | def smart_tokenizer_and_embedding_resize( 117 | special_tokens_dict: Dict, 118 | tokenizer: transformers.PreTrainedTokenizer, 119 | model: transformers.PreTrainedModel, 120 | ): 121 | """Resize tokenizer and embedding. 122 | 123 | Note: This is the unoptimized version that may make your embedding size not be divisible by 64. 124 | """ 125 | num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict) 126 | model.resize_token_embeddings(len(tokenizer)) 127 | 128 | if num_new_tokens > 0: 129 | input_embeddings = model.get_input_embeddings().weight.data 130 | output_embeddings = model.get_output_embeddings().weight.data 131 | 132 | input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True) 133 | output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True) 134 | 135 | input_embeddings[-num_new_tokens:] = input_embeddings_avg 136 | output_embeddings[-num_new_tokens:] = output_embeddings_avg 137 | 138 | 139 | def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer) -> Dict: 140 | """Tokenize a list of strings.""" 141 | tokenized_list = [ 142 | tokenizer( 143 | text, 144 | return_tensors="pt", 145 | padding="longest", 146 | max_length=tokenizer.model_max_length, 147 | truncation=True, 148 | ) 149 | for text in strings 150 | ] 151 | input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list] 152 | input_ids_lens = labels_lens = [ 153 | tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list 154 | ] 155 | return dict( 156 | input_ids=input_ids, 157 | labels=labels, 158 | input_ids_lens=input_ids_lens, 159 | labels_lens=labels_lens, 160 | ) 161 | 162 | 163 | def preprocess( 164 | sources: Sequence[str], 165 | targets: Sequence[str], 166 | tokenizer: transformers.PreTrainedTokenizer, 167 | ) -> Dict: 168 | """Preprocess the data by tokenizing.""" 169 | examples = [s + t for s, t in zip(sources, targets)] 170 | examples_tokenized, sources_tokenized = [_tokenize_fn(strings, tokenizer) for strings in (examples, sources)] 171 | input_ids = examples_tokenized["input_ids"] 172 | labels = copy.deepcopy(input_ids) 173 | for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]): 174 | label[:source_len] = IGNORE_INDEX 175 | return dict(input_ids=input_ids, labels=labels) 176 | 177 | 178 | class SupervisedDataset(Dataset): 179 | """Dataset for supervised fine-tuning.""" 180 | 181 | def __init__(self, data_path: str, tokenizer: transformers.PreTrainedTokenizer): 182 | super(SupervisedDataset, self).__init__() 183 | logging.warning("Loading data...") 184 | list_data_dict = jload(data_path) 185 | 186 | logging.warning("Formatting inputs...") 187 | 188 | prompt_input, prompt_no_input = PROMPT_DICT["prompt_input_llama2"], PROMPT_DICT["prompt_llama2"] 189 | sources = [ 190 | prompt_input.format_map(example) if example.get("input", "") != "" else prompt_no_input.format_map(example) 191 | for example in list_data_dict 192 | ] 193 | 194 | targets = [f"{example['output']}{tokenizer.eos_token}" for example in list_data_dict] 195 | 196 | logging.warning("Tokenizing inputs... This may take some time...") 197 | data_dict = preprocess(sources, targets, tokenizer) 198 | 199 | self.input_ids = data_dict["input_ids"] 200 | self.labels = data_dict["labels"] 201 | 202 | def __len__(self): 203 | return len(self.input_ids) 204 | 205 | def __getitem__(self, i) -> Dict[str, torch.Tensor]: 206 | return dict(input_ids=self.input_ids[i], labels=self.labels[i]) 207 | 208 | 209 | @dataclass 210 | class DataCollatorForSupervisedDataset(object): 211 | """Collate examples for supervised fine-tuning.""" 212 | 213 | tokenizer: transformers.PreTrainedTokenizer 214 | 215 | def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: 216 | input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels")) 217 | input_ids = torch.nn.utils.rnn.pad_sequence( 218 | input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id 219 | ) 220 | labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX) 221 | return dict( 222 | input_ids=input_ids, 223 | labels=labels, 224 | attention_mask=input_ids.ne(self.tokenizer.pad_token_id), 225 | ) 226 | 227 | 228 | def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args) -> Dict: 229 | """Make dataset and collator for supervised fine-tuning.""" 230 | train_dataset = SupervisedDataset(tokenizer=tokenizer, data_path=data_args.data_path) 231 | data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer) 232 | return dict(train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator) 233 | 234 | 235 | def train(): 236 | parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments)) 237 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 238 | 239 | # NOTE: May expand supported model types in the future 240 | if model_args.model_type == "gpt-neox": 241 | replace_gpt_neox_attn(training_args.use_flash_attn, training_args.use_full_attn) 242 | else: 243 | replace_llama_attn(training_args.use_flash_attn, training_args.use_full_attn) 244 | 245 | # Set RoPE scaling factor 246 | config = transformers.AutoConfig.from_pretrained( 247 | model_args.model_name_or_path, 248 | cache_dir=training_args.cache_dir, 249 | ) 250 | 251 | orig_rope_scaling = getattr(config, "rope_scaling", None) 252 | if orig_rope_scaling is None: 253 | orig_rope_scaling = {"factor": 1} 254 | orig_rope_scaling_factor = orig_rope_scaling["factor"] if "factor" in orig_rope_scaling.keys() else 1 255 | orig_ctx_len = getattr(config, "max_position_embeddings", None) 256 | if orig_ctx_len: 257 | orig_ctx_len *= orig_rope_scaling_factor 258 | if training_args.model_max_length > orig_ctx_len: 259 | scaling_factor = float(math.ceil(training_args.model_max_length / orig_ctx_len)) 260 | config.rope_scaling = {"type": "linear", "factor": scaling_factor} 261 | 262 | # Load model and tokenizer 263 | model = transformers.AutoModelForCausalLM.from_pretrained( 264 | model_args.model_name_or_path, 265 | config=config, 266 | cache_dir=training_args.cache_dir, 267 | torch_dtype=torch.bfloat16, 268 | quantization_config=BitsAndBytesConfig( 269 | load_in_4bit=True, 270 | llm_int8_threshold=6.0, 271 | llm_int8_has_fp16_weight=False, 272 | bnb_4bit_compute_dtype=torch.bfloat16, 273 | bnb_4bit_use_double_quant=True, 274 | bnb_4bit_quant_type="nf4", 275 | ), 276 | ) 277 | 278 | for param in model.parameters(): 279 | param.requires_grad = False # freeze the model - train adapters later 280 | if param.ndim == 1: 281 | # cast the small parameters (e.g. layernorm) to fp32 for stability 282 | param.data = param.data.to(torch.float32) 283 | 284 | tokenizer = transformers.AutoTokenizer.from_pretrained( 285 | model_args.model_name_or_path, 286 | cache_dir=training_args.cache_dir, 287 | model_max_length=training_args.model_max_length, 288 | padding_side="right", 289 | use_fast=True, 290 | ) 291 | 292 | special_tokens_dict = dict() 293 | if tokenizer.pad_token is None: 294 | special_tokens_dict["pad_token"] = DEFAULT_PAD_TOKEN 295 | if tokenizer.eos_token is None: 296 | special_tokens_dict["eos_token"] = DEFAULT_EOS_TOKEN 297 | if tokenizer.bos_token is None: 298 | special_tokens_dict["bos_token"] = DEFAULT_BOS_TOKEN 299 | if tokenizer.unk_token is None: 300 | special_tokens_dict["unk_token"] = DEFAULT_UNK_TOKEN 301 | 302 | smart_tokenizer_and_embedding_resize( 303 | special_tokens_dict=special_tokens_dict, 304 | tokenizer=tokenizer, 305 | model=model, 306 | ) 307 | 308 | data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args) 309 | 310 | if training_args.low_rank_training: 311 | if model_args.model_type == "gpt-neox": 312 | # added `dense` to match with llama as the basic LoRA would only target 'query_key_value' 313 | targets = ["query_key_value", "dense"] 314 | else: 315 | targets=["q_proj", "k_proj", "v_proj", "o_proj"] 316 | 317 | config = LoraConfig( 318 | r=8, 319 | lora_alpha=16, 320 | target_modules=targets, 321 | lora_dropout=0, 322 | bias="none", 323 | task_type="CAUSAL_LM", 324 | ) 325 | model = get_peft_model(model, config) 326 | # enable trainable params 327 | [p.requires_grad_() for n, p in model.named_parameters() if any([k in n for k in training_args.trainable_params.split(",")])] 328 | 329 | class CastOutputToFloat(nn.Sequential): 330 | def forward(self, x): 331 | return super().forward(x).to(torch.float32) 332 | 333 | model.lm_head = CastOutputToFloat(model.lm_head) 334 | 335 | # Verifying the datatypes. 336 | dtypes = {} 337 | for _, p in model.named_parameters(): 338 | dtype = p.dtype 339 | if dtype not in dtypes: 340 | dtypes[dtype] = 0 341 | dtypes[dtype] += p.numel() 342 | total = 0 343 | for k, v in dtypes.items(): 344 | total += v 345 | for k, v in dtypes.items(): 346 | print(k, v, v / total) 347 | 348 | model.config.use_cache = False # required for gradient checkpointing 349 | model.enable_input_require_grads() # required for gradient checkpointing 350 | model.gradient_checkpointing_enable() # enable gradient checkpointing 351 | 352 | trainer = Trainer(model=model, tokenizer=tokenizer, args=training_args, **data_module) 353 | trainer.train() 354 | trainer.save_state() 355 | trainer.save_model(output_dir=training_args.output_dir) 356 | 357 | 358 | if __name__ == "__main__": 359 | train() 360 | -------------------------------------------------------------------------------- /supervised-fine-tune.py: -------------------------------------------------------------------------------- 1 | # Written by Yukang Chen 2 | # Some code based on https://github.com/epfml/landmark-attention 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import io 17 | import os 18 | import copy 19 | import json 20 | import math 21 | import logging 22 | from dataclasses import dataclass, field 23 | from typing import Dict, Optional, Sequence 24 | 25 | import torch 26 | import transformers 27 | from torch.utils.data import Dataset 28 | from transformers import Trainer, DataCollatorForLanguageModeling 29 | from llama_attn_replace_sft import replace_llama_attn 30 | from gptneox_attn_replace import replace_gpt_neox_attn 31 | from peft import LoraConfig, get_peft_model 32 | from torch.distributed import barrier 33 | 34 | IGNORE_INDEX = -100 35 | DEFAULT_PAD_TOKEN = "[PAD]" 36 | DEFAULT_EOS_TOKEN = "" 37 | DEFAULT_BOS_TOKEN = "" 38 | DEFAULT_UNK_TOKEN = "" 39 | 40 | def _make_r_io_base(f, mode: str): 41 | if not isinstance(f, io.IOBase): 42 | f = open(f, mode=mode) 43 | return f 44 | 45 | def jload(f, mode="r"): 46 | """Load a .json file into a dictionary.""" 47 | f = _make_r_io_base(f, mode) 48 | jdict = json.load(f) 49 | f.close() 50 | return jdict 51 | 52 | PROMPT_DICT = { 53 | "prompt_input": ( 54 | "Below is an instruction that describes a task, paired with an input that provides further context. " 55 | "Write a response that appropriately completes the request.\n\n" 56 | "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:" 57 | ), 58 | "prompt_no_input": ( 59 | "Below is an instruction that describes a task. " 60 | "Write a response that appropriately completes the request.\n\n" 61 | "### Instruction:\n{instruction}\n\n### Response:" 62 | ), 63 | "prompt_no_input_llama2":( 64 | "[INST] <>\n" 65 | "You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\n" 66 | "If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\n" 67 | "<> \n\n {instruction} [/INST]" 68 | ), 69 | "prompt_input_llama2": ( 70 | "[INST] <>\n" 71 | "You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\n" 72 | "If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\n" 73 | "<> \n\n {instruction} \n{input} [/INST]" 74 | ), 75 | "prompt_llama2": "[INST]{instruction}[/INST]" 76 | } 77 | 78 | 79 | @dataclass 80 | class ModelArguments: 81 | model_name_or_path: Optional[str] = field(default="EleutherAI/pythia-1.4b-deduped") 82 | model_type: Optional[str] = field(default="llama") 83 | 84 | 85 | @dataclass 86 | class DataArguments: 87 | data_path: str = field(default=None, metadata={"help": "Path to the training data."}) 88 | 89 | 90 | @dataclass 91 | class TrainingArguments(transformers.TrainingArguments): 92 | cache_dir: Optional[str] = field(default=None) 93 | optim: str = field(default="adamw_torch") 94 | model_max_length: int = field( 95 | default=8192 * 4, 96 | metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."}, 97 | ) 98 | use_flash_attn: bool = field( 99 | default=True, 100 | metadata={"help": "Whether use flash attention for training."}, 101 | ) 102 | use_full_attn: bool = field( 103 | default=False, 104 | metadata={"help": "Whether to use plain, full-attention for training."}, 105 | ) 106 | low_rank_training: bool = field( 107 | default=True, 108 | metadata={"help": "Whether use low rank adaptation for training."}, 109 | ) 110 | trainable_params: str = field( 111 | default="embed,norm", 112 | metadata={"help": "Additional trainable parameters except LoRA weights, if low rank training."}, 113 | ) 114 | 115 | def smart_tokenizer_and_embedding_resize( 116 | special_tokens_dict: Dict, 117 | tokenizer: transformers.PreTrainedTokenizer, 118 | model: transformers.PreTrainedModel, 119 | ): 120 | """Resize tokenizer and embedding. 121 | 122 | Note: This is the unoptimized version that may make your embedding size not be divisible by 64. 123 | """ 124 | num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict) 125 | model.resize_token_embeddings(len(tokenizer)) 126 | 127 | if num_new_tokens > 0: 128 | input_embeddings = model.get_input_embeddings().weight.data 129 | output_embeddings = model.get_output_embeddings().weight.data 130 | 131 | input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True) 132 | output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True) 133 | 134 | input_embeddings[-num_new_tokens:] = input_embeddings_avg 135 | output_embeddings[-num_new_tokens:] = output_embeddings_avg 136 | 137 | 138 | def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer) -> Dict: 139 | """Tokenize a list of strings.""" 140 | tokenized_list = [ 141 | tokenizer( 142 | text, 143 | return_tensors="pt", 144 | padding="longest", 145 | max_length=tokenizer.model_max_length, 146 | truncation=True, 147 | ) 148 | for text in strings 149 | ] 150 | input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list] 151 | input_ids_lens = labels_lens = [ 152 | tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list 153 | ] 154 | return dict( 155 | input_ids=input_ids, 156 | labels=labels, 157 | input_ids_lens=input_ids_lens, 158 | labels_lens=labels_lens, 159 | ) 160 | 161 | 162 | def preprocess( 163 | sources: Sequence[str], 164 | targets: Sequence[str], 165 | tokenizer: transformers.PreTrainedTokenizer, 166 | ) -> Dict: 167 | """Preprocess the data by tokenizing.""" 168 | examples = [s + t for s, t in zip(sources, targets)] 169 | examples_tokenized, sources_tokenized = [_tokenize_fn(strings, tokenizer) for strings in (examples, sources)] 170 | input_ids = examples_tokenized["input_ids"] 171 | labels = copy.deepcopy(input_ids) 172 | for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]): 173 | label[:source_len] = IGNORE_INDEX 174 | return dict(input_ids=input_ids, labels=labels) 175 | 176 | 177 | class SupervisedDataset(Dataset): 178 | """Dataset for supervised fine-tuning.""" 179 | 180 | def __init__(self, data_path: str, tokenizer: transformers.PreTrainedTokenizer): 181 | super(SupervisedDataset, self).__init__() 182 | logging.warning("Loading data...") 183 | list_data_dict = jload(data_path) 184 | 185 | logging.warning("Formatting inputs...") 186 | 187 | prompt_input, prompt_no_input = PROMPT_DICT["prompt_input_llama2"], PROMPT_DICT["prompt_llama2"] 188 | sources = [ 189 | prompt_input.format_map(example) if example.get("input", "") != "" else prompt_no_input.format_map(example) 190 | for example in list_data_dict 191 | ] 192 | 193 | targets = [f"{example['output']}{tokenizer.eos_token}" for example in list_data_dict] 194 | 195 | logging.warning("Tokenizing inputs... This may take some time...") 196 | data_dict = preprocess(sources, targets, tokenizer) 197 | 198 | self.input_ids = data_dict["input_ids"] 199 | self.labels = data_dict["labels"] 200 | 201 | def __len__(self): 202 | return len(self.input_ids) 203 | 204 | def __getitem__(self, i) -> Dict[str, torch.Tensor]: 205 | return dict(input_ids=self.input_ids[i], labels=self.labels[i]) 206 | 207 | 208 | @dataclass 209 | class DataCollatorForSupervisedDataset(object): 210 | """Collate examples for supervised fine-tuning.""" 211 | 212 | tokenizer: transformers.PreTrainedTokenizer 213 | 214 | def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: 215 | input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels")) 216 | input_ids = torch.nn.utils.rnn.pad_sequence( 217 | input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id 218 | ) 219 | labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX) 220 | return dict( 221 | input_ids=input_ids, 222 | labels=labels, 223 | attention_mask=input_ids.ne(self.tokenizer.pad_token_id), 224 | ) 225 | 226 | 227 | def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args) -> Dict: 228 | """Make dataset and collator for supervised fine-tuning.""" 229 | train_dataset = SupervisedDataset(tokenizer=tokenizer, data_path=data_args.data_path) 230 | data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer) 231 | return dict(train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator) 232 | 233 | 234 | def train(): 235 | parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments)) 236 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 237 | 238 | # NOTE: May expand supported model types in the future 239 | if model_args.model_type == "gpt-neox": 240 | replace_gpt_neox_attn(training_args.use_flash_attn, training_args.use_full_attn) 241 | else: 242 | replace_llama_attn(training_args.use_flash_attn, training_args.use_full_attn) 243 | 244 | # Set RoPE scaling factor 245 | config = transformers.AutoConfig.from_pretrained( 246 | model_args.model_name_or_path, 247 | cache_dir=training_args.cache_dir, 248 | ) 249 | 250 | orig_rope_scaling = getattr(config, "rope_scaling", None) 251 | if orig_rope_scaling is None: 252 | orig_rope_scaling = {"factor": 1} 253 | orig_rope_scaling_factor = orig_rope_scaling["factor"] if "factor" in orig_rope_scaling.keys() else 1 254 | orig_ctx_len = getattr(config, "max_position_embeddings", None) 255 | if orig_ctx_len: 256 | orig_ctx_len *= orig_rope_scaling_factor 257 | if training_args.model_max_length > orig_ctx_len: 258 | scaling_factor = float(math.ceil(training_args.model_max_length / orig_ctx_len)) 259 | config.rope_scaling = {"type": "linear", "factor": scaling_factor} 260 | 261 | # Load model and tokenizer 262 | model = transformers.AutoModelForCausalLM.from_pretrained( 263 | model_args.model_name_or_path, 264 | config=config, 265 | cache_dir=training_args.cache_dir, 266 | torch_dtype=torch.bfloat16, 267 | ) 268 | 269 | tokenizer = transformers.AutoTokenizer.from_pretrained( 270 | model_args.model_name_or_path, 271 | cache_dir=training_args.cache_dir, 272 | model_max_length=training_args.model_max_length, 273 | padding_side="right", 274 | use_fast=True, 275 | ) 276 | 277 | special_tokens_dict = dict() 278 | if tokenizer.pad_token is None: 279 | special_tokens_dict["pad_token"] = DEFAULT_PAD_TOKEN 280 | if tokenizer.eos_token is None: 281 | special_tokens_dict["eos_token"] = DEFAULT_EOS_TOKEN 282 | if tokenizer.bos_token is None: 283 | special_tokens_dict["bos_token"] = DEFAULT_BOS_TOKEN 284 | if tokenizer.unk_token is None: 285 | special_tokens_dict["unk_token"] = DEFAULT_UNK_TOKEN 286 | 287 | smart_tokenizer_and_embedding_resize( 288 | special_tokens_dict=special_tokens_dict, 289 | tokenizer=tokenizer, 290 | model=model, 291 | ) 292 | 293 | data_module = make_supervised_data_module(tokenizer=tokenizer, data_args=data_args) 294 | 295 | if training_args.low_rank_training: 296 | if model_args.model_type == "gpt-neox": 297 | # added `dense` to match with llama as the basic LoRA would only target 'query_key_value' 298 | targets = ["query_key_value", "dense"] 299 | else: 300 | targets=["q_proj", "k_proj", "v_proj", "o_proj"] 301 | 302 | config = LoraConfig( 303 | r=8, 304 | lora_alpha=16, 305 | target_modules=targets, 306 | lora_dropout=0, 307 | bias="none", 308 | task_type="CAUSAL_LM", 309 | ) 310 | model = get_peft_model(model, config) 311 | # enable trainable params 312 | [p.requires_grad_() for n, p in model.named_parameters() if any([k in n for k in training_args.trainable_params.split(",")])] 313 | 314 | model.config.use_cache = False # required for gradient checkpointing 315 | model.enable_input_require_grads() # required for gradient checkpointing 316 | model.gradient_checkpointing_enable() # enable gradient checkpointing 317 | 318 | trainer = Trainer(model=model, tokenizer=tokenizer, args=training_args, **data_module) 319 | trainer.train() 320 | trainer.save_state() 321 | trainer.save_model(output_dir=training_args.output_dir) 322 | 323 | 324 | if __name__ == "__main__": 325 | train() 326 | --------------------------------------------------------------------------------