├── .gitattributes ├── .gitignore ├── DATA_LICENSE ├── LICENSE ├── README.md ├── arguments.py ├── configs └── default_offload_opt_param.json ├── data └── README.md ├── figures └── apo_framework_v.png ├── model.py ├── requirements.txt ├── reward_datasets.py ├── tools ├── apo_data_converter.py ├── convert_apo_data.sh ├── inference_llm.py ├── llm_response_gen.sh └── rejection_sampling.py ├── train.py ├── trainer.py └── utils.py /.gitattributes: -------------------------------------------------------------------------------- 1 | *.json filter=lfs diff=lfs merge=lfs -text 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | *.*~ 6 | *tmp.py 7 | *.bak 8 | 9 | # wandb 10 | wandb/ 11 | 12 | # C extensions 13 | *.so 14 | 15 | # Distribution / packaging 16 | .Python 17 | build/ 18 | develop-eggs/ 19 | dist/ 20 | downloads/ 21 | eggs/ 22 | .eggs/ 23 | lib/ 24 | lib64/ 25 | parts/ 26 | sdist/ 27 | var/ 28 | wheels/ 29 | pip-wheel-metadata/ 30 | share/python-wheels/ 31 | *.egg-info/ 32 | .installed.cfg 33 | *.egg 34 | MANIFEST 35 | 36 | # PyInstaller 37 | # Usually these files are written by a python script from a template 38 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 39 | *.manifest 40 | *.spec 41 | 42 | # Installer logs 43 | pip-log.txt 44 | pip-delete-this-directory.txt 45 | 46 | # Unit test / coverage reports 47 | htmlcov/ 48 | .tox/ 49 | .nox/ 50 | .coverage 51 | .coverage.* 52 | .cache 53 | nosetests.xml 54 | coverage.xml 55 | *.cover 56 | *.py,cover 57 | .hypothesis/ 58 | .pytest_cache/ 59 | 60 | # Translations 61 | *.mo 62 | *.pot 63 | 64 | # Django stuff: 65 | *.log 66 | local_settings.py 67 | db.sqlite3 68 | db.sqlite3-journal 69 | 70 | # Flask stuff: 71 | instance/ 72 | .webassets-cache 73 | 74 | # Scrapy stuff: 75 | .scrapy 76 | 77 | # Sphinx documentation 78 | docs/_build/ 79 | 80 | # PyBuilder 81 | target/ 82 | 83 | # Jupyter Notebook 84 | .ipynb_checkpoints 85 | 86 | # IPython 87 | profile_default/ 88 | ipython_config.py 89 | 90 | # pyenv 91 | .python-version 92 | 93 | # pipenv 94 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 95 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 96 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 97 | # install all needed dependencies. 98 | #Pipfile.lock 99 | 100 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 101 | __pypackages__/ 102 | 103 | # Celery stuff 104 | celerybeat-schedule 105 | celerybeat.pid 106 | 107 | # SageMath parsed files 108 | *.sage.py 109 | 110 | # Environments 111 | .env 112 | .venv 113 | env/ 114 | venv/ 115 | ENV/ 116 | env.bak/ 117 | venv.bak/ 118 | 119 | # Spyder project settings 120 | .spyderproject 121 | .spyproject 122 | 123 | # Rope project settings 124 | .ropeproject 125 | 126 | # mkdocs documentation 127 | /site 128 | 129 | # mypy 130 | .mypy_cache/ 131 | .dmypy.json 132 | dmypy.json 133 | 134 | # Pyre type checker 135 | .pyre/ 136 | 137 | .DS_Store 138 | .idea 139 | -------------------------------------------------------------------------------- /DATA_LICENSE: -------------------------------------------------------------------------------- 1 | Attribution-NonCommercial 4.0 International 2 | 3 | ======================================================================= 4 | 5 | Creative Commons Corporation ("Creative Commons") is not a law firm and 6 | does not provide legal services or legal advice. Distribution of 7 | Creative Commons public licenses does not create a lawyer-client or 8 | other relationship. Creative Commons makes its licenses and related 9 | information available on an "as-is" basis. Creative Commons gives no 10 | warranties regarding its licenses, any material licensed under their 11 | terms and conditions, or any related information. Creative Commons 12 | disclaims all liability for damages resulting from their use to the 13 | fullest extent possible. 14 | 15 | Using Creative Commons Public Licenses 16 | 17 | Creative Commons public licenses provide a standard set of terms and 18 | conditions that creators and other rights holders may use to share 19 | original works of authorship and other material subject to copyright 20 | and certain other rights specified in the public license below. The 21 | following considerations are for informational purposes only, are not 22 | exhaustive, and do not form part of our licenses. 23 | 24 | Considerations for licensors: Our public licenses are 25 | intended for use by those authorized to give the public 26 | permission to use material in ways otherwise restricted by 27 | copyright and certain other rights. Our licenses are 28 | irrevocable. Licensors should read and understand the terms 29 | and conditions of the license they choose before applying it. 30 | Licensors should also secure all rights necessary before 31 | applying our licenses so that the public can reuse the 32 | material as expected. Licensors should clearly mark any 33 | material not subject to the license. This includes other CC- 34 | licensed material, or material used under an exception or 35 | limitation to copyright. More considerations for licensors: 36 | wiki.creativecommons.org/Considerations_for_licensors 37 | 38 | Considerations for the public: By using one of our public 39 | licenses, a licensor grants the public permission to use the 40 | licensed material under specified terms and conditions. If 41 | the licensor's permission is not necessary for any reason--for 42 | example, because of any applicable exception or limitation to 43 | copyright--then that use is not regulated by the license. Our 44 | licenses grant only permissions under copyright and certain 45 | other rights that a licensor has authority to grant. Use of 46 | the licensed material may still be restricted for other 47 | reasons, including because others have copyright or other 48 | rights in the material. A licensor may make special requests, 49 | such as asking that all changes be marked or described. 50 | Although not required by our licenses, you are encouraged to 51 | respect those requests where reasonable. More considerations 52 | for the public: 53 | wiki.creativecommons.org/Considerations_for_licensees 54 | 55 | ======================================================================= 56 | 57 | Creative Commons Attribution-NonCommercial 4.0 International Public 58 | License 59 | 60 | By exercising the Licensed Rights (defined below), You accept and agree 61 | to be bound by the terms and conditions of this Creative Commons 62 | Attribution-NonCommercial 4.0 International Public License ("Public 63 | License"). To the extent this Public License may be interpreted as a 64 | contract, You are granted the Licensed Rights in consideration of Your 65 | acceptance of these terms and conditions, and the Licensor grants You 66 | such rights in consideration of benefits the Licensor receives from 67 | making the Licensed Material available under these terms and 68 | conditions. 69 | 70 | 71 | Section 1 -- Definitions. 72 | 73 | a. Adapted Material means material subject to Copyright and Similar 74 | Rights that is derived from or based upon the Licensed Material 75 | and in which the Licensed Material is translated, altered, 76 | arranged, transformed, or otherwise modified in a manner requiring 77 | permission under the Copyright and Similar Rights held by the 78 | Licensor. For purposes of this Public License, where the Licensed 79 | Material is a musical work, performance, or sound recording, 80 | Adapted Material is always produced where the Licensed Material is 81 | synched in timed relation with a moving image. 82 | 83 | b. Adapter's License means the license You apply to Your Copyright 84 | and Similar Rights in Your contributions to Adapted Material in 85 | accordance with the terms and conditions of this Public License. 86 | 87 | c. Copyright and Similar Rights means copyright and/or similar rights 88 | closely related to copyright including, without limitation, 89 | performance, broadcast, sound recording, and Sui Generis Database 90 | Rights, without regard to how the rights are labeled or 91 | categorized. For purposes of this Public License, the rights 92 | specified in Section 2(b)(1)-(2) are not Copyright and Similar 93 | Rights. 94 | d. Effective Technological Measures means those measures that, in the 95 | absence of proper authority, may not be circumvented under laws 96 | fulfilling obligations under Article 11 of the WIPO Copyright 97 | Treaty adopted on December 20, 1996, and/or similar international 98 | agreements. 99 | 100 | e. Exceptions and Limitations means fair use, fair dealing, and/or 101 | any other exception or limitation to Copyright and Similar Rights 102 | that applies to Your use of the Licensed Material. 103 | 104 | f. Licensed Material means the artistic or literary work, database, 105 | or other material to which the Licensor applied this Public 106 | License. 107 | 108 | g. Licensed Rights means the rights granted to You subject to the 109 | terms and conditions of this Public License, which are limited to 110 | all Copyright and Similar Rights that apply to Your use of the 111 | Licensed Material and that the Licensor has authority to license. 112 | 113 | h. Licensor means the individual(s) or entity(ies) granting rights 114 | under this Public License. 115 | 116 | i. NonCommercial means not primarily intended for or directed towards 117 | commercial advantage or monetary compensation. For purposes of 118 | this Public License, the exchange of the Licensed Material for 119 | other material subject to Copyright and Similar Rights by digital 120 | file-sharing or similar means is NonCommercial provided there is 121 | no payment of monetary compensation in connection with the 122 | exchange. 123 | 124 | j. Share means to provide material to the public by any means or 125 | process that requires permission under the Licensed Rights, such 126 | as reproduction, public display, public performance, distribution, 127 | dissemination, communication, or importation, and to make material 128 | available to the public including in ways that members of the 129 | public may access the material from a place and at a time 130 | individually chosen by them. 131 | 132 | k. Sui Generis Database Rights means rights other than copyright 133 | resulting from Directive 96/9/EC of the European Parliament and of 134 | the Council of 11 March 1996 on the legal protection of databases, 135 | as amended and/or succeeded, as well as other essentially 136 | equivalent rights anywhere in the world. 137 | 138 | l. You means the individual or entity exercising the Licensed Rights 139 | under this Public License. Your has a corresponding meaning. 140 | 141 | 142 | Section 2 -- Scope. 143 | 144 | a. License grant. 145 | 146 | 1. Subject to the terms and conditions of this Public License, 147 | the Licensor hereby grants You a worldwide, royalty-free, 148 | non-sublicensable, non-exclusive, irrevocable license to 149 | exercise the Licensed Rights in the Licensed Material to: 150 | 151 | a. reproduce and Share the Licensed Material, in whole or 152 | in part, for NonCommercial purposes only; and 153 | 154 | b. produce, reproduce, and Share Adapted Material for 155 | NonCommercial purposes only. 156 | 157 | 2. Exceptions and Limitations. For the avoidance of doubt, where 158 | Exceptions and Limitations apply to Your use, this Public 159 | License does not apply, and You do not need to comply with 160 | its terms and conditions. 161 | 162 | 3. Term. The term of this Public License is specified in Section 163 | 6(a). 164 | 165 | 4. Media and formats; technical modifications allowed. The 166 | Licensor authorizes You to exercise the Licensed Rights in 167 | all media and formats whether now known or hereafter created, 168 | and to make technical modifications necessary to do so. The 169 | Licensor waives and/or agrees not to assert any right or 170 | authority to forbid You from making technical modifications 171 | necessary to exercise the Licensed Rights, including 172 | technical modifications necessary to circumvent Effective 173 | Technological Measures. For purposes of this Public License, 174 | simply making modifications authorized by this Section 2(a) 175 | (4) never produces Adapted Material. 176 | 177 | 5. Downstream recipients. 178 | 179 | a. Offer from the Licensor -- Licensed Material. Every 180 | recipient of the Licensed Material automatically 181 | receives an offer from the Licensor to exercise the 182 | Licensed Rights under the terms and conditions of this 183 | Public License. 184 | 185 | b. No downstream restrictions. You may not offer or impose 186 | any additional or different terms or conditions on, or 187 | apply any Effective Technological Measures to, the 188 | Licensed Material if doing so restricts exercise of the 189 | Licensed Rights by any recipient of the Licensed 190 | Material. 191 | 192 | 6. No endorsement. Nothing in this Public License constitutes or 193 | may be construed as permission to assert or imply that You 194 | are, or that Your use of the Licensed Material is, connected 195 | with, or sponsored, endorsed, or granted official status by, 196 | the Licensor or others designated to receive attribution as 197 | provided in Section 3(a)(1)(A)(i). 198 | 199 | b. Other rights. 200 | 201 | 1. Moral rights, such as the right of integrity, are not 202 | licensed under this Public License, nor are publicity, 203 | privacy, and/or other similar personality rights; however, to 204 | the extent possible, the Licensor waives and/or agrees not to 205 | assert any such rights held by the Licensor to the limited 206 | extent necessary to allow You to exercise the Licensed 207 | Rights, but not otherwise. 208 | 209 | 2. Patent and trademark rights are not licensed under this 210 | Public License. 211 | 212 | 3. To the extent possible, the Licensor waives any right to 213 | collect royalties from You for the exercise of the Licensed 214 | Rights, whether directly or through a collecting society 215 | under any voluntary or waivable statutory or compulsory 216 | licensing scheme. In all other cases the Licensor expressly 217 | reserves any right to collect such royalties, including when 218 | the Licensed Material is used other than for NonCommercial 219 | purposes. 220 | 221 | 222 | Section 3 -- License Conditions. 223 | 224 | Your exercise of the Licensed Rights is expressly made subject to the 225 | following conditions. 226 | 227 | a. Attribution. 228 | 229 | 1. If You Share the Licensed Material (including in modified 230 | form), You must: 231 | 232 | a. retain the following if it is supplied by the Licensor 233 | with the Licensed Material: 234 | 235 | i. identification of the creator(s) of the Licensed 236 | Material and any others designated to receive 237 | attribution, in any reasonable manner requested by 238 | the Licensor (including by pseudonym if 239 | designated); 240 | 241 | ii. a copyright notice; 242 | 243 | iii. a notice that refers to this Public License; 244 | 245 | iv. a notice that refers to the disclaimer of 246 | warranties; 247 | 248 | v. a URI or hyperlink to the Licensed Material to the 249 | extent reasonably practicable; 250 | 251 | b. indicate if You modified the Licensed Material and 252 | retain an indication of any previous modifications; and 253 | 254 | c. indicate the Licensed Material is licensed under this 255 | Public License, and include the text of, or the URI or 256 | hyperlink to, this Public License. 257 | 258 | 2. You may satisfy the conditions in Section 3(a)(1) in any 259 | reasonable manner based on the medium, means, and context in 260 | which You Share the Licensed Material. For example, it may be 261 | reasonable to satisfy the conditions by providing a URI or 262 | hyperlink to a resource that includes the required 263 | information. 264 | 265 | 3. If requested by the Licensor, You must remove any of the 266 | information required by Section 3(a)(1)(A) to the extent 267 | reasonably practicable. 268 | 269 | 4. If You Share Adapted Material You produce, the Adapter's 270 | License You apply must not prevent recipients of the Adapted 271 | Material from complying with this Public License. 272 | 273 | 274 | Section 4 -- Sui Generis Database Rights. 275 | 276 | Where the Licensed Rights include Sui Generis Database Rights that 277 | apply to Your use of the Licensed Material: 278 | 279 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right 280 | to extract, reuse, reproduce, and Share all or a substantial 281 | portion of the contents of the database for NonCommercial purposes 282 | only; 283 | 284 | b. if You include all or a substantial portion of the database 285 | contents in a database in which You have Sui Generis Database 286 | Rights, then the database in which You have Sui Generis Database 287 | Rights (but not its individual contents) is Adapted Material; and 288 | 289 | c. You must comply with the conditions in Section 3(a) if You Share 290 | all or a substantial portion of the contents of the database. 291 | 292 | For the avoidance of doubt, this Section 4 supplements and does not 293 | replace Your obligations under this Public License where the Licensed 294 | Rights include other Copyright and Similar Rights. 295 | 296 | 297 | Section 5 -- Disclaimer of Warranties and Limitation of Liability. 298 | 299 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE 300 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS 301 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF 302 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, 303 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, 304 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR 305 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, 306 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT 307 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT 308 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. 309 | 310 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE 311 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, 312 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, 313 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, 314 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR 315 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN 316 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR 317 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR 318 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. 319 | 320 | c. The disclaimer of warranties and limitation of liability provided 321 | above shall be interpreted in a manner that, to the extent 322 | possible, most closely approximates an absolute disclaimer and 323 | waiver of all liability. 324 | 325 | 326 | Section 6 -- Term and Termination. 327 | 328 | a. This Public License applies for the term of the Copyright and 329 | Similar Rights licensed here. However, if You fail to comply with 330 | this Public License, then Your rights under this Public License 331 | terminate automatically. 332 | 333 | b. Where Your right to use the Licensed Material has terminated under 334 | Section 6(a), it reinstates: 335 | 336 | 1. automatically as of the date the violation is cured, provided 337 | it is cured within 30 days of Your discovery of the 338 | violation; or 339 | 340 | 2. upon express reinstatement by the Licensor. 341 | 342 | For the avoidance of doubt, this Section 6(b) does not affect any 343 | right the Licensor may have to seek remedies for Your violations 344 | of this Public License. 345 | 346 | c. For the avoidance of doubt, the Licensor may also offer the 347 | Licensed Material under separate terms or conditions or stop 348 | distributing the Licensed Material at any time; however, doing so 349 | will not terminate this Public License. 350 | 351 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public 352 | License. 353 | 354 | 355 | Section 7 -- Other Terms and Conditions. 356 | 357 | a. The Licensor shall not be bound by any additional or different 358 | terms or conditions communicated by You unless expressly agreed. 359 | 360 | b. Any arrangements, understandings, or agreements regarding the 361 | Licensed Material not stated herein are separate from and 362 | independent of the terms and conditions of this Public License. 363 | 364 | 365 | Section 8 -- Interpretation. 366 | 367 | a. For the avoidance of doubt, this Public License does not, and 368 | shall not be interpreted to, reduce, limit, restrict, or impose 369 | conditions on any use of the Licensed Material that could lawfully 370 | be made without permission under this Public License. 371 | 372 | b. To the extent possible, if any provision of this Public License is 373 | deemed unenforceable, it shall be automatically reformed to the 374 | minimum extent necessary to make it enforceable. If the provision 375 | cannot be reformed, it shall be severed from this Public License 376 | without affecting the enforceability of the remaining terms and 377 | conditions. 378 | 379 | c. No term or condition of this Public License will be waived and no 380 | failure to comply consented to unless expressly agreed to by the 381 | Licensor. 382 | 383 | d. Nothing in this Public License constitutes or may be interpreted 384 | as a limitation upon, or waiver of, any privileges and immunities 385 | that apply to the Licensor or You, including from the legal 386 | processes of any jurisdiction or authority. 387 | 388 | ======================================================================= 389 | 390 | Creative Commons is not a party to its public 391 | licenses. Notwithstanding, Creative Commons may elect to apply one of 392 | its public licenses to material it publishes and in those instances 393 | will be considered the “Licensor.” The text of the Creative Commons 394 | public licenses is dedicated to the public domain under the CC0 Public 395 | Domain Dedication. Except for the limited purpose of indicating that 396 | material is shared under a Creative Commons public license or as 397 | otherwise permitted by the Creative Commons policies published at 398 | creativecommons.org/policies, Creative Commons does not authorize the 399 | use of the trademark "Creative Commons" or any other trademark or logo 400 | of Creative Commons without its prior written consent including, 401 | without limitation, in connection with any unauthorized modifications 402 | to any of its public licenses or any other arrangements, 403 | understandings, or agreements concerning use of licensed material. For 404 | the avoidance of doubt, this paragraph does not form part of the 405 | public licenses. 406 | 407 | Creative Commons may be contacted at creativecommons.org. 408 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Adversarial Preference Optimization 2 | 3 | [![Code License](https://img.shields.io/badge/Code%20License-Apache_2.0-green.svg)](https://github.com/Linear95/APO/blob/main/LICENSE) 4 | [![Data License](https://img.shields.io/badge/Data%20License-CC%20By%20NC%204.0-red.svg)](https://github.com/Linear95/APO/blob/main/DATA_LICENSE) 5 | [![Python 3.8+](https://img.shields.io/badge/python-3.8+-blue.svg)](https://www.python.org/downloads/release/python-380/) 6 | 7 | This repo contains the implementation of the ACL 2024 paper: 8 | - [Adversarial Preference Optimization: Enhancing Your Alignment via RM-LLM Game](https://arxiv.org/abs/2311.08045). 9 | 10 | In Adversarial Preference Optimization (APO), we let the reward model (RM) and LLM agent play a min-max game, through which both models can be further enhanced without additional preference annotation. 11 | 12 |

13 | 14 |

15 | 16 | For an overview, the repo contains: 17 | - [Split Helpful\&Harmless](https://drive.google.com/drive/folders/1v0xNMMOfL9lfFLzTGCerZCPNPJrR9ZLX?usp=sharing) (HH) dataset 18 | - [GPT-4 responses](https://drive.google.com/file/d/1hDo6Sk8QX1c3kP_qJUgZ4J16kHAi0hEq/view?usp=sharing) as golden annotation on HH-RM training set 19 | - The base RM, testing RM, and APO RM training \& scoring pipelines 20 | - The LLM response generation [pipeline](https://github.com/Linear95/APO/blob/main/tools/llm_response_gen.sh) 21 | 22 | 23 | ## Environment 24 | We use `Python3.8` with the dependencies listed in `requirements.txt`. To build the appropriate environment, use the following command: 25 | ``` 26 | pip3 install -r requirements.txt 27 | ``` 28 | 29 | ## Data \& Annotation 30 | 31 | To separately update RM and LLM, we split the cleaned [Helpful\&Harmless](https://github.com/Linear95/DSP/tree/main/data) (HH) dataset into an RM training set and a LLM training set. 32 | | Data Type| HH-RM Train Set | HH-LLM Train Set| HH Test Set| 33 | | --------:| :----------|:-------| :--------| 34 | | Preference Pairs | [RM training set](https://drive.google.com/file/d/12DefElb3DazIPeaIEwd0B_9La84Slc7f/view?usp=sharing) | [RM validation set](https://drive.google.com/file/d/1ZqTuupFxrK2m3_E6ezMRcdT_4k6zX-IW/view?usp=sharing) (sampled 10K pairs) | [RM testing set](https://drive.google.com/file/d/1ite1KXZlGs1ojCVB20rLHlj7_3KlOULY/view?usp=sharing)| 35 | | Golden Answers | [APO positive responses](https://drive.google.com/file/d/1hDo6Sk8QX1c3kP_qJUgZ4J16kHAi0hEq/view?usp=sharing) | | | 36 | | LLM Samples | APO negative responses ([`alpaca_rm_samples`](https://drive.google.com/file/d/1_wiKVKob6QVOHja4C_N-y5LlvHZE9ZiZ/view?usp=sharing)) | LLM alignment samples ([`alpaca_llm_samples`](https://drive.google.com/file/d/1ZpAXK0F-YC919_vP7gnyGpo8ezQGIv5O/view?usp=sharing))| [LLM testing Queries](https://drive.google.com/file/d/1ite1KXZlGs1ojCVB20rLHlj7_3KlOULY/view?usp=drive_link)| 37 | 38 | 39 | On both HH-RM and HH-LLM training sets, we infer four LLM responses for each query as [`alpaca_rm_samples`](https://drive.google.com/file/d/1_wiKVKob6QVOHja4C_N-y5LlvHZE9ZiZ/view?usp=sharing) and [`alpaca_llm_samples`](https://drive.google.com/file/d/1ZpAXK0F-YC919_vP7gnyGpo8ezQGIv5O/view?usp=sharing). `alpaca_rm_samples` is combined with the golden responses on the HH-RM set as APO RM training pairs. `alpaca_llm_samples` is further scored by RMs and used for LLM alignment. To obtain LLM responses by yourself, run the command: 40 | ```bash 41 | bash tools/llm_response_gen.sh 42 | ``` 43 | 44 | 45 | 46 | ## RM Training 47 | 48 | ### Base RM Training 49 | 50 | We build our RM on the pretrained LLaMA-7B ([`decapoda-research/llama-7b-hf`](https://huggingface.co/decapoda-research/llama-7b-hf)). To train the base RM for rejection sampling, use the following command: 51 | 52 | ```bash 53 | REPO_DIR= 54 | DATA_DIR=${REPO_DIR}/data/hh-split 55 | TRAIN_DATA_LIST="${DATA_DIR}/rm_data/hh_split_rm.train.json" 56 | TEST_DATA_LIST="${DATA_DIR}/eval_data/hh_cleaned_origin.test.json\ 57 | ${DATA_DIR}/eval_data/hh_split_llm.valid.json" 58 | 59 | NUM_GPUS=8 60 | BATCH_SIZE=64 61 | MICRO_BATCH_SIZE=1 62 | LEARNING_RATE=1e-6 63 | GRADIENT_ACCUMULATION_STEP=$((BATCH_SIZE / NUM_GPUS / MICRO_BATCH_SIZE)) 64 | 65 | torchrun --nproc_per_node=${NUM_GPUS} --master_port=6000 ${REPO_DIR}/train.py \ 66 | --task_type hh_split \ 67 | --do_train True \ 68 | --eval_at_start False \ 69 | --model_type reward \ 70 | --model_name_or_path "decapoda-research/llama-7b-hf" \ 71 | --data_type "comparison_pair" \ 72 | --train_data_path ${TRAIN_DATA_LIST} \ 73 | --eval_data_path ${TEST_DATA_LIST} \ 74 | --rm_calibration True \ 75 | --data_suffix rm_base \ 76 | --add_sep_token True \ 77 | --remove_unused_columns false \ 78 | --output_dir \ 79 | --num_train_epochs 1 \ 80 | --per_device_train_batch_size ${MICRO_BATCH_SIZE} \ 81 | --per_device_eval_batch_size ${MICRO_BATCH_SIZE} \ 82 | --gradient_accumulation_steps ${GRADIENT_ACCUMULATION_STEP} \ 83 | --evaluation_strategy steps \ 84 | --padding_side right \ 85 | --truncation_side left \ 86 | --pooling_type last \ 87 | --max_length 512 \ 88 | --save_strategy steps \ 89 | --learning_rate ${LEARNING_RATE} \ 90 | --warmup_steps 100 \ 91 | --deepspeed configs/default_offload_opt_param.json \ 92 | --tf32 false --fp16 false 93 | ``` 94 | 95 | We also trained a testing RM to automatically evaluate the LLM response quality on the testing queries. To train the testing RM, change `TRAIN_DATA_LIST=${DATA_DIR}/hh_cleaned_origin.train.json` in the above command to learn with all the HH training comparisons. 96 | 97 | The RM training data files (values in `TRAIN_DATA_LIST`) are lists of dictionaries, where each dictionary is an RM training item (`--data_type="comparison_pair"`) including the following keys: 98 | - `text`: a list of query-response text, split by a special token ``. 99 | - `scores`: a list of float numbers, representing the preference scores of the corresponding query-response text. 100 | - `query_id`: a unique ID to the RM training item. 101 | 102 | 103 | 104 | ### APO RM Training 105 | 106 | To train the APO RM, first merge LLM samples and golden annotations into APO comparison pairs: 107 | ``` 108 | REPO_DIR= 109 | DATA_DIR="${REPO_DIR}/data/hh-split" 110 | 111 | python3 ${REPO_DIR}/tools/apo_data_converter.py \ 112 | --golden_data_path ${DATA_DIR}/rm_data/hh_split_rm.golden.json \ 113 | --sample_data_path ${DATA_DIR}/rm_data/hh_split_rm_alpaca_v0.sample.json \ 114 | --output_dir ${DATA_DIR}/apo_data \ 115 | --apo_data_name "rm_apo_data_v0" 116 | ``` 117 | 118 | Then use the following command to conduct APO RM finetuning: 119 | ``` 120 | REPO_DIR= 121 | DATA_DIR=${REPO_DIR}/data/hh-split 122 | TRAIN_DATA_LIST="${DATA_DIR}/rm_data/hh_split_rm.train.json \ 123 | ${DATA_DIR}/apo_data/rm_apo_data_v0_text_scores.json" 124 | NUM_APO_SAMPLES=4 125 | 126 | TEST_DATA_LIST="${DATA_DIR}/eval_data/hh_cleaned_origin.test.json \ 127 | ${DATA_DIR}/eval_data/hh_split_llm.valid.json" 128 | 129 | NUM_GPUS=8 130 | BATCH_SIZE=64 131 | MICRO_BATCH_SIZE=1 132 | LEARNING_RATE=1e-6 133 | APO_COEFF=0.1 134 | GRADIENT_ACCUMULATION_STEP=$((BATCH_SIZE / NUM_GPUS / MICRO_BATCH_SIZE)) 135 | 136 | 137 | torchrun --nproc_per_node=${NUM_GPUS} --master_port=6000 ${REPO_DIR}/train.py \ 138 | --task_type apo \ 139 | --do_train True \ 140 | --eval_at_start False \ 141 | --model_type reward \ 142 | --model_name_or_path "decapoda-research/llama-7b-hf" \ 143 | --data_type "comparison_pair" \ 144 | --train_data_path ${TRAIN_DATA_LIST} \ 145 | --eval_data_path ${TEST_DATA_LIST} \ 146 | --rm_calibration True \ 147 | --data_suffix rm_apo_v1 \ 148 | --add_sep_token True \ 149 | --remove_unused_columns false \ 150 | --output_dir \ 151 | --num_train_epochs 1 \ 152 | --apo_loss_coeff ${APO_COEFF} \ 153 | --apo_sample_num ${NUM_APO_SAMPLES} \ 154 | --per_device_train_batch_size ${MICRO_BATCH_SIZE} \ 155 | --per_device_eval_batch_size ${MICRO_BATCH_SIZE} \ 156 | --gradient_accumulation_steps ${GRADIENT_ACCUMULATION_STEP} \ 157 | --evaluation_strategy steps \ 158 | --padding_side right \ 159 | --truncation_side left \ 160 | --pooling_type last \ 161 | --max_length 512 \ 162 | --save_strategy steps \ 163 | --save_total_limit 10 \ 164 | --learning_rate ${LEARNING_RATE} \ 165 | --warmup_steps 100 \ 166 | --deepspeed configs/default_offload_opt_param.json \ 167 | --tf32 false --fp16 false 168 | ``` 169 | ## RM Scoring 170 | 171 | After finishing the RM training, we can use the following command to scoring new LLM samples: 172 | ```bash 173 | REPO_DIR= 174 | DATA_DIR=${REPO_DIR}/data/hh-split/llm_data 175 | DATA_PATH="${DATA_DIR}/hh_split_llm_alpaca_v0.sample.json" 176 | 177 | MODEL_PATH= 178 | MODEL_NAME="base_rm" # or "apo_rm" 179 | 180 | NUM_GPUS=8 181 | MICRO_BATCH_SIZE=16 182 | 183 | torchrun --nproc_per_node=${NUM_GPUS} --master_port=6000 ${REPO_DIR}/train.py \ 184 | --task_type inference \ 185 | --do_train False \ 186 | --eval_at_start True \ 187 | --model_type reward \ 188 | --model_name_or_path ${MODEL_PATH} \ 189 | --data_type "reject_sample" \ 190 | --eval_data_path ${DATA_PATH} \ 191 | --rm_calibration False \ 192 | --data_suffix ${MODEL_NAME} \ 193 | --add_sep_token True \ 194 | --remove_unused_columns false \ 195 | --output_dir \ 196 | --per_device_eval_batch_size ${MICRO_BATCH_SIZE} \ 197 | --evaluation_strategy steps \ 198 | --padding_side right \ 199 | --truncation_side left \ 200 | --pooling_type last \ 201 | --max_length 512 \ 202 | --deepspeed configs/default_offload_opt_param.json \ 203 | --tf32 false --fp16 false 204 | 205 | 206 | # rejection sampling 207 | SCORE_PATH=${DATA_PATH}_pred_${MODEL_NAME}_results.json 208 | OUTPUT_FILE_NAME=${DATA_PATH}_rjs_${MODEL_NAME}.json 209 | 210 | python3 ${REPO_DIR}/tools/rejection_sampling.py \ 211 | --data_path ${DATA_DIR} \ 212 | --score_path ${SCORE_PATH} \ 213 | --output_dir ${DATA_DIR} \ 214 | --rm_scorer ${MODEL_NAME} \ 215 | --output_file_name ${OUTPUT_FILE_NAME} 216 | 217 | # remove tmp inference files 218 | rm ${DATA_DIR}/*rank*.jsonl 219 | ``` 220 | After inference process, we obtain a RM scoring file `${DATA_PATH}_rjs_${MODEL_NAME}.json`. Then we can update the Alpaca model with the training pipeline [here](https://github.com/tatsu-lab/stanford_alpaca). 221 | 222 | 223 | ## Citation 224 | ``` 225 | @inproceedings{cheng2024adversarial, 226 | title={Adversarial Preference Optimization: Enhancing Your Alignment via RM-LLM Game}, 227 | author={Cheng, Pengyu and Yang, Yifan and Li, Jian and Dai, Yong and Hu, Tianhao and Cao, Peixin and Du, Nan and Li, Xiaolong}, 228 | booktitle={Findings of the Association for Computational Linguistics}, 229 | year={2024} 230 | } 231 | ``` 232 | -------------------------------------------------------------------------------- /arguments.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Tuple, Union 2 | 3 | from dataclasses import dataclass, field 4 | from transformers import TrainingArguments 5 | 6 | @dataclass 7 | class CustomTrainingArguments(TrainingArguments): 8 | # experiment setups 9 | reward_domain: str = field( 10 | default="normal", 11 | metadata={"help": "the domain for reward model training."} 12 | ) 13 | # tokenizer params 14 | padding_side: str = field( 15 | default="right", 16 | metadata={"help": "the direction for tokenizer to add padding tokens."} 17 | ) 18 | 19 | truncation_side: str = field( 20 | default="left", 21 | metadata={"help": "the direction for tokenizer to add padding tokens."} 22 | ) 23 | 24 | add_sep_token: bool =field( 25 | default=False, 26 | metadata={"help": "whether add a token between query and response."} 27 | ) 28 | 29 | tokenizer_path: str = field( 30 | default="llama-7b-hf", 31 | metadata={"help": "the path to load pretrained tokenizer."} 32 | ) 33 | 34 | 35 | # model params 36 | model_type: str = field( 37 | default="llama", 38 | metadata={"help": "the base model type for reward model, selected from [llama, bert]."} 39 | ) 40 | 41 | model_prefix: str = field( 42 | default="llama", 43 | metadata={"help": "the base model type for reward model, selected from [llama, bert]."} 44 | ) 45 | 46 | 47 | pooling_type: str = field( 48 | default="average", 49 | metadata={"help": "the pooling method for reward model, selected from [average, max, last]."} 50 | ) 51 | 52 | model_name_or_path: str = field( 53 | default="llama-7b-hf", 54 | metadata={"help": "the path to load pretrained model."} 55 | ) 56 | 57 | 58 | # data params 59 | 60 | apo_sample_num: int = field( 61 | default=1, 62 | metadata={"help": "the maximum response number of each data item"} 63 | ) 64 | 65 | 66 | data_dir: str = field( 67 | default="path/to/cleaned_data", 68 | metadata={"help": "the directory to load data."} 69 | ) 70 | 71 | data_type: str = field( 72 | default="no_type", 73 | metadata={"help": "the type of data."} 74 | ) 75 | data_path: str = field( 76 | default="yahma/alpaca-cleaned", 77 | metadata={"help": "the path to load data."} 78 | ) 79 | 80 | train_data_path: List[str] = field( 81 | default_factory=lambda: ["/data/to/train/dataset"], 82 | metadata={"help": "train datasets paths."} 83 | ) 84 | 85 | 86 | eval_data_path: List[str] = field( 87 | default_factory=lambda: ["/data/to/eval/dataset"], 88 | metadata={"help": "evaluation datasets paths."} 89 | ) 90 | 91 | 92 | data_prefix: str = field( 93 | default="yahma/alpaca-cleaned", 94 | metadata={"help": "the prefix to load train and test data."} 95 | ) 96 | 97 | data_suffix: str = field( 98 | default="yahma/alpaca-cleaned", 99 | metadata={"help": "the suffix to save inference data."} 100 | ) 101 | 102 | 103 | format_mode: str = field( 104 | default="lab_mode", 105 | metadata={"help": "the format to process data"} 106 | ) 107 | 108 | 109 | # training hyperparams 110 | task_type: str = field( 111 | default="training", 112 | metadata={"help": "the task type"} 113 | ) 114 | 115 | 116 | eval_at_start: bool = field( 117 | default=False, 118 | metadata={"help": "whether make eval at start."} 119 | ) 120 | 121 | debug_mode: bool = field( 122 | default=False, 123 | metadata={"help": "whether use the debug mode."} 124 | ) 125 | 126 | cache_dir: Optional[str] = field(default=None) 127 | 128 | optim: str = field(default="adamw_torch", metadata={"help": "the paramter to use"}) 129 | 130 | apo_loss_type: str = field(default="ranking", metadata={"help": "use `ranking` or `diff` loss for apo"}) 131 | 132 | apo_loss_coeff: float = field(default=0., metadata={"help": "the coefficient for apo loss."}) 133 | 134 | lm_loss_coeff: float = field(default=0., metadata={"help": "the coefficient for language modeling loss."}) 135 | 136 | rm_kl_coeff: float = field(default=1., metadata={"help": "the coefficient for apo rm kl regularizer."}) 137 | 138 | contrast_loss_coeff: float = field(default=0., metadata={"help": "the coefficient for contrastive learning loss."}) 139 | 140 | lm_score_thresh: float = field(default=0.85, metadata={"help": "the threshold to select response for language modeling"}) 141 | 142 | max_length: int = field( 143 | default=256, 144 | metadata={"help": "the max sentence sequence length."} 145 | ) 146 | 147 | batch_size: int = field( 148 | default=256, 149 | metadata={"help": "the overall training batch size"} 150 | ) 151 | 152 | micro_batch_size: int = field( 153 | default=32, 154 | metadata={"help": "the batch size on each device, equavilent to `per_gpu_train_batch_size`"} 155 | ) 156 | 157 | 158 | valid_data_size: int = field( 159 | default=0, 160 | metadata={"help": "the data size for validation data"} 161 | ) 162 | 163 | resume_from_checkpoint: Optional[str] = field( 164 | default=None, 165 | metadata={"help": "either training checkpoint or final adapter"} 166 | ) 167 | # generation parameters: 168 | max_new_tokens: int = field( 169 | default=256, 170 | metadata={"help": "the max sentence sequence length."} 171 | ) 172 | 173 | # evaluation parameters: 174 | rm_calibration: bool = field( 175 | default=False, 176 | metadata={"help": "whether evaluate the calibration score for RM"} 177 | ) 178 | 179 | calibration_bins: List[int] = field( 180 | default_factory=lambda: [10], 181 | metadata={"help": "number of bins for RM calibration"} 182 | ) 183 | 184 | 185 | save_calibration: bool = field( 186 | default=False, 187 | metadata={"help": "whether save the calibration results for RM"} 188 | ) 189 | 190 | -------------------------------------------------------------------------------- /configs/default_offload_opt_param.json: -------------------------------------------------------------------------------- 1 | { 2 | "bf16": { 3 | "enabled": "auto" 4 | }, 5 | "optimizer": { 6 | "type": "AdamW", 7 | "params": { 8 | "lr": "auto", 9 | "betas": "auto", 10 | "eps": "auto", 11 | "weight_decay": "auto" 12 | } 13 | }, 14 | "scheduler": { 15 | "type": "WarmupDecayLR", 16 | "params": { 17 | "total_num_steps": "auto", 18 | "warmup_min_lr": "auto", 19 | "warmup_max_lr": "auto", 20 | "warmup_num_steps": "auto" 21 | } 22 | }, 23 | "zero_optimization": { 24 | "stage": 3, 25 | "offload_optimizer": { 26 | "device": "cpu", 27 | "pin_memory": true 28 | }, 29 | "offload_param": { 30 | "device": "cpu", 31 | "pin_memory": true 32 | }, 33 | "overlap_comm": true, 34 | "contiguous_gradients": true, 35 | "sub_group_size": 1e9, 36 | "reduce_bucket_size": "auto", 37 | "stage3_prefetch_bucket_size": "auto", 38 | "stage3_param_persistence_threshold": "auto", 39 | "stage3_max_live_parameters": 1e9, 40 | "stage3_max_reuse_distance": 1e9, 41 | "stage3_gather_16bit_weights_on_model_save": true 42 | }, 43 | "gradient_accumulation_steps": "auto", 44 | "gradient_clipping": "auto", 45 | "steps_per_print": 5, 46 | "train_batch_size": "auto", 47 | "train_micro_batch_size_per_gpu": "auto", 48 | "wall_clock_breakdown": false 49 | } 50 | -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- 1 | # APO Training & Evaluating Data 2 | 3 | we have move the data to [Google Drive](https://drive.google.com/drive/folders/1v0xNMMOfL9lfFLzTGCerZCPNPJrR9ZLX) due to the GitHub LFS storage limitation. 4 | 5 | The data separation can be found below: 6 | 7 | 8 | | Data Type| HH-RM Train Set | HH-LLM Train Set| HH Test Set| 9 | | --------:| :----------|:-------| :--------| 10 | | Preference Pairs | [RM training set](https://drive.google.com/file/d/12DefElb3DazIPeaIEwd0B_9La84Slc7f/view?usp=sharing) | [RM validation set](https://drive.google.com/file/d/1ZqTuupFxrK2m3_E6ezMRcdT_4k6zX-IW/view?usp=sharing) | [RM testing set](https://drive.google.com/file/d/1ite1KXZlGs1ojCVB20rLHlj7_3KlOULY/view?usp=sharing)| 11 | | Golden Answers | [APO positive responses](https://drive.google.com/file/d/1hDo6Sk8QX1c3kP_qJUgZ4J16kHAi0hEq/view?usp=sharing) | - | -| 12 | |User Queries | [APO negative responses](https://drive.google.com/file/d/1_wiKVKob6QVOHja4C_N-y5LlvHZE9ZiZ/view?usp=sharing) (Alpaca samples)| [LLM (Alpaca) rejection samples](https://drive.google.com/file/d/1ZpAXK0F-YC919_vP7gnyGpo8ezQGIv5O/view?usp=sharing)| [LLM testing Queries](https://drive.google.com/file/d/1ite1KXZlGs1ojCVB20rLHlj7_3KlOULY/view?usp=drive_link)| 13 | -------------------------------------------------------------------------------- /figures/apo_framework_v.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Linear95/APO/4282775cf9f7dcfe04ed014835bb9d07cae5fbae/figures/apo_framework_v.png -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import List, Optional, Tuple, Union 3 | from pprint import pprint 4 | 5 | import torch 6 | import torch.utils.checkpoint 7 | from torch import nn 8 | 9 | from transformers.modeling_outputs import SequenceClassifierOutputWithPast 10 | from transformers import LlamaModel, LlamaForCausalLM, LlamaPreTrainedModel, LlamaTokenizer 11 | from transformers import BertModel, BertPreTrainedModel 12 | 13 | 14 | class LlamaRewardModel(LlamaPreTrainedModel): 15 | def __init__(self, config): 16 | super().__init__(config) 17 | self.model = LlamaModel(config) 18 | self.reward_head = nn.Linear(config.hidden_size, 1, bias=False) 19 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) 20 | self.post_init() 21 | 22 | def get_input_embeddings(self): 23 | return self.model.embed_tokens 24 | 25 | def set_input_embeddings(self, value): 26 | self.model.embed_tokens = value 27 | 28 | def floating_point_ops(self, inputs): 29 | return 0 30 | 31 | def forward( 32 | self, 33 | input_ids: torch.LongTensor = None, 34 | attention_mask: Optional[torch.Tensor] = None, 35 | position_ids: Optional[torch.LongTensor] = None, 36 | past_key_values: Optional[List[torch.FloatTensor]] = None, 37 | inputs_embeds: Optional[torch.FloatTensor] = None, 38 | labels: Optional[torch.LongTensor] = None, 39 | pooling_type: str = "average", 40 | padding_side: str = "right", 41 | use_cache: Optional[bool] = None, 42 | output_attentions: Optional[bool] = None, 43 | output_hidden_states: Optional[bool] = None, 44 | return_dict: Optional[bool] = None, 45 | ): 46 | r""" 47 | labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): 48 | Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., 49 | config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If 50 | `config.num_labels > 1` a classification loss is computed (Cross-Entropy). 51 | """ 52 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 53 | 54 | transformer_outputs = self.model( 55 | input_ids, 56 | attention_mask=attention_mask, 57 | position_ids=position_ids, 58 | past_key_values=past_key_values, 59 | inputs_embeds=inputs_embeds, 60 | use_cache=use_cache, 61 | output_attentions=output_attentions, 62 | output_hidden_states=output_hidden_states, 63 | return_dict=return_dict, 64 | ) 65 | hidden_states = transformer_outputs[0] 66 | 67 | lm_logits = self.lm_head(hidden_states) 68 | 69 | 70 | if input_ids is not None: 71 | batch_size = input_ids.shape[0] 72 | else: 73 | batch_size = inputs_embeds.shape[0] 74 | 75 | if self.config.pad_token_id is None and batch_size != 1: 76 | raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") 77 | if self.config.pad_token_id is None: 78 | sequence_lengths = -1 79 | else: 80 | if input_ids is not None: 81 | sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(dim=-1)).to(hidden_states.device) 82 | else: 83 | sequence_lengths = -1 84 | 85 | if attention_mask is None: 86 | attention_mask = torch.ne(input_ids, self.config.pad_token_id).float() 87 | 88 | # print("hidden_states shape {}".format(hidden_states.shape)) 89 | # print("attention_mask shape {}".format(attention_mask.shape)) 90 | 91 | attention_mask_ext = attention_mask.unsqueeze(-1) 92 | if pooling_type in ["last", "eos"]: 93 | offset = 1 if pooling_type == "eos" else 2 94 | if padding_side == "right": 95 | pooled_hidden_state = hidden_states[torch.arange(batch_size, device=hidden_states.device), sequence_lengths - offset] 96 | else: 97 | pooled_hidden_state = hidden_states[torch.arange(batch_size, device=hidden_states.device), - offset] 98 | 99 | elif pooling_type == "average": 100 | pooled_hidden_state = (hidden_states * attention_mask_ext).sum(dim=1) / attention_mask_ext.sum(dim=1) 101 | elif pooling_type == "max": 102 | pooled_hidden_state = (hidden_states * attention_mask_ext).max(dim=1)[0] 103 | else: 104 | raise ValueError("The pooling method {} is not implemented!!".format(pooling_type)) 105 | 106 | pooled_logits = self.reward_head(pooled_hidden_state) 107 | 108 | #pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths] 109 | 110 | return { 111 | "lm_logits": lm_logits, 112 | "rm_logits": pooled_logits, 113 | "hidden_states": transformer_outputs[0], 114 | "rm_embeddings": pooled_hidden_state 115 | } 116 | 117 | 118 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy 2 | scikit-learn 3 | fire 4 | transformers==4.28.1 5 | torch==2.0.0 6 | sentencepiece 7 | tokenizers>=0.13.3 8 | wandb 9 | datasets 10 | accelerate==0.20.3 11 | deepspeed==0.12.6 12 | pydantic==1.10.7 13 | -------------------------------------------------------------------------------- /reward_datasets.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | from tqdm import tqdm 4 | import gzip 5 | import random 6 | from copy import deepcopy 7 | 8 | from utils import print_rank_0 9 | from pprint import pprint 10 | import numpy as np 11 | 12 | import torch 13 | from torch.utils.data import Dataset 14 | 15 | from transformers import LlamaTokenizer 16 | 17 | from datasets import load_dataset 18 | from utils import read_json_or_jsonl_data 19 | from utils import DEFAULT_PAD_TOKEN, DEFAULT_BOS_TOKEN, DEFAULT_EOS_TOKEN, DEFAULT_UNK_TOKEN 20 | from utils import QUERY_PROMPT, SEP_TOKEN, STRING_SEP 21 | 22 | 23 | class TextRewardDataset(Dataset): 24 | def __init__(self, data): 25 | self.data = data 26 | 27 | def __getitem__(self, index): 28 | return self.data[index] 29 | 30 | def __len__(self,): 31 | return len(self.data) 32 | 33 | 34 | def reward_data_collactor(args, batch, tokenizer): 35 | input_ids, attention_mask = [], [] 36 | query_ids, text, scores, apo_data_mask = [], [], [], [] 37 | 38 | max_response_num = max([len(item['scores']) for item in batch]) 39 | if args.debug_mode: 40 | print_rank_0(">>> response padding number: {}".format(max_response_num)) 41 | 42 | for item1 in batch: 43 | item = prepare_data_item(args, item1, 44 | tokenizer=tokenizer, 45 | padding=(not len(batch) == 1), 46 | max_response_num=max_response_num) 47 | 48 | scores.append(item['scores']) 49 | input_ids.append(item['tokens']['input_ids']) 50 | attention_mask.append(item['tokens']['attention_mask']) 51 | text.append(item['text']) 52 | 53 | if item.get("type", "hh") == 'apo': 54 | apo_data_mask.append(1) 55 | # coeffs.append(args.apo_loss_coeff / args.apo_sample_num) 56 | else: 57 | apo_data_mask.append(0) 58 | # coeffs.append(args.rm_kl_coeff) 59 | 60 | if "query_ids" in item: 61 | query_ids.append(item['query_ids']) 62 | 63 | if len(query_ids) > 0: 64 | assert len(query_ids) == len(scores), f"not all items have key:query_id, in {batch}" 65 | 66 | 67 | return { 68 | "scores": scores, 69 | "input_ids": input_ids, 70 | "attention_mask": attention_mask, 71 | "query_ids": query_ids, 72 | "text": text, 73 | "apo_data_mask": apo_data_mask 74 | # "coeffs": coeffs 75 | } 76 | 77 | 78 | def reward_tokenize(sentences, tokenizer, padding="longest", add_sep_token=False): 79 | if isinstance(sentences, str): 80 | sentences = [sentences] 81 | 82 | input_ids = [] 83 | for sent in sentences: 84 | if add_sep_token: 85 | query, response = sent.split(SEP_TOKEN) 86 | query_ids = tokenizer.encode(query, add_special_tokens=False) 87 | response_ids = tokenizer.encode(response, add_special_tokens=False) 88 | input_ids.append( 89 | [tokenizer.bos_token_id] + query_ids + [tokenizer.sep_token_id] + response_ids + [tokenizer.eos_token_id] 90 | ) 91 | else: 92 | if SEP_TOKEN in sent: 93 | query, response = sent.split(SEP_TOKEN) 94 | query_ids = tokenizer.encode(query, add_special_tokens=False) 95 | response_ids = tokenizer.encode(response, add_special_tokens=False) 96 | input_ids.append( 97 | [tokenizer.bos_token_id] + query_ids + response_ids + [tokenizer.eos_token_id] 98 | ) 99 | else: 100 | input_ids.append( 101 | [tokenizer.bos_token_id] + tokenizer.encode(sent, add_special_tokens=False) + [tokenizer.eos_token_id] 102 | ) 103 | 104 | return batch_padding(input_ids, tokenizer, padding=padding) 105 | 106 | 107 | def batch_padding(input_ids, tokenizer, padding='longest'): 108 | if padding == 'longest': 109 | max_input_length = max([len(inp_ids) for inp_ids in input_ids]) 110 | max_length = min(tokenizer.model_max_length, max_input_length) 111 | else: 112 | max_length = tokenizer.model_max_length 113 | 114 | outputs = {"input_ids": [], "attention_mask": []} 115 | for inp_ids in input_ids: 116 | attn_mask = [1] * len(inp_ids) 117 | if len(inp_ids) >= max_length: 118 | if tokenizer.truncation_side == 'left': 119 | inp_ids = inp_ids[-max_length :] 120 | attn_mask = attn_mask[-max_length :] 121 | else: 122 | inp_ids = inp_ids[:max_length] 123 | attn_mask = attn_mask[:max_length] 124 | else: 125 | if tokenizer.padding_side == 'left': 126 | inp_ids = [tokenizer.pad_token_id] * (max_length - len(inp_ids)) + inp_ids 127 | attn_mask = [0] * (max_length - len(attn_mask)) + attn_mask 128 | else: 129 | inp_ids = inp_ids + [tokenizer.pad_token_id] * (max_length - len(inp_ids)) 130 | attn_mask = attn_mask + [0] * (max_length - len(attn_mask)) 131 | 132 | outputs['input_ids'].append(deepcopy(inp_ids)) 133 | outputs['attention_mask'].append(deepcopy(attn_mask)) 134 | return outputs 135 | 136 | 137 | def prepare_data_item(args, item, tokenizer=None, padding=False, max_response_num=1): 138 | new_item = deepcopy(item) 139 | if not len(new_item['scores']) == len(new_item['text']): 140 | ValueError("invalid data point {}".format(new_item)) 141 | return None 142 | 143 | 144 | if "query_ids" in new_item and not len(new_item['scores']) == len(new_item['query_ids']): 145 | ValueError("invalid data point {}".format(new_item)) 146 | return None 147 | 148 | # score_idx = np.argsort(new_item['scores']) 149 | max_score = max(new_item['scores']) + 1e-5 150 | min_score = min(new_item['scores']) - 1e-5 151 | new_item['scores'] = [(score - min_score) / (max_score -min_score) for score in new_item['scores']] 152 | 153 | if padding: 154 | new_item['text'] += ["\n\nHuman: ?\n\nAssistant: Some"] * (max_response_num - len(new_item['text'])) 155 | new_item['scores'] += [-1.] * (max_response_num - len(new_item['scores'])) 156 | if "query_ids" in new_item: 157 | new_item['query_ids'] += [ "unk" + STRING_SEP + "pad" + STRING_SEP + "unk"] * (max_response_num - len(new_item['query_ids'])) 158 | 159 | 160 | if tokenizer is not None: 161 | try: 162 | new_item['tokens'] = reward_tokenize( 163 | sentences=new_item['text'], 164 | tokenizer=tokenizer, 165 | padding="max_length" if padding else "longest", 166 | add_sep_token=args.add_sep_token 167 | ) 168 | except: 169 | raise ValueError(f"get tokenization error with {new_item}") 170 | 171 | return new_item 172 | 173 | 174 | 175 | def load_rejection_samples(data_path): 176 | data_list = read_json_or_jsonl_data(data_path) 177 | outputs = [] 178 | for item in data_list: 179 | # print_rank_0(item) 180 | if 'query' in item: 181 | query = str(item['query']) 182 | else: 183 | query = str(item['instruction']) 184 | 185 | query_id = str(item['query_id']) 186 | 187 | for key in item: 188 | #if "hh_best" in key or "gpt4" in key: 189 | if "sample_" in key or "gpt4" in key or 'ans_' in key: 190 | outputs.append({ 191 | "text": [ query + SEP_TOKEN + str(item[key])], 192 | "query_ids": [ data_path + STRING_SEP + query_id + STRING_SEP + key], 193 | "scores": [-1] 194 | }) 195 | print(f">>> totally get {len(outputs)} rejection samples.") 196 | print(outputs[0]) 197 | return outputs 198 | 199 | 200 | def load_text_score_dataset(args, data_path): 201 | print_rank_0("loading text-scores dataset from: \n {}".format(data_path)) 202 | 203 | if args.data_type == "reject_sample": 204 | data_list = load_rejection_samples(data_path) 205 | else: 206 | data_list = read_json_or_jsonl_data(data_path) 207 | for item in data_list: 208 | item['query_ids'] = [os.path.split(data_path)[1]] * len(item['text']) 209 | 210 | 211 | 212 | print_rank_0("finished loading with {} data.".format(len(data_list))) 213 | return data_list 214 | 215 | 216 | -------------------------------------------------------------------------------- /tools/apo_data_converter.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import json 4 | import argparse 5 | 6 | from pprint import pprint 7 | from tqdm import tqdm 8 | 9 | def preprocess_response(response): 10 | while "\nHuman" in response: 11 | # remove the additional generation of LLM after the current turn responses. 12 | response = response.split("\nHuman")[0].strip() 13 | 14 | return response 15 | 16 | 17 | def convert_item(item, sampling=False): 18 | sample_names = ['sample_0', 'sample_1', 'sample_2', 'sample_3'] 19 | if "\nHuman:" in item['golden']: 20 | print(item) 21 | gpt_response = preprocess_response(item['golden']) 22 | 23 | if sampling: 24 | sample_names = [random.choice(sample_names)] 25 | 26 | outputs = [] 27 | for sample_name in sample_names: 28 | query = item['query'] 29 | query_id = str(item['query_id']) 30 | res_response = preprocess_response(item[sample_name]) 31 | data_point = { 32 | "text": [query+''+gpt_response, query+''+res_response], 33 | "scores": [1., 0.], 34 | "type": "apo" 35 | } 36 | outputs.append(data_point) 37 | 38 | return outputs 39 | 40 | 41 | if __name__ == '__main__': 42 | 43 | parser = argparse.ArgumentParser(description ='parser for preference data processing.') 44 | parser.add_argument("--golden_data_path", type=str, default="", help="the path to golden annotation data.") 45 | parser.add_argument("--sample_data_path", type=str, default="", help="the path to llm sample data.") 46 | parser.add_argument("--output_dir", type=str, default="", help="the path to output converted data.") 47 | parser.add_argument("--apo_data_name", type=str, default="", help="the path to output converted data.") 48 | parser.add_argument("--sampling", action="store_true", help="whether random select one of the llm sample for each query") 49 | args = parser.parse_args() 50 | 51 | 52 | with open(args.sample_data_path, 'r') as f: 53 | sft_samples = json.load(f) 54 | print(f'finished loadding {len(sft_samples)} samples') 55 | 56 | with open(args.golden_data_path, 'r') as f: 57 | golden_samples = json.load(f) 58 | 59 | print(f'finished loadding {len(golden_samples)} samples') 60 | 61 | merged_data = {} 62 | 63 | for item in tqdm(sft_samples): 64 | query_id = str(item['query_id']) 65 | merged_data[query_id] = item 66 | 67 | for item in tqdm(golden_samples): 68 | query_id = str(item['query_id']) 69 | merged_data[query_id]['golden'] = item['golden'] 70 | 71 | score_dict = None 72 | outputs = [] 73 | for query_id, item in merged_data.items(): 74 | new_results = convert_item(item, sampling=args.sampling) 75 | outputs.extend(new_results) 76 | # except: 77 | # pprint(item1) 78 | # error_count += 1 79 | 80 | # print(f"get {error_count} error items") 81 | 82 | if not os.path.exists(args.output_dir): 83 | os.mkdir(args.output_dir) 84 | 85 | if args.sampling: 86 | output_path = f"{args.output_dir}/{args.apo_data_name}_sampled_text_scores.json" 87 | else: 88 | output_path = f"{args.output_dir}/{args.apo_data_name}_text_scores.json" 89 | 90 | print(f'finished processing {len(outputs)} data at {output_path}') 91 | with open(output_path, 'w') as f: 92 | json.dump(outputs, f, ensure_ascii=False, indent=2) 93 | 94 | 95 | -------------------------------------------------------------------------------- /tools/convert_apo_data.sh: -------------------------------------------------------------------------------- 1 | 2 | REPO_DIR=path/to/APO/repo 3 | DATA_DIR="${REPO_DIR}/data/hh-split" 4 | 5 | python3 ${REPO_DIR}/tools/apo_data_converter.py \ 6 | --golden_data_path ${DATA_DIR}/rm_data/hh_split_rm.golden.json \ 7 | --sample_data_path ${DATA_DIR}/rm_data/hh_split_rm_alpaca_v0.sample.json \ 8 | --output_dir ${DATA_DIR}/apo_data \ 9 | --apo_data_name "rm_apo_data_v0" 10 | -------------------------------------------------------------------------------- /tools/inference_llm.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | from copy import deepcopy 4 | import json 5 | import glob 6 | from dataclasses import dataclass 7 | from typing import Dict, Sequence 8 | from tqdm import tqdm 9 | 10 | 11 | import torch 12 | import torch.distributed as dist 13 | from torch.nn.parallel import DistributedDataParallel as DDP 14 | from torch.utils.data import Dataset, DataLoader 15 | 16 | import transformers 17 | from transformers import GenerationConfig, AutoModelForCausalLM, AutoTokenizer, LlamaForCausalLM 18 | from datasets import load_dataset 19 | from arguments import CustomTrainingArguments 20 | 21 | from utils import print_rank_0, read_json_or_jsonl_data, SEP_TOKEN 22 | from utils import DEFAULT_PAD_TOKEN, DEFAULT_EOS_TOKEN, DEFAULT_BOS_TOKEN, DEFAULT_UNK_TOKEN 23 | 24 | from reward_datasets import TextRewardDataset, batch_padding 25 | 26 | IGNORE_INDEX = -100 27 | PROMPT_DICT = { 28 | "prompt_input": ( 29 | "\n\nHuman: {instruction}\n{input}\n\nAssistant: " 30 | ), 31 | "prompt_no_input": ( 32 | "\n\nHuman: {instruction}\n\nAssistant: " 33 | ), 34 | } 35 | 36 | B_INST, E_INST = "[INST]", "[/INST]" # for llama-2-chat 37 | 38 | def load_query_data_for_generation(args, data_path): 39 | all_data = read_json_or_jsonl_data(data_path) 40 | outputs = [] 41 | for idx, item in enumerate(all_data): 42 | if args.data_type == "comparison_pair": 43 | query = item['text'][0].split(SEP_TOKEN)[0] 44 | outputs.append({ 45 | "query": query, 46 | "query_id": item.get("query_id", str(idx)) 47 | }) 48 | else: 49 | outputs.append({ 50 | 'query': item['query'], 51 | "query_id": item.get("query_id", str(idx)) 52 | }) 53 | return TextRewardDataset(outputs) 54 | 55 | 56 | def query_data_collactor(args, batch, tokenizer): 57 | input_ids, attention_mask, labels = [], [], [] 58 | text = [item['query'] for item in batch] 59 | query_ids = [item['query_id'] for item in batch] 60 | 61 | for sent in text: 62 | if args.model_prefix == "llama-2-chat": 63 | # check details at https://huggingface.co/meta-llama/Llama-2-7b-chat 64 | sent = sent.replace("\nAssistant", f" {E_INST} ").replace("\nHuman", f" {tokenizer.eos_token} {tokenizer.bos_token} {B_INST} ") 65 | sent = sent.strip().strip(tokenizer.eos_token) 66 | input_query_ids = tokenizer.encode(sent, add_special_tokens=False) 67 | 68 | else: 69 | input_query_ids = tokenizer.encode(sent) 70 | 71 | input_ids.append(input_query_ids) 72 | 73 | outputs = batch_padding(input_ids, tokenizer) 74 | outputs['query_ids'] = query_ids 75 | outputs['text'] = text 76 | return outputs 77 | 78 | 79 | def main(): 80 | parser = transformers.HfArgumentParser(CustomTrainingArguments) 81 | args = parser.parse_args_into_dataclasses()[0] 82 | 83 | # setup model 84 | #--------------------------------------------------------------------------------- 85 | device = torch.cuda.current_device() 86 | print_rank_0(f"start loading model from {args.model_name_or_path}") 87 | model = LlamaForCausalLM.from_pretrained( 88 | args.model_name_or_path, 89 | # torch_dtype=torch.float16, 90 | ) 91 | print_rank_0(model) 92 | 93 | tokenizer = AutoTokenizer.from_pretrained( 94 | args.model_name_or_path, 95 | padding_side="left", # for batch decode 96 | truncation_side='left', 97 | model_max_length=args.max_length, 98 | ) 99 | 100 | if tokenizer.pad_token is None: 101 | tokenizer.pad_token = tokenizer.eos_token 102 | tokenizer.pad_token_id = 0 103 | # tokenizer.pad_token = DEFAULT_PAD_TOKEN 104 | # smart_tokenizer_and_embedding_resize( 105 | # special_tokens_dict=dict(pad_token=DEFAULT_PAD_TOKEN), 106 | # tokenizer=tokenizer, 107 | # model=model, 108 | # ) 109 | 110 | eval_dataset = load_query_data_for_generation(args, args.data_path) 111 | 112 | sampler = torch.utils.data.distributed.DistributedSampler(eval_dataset, shuffle=False) 113 | dataloader = DataLoader( 114 | eval_dataset, 115 | shuffle=False, 116 | collate_fn=lambda x: query_data_collactor(args, x, tokenizer), 117 | batch_size=args.per_device_eval_batch_size, 118 | sampler=sampler, 119 | ) 120 | 121 | if args.task_type == "testing": 122 | generation_config = GenerationConfig( 123 | temperature=0.3, 124 | do_sample=True, 125 | max_new_tokens=512, 126 | top_k=5, 127 | top_p=0.85, 128 | bos_token_id=tokenizer.bos_token_id, 129 | eos_token_id=tokenizer.eos_token_id, 130 | pad_token_id=0, 131 | repetition_penalty=1.05, 132 | num_return_sequences=1, 133 | ) 134 | elif args.task_type == "sampling": 135 | if args.model_prefix == "llama-2-chat": 136 | temperature = 0.6 137 | top_p=0.9 138 | else: 139 | temperature = 1.2 140 | top_p=1. 141 | 142 | generation_config = GenerationConfig( 143 | temperature=temperature, # default=0.8 144 | do_sample=True, 145 | min_length=1, 146 | max_new_tokens=256, 147 | top_p=top_p, 148 | bos_token_id=tokenizer.bos_token_id, 149 | eos_token_id=tokenizer.eos_token_id, 150 | pad_token_id=0, 151 | num_return_sequences=4, 152 | ) 153 | 154 | 155 | model.to(device) 156 | model.eval() 157 | 158 | all_outputs = [] 159 | progress_bar = tqdm(range(len(dataloader)), disable=(dist.get_rank() != 0)) 160 | for step, batch in enumerate(dataloader): 161 | progress_bar.update(1) 162 | input_ids = torch.Tensor(batch['input_ids']).long().to(model.device) 163 | attention_mask = torch.Tensor(batch['attention_mask']).float().to(model.device) 164 | query_ids = batch['query_ids'] 165 | text = batch['text'] 166 | 167 | batch_size = input_ids.shape[0] 168 | 169 | with torch.no_grad(): 170 | generation_output = model.generate( 171 | input_ids=input_ids, 172 | attention_mask=attention_mask, 173 | generation_config=generation_config, 174 | return_dict_in_generate=True, 175 | ) 176 | output_seq = generation_output.sequences.reshape(batch_size, generation_config.num_return_sequences, -1) 177 | 178 | inputs_string = tokenizer.batch_decode(input_ids.reshape(batch_size, -1), skip_special_tokens=True) 179 | 180 | for idx in range(len(inputs_string)): 181 | new_item = {"query_id": query_ids[idx], "query": text[idx]} 182 | output_responses = tokenizer.batch_decode(output_seq[idx], skip_special_tokens=True) 183 | for res_idx, output_res in enumerate(output_responses): 184 | response_sample = output_res.replace(inputs_string[idx], '') 185 | if args.model_prefix == "llama-2-chat": 186 | #sent = sent.replace("\nAssistant", f" {E_INST} ").replace("\nHuman", f" {tokenizer.eos_token} {tokenizer.bos_token} {B_INST} ") 187 | response_sample = response_sample.replace(E_INST, "\nAssistant").replace(B_INST, "\nHuman") 188 | #response_sample = response_sample.replace(E_INST, "\n\nAssistant:").replace(B_INST, "\n\nHuman:") 189 | 190 | new_item[f"sample_{res_idx}"] = response_sample 191 | 192 | all_outputs.append(new_item) 193 | 194 | if dist.get_rank() == 0 and (step % 10 == 0): 195 | print_rank_0(f"finished {step} of {len(dataloader)}") 196 | print_rank_0(all_outputs[-1]) 197 | 198 | 199 | output_file_prefix = f"{args.output_dir}/{args.model_prefix}_{args.task_type}_{args.data_suffix}" 200 | with open(f"{output_file_prefix}_rank{dist.get_rank()}.json", 'w') as f: 201 | json.dump(all_outputs, f, ensure_ascii=False, indent=2) 202 | print(f"rank {dist.get_rank()} finishs inference.") 203 | 204 | del model 205 | torch.cuda.empty_cache() 206 | dist.barrier() 207 | if dist.get_rank() == 0: 208 | result_paths = glob.glob(f"{output_file_prefix}_rank*.json") 209 | all_results = [] 210 | for res_path in result_paths: 211 | new_results = read_json_or_jsonl_data(res_path) 212 | all_results.extend(new_results) 213 | 214 | print(f"totally loaded {len(all_results)} results") 215 | with open(f"{output_file_prefix}_results.json", 'w') as f: 216 | json.dump(all_results, f, ensure_ascii=False, indent=2) 217 | print(f"finished inference results merge.") 218 | 219 | if __name__ == "__main__": 220 | main() 221 | -------------------------------------------------------------------------------- /tools/llm_response_gen.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | 3 | REPO_DIR= 4 | export PYTHONPATH=${REPO_DIR} 5 | 6 | 7 | MODEL_DIR="chavinlo/alpaca-native" 8 | MODEL_NAME="alpaca" 9 | 10 | #TASK_TYPE="testing" 11 | TASK_TYPE="sampling" 12 | 13 | DATA_DIR=${REPO_DIR}/data/hh-split 14 | if [[ "${TASK_TYPE}" == "testing" ]]; then 15 | DATA_PATH=${DATA_DIR}/eval_data/hh_cleaned_origin.test.json 16 | DATA_NAME="hh_test" 17 | DATA_TYPE="comparison_pair" 18 | else 19 | DATA_DIR=${REPO_DIR}/data/hh-split 20 | DATA_PATH=${DATA_DIR}/llm_data/hh_split_llm.train.json 21 | DATA_NAME="hh_llm_train" 22 | DATA_TYPE="comparison_pair" 23 | fi 24 | 25 | OUTPUT_DIR=${DATA_DIR}/sample_data 26 | mkdir -p $OUTPUT_DIR 27 | 28 | 29 | EVAL_MICRO_BATCH_SIZE=1 30 | MAX_INPUT_LENGTH=512 31 | 32 | torchrun --nproc_per_node 8 --master_port 6000 ${REPO_DIR}/tools/inference_llm.py \ 33 | --model_name_or_path $MODEL_DIR \ 34 | --model_prefix ${MODEL_NAME} \ 35 | --data_path $DATA_PATH \ 36 | --output_dir $OUTPUT_DIR \ 37 | --per_device_eval_batch_size $EVAL_MICRO_BATCH_SIZE \ 38 | --task_type ${TASK_TYPE} \ 39 | --data_suffix ${DATA_NAME} \ 40 | --max_length ${MAX_INPUT_LENGTH} \ 41 | --data_type ${DATA_TYPE} 42 | -------------------------------------------------------------------------------- /tools/rejection_sampling.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import json 4 | import argparse 5 | import glob 6 | from copy import deepcopy 7 | from pprint import pprint 8 | 9 | def get_best_key(item, item_scores, filter_pattern=False): 10 | max_score = -100000000. 11 | result = None 12 | for key, value in item_scores.items(): 13 | if len(item_scores) > 1 and key == "hh_best": 14 | continue 15 | if value > max_score: 16 | if item[key].strip() == "": 17 | continue 18 | else: 19 | result = deepcopy(key) 20 | max_score = value 21 | return result 22 | 23 | def get_scores_from_list(score_list): 24 | score_dict = {} 25 | for item in score_list: 26 | for key, value in item.items(): 27 | query_id, ans_id = key.split(':') 28 | if query_id in score_dict: 29 | if ans_id in score_dict[query_id] and value != score_dict[query_id][ans_id]: 30 | print(f">>>>> warning!") 31 | print(f">>>>> replacing {query_id}: {ans_id} value {score_dict[query_id][ans_id]} with {value}") 32 | 33 | score_dict[query_id][ans_id] = value 34 | else: 35 | score_dict[query_id] = {ans_id: value} 36 | return score_dict 37 | 38 | def get_scores(data_path, rm_scorer): 39 | file_names = glob.glob(f"{data_path}_*pred_{rm_scorer}*rank*.jsonl") 40 | score_dict = {} 41 | for file_name in file_names: 42 | with open(file_name, 'r') as f: 43 | lines = f.readlines() 44 | scores = [json.loads(l.strip()) for l in lines] 45 | for item in scores: 46 | for key, value in item.items(): 47 | query_id, ans_id = key.split(':') 48 | if query_id in score_dict: 49 | if ans_id in score_dict[query_id] and value != score_dict[query_id][ans_id]: 50 | print(f">>>>> warning!") 51 | print(f">>>>> replacing {query_id}: {ans_id} value {score_dict[query_id][ans_id]} with {value}") 52 | 53 | score_dict[query_id][ans_id] = value 54 | else: 55 | score_dict[query_id] = {ans_id: value} 56 | return score_dict 57 | 58 | def rejection_sample(data_path, score_path=None, rm_scorer=None): 59 | with open(data_path, 'r') as f: 60 | data_list = json.load(f) 61 | 62 | print(f"totally load {len(data_list)} samples for rejection sampling") 63 | 64 | if score_path is not None: 65 | with open(score_path, 'r') as f: 66 | score_list = json.load(f) 67 | data_scores = get_scores_from_list(score_list) 68 | elif rm_scorer is not None: 69 | data_scores = get_scores(data_path, rm_scorer) 70 | else: 71 | raise ValueError('cannot found score data') 72 | 73 | hh_best_counter = 0 74 | outputs = [] 75 | for item in data_list:#[:10]: 76 | query_id = str(item['query_id']) 77 | item_scores = data_scores[query_id] 78 | 79 | 80 | #best_res_key = max(item_scores, key=item_scores.get) 81 | best_res_key = get_best_key(item, item_scores, filter_pattern=True) 82 | if best_res_key is None: 83 | best_res_key = get_best_key(item, item_scores, filter_pattern=False) 84 | if best_res_key is None: 85 | print(item) 86 | continue 87 | 88 | item['target'] = item[best_res_key] 89 | item['scores'] = item_scores 90 | 91 | if best_res_key == "hh_best": 92 | hh_best_counter += 1 93 | outputs.append(deepcopy(item)) 94 | print(f"get {hh_best_counter} data with hh_best selected") 95 | return outputs 96 | 97 | 98 | if __name__ == "__main__": 99 | parser = argparse.ArgumentParser(description ='parser for preference data processing.') 100 | parser.add_argument("--data_path", type=str, default="", help="the path to input data.") 101 | parser.add_argument("--output_dir", type=str, default="", help="the path to output data.") 102 | parser.add_argument("--output_file_name", type=str, default="", help="the path to output data.") 103 | parser.add_argument("--score_path", type=str, default="", help="the rm model name to get score") 104 | parser.add_argument("--rm_scorer", type=str, default="", help="the rm model name to get score") 105 | 106 | parser.add_argument("--domain", type=str, default="general", help="the domain of the preference data, selected from [general, normal, academy, business, entertainment, literature].") 107 | 108 | parser.add_argument("--convert", action='store_true', help="whether convert responses into the preference text-score format.") 109 | parser.add_argument("--to_pairs", action='store_true', help="whether convert responses into pair comparisons.") 110 | 111 | args = parser.parse_args() 112 | #outputs = rejection_sample(args.data_path, f"{args.data_path}_{args.rm_scorer}_prediction.json") 113 | outputs = rejection_sample(args.data_path, args.score_path, args.rm_scorer) 114 | 115 | if len(args.output_file_name) == 0: 116 | 117 | _, file_name = os.path.split(args.score_path) 118 | print(file_name) 119 | args.output_file_name = f"{args.output_dir}/{file_name}_sft.json" 120 | 121 | with open(f"{args.output_file_name}", 'w', encoding="utf-8") as f: 122 | json.dump(outputs, f, ensure_ascii=False, indent=2) 123 | 124 | 125 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import copy 3 | import logging 4 | from dataclasses import dataclass, field 5 | from typing import Dict, Optional, Sequence, List 6 | import json 7 | import random 8 | 9 | import torch 10 | import torch.distributed as dist 11 | import transformers 12 | 13 | from torch.utils.data import Dataset 14 | from transformers import Trainer, AutoConfig 15 | from transformers import EvalPrediction 16 | 17 | 18 | from model import LlamaRewardModel 19 | 20 | from reward_datasets import TextRewardDataset, reward_data_collactor 21 | from reward_datasets import load_text_score_dataset 22 | from arguments import CustomTrainingArguments 23 | from trainer import RewardModelTrainer, compute_metrics 24 | 25 | from utils import print_rank_0, set_reward_tokenizer, merge_json_or_jsonl_data 26 | from utils import DEFAULT_PAD_TOKEN, DEFAULT_BOS_TOKEN, DEFAULT_EOS_TOKEN, DEFAULT_UNK_TOKEN 27 | from utils import QUERY_PROMPT, SEP_TOKEN, STRING_SEP, INFER_TMP_FILE 28 | 29 | 30 | 31 | def get_eval_datasets(args): 32 | data_dict = {} 33 | 34 | for data_path in args.eval_data_path: 35 | eval_data_list = load_text_score_dataset(args=args, data_path=data_path) 36 | 37 | eval_dataset = TextRewardDataset(eval_data_list) 38 | 39 | data_name = os.path.split(data_path)[-1] 40 | data_dict[data_name] = eval_dataset 41 | print_rank_0(">> finished loading {} data with data size = {}".format(data_name, len(eval_dataset))) 42 | 43 | if args.debug_mode: 44 | print_rank_0(f">>> check loaded data:") 45 | print_rank_0(f">>> {eval_dataset[0]}") 46 | 47 | return data_dict 48 | 49 | def get_train_dataset(args): 50 | all_train_data = [] 51 | for train_data_path in args.train_data_path: 52 | train_data = load_text_score_dataset(args=args, data_path=train_data_path) 53 | all_train_data.extend(train_data) 54 | 55 | if args.debug_mode: 56 | print_rank_0(f">>> check loaded data:") 57 | print_rank_0(f">>> {all_train_data[0]}") 58 | 59 | train_set = TextRewardDataset(all_train_data) 60 | return train_set 61 | 62 | 63 | def train(): 64 | parser = transformers.HfArgumentParser(CustomTrainingArguments) 65 | args = parser.parse_args_into_dataclasses()[0] 66 | print_rank_0(args) 67 | 68 | # load data 69 | #--------------------------------------------------------------------------------- 70 | if args.do_train: 71 | train_dataset = get_train_dataset(args) 72 | else: 73 | train_dataset = None 74 | 75 | eval_dataset_dict = get_eval_datasets(args) 76 | 77 | # setup model 78 | #--------------------------------------------------------------------------------- 79 | print_rank_0(f"Begin loading model from {args.model_name_or_path}") 80 | if args.model_type == "reward": 81 | model = LlamaRewardModel.from_pretrained(args.model_name_or_path) 82 | elif args.model_type == "sft": 83 | model = LlamaForCausalLM.from_pretrained(args.model_name_or_path) 84 | 85 | print_rank_0(model) 86 | print_rank_0(f"Finished loading model from {args.model_name_or_path}") 87 | 88 | model.is_parallelizable = True 89 | model.model_parallel = True 90 | 91 | # setup tokenizer 92 | #--------------------------------------------------------------------------------- 93 | tokenizer = transformers.AutoTokenizer.from_pretrained( 94 | args.model_name_or_path, 95 | model_max_length=args.max_length, 96 | padding_side=args.padding_side, 97 | truncation_side=args.truncation_side, 98 | use_fast=False, 99 | ) 100 | 101 | if args.model_type == "reward": 102 | model, tokenizer = set_reward_tokenizer(model=model, tokenizer=tokenizer) 103 | 104 | # build trainer 105 | #--------------------------------------------------------------------------------- 106 | 107 | trainer = RewardModelTrainer( 108 | model=model, 109 | tokenizer=tokenizer, 110 | args=args, 111 | compute_metrics=lambda x: compute_metrics(args, x), 112 | train_dataset=train_dataset, 113 | eval_dataset=eval_dataset_dict, 114 | data_collator=lambda x: reward_data_collactor(args, x, tokenizer) 115 | ) 116 | 117 | if args.do_train: 118 | if args.eval_at_start: 119 | for eval_set_name, eval_dataset in eval_dataset_dict.items(): 120 | eval_result = trainer.evaluate(eval_dataset=eval_dataset, metric_key_prefix="eval_"+eval_set_name) 121 | print_rank_0(eval_result) 122 | 123 | if args.resume_from_checkpoint: 124 | train_result = trainer.train(resume_from_checkpoint=args.resume_from_checkpoint) 125 | else: 126 | train_result = trainer.train() 127 | 128 | metrics = train_result.metrics 129 | trainer.log_metrics("train", metrics) 130 | trainer.save_metrics("train", metrics) 131 | 132 | trainer.save_state() 133 | trainer.save_model(output_dir=args.output_dir) 134 | 135 | 136 | final_eval_results ={} 137 | for eval_set_name, eval_dataset in eval_dataset_dict.items(): 138 | args.current_eval_filename = os.path.split(eval_set_name)[-1] 139 | eval_result = trainer.evaluate(eval_dataset=eval_dataset, metric_key_prefix="eval_"+eval_set_name) 140 | 141 | print_rank_0(eval_result) 142 | final_eval_results[eval_set_name] = eval_result 143 | 144 | if args.task_type == "inference": 145 | torch.distributed.barrier() 146 | if dist.get_rank() == 0: 147 | print_rank_0(eval_set_name) 148 | data_path = eval_dataset[0]['query_ids'][0].split(STRING_SEP)[0] 149 | 150 | result_temp = INFER_TMP_FILE.format(data_path=data_path, 151 | data_suffix=args.data_suffix, 152 | rank="*") 153 | print_rank_0(f"begin merge temp file from {result_temp}") 154 | outputs = merge_json_or_jsonl_data(result_temp) 155 | with open(f"{data_path}_pred_{args.data_suffix}_results.json", 'w') as f: 156 | json.dump(outputs, f, ensure_ascii=False, indent=2) 157 | 158 | 159 | 160 | with open(f"{args.output_dir}/final_eval_results.json", 'w') as f: 161 | json.dump(final_eval_results, f, ensure_ascii=False) 162 | 163 | 164 | 165 | if __name__ == "__main__": 166 | train() 167 | -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import logging 3 | from dataclasses import dataclass, field 4 | from typing import Dict, Optional, Sequence, List 5 | import json 6 | import datetime 7 | 8 | import numpy as np 9 | import sklearn 10 | 11 | import torch 12 | import torch.distributed as dist 13 | import torch.nn.functional as F 14 | import transformers 15 | 16 | 17 | from transformers import Trainer, AutoConfig 18 | from transformers import EvalPrediction 19 | 20 | from utils import print_rank_0, calibration_error, numpy_sigmoid 21 | from utils import QUERY_PROMPT, SEP_TOKEN, STRING_SEP, INFER_TMP_FILE 22 | 23 | 24 | 25 | def rm_calibration_errors(args, labels, probs, masks, num_bins): 26 | label_list = labels.reshape(-1).tolist() 27 | prob_list = probs.reshape(-1).tolist() 28 | mask_list = masks.reshape(-1).tolist() 29 | 30 | y_true, y_prob = [], [] 31 | for label, prob, mask in zip(label_list, prob_list, mask_list): 32 | if mask: 33 | y_true.append(label) 34 | y_prob.append(prob) 35 | 36 | if args.debug_mode: 37 | print_rank_0(f">>>>> check calibration inputs mask filtered...") 38 | print_rank_0(f">>>>>>>> y_true: {y_true[:10]}") 39 | print_rank_0(f">>>>>>>> y_prob: {y_prob[:10]}") 40 | 41 | return calibration_error(np.array(y_true), np.array(y_prob), n_bins=num_bins) 42 | 43 | 44 | def compute_metrics(args, prediction: EvalPrediction): 45 | logits = torch.from_numpy(prediction.predictions) 46 | scores = torch.from_numpy(prediction.label_ids) 47 | 48 | if args.debug_mode: 49 | print_rank_0(f">> check eval_prediction inputs...") 50 | print_rank_0(f">>> logits: {logits[:5]}") 51 | print_rank_0(f">>> scores: {scores[:5]}") 52 | 53 | logits_diff = logits.unsqueeze(1) - logits.unsqueeze(2) # [batch_size, num_sample, num_sample] 54 | 55 | score_mask_larger = (scores.unsqueeze(1) > scores.unsqueeze(2)) * 1. 56 | score_mask_smaller = (scores.unsqueeze(1) < scores.unsqueeze(2)) * 1. 57 | score_mask = score_mask_larger - score_mask_smaller 58 | pad_mask = (scores >= 0).unsqueeze(1) * 1. * (scores >= 0).unsqueeze(2) 59 | 60 | 61 | # calculate accuracy... 62 | pred_compare = (logits_diff.detach() * score_mask > 0.) * 1. 63 | total_mask = (score_mask_larger + score_mask_smaller) * pad_mask 64 | #correct_compare = (pred_compare == score_mask_larger) * total_mask 65 | correct_compare = pred_compare * total_mask 66 | 67 | all_acc = correct_compare.sum() / total_mask.sum() if total_mask.sum() > 0 else total_mask.sum() 68 | average_score = logits.mean().item() 69 | 70 | calibration_errors = {} 71 | if args.rm_calibration: 72 | for num_bins in args.calibration_bins: 73 | expected_error, average_error, max_error = rm_calibration_errors( 74 | args=args, 75 | labels=score_mask_larger, 76 | #probs=torch.sigmoid(logits_diff), 77 | probs=numpy_sigmoid(logits_diff.numpy()), 78 | masks=total_mask, 79 | num_bins=num_bins 80 | ) 81 | # if args.save_calibration and args.task_type == "eval": 82 | # time = datetime.datetime.now() 83 | # time_stamp = time.strftime("%d-%H:%M:%S") 84 | # if dist.get_rank() == 0: 85 | # outputs = {"prob_true": prob_true.tolist(), "prob_pred": prob_pred.tolist()} 86 | # with open(f"{args.output_dir}/calibration_result_t{args.current_eval_filename}_bin{num_bins}.json", 'w') as f: 87 | # json.dump(outputs, f, ensure_ascii=False, indent=2) 88 | 89 | calibration_errors[f"calibration_ECE_bin{num_bins}"] = expected_error 90 | calibration_errors[f"calibration_ACE_bin{num_bins}"] = average_error 91 | calibration_errors[f"calibration_MCE_bin{num_bins}"] = max_error 92 | 93 | if args.debug_mode: 94 | print_rank_0(f">> check eval_prediction outputs...") 95 | print_rank_0(f">>> correct_compare: {correct_compare}") 96 | print_rank_0(f">>> total_mask: {total_mask}") 97 | print_rank_0(f">>> all_acc: {all_acc}") 98 | print_rank_0(f">>> calibration error: {calibration_errors}") 99 | 100 | return {"Preference Acc": all_acc.item(), "Avg Score": average_score, **calibration_errors} 101 | 102 | 103 | def reward_model_loss(logits, scores, coeffs=None, loss_type="ranking"): # `logits`, `scores` with shape [bs, r], `coeffs` with shape [bs] 104 | logits_diff = logits.unsqueeze(1) - logits.unsqueeze(2) # shape [bs, r, r] 105 | 106 | score_mask_larger = (scores.unsqueeze(1) > scores.unsqueeze(2)) * 1. 107 | score_mask_smaller = (scores.unsqueeze(1) < scores.unsqueeze(2)) * 1. 108 | score_mask = score_mask_larger - score_mask_smaller 109 | pad_mask = (scores >= 0).unsqueeze(1) * 1. * (scores >= 0).unsqueeze(2) 110 | 111 | total_mask = (score_mask_larger + score_mask_smaller) * pad_mask 112 | 113 | if loss_type == "diff": 114 | log_prob = logits_diff * score_mask * pad_mask # shape [bs, r, r] 115 | else: 116 | log_prob = torch.nn.functional.logsigmoid(logits_diff * score_mask * pad_mask) # shape [bs, r, r] 117 | 118 | if coeffs is not None: 119 | log_prob = log_prob * coeffs.unsqueeze(-1).unsqueeze(-1) 120 | 121 | total_loss = - (log_prob * total_mask).sum() 122 | total_pairs = total_mask.sum() 123 | 124 | return total_loss / total_pairs if total_pairs > 0 else total_loss 125 | #return - log_prob.mean() 126 | 127 | 128 | class RewardModelTrainer(Trainer): 129 | def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys: Optional[List[str]] = None): 130 | device = model.device 131 | labels = torch.Tensor(inputs['scores']).float().to(device) 132 | 133 | with torch.no_grad(): 134 | loss, logits = self.compute_loss(model, inputs, return_outputs=True) 135 | loss = loss.mean().detach() 136 | # logits = outputs.logits 137 | 138 | if prediction_loss_only: 139 | return (loss, None, None) 140 | 141 | return (loss, logits, labels) 142 | 143 | 144 | def compute_loss(self, model, inputs, return_outputs=False): 145 | device = model.device 146 | scores = torch.Tensor(inputs['scores']).float().to(device) # shape [batch_size, response_num] 147 | input_ids = torch.Tensor(inputs['input_ids']).long().to(device) # shape [batch_size, response_num, seq_length] 148 | attention_mask = torch.Tensor(inputs['attention_mask']).float().to(device) 149 | # coeffs = torch.Tensor(inputs['coeffs']).float().to(device) 150 | apo_data_mask = torch.Tensor(inputs['apo_data_mask']).float().to(device) # shape [batch_size] value 1 if apo data 151 | 152 | batch_size, response_num, seq_length = input_ids.shape 153 | 154 | if self.args.debug_mode: 155 | print(f">>> input_ids shape {input_ids.shape}") 156 | 157 | outputs = model( 158 | input_ids=input_ids.view(-1, seq_length), 159 | attention_mask=attention_mask.view(-1, seq_length), 160 | padding_side=self.args.padding_side, 161 | pooling_type=self.args.pooling_type 162 | ) 163 | 164 | batch_logits = outputs['rm_logits'].view(batch_size, response_num) # shape [bs, r] 165 | 166 | if self.args.task_type == "apo": 167 | rm_kl_loss = reward_model_loss(batch_logits, scores, coeffs=(1. - apo_data_mask), loss_type="ranking") 168 | apo_loss = reward_model_loss(batch_logits, scores, coeffs=apo_data_mask, loss_type=self.args.apo_loss_type) 169 | total_loss = self.args.rm_kl_coeff * rm_kl_loss + self.args.apo_loss_coeff / self.args.apo_sample_num * apo_loss 170 | else: 171 | total_loss = reward_model_loss(batch_logits, scores, coeffs=None, loss_type="ranking") 172 | 173 | if self.args.debug_mode: 174 | print_rank_0(f">>> debug") 175 | print_rank_0(f">>> input_ids shape {input_ids.shape}") 176 | print_rank_0(f">>> Batch rm logits {batch_logits}") 177 | 178 | if self.args.task_type == "inference": 179 | query_ids = inputs['query_ids'] 180 | new_results = [] 181 | 182 | for i_bs in range(batch_size): 183 | for j_sample in range(response_num): 184 | data_path, query_id, ans_id = query_ids[i_bs][j_sample].split(STRING_SEP) 185 | new_results.append( 186 | json.dumps({f"{query_id}:{ans_id}": batch_logits[i_bs][j_sample].item()}, ensure_ascii=False) 187 | ) 188 | 189 | output_file_path = INFER_TMP_FILE.format(data_path=data_path, 190 | data_suffix=self.args.data_suffix, 191 | rank=dist.get_rank()) 192 | with open(output_file_path, 'a') as f: 193 | f.write("\n".join(new_results)+"\n") 194 | 195 | return (total_loss, batch_logits) if return_outputs else total_loss 196 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | import logging 3 | import math 4 | import os 5 | import io 6 | import sys 7 | import time 8 | import json 9 | import glob 10 | from typing import Optional, Sequence, Union, List, Dict 11 | 12 | import openai 13 | import tqdm 14 | from openai import openai_object 15 | import copy 16 | 17 | import numpy as np 18 | import torch 19 | 20 | SEP_TOKEN="" 21 | STRING_SEP="<:>" 22 | 23 | DEFAULT_PAD_TOKEN = "[PAD]" 24 | DEFAULT_EOS_TOKEN = "" 25 | DEFAULT_BOS_TOKEN = "" 26 | DEFAULT_UNK_TOKEN = "" 27 | 28 | QUERY_PROMPT="## Human:\n{request}\n\n## Assistant:\n{response}" 29 | 30 | INFER_TMP_FILE="{data_path}_pred_{data_suffix}_results_rank_{rank}.jsonl" 31 | 32 | def numpy_sigmoid(x): 33 | # r_x = x - x.max() 34 | return 1. / (1. + np.exp(-x)) 35 | 36 | 37 | def read_json_or_jsonl_data(data_path): 38 | if data_path[-5:] == ".json": 39 | with open(data_path, 'r') as f: 40 | data_list = json.load(f) 41 | else: 42 | with open(data_path, 'r') as f: 43 | lines = f.read().strip().split('\n') 44 | data_list = [json.loads(l) for l in lines] 45 | 46 | print_rank_0(f">>> totally load {len(data_list)} data from {data_path}") 47 | return data_list 48 | 49 | def merge_json_or_jsonl_data(data_path_pattern): 50 | file_names = glob.glob(data_path_pattern) 51 | print_rank_0(f"load {len(file_names)} files from {data_path_pattern}.") 52 | outputs = [] 53 | for file_name in file_names: 54 | new_data = read_json_or_jsonl_data(file_name) 55 | if isinstance(new_data, list): 56 | outputs.extend(new_data) 57 | elif isinstance(new_data, dict): 58 | outputs.append(new_data) 59 | return outputs 60 | 61 | 62 | def print_rank_0(message): 63 | if torch.distributed.is_initialized(): 64 | if torch.distributed.get_rank() == 0: 65 | print(message, flush=True) 66 | else: 67 | print(message, flush=True) 68 | 69 | 70 | def set_reward_tokenizer(model, tokenizer): 71 | 72 | tokenizer.pad_token_id = 3 73 | tokenizer.bos_token_id = 1 74 | tokenizer.eos_token_id = 2 75 | tokenizer.unk_token_id = 0 76 | tokenizer.sep_token_id = 4 77 | 78 | model.config.pad_token_id = tokenizer.pad_token_id 79 | model.config.bos_token_id = tokenizer.bos_token_id 80 | model.config.eos_token_id = tokenizer.eos_token_id 81 | 82 | print_rank_0(tokenizer) 83 | return model, tokenizer 84 | 85 | 86 | 87 | 88 | def calibration_error( 89 | y_true, 90 | y_prob, 91 | n_bins=5, 92 | strategy="uniform", 93 | ): 94 | if len(y_true) == 0: 95 | return 0., 0., 0. 96 | 97 | if strategy == "quantile": # Determine bin edges by distribution of data 98 | quantiles = np.linspace(0, 1, n_bins + 1) 99 | bins = np.percentile(y_prob, quantiles * 100) 100 | elif strategy == "uniform": 101 | bins = np.linspace(0.0, 1.0, n_bins + 1) 102 | else: 103 | raise ValueError( 104 | "Invalid entry to 'strategy' input. Strategy " 105 | "must be either 'quantile' or 'uniform'." 106 | ) 107 | 108 | binids = np.searchsorted(bins[1:-1], y_prob) 109 | 110 | bin_sums = np.bincount(binids, weights=y_prob, minlength=len(bins)) 111 | bin_true = np.bincount(binids, weights=y_true, minlength=len(bins)) 112 | bin_total = np.bincount(binids, minlength=len(bins)) 113 | 114 | nonzero = bin_total != 0 115 | # prob_true = bin_true[nonzero] / bin_total[nonzero] 116 | # prob_pred = bin_sums[nonzero] / bin_total[nonzero] 117 | 118 | # return prob_true, prob_pred, bin_total[nonzero] 119 | try: 120 | expected_error = np.abs(bin_sums - bin_true).sum() / len(y_prob) 121 | average_error = (np.abs(bin_sums[nonzero] - bin_true[nonzero]) / bin_total[nonzero]).mean() 122 | max_error = (np.abs(bin_sums[nonzero] - bin_true[nonzero]) / bin_total[nonzero]).max() 123 | except Exception as e: 124 | print_rank_0(">>>> WARNING: Encounter error in calibration calculation") 125 | print_rank_0(e) 126 | expected_error, average_error, max_error = 0., 0., 0. 127 | 128 | return expected_error, average_error, max_error 129 | --------------------------------------------------------------------------------