├── LICENSE ├── README.md ├── figure └── subzero.png ├── large_models ├── lora.py ├── metrics.py ├── modeling_llama.py ├── modeling_mistral │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-310.pyc │ │ ├── configuration_mistral.cpython-310.pyc │ │ └── modeling_mistral.cpython-310.pyc │ ├── configuration_mistral.py │ └── modeling_mistral.py ├── modeling_opt.py ├── prefix_tuning.py ├── prompt_tuning.py ├── run.py ├── tasks.py ├── templates.py ├── trainer.py └── utils.py └── requirements.txt /LICENSE: -------------------------------------------------------------------------------- 1 | GNU GENERAL PUBLIC LICENSE 2 | Version 3, 29 June 2007 3 | 4 | Copyright (C) 2007 Free Software Foundation, Inc. 5 | Everyone is permitted to copy and distribute verbatim copies 6 | of this license document, but changing it is not allowed. 7 | 8 | Preamble 9 | 10 | The GNU General Public License is a free, copyleft license for 11 | software and other kinds of works. 12 | 13 | The licenses for most software and other practical works are designed 14 | to take away your freedom to share and change the works. By contrast, 15 | the GNU General Public License is intended to guarantee your freedom to 16 | share and change all versions of a program--to make sure it remains free 17 | software for all its users. We, the Free Software Foundation, use the 18 | GNU General Public License for most of our software; it applies also to 19 | any other work released this way by its authors. You can apply it to 20 | your programs, too. 21 | 22 | When we speak of free software, we are referring to freedom, not 23 | price. Our General Public Licenses are designed to make sure that you 24 | have the freedom to distribute copies of free software (and charge for 25 | them if you wish), that you receive source code or can get it if you 26 | want it, that you can change the software or use pieces of it in new 27 | free programs, and that you know you can do these things. 28 | 29 | To protect your rights, we need to prevent others from denying you 30 | these rights or asking you to surrender the rights. Therefore, you have 31 | certain responsibilities if you distribute copies of the software, or if 32 | you modify it: responsibilities to respect the freedom of others. 33 | 34 | For example, if you distribute copies of such a program, whether 35 | gratis or for a fee, you must pass on to the recipients the same 36 | freedoms that you received. You must make sure that they, too, receive 37 | or can get the source code. And you must show them these terms so they 38 | know their rights. 39 | 40 | Developers that use the GNU GPL protect your rights with two steps: 41 | (1) assert copyright on the software, and (2) offer you this License 42 | giving you legal permission to copy, distribute and/or modify it. 43 | 44 | For the developers' and authors' protection, the GPL clearly explains 45 | that there is no warranty for this free software. For both users' and 46 | authors' sake, the GPL requires that modified versions be marked as 47 | changed, so that their problems will not be attributed erroneously to 48 | authors of previous versions. 49 | 50 | Some devices are designed to deny users access to install or run 51 | modified versions of the software inside them, although the manufacturer 52 | can do so. This is fundamentally incompatible with the aim of 53 | protecting users' freedom to change the software. The systematic 54 | pattern of such abuse occurs in the area of products for individuals to 55 | use, which is precisely where it is most unacceptable. Therefore, we 56 | have designed this version of the GPL to prohibit the practice for those 57 | products. If such problems arise substantially in other domains, we 58 | stand ready to extend this provision to those domains in future versions 59 | of the GPL, as needed to protect the freedom of users. 60 | 61 | Finally, every program is threatened constantly by software patents. 62 | States should not allow patents to restrict development and use of 63 | software on general-purpose computers, but in those that do, we wish to 64 | avoid the special danger that patents applied to a free program could 65 | make it effectively proprietary. To prevent this, the GPL assures that 66 | patents cannot be used to render the program non-free. 67 | 68 | The precise terms and conditions for copying, distribution and 69 | modification follow. 70 | 71 | TERMS AND CONDITIONS 72 | 73 | 0. Definitions. 74 | 75 | "This License" refers to version 3 of the GNU General Public License. 76 | 77 | "Copyright" also means copyright-like laws that apply to other kinds of 78 | works, such as semiconductor masks. 79 | 80 | "The Program" refers to any copyrightable work licensed under this 81 | License. Each licensee is addressed as "you". "Licensees" and 82 | "recipients" may be individuals or organizations. 83 | 84 | To "modify" a work means to copy from or adapt all or part of the work 85 | in a fashion requiring copyright permission, other than the making of an 86 | exact copy. The resulting work is called a "modified version" of the 87 | earlier work or a work "based on" the earlier work. 88 | 89 | A "covered work" means either the unmodified Program or a work based 90 | on the Program. 91 | 92 | To "propagate" a work means to do anything with it that, without 93 | permission, would make you directly or secondarily liable for 94 | infringement under applicable copyright law, except executing it on a 95 | computer or modifying a private copy. Propagation includes copying, 96 | distribution (with or without modification), making available to the 97 | public, and in some countries other activities as well. 98 | 99 | To "convey" a work means any kind of propagation that enables other 100 | parties to make or receive copies. Mere interaction with a user through 101 | a computer network, with no transfer of a copy, is not conveying. 102 | 103 | An interactive user interface displays "Appropriate Legal Notices" 104 | to the extent that it includes a convenient and prominently visible 105 | feature that (1) displays an appropriate copyright notice, and (2) 106 | tells the user that there is no warranty for the work (except to the 107 | extent that warranties are provided), that licensees may convey the 108 | work under this License, and how to view a copy of this License. If 109 | the interface presents a list of user commands or options, such as a 110 | menu, a prominent item in the list meets this criterion. 111 | 112 | 1. Source Code. 113 | 114 | The "source code" for a work means the preferred form of the work 115 | for making modifications to it. "Object code" means any non-source 116 | form of a work. 117 | 118 | A "Standard Interface" means an interface that either is an official 119 | standard defined by a recognized standards body, or, in the case of 120 | interfaces specified for a particular programming language, one that 121 | is widely used among developers working in that language. 122 | 123 | The "System Libraries" of an executable work include anything, other 124 | than the work as a whole, that (a) is included in the normal form of 125 | packaging a Major Component, but which is not part of that Major 126 | Component, and (b) serves only to enable use of the work with that 127 | Major Component, or to implement a Standard Interface for which an 128 | implementation is available to the public in source code form. A 129 | "Major Component", in this context, means a major essential component 130 | (kernel, window system, and so on) of the specific operating system 131 | (if any) on which the executable work runs, or a compiler used to 132 | produce the work, or an object code interpreter used to run it. 133 | 134 | The "Corresponding Source" for a work in object code form means all 135 | the source code needed to generate, install, and (for an executable 136 | work) run the object code and to modify the work, including scripts to 137 | control those activities. However, it does not include the work's 138 | System Libraries, or general-purpose tools or generally available free 139 | programs which are used unmodified in performing those activities but 140 | which are not part of the work. For example, Corresponding Source 141 | includes interface definition files associated with source files for 142 | the work, and the source code for shared libraries and dynamically 143 | linked subprograms that the work is specifically designed to require, 144 | such as by intimate data communication or control flow between those 145 | subprograms and other parts of the work. 146 | 147 | The Corresponding Source need not include anything that users 148 | can regenerate automatically from other parts of the Corresponding 149 | Source. 150 | 151 | The Corresponding Source for a work in source code form is that 152 | same work. 153 | 154 | 2. Basic Permissions. 155 | 156 | All rights granted under this License are granted for the term of 157 | copyright on the Program, and are irrevocable provided the stated 158 | conditions are met. This License explicitly affirms your unlimited 159 | permission to run the unmodified Program. The output from running a 160 | covered work is covered by this License only if the output, given its 161 | content, constitutes a covered work. This License acknowledges your 162 | rights of fair use or other equivalent, as provided by copyright law. 163 | 164 | You may make, run and propagate covered works that you do not 165 | convey, without conditions so long as your license otherwise remains 166 | in force. You may convey covered works to others for the sole purpose 167 | of having them make modifications exclusively for you, or provide you 168 | with facilities for running those works, provided that you comply with 169 | the terms of this License in conveying all material for which you do 170 | not control copyright. Those thus making or running the covered works 171 | for you must do so exclusively on your behalf, under your direction 172 | and control, on terms that prohibit them from making any copies of 173 | your copyrighted material outside their relationship with you. 174 | 175 | Conveying under any other circumstances is permitted solely under 176 | the conditions stated below. Sublicensing is not allowed; section 10 177 | makes it unnecessary. 178 | 179 | 3. Protecting Users' Legal Rights From Anti-Circumvention Law. 180 | 181 | No covered work shall be deemed part of an effective technological 182 | measure under any applicable law fulfilling obligations under article 183 | 11 of the WIPO copyright treaty adopted on 20 December 1996, or 184 | similar laws prohibiting or restricting circumvention of such 185 | measures. 186 | 187 | When you convey a covered work, you waive any legal power to forbid 188 | circumvention of technological measures to the extent such circumvention 189 | is effected by exercising rights under this License with respect to 190 | the covered work, and you disclaim any intention to limit operation or 191 | modification of the work as a means of enforcing, against the work's 192 | users, your or third parties' legal rights to forbid circumvention of 193 | technological measures. 194 | 195 | 4. Conveying Verbatim Copies. 196 | 197 | You may convey verbatim copies of the Program's source code as you 198 | receive it, in any medium, provided that you conspicuously and 199 | appropriately publish on each copy an appropriate copyright notice; 200 | keep intact all notices stating that this License and any 201 | non-permissive terms added in accord with section 7 apply to the code; 202 | keep intact all notices of the absence of any warranty; and give all 203 | recipients a copy of this License along with the Program. 204 | 205 | You may charge any price or no price for each copy that you convey, 206 | and you may offer support or warranty protection for a fee. 207 | 208 | 5. Conveying Modified Source Versions. 209 | 210 | You may convey a work based on the Program, or the modifications to 211 | produce it from the Program, in the form of source code under the 212 | terms of section 4, provided that you also meet all of these conditions: 213 | 214 | a) The work must carry prominent notices stating that you modified 215 | it, and giving a relevant date. 216 | 217 | b) The work must carry prominent notices stating that it is 218 | released under this License and any conditions added under section 219 | 7. This requirement modifies the requirement in section 4 to 220 | "keep intact all notices". 221 | 222 | c) You must license the entire work, as a whole, under this 223 | License to anyone who comes into possession of a copy. This 224 | License will therefore apply, along with any applicable section 7 225 | additional terms, to the whole of the work, and all its parts, 226 | regardless of how they are packaged. This License gives no 227 | permission to license the work in any other way, but it does not 228 | invalidate such permission if you have separately received it. 229 | 230 | d) If the work has interactive user interfaces, each must display 231 | Appropriate Legal Notices; however, if the Program has interactive 232 | interfaces that do not display Appropriate Legal Notices, your 233 | work need not make them do so. 234 | 235 | A compilation of a covered work with other separate and independent 236 | works, which are not by their nature extensions of the covered work, 237 | and which are not combined with it such as to form a larger program, 238 | in or on a volume of a storage or distribution medium, is called an 239 | "aggregate" if the compilation and its resulting copyright are not 240 | used to limit the access or legal rights of the compilation's users 241 | beyond what the individual works permit. Inclusion of a covered work 242 | in an aggregate does not cause this License to apply to the other 243 | parts of the aggregate. 244 | 245 | 6. Conveying Non-Source Forms. 246 | 247 | You may convey a covered work in object code form under the terms 248 | of sections 4 and 5, provided that you also convey the 249 | machine-readable Corresponding Source under the terms of this License, 250 | in one of these ways: 251 | 252 | a) Convey the object code in, or embodied in, a physical product 253 | (including a physical distribution medium), accompanied by the 254 | Corresponding Source fixed on a durable physical medium 255 | customarily used for software interchange. 256 | 257 | b) Convey the object code in, or embodied in, a physical product 258 | (including a physical distribution medium), accompanied by a 259 | written offer, valid for at least three years and valid for as 260 | long as you offer spare parts or customer support for that product 261 | model, to give anyone who possesses the object code either (1) a 262 | copy of the Corresponding Source for all the software in the 263 | product that is covered by this License, on a durable physical 264 | medium customarily used for software interchange, for a price no 265 | more than your reasonable cost of physically performing this 266 | conveying of source, or (2) access to copy the 267 | Corresponding Source from a network server at no charge. 268 | 269 | c) Convey individual copies of the object code with a copy of the 270 | written offer to provide the Corresponding Source. This 271 | alternative is allowed only occasionally and noncommercially, and 272 | only if you received the object code with such an offer, in accord 273 | with subsection 6b. 274 | 275 | d) Convey the object code by offering access from a designated 276 | place (gratis or for a charge), and offer equivalent access to the 277 | Corresponding Source in the same way through the same place at no 278 | further charge. You need not require recipients to copy the 279 | Corresponding Source along with the object code. If the place to 280 | copy the object code is a network server, the Corresponding Source 281 | may be on a different server (operated by you or a third party) 282 | that supports equivalent copying facilities, provided you maintain 283 | clear directions next to the object code saying where to find the 284 | Corresponding Source. Regardless of what server hosts the 285 | Corresponding Source, you remain obligated to ensure that it is 286 | available for as long as needed to satisfy these requirements. 287 | 288 | e) Convey the object code using peer-to-peer transmission, provided 289 | you inform other peers where the object code and Corresponding 290 | Source of the work are being offered to the general public at no 291 | charge under subsection 6d. 292 | 293 | A separable portion of the object code, whose source code is excluded 294 | from the Corresponding Source as a System Library, need not be 295 | included in conveying the object code work. 296 | 297 | A "User Product" is either (1) a "consumer product", which means any 298 | tangible personal property which is normally used for personal, family, 299 | or household purposes, or (2) anything designed or sold for incorporation 300 | into a dwelling. In determining whether a product is a consumer product, 301 | doubtful cases shall be resolved in favor of coverage. For a particular 302 | product received by a particular user, "normally used" refers to a 303 | typical or common use of that class of product, regardless of the status 304 | of the particular user or of the way in which the particular user 305 | actually uses, or expects or is expected to use, the product. A product 306 | is a consumer product regardless of whether the product has substantial 307 | commercial, industrial or non-consumer uses, unless such uses represent 308 | the only significant mode of use of the product. 309 | 310 | "Installation Information" for a User Product means any methods, 311 | procedures, authorization keys, or other information required to install 312 | and execute modified versions of a covered work in that User Product from 313 | a modified version of its Corresponding Source. The information must 314 | suffice to ensure that the continued functioning of the modified object 315 | code is in no case prevented or interfered with solely because 316 | modification has been made. 317 | 318 | If you convey an object code work under this section in, or with, or 319 | specifically for use in, a User Product, and the conveying occurs as 320 | part of a transaction in which the right of possession and use of the 321 | User Product is transferred to the recipient in perpetuity or for a 322 | fixed term (regardless of how the transaction is characterized), the 323 | Corresponding Source conveyed under this section must be accompanied 324 | by the Installation Information. But this requirement does not apply 325 | if neither you nor any third party retains the ability to install 326 | modified object code on the User Product (for example, the work has 327 | been installed in ROM). 328 | 329 | The requirement to provide Installation Information does not include a 330 | requirement to continue to provide support service, warranty, or updates 331 | for a work that has been modified or installed by the recipient, or for 332 | the User Product in which it has been modified or installed. Access to a 333 | network may be denied when the modification itself materially and 334 | adversely affects the operation of the network or violates the rules and 335 | protocols for communication across the network. 336 | 337 | Corresponding Source conveyed, and Installation Information provided, 338 | in accord with this section must be in a format that is publicly 339 | documented (and with an implementation available to the public in 340 | source code form), and must require no special password or key for 341 | unpacking, reading or copying. 342 | 343 | 7. Additional Terms. 344 | 345 | "Additional permissions" are terms that supplement the terms of this 346 | License by making exceptions from one or more of its conditions. 347 | Additional permissions that are applicable to the entire Program shall 348 | be treated as though they were included in this License, to the extent 349 | that they are valid under applicable law. If additional permissions 350 | apply only to part of the Program, that part may be used separately 351 | under those permissions, but the entire Program remains governed by 352 | this License without regard to the additional permissions. 353 | 354 | When you convey a copy of a covered work, you may at your option 355 | remove any additional permissions from that copy, or from any part of 356 | it. (Additional permissions may be written to require their own 357 | removal in certain cases when you modify the work.) You may place 358 | additional permissions on material, added by you to a covered work, 359 | for which you have or can give appropriate copyright permission. 360 | 361 | Notwithstanding any other provision of this License, for material you 362 | add to a covered work, you may (if authorized by the copyright holders of 363 | that material) supplement the terms of this License with terms: 364 | 365 | a) Disclaiming warranty or limiting liability differently from the 366 | terms of sections 15 and 16 of this License; or 367 | 368 | b) Requiring preservation of specified reasonable legal notices or 369 | author attributions in that material or in the Appropriate Legal 370 | Notices displayed by works containing it; or 371 | 372 | c) Prohibiting misrepresentation of the origin of that material, or 373 | requiring that modified versions of such material be marked in 374 | reasonable ways as different from the original version; or 375 | 376 | d) Limiting the use for publicity purposes of names of licensors or 377 | authors of the material; or 378 | 379 | e) Declining to grant rights under trademark law for use of some 380 | trade names, trademarks, or service marks; or 381 | 382 | f) Requiring indemnification of licensors and authors of that 383 | material by anyone who conveys the material (or modified versions of 384 | it) with contractual assumptions of liability to the recipient, for 385 | any liability that these contractual assumptions directly impose on 386 | those licensors and authors. 387 | 388 | All other non-permissive additional terms are considered "further 389 | restrictions" within the meaning of section 10. If the Program as you 390 | received it, or any part of it, contains a notice stating that it is 391 | governed by this License along with a term that is a further 392 | restriction, you may remove that term. If a license document contains 393 | a further restriction but permits relicensing or conveying under this 394 | License, you may add to a covered work material governed by the terms 395 | of that license document, provided that the further restriction does 396 | not survive such relicensing or conveying. 397 | 398 | If you add terms to a covered work in accord with this section, you 399 | must place, in the relevant source files, a statement of the 400 | additional terms that apply to those files, or a notice indicating 401 | where to find the applicable terms. 402 | 403 | Additional terms, permissive or non-permissive, may be stated in the 404 | form of a separately written license, or stated as exceptions; 405 | the above requirements apply either way. 406 | 407 | 8. Termination. 408 | 409 | You may not propagate or modify a covered work except as expressly 410 | provided under this License. Any attempt otherwise to propagate or 411 | modify it is void, and will automatically terminate your rights under 412 | this License (including any patent licenses granted under the third 413 | paragraph of section 11). 414 | 415 | However, if you cease all violation of this License, then your 416 | license from a particular copyright holder is reinstated (a) 417 | provisionally, unless and until the copyright holder explicitly and 418 | finally terminates your license, and (b) permanently, if the copyright 419 | holder fails to notify you of the violation by some reasonable means 420 | prior to 60 days after the cessation. 421 | 422 | Moreover, your license from a particular copyright holder is 423 | reinstated permanently if the copyright holder notifies you of the 424 | violation by some reasonable means, this is the first time you have 425 | received notice of violation of this License (for any work) from that 426 | copyright holder, and you cure the violation prior to 30 days after 427 | your receipt of the notice. 428 | 429 | Termination of your rights under this section does not terminate the 430 | licenses of parties who have received copies or rights from you under 431 | this License. If your rights have been terminated and not permanently 432 | reinstated, you do not qualify to receive new licenses for the same 433 | material under section 10. 434 | 435 | 9. Acceptance Not Required for Having Copies. 436 | 437 | You are not required to accept this License in order to receive or 438 | run a copy of the Program. Ancillary propagation of a covered work 439 | occurring solely as a consequence of using peer-to-peer transmission 440 | to receive a copy likewise does not require acceptance. However, 441 | nothing other than this License grants you permission to propagate or 442 | modify any covered work. These actions infringe copyright if you do 443 | not accept this License. Therefore, by modifying or propagating a 444 | covered work, you indicate your acceptance of this License to do so. 445 | 446 | 10. Automatic Licensing of Downstream Recipients. 447 | 448 | Each time you convey a covered work, the recipient automatically 449 | receives a license from the original licensors, to run, modify and 450 | propagate that work, subject to this License. You are not responsible 451 | for enforcing compliance by third parties with this License. 452 | 453 | An "entity transaction" is a transaction transferring control of an 454 | organization, or substantially all assets of one, or subdividing an 455 | organization, or merging organizations. If propagation of a covered 456 | work results from an entity transaction, each party to that 457 | transaction who receives a copy of the work also receives whatever 458 | licenses to the work the party's predecessor in interest had or could 459 | give under the previous paragraph, plus a right to possession of the 460 | Corresponding Source of the work from the predecessor in interest, if 461 | the predecessor has it or can get it with reasonable efforts. 462 | 463 | You may not impose any further restrictions on the exercise of the 464 | rights granted or affirmed under this License. For example, you may 465 | not impose a license fee, royalty, or other charge for exercise of 466 | rights granted under this License, and you may not initiate litigation 467 | (including a cross-claim or counterclaim in a lawsuit) alleging that 468 | any patent claim is infringed by making, using, selling, offering for 469 | sale, or importing the Program or any portion of it. 470 | 471 | 11. Patents. 472 | 473 | A "contributor" is a copyright holder who authorizes use under this 474 | License of the Program or a work on which the Program is based. The 475 | work thus licensed is called the contributor's "contributor version". 476 | 477 | A contributor's "essential patent claims" are all patent claims 478 | owned or controlled by the contributor, whether already acquired or 479 | hereafter acquired, that would be infringed by some manner, permitted 480 | by this License, of making, using, or selling its contributor version, 481 | but do not include claims that would be infringed only as a 482 | consequence of further modification of the contributor version. For 483 | purposes of this definition, "control" includes the right to grant 484 | patent sublicenses in a manner consistent with the requirements of 485 | this License. 486 | 487 | Each contributor grants you a non-exclusive, worldwide, royalty-free 488 | patent license under the contributor's essential patent claims, to 489 | make, use, sell, offer for sale, import and otherwise run, modify and 490 | propagate the contents of its contributor version. 491 | 492 | In the following three paragraphs, a "patent license" is any express 493 | agreement or commitment, however denominated, not to enforce a patent 494 | (such as an express permission to practice a patent or covenant not to 495 | sue for patent infringement). To "grant" such a patent license to a 496 | party means to make such an agreement or commitment not to enforce a 497 | patent against the party. 498 | 499 | If you convey a covered work, knowingly relying on a patent license, 500 | and the Corresponding Source of the work is not available for anyone 501 | to copy, free of charge and under the terms of this License, through a 502 | publicly available network server or other readily accessible means, 503 | then you must either (1) cause the Corresponding Source to be so 504 | available, or (2) arrange to deprive yourself of the benefit of the 505 | patent license for this particular work, or (3) arrange, in a manner 506 | consistent with the requirements of this License, to extend the patent 507 | license to downstream recipients. "Knowingly relying" means you have 508 | actual knowledge that, but for the patent license, your conveying the 509 | covered work in a country, or your recipient's use of the covered work 510 | in a country, would infringe one or more identifiable patents in that 511 | country that you have reason to believe are valid. 512 | 513 | If, pursuant to or in connection with a single transaction or 514 | arrangement, you convey, or propagate by procuring conveyance of, a 515 | covered work, and grant a patent license to some of the parties 516 | receiving the covered work authorizing them to use, propagate, modify 517 | or convey a specific copy of the covered work, then the patent license 518 | you grant is automatically extended to all recipients of the covered 519 | work and works based on it. 520 | 521 | A patent license is "discriminatory" if it does not include within 522 | the scope of its coverage, prohibits the exercise of, or is 523 | conditioned on the non-exercise of one or more of the rights that are 524 | specifically granted under this License. You may not convey a covered 525 | work if you are a party to an arrangement with a third party that is 526 | in the business of distributing software, under which you make payment 527 | to the third party based on the extent of your activity of conveying 528 | the work, and under which the third party grants, to any of the 529 | parties who would receive the covered work from you, a discriminatory 530 | patent license (a) in connection with copies of the covered work 531 | conveyed by you (or copies made from those copies), or (b) primarily 532 | for and in connection with specific products or compilations that 533 | contain the covered work, unless you entered into that arrangement, 534 | or that patent license was granted, prior to 28 March 2007. 535 | 536 | Nothing in this License shall be construed as excluding or limiting 537 | any implied license or other defenses to infringement that may 538 | otherwise be available to you under applicable patent law. 539 | 540 | 12. No Surrender of Others' Freedom. 541 | 542 | If conditions are imposed on you (whether by court order, agreement or 543 | otherwise) that contradict the conditions of this License, they do not 544 | excuse you from the conditions of this License. If you cannot convey a 545 | covered work so as to satisfy simultaneously your obligations under this 546 | License and any other pertinent obligations, then as a consequence you may 547 | not convey it at all. For example, if you agree to terms that obligate you 548 | to collect a royalty for further conveying from those to whom you convey 549 | the Program, the only way you could satisfy both those terms and this 550 | License would be to refrain entirely from conveying the Program. 551 | 552 | 13. Use with the GNU Affero General Public License. 553 | 554 | Notwithstanding any other provision of this License, you have 555 | permission to link or combine any covered work with a work licensed 556 | under version 3 of the GNU Affero General Public License into a single 557 | combined work, and to convey the resulting work. The terms of this 558 | License will continue to apply to the part which is the covered work, 559 | but the special requirements of the GNU Affero General Public License, 560 | section 13, concerning interaction through a network will apply to the 561 | combination as such. 562 | 563 | 14. Revised Versions of this License. 564 | 565 | The Free Software Foundation may publish revised and/or new versions of 566 | the GNU General Public License from time to time. Such new versions will 567 | be similar in spirit to the present version, but may differ in detail to 568 | address new problems or concerns. 569 | 570 | Each version is given a distinguishing version number. If the 571 | Program specifies that a certain numbered version of the GNU General 572 | Public License "or any later version" applies to it, you have the 573 | option of following the terms and conditions either of that numbered 574 | version or of any later version published by the Free Software 575 | Foundation. If the Program does not specify a version number of the 576 | GNU General Public License, you may choose any version ever published 577 | by the Free Software Foundation. 578 | 579 | If the Program specifies that a proxy can decide which future 580 | versions of the GNU General Public License can be used, that proxy's 581 | public statement of acceptance of a version permanently authorizes you 582 | to choose that version for the Program. 583 | 584 | Later license versions may give you additional or different 585 | permissions. However, no additional obligations are imposed on any 586 | author or copyright holder as a result of your choosing to follow a 587 | later version. 588 | 589 | 15. Disclaimer of Warranty. 590 | 591 | THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY 592 | APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT 593 | HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY 594 | OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, 595 | THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR 596 | PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM 597 | IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF 598 | ALL NECESSARY SERVICING, REPAIR OR CORRECTION. 599 | 600 | 16. Limitation of Liability. 601 | 602 | IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING 603 | WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS 604 | THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY 605 | GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE 606 | USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF 607 | DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD 608 | PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), 609 | EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF 610 | SUCH DAMAGES. 611 | 612 | 17. Interpretation of Sections 15 and 16. 613 | 614 | If the disclaimer of warranty and limitation of liability provided 615 | above cannot be given local legal effect according to their terms, 616 | reviewing courts shall apply local law that most closely approximates 617 | an absolute waiver of all civil liability in connection with the 618 | Program, unless a warranty or assumption of liability accompanies a 619 | copy of the Program in return for a fee. 620 | 621 | END OF TERMS AND CONDITIONS 622 | 623 | How to Apply These Terms to Your New Programs 624 | 625 | If you develop a new program, and you want it to be of the greatest 626 | possible use to the public, the best way to achieve this is to make it 627 | free software which everyone can redistribute and change under these terms. 628 | 629 | To do so, attach the following notices to the program. It is safest 630 | to attach them to the start of each source file to most effectively 631 | state the exclusion of warranty; and each file should have at least 632 | the "copyright" line and a pointer to where the full notice is found. 633 | 634 | 635 | Copyright (C) 636 | 637 | This program is free software: you can redistribute it and/or modify 638 | it under the terms of the GNU General Public License as published by 639 | the Free Software Foundation, either version 3 of the License, or 640 | (at your option) any later version. 641 | 642 | This program is distributed in the hope that it will be useful, 643 | but WITHOUT ANY WARRANTY; without even the implied warranty of 644 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 645 | GNU General Public License for more details. 646 | 647 | You should have received a copy of the GNU General Public License 648 | along with this program. If not, see . 649 | 650 | Also add information on how to contact you by electronic and paper mail. 651 | 652 | If the program does terminal interaction, make it output a short 653 | notice like this when it starts in an interactive mode: 654 | 655 | Copyright (C) 656 | This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'. 657 | This is free software, and you are welcome to redistribute it 658 | under certain conditions; type `show c' for details. 659 | 660 | The hypothetical commands `show w' and `show c' should show the appropriate 661 | parts of the General Public License. Of course, your program's commands 662 | might be different; for a GUI interface, you would use an "about box". 663 | 664 | You should also get your employer (if you work as a programmer) or school, 665 | if any, to sign a "copyright disclaimer" for the program, if necessary. 666 | For more information on this, and how to apply and follow the GNU GPL, see 667 | . 668 | 669 | The GNU General Public License does not permit incorporating your program 670 | into proprietary programs. If your program is a subroutine library, you 671 | may consider it more useful to permit linking proprietary applications with 672 | the library. If this is what you want to do, use the GNU Lesser General 673 | Public License instead of this License. But first, please read 674 | . 675 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Source code for paper "Zeroth-Order Fine-Tuning of LLMs in Random Subspaces" 2 | 3 | This is the implementation for the paper [Zeroth-Order Fine-Tuning of LLMs in Random Subspaces](http://arxiv.org/abs/2410.08989). 4 | 5 | In this paper, we propose the random Subspace Zeroth-order (SubZero) optimization to address the challenges posed by LLMs’ high dimensionality. We introduce a low-rank perturbation tailored for LLMs that significantly reduces memory consumption while improving training performance. Additionally, we have successfully applied SubZero to four popular fine-tuning schemes for LLMs, including full parameter tuning, LoRA, prefix tuning, and prompt tuning. This demonstrates SubZero's compatibility and versatility across different tuning approaches. 6 | 7 | Furthermore, we prove that our gradient estimation closely approximates the backpropagation gradient, exhibits lower variance than traditional ZO methods, and ensures convergence when combined with SGD. Experimental results show that SubZero enhances fine-tuning performance and achieves faster convergence compared to standard ZO approaches like [MeZO](https://github.com/princeton-nlp/MeZO) across various language modeling tasks. 8 | 9 | 10 |

11 | Fig 12 | 13 | Visualization of cosine similarity, relative variance, training loss and GPU memory cost on OPT-1.3B under the prompt tuning scheme. SubZero demonstrates reduced angle error and variance in gradient estimation, while also accelerating convergence with minimal additional memory overhead. 14 | 15 |

16 | 17 | ## Getting start 18 | - We use python 3.10 and torch 2.1.0, transformers 4.28.1, and cuda 11.8.0. 19 | - pip install -r requirements.txt 20 | 21 | ## Usage 22 | 23 | Use `run.py` for all functions (zero-shot/ICL/fine-tuning/MeZO/SubZero): 24 | ```bash 25 | python run.py {ARGUMENTS} 26 | ``` 27 | 28 | Please read `run.py` for a complete list of arguments. We introduce some of the most important ones below. 29 | * `--num_train`: Number of training examples. For ICL, this is the number of demonstrations. 30 | * `--num_dev`: Number of validation examples. 31 | * `--num_test`: Number of testing examples. 32 | * `--model_name`: HuggingFace model name or path. 33 | * `--task_name`: Task name. 34 | * `--trainer`: can be `none` (zero-shot/ICL), `regular` (fine-tuning), or `zo_sgd` (MeZO) or `subzero_sgd`(SubZero). 35 | * `--train_as_classification`: turn this on for classification tasks (Cross Entropy over likelihood of each class' label words). Otherwise it is LM-style teacher forcing. 36 | * `--zo_eps`: ZO hyperparameter epsilon 37 | * `--prefix_tuning`: use prefix-tuning. 38 | * `--lora`: use LoRA. 39 | * `--prompt_tuning`: use prompt-tuning. 40 | 41 | ## Reproducing Results 42 | 43 | We provide an example of the OPT-1.3b model performing prompt tuning on the SST-2 dataset. 44 | 45 | ### MeZO-SGD 46 | `CUDA_VISIBLE_DEVICES=0 python run.py --task_name=SST2 --model_name=facebook/opt-1.3b --output_dir=result/opt1.3b-SST2-prompt-mezo --num_train_epochs=5 --per_device_train_batch_size=16 --load_best_model_at_end --evaluation_strategy=steps --save_strategy=steps --save_total_limit=1 --eval_steps=1000 --max_steps=20000 --logging_steps=10 --num_eval=1000 --num_train=1000 --num_dev=500 --train_as_classification --perturbation_mode=two_side --trainer=zo_sgd --train_set_seed=0 --lr_scheduler_type=constant --eval_steps=500 --save_steps=500 --prompt_tuning --num_virtual_tokens=10 --prompt_init_by_real_tokens --learning_rate=1e-3 --zo_eps=1e-2 --weight_decay=0` 47 | 48 | ### SubZero-SGD 49 | `CUDA_VISIBLE_DEVICES=0 python run.py --task_name=SST2 --model_name=facebook/opt-1.3b --output_dir=result/opt1.3b-SST2-prompt-subzero --num_train_epochs=5 --per_device_train_batch_size=16 --load_best_model_at_end --evaluation_strategy=steps --save_strategy=steps --save_total_limit=1 --eval_steps=1000 --max_steps=20000 --logging_steps=10 --num_eval=1000 --num_train=1000 --num_dev=500 --train_as_classification --perturbation_mode=two_side --trainer=subzero_sgd --train_set_seed=0 --lr_scheduler_type=constant --eval_steps=500 --save_steps=500 --prompt_tuning --num_virtual_tokens=10 --prompt_init_by_real_tokens --learning_rate=1e-3 --zo_eps=1e-2 --weight_decay=0 --gauss_rank=24 --update_interval=1000` 50 | 51 | ### FO-SGD 52 | `CUDA_VISIBLE_DEVICES=0 python run.py --task_name=SST2 --model_name=facebook/opt-1.3b --output_dir=result/opt1.3b-SST2-prompt-sgd --num_train_epochs=5 --per_device_train_batch_size=16 --load_best_model_at_end --evaluation_strategy=steps --save_strategy=steps --save_total_limit=1 --eval_steps=1000 --max_steps=20000 --logging_steps=10 --num_eval=1000 --num_train=1000 --num_dev=500 --train_as_classification --perturbation_mode=two_side --trainer=sgd --optimizer=sgd --train_set_seed=0 --lr_scheduler_type=constant --eval_steps=500 --save_steps=500 --prompt_tuning --num_virtual_tokens=10 --prompt_init_by_real_tokens --learning_rate=1e-3 --zo_eps=1e-2 --weight_decay=0` 53 | 54 | ## Acknowledgment 55 | 56 | This project is built upon the foundation laid by [MeZO: Fine-Tuning Language Models with Just Forward Passes](https://github.com/princeton-nlp/MeZO) and [Revisiting Zeroth-Order Optimization for Memory-Efficient LLM Fine-Tuning: A Benchmark](https://github.com/ZO-Bench/ZO-LLM/tree/main). The original code from their project is licensed under the [MIT License](https://github.com/princeton-nlp/MeZO/blob/main/LICENSE) and [License](https://github.com/ZO-Bench/ZO-LLM/blob/main/LICENSE) respectively. We would like to thank the authors for their great work and contributions. 57 | -------------------------------------------------------------------------------- /figure/subzero.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zimingyy/SubZero/cf6effdd0aef2b1cea79b2a121dc2ad4e7037414/figure/subzero.png -------------------------------------------------------------------------------- /large_models/lora.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') 4 | logger = logging.getLogger(__name__) 5 | logger.setLevel(logging.INFO) 6 | 7 | import torch 8 | from torch import nn 9 | from torch.nn import functional as F 10 | import math 11 | 12 | 13 | def find_module(root_module: nn.Module, key: str): 14 | """ 15 | Find a module with a specific name in a Transformer model 16 | From OpenDelta https://github.com/thunlp/OpenDelta 17 | """ 18 | sub_keys = key.split(".") 19 | parent_module = root_module 20 | for sub_key in sub_keys[:-1]: 21 | parent_module = getattr(parent_module, sub_key) 22 | module = getattr(parent_module, sub_keys[-1]) 23 | return parent_module, sub_keys[-1], module 24 | 25 | 26 | class LoRALinear(nn.Linear): 27 | """ 28 | LoRA implemented in a dense layer 29 | From https://github.com/microsoft/LoRA/blob/main/loralib/layers.py 30 | """ 31 | 32 | def __init__( 33 | self, 34 | in_features: int, 35 | out_features: int, 36 | r: int = 0, 37 | lora_alpha: int = 1, 38 | lora_dropout: float = 0., 39 | fan_in_fan_out: bool = False, 40 | # Set this to True if the layer to replace stores weight like (fan_in, fan_out) 41 | merge_weights: bool = False, 42 | # Not sure if this will affect saving/loading models so just set it to be False 43 | **kwargs 44 | ): 45 | nn.Linear.__init__(self, in_features, out_features, **kwargs) 46 | 47 | self.r = r 48 | self.lora_alpha = lora_alpha 49 | # Optional dropout 50 | if lora_dropout > 0.: 51 | self.lora_dropout = nn.Dropout(p=lora_dropout) 52 | else: 53 | self.lora_dropout = lambda x: x 54 | # Mark the weight as unmerged 55 | self.merged = False 56 | self.merge_weights = merge_weights 57 | self.fan_in_fan_out = fan_in_fan_out 58 | # Actual trainable parameters 59 | if r > 0: 60 | self.lora_A = nn.Parameter(self.weight.new_zeros((r, in_features))) 61 | self.lora_B = nn.Parameter(self.weight.new_zeros((out_features, r))) 62 | self.scaling = self.lora_alpha / self.r 63 | # Freezing the pre-trained weight matrix 64 | self.weight.requires_grad = False 65 | self.reset_parameters() 66 | if fan_in_fan_out: 67 | self.weight.data = self.weight.data.transpose(0, 1) 68 | 69 | def reset_parameters(self): 70 | nn.Linear.reset_parameters(self) 71 | if hasattr(self, 'lora_A'): 72 | # initialize A the same way as the default for nn.Linear and B to zero 73 | nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) 74 | nn.init.zeros_(self.lora_B) 75 | 76 | def train(self, mode: bool = True): 77 | def T(w): 78 | return w.transpose(0, 1) if self.fan_in_fan_out else w 79 | 80 | nn.Linear.train(self, mode) 81 | if mode: 82 | if self.merge_weights and self.merged: 83 | # Make sure that the weights are not merged 84 | if self.r > 0: 85 | self.weight.data -= T(self.lora_B @ self.lora_A) * self.scaling 86 | self.merged = False 87 | else: 88 | if self.merge_weights and not self.merged: 89 | # Merge the weights and mark it 90 | if self.r > 0: 91 | self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling 92 | self.merged = True 93 | 94 | def forward(self, x: torch.Tensor): 95 | def T(w): 96 | return w.transpose(0, 1) if self.fan_in_fan_out else w 97 | 98 | if self.r > 0 and not self.merged: 99 | result = F.linear(x, T(self.weight), bias=self.bias) 100 | if self.r > 0: 101 | result += (self.lora_dropout(x) @ self.lora_A.transpose(0, 1) @ self.lora_B.transpose(0, 102 | 1)) * self.scaling 103 | return result 104 | else: 105 | return F.linear(x, T(self.weight), bias=self.bias) 106 | 107 | 108 | class LoRA: 109 | 110 | def __init__(self, model, r, alpha, float16): 111 | """ 112 | Input: 113 | r, alpha: LoRA hyperparameters 114 | float16: Whether the model parameters are float16 or not 115 | """ 116 | 117 | self.model = model 118 | self.hidden_dim = model.config.hidden_size 119 | self.float16 = float16 120 | 121 | if model.config.model_type == "opt": 122 | attention_name = "attn" 123 | elif model.config.model_type == "roberta": 124 | attention_name = "attention" 125 | elif model.config.model_type in ["llama", "mistral"]: 126 | attention_name = "self_attn" 127 | else: 128 | raise NotImplementedError 129 | 130 | # Insert LoRA 131 | for key, _ in model.named_modules(): 132 | if key[-len(attention_name):] == attention_name: 133 | logger.info(f"Inject lora to: {key}") 134 | _, _, attn = find_module(model, key) 135 | 136 | if model.config.model_type == "opt": 137 | original_q_weight = attn.q_proj.weight.data 138 | original_q_bias = attn.q_proj.bias.data 139 | original_v_weight = attn.v_proj.weight.data 140 | original_v_bias = attn.v_proj.bias.data 141 | attn.q_proj = LoRALinear(model.config.hidden_size, model.config.hidden_size, r=r, lora_alpha=alpha, 142 | bias=model.config.enable_bias).to(original_q_weight.device) 143 | attn.v_proj = LoRALinear(model.config.hidden_size, model.config.hidden_size, r=r, lora_alpha=alpha, 144 | bias=model.config.enable_bias).to(original_v_weight.device) 145 | if float16: 146 | attn.q_proj.half() 147 | attn.v_proj.half() 148 | attn.q_proj.weight.data = original_q_weight 149 | attn.q_proj.bias.data = original_q_bias 150 | attn.v_proj.weight.data = original_v_weight 151 | attn.v_proj.bias.data = original_v_bias 152 | elif model.config.model_type == "llama": 153 | # in early version of transformers, llama attention bias is hard coded to False 154 | attention_bias = False if not hasattr(model.config, "attention_bias") else model.config.attention_bias 155 | original_q_weight = attn.q_proj.weight.data 156 | original_v_weight = attn.v_proj.weight.data 157 | original_q_bias = attn.q_proj.bias.data if attention_bias else None 158 | original_v_bias = attn.v_proj.bias.data if attention_bias else None 159 | attn.q_proj = LoRALinear( 160 | model.config.hidden_size, 161 | model.config.hidden_size, 162 | r=r, lora_alpha=alpha, bias=attention_bias 163 | ).to(original_q_weight.device) 164 | attn.v_proj = LoRALinear( 165 | model.config.hidden_size, 166 | model.config.hidden_size, 167 | r=r, lora_alpha=alpha, bias=attention_bias 168 | ).to(original_v_weight.device) 169 | if float16: 170 | attn.q_proj.half() 171 | attn.v_proj.half() 172 | attn.q_proj.weight.data = original_q_weight 173 | attn.v_proj.weight.data = original_v_weight 174 | if attention_bias: 175 | attn.q_proj.bias.data = original_q_bias 176 | attn.v_proj.bias.data = original_v_bias 177 | elif model.config.model_type == "mistral": 178 | # in early version of transformers, llama attention bias is hard coded to False 179 | config = model.config 180 | original_q_weight = attn.q_proj.weight.data 181 | original_v_weight = attn.v_proj.weight.data 182 | head_dim = config.hidden_size // config.num_attention_heads 183 | attn.q_proj = LoRALinear( 184 | config.hidden_size, 185 | config.hidden_size, 186 | r=r, lora_alpha=alpha 187 | ).to(original_q_weight.device) 188 | attn.v_proj = LoRALinear( 189 | config.hidden_size, 190 | config.num_key_value_heads * head_dim, 191 | r=r, lora_alpha=alpha 192 | ).to(original_v_weight.device) 193 | if float16: 194 | attn.q_proj.half() 195 | attn.v_proj.half() 196 | attn.q_proj.weight.data = original_q_weight 197 | attn.v_proj.weight.data = original_v_weight 198 | else: 199 | raise NotImplementedError 200 | 201 | # Freeze non-LoRA parameters 202 | for n, p in model.named_parameters(): 203 | if "lora" not in n: 204 | p.requires_grad = False 205 | -------------------------------------------------------------------------------- /large_models/metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import collections 3 | import re 4 | import string 5 | from collections import Counter 6 | 7 | def normalize_answer(s): 8 | """Lower text and remove punctuation, articles and extra whitespace.""" 9 | 10 | def remove_articles(text): 11 | return re.sub(r'\b(a|an|the)\b', ' ', text) 12 | 13 | def white_space_fix(text): 14 | return ' '.join(text.split()) 15 | 16 | def remove_punc(text): 17 | exclude = set(string.punctuation) 18 | return ''.join(ch for ch in text if ch not in exclude) 19 | 20 | def lower(text): 21 | return text.lower() 22 | 23 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 24 | 25 | 26 | def calculate_metric(predictions, metric_name): 27 | if metric_name == "accuracy": 28 | if isinstance(predictions[0].correct_candidate, list): 29 | return np.mean([pred.predicted_candidate in pred.correct_candidate for pred in predictions]) 30 | else: 31 | return np.mean([pred.correct_candidate == pred.predicted_candidate for pred in predictions]) 32 | elif metric_name == "em": 33 | # For question answering 34 | return np.mean([any([normalize_answer(ans) == normalize_answer(pred.predicted_candidate) for ans in pred.correct_candidate]) for pred in predictions]) 35 | elif metric_name == "f1": 36 | # For question answering 37 | f1 = [] 38 | for pred in predictions: 39 | all_f1s = [] 40 | if pred.correct_candidate[0] == "CANNOTANSWER" or pred.correct_candidate[0] == "no answer": 41 | f1.append(int(normalize_answer(pred.correct_candidate[0]) == normalize_answer(pred.predicted_candidate))) 42 | else: 43 | for ans in pred.correct_candidate: 44 | prediction_tokens = normalize_answer(pred.predicted_candidate).split() 45 | ground_truth_tokens = normalize_answer(ans).split() 46 | common = Counter(prediction_tokens) & Counter(ground_truth_tokens) 47 | num_same = sum(common.values()) 48 | if num_same == 0: 49 | all_f1s.append(0) 50 | else: 51 | precision = 1.0 * num_same / len(prediction_tokens) 52 | recall = 1.0 * num_same / len(ground_truth_tokens) 53 | all_f1s.append((2 * precision * recall) / (precision + recall)) 54 | f1.append(max(all_f1s)) 55 | 56 | return np.mean(f1) 57 | 58 | 59 | def f1(pred, gold): 60 | """ 61 | This separate F1 function is used as non-differentiable metric for SQuAD 62 | """ 63 | if gold[0] == "CANNOTANSWER" or gold[0] == "no answer": 64 | return int(normalize_answer(gold[0]) == normalize_answer(pred)) 65 | else: 66 | all_f1s = [] 67 | for ans in gold: 68 | prediction_tokens = normalize_answer(pred).split() 69 | ground_truth_tokens = normalize_answer(ans).split() 70 | common = Counter(prediction_tokens) & Counter(ground_truth_tokens) 71 | num_same = sum(common.values()) 72 | if num_same == 0: 73 | all_f1s.append(0) 74 | else: 75 | precision = 1.0 * num_same / len(prediction_tokens) 76 | recall = 1.0 * num_same / len(ground_truth_tokens) 77 | all_f1s.append((2 * precision * recall) / (precision + recall)) 78 | return np.max(all_f1s) -------------------------------------------------------------------------------- /large_models/modeling_mistral/__init__.py: -------------------------------------------------------------------------------- 1 | from .configuration_mistral import MistralConfig, MISTRAL_PRETRAINED_CONFIG_ARCHIVE_MAP 2 | from .modeling_mistral import ( 3 | MistralModel, 4 | MistralForCausalLM, 5 | MistralForSequenceClassification, 6 | MistralPreTrainedModel, 7 | MistralForCausalLMWithHeadTuning 8 | ) 9 | -------------------------------------------------------------------------------- /large_models/modeling_mistral/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zimingyy/SubZero/cf6effdd0aef2b1cea79b2a121dc2ad4e7037414/large_models/modeling_mistral/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /large_models/modeling_mistral/__pycache__/configuration_mistral.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zimingyy/SubZero/cf6effdd0aef2b1cea79b2a121dc2ad4e7037414/large_models/modeling_mistral/__pycache__/configuration_mistral.cpython-310.pyc -------------------------------------------------------------------------------- /large_models/modeling_mistral/__pycache__/modeling_mistral.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zimingyy/SubZero/cf6effdd0aef2b1cea79b2a121dc2ad4e7037414/large_models/modeling_mistral/__pycache__/modeling_mistral.cpython-310.pyc -------------------------------------------------------------------------------- /large_models/modeling_mistral/configuration_mistral.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 Mistral AI and the HuggingFace Inc. team. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """ Mistral model configuration""" 16 | 17 | from transformers.configuration_utils import PretrainedConfig 18 | from transformers.utils import logging 19 | 20 | 21 | logger = logging.get_logger(__name__) 22 | 23 | MISTRAL_PRETRAINED_CONFIG_ARCHIVE_MAP = { 24 | "mistralai/Mistral-7B-v0.1": "https://huggingface.co/mistralai/Mistral-7B-v0.1/resolve/main/config.json", 25 | "mistralai/Mistral-7B-Instruct-v0.1": "https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1/resolve/main/config.json", 26 | } 27 | 28 | 29 | class MistralConfig(PretrainedConfig): 30 | r""" 31 | This is the configuration class to store the configuration of a [`MistralModel`]. It is used to instantiate an 32 | Mistral model according to the specified arguments, defining the model architecture. Instantiating a configuration 33 | with the defaults will yield a similar configuration to that of the Mistral-7B-v0.1 or Mistral-7B-Instruct-v0.1. 34 | 35 | [mistralai/Mistral-7B-v0.1](https://huggingface.co/mistralai/Mistral-7B-v0.1) 36 | [mistralai/Mistral-7B-Instruct-v0.1](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1) 37 | 38 | Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the 39 | documentation from [`PretrainedConfig`] for more information. 40 | 41 | 42 | Args: 43 | vocab_size (`int`, *optional*, defaults to 32000): 44 | Vocabulary size of the Mistral model. Defines the number of different tokens that can be represented by the 45 | `inputs_ids` passed when calling [`MistralModel`] 46 | hidden_size (`int`, *optional*, defaults to 4096): 47 | Dimension of the hidden representations. 48 | intermediate_size (`int`, *optional*, defaults to 14336): 49 | Dimension of the MLP representations. 50 | num_hidden_layers (`int`, *optional*, defaults to 32): 51 | Number of hidden layers in the Transformer encoder. 52 | num_attention_heads (`int`, *optional*, defaults to 32): 53 | Number of attention heads for each attention layer in the Transformer encoder. 54 | num_key_value_heads (`int`, *optional*, defaults to 8): 55 | This is the number of key_value heads that should be used to implement Grouped Query Attention. If 56 | `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if 57 | `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When 58 | converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed 59 | by meanpooling all the original heads within that group. For more details checkout [this 60 | paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to `8`. 61 | hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): 62 | The non-linear activation function (function or string) in the decoder. 63 | max_position_embeddings (`int`, *optional*, defaults to `4096*32`): 64 | The maximum sequence length that this model might ever be used with. Mistral's sliding window attention 65 | allows sequence of up to 4096*32 tokens. 66 | initializer_range (`float`, *optional*, defaults to 0.02): 67 | The standard deviation of the truncated_normal_initializer for initializing all weight matrices. 68 | rms_norm_eps (`float`, *optional*, defaults to 1e-06): 69 | The epsilon used by the rms normalization layers. 70 | use_cache (`bool`, *optional*, defaults to `True`): 71 | Whether or not the model should return the last key/values attentions (not used by all models). Only 72 | relevant if `config.is_decoder=True`. 73 | pad_token_id (`int`, *optional*): 74 | The id of the padding token. 75 | bos_token_id (`int`, *optional*, defaults to 1): 76 | The id of the "beginning-of-sequence" token. 77 | eos_token_id (`int`, *optional*, defaults to 2): 78 | The id of the "end-of-sequence" token. 79 | tie_word_embeddings (`bool`, *optional*, defaults to `False`): 80 | Whether the model's input and output word embeddings should be tied. 81 | rope_theta (`float`, *optional*, defaults to 10000.0): 82 | The base period of the RoPE embeddings. 83 | sliding_window (`int`, *optional*, defaults to 4096): 84 | Sliding window attention window size. If not specified, will default to `4096`. 85 | attention_dropout (`float`, *optional*, defaults to 0.0): 86 | The dropout ratio for the attention probabilities. 87 | 88 | ```python 89 | >>> from transformers import MistralModel, MistralConfig 90 | 91 | >>> # Initializing a Mistral 7B style configuration 92 | >>> configuration = MistralConfig() 93 | 94 | >>> # Initializing a model from the Mistral 7B style configuration 95 | >>> model = MistralModel(configuration) 96 | 97 | >>> # Accessing the model configuration 98 | >>> configuration = model.config 99 | ```""" 100 | 101 | model_type = "mistral" 102 | keys_to_ignore_at_inference = ["past_key_values"] 103 | 104 | def __init__( 105 | self, 106 | vocab_size=32000, 107 | hidden_size=4096, 108 | intermediate_size=14336, 109 | num_hidden_layers=32, 110 | num_attention_heads=32, 111 | num_key_value_heads=8, 112 | hidden_act="silu", 113 | max_position_embeddings=4096 * 32, 114 | initializer_range=0.02, 115 | rms_norm_eps=1e-6, 116 | use_cache=True, 117 | pad_token_id=None, 118 | bos_token_id=1, 119 | eos_token_id=2, 120 | tie_word_embeddings=False, 121 | rope_theta=10000.0, 122 | sliding_window=4096, 123 | attention_dropout=0.0, 124 | **kwargs, 125 | ): 126 | self.vocab_size = vocab_size 127 | self.max_position_embeddings = max_position_embeddings 128 | self.hidden_size = hidden_size 129 | self.intermediate_size = intermediate_size 130 | self.num_hidden_layers = num_hidden_layers 131 | self.num_attention_heads = num_attention_heads 132 | self.sliding_window = sliding_window 133 | 134 | # for backward compatibility 135 | if num_key_value_heads is None: 136 | num_key_value_heads = num_attention_heads 137 | 138 | self.num_key_value_heads = num_key_value_heads 139 | self.hidden_act = hidden_act 140 | self.initializer_range = initializer_range 141 | self.rms_norm_eps = rms_norm_eps 142 | self.use_cache = use_cache 143 | self.rope_theta = rope_theta 144 | self.attention_dropout = attention_dropout 145 | 146 | super().__init__( 147 | pad_token_id=pad_token_id, 148 | bos_token_id=bos_token_id, 149 | eos_token_id=eos_token_id, 150 | tie_word_embeddings=tie_word_embeddings, 151 | **kwargs, 152 | ) -------------------------------------------------------------------------------- /large_models/prefix_tuning.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') 4 | logger = logging.getLogger(__name__) 5 | logger.setLevel(logging.INFO) 6 | 7 | import torch 8 | from torch import nn 9 | 10 | 11 | def find_module(root_module: nn.Module, key: str): 12 | """ 13 | Find a module with a specific name in a Transformer model 14 | From OpenDelta https://github.com/thunlp/OpenDelta 15 | """ 16 | sub_keys = key.split(".") 17 | parent_module = root_module 18 | for sub_key in sub_keys[:-1]: 19 | parent_module = getattr(parent_module, sub_key) 20 | module = getattr(parent_module, sub_keys[-1]) 21 | return parent_module, sub_keys[-1], module 22 | 23 | 24 | def attn_forward_hook(self, *args, **kwargs): 25 | """ 26 | Replace the original attention forward with this to enable prefix 27 | """ 28 | 29 | def _expand_bsz(x, bsz): 30 | x = x.reshape(x.size(0), self.num_heads, -1).transpose(0, 31 | 1) # (num_prefix, hidden) -> (num_head, num_prefix, hidden/num_head) 32 | x = x.unsqueeze(0).expand(bsz, *x.shape) # -> (bsz, num_head, num_prefix, hidden/num_head) 33 | return x 34 | 35 | if "hidden_states" in kwargs: 36 | hidden_states = kwargs["hidden_states"] 37 | else: 38 | hidden_states = args[0] 39 | bsz = hidden_states.size(0) 40 | 41 | if 'past_key_value' not in kwargs or kwargs['past_key_value'] is None: 42 | if self.reparam: 43 | prefix_keys = self.prefix_mlp_keys(self.prefix_input_embeds) 44 | prefix_values = self.prefix_mlp_values(self.prefix_input_embeds) 45 | else: 46 | prefix_keys, prefix_values = self.prefix_keys, self.prefix_values 47 | kwargs['past_key_value'] = (_expand_bsz(prefix_keys, bsz), _expand_bsz(prefix_values, bsz)) 48 | 49 | if 'attention_mask' in kwargs and kwargs['attention_mask'] is not None: 50 | am = kwargs['attention_mask'] 51 | kwargs['attention_mask'] = torch.cat( 52 | [-torch.zeros((*am.shape[:-1], self.num_prefix), dtype=am.dtype, device=am.device), am], dim=-1) 53 | elif len(args) > 1: # attention mask is passed via positional argument 54 | am = args[1] 55 | am = torch.cat([-torch.zeros((*am.shape[:-1], self.num_prefix), dtype=am.dtype, device=am.device), am], 56 | dim=-1) 57 | args = (args[0], am) + args[2:] 58 | 59 | return self.original_forward(*args, **kwargs) 60 | 61 | 62 | def prepare_inputs_for_generation( 63 | self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs): 64 | """ 65 | Replace the original "prepare_inputs_for_generation" with this to pass prefix correctly 66 | """ 67 | original_input_len = input_ids.size(-1) 68 | if past_key_values: 69 | input_ids = input_ids[:, -1:] 70 | 71 | # if `inputs_embeds` are passed, we only want to use them in the 1st generation step 72 | if inputs_embeds is not None and past_key_values is None: 73 | model_inputs = {"inputs_embeds": inputs_embeds} 74 | else: 75 | model_inputs = {"input_ids": input_ids} 76 | 77 | if past_key_values is not None: 78 | # Check if we should add extra to attention mask 79 | if past_key_values[0][0].size(2) != attention_mask.size(1) - 1: 80 | num_prefix = past_key_values[0][0].size(2) - (attention_mask.size(1) - 1) 81 | attention_mask = torch.cat([torch.ones((attention_mask.size(0), num_prefix), dtype=attention_mask.dtype, 82 | device=attention_mask.device), attention_mask], dim=-1) 83 | 84 | model_inputs.update( 85 | { 86 | "past_key_values": past_key_values, 87 | "use_cache": kwargs.get("use_cache"), 88 | "attention_mask": attention_mask, 89 | } 90 | ) 91 | return model_inputs 92 | 93 | 94 | class PrefixTuning: 95 | 96 | def __init__(self, model, num_prefix, reparam=True, embed_dim=512, mid_dim=512, float16=False, 97 | init_by_real_act=False): 98 | """ 99 | Inputs: 100 | num_prefix: number of prefix tokens 101 | reparam: use reparameterization trick (not used in MeZO) 102 | embed_dim, mid_dim: hyperparameters for reparameterization trick (not used in MeZO) 103 | float15: whether the model parameters are float15 104 | init_by_real_act: init prefix tokens by real activations 105 | """ 106 | 107 | self.model = model 108 | self.num_prefix = num_prefix 109 | self.hidden_dim = model.config.hidden_size 110 | self.float16 = float16 111 | 112 | # Reparameterization 113 | self.reparam = reparam 114 | self.embed_dim = embed_dim 115 | self.mid_dim = mid_dim 116 | 117 | input_embeds = None # For reparameterization 118 | if model.config.model_type == "opt": 119 | attention_name = "attn" 120 | first_layer_name = "layers.0" 121 | layer_name = "layers." 122 | elif model.config.model_type == "roberta": 123 | attention_name = "attention" 124 | first_layer_name = "layer.0" 125 | layer_name = "layer." 126 | elif model.config.model_type in ["llama", "mistral"]: 127 | attention_name = "self_attn" 128 | first_layer_name = "layers.0" 129 | layer_name = "layers." 130 | else: 131 | raise NotImplementedError 132 | 133 | if init_by_real_act: 134 | # Initialize prefix with real words' activations 135 | assert not reparam 136 | 137 | # Randomly sample input tokens 138 | input_tokens = torch.randint(low=0, high=model.config.vocab_size, size=(1, num_prefix), 139 | dtype=torch.long).cuda() 140 | if model.config.model_type in ["opt", "llama", "mistral"]: 141 | with torch.no_grad(): 142 | # Get the real activations 143 | real_key_values = model(input_ids=input_tokens, use_cache=True).past_key_values 144 | else: 145 | raise NotImplementedError 146 | 147 | # Insert prefix 148 | for key, _ in model.named_modules(): 149 | if key[-len(attention_name):] == attention_name: 150 | layer_id = int(key.split(layer_name)[1].split(".")[0]) 151 | logger.info(f"Inject prefix to: {key}") 152 | _, _, attn = find_module(model, key) 153 | 154 | # Replace the old forward functions 155 | attn.original_forward = attn.forward 156 | attn.forward = attn_forward_hook.__get__(attn, type(attn)) 157 | if not hasattr(attn, "num_heads"): 158 | attn.num_heads = model.config.num_attention_heads 159 | first = first_layer_name in key 160 | self.add_prefix(attn, first=first, input_embeds=input_embeds) 161 | 162 | if first and self.reparam: 163 | input_embeds = attn.prefix_input_embeds 164 | if init_by_real_act: 165 | logger.info(f"Reinitialize with actual activation: {key} (layer {layer_id})") 166 | keys = real_key_values[layer_id][0].squeeze(0).transpose(0, 1).reshape(num_prefix, -1) 167 | values = real_key_values[layer_id][1].squeeze(0).transpose(0, 1).reshape(num_prefix, -1) 168 | attn.prefix_keys.data = keys.to(attn.prefix_keys.data.device) 169 | attn.prefix_values.data = values.to(attn.prefix_values.data.device) 170 | 171 | # Freeze non-prefix parameters 172 | for n, p in model.named_parameters(): 173 | if "prefix" not in n: 174 | p.requires_grad = False 175 | 176 | # Replace the old prepare_inputs_for_generation function 177 | model.prepare_inputs_for_generation = prepare_inputs_for_generation.__get__(model, type(model)) 178 | 179 | def add_prefix(self, module, first, input_embeds=None): 180 | device = module.k_proj.weight.data.device 181 | module.num_prefix = self.num_prefix 182 | module.reparam = self.reparam 183 | if self.reparam: 184 | if first: 185 | # For the first layer we inject the embeddings 186 | logger.info("For prefix+reparameterization, inject the embeddings in the first layer.") 187 | module.prefix_input_embeds = nn.Parameter( 188 | torch.randn(self.num_prefix, self.embed_dim, device=device, dtype=self.model.dtype), 189 | requires_grad=True) 190 | else: 191 | assert input_embeds is not None 192 | module.prefix_input_embeds = input_embeds 193 | module.prefix_mlp_keys = nn.Sequential( 194 | nn.Linear(self.embed_dim, self.mid_dim), 195 | nn.Tanh(), 196 | nn.Linear(self.mid_dim, self.hidden_dim) 197 | ).to(device) 198 | module.prefix_mlp_values = nn.Sequential( 199 | nn.Linear(self.embed_dim, self.mid_dim), 200 | nn.Tanh(), 201 | nn.Linear(self.mid_dim, self.hidden_dim) 202 | ).to(device) 203 | if self.float16: 204 | module.prefix_mlp_keys = module.prefix_mlp_keys.half() 205 | module.prefix_mlp_values = module.prefix_mlp_values.half() 206 | else: 207 | module.prefix_keys = nn.Parameter( 208 | torch.randn(self.num_prefix, self.hidden_dim, device=device, dtype=self.model.dtype), 209 | requires_grad=True) 210 | module.prefix_values = nn.Parameter( 211 | torch.randn(self.num_prefix, self.hidden_dim, device=device, dtype=self.model.dtype), 212 | requires_grad=True) 213 | -------------------------------------------------------------------------------- /large_models/prompt_tuning.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from functools import partial 3 | from typing import Optional, Callable 4 | 5 | import torch 6 | from torch import nn 7 | from transformers import PreTrainedModel 8 | 9 | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') 10 | logger = logging.getLogger(__name__) 11 | logger.setLevel(logging.INFO) 12 | 13 | 14 | class PromptEmbedding(nn.Module): 15 | def __init__( 16 | self, 17 | num_virtual_tokens: int, 18 | token_dim: int, 19 | init_by_real_text: bool, 20 | word_embeddings: Optional[nn.Module] = None, 21 | vocab_size: Optional[int] = None, 22 | ): 23 | super().__init__() 24 | self.num_virtual_tokens = num_virtual_tokens 25 | 26 | self.embedding = torch.nn.Embedding(num_virtual_tokens, token_dim) 27 | if init_by_real_text: 28 | init_token_ids = torch.randint( 29 | low=0, high=vocab_size, 30 | size=(num_virtual_tokens,), dtype=torch.long 31 | ).to(word_embeddings.weight.device) 32 | 33 | word_embedding_weights = word_embeddings(init_token_ids).detach().clone() 34 | word_embedding_weights = word_embedding_weights.to(torch.float32) 35 | self.embedding.weight = nn.Parameter(word_embedding_weights) 36 | 37 | def forward(self, indices): 38 | # Just get embeddings 39 | prompt_embeddings = self.embedding(indices) 40 | return prompt_embeddings 41 | 42 | 43 | def _get_batch_size(input_ids: Optional[torch.Tensor], inputs_embeds: Optional[torch.Tensor]) -> int: 44 | if (input_ids is None) and (inputs_embeds is None): 45 | raise ValueError("You have to provide either input_ids or inputs_embeds") 46 | 47 | if input_ids is not None: 48 | batch_size = input_ids.shape[0] 49 | else: 50 | batch_size = inputs_embeds.shape[0] 51 | return batch_size 52 | 53 | 54 | def _model_forward_hook( 55 | self, 56 | embedding_module: Callable, 57 | embedding_module_device_refer, 58 | hide_virtual_token_logits: bool, 59 | input_ids=None, 60 | attention_mask=None, 61 | inputs_embeds=None, 62 | labels=None, 63 | output_attentions=None, 64 | output_hidden_states=None, 65 | return_dict=None, 66 | **kwargs, 67 | ): 68 | batch_size = _get_batch_size(input_ids, inputs_embeds) 69 | num_virtual_tokens = self.prompt_encoder.num_virtual_tokens 70 | if attention_mask is not None: 71 | # concat prompt attention mask 72 | prefix_attention_mask = torch.ones(batch_size, num_virtual_tokens).to(attention_mask.device) 73 | attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1) 74 | if kwargs.get("position_ids", None) is not None: 75 | warnings.warn("Position ids are not supported for parameter efficient tuning. Ignoring position ids.") 76 | kwargs["position_ids"] = None 77 | kwargs.update( 78 | { 79 | "attention_mask": attention_mask, 80 | "output_attentions": output_attentions, 81 | "output_hidden_states": output_hidden_states, 82 | "return_dict": return_dict, 83 | } 84 | ) 85 | 86 | if labels is not None: 87 | if len(labels.shape) == 1: 88 | # if sequence classification task, labels do not have to be padded 89 | kwargs["labels"] = labels 90 | elif len(labels.shape) == 2: 91 | # suppose to be language modeling task, labels have to be padded with -100 92 | kwargs["labels"] = torch.cat( 93 | ( 94 | -100 * torch.ones(batch_size, num_virtual_tokens).to(labels.device).long(), 95 | labels, 96 | ), 97 | dim=1, 98 | ) 99 | else: 100 | raise NotImplementedError("Not implemented for labels with shape {}".format(labels.shape)) 101 | 102 | if kwargs.get("token_type_ids", None) is not None: 103 | kwargs["token_type_ids"] = torch.cat( 104 | ( 105 | torch.zeros(batch_size, num_virtual_tokens).to(kwargs["token_type_ids"].device), 106 | kwargs["token_type_ids"], 107 | ), 108 | dim=1, 109 | ).long() 110 | 111 | if kwargs.get("mask_pos", None) is not None: 112 | kwargs["mask_pos"] = num_virtual_tokens + kwargs["mask_pos"] 113 | 114 | input_device = input_ids.device if input_ids is not None else inputs_embeds.device 115 | if inputs_embeds is None: 116 | inputs_embeds = embedding_module(input_ids.to(embedding_module_device_refer.device)) 117 | inputs_embeds = inputs_embeds.to(input_device) 118 | prompts = torch.arange(num_virtual_tokens).unsqueeze(0).expand(batch_size, -1).to( 119 | self.prompt_encoder.embedding.weight.device) 120 | prompts = self.prompt_encoder(prompts).to(dtype=inputs_embeds.dtype, device=input_device) 121 | 122 | inputs_embeds = torch.cat((prompts, inputs_embeds), dim=1) 123 | 124 | outputs = self.prompt_tuning_original_forward(inputs_embeds=inputs_embeds, **kwargs) 125 | if hide_virtual_token_logits and hasattr(outputs, "logits"): 126 | outputs.logits = outputs.logits[..., num_virtual_tokens:, :] 127 | return outputs 128 | 129 | 130 | class PromptTuning: 131 | 132 | def __init__( 133 | self, 134 | model: PreTrainedModel, 135 | num_virtual_tokens: int, 136 | init_by_real_tokens: Optional[bool] = False, 137 | hide_virtual_token_logits: Optional[bool] = True, 138 | ): 139 | """ 140 | Prompt tuning model initializer. 141 | 142 | Parameters 143 | ---------- 144 | model: PreTrainedModel, required 145 | The model to be tuned. 146 | num_virtual_tokens: int, required 147 | The number of virtual tokens to be added. 148 | init_by_real_tokens: bool, optional, default=False 149 | Whether to initialize the virtual tokens by real tokens. 150 | """ 151 | hidden_dim = model.config.hidden_size 152 | 153 | if model.config.model_type == "opt": 154 | embedding_module = model.get_input_embeddings() 155 | embedding_module_device_refer = embedding_module.weight 156 | elif model.config.model_type == "roberta": 157 | if hasattr(model, "roberta"): # is RoBERTaForMaskedLM etc. 158 | embedding_module = partial(model.roberta.embeddings, past_key_values_length=num_virtual_tokens) 159 | embedding_module_device_refer = model.roberta.embeddings.word_embeddings.weight 160 | elif hasattr(model, "embeddings"): # is RoBERTa base model 161 | embedding_module = partial(model.embeddings, past_key_values_length=num_virtual_tokens) 162 | embedding_module_device_refer = model.embeddings.word_embeddings.weight 163 | else: 164 | raise ValueError(f"Cannot find embedding module in {model.__class__.__name__}") 165 | elif model.config.model_type in ["llama", "mistral"]: 166 | embedding_module = model.get_input_embeddings() 167 | embedding_module_device_refer = embedding_module.weight 168 | else: 169 | raise NotImplementedError 170 | 171 | model.prompt_encoder = PromptEmbedding( 172 | num_virtual_tokens, hidden_dim, init_by_real_tokens, 173 | model.get_input_embeddings(), model.config.vocab_size 174 | ) 175 | 176 | model.prompt_tuning_original_forward = model.forward 177 | 178 | if not hasattr(embedding_module_device_refer, "device"): 179 | raise ValueError(f"Cannot find device attribute in {embedding_module_device_refer.__class__.__name__}") 180 | 181 | forward_hook_kwargs = { 182 | "embedding_module": embedding_module, 183 | "embedding_module_device_refer": embedding_module_device_refer, 184 | "hide_virtual_token_logits": hide_virtual_token_logits, 185 | } 186 | model.forward = partial( 187 | _model_forward_hook.__get__(model, type(model)), 188 | **forward_hook_kwargs 189 | ) 190 | 191 | for n, p in model.named_parameters(): 192 | if "prompt_encoder" not in n: 193 | p.requires_grad = False 194 | 195 | 196 | def test_roberta(): 197 | from transformers import AutoTokenizer, RobertaModel 198 | model = RobertaModel.from_pretrained("roberta-base") 199 | tokenizer = AutoTokenizer.from_pretrained("roberta-base") 200 | 201 | PromptTuning(model, num_virtual_tokens=5, init_by_real_tokens=True) 202 | 203 | inputs = tokenizer("in heissem Liebesstreben", return_tensors="pt") 204 | outputs = model(**inputs) 205 | 206 | 207 | def test_opt(): 208 | from transformers import AutoTokenizer, OPTModel 209 | model = OPTModel.from_pretrained("facebook/opt-125m") 210 | tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125m") 211 | 212 | PromptTuning(model, num_virtual_tokens=5, init_by_real_tokens=True) 213 | 214 | inputs = tokenizer("werd ich entschweben", return_tensors="pt") 215 | outputs = model(**inputs) 216 | 217 | 218 | if __name__ == "__main__": 219 | test_roberta() 220 | test_opt() 221 | -------------------------------------------------------------------------------- /large_models/run.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import random 5 | 6 | import wandb 7 | from torch.utils.tensorboard import SummaryWriter 8 | from datetime import datetime 9 | from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP 10 | from torch.utils.data import Dataset 11 | from tqdm import tqdm 12 | from transformers import ( 13 | AutoConfig, 14 | AutoTokenizer, 15 | AutoModelForCausalLM, 16 | HfArgumentParser, 17 | TrainingArguments, 18 | DataCollatorForTokenClassification 19 | ) 20 | 21 | from metrics import calculate_metric 22 | from modeling_mistral import ( 23 | MistralForCausalLM, 24 | MistralConfig 25 | ) 26 | from tasks import get_task 27 | from trainer import OurTrainer 28 | from utils import * 29 | 30 | os.environ["TRANSFORMERS_CACHE"] = "./cache" 31 | 32 | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') 33 | logger = logging.getLogger(__name__) 34 | logger.setLevel(logging.INFO) 35 | 36 | AutoConfig.register("mistral", MistralConfig) 37 | AutoModelForCausalLM.register(MistralConfig, MistralForCausalLM) 38 | 39 | 40 | @dataclass 41 | class OurArguments(TrainingArguments): 42 | # dataset and sampling strategy 43 | task_name: str = "SST2" # task name should match the string before Dataset in the Dataset class name. We support the following task_name: SST2, RTE, CB, BoolQ, WSC, WIC, MultiRC, Copa, ReCoRD, SQuAD, DROP 44 | 45 | # Number of examples 46 | num_train: int = 0 # ICL mode: number of demonstrations; training mode: number of training samples 47 | num_dev: int = None # (only enabled with training) number of development samples 48 | num_eval: int = None # number of evaluation samples 49 | num_train_sets: int = None # how many sets of training samples/demos to sample; if None and train_set_seed is None, then we will sample one set for each evaluation sample 50 | train_set_seed: int = 0 # designated seed to sample training samples/demos 51 | result_file: str = None # file name for saving performance; if None, then use the task name, model name, and config 52 | 53 | # Model loading 54 | model_name: str = "facebook/opt-125m" # HuggingFace model name 55 | load_float16: bool = False # load model parameters as float16 56 | load_bfloat16: bool = False # load model parameters as bfloat16 57 | load_int8: bool = False # load model parameters as int8 58 | max_length: int = 2048 # max length the model can take 59 | no_auto_device: bool = False # do not load model by auto device; should turn this on when using FSDP 60 | 61 | # Calibration 62 | sfc: bool = False # whether to use SFC calibration 63 | icl_sfc: bool = False # whether to use SFC calibration for ICL samples 64 | 65 | template_ver: int = 0 # template. For some tasks (SST2, RTE, Copa), we add template ver=1 as the empty template. 66 | 67 | # Training 68 | trainer: str = "subzero_sgd" 69 | ## options 70 | ## - none: no training -- for zero-shot or in-context learning (ICL) 71 | ## - regular: regular huggingface trainer -- for fine-tuning 72 | ## - zo_sgd: zeroth-order SGD (MeZO) training 73 | ## - zo_conserv: zeroth-order SGD conservative training 74 | ## - zo_adam: zeroth-order Adam training 75 | ## - zo_sign_opt: zeroth-order sign sgd training 76 | ## - forward_grad: forward gradient 77 | ## (add) -zo_sgd_svd 78 | 79 | optimizer: str = "adamw" 80 | ## options 81 | ## - sgd 82 | ## - adam 83 | ## - adamw # this is huggingface default 84 | only_train_option: bool = True # whether to only train the option part of the input 85 | train_as_classification: bool = False # take the log likelihood of all options and train as classification 86 | momentum: float = 0.0 # only work for SGD optimizer 87 | lr_scheduler_type: str = "constant" # only work for SGD optimizer 88 | 89 | # MeZO and SubZero 90 | zo_eps: float = 1e-3 # eps in MeZO 91 | perturbation_mode: str = "two_side" 92 | q: int = 1 # number of Gaussian samples for zeroth-order trainers 93 | 94 | update_interval: int = 2000 95 | gauss_rank: int = 8 96 | 97 | 98 | # Prefix tuning 99 | prefix_tuning: bool = False # whether to use prefix tuning 100 | num_prefix: int = 5 # number of prefixes to use 101 | no_reparam: bool = True # do not use reparameterization trick 102 | prefix_init_by_real_act: bool = True # initialize prefix by real activations of random words 103 | 104 | # prompt tuning hyperparameters 105 | prompt_tuning: bool = False # whether to use prompt tuning 106 | num_virtual_tokens: int = 10 # number of prompt tokens to use 107 | prompt_init_by_real_tokens: bool = False # whether to sample random tokens from Embedding layer 108 | 109 | # LoRA 110 | lora: bool = False # whether to use LoRA 111 | lora_alpha: int = 16 # alpha in LoRA 112 | lora_r: int = 8 # r in LoRA 113 | 114 | # Generation 115 | sampling: bool = False # whether to use sampling 116 | temperature: float = 1.0 # temperature for generation 117 | num_beams: int = 1 # number of beams for generation 118 | top_k: int = None # top-k for generation 119 | top_p: float = 0.95 # top-p for generation 120 | max_new_tokens: int = 50 # max number of new tokens to generate 121 | eos_token: str = "\n" # end of sentence token 122 | 123 | # Saving 124 | save_model: bool = False # whether to save the model 125 | no_eval: bool = False # whether to skip evaluation 126 | tag: str = "" # saving tag 127 | 128 | # Linear probing 129 | linear_probing: bool = False # whether to do linear probing 130 | lp_early_stopping: bool = False # whether to do early stopping in linear probing 131 | head_tuning: bool = False # head tuning: only tune the LM head 132 | 133 | # Untie emb/lm_head weights 134 | untie_emb: bool = False # untie the embeddings and LM head 135 | 136 | # Display 137 | verbose: bool = False # verbose output 138 | 139 | # Non-diff objective 140 | non_diff: bool = False # use non-differentiable objective (only support F1 for SQuAD for now) 141 | 142 | # Auto saving when interrupted 143 | save_on_interrupt: bool = False # save model when interrupted (useful for long training) 144 | 145 | clean_model_at_end: bool = True # remove everthing at the end. 146 | 147 | def parse_args(): 148 | parser = argparse.ArgumentParser() 149 | parser = HfArgumentParser(OurArguments) 150 | args = parser.parse_args_into_dataclasses()[0] 151 | print(args) 152 | return args 153 | 154 | 155 | def set_seed(seed: int): 156 | random.seed(seed) 157 | np.random.seed(seed) 158 | torch.manual_seed(seed) 159 | torch.cuda.manual_seed_all(seed) 160 | 161 | 162 | class Framework: 163 | 164 | def __init__(self, args, task): 165 | self.args = args 166 | self.task = task 167 | self.model, self.tokenizer = self.load_model() 168 | 169 | def load_model(self): 170 | """ 171 | Load HuggingFace models 172 | """ 173 | with count_time("Loading model with FP%d" % (16 if self.args.load_float16 else 32)): 174 | free_in_GB = int(torch.cuda.mem_get_info()[0] / 1024 ** 3) 175 | print(free_in_GB) 176 | config = AutoConfig.from_pretrained(self.args.model_name) 177 | if self.args.untie_emb: 178 | # Untie embeddings/LM head 179 | logger.warn("Untie embeddings and LM head") 180 | config.tie_word_embeddings = False 181 | if self.args.head_tuning: 182 | torch_dtype = torch.float32 183 | if self.args.load_float16: 184 | torch_dtype = torch.float16 185 | elif self.args.load_bfloat16: 186 | torch_dtype = torch.bfloat16 187 | # Head tuning 188 | if "opt" in self.args.model_name.lower(): 189 | from modeling_opt import OPTForCausalLM 190 | model = OPTForCausalLM.from_pretrained( 191 | self.args.model_name, 192 | config=config, 193 | device_map='auto', 194 | torch_dtype=torch_dtype, 195 | max_memory={i: f'{free_in_GB - 5}GB' for i in 196 | range(torch.cuda.device_count())}, 197 | ) 198 | elif "llama" in self.args.model_name.lower(): 199 | from modeling_llama import LlamaForCausalLMWithHeadTuning 200 | model = LlamaForCausalLMWithHeadTuning.from_pretrained( 201 | self.args.model_name, 202 | config=config, 203 | device_map='auto', 204 | torch_dtype=torch_dtype, 205 | max_memory={i: f'{free_in_GB - 5}GB' for i in 206 | range(torch.cuda.device_count())}, 207 | ) 208 | elif "mistral" in self.args.model_name.lower(): 209 | from modeling_mistral import MistralForCausalLMWithHeadTuning 210 | model = MistralForCausalLMWithHeadTuning.from_pretrained( 211 | self.args.model_name, 212 | config=config, 213 | device_map='auto', 214 | torch_dtype=torch_dtype, 215 | max_memory={i: f'{free_in_GB - 5}GB' for i in 216 | range(torch.cuda.device_count())}, 217 | ) 218 | else: 219 | raise NotImplementedError(f"Head tuning is not supported for {self.args.model_name}") 220 | elif self.args.no_auto_device: 221 | # No auto device (use for FSDP) 222 | model = AutoModelForCausalLM.from_pretrained(self.args.model_name, config=config, ) 223 | else: 224 | # Auto device loading 225 | torch_dtype = torch.float32 226 | if self.args.load_float16: 227 | torch_dtype = torch.float16 228 | elif self.args.load_bfloat16: 229 | torch_dtype = torch.bfloat16 230 | model = AutoModelForCausalLM.from_pretrained(self.args.model_name, config=config, device_map='auto', 231 | torch_dtype=torch_dtype, 232 | max_memory={i: f'{free_in_GB - 0.5}GB' for i in 233 | range(torch.cuda.device_count())}, 234 | load_in_8bit=self.args.load_int8, ) 235 | model.eval() 236 | 237 | # Load tokenizer 238 | # In mezo, use_fast is set to False. But TypeError will occur when running SQuaD. Setting to be True can fix. 239 | tokenizer = AutoTokenizer.from_pretrained(self.args.model_name, use_fast=True) 240 | 241 | # HF tokenizer bug fix 242 | if "opt" in self.args.model_name: 243 | tokenizer.bos_token_id = 0 244 | 245 | if ("llama" in self.args.model_name) or ("mistral" in self.args.model_name.lower()): 246 | # LLaMA padding token 247 | tokenizer.pad_token_id = 0 # technically 248 | 249 | # Prefix tuning/LoRA 250 | if self.args.prefix_tuning: 251 | from prefix_tuning import PrefixTuning 252 | PrefixTuning(model, num_prefix=self.args.num_prefix, reparam=not self.args.no_reparam, 253 | float16=self.args.load_float16, init_by_real_act=self.args.prefix_init_by_real_act) 254 | if self.args.lora: 255 | from lora import LoRA 256 | LoRA(model, r=self.args.lora_r, alpha=self.args.lora_alpha, float16=self.args.load_float16) 257 | 258 | if self.args.prompt_tuning: 259 | from prompt_tuning import PromptTuning 260 | print("Adding Prompt Tuning to model...") 261 | PromptTuning( 262 | model, 263 | num_virtual_tokens=self.args.num_virtual_tokens, 264 | init_by_real_tokens=self.args.prompt_init_by_real_tokens, 265 | hide_virtual_token_logits=True, # a workaround for the other loss/prediction functions 266 | ) 267 | 268 | # for name, param in model.named_parameters(): 269 | # if name == 'prompt_encoder.embedding.weight': 270 | # print(param.shape, end="\n") 271 | 272 | 273 | print("Total/Trainable number of parameters: {}/{}".format( 274 | sum(p.numel() for p in model.parameters()), 275 | sum(p.numel() for p in model.parameters() if p.requires_grad), 276 | )) 277 | 278 | if self.args.head_tuning: 279 | if model.config.model_type in ["opt", "llama", "mistral"]: 280 | head_name = "lm_head" if self.args.untie_emb else "embed_tokens" 281 | else: 282 | raise NotImplementedError 283 | for n, p in model.named_parameters(): 284 | if head_name not in n: 285 | p.requires_grad = False 286 | else: 287 | logger.info(f"Only tuning {n}") 288 | 289 | return model, tokenizer 290 | 291 | def forward(self, input_ids, option_len=None, generation=False): 292 | """ 293 | Given input_ids and the length of the option, return the log-likelihood of each token in the option. 294 | For generation tasks, return the generated text. 295 | This function is only for inference 296 | """ 297 | input_ids = torch.tensor([input_ids]).to(self.model.device) 298 | 299 | if generation: 300 | args = self.args 301 | # Autoregressive generation 302 | outputs = self.model.generate(input_ids, do_sample=args.sampling, temperature=args.temperature, 303 | num_beams=args.num_beams, top_p=args.top_p, top_k=args.top_k, 304 | max_new_tokens=min(args.max_new_tokens, args.max_length - input_ids.size(1)), 305 | num_return_sequences=1, 306 | eos_token_id=[ 307 | self.tokenizer.encode(args.eos_token, add_special_tokens=False)[-1], 308 | self.tokenizer.eos_token_id], ) 309 | # For generation, directly return the text output 310 | output_text = self.tokenizer.decode(outputs[0][input_ids.size(1):], skip_special_tokens=True).strip() 311 | return output_text 312 | else: 313 | with torch.inference_mode(): 314 | self.model.eval() 315 | logits = self.model(input_ids=input_ids).logits 316 | labels = input_ids[0, 1:] 317 | logits = logits[0, :-1] 318 | log_probs = F.log_softmax(logits, dim=-1) 319 | 320 | selected_log_probs = log_probs[torch.arange(len(labels)).to(labels.device), labels] 321 | selected_log_probs = selected_log_probs.cpu().detach() 322 | # Only return the option (candidate) part 323 | return selected_log_probs[-option_len:] 324 | 325 | def one_step_pred(self, train_samples, eval_sample, verbose=False): 326 | """ 327 | Return the prediction on the eval sample. In ICL, use train_samples as demonstrations 328 | """ 329 | verbose = verbose or self.args.verbose 330 | # if verbose: 331 | # logger.info("========= Example =========") 332 | # logger.info(f"Candidate: {eval_sample.candidates}") 333 | # logger.info(f"Correct candidate: {eval_sample.correct_candidate}") 334 | 335 | # Encode (add prompt and tokenize) the sample; if multiple-choice/classification, encode all candidates (options) 336 | encoded_candidates, option_lens = encode_prompt(self.task, 337 | self.task.get_template(template_version=self.args.template_ver), 338 | train_samples, eval_sample, 339 | self.tokenizer, max_length=self.args.max_length, 340 | generation=self.task.generation, 341 | max_new_tokens=self.args.max_new_tokens) 342 | 343 | # Calibration 344 | if self.args.sfc or self.args.icl_sfc: 345 | sfc_encoded_candidates, sfc_option_lens = encode_prompt(self.task, self.task.get_template( 346 | template_version=self.args.template_ver), train_samples, 347 | eval_sample, self.tokenizer, 348 | max_length=self.args.max_length, sfc=self.args.sfc, 349 | icl_sfc=self.args.icl_sfc, 350 | generation=self.task.generation, 351 | max_new_tokens=self.args.max_new_tokens) 352 | 353 | outputs = [] 354 | if self.task.generation: 355 | # For generation tasks, return the autoregressively-generated text 356 | output_text = self.forward(encoded_candidates[0], generation=True) 357 | # if verbose: 358 | # logger.info("=== Prompt ===") 359 | # logger.info(self.tokenizer.decode(encoded_candidates[0])) 360 | # logger.info(f"Output: {output_text}") 361 | return Prediction(correct_candidate=eval_sample.correct_candidate, predicted_candidate=output_text) 362 | else: 363 | # For classification/multiple-choice, calculate the probabilities of all candidates 364 | for candidate_id, encoded_candidate in enumerate(encoded_candidates): 365 | selected_log_probs = self.forward(encoded_candidate, option_len=option_lens[candidate_id]) 366 | if verbose: 367 | # if candidate_id == 0: 368 | # logger.info("=== Candidate %d ===" % candidate_id) 369 | # logger.info(self.tokenizer.decode(encoded_candidate)) 370 | # else: 371 | # logger.info("=== Candidate %d (without context)===" % candidate_id) 372 | # logger.info(self.tokenizer.decode(encoded_candidate).split(self.task.train_sep)[-1]) 373 | logger.info(f"Log probabilities of the option tokens: {selected_log_probs}") 374 | 375 | if self.args.sfc or self.args.icl_sfc: 376 | sfc_selected_log_probs = self.forward(sfc_encoded_candidates[candidate_id], 377 | option_len=sfc_option_lens[ 378 | candidate_id]) # if verbose: # logger.info("=== Candidate %d (without context) SFC ===" % candidate_id) # logger.info( # self.tokenizer.decode(sfc_encoded_candidates[candidate_id]).split(self.task.train_sep)[-1]) # logger.info(f"Log probabilities of the option tokens: {sfc_selected_log_probs}") 379 | 380 | outputs.append({"log_probs": selected_log_probs, 381 | "sfc_log_probs": sfc_selected_log_probs if self.args.sfc or self.args.icl_sfc else None}) 382 | 383 | if self.args.sfc or self.args.icl_sfc: 384 | # Calibrated probabilities (surface form competition; https://arxiv.org/pdf/2104.08315.pdf) 385 | # log p(candidate | input) = log p_lm(candidate | input) - log p_lm(candidate | sfc prompt) 386 | scores = [x['log_probs'].sum().item() - x['sfc_log_probs'].sum().item() for x in outputs] 387 | else: 388 | # (Default) length-normalized log probabilities 389 | # log p(candidate | input) = log p_lm(candidate | input) / |candidate #tokens| 390 | scores = [x['log_probs'].mean().item() for x in outputs] 391 | 392 | if verbose: 393 | logger.info(f"Prediction scores: {scores}") 394 | 395 | if isinstance(eval_sample.correct_candidate, list): 396 | # For some datasets there are multiple correct answers 397 | correct_candidate_id = [eval_sample.candidates.index(c) for c in eval_sample.correct_candidate] 398 | else: 399 | correct_candidate_id = eval_sample.candidates.index(eval_sample.correct_candidate) 400 | 401 | return Prediction(correct_candidate=correct_candidate_id, predicted_candidate=int(np.argmax(scores))) 402 | 403 | def evaluate(self, train_samples, eval_samples, one_train_set_per_eval_sample=False, description=None): 404 | """ 405 | Evaluate function. 406 | Here, train_samples are used for demonstrations for ICL. 407 | If one_train_set_per_eval_sample is True, then each eval sample has its own training (demonstration) set. 408 | Otherwise, the same training set is used for all eval samples. 409 | """ 410 | if one_train_set_per_eval_sample: 411 | logger.info(f"There are {len(eval_samples)} validation samples and one train set per eval sample") 412 | else: 413 | logger.info(f"There are {len(train_samples)} training samples and {len(eval_samples)} validation samples") 414 | 415 | # Prediction loop 416 | predictions = [] 417 | for eval_id, eval_sample in enumerate(tqdm(eval_samples, desc=description)): 418 | predictions.append( 419 | self.one_step_pred(train_samples[eval_id] if one_train_set_per_eval_sample else train_samples, 420 | eval_sample, verbose=False)) 421 | 422 | # Calculate metrics 423 | metric_name = getattr(self.task, "metric_name", "accuracy") 424 | metrics = {metric_name: calculate_metric(predictions, metric_name)} 425 | return metrics 426 | 427 | def train(self, train_samples, dev_samples, eval_samples, writer): 428 | """ 429 | Training function 430 | if self.num_dev is not None, eval_samples are dev_samples 431 | """ 432 | logger.info(f"Eval sample length is {len(eval_samples)}") 433 | # Set tokenizer to left padding (so that all the options are right aligned) 434 | self.tokenizer.padding_side = "left" 435 | 436 | class HFDataset(Dataset): 437 | 438 | def __init__(self, data): 439 | self.data = data 440 | 441 | def __len__(self): 442 | return len(self.data) 443 | 444 | def __getitem__(self, idx): 445 | return self.data[idx] 446 | 447 | def _convert(samples): 448 | """ 449 | Convert samples to HF-compatible dataset 450 | """ 451 | data = [] 452 | for sample in samples: 453 | encoded_candidates, option_lens = encode_prompt(self.task, self.task.get_template( 454 | template_version=self.args.template_ver), [], sample, 455 | self.tokenizer, max_length=self.args.max_length, 456 | generation=self.task.generation, 457 | generation_with_gold=True, 458 | max_new_tokens=self.args.max_new_tokens) 459 | if self.task.generation: 460 | correct_candidate_id = 0 461 | elif isinstance(sample.correct_candidate, list): 462 | correct_candidate_id = sample.candidates.index(sample.correct_candidate[0]) 463 | else: 464 | correct_candidate_id = sample.candidates.index(sample.correct_candidate) 465 | 466 | if self.args.non_diff: 467 | # For non-differentiable objective, there is no teacher forcing thus the 468 | # current answer part is removed 469 | encoded_candidates[correct_candidate_id] = encoded_candidates[correct_candidate_id][ 470 | :-option_lens[correct_candidate_id]] 471 | 472 | if self.args.train_as_classification: 473 | # For classification, we provide the label as the correct candidate id 474 | data.append([{"input_ids": encoded_candidates[_i], "labels": correct_candidate_id, 475 | "option_len": option_lens[_i], "num_options": len(sample.candidates)} for _i in 476 | range(len(encoded_candidates))]) 477 | elif self.args.only_train_option: 478 | # Otherwise, it is just LM-style teacher forcing 479 | if self.args.non_diff: 480 | # For non-differentiable objective, we need to provide the gold answer to calculate F1/acc 481 | data.append({"input_ids": encoded_candidates[correct_candidate_id], 482 | "labels": encoded_candidates[correct_candidate_id], 483 | "option_len": option_lens[correct_candidate_id], "gold": sample.correct_candidate}) 484 | else: 485 | data.append({"input_ids": encoded_candidates[correct_candidate_id], 486 | "labels": encoded_candidates[correct_candidate_id], 487 | "option_len": option_lens[correct_candidate_id]}) 488 | else: 489 | data.append({"input_ids": encoded_candidates[correct_candidate_id], 490 | "labels": encoded_candidates[correct_candidate_id]}) 491 | return data 492 | 493 | with count_time("Tokenizing training samples"): 494 | train_dataset = HFDataset(_convert(train_samples)) 495 | eval_dataset = HFDataset(_convert(eval_samples)) 496 | dev_dataset = HFDataset(_convert(dev_samples)) 497 | 498 | if self.args.only_train_option and not self.args.non_diff: 499 | # If --only_train_option and not with a non-differentiable objective, we wrap the forward function 500 | self.model.original_forward = self.model.forward 501 | self.model.forward = forward_wrap_with_option_len.__get__(self.model, type(self.model)) 502 | 503 | if self.args.non_diff: 504 | collator = NondiffCollator 505 | else: 506 | collator = DataCollatorForTokenClassification 507 | 508 | trainer = OurTrainer(model=self.model, 509 | args=self.args, 510 | train_dataset=train_dataset, 511 | eval_dataset=eval_dataset, 512 | tokenizer=self.tokenizer, 513 | data_collator=DataCollatorWithPaddingAndNesting(self.tokenizer, 514 | pad_to_multiple_of=8) if self.args.train_as_classification else collator( 515 | self.tokenizer, pad_to_multiple_of=8), 516 | eval_samples=eval_samples, 517 | dev_samples=dev_samples, 518 | evaluate_func=self.evaluate, 519 | writer=writer 520 | ) 521 | 522 | if self.args.save_on_interrupt: 523 | trainer.add_callback(SIGUSR1Callback()) 524 | 525 | # Resume training from a last checkpoint 526 | last_checkpoint = None 527 | from transformers.trainer_utils import get_last_checkpoint 528 | if os.path.isdir(self.args.output_dir) and not self.args.overwrite_output_dir: 529 | last_checkpoint = get_last_checkpoint(self.args.output_dir) 530 | if last_checkpoint is not None and self.args.resume_from_checkpoint is None: 531 | logger.info(f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " 532 | "the `--output_dir` or add `--overwrite_output_dir` to train from scratch.") 533 | if self.args.resume_from_checkpoint is not None: 534 | last_checkpoint = self.args.resume_from_checkpoint 535 | 536 | # This calls the trainer._inner_training_loop() 537 | trainer.train(resume_from_checkpoint=last_checkpoint) 538 | 539 | # Explicitly save the model 540 | if self.args.save_model: 541 | logger.info("Save model..") 542 | trainer.save_model() 543 | 544 | # FSDP compatibility 545 | self.model = trainer.model 546 | 547 | # Reset the forward function for evaluation 548 | if self.args.only_train_option and not self.args.non_diff: 549 | if type(self.model) == FSDP: 550 | logger.info("This is an FSDP model now. Be careful when assigning back the original forward function") 551 | self.model._fsdp_wrapped_module.forward = self.model._fsdp_wrapped_module.original_forward 552 | else: 553 | self.model.forward = self.model.original_forward 554 | 555 | def delete_checkpoints(self): 556 | import shutil 557 | print(f"\nWARNING: Removing everything at end: {self.args.output_dir}") 558 | deleted_folders = [folder for folder in os.listdir(self.args.output_dir) 559 | if os.path.isdir(os.path.join(self.args.output_dir, folder)) 560 | and folder.startswith("checkpoint-")] 561 | for f in deleted_folders: 562 | shutil.rmtree(os.path.join(self.args.output_dir, f)) 563 | print(f"deleted folders: ", deleted_folders) 564 | 565 | 566 | def result_file_tag(args): 567 | """ 568 | Get the result file tag 569 | """ 570 | save_model_name = args.model_name.split("/")[-1] 571 | sfc_tag = "-sfc" if args.sfc else "" 572 | icl_sfc_tag = "-icl_sfc" if args.icl_sfc else "" 573 | sample_eval_tag = "-sampleeval%d" % args.num_eval if args.num_eval is not None else "" 574 | sample_train_tag = "-ntrain%d" % args.num_train if args.num_train > 0 else "" 575 | sample_dev_tag = "-ndev%d" % args.num_dev if args.num_dev is not None else "" 576 | customized_tag = f"-{args.tag}" if len(args.tag) > 0 else "" 577 | return f"{args.task_name}-{save_model_name}" + sfc_tag + icl_sfc_tag + sample_eval_tag + sample_train_tag + sample_dev_tag + customized_tag 578 | 579 | 580 | def main(): 581 | args = parse_args() 582 | if args.prefix_tuning: 583 | args.mode = "prefix" 584 | elif args.lora: 585 | args.mode = "lora" 586 | elif args.prompt_tuning: 587 | args.mode = "prompt" 588 | else: 589 | args.mode = "ft" 590 | args.tag = f"{args.trainer}-{args.task_name}-{args.template_ver}-{args.model_name.split('/')[-1]}-OPTIM_{args.mode}-STEP{args.max_steps}-{args.optimizer}-momen{args.momentum}-LR{args.learning_rate}-{args.lr_scheduler_type}-ZOEPS{args.zo_eps}-T{args.update_interval}-gauss_rank{args.gauss_rank}-Q{args.q}-bs{args.per_device_train_batch_size}-gradAccumulation{args.gradient_accumulation_steps}" 591 | args.run_name = args.tag 592 | args.output_dir = f"result/{args.task_name}/{args.model_name.split('/')[-1]}/{args.mode}/{args.trainer}/{args.tag}" 593 | args.result_file = f"result/{args.task_name}/{args.model_name.split('/')[-1]}/{args.mode}/{args.trainer}/{args.tag}/results.json" 594 | os.makedirs(args.output_dir, exist_ok=True) 595 | 596 | current_date = datetime.now().strftime('%Y-%m-%d_%H-%M-%S') 597 | # wandb.init(project='zo-bench', name=args.tag, config=args) 598 | tensorboard_log_dir = f"result/{args.task_name}/{args.model_name.split('/')[-1]}/{args.mode}/{args.trainer}/{args.tag}/{current_date}" 599 | args.logging_dir = os.path.join(tensorboard_log_dir, "logs") 600 | os.makedirs(args.logging_dir, exist_ok=True) 601 | 602 | writer = SummaryWriter(tensorboard_log_dir) 603 | set_seed(args.seed) 604 | task = get_task(args.task_name) 605 | 606 | # This function samples both training and validation samples. The validation (dev) samples are also stored in "train_sets" 607 | # Later the train_samples and dev_samples are separated 608 | train_sets = task.sample_train_sets(num_train=args.num_train, num_dev=args.num_dev, num_eval=args.num_eval, 609 | num_train_sets=args.num_train_sets, seed=args.train_set_seed) 610 | 611 | # Initialize trainer and load model 612 | framework = Framework(args, task) 613 | 614 | # ZO-Bench Added 615 | # We add these parameters to evaluate the model during the training. 616 | # These two parameters will be used in the training loop 617 | # args.task = task 618 | # args.framework = framework 619 | 620 | if args.train_set_seed is not None or args.num_train_sets is not None: 621 | 622 | # Training goes to this way 623 | 624 | # Eval samples share one (or multiple) training set(s) 625 | for train_set_id, train_samples in enumerate(train_sets): 626 | train_set_seed = train_set_id if args.train_set_seed is None else args.train_set_seed 627 | 628 | # Sample eval samples 629 | if args.num_eval is not None: 630 | eval_samples = task.sample_subset(data_split="valid", seed=train_set_seed, num=args.num_eval) 631 | else: 632 | eval_samples = task.valid_samples 633 | 634 | if args.trainer != "none": 635 | # Here the training samples are seperated 636 | if args.num_dev is not None: 637 | # Dev samples 638 | # assert args.num_dev + args.num_train <= len(train_samples), f"num_dev({args.num_dev})+num_train({args.num_train}) is more than actual num of training samples ({len(train_samples)})." 639 | dev_samples = train_samples[-args.num_dev:] 640 | train_samples = train_samples[:-args.num_dev] 641 | logger.info("Dev samples: %d" % len(dev_samples)) 642 | logger.info("Train samples: %d" % len(train_samples)) 643 | else: 644 | dev_samples = None 645 | logger.info("Train samples: %d" % len(train_samples)) 646 | logger.info("No dev samples") 647 | 648 | args.dev_samples = dev_samples 649 | args.eval_samples = eval_samples 650 | 651 | # Training 652 | framework.train(train_samples, dev_samples if dev_samples is not None else eval_samples, eval_samples, writer) 653 | 654 | if not args.no_eval: # This is True 655 | metrics = framework.evaluate([], eval_samples, description="Evaluating on the Test Set") 656 | _keys = list(metrics.keys()) 657 | for m in _keys: 658 | metrics["test_" + m] = metrics[m] 659 | if dev_samples is not None: 660 | dev_metrics = framework.evaluate( 661 | [], dev_samples, description="Evaluating on the Validation Set" 662 | ) 663 | _keys = list(dev_metrics.keys()) 664 | for m in _keys: 665 | metrics["val_" + m] = dev_metrics[m] 666 | else: 667 | assert args.num_dev is None 668 | # Zero-shot / in-context learning 669 | metrics = framework.evaluate(train_samples, eval_samples) 670 | logger.info(metrics) 671 | print('metrics: \n\n\n', metrics) 672 | # wandb.log(metrics) 673 | 674 | # for key, value in metrics.items(): 675 | # writer.add_scalar(key, value, global_step) 676 | 677 | if not args.no_eval: 678 | logger.info("===== Train set %d =====" % train_set_seed) 679 | logger.info(metrics) 680 | print('metric: /n/n/n', metrics) 681 | # wandb.log(metrics) 682 | if args.local_rank <= 0: 683 | write_metrics_to_file(metrics, "result/" + result_file_tag( 684 | args) + f"-trainset{train_set_id}.json" if args.result_file is None else args.result_file) 685 | if args.trainer != "none" and args.clean_model_at_end: 686 | framework.delete_checkpoints() 687 | 688 | else: 689 | # For each eval sample, there is a training set. no training is allowed 690 | # This is for in-context learning (ICL) 691 | assert args.trainer == "none" 692 | if args.num_eval is not None: 693 | eval_samples = task.sample_subset(data_split="valid", seed=0, num=args.num_eval) 694 | else: 695 | eval_samples = task.valid_samples 696 | metrics = framework.evaluate(train_sets, eval_samples, one_train_set_per_eval_sample=True) 697 | logger.info(metrics) 698 | # wandb.log(metrics) 699 | if args.local_rank <= 0: 700 | write_metrics_to_file(metrics, "result/" + result_file_tag( 701 | args) + "-onetrainpereval.json" if args.result_file is None else args.result_file) 702 | 703 | writer.close() 704 | 705 | if __name__ == "__main__": 706 | main() 707 | -------------------------------------------------------------------------------- /large_models/tasks.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import sys 3 | from dataclasses import dataclass 4 | from typing import List, Union 5 | 6 | import numpy as np 7 | from datasets import load_dataset 8 | 9 | from templates import * 10 | from utils import temp_seed 11 | 12 | logger = logging.getLogger(__name__) 13 | logger.setLevel(logging.INFO) 14 | 15 | 16 | def get_task(task_name): 17 | aa = task_name.split("__") 18 | if len(aa) == 2: 19 | task_group, subtask = aa 20 | else: 21 | task_group = aa[0] 22 | subtask = None 23 | class_ = getattr(sys.modules[__name__], f"{task_group}Dataset") 24 | instance = class_(subtask) 25 | return instance 26 | 27 | 28 | @dataclass 29 | class Sample: 30 | id: int = None 31 | data: dict = None 32 | correct_candidate: Union[str, List[str]] = None 33 | candidates: List[str] = None 34 | 35 | 36 | class Dataset: 37 | mixed_set = False 38 | train_sep = "\n\n" 39 | generation = False # whether this is a generation task 40 | 41 | def __init__(self, subtask=None, **kwargs) -> None: 42 | self.samples = None 43 | self.subtask = subtask 44 | 45 | def get_task_name(self): 46 | return self.subtask 47 | 48 | def load_dataset(self, path, **kwargs): 49 | raise NotImplementedError 50 | 51 | def get_template(self, template_version=0): 52 | templates = {0: Template} 53 | return templates[template_version] 54 | 55 | def build_sample(self, example): 56 | return 57 | 58 | def sample_train_sets(self, num_train=32, num_dev=None, num_eval=None, num_train_sets=None, seed=None): 59 | if seed is not None: 60 | # one train/demo set using the designated seed 61 | seeds = [seed] 62 | elif num_train_sets is not None: 63 | # num_train_sets train/demo sets 64 | seeds = list(range(num_train_sets)) 65 | else: 66 | # one train/demo set per evaluation sample 67 | assert num_dev is None # not supported 68 | len_valid_samples = len(self.samples["valid"]) if num_eval is None else num_eval 69 | with temp_seed(0): 70 | seeds = np.random.randint(0, 10000, len_valid_samples) 71 | 72 | train_samples = [] 73 | for i, set_seed in enumerate(seeds): 74 | if self.mixed_set: # This is always False for now 75 | raise NotImplementedError 76 | train_samples.append(self.sample_subset(data_split="valid", seed=set_seed, num=num_train, exclude=i)) 77 | else: 78 | if num_dev is not None: 79 | train_samples.append(self.sample_subset(data_split="train", seed=set_seed, 80 | num=num_train + num_dev)) # dev set is included at the end of train set 81 | if num_train + num_dev > len(self.samples["train"]): 82 | logger.warn("num_train + num_dev > available training examples") 83 | else: 84 | train_samples.append(self.sample_subset(data_split="train", seed=set_seed, num=num_train)) 85 | if num_dev is not None: 86 | logger.info(f"Sample train set {len(train_samples[-1])}/{len(self.samples['train'])}") 87 | logger.info(f"... including dev set {num_dev} samples") 88 | return train_samples 89 | 90 | def sample_subset(self, data_split="train", seed=0, num=100, exclude=None): 91 | with temp_seed(seed): 92 | samples = self.samples[data_split] 93 | lens = len(samples) 94 | index = np.random.permutation(lens).tolist()[:num if exclude is None else num + 1] 95 | if exclude is not None and exclude in index: 96 | index.remove(exclude) 97 | else: 98 | index = index[:num] 99 | return [samples[i] for i in index] 100 | 101 | @property 102 | def valid_samples(self): 103 | return self.samples["valid"] 104 | 105 | 106 | class SST2Dataset(Dataset): 107 | train_sep = "\n\n" 108 | 109 | def __init__(self, subtask=None, **kwargs) -> None: 110 | self.load_dataset(subtask, **kwargs) 111 | 112 | def load_dataset(self, path, **kwargs): 113 | d = load_dataset('glue', 'sst2') 114 | train_d = d["train"] 115 | validation_d = d["validation"] 116 | 117 | train_samples = [self.build_sample(example) for example in train_d] 118 | valid_samples = [self.build_sample(example) for example in validation_d] 119 | 120 | self.samples = {"train": train_samples, "valid": valid_samples} 121 | 122 | # for generative tasks, candidates are [] 123 | def build_sample(self, example): 124 | label = int(example["label"]) 125 | # print('example', example) 126 | return Sample(id=example["idx"], data=example, correct_candidate=label, candidates=[0, 1]) 127 | 128 | def get_template(self, template_version=0): 129 | return {0: SST2Template, 1: SST2TemplateEmpty}[template_version]() 130 | 131 | class SST5Dataset(Dataset): 132 | train_sep = "\n\n" 133 | 134 | def __init__(self, subtask=None, **kwargs) -> None: 135 | self.load_dataset(subtask, **kwargs) 136 | 137 | def load_dataset(self, path, **kwargs): 138 | d = load_dataset("SetFit/sst5") 139 | # print(d) 140 | train_d = d["train"] 141 | validation_d = d["validation"] 142 | 143 | train_samples = [self.build_sample(example) for example in train_d] 144 | valid_samples = [self.build_sample(example) for example in validation_d] 145 | 146 | self.samples = {"train": train_samples, "valid": valid_samples} 147 | 148 | # for generative tasks, candidates are [] 149 | def build_sample(self, example): 150 | label = int(example["label"]) 151 | # print('example', example) 152 | return Sample(data=example, correct_candidate=label, candidates=[0, 1, 2, 3, 4]) 153 | 154 | def get_template(self, template_version=0): 155 | return {0: SST5Template, 1: SST5TemplateEmpty}[template_version]() 156 | 157 | class CopaDataset(Dataset): 158 | train_sep = "\n\n" 159 | mixed_set = False 160 | 161 | def __init__(self, subtask=None, **kwargs) -> None: 162 | self.load_dataset(subtask, **kwargs) 163 | 164 | def load_dataset(self, path, **kwargs): 165 | train_examples = load_dataset('super_glue', "copa")["train"] 166 | valid_examples = load_dataset('super_glue', "copa")["validation"] 167 | 168 | train_samples = [self.build_sample(example) for example in train_examples] 169 | valid_samples = [self.build_sample(example) for example in valid_examples] 170 | self.samples = {"train": train_samples, "valid": valid_samples} 171 | 172 | # for generative tasks, candidates are [] 173 | def build_sample(self, example): 174 | sample = \ 175 | Sample( 176 | id=example["idx"], 177 | data=example, 178 | candidates=[example["choice1"], example["choice2"]], 179 | correct_candidate=example[f"choice{example['label'] + 1}"], 180 | ) 181 | 182 | return sample 183 | 184 | def get_template(self, template_version=0): 185 | return {0: CopaTemplate, 1: CopaTemplateEmpty}[template_version]() 186 | 187 | 188 | class BoolQDataset(Dataset): 189 | def __init__(self, subtask=None, **kwargs) -> None: 190 | self.load_dataset(subtask, **kwargs) 191 | 192 | def load_dataset(self, path, **kwargs): 193 | d = load_dataset("boolq") 194 | train_set = d["train"] 195 | valid_set = d["validation"] 196 | 197 | train_samples = [self.build_sample(example) for example in train_set] 198 | valid_samples = [self.build_sample(example) for example in valid_set] 199 | self.samples = {"train": train_samples, "valid": valid_samples} 200 | 201 | def build_sample(self, example): 202 | # print('example', example) 203 | sample = \ 204 | Sample( 205 | data=example, 206 | candidates=["Yes", "No"], 207 | correct_candidate="Yes" if example["answer"] else "No", 208 | ) 209 | 210 | return sample 211 | 212 | def get_template(self, template_version=2): 213 | return {0: BoolQTemplate, 1: BoolQTemplateV2, 2: BoolQTemplateV3}[template_version]() 214 | 215 | 216 | class MultiRCDataset(Dataset): 217 | 218 | def __init__(self, subtask=None, **kwargs) -> None: 219 | self.load_dataset(subtask, **kwargs) 220 | 221 | def load_dataset(self, path, **kwargs): 222 | d = load_dataset("super_glue", "multirc") 223 | train_set = d["train"] 224 | valid_set = d["validation"] 225 | 226 | train_samples = [self.build_sample(example) for example in train_set] 227 | valid_samples = [self.build_sample(example) for example in valid_set] 228 | self.samples = {"train": train_samples, "valid": valid_samples} 229 | 230 | def build_sample(self, example): 231 | sample = \ 232 | Sample( 233 | data=example, 234 | candidates=[0, 1], 235 | correct_candidate=example['label'] 236 | ) 237 | 238 | return sample 239 | 240 | def get_template(self, template_version=0): 241 | return {0: MultiRCTemplate}[template_version]() 242 | 243 | 244 | class CBDataset(Dataset): 245 | 246 | def __init__(self, subtask=None, **kwargs) -> None: 247 | self.load_dataset(subtask, **kwargs) 248 | 249 | def load_dataset(self, path, **kwargs): 250 | d = load_dataset("super_glue", "cb") 251 | train_set = d["train"] 252 | valid_set = d["validation"] 253 | 254 | train_samples = [self.build_sample(example) for example in train_set] 255 | valid_samples = [self.build_sample(example) for example in valid_set] 256 | self.samples = {"train": train_samples, "valid": valid_samples} 257 | 258 | def build_sample(self, example): 259 | sample = \ 260 | Sample( 261 | data=example, 262 | candidates=[0, 1, 2], 263 | correct_candidate=example['label'] 264 | ) 265 | 266 | return sample 267 | 268 | def get_template(self, template_version=0): 269 | return {0: CBTemplate}[template_version]() 270 | 271 | 272 | class WICDataset(Dataset): 273 | 274 | def __init__(self, subtask=None, **kwargs) -> None: 275 | self.load_dataset(subtask, **kwargs) 276 | 277 | def load_dataset(self, path, **kwargs): 278 | d = load_dataset("super_glue", "wic") 279 | train_set = d["train"] 280 | valid_set = d["validation"] 281 | 282 | train_samples = [self.build_sample(example) for example in train_set] 283 | valid_samples = [self.build_sample(example) for example in valid_set] 284 | self.samples = {"train": train_samples, "valid": valid_samples} 285 | 286 | def build_sample(self, example): 287 | sample = \ 288 | Sample( 289 | data=example, 290 | candidates=[0, 1], 291 | correct_candidate=example['label'] 292 | ) 293 | 294 | return sample 295 | 296 | def get_template(self, template_version=0): 297 | return {0: WICTemplate}[template_version]() 298 | 299 | 300 | class WSCDataset(Dataset): 301 | 302 | def __init__(self, subtask=None, **kwargs) -> None: 303 | self.load_dataset(subtask, **kwargs) 304 | 305 | def load_dataset(self, path, **kwargs): 306 | d = load_dataset("super_glue", "wsc.fixed") 307 | train_set = d["train"] 308 | valid_set = d["validation"] 309 | 310 | train_samples = [self.build_sample(example) for example in train_set] 311 | valid_samples = [self.build_sample(example) for example in valid_set] 312 | self.samples = {"train": train_samples, "valid": valid_samples} 313 | 314 | def build_sample(self, example): 315 | sample = \ 316 | Sample( 317 | data=example, 318 | candidates=[0, 1], 319 | correct_candidate=example['label'] 320 | ) 321 | 322 | return sample 323 | 324 | def get_template(self, template_version=0): 325 | return {0: WSCTemplate}[template_version]() 326 | 327 | 328 | class ReCoRDDataset(Dataset): 329 | 330 | def __init__(self, subtask=None, **kwargs) -> None: 331 | self.load_dataset(subtask, **kwargs) 332 | 333 | def load_dataset(self, path, **kwargs): 334 | d = load_dataset("super_glue", "record") 335 | train_set = d["train"] 336 | valid_set = d["validation"] 337 | 338 | train_samples = [self.build_sample(example) for example in train_set] 339 | valid_samples = [self.build_sample(example) for example in valid_set] 340 | self.samples = {"train": train_samples, "valid": valid_samples} 341 | 342 | def build_sample(self, example): 343 | sample = \ 344 | Sample( 345 | data=example, 346 | candidates=example['entities'], 347 | correct_candidate=example['answers'] 348 | ) 349 | 350 | return sample 351 | 352 | def get_template(self, template_version=0): 353 | return {0: ReCoRDTemplateGPT3}[template_version]() 354 | 355 | 356 | class RTEDataset(Dataset): 357 | 358 | def __init__(self, subtask=None, **kwargs) -> None: 359 | self.load_dataset(subtask, **kwargs) 360 | 361 | def load_dataset(self, path, **kwargs): 362 | d = load_dataset("super_glue", "rte") 363 | train_set = d["train"] 364 | valid_set = d["validation"] 365 | 366 | train_samples = [self.build_sample(example) for example in train_set] 367 | valid_samples = [self.build_sample(example) for example in valid_set] 368 | self.samples = {"train": train_samples, "valid": valid_samples} 369 | 370 | def build_sample(self, example): 371 | sample = \ 372 | Sample( 373 | data=example, 374 | candidates=[0, 1], 375 | correct_candidate=example['label'] 376 | ) 377 | 378 | return sample 379 | 380 | def get_template(self, template_version=0): 381 | return {0: RTETemplate, 1: RTETemplateEmpty}[template_version]() 382 | 383 | 384 | class SQuADDataset(Dataset): 385 | metric_name = "f1" 386 | generation = True 387 | 388 | def __init__(self, subtask=None, **kwargs) -> None: 389 | self.load_dataset() 390 | 391 | def load_dataset(self): 392 | dataset = load_dataset("squad") 393 | train_examples = dataset["train"] 394 | valid_examples = dataset["validation"] 395 | 396 | train_samples = [self.build_sample(example, idx) for idx, example in enumerate(train_examples)] 397 | valid_samples = [self.build_sample(example, idx) for idx, example in enumerate(valid_examples)] 398 | self.samples = {"train": train_samples, "valid": valid_samples} 399 | 400 | # for generative tasks, candidates are [] 401 | def build_sample(self, example, idx): 402 | answers = example['answers']['text'] 403 | assert len(answers) > 0 404 | return Sample( 405 | id=idx, 406 | data={ 407 | "title": example['title'], 408 | "context": example['context'], 409 | "question": example['question'], 410 | "answers": answers 411 | }, 412 | candidates=None, 413 | correct_candidate=answers 414 | ) 415 | 416 | def get_template(self, template_version=0): 417 | return {0: SQuADv2Template}[template_version]() 418 | 419 | 420 | class DROPDataset(Dataset): 421 | metric_name = "f1" 422 | generation = True 423 | 424 | def __init__(self, subtask=None, **kwargs) -> None: 425 | self.load_dataset() 426 | 427 | def load_dataset(self): 428 | dataset = load_dataset("drop") 429 | train_examples = dataset["train"] 430 | valid_examples = dataset["validation"] 431 | 432 | train_samples = [self.build_sample(example, idx) for idx, example in enumerate(train_examples)] 433 | valid_samples = [self.build_sample(example, idx) for idx, example in enumerate(valid_examples)] 434 | self.samples = {"train": train_samples, "valid": valid_samples} 435 | 436 | # for generative tasks, candidates are [] 437 | def build_sample(self, example, idx): 438 | answers = example['answers_spans']['spans'] 439 | assert len(answers) > 0 440 | return Sample( 441 | id=idx, 442 | data={ 443 | "context": example['passage'], 444 | "question": example['question'], 445 | "answers": answers 446 | }, 447 | candidates=None, 448 | correct_candidate=answers 449 | ) 450 | 451 | def get_template(self, template_version=0): 452 | return {0: DROPTemplate}[template_version]() 453 | 454 | 455 | class WinoGrandeDataset(Dataset): 456 | def __init__(self, subtask=None, **kwargs) -> None: 457 | super().__init__(subtask, **kwargs) 458 | self.load_dataset(subtask, **kwargs) 459 | 460 | def load_dataset(self, path, **kwargs): 461 | train_set = load_dataset('winogrande', 'winogrande_m', split='train') 462 | valid_set = load_dataset('winogrande', 'winogrande_m', split='validation') 463 | 464 | train_samples = [self.build_sample(example) for example in train_set] 465 | valid_samples = [self.build_sample(example) for example in valid_set] 466 | self.samples = {"train": train_samples, "valid": valid_samples} 467 | 468 | def build_sample(self, example): 469 | """ 470 | Prompt adapted from https://arxiv.org/pdf/2110.08207.pdf 471 | """ 472 | sentence = example["sentence"] 473 | context, target = sentence.split("_") 474 | sample = Sample( 475 | data=example, 476 | candidates=[example['option1'] + target, example['option2'] + target], 477 | correct_candidate=example[f'option{example["answer"]}'] + target, 478 | ) 479 | return sample 480 | 481 | def get_template(self, template_version=0): 482 | if template_version == 0: 483 | return WinoGrandeTemplate() 484 | else: 485 | raise NotImplementedError(f"Template version {template_version} not implemented for WinoGrande") 486 | -------------------------------------------------------------------------------- /large_models/templates.py: -------------------------------------------------------------------------------- 1 | class Template: 2 | def encode(self, sample): 3 | """ 4 | Return prompted version of the example (without the answer/candidate) 5 | """ 6 | raise NotImplementedError 7 | 8 | def verbalize(self, sample, candidate): 9 | """ 10 | Return the prompted version of the example (with the answer/candidate) 11 | """ 12 | return candidate 13 | 14 | def encode_sfc(self, sample): 15 | """ 16 | Same as encode, but for SFC (calibration) -- this usually means the input is not included 17 | """ 18 | return "" 19 | 20 | def verbalize_sfc(self, sample, candidate): 21 | """ 22 | Same as verbalize, but for SFC (calibration) -- this usually means the input is not included 23 | """ 24 | return candidate 25 | 26 | 27 | class SST2Template(Template): 28 | verbalizer = {0: "terrible", 1: "great"} 29 | 30 | def encode(self, sample): 31 | text = sample.data["sentence"].strip() 32 | return f"{text} It was" 33 | 34 | def verbalize(self, sample, candidate): 35 | text = sample.data["sentence"].strip() 36 | return f"{text} It was {self.verbalizer[candidate]}" 37 | 38 | def encode_sfc(self, sample): 39 | return f" It was" 40 | 41 | def verbalize_sfc(self, sample, candidate): 42 | return f" It was {self.verbalizer[candidate]}" 43 | 44 | class SST2TemplateEmpty(Template): 45 | verbalizer = {0: "terrible", 1: "great"} 46 | 47 | def encode(self, sample): 48 | text = sample.data["sentence"].strip() 49 | return f"{text} " 50 | 51 | def verbalize(self, sample, candidate): 52 | text = sample.data["sentence"].strip() 53 | return f"{text} {self.verbalizer[candidate]}" 54 | 55 | def encode_sfc(self, sample): 56 | return f" " 57 | 58 | def verbalize_sfc(self, sample, candidate): 59 | return f" {self.verbalizer[candidate]}" 60 | 61 | 62 | class CopaTemplate(Template): 63 | capitalization: str = "correct" 64 | effect_conj: str = " so " 65 | cause_conj: str = " because " 66 | 67 | def get_conjucture(self, sample): 68 | if sample.data["question"] == "effect": 69 | conjunction = self.effect_conj 70 | elif sample.data["question"] == "cause": 71 | conjunction = self.cause_conj 72 | else: 73 | raise NotImplementedError 74 | return conjunction 75 | 76 | def get_prompt(self, sample): 77 | premise = sample.data["premise"].rstrip() 78 | if premise.endswith("."): # TODO Add other scripts with different punctuation 79 | premise = premise[:-1] 80 | conjunction = self.get_conjucture(sample) 81 | prompt = premise + conjunction 82 | if self.capitalization == "upper": 83 | prompt = prompt.upper() 84 | elif self.capitalization == "lower": 85 | prompt = prompt.lower() 86 | return prompt 87 | 88 | def encode(self, sample): 89 | prompt = self.get_prompt(sample) 90 | return prompt 91 | 92 | def capitalize(self, c): 93 | if self.capitalization == "correct": 94 | words = c.split(" ") 95 | if words[0] != "I": 96 | words[0] = words[0].lower() 97 | return " ".join(words) 98 | elif self.capitalization == "bug": 99 | return c 100 | elif self.capitalization == "upper": 101 | return c.upper() 102 | elif self.capitalization == "lower": 103 | return c.lower() 104 | else: 105 | raise NotImplementedError 106 | 107 | def verbalize(self, sample, candidate): 108 | prompt = self.get_prompt(sample) 109 | return prompt + self.capitalize(candidate) 110 | 111 | def encode_sfc(self, sample): 112 | conjunction = self.get_conjucture(sample) 113 | return conjunction.strip() 114 | 115 | def verbalize_sfc(self, sample, candidate): 116 | conjunction = self.get_conjucture(sample) 117 | sfc_prompt = conjunction.strip() + " " + self.capitalize(candidate) 118 | return sfc_prompt 119 | 120 | 121 | class CopaTemplateEmpty(Template): 122 | capitalization: str = "correct" 123 | effect_conj: str = " " 124 | cause_conj: str = " " 125 | 126 | def get_conjucture(self, sample): 127 | if sample.data["question"] == "effect": 128 | conjunction = self.effect_conj 129 | elif sample.data["question"] == "cause": 130 | conjunction = self.cause_conj 131 | else: 132 | raise NotImplementedError 133 | return conjunction 134 | 135 | def get_prompt(self, sample): 136 | premise = sample.data["premise"].rstrip() 137 | if premise.endswith("."): # TODO Add other scripts with different punctuation 138 | premise = premise[:-1] 139 | conjunction = self.get_conjucture(sample) 140 | prompt = premise + conjunction 141 | if self.capitalization == "upper": 142 | prompt = prompt.upper() 143 | elif self.capitalization == "lower": 144 | prompt = prompt.lower() 145 | return prompt 146 | 147 | def encode(self, sample): 148 | prompt = self.get_prompt(sample) 149 | return prompt 150 | 151 | def capitalize(self, c): 152 | if self.capitalization == "correct": 153 | words = c.split(" ") 154 | if words[0] != "I": 155 | words[0] = words[0].lower() 156 | return " ".join(words) 157 | elif self.capitalization == "bug": 158 | return c 159 | elif self.capitalization == "upper": 160 | return c.upper() 161 | elif self.capitalization == "lower": 162 | return c.lower() 163 | else: 164 | raise NotImplementedError 165 | 166 | def verbalize(self, sample, candidate): 167 | prompt = self.get_prompt(sample) 168 | return prompt + self.capitalize(candidate) 169 | 170 | def encode_sfc(self, sample): 171 | conjunction = self.get_conjucture(sample) 172 | return conjunction.strip() 173 | 174 | def verbalize_sfc(self, sample, candidate): 175 | conjunction = self.get_conjucture(sample) 176 | sfc_prompt = conjunction.strip() + " " + self.capitalize(candidate) 177 | return sfc_prompt 178 | 179 | 180 | class BoolQTemplate(Template): 181 | def encode(self, sample): 182 | passage = sample.data["passage"] 183 | question = sample.data["question"] 184 | if not question.endswith("?"): 185 | question = question + "?" 186 | question = question[0].upper() + question[1:] 187 | return f"{passage} {question}" 188 | 189 | def verbalize(self, sample, candidate): 190 | passage = sample.data["passage"] 191 | question = sample.data["question"] 192 | if not question.endswith("?"): 193 | question = question + "?" 194 | question = question[0].upper() + question[1:] 195 | return f"{passage} {question} {candidate}" 196 | 197 | def encode_sfc(self, sample): 198 | return "" 199 | 200 | def verbalize_sfc(self, sample, candidate): 201 | return candidate 202 | 203 | 204 | class BoolQTemplateV2(Template): 205 | def encode(self, sample): 206 | passage = sample.data["passage"] 207 | question = sample.data["question"] 208 | if not question.endswith("?"): 209 | question = question + "?" 210 | question = question[0].upper() + question[1:] 211 | return f"{passage} {question}\\n\\n" 212 | 213 | def verbalize(self, sample, candidate): 214 | passage = sample.data["passage"] 215 | question = sample.data["question"] 216 | if not question.endswith("?"): 217 | question = question + "?" 218 | question = question[0].upper() + question[1:] 219 | return f"{passage} {question}\\n\\n{candidate}" 220 | 221 | def encode_sfc(self, sample): 222 | return "" 223 | 224 | def verbalize_sfc(self, sample, candidate): 225 | return candidate 226 | 227 | 228 | class BoolQTemplateV3(Template): 229 | def encode(self, sample): 230 | passage = sample.data["passage"] 231 | question = sample.data["question"] 232 | if not question.endswith("?"): 233 | question = question + "?" 234 | question = question[0].upper() + question[1:] 235 | return f"{passage} {question}\n" 236 | 237 | def verbalize(self, sample, candidate): 238 | passage = sample.data["passage"] 239 | question = sample.data["question"] 240 | if not question.endswith("?"): 241 | question = question + "?" 242 | question = question[0].upper() + question[1:] 243 | return f"{passage} {question}\n{candidate}" 244 | 245 | def encode_sfc(self, sample): 246 | return "" 247 | 248 | def verbalize_sfc(self, sample, candidate): 249 | return candidate 250 | 251 | 252 | class MultiRCTemplate(Template): 253 | # From PromptSource 1 254 | verbalizer = {0: "No", 1: "Yes"} 255 | 256 | def encode(self, sample): 257 | paragraph = sample.data["paragraph"] 258 | question = sample.data["question"] 259 | answer = sample.data["answer"] 260 | return f"{paragraph}\nQuestion: {question}\nI found this answer \"{answer}\". Is that correct? Yes or No?\n" 261 | 262 | def verbalize(self, sample, candidate): 263 | paragraph = sample.data["paragraph"] 264 | question = sample.data["question"] 265 | answer = sample.data["answer"] 266 | return f"{paragraph}\nQuestion: {question}\nI found this answer \"{answer}\". Is that correct? Yes or No?\n{self.verbalizer[candidate]}" 267 | 268 | def encode_sfc(self, sample): 269 | return f"" 270 | 271 | def verbalize_sfc(self, sample, candidate): 272 | return f"{self.verbalizer[candidate]}" 273 | 274 | 275 | class CBTemplate(Template): 276 | # From PromptSource 1 277 | verbalizer = {0: "Yes", 1: "No", 2: "Maybe"} 278 | 279 | def encode(self, sample): 280 | premise = sample.data["premise"] 281 | hypothesis = sample.data["hypothesis"] 282 | return f"Suppose {premise} Can we infer that \"{hypothesis}\"? Yes, No, or Maybe?\n" 283 | 284 | def verbalize(self, sample, candidate): 285 | premise = sample.data["premise"] 286 | hypothesis = sample.data["hypothesis"] 287 | return f"Suppose {premise} Can we infer that \"{hypothesis}\"? Yes, No, or Maybe?\n{self.verbalizer[candidate]}" 288 | 289 | def encode_sfc(self, sample): 290 | return f"" 291 | 292 | def verbalize_sfc(self, sample, candidate): 293 | return f"{self.verbalizer[candidate]}" 294 | 295 | 296 | class WICTemplate(Template): 297 | # From PromptSource 1 298 | verbalizer = {0: "No", 1: "Yes"} 299 | 300 | def encode(self, sample): 301 | sent1 = sample.data["sentence1"] 302 | sent2 = sample.data["sentence2"] 303 | word = sample.data["word"] 304 | return f"Does the word \"{word}\" have the same meaning in these two sentences? Yes, No?\n{sent1}\n{sent2}\n" 305 | 306 | def verbalize(self, sample, candidate): 307 | sent1 = sample.data["sentence1"] 308 | sent2 = sample.data["sentence2"] 309 | word = sample.data["word"] 310 | return f"Does the word \"{word}\" have the same meaning in these two sentences? Yes, No?\n{sent1}\n{sent2}\n{self.verbalizer[candidate]}" 311 | 312 | def encode_sfc(self, sample): 313 | return f"" 314 | 315 | def verbalize_sfc(self, sample, candidate): 316 | return f"{self.verbalizer[candidate]}" 317 | 318 | 319 | class WSCTemplate(Template): 320 | # From PromptSource 1 321 | verbalizer = {0: "No", 1: "Yes"} 322 | 323 | def encode(self, sample): 324 | text = sample.data['text'] 325 | span1 = sample.data['span1_text'] 326 | span2 = sample.data['span2_text'] 327 | return f"{text}\nIn the previous sentence, does the pronoun \"{span2.lower()}\" refer to {span1}? Yes or No?\n" 328 | 329 | def verbalize(self, sample, candidate): 330 | text = sample.data['text'] 331 | span1 = sample.data['span1_text'] 332 | span2 = sample.data['span2_text'] 333 | return f"{text}\nIn the previous sentence, does the pronoun \"{span2.lower()}\" refer to {span1}? Yes or No?\n{self.verbalizer[candidate]}" 334 | 335 | def encode_sfc(self, sample): 336 | return f"" 337 | 338 | def verbalize_sfc(self, sample, candidate): 339 | return f"{self.verbalizer[candidate]}" 340 | 341 | 342 | class ReCoRDTemplate(Template): 343 | # From PromptSource 1 but modified 344 | 345 | def encode(self, sample): 346 | passage = sample.data['passage'] 347 | query = sample.data['query'] 348 | return f"{passage}\n{query}\nQuestion: what is the \"@placeholder\"\nAnswer:" 349 | 350 | def verbalize(self, sample, candidate): 351 | passage = sample.data['passage'] 352 | query = sample.data['query'] 353 | return f"{passage}\n{query}\nQuestion: what is the \"@placeholder\"\nAnswer: {candidate}" 354 | 355 | def encode_sfc(self, sample): 356 | return f"Answer:" 357 | 358 | def verbalize_sfc(self, sample, candidate): 359 | return f"Answer: {candidate}" 360 | 361 | 362 | class ReCoRDTemplateGPT3(Template): 363 | # From PromptSource 1 but modified 364 | 365 | def encode(self, sample): 366 | passage = sample.data['passage'].replace("@highlight\n", "- ") 367 | return f"{passage}\n-" 368 | 369 | def verbalize(self, sample, candidate): 370 | passage = sample.data['passage'].replace("@highlight\n", "- ") 371 | query = sample.data['query'].replace("@placeholder", candidate[0] if isinstance(candidate, list) else candidate) 372 | return f"{passage}\n- {query}" 373 | 374 | # passage = sample.data['passage'] 375 | # query = sample.data['query'] 376 | # return f"{passage}\n{query}\nQuestion: what is the \"@placeholder\"\nAnswer: {candidate}" 377 | 378 | def encode_sfc(self, sample): 379 | return f"-" 380 | 381 | def verbalize_sfc(self, sample, candidate): 382 | query = sample.data['query'].replace("@placeholder", candidate[0] if isinstance(candidate, list) else candidate) 383 | return f"- {query}" 384 | 385 | 386 | class RTETemplate(Template): 387 | # From PromptSource 1 388 | verbalizer = {0: "Yes", 1: "No"} 389 | 390 | def encode(self, sample): 391 | premise = sample.data['premise'] 392 | hypothesis = sample.data['hypothesis'] 393 | return f"{premise}\nDoes this mean that \"{hypothesis}\" is true? Yes or No?\n" 394 | 395 | def verbalize(self, sample, candidate): 396 | premise = sample.data['premise'] 397 | hypothesis = sample.data['hypothesis'] 398 | return f"{premise}\nDoes this mean that \"{hypothesis}\" is true? Yes or No?\n{self.verbalizer[candidate]}" 399 | 400 | def encode_sfc(self, sample): 401 | return f"" 402 | 403 | def verbalize_sfc(self, sample, candidate): 404 | return f"{self.verbalizer[candidate]}" 405 | 406 | class RTETemplateEmpty(Template): 407 | # From PromptSource 1 408 | verbalizer = {0: "Yes", 1: "No"} 409 | 410 | def encode(self, sample): 411 | premise = sample.data['premise'] 412 | hypothesis = sample.data['hypothesis'] 413 | return f"{premise}\n\"{hypothesis}\"\n" 414 | 415 | def verbalize(self, sample, candidate): 416 | premise = sample.data['premise'] 417 | hypothesis = sample.data['hypothesis'] 418 | return f"{premise}\n\"{hypothesis}\"\n{self.verbalizer[candidate]}" 419 | 420 | def encode_sfc(self, sample): 421 | return f"" 422 | 423 | def verbalize_sfc(self, sample, candidate): 424 | return f"{self.verbalizer[candidate]}" 425 | 426 | 427 | class SQuADv2Template(Template): 428 | 429 | def encode(self, sample): 430 | question = sample.data['question'].strip() 431 | title = sample.data['title'] 432 | context = sample.data['context'] 433 | answer = sample.data['answers'][0] # there are multiple answers. for the prompt we only take the first one 434 | 435 | return f"Title: {title}\nContext: {context}\nQuestion: {question}\nAnswer:" 436 | 437 | def verbalize(self, sample, candidate): 438 | question = sample.data['question'].strip() 439 | title = sample.data['title'] 440 | context = sample.data['context'] 441 | answer = sample.data['answers'][0] # there are multiple answers. for the prompt we only take the first one 442 | 443 | return f"Title: {title}\nContext: {context}\nQuestion: {question}\nAnswer: {answer}\n" 444 | 445 | def encode_sfc(self, sample): 446 | raise NotImplementedError 447 | 448 | def verbalize_sfc(self, sample, candidate): 449 | raise NotImplementedError 450 | 451 | 452 | class DROPTemplate(Template): 453 | 454 | def encode(self, sample): 455 | question = sample.data['question'].strip() 456 | # title = sample.data['title'] 457 | context = sample.data['context'] 458 | answer = sample.data['answers'][0] # there are multiple answers. for the prompt we only take the first one 459 | 460 | return f"Passage: {context}\nQuestion: {question}\nAnswer:" 461 | 462 | def verbalize(self, sample, candidate): 463 | question = sample.data['question'].strip() 464 | # title = sample.data['title'] 465 | context = sample.data['context'] 466 | answer = sample.data['answers'][0] # there are multiple answers. for the prompt we only take the first one 467 | 468 | return f"Passage: {context}\nQuestion: {question}\nAnswer: {answer}\n" 469 | 470 | def encode_sfc(self, sample): 471 | raise NotImplementedError 472 | 473 | def verbalize_sfc(self, sample, candidate): 474 | raise NotImplementedError 475 | 476 | 477 | class WinoGrandeTemplate(Template): 478 | @staticmethod 479 | def get_prompt(sample): 480 | """ 481 | Prompt adapted from https://arxiv.org/pdf/2110.08207.pdf 482 | """ 483 | sentence = sample.data["sentence"] 484 | context, target = sentence.split("_") 485 | return context 486 | 487 | def encode(self, sample): 488 | prompt = self.get_prompt(sample) 489 | return prompt 490 | 491 | def verbalize(self, sample, candidate): 492 | prompt = self.get_prompt(sample) 493 | return prompt + candidate 494 | 495 | def encode_sfc(self, sample): 496 | return "" 497 | 498 | def verbalize_sfc(self, sample, candidate): 499 | return candidate 500 | -------------------------------------------------------------------------------- /large_models/utils.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | import json 3 | import logging 4 | import signal 5 | import time 6 | from collections.abc import Mapping 7 | from dataclasses import is_dataclass, asdict 8 | from typing import Any, Dict, List, NewType, Optional, Union 9 | 10 | import numpy as np 11 | import torch 12 | import torch.nn.functional as F 13 | import transformers 14 | from torch.nn import CrossEntropyLoss 15 | from transformers.data.data_collator import DataCollatorMixin 16 | from transformers.modeling_outputs import CausalLMOutputWithPast 17 | from transformers.utils import PaddingStrategy 18 | 19 | InputDataClass = NewType("InputDataClass", Any) 20 | from dataclasses import dataclass 21 | from transformers.tokenization_utils_base import PreTrainedTokenizerBase 22 | 23 | logger = logging.getLogger(__name__) 24 | 25 | def forward_wrap_with_option_len( 26 | self, 27 | input_ids=None, 28 | labels=None, 29 | option_len=None, 30 | num_options=None, 31 | return_dict=None, 32 | **kwargs 33 | ): 34 | """ 35 | This is to replace the original forward function of Transformer models to enable: 36 | (1) Partial target sequence: loss will only be calculated on part of the sequence 37 | (2) Classification-style training: a classification loss (CE) will be calculated over several options 38 | Input: 39 | - input_ids, labels: same as the original forward function 40 | - option_len: a list of int indicating the option lengths, and loss will be calculated only on the 41 | last option_len tokens 42 | - num_options: a list of int indicating the number of options for each example (this will be #label 43 | words for classification tasks and #choices for multiple choice tasks), and a classification loss 44 | will be calculated. 45 | """ 46 | outputs = self.original_forward(input_ids=input_ids, **kwargs) 47 | 48 | if labels is None: 49 | return outputs 50 | 51 | # in prompt tuning, we need to remove the virtual tokens from the logits to match the input ids 52 | logits = outputs.logits 53 | 54 | loss = None 55 | # Shift so that tokens < n predict n 56 | shift_logits = logits[..., :-1, :].contiguous() 57 | # Here we use input_ids (which should always = labels) bc sometimes labels are correct candidate IDs 58 | shift_labels = torch.clone(input_ids)[..., 1:].contiguous() 59 | shift_labels[shift_labels == self.config.pad_token_id] = -100 60 | 61 | # Apply option len (do not calculate loss on the non-option part) 62 | # for _i, _len in enumerate(option_len): 63 | # shift_labels[_i, :-_len] = -100 64 | # re-write the above code to avoid the for loop 65 | non_option_len = shift_labels.shape[1] - option_len 66 | mask = torch.arange( 67 | shift_labels.shape[1], device=shift_labels.device 68 | ).expand(shift_labels.shape[0], -1) < non_option_len.unsqueeze(-1) 69 | shift_labels[mask] = -100 70 | 71 | # Calculate the loss 72 | loss_fct = CrossEntropyLoss(ignore_index=-100) 73 | 74 | if num_options is not None: 75 | # Train as a classification tasks 76 | log_probs = F.log_softmax(shift_logits, dim=-1) 77 | mask = shift_labels != -100 # Option part 78 | shift_labels[~mask] = 0 # So that it doesn't mess up with indexing 79 | 80 | selected_log_probs = torch.gather(log_probs, dim=-1, index=shift_labels.unsqueeze(-1)).squeeze( 81 | -1) # (bsz x num_options, len) 82 | selected_log_probs = (selected_log_probs * mask).sum(-1) / mask.sum(-1) # (bsz x num_options) 83 | 84 | if any([x != num_options[0] for x in num_options]): 85 | # Multi choice tasks with different number of options 86 | loss = 0 87 | start_id = 0 88 | count = 0 89 | while start_id < len(num_options): 90 | end_id = start_id + num_options[start_id] 91 | _logits = selected_log_probs[start_id:end_id].unsqueeze(0) # (1, num_options) 92 | _labels = labels[start_id:end_id][0].unsqueeze(0) # (1) 93 | loss = loss_fct(_logits, _labels) + loss 94 | count += 1 95 | start_id = end_id 96 | loss = loss / count 97 | else: 98 | num_options = num_options[0] 99 | selected_log_probs = selected_log_probs.view(-1, num_options) # (bsz, num_options) 100 | labels = labels.view(-1, num_options)[:, 0] # Labels repeat so we only take the first one 101 | # print('selected_log_probs', selected_log_probs.shape, selected_log_probs.softmax(dim=1).argmax(dim=1)) 102 | # print('log', selected_log_probs.argmax(dim=1)) 103 | loss = loss_fct(selected_log_probs, labels) 104 | 105 | else: 106 | loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1)) 107 | 108 | if not return_dict: 109 | output = (logits,) + outputs[1:] 110 | return (loss,) + output if loss is not None else output 111 | 112 | return CausalLMOutputWithPast( 113 | loss=loss, 114 | logits=logits, 115 | past_key_values=outputs.past_key_values, 116 | hidden_states=outputs.hidden_states, 117 | attentions=outputs.attentions, 118 | ) 119 | 120 | 121 | def encode_prompt(task, template, train_samples, eval_sample, tokenizer, max_length, sfc=False, icl_sfc=False, 122 | generation=False, generation_with_gold=False, max_new_tokens=None): 123 | """ 124 | Encode prompts for eval_sample 125 | Input: 126 | - task, template: task and template class 127 | - train_samples, eval_sample: demonstrations and the actual sample 128 | - tokenizer, max_length: tokenizer and max length 129 | - sfc: generate prompts for calibration (surface form competition; https://arxiv.org/abs/2104.08315) 130 | - icl_sfc: generate prompts for ICL version calibration 131 | - generation: whether it is an generation task 132 | - generation_with_gold: whether to include the generation-task gold answers (for training) 133 | - max_new_tokens: max number of new tokens to generate so that we can save enough space 134 | (only for generation tasks) 135 | Output: 136 | - encodings: a list of N lists of tokens. N is the number of options for classification/multiple-choice. 137 | - option_lens: a list of N integers indicating the number of option tokens. 138 | """ 139 | 140 | # Demonstrations for ICL 141 | train_prompts = [template.verbalize(sample, sample.correct_candidate).strip() for sample in train_samples] 142 | train_prompts = task.train_sep.join(train_prompts).strip() 143 | 144 | # sfc or icl_sfc indicates that this example is used for calibration 145 | if sfc or icl_sfc: 146 | encode_fn = template.encode_sfc 147 | verbalize_fn = template.verbalize_sfc 148 | else: 149 | encode_fn = template.encode 150 | verbalize_fn = template.verbalize 151 | 152 | unverbalized_eval_prompt = encode_fn(eval_sample).strip(' ') 153 | if not generation: 154 | # We generate one prompt for each candidate (different classes in classification) 155 | # or different choices in multiple-choice tasks 156 | verbalized_eval_prompts = [verbalize_fn(eval_sample, cand).strip(' ') for cand in eval_sample.candidates] 157 | unverbalized_eval_prompt_length = len(tokenizer.encode(unverbalized_eval_prompt)) 158 | option_lens = [(len(tokenizer.encode(verbalized_eval_prompt)) - unverbalized_eval_prompt_length) for 159 | verbalized_eval_prompt in verbalized_eval_prompts] 160 | 161 | if sfc: 162 | # Without demonstrations 163 | final_prompts = verbalized_eval_prompts 164 | else: 165 | # With demonstrations 166 | final_prompts = [(train_prompts + task.train_sep + eval_prompt).lstrip().strip(' ') for eval_prompt in 167 | verbalized_eval_prompts] 168 | else: 169 | assert not sfc and not icl_sfc, "Generation tasks do not support SFC" 170 | if generation_with_gold: 171 | verbalized_eval_prompts = [verbalize_fn(eval_sample, eval_sample.correct_candidate)] 172 | unverbalized_eval_prompt_length = len(tokenizer.encode(unverbalized_eval_prompt)) 173 | option_lens = [(len(tokenizer.encode(verbalized_eval_prompt)) - unverbalized_eval_prompt_length) for 174 | verbalized_eval_prompt in verbalized_eval_prompts] 175 | final_prompts = [(train_prompts + task.train_sep + eval_prompt).lstrip().strip(' ') for eval_prompt in 176 | verbalized_eval_prompts] 177 | else: 178 | option_lens = [0] 179 | final_prompts = [(train_prompts + task.train_sep + unverbalized_eval_prompt).lstrip().strip(' ')] 180 | 181 | # Tokenize 182 | encodings = [tokenizer.encode(final_prompt) for final_prompt in final_prompts] 183 | 184 | # Truncate (left truncate as demonstrations are less important) 185 | if generation and max_new_tokens is not None: 186 | max_length = max_length - max_new_tokens 187 | 188 | if any([len(encoding) > max_length for encoding in encodings]): 189 | logger.warn("Exceed max length") 190 | if hasattr(tokenizer, 'add_bos_token') and tokenizer.add_bos_token: 191 | encodings = [encoding[0:1] + encoding[1:][-(max_length - 1):] for encoding in encodings] 192 | else: 193 | encodings = [encoding[-max_length:] for encoding in encodings] 194 | 195 | return encodings, option_lens 196 | 197 | 198 | @dataclass 199 | class ICLCollator: 200 | """ 201 | Collator for ICL 202 | """ 203 | tokenizer: PreTrainedTokenizerBase 204 | 205 | def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]: 206 | if not isinstance(features[0], Mapping): 207 | features = [vars(f) for f in features] 208 | first = features[0] 209 | batch = {} 210 | 211 | pad_id = self.tokenizer.pad_token_id 212 | 213 | pad_ids = {"input_ids": pad_id, "attention_mask": 0, "sfc_input_ids": pad_id, "sfc_attention_mask": 0, 214 | "labels": pad_id} 215 | for key in first: 216 | pp = pad_ids[key] 217 | lens = [len(f[key]) for f in features] 218 | max_len = max(lens) 219 | feature = np.stack([np.pad(f[key], (0, max_len - lens[i]), "constant", constant_values=(0, pp)) for i, f in 220 | enumerate(features)]) 221 | padded_feature = torch.from_numpy(feature).long() 222 | batch[key] = padded_feature 223 | 224 | return batch 225 | 226 | 227 | @dataclass 228 | class DataCollatorWithPaddingAndNesting: 229 | """ 230 | Collator for training 231 | """ 232 | 233 | tokenizer: PreTrainedTokenizerBase 234 | padding: Union[bool, str, PaddingStrategy] = True 235 | max_length: Optional[int] = None 236 | pad_to_multiple_of: Optional[int] = None 237 | return_tensors: str = "pt" 238 | 239 | def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]: 240 | features = [ff for f in features for ff in f] 241 | batch = self.tokenizer.pad( 242 | features, 243 | padding=self.padding, 244 | max_length=self.max_length, 245 | pad_to_multiple_of=self.pad_to_multiple_of, 246 | return_tensors=self.return_tensors, 247 | ) 248 | if "label" in batch: 249 | batch["labels"] = batch["label"] 250 | del batch["label"] 251 | if "label_ids" in batch: 252 | batch["labels"] = batch["label_ids"] 253 | del batch["label_ids"] 254 | return batch 255 | 256 | 257 | @dataclass 258 | class NondiffCollator(DataCollatorMixin): 259 | """ 260 | Collator for non-differentiable objectives 261 | """ 262 | tokenizer: PreTrainedTokenizerBase 263 | padding: Union[bool, str, PaddingStrategy] = True 264 | max_length: Optional[int] = None 265 | pad_to_multiple_of: Optional[int] = None 266 | label_pad_token_id: int = -100 267 | return_tensors: str = "pt" 268 | 269 | def torch_call(self, features): 270 | import torch 271 | 272 | label_name = "label" if "label" in features[0].keys() else "labels" 273 | labels = [feature[label_name] for feature in features] if label_name in features[0].keys() else None 274 | 275 | no_labels_features = [{k: v for k, v in feature.items() if k != label_name and k != "gold"} for feature in 276 | features] 277 | 278 | batch = self.tokenizer.pad( 279 | no_labels_features, 280 | padding=self.padding, 281 | max_length=self.max_length, 282 | pad_to_multiple_of=self.pad_to_multiple_of, 283 | return_tensors="pt", 284 | ) 285 | 286 | if labels is None: 287 | return batch 288 | 289 | sequence_length = batch["input_ids"].shape[1] 290 | padding_side = self.tokenizer.padding_side 291 | 292 | def to_list(tensor_or_iterable): 293 | if isinstance(tensor_or_iterable, torch.Tensor): 294 | return tensor_or_iterable.tolist() 295 | return list(tensor_or_iterable) 296 | 297 | if padding_side == "right": 298 | batch[label_name] = [ 299 | to_list(label) + [self.label_pad_token_id] * (sequence_length - len(label)) for label in labels 300 | ] 301 | else: 302 | batch[label_name] = [ 303 | [self.label_pad_token_id] * (sequence_length - len(label)) + to_list(label) for label in labels 304 | ] 305 | 306 | batch[label_name] = torch.tensor(batch[label_name], dtype=torch.int64) 307 | if "gold" in features[0]: 308 | batch["gold"] = [feature["gold"] for feature in features] 309 | 310 | return batch 311 | 312 | 313 | class SIGUSR1Callback(transformers.TrainerCallback): 314 | """ 315 | This callback is used to save the model when a SIGUSR1 signal is received 316 | (SLURM stop signal or a keyboard interruption signal). 317 | """ 318 | 319 | def __init__(self) -> None: 320 | super().__init__() 321 | self.signal_received = False 322 | signal.signal(signal.SIGUSR1, self.handle_signal) 323 | signal.signal(signal.SIGINT, self.handle_signal) 324 | logger.warn("Handler registered") 325 | 326 | def handle_signal(self, signum, frame): 327 | self.signal_received = True 328 | logger.warn("Signal received") 329 | 330 | def on_step_end(self, args, state, control, **kwargs): 331 | if self.signal_received: 332 | control.should_save = True 333 | control.should_training_stop = True 334 | 335 | def on_train_end(self, args, state, control, **kwargs): 336 | if self.signal_received: 337 | exit(0) 338 | 339 | 340 | @dataclass 341 | class Prediction: 342 | correct_candidate: Union[int, str] 343 | predicted_candidate: Union[int, str] 344 | 345 | 346 | @contextlib.contextmanager 347 | def count_time(name): 348 | logger.info("%s..." % name) 349 | start_time = time.time() 350 | try: 351 | yield 352 | finally: 353 | logger.info("Done with %.2fs" % (time.time() - start_time)) 354 | 355 | 356 | @contextlib.contextmanager 357 | def temp_seed(seed): 358 | state = np.random.get_state() 359 | np.random.seed(seed) 360 | try: 361 | yield 362 | finally: 363 | np.random.set_state(state) 364 | 365 | 366 | class EnhancedJSONEncoder(json.JSONEncoder): 367 | def default(self, o): 368 | if is_dataclass(o): 369 | return asdict(o) 370 | return super().default(o) 371 | 372 | 373 | def write_predictions_to_file(final_preds, output): 374 | with open(output, "w") as f: 375 | for pred in final_preds: 376 | f.write(json.dumps(pred, cls=EnhancedJSONEncoder) + "\n") 377 | 378 | 379 | def write_metrics_to_file(metrics, output): 380 | json.dump(metrics, open(output, "w"), cls=EnhancedJSONEncoder, indent=4) 381 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==0.25.0 2 | aiohttp==3.9.1 3 | aiosignal==1.3.1 4 | appdirs==1.4.4 5 | async-timeout==4.0.3 6 | attrs==23.1.0 7 | certifi==2023.11.17 8 | charset-normalizer==3.3.2 9 | click==8.1.7 10 | datasets==2.16.0 11 | dill==0.3.7 12 | docker-pycreds==0.4.0 13 | filelock==3.13.1 14 | fsspec==2023.10.0 15 | gitdb==4.0.11 16 | GitPython==3.1.40 17 | gmpy2==2.1.2 18 | huggingface-hub==0.20.1 19 | idna==3.6 20 | Jinja2==3.1.2 21 | joblib==1.3.2 22 | llvmlite==0.41.1 23 | MarkupSafe==2.1.3 24 | mpmath==1.3.0 25 | multidict==6.0.4 26 | multiprocess==0.70.15 27 | networkx 28 | numba==0.58.1 29 | numpy 30 | packaging==23.2 31 | pandas 32 | pip==23.3.2 33 | protobuf==4.25.1 34 | psutil==5.9.7 35 | pyarrow==14.0.2 36 | pyarrow-hotfix==0.6 37 | python-dateutil==2.8.2 38 | pytz==2023.3.post1 39 | PyYAML==6.0.1 40 | regex==2023.12.25 41 | requests==2.31.0 42 | safetensors==0.4.1 43 | scikit-learn==1.3.2 44 | sentry-sdk==1.39.1 45 | setproctitle==1.3.3 46 | setuptools==68.2.2 47 | six==1.16.0 48 | smmap==5.0.1 49 | sympy==1.12 50 | threadpoolctl==3.2.0 51 | tokenizers==0.13.3 52 | torch==2.1.0 53 | tqdm==4.66.1 54 | transformers==4.28.1 55 | typing_extensions==4.9.0 56 | tzdata==2023.3 57 | urllib3==2.1.0 58 | wandb==0.16.1 59 | wheel==0.42.0 60 | xxhash==3.4.1 61 | yarl==1.9.4 62 | --------------------------------------------------------------------------------