├── .gitattributes
├── .gitignore
├── DATA_LICENSE
├── LICENSE
├── README.md
├── arguments.py
├── configs
└── default_offload_opt_param.json
├── data
└── README.md
├── figures
└── apo_framework_v.png
├── model.py
├── requirements.txt
├── reward_datasets.py
├── tools
├── apo_data_converter.py
├── convert_apo_data.sh
├── inference_llm.py
├── llm_response_gen.sh
└── rejection_sampling.py
├── train.py
├── trainer.py
└── utils.py
/.gitattributes:
--------------------------------------------------------------------------------
1 | *.json filter=lfs diff=lfs merge=lfs -text
2 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 | *.*~
6 | *tmp.py
7 | *.bak
8 |
9 | # wandb
10 | wandb/
11 |
12 | # C extensions
13 | *.so
14 |
15 | # Distribution / packaging
16 | .Python
17 | build/
18 | develop-eggs/
19 | dist/
20 | downloads/
21 | eggs/
22 | .eggs/
23 | lib/
24 | lib64/
25 | parts/
26 | sdist/
27 | var/
28 | wheels/
29 | pip-wheel-metadata/
30 | share/python-wheels/
31 | *.egg-info/
32 | .installed.cfg
33 | *.egg
34 | MANIFEST
35 |
36 | # PyInstaller
37 | # Usually these files are written by a python script from a template
38 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
39 | *.manifest
40 | *.spec
41 |
42 | # Installer logs
43 | pip-log.txt
44 | pip-delete-this-directory.txt
45 |
46 | # Unit test / coverage reports
47 | htmlcov/
48 | .tox/
49 | .nox/
50 | .coverage
51 | .coverage.*
52 | .cache
53 | nosetests.xml
54 | coverage.xml
55 | *.cover
56 | *.py,cover
57 | .hypothesis/
58 | .pytest_cache/
59 |
60 | # Translations
61 | *.mo
62 | *.pot
63 |
64 | # Django stuff:
65 | *.log
66 | local_settings.py
67 | db.sqlite3
68 | db.sqlite3-journal
69 |
70 | # Flask stuff:
71 | instance/
72 | .webassets-cache
73 |
74 | # Scrapy stuff:
75 | .scrapy
76 |
77 | # Sphinx documentation
78 | docs/_build/
79 |
80 | # PyBuilder
81 | target/
82 |
83 | # Jupyter Notebook
84 | .ipynb_checkpoints
85 |
86 | # IPython
87 | profile_default/
88 | ipython_config.py
89 |
90 | # pyenv
91 | .python-version
92 |
93 | # pipenv
94 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
95 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
96 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
97 | # install all needed dependencies.
98 | #Pipfile.lock
99 |
100 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
101 | __pypackages__/
102 |
103 | # Celery stuff
104 | celerybeat-schedule
105 | celerybeat.pid
106 |
107 | # SageMath parsed files
108 | *.sage.py
109 |
110 | # Environments
111 | .env
112 | .venv
113 | env/
114 | venv/
115 | ENV/
116 | env.bak/
117 | venv.bak/
118 |
119 | # Spyder project settings
120 | .spyderproject
121 | .spyproject
122 |
123 | # Rope project settings
124 | .ropeproject
125 |
126 | # mkdocs documentation
127 | /site
128 |
129 | # mypy
130 | .mypy_cache/
131 | .dmypy.json
132 | dmypy.json
133 |
134 | # Pyre type checker
135 | .pyre/
136 |
137 | .DS_Store
138 | .idea
139 |
--------------------------------------------------------------------------------
/DATA_LICENSE:
--------------------------------------------------------------------------------
1 | Attribution-NonCommercial 4.0 International
2 |
3 | =======================================================================
4 |
5 | Creative Commons Corporation ("Creative Commons") is not a law firm and
6 | does not provide legal services or legal advice. Distribution of
7 | Creative Commons public licenses does not create a lawyer-client or
8 | other relationship. Creative Commons makes its licenses and related
9 | information available on an "as-is" basis. Creative Commons gives no
10 | warranties regarding its licenses, any material licensed under their
11 | terms and conditions, or any related information. Creative Commons
12 | disclaims all liability for damages resulting from their use to the
13 | fullest extent possible.
14 |
15 | Using Creative Commons Public Licenses
16 |
17 | Creative Commons public licenses provide a standard set of terms and
18 | conditions that creators and other rights holders may use to share
19 | original works of authorship and other material subject to copyright
20 | and certain other rights specified in the public license below. The
21 | following considerations are for informational purposes only, are not
22 | exhaustive, and do not form part of our licenses.
23 |
24 | Considerations for licensors: Our public licenses are
25 | intended for use by those authorized to give the public
26 | permission to use material in ways otherwise restricted by
27 | copyright and certain other rights. Our licenses are
28 | irrevocable. Licensors should read and understand the terms
29 | and conditions of the license they choose before applying it.
30 | Licensors should also secure all rights necessary before
31 | applying our licenses so that the public can reuse the
32 | material as expected. Licensors should clearly mark any
33 | material not subject to the license. This includes other CC-
34 | licensed material, or material used under an exception or
35 | limitation to copyright. More considerations for licensors:
36 | wiki.creativecommons.org/Considerations_for_licensors
37 |
38 | Considerations for the public: By using one of our public
39 | licenses, a licensor grants the public permission to use the
40 | licensed material under specified terms and conditions. If
41 | the licensor's permission is not necessary for any reason--for
42 | example, because of any applicable exception or limitation to
43 | copyright--then that use is not regulated by the license. Our
44 | licenses grant only permissions under copyright and certain
45 | other rights that a licensor has authority to grant. Use of
46 | the licensed material may still be restricted for other
47 | reasons, including because others have copyright or other
48 | rights in the material. A licensor may make special requests,
49 | such as asking that all changes be marked or described.
50 | Although not required by our licenses, you are encouraged to
51 | respect those requests where reasonable. More considerations
52 | for the public:
53 | wiki.creativecommons.org/Considerations_for_licensees
54 |
55 | =======================================================================
56 |
57 | Creative Commons Attribution-NonCommercial 4.0 International Public
58 | License
59 |
60 | By exercising the Licensed Rights (defined below), You accept and agree
61 | to be bound by the terms and conditions of this Creative Commons
62 | Attribution-NonCommercial 4.0 International Public License ("Public
63 | License"). To the extent this Public License may be interpreted as a
64 | contract, You are granted the Licensed Rights in consideration of Your
65 | acceptance of these terms and conditions, and the Licensor grants You
66 | such rights in consideration of benefits the Licensor receives from
67 | making the Licensed Material available under these terms and
68 | conditions.
69 |
70 |
71 | Section 1 -- Definitions.
72 |
73 | a. Adapted Material means material subject to Copyright and Similar
74 | Rights that is derived from or based upon the Licensed Material
75 | and in which the Licensed Material is translated, altered,
76 | arranged, transformed, or otherwise modified in a manner requiring
77 | permission under the Copyright and Similar Rights held by the
78 | Licensor. For purposes of this Public License, where the Licensed
79 | Material is a musical work, performance, or sound recording,
80 | Adapted Material is always produced where the Licensed Material is
81 | synched in timed relation with a moving image.
82 |
83 | b. Adapter's License means the license You apply to Your Copyright
84 | and Similar Rights in Your contributions to Adapted Material in
85 | accordance with the terms and conditions of this Public License.
86 |
87 | c. Copyright and Similar Rights means copyright and/or similar rights
88 | closely related to copyright including, without limitation,
89 | performance, broadcast, sound recording, and Sui Generis Database
90 | Rights, without regard to how the rights are labeled or
91 | categorized. For purposes of this Public License, the rights
92 | specified in Section 2(b)(1)-(2) are not Copyright and Similar
93 | Rights.
94 | d. Effective Technological Measures means those measures that, in the
95 | absence of proper authority, may not be circumvented under laws
96 | fulfilling obligations under Article 11 of the WIPO Copyright
97 | Treaty adopted on December 20, 1996, and/or similar international
98 | agreements.
99 |
100 | e. Exceptions and Limitations means fair use, fair dealing, and/or
101 | any other exception or limitation to Copyright and Similar Rights
102 | that applies to Your use of the Licensed Material.
103 |
104 | f. Licensed Material means the artistic or literary work, database,
105 | or other material to which the Licensor applied this Public
106 | License.
107 |
108 | g. Licensed Rights means the rights granted to You subject to the
109 | terms and conditions of this Public License, which are limited to
110 | all Copyright and Similar Rights that apply to Your use of the
111 | Licensed Material and that the Licensor has authority to license.
112 |
113 | h. Licensor means the individual(s) or entity(ies) granting rights
114 | under this Public License.
115 |
116 | i. NonCommercial means not primarily intended for or directed towards
117 | commercial advantage or monetary compensation. For purposes of
118 | this Public License, the exchange of the Licensed Material for
119 | other material subject to Copyright and Similar Rights by digital
120 | file-sharing or similar means is NonCommercial provided there is
121 | no payment of monetary compensation in connection with the
122 | exchange.
123 |
124 | j. Share means to provide material to the public by any means or
125 | process that requires permission under the Licensed Rights, such
126 | as reproduction, public display, public performance, distribution,
127 | dissemination, communication, or importation, and to make material
128 | available to the public including in ways that members of the
129 | public may access the material from a place and at a time
130 | individually chosen by them.
131 |
132 | k. Sui Generis Database Rights means rights other than copyright
133 | resulting from Directive 96/9/EC of the European Parliament and of
134 | the Council of 11 March 1996 on the legal protection of databases,
135 | as amended and/or succeeded, as well as other essentially
136 | equivalent rights anywhere in the world.
137 |
138 | l. You means the individual or entity exercising the Licensed Rights
139 | under this Public License. Your has a corresponding meaning.
140 |
141 |
142 | Section 2 -- Scope.
143 |
144 | a. License grant.
145 |
146 | 1. Subject to the terms and conditions of this Public License,
147 | the Licensor hereby grants You a worldwide, royalty-free,
148 | non-sublicensable, non-exclusive, irrevocable license to
149 | exercise the Licensed Rights in the Licensed Material to:
150 |
151 | a. reproduce and Share the Licensed Material, in whole or
152 | in part, for NonCommercial purposes only; and
153 |
154 | b. produce, reproduce, and Share Adapted Material for
155 | NonCommercial purposes only.
156 |
157 | 2. Exceptions and Limitations. For the avoidance of doubt, where
158 | Exceptions and Limitations apply to Your use, this Public
159 | License does not apply, and You do not need to comply with
160 | its terms and conditions.
161 |
162 | 3. Term. The term of this Public License is specified in Section
163 | 6(a).
164 |
165 | 4. Media and formats; technical modifications allowed. The
166 | Licensor authorizes You to exercise the Licensed Rights in
167 | all media and formats whether now known or hereafter created,
168 | and to make technical modifications necessary to do so. The
169 | Licensor waives and/or agrees not to assert any right or
170 | authority to forbid You from making technical modifications
171 | necessary to exercise the Licensed Rights, including
172 | technical modifications necessary to circumvent Effective
173 | Technological Measures. For purposes of this Public License,
174 | simply making modifications authorized by this Section 2(a)
175 | (4) never produces Adapted Material.
176 |
177 | 5. Downstream recipients.
178 |
179 | a. Offer from the Licensor -- Licensed Material. Every
180 | recipient of the Licensed Material automatically
181 | receives an offer from the Licensor to exercise the
182 | Licensed Rights under the terms and conditions of this
183 | Public License.
184 |
185 | b. No downstream restrictions. You may not offer or impose
186 | any additional or different terms or conditions on, or
187 | apply any Effective Technological Measures to, the
188 | Licensed Material if doing so restricts exercise of the
189 | Licensed Rights by any recipient of the Licensed
190 | Material.
191 |
192 | 6. No endorsement. Nothing in this Public License constitutes or
193 | may be construed as permission to assert or imply that You
194 | are, or that Your use of the Licensed Material is, connected
195 | with, or sponsored, endorsed, or granted official status by,
196 | the Licensor or others designated to receive attribution as
197 | provided in Section 3(a)(1)(A)(i).
198 |
199 | b. Other rights.
200 |
201 | 1. Moral rights, such as the right of integrity, are not
202 | licensed under this Public License, nor are publicity,
203 | privacy, and/or other similar personality rights; however, to
204 | the extent possible, the Licensor waives and/or agrees not to
205 | assert any such rights held by the Licensor to the limited
206 | extent necessary to allow You to exercise the Licensed
207 | Rights, but not otherwise.
208 |
209 | 2. Patent and trademark rights are not licensed under this
210 | Public License.
211 |
212 | 3. To the extent possible, the Licensor waives any right to
213 | collect royalties from You for the exercise of the Licensed
214 | Rights, whether directly or through a collecting society
215 | under any voluntary or waivable statutory or compulsory
216 | licensing scheme. In all other cases the Licensor expressly
217 | reserves any right to collect such royalties, including when
218 | the Licensed Material is used other than for NonCommercial
219 | purposes.
220 |
221 |
222 | Section 3 -- License Conditions.
223 |
224 | Your exercise of the Licensed Rights is expressly made subject to the
225 | following conditions.
226 |
227 | a. Attribution.
228 |
229 | 1. If You Share the Licensed Material (including in modified
230 | form), You must:
231 |
232 | a. retain the following if it is supplied by the Licensor
233 | with the Licensed Material:
234 |
235 | i. identification of the creator(s) of the Licensed
236 | Material and any others designated to receive
237 | attribution, in any reasonable manner requested by
238 | the Licensor (including by pseudonym if
239 | designated);
240 |
241 | ii. a copyright notice;
242 |
243 | iii. a notice that refers to this Public License;
244 |
245 | iv. a notice that refers to the disclaimer of
246 | warranties;
247 |
248 | v. a URI or hyperlink to the Licensed Material to the
249 | extent reasonably practicable;
250 |
251 | b. indicate if You modified the Licensed Material and
252 | retain an indication of any previous modifications; and
253 |
254 | c. indicate the Licensed Material is licensed under this
255 | Public License, and include the text of, or the URI or
256 | hyperlink to, this Public License.
257 |
258 | 2. You may satisfy the conditions in Section 3(a)(1) in any
259 | reasonable manner based on the medium, means, and context in
260 | which You Share the Licensed Material. For example, it may be
261 | reasonable to satisfy the conditions by providing a URI or
262 | hyperlink to a resource that includes the required
263 | information.
264 |
265 | 3. If requested by the Licensor, You must remove any of the
266 | information required by Section 3(a)(1)(A) to the extent
267 | reasonably practicable.
268 |
269 | 4. If You Share Adapted Material You produce, the Adapter's
270 | License You apply must not prevent recipients of the Adapted
271 | Material from complying with this Public License.
272 |
273 |
274 | Section 4 -- Sui Generis Database Rights.
275 |
276 | Where the Licensed Rights include Sui Generis Database Rights that
277 | apply to Your use of the Licensed Material:
278 |
279 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right
280 | to extract, reuse, reproduce, and Share all or a substantial
281 | portion of the contents of the database for NonCommercial purposes
282 | only;
283 |
284 | b. if You include all or a substantial portion of the database
285 | contents in a database in which You have Sui Generis Database
286 | Rights, then the database in which You have Sui Generis Database
287 | Rights (but not its individual contents) is Adapted Material; and
288 |
289 | c. You must comply with the conditions in Section 3(a) if You Share
290 | all or a substantial portion of the contents of the database.
291 |
292 | For the avoidance of doubt, this Section 4 supplements and does not
293 | replace Your obligations under this Public License where the Licensed
294 | Rights include other Copyright and Similar Rights.
295 |
296 |
297 | Section 5 -- Disclaimer of Warranties and Limitation of Liability.
298 |
299 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
300 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
301 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
302 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
303 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
304 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
305 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
306 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
307 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
308 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
309 |
310 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
311 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
312 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
313 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
314 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
315 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
316 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
317 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
318 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
319 |
320 | c. The disclaimer of warranties and limitation of liability provided
321 | above shall be interpreted in a manner that, to the extent
322 | possible, most closely approximates an absolute disclaimer and
323 | waiver of all liability.
324 |
325 |
326 | Section 6 -- Term and Termination.
327 |
328 | a. This Public License applies for the term of the Copyright and
329 | Similar Rights licensed here. However, if You fail to comply with
330 | this Public License, then Your rights under this Public License
331 | terminate automatically.
332 |
333 | b. Where Your right to use the Licensed Material has terminated under
334 | Section 6(a), it reinstates:
335 |
336 | 1. automatically as of the date the violation is cured, provided
337 | it is cured within 30 days of Your discovery of the
338 | violation; or
339 |
340 | 2. upon express reinstatement by the Licensor.
341 |
342 | For the avoidance of doubt, this Section 6(b) does not affect any
343 | right the Licensor may have to seek remedies for Your violations
344 | of this Public License.
345 |
346 | c. For the avoidance of doubt, the Licensor may also offer the
347 | Licensed Material under separate terms or conditions or stop
348 | distributing the Licensed Material at any time; however, doing so
349 | will not terminate this Public License.
350 |
351 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
352 | License.
353 |
354 |
355 | Section 7 -- Other Terms and Conditions.
356 |
357 | a. The Licensor shall not be bound by any additional or different
358 | terms or conditions communicated by You unless expressly agreed.
359 |
360 | b. Any arrangements, understandings, or agreements regarding the
361 | Licensed Material not stated herein are separate from and
362 | independent of the terms and conditions of this Public License.
363 |
364 |
365 | Section 8 -- Interpretation.
366 |
367 | a. For the avoidance of doubt, this Public License does not, and
368 | shall not be interpreted to, reduce, limit, restrict, or impose
369 | conditions on any use of the Licensed Material that could lawfully
370 | be made without permission under this Public License.
371 |
372 | b. To the extent possible, if any provision of this Public License is
373 | deemed unenforceable, it shall be automatically reformed to the
374 | minimum extent necessary to make it enforceable. If the provision
375 | cannot be reformed, it shall be severed from this Public License
376 | without affecting the enforceability of the remaining terms and
377 | conditions.
378 |
379 | c. No term or condition of this Public License will be waived and no
380 | failure to comply consented to unless expressly agreed to by the
381 | Licensor.
382 |
383 | d. Nothing in this Public License constitutes or may be interpreted
384 | as a limitation upon, or waiver of, any privileges and immunities
385 | that apply to the Licensor or You, including from the legal
386 | processes of any jurisdiction or authority.
387 |
388 | =======================================================================
389 |
390 | Creative Commons is not a party to its public
391 | licenses. Notwithstanding, Creative Commons may elect to apply one of
392 | its public licenses to material it publishes and in those instances
393 | will be considered the “Licensor.” The text of the Creative Commons
394 | public licenses is dedicated to the public domain under the CC0 Public
395 | Domain Dedication. Except for the limited purpose of indicating that
396 | material is shared under a Creative Commons public license or as
397 | otherwise permitted by the Creative Commons policies published at
398 | creativecommons.org/policies, Creative Commons does not authorize the
399 | use of the trademark "Creative Commons" or any other trademark or logo
400 | of Creative Commons without its prior written consent including,
401 | without limitation, in connection with any unauthorized modifications
402 | to any of its public licenses or any other arrangements,
403 | understandings, or agreements concerning use of licensed material. For
404 | the avoidance of doubt, this paragraph does not form part of the
405 | public licenses.
406 |
407 | Creative Commons may be contacted at creativecommons.org.
408 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Adversarial Preference Optimization
2 |
3 | [](https://github.com/Linear95/APO/blob/main/LICENSE)
4 | [](https://github.com/Linear95/APO/blob/main/DATA_LICENSE)
5 | [](https://www.python.org/downloads/release/python-380/)
6 |
7 | This repo contains the implementation of the ACL 2024 paper:
8 | - [Adversarial Preference Optimization: Enhancing Your Alignment via RM-LLM Game](https://arxiv.org/abs/2311.08045).
9 |
10 | In Adversarial Preference Optimization (APO), we let the reward model (RM) and LLM agent play a min-max game, through which both models can be further enhanced without additional preference annotation.
11 |
12 |
13 |
14 |
15 |
16 | For an overview, the repo contains:
17 | - [Split Helpful\&Harmless](https://drive.google.com/drive/folders/1v0xNMMOfL9lfFLzTGCerZCPNPJrR9ZLX?usp=sharing) (HH) dataset
18 | - [GPT-4 responses](https://drive.google.com/file/d/1hDo6Sk8QX1c3kP_qJUgZ4J16kHAi0hEq/view?usp=sharing) as golden annotation on HH-RM training set
19 | - The base RM, testing RM, and APO RM training \& scoring pipelines
20 | - The LLM response generation [pipeline](https://github.com/Linear95/APO/blob/main/tools/llm_response_gen.sh)
21 |
22 |
23 | ## Environment
24 | We use `Python3.8` with the dependencies listed in `requirements.txt`. To build the appropriate environment, use the following command:
25 | ```
26 | pip3 install -r requirements.txt
27 | ```
28 |
29 | ## Data \& Annotation
30 |
31 | To separately update RM and LLM, we split the cleaned [Helpful\&Harmless](https://github.com/Linear95/DSP/tree/main/data) (HH) dataset into an RM training set and a LLM training set.
32 | | Data Type| HH-RM Train Set | HH-LLM Train Set| HH Test Set|
33 | | --------:| :----------|:-------| :--------|
34 | | Preference Pairs | [RM training set](https://drive.google.com/file/d/12DefElb3DazIPeaIEwd0B_9La84Slc7f/view?usp=sharing) | [RM validation set](https://drive.google.com/file/d/1ZqTuupFxrK2m3_E6ezMRcdT_4k6zX-IW/view?usp=sharing) (sampled 10K pairs) | [RM testing set](https://drive.google.com/file/d/1ite1KXZlGs1ojCVB20rLHlj7_3KlOULY/view?usp=sharing)|
35 | | Golden Answers | [APO positive responses](https://drive.google.com/file/d/1hDo6Sk8QX1c3kP_qJUgZ4J16kHAi0hEq/view?usp=sharing) | | |
36 | | LLM Samples | APO negative responses ([`alpaca_rm_samples`](https://drive.google.com/file/d/1_wiKVKob6QVOHja4C_N-y5LlvHZE9ZiZ/view?usp=sharing)) | LLM alignment samples ([`alpaca_llm_samples`](https://drive.google.com/file/d/1ZpAXK0F-YC919_vP7gnyGpo8ezQGIv5O/view?usp=sharing))| [LLM testing Queries](https://drive.google.com/file/d/1ite1KXZlGs1ojCVB20rLHlj7_3KlOULY/view?usp=drive_link)|
37 |
38 |
39 | On both HH-RM and HH-LLM training sets, we infer four LLM responses for each query as [`alpaca_rm_samples`](https://drive.google.com/file/d/1_wiKVKob6QVOHja4C_N-y5LlvHZE9ZiZ/view?usp=sharing) and [`alpaca_llm_samples`](https://drive.google.com/file/d/1ZpAXK0F-YC919_vP7gnyGpo8ezQGIv5O/view?usp=sharing). `alpaca_rm_samples` is combined with the golden responses on the HH-RM set as APO RM training pairs. `alpaca_llm_samples` is further scored by RMs and used for LLM alignment. To obtain LLM responses by yourself, run the command:
40 | ```bash
41 | bash tools/llm_response_gen.sh
42 | ```
43 |
44 |
45 |
46 | ## RM Training
47 |
48 | ### Base RM Training
49 |
50 | We build our RM on the pretrained LLaMA-7B ([`decapoda-research/llama-7b-hf`](https://huggingface.co/decapoda-research/llama-7b-hf)). To train the base RM for rejection sampling, use the following command:
51 |
52 | ```bash
53 | REPO_DIR=
54 | DATA_DIR=${REPO_DIR}/data/hh-split
55 | TRAIN_DATA_LIST="${DATA_DIR}/rm_data/hh_split_rm.train.json"
56 | TEST_DATA_LIST="${DATA_DIR}/eval_data/hh_cleaned_origin.test.json\
57 | ${DATA_DIR}/eval_data/hh_split_llm.valid.json"
58 |
59 | NUM_GPUS=8
60 | BATCH_SIZE=64
61 | MICRO_BATCH_SIZE=1
62 | LEARNING_RATE=1e-6
63 | GRADIENT_ACCUMULATION_STEP=$((BATCH_SIZE / NUM_GPUS / MICRO_BATCH_SIZE))
64 |
65 | torchrun --nproc_per_node=${NUM_GPUS} --master_port=6000 ${REPO_DIR}/train.py \
66 | --task_type hh_split \
67 | --do_train True \
68 | --eval_at_start False \
69 | --model_type reward \
70 | --model_name_or_path "decapoda-research/llama-7b-hf" \
71 | --data_type "comparison_pair" \
72 | --train_data_path ${TRAIN_DATA_LIST} \
73 | --eval_data_path ${TEST_DATA_LIST} \
74 | --rm_calibration True \
75 | --data_suffix rm_base \
76 | --add_sep_token True \
77 | --remove_unused_columns false \
78 | --output_dir \
79 | --num_train_epochs 1 \
80 | --per_device_train_batch_size ${MICRO_BATCH_SIZE} \
81 | --per_device_eval_batch_size ${MICRO_BATCH_SIZE} \
82 | --gradient_accumulation_steps ${GRADIENT_ACCUMULATION_STEP} \
83 | --evaluation_strategy steps \
84 | --padding_side right \
85 | --truncation_side left \
86 | --pooling_type last \
87 | --max_length 512 \
88 | --save_strategy steps \
89 | --learning_rate ${LEARNING_RATE} \
90 | --warmup_steps 100 \
91 | --deepspeed configs/default_offload_opt_param.json \
92 | --tf32 false --fp16 false
93 | ```
94 |
95 | We also trained a testing RM to automatically evaluate the LLM response quality on the testing queries. To train the testing RM, change `TRAIN_DATA_LIST=${DATA_DIR}/hh_cleaned_origin.train.json` in the above command to learn with all the HH training comparisons.
96 |
97 | The RM training data files (values in `TRAIN_DATA_LIST`) are lists of dictionaries, where each dictionary is an RM training item (`--data_type="comparison_pair"`) including the following keys:
98 | - `text`: a list of query-response text, split by a special token ``.
99 | - `scores`: a list of float numbers, representing the preference scores of the corresponding query-response text.
100 | - `query_id`: a unique ID to the RM training item.
101 |
102 |
103 |
104 | ### APO RM Training
105 |
106 | To train the APO RM, first merge LLM samples and golden annotations into APO comparison pairs:
107 | ```
108 | REPO_DIR=
109 | DATA_DIR="${REPO_DIR}/data/hh-split"
110 |
111 | python3 ${REPO_DIR}/tools/apo_data_converter.py \
112 | --golden_data_path ${DATA_DIR}/rm_data/hh_split_rm.golden.json \
113 | --sample_data_path ${DATA_DIR}/rm_data/hh_split_rm_alpaca_v0.sample.json \
114 | --output_dir ${DATA_DIR}/apo_data \
115 | --apo_data_name "rm_apo_data_v0"
116 | ```
117 |
118 | Then use the following command to conduct APO RM finetuning:
119 | ```
120 | REPO_DIR=
121 | DATA_DIR=${REPO_DIR}/data/hh-split
122 | TRAIN_DATA_LIST="${DATA_DIR}/rm_data/hh_split_rm.train.json \
123 | ${DATA_DIR}/apo_data/rm_apo_data_v0_text_scores.json"
124 | NUM_APO_SAMPLES=4
125 |
126 | TEST_DATA_LIST="${DATA_DIR}/eval_data/hh_cleaned_origin.test.json \
127 | ${DATA_DIR}/eval_data/hh_split_llm.valid.json"
128 |
129 | NUM_GPUS=8
130 | BATCH_SIZE=64
131 | MICRO_BATCH_SIZE=1
132 | LEARNING_RATE=1e-6
133 | APO_COEFF=0.1
134 | GRADIENT_ACCUMULATION_STEP=$((BATCH_SIZE / NUM_GPUS / MICRO_BATCH_SIZE))
135 |
136 |
137 | torchrun --nproc_per_node=${NUM_GPUS} --master_port=6000 ${REPO_DIR}/train.py \
138 | --task_type apo \
139 | --do_train True \
140 | --eval_at_start False \
141 | --model_type reward \
142 | --model_name_or_path "decapoda-research/llama-7b-hf" \
143 | --data_type "comparison_pair" \
144 | --train_data_path ${TRAIN_DATA_LIST} \
145 | --eval_data_path ${TEST_DATA_LIST} \
146 | --rm_calibration True \
147 | --data_suffix rm_apo_v1 \
148 | --add_sep_token True \
149 | --remove_unused_columns false \
150 | --output_dir \
151 | --num_train_epochs 1 \
152 | --apo_loss_coeff ${APO_COEFF} \
153 | --apo_sample_num ${NUM_APO_SAMPLES} \
154 | --per_device_train_batch_size ${MICRO_BATCH_SIZE} \
155 | --per_device_eval_batch_size ${MICRO_BATCH_SIZE} \
156 | --gradient_accumulation_steps ${GRADIENT_ACCUMULATION_STEP} \
157 | --evaluation_strategy steps \
158 | --padding_side right \
159 | --truncation_side left \
160 | --pooling_type last \
161 | --max_length 512 \
162 | --save_strategy steps \
163 | --save_total_limit 10 \
164 | --learning_rate ${LEARNING_RATE} \
165 | --warmup_steps 100 \
166 | --deepspeed configs/default_offload_opt_param.json \
167 | --tf32 false --fp16 false
168 | ```
169 | ## RM Scoring
170 |
171 | After finishing the RM training, we can use the following command to scoring new LLM samples:
172 | ```bash
173 | REPO_DIR=
174 | DATA_DIR=${REPO_DIR}/data/hh-split/llm_data
175 | DATA_PATH="${DATA_DIR}/hh_split_llm_alpaca_v0.sample.json"
176 |
177 | MODEL_PATH=
178 | MODEL_NAME="base_rm" # or "apo_rm"
179 |
180 | NUM_GPUS=8
181 | MICRO_BATCH_SIZE=16
182 |
183 | torchrun --nproc_per_node=${NUM_GPUS} --master_port=6000 ${REPO_DIR}/train.py \
184 | --task_type inference \
185 | --do_train False \
186 | --eval_at_start True \
187 | --model_type reward \
188 | --model_name_or_path ${MODEL_PATH} \
189 | --data_type "reject_sample" \
190 | --eval_data_path ${DATA_PATH} \
191 | --rm_calibration False \
192 | --data_suffix ${MODEL_NAME} \
193 | --add_sep_token True \
194 | --remove_unused_columns false \
195 | --output_dir \
196 | --per_device_eval_batch_size ${MICRO_BATCH_SIZE} \
197 | --evaluation_strategy steps \
198 | --padding_side right \
199 | --truncation_side left \
200 | --pooling_type last \
201 | --max_length 512 \
202 | --deepspeed configs/default_offload_opt_param.json \
203 | --tf32 false --fp16 false
204 |
205 |
206 | # rejection sampling
207 | SCORE_PATH=${DATA_PATH}_pred_${MODEL_NAME}_results.json
208 | OUTPUT_FILE_NAME=${DATA_PATH}_rjs_${MODEL_NAME}.json
209 |
210 | python3 ${REPO_DIR}/tools/rejection_sampling.py \
211 | --data_path ${DATA_DIR} \
212 | --score_path ${SCORE_PATH} \
213 | --output_dir ${DATA_DIR} \
214 | --rm_scorer ${MODEL_NAME} \
215 | --output_file_name ${OUTPUT_FILE_NAME}
216 |
217 | # remove tmp inference files
218 | rm ${DATA_DIR}/*rank*.jsonl
219 | ```
220 | After inference process, we obtain a RM scoring file `${DATA_PATH}_rjs_${MODEL_NAME}.json`. Then we can update the Alpaca model with the training pipeline [here](https://github.com/tatsu-lab/stanford_alpaca).
221 |
222 |
223 | ## Citation
224 | ```
225 | @inproceedings{cheng2024adversarial,
226 | title={Adversarial Preference Optimization: Enhancing Your Alignment via RM-LLM Game},
227 | author={Cheng, Pengyu and Yang, Yifan and Li, Jian and Dai, Yong and Hu, Tianhao and Cao, Peixin and Du, Nan and Li, Xiaolong},
228 | booktitle={Findings of the Association for Computational Linguistics},
229 | year={2024}
230 | }
231 | ```
232 |
--------------------------------------------------------------------------------
/arguments.py:
--------------------------------------------------------------------------------
1 | from typing import List, Optional, Tuple, Union
2 |
3 | from dataclasses import dataclass, field
4 | from transformers import TrainingArguments
5 |
6 | @dataclass
7 | class CustomTrainingArguments(TrainingArguments):
8 | # experiment setups
9 | reward_domain: str = field(
10 | default="normal",
11 | metadata={"help": "the domain for reward model training."}
12 | )
13 | # tokenizer params
14 | padding_side: str = field(
15 | default="right",
16 | metadata={"help": "the direction for tokenizer to add padding tokens."}
17 | )
18 |
19 | truncation_side: str = field(
20 | default="left",
21 | metadata={"help": "the direction for tokenizer to add padding tokens."}
22 | )
23 |
24 | add_sep_token: bool =field(
25 | default=False,
26 | metadata={"help": "whether add a token between query and response."}
27 | )
28 |
29 | tokenizer_path: str = field(
30 | default="llama-7b-hf",
31 | metadata={"help": "the path to load pretrained tokenizer."}
32 | )
33 |
34 |
35 | # model params
36 | model_type: str = field(
37 | default="llama",
38 | metadata={"help": "the base model type for reward model, selected from [llama, bert]."}
39 | )
40 |
41 | model_prefix: str = field(
42 | default="llama",
43 | metadata={"help": "the base model type for reward model, selected from [llama, bert]."}
44 | )
45 |
46 |
47 | pooling_type: str = field(
48 | default="average",
49 | metadata={"help": "the pooling method for reward model, selected from [average, max, last]."}
50 | )
51 |
52 | model_name_or_path: str = field(
53 | default="llama-7b-hf",
54 | metadata={"help": "the path to load pretrained model."}
55 | )
56 |
57 |
58 | # data params
59 |
60 | apo_sample_num: int = field(
61 | default=1,
62 | metadata={"help": "the maximum response number of each data item"}
63 | )
64 |
65 |
66 | data_dir: str = field(
67 | default="path/to/cleaned_data",
68 | metadata={"help": "the directory to load data."}
69 | )
70 |
71 | data_type: str = field(
72 | default="no_type",
73 | metadata={"help": "the type of data."}
74 | )
75 | data_path: str = field(
76 | default="yahma/alpaca-cleaned",
77 | metadata={"help": "the path to load data."}
78 | )
79 |
80 | train_data_path: List[str] = field(
81 | default_factory=lambda: ["/data/to/train/dataset"],
82 | metadata={"help": "train datasets paths."}
83 | )
84 |
85 |
86 | eval_data_path: List[str] = field(
87 | default_factory=lambda: ["/data/to/eval/dataset"],
88 | metadata={"help": "evaluation datasets paths."}
89 | )
90 |
91 |
92 | data_prefix: str = field(
93 | default="yahma/alpaca-cleaned",
94 | metadata={"help": "the prefix to load train and test data."}
95 | )
96 |
97 | data_suffix: str = field(
98 | default="yahma/alpaca-cleaned",
99 | metadata={"help": "the suffix to save inference data."}
100 | )
101 |
102 |
103 | format_mode: str = field(
104 | default="lab_mode",
105 | metadata={"help": "the format to process data"}
106 | )
107 |
108 |
109 | # training hyperparams
110 | task_type: str = field(
111 | default="training",
112 | metadata={"help": "the task type"}
113 | )
114 |
115 |
116 | eval_at_start: bool = field(
117 | default=False,
118 | metadata={"help": "whether make eval at start."}
119 | )
120 |
121 | debug_mode: bool = field(
122 | default=False,
123 | metadata={"help": "whether use the debug mode."}
124 | )
125 |
126 | cache_dir: Optional[str] = field(default=None)
127 |
128 | optim: str = field(default="adamw_torch", metadata={"help": "the paramter to use"})
129 |
130 | apo_loss_type: str = field(default="ranking", metadata={"help": "use `ranking` or `diff` loss for apo"})
131 |
132 | apo_loss_coeff: float = field(default=0., metadata={"help": "the coefficient for apo loss."})
133 |
134 | lm_loss_coeff: float = field(default=0., metadata={"help": "the coefficient for language modeling loss."})
135 |
136 | rm_kl_coeff: float = field(default=1., metadata={"help": "the coefficient for apo rm kl regularizer."})
137 |
138 | contrast_loss_coeff: float = field(default=0., metadata={"help": "the coefficient for contrastive learning loss."})
139 |
140 | lm_score_thresh: float = field(default=0.85, metadata={"help": "the threshold to select response for language modeling"})
141 |
142 | max_length: int = field(
143 | default=256,
144 | metadata={"help": "the max sentence sequence length."}
145 | )
146 |
147 | batch_size: int = field(
148 | default=256,
149 | metadata={"help": "the overall training batch size"}
150 | )
151 |
152 | micro_batch_size: int = field(
153 | default=32,
154 | metadata={"help": "the batch size on each device, equavilent to `per_gpu_train_batch_size`"}
155 | )
156 |
157 |
158 | valid_data_size: int = field(
159 | default=0,
160 | metadata={"help": "the data size for validation data"}
161 | )
162 |
163 | resume_from_checkpoint: Optional[str] = field(
164 | default=None,
165 | metadata={"help": "either training checkpoint or final adapter"}
166 | )
167 | # generation parameters:
168 | max_new_tokens: int = field(
169 | default=256,
170 | metadata={"help": "the max sentence sequence length."}
171 | )
172 |
173 | # evaluation parameters:
174 | rm_calibration: bool = field(
175 | default=False,
176 | metadata={"help": "whether evaluate the calibration score for RM"}
177 | )
178 |
179 | calibration_bins: List[int] = field(
180 | default_factory=lambda: [10],
181 | metadata={"help": "number of bins for RM calibration"}
182 | )
183 |
184 |
185 | save_calibration: bool = field(
186 | default=False,
187 | metadata={"help": "whether save the calibration results for RM"}
188 | )
189 |
190 |
--------------------------------------------------------------------------------
/configs/default_offload_opt_param.json:
--------------------------------------------------------------------------------
1 | {
2 | "bf16": {
3 | "enabled": "auto"
4 | },
5 | "optimizer": {
6 | "type": "AdamW",
7 | "params": {
8 | "lr": "auto",
9 | "betas": "auto",
10 | "eps": "auto",
11 | "weight_decay": "auto"
12 | }
13 | },
14 | "scheduler": {
15 | "type": "WarmupDecayLR",
16 | "params": {
17 | "total_num_steps": "auto",
18 | "warmup_min_lr": "auto",
19 | "warmup_max_lr": "auto",
20 | "warmup_num_steps": "auto"
21 | }
22 | },
23 | "zero_optimization": {
24 | "stage": 3,
25 | "offload_optimizer": {
26 | "device": "cpu",
27 | "pin_memory": true
28 | },
29 | "offload_param": {
30 | "device": "cpu",
31 | "pin_memory": true
32 | },
33 | "overlap_comm": true,
34 | "contiguous_gradients": true,
35 | "sub_group_size": 1e9,
36 | "reduce_bucket_size": "auto",
37 | "stage3_prefetch_bucket_size": "auto",
38 | "stage3_param_persistence_threshold": "auto",
39 | "stage3_max_live_parameters": 1e9,
40 | "stage3_max_reuse_distance": 1e9,
41 | "stage3_gather_16bit_weights_on_model_save": true
42 | },
43 | "gradient_accumulation_steps": "auto",
44 | "gradient_clipping": "auto",
45 | "steps_per_print": 5,
46 | "train_batch_size": "auto",
47 | "train_micro_batch_size_per_gpu": "auto",
48 | "wall_clock_breakdown": false
49 | }
50 |
--------------------------------------------------------------------------------
/data/README.md:
--------------------------------------------------------------------------------
1 | # APO Training & Evaluating Data
2 |
3 | we have move the data to [Google Drive](https://drive.google.com/drive/folders/1v0xNMMOfL9lfFLzTGCerZCPNPJrR9ZLX) due to the GitHub LFS storage limitation.
4 |
5 | The data separation can be found below:
6 |
7 |
8 | | Data Type| HH-RM Train Set | HH-LLM Train Set| HH Test Set|
9 | | --------:| :----------|:-------| :--------|
10 | | Preference Pairs | [RM training set](https://drive.google.com/file/d/12DefElb3DazIPeaIEwd0B_9La84Slc7f/view?usp=sharing) | [RM validation set](https://drive.google.com/file/d/1ZqTuupFxrK2m3_E6ezMRcdT_4k6zX-IW/view?usp=sharing) | [RM testing set](https://drive.google.com/file/d/1ite1KXZlGs1ojCVB20rLHlj7_3KlOULY/view?usp=sharing)|
11 | | Golden Answers | [APO positive responses](https://drive.google.com/file/d/1hDo6Sk8QX1c3kP_qJUgZ4J16kHAi0hEq/view?usp=sharing) | - | -|
12 | |User Queries | [APO negative responses](https://drive.google.com/file/d/1_wiKVKob6QVOHja4C_N-y5LlvHZE9ZiZ/view?usp=sharing) (Alpaca samples)| [LLM (Alpaca) rejection samples](https://drive.google.com/file/d/1ZpAXK0F-YC919_vP7gnyGpo8ezQGIv5O/view?usp=sharing)| [LLM testing Queries](https://drive.google.com/file/d/1ite1KXZlGs1ojCVB20rLHlj7_3KlOULY/view?usp=drive_link)|
13 |
--------------------------------------------------------------------------------
/figures/apo_framework_v.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Linear95/APO/4282775cf9f7dcfe04ed014835bb9d07cae5fbae/figures/apo_framework_v.png
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | import math
2 | from typing import List, Optional, Tuple, Union
3 | from pprint import pprint
4 |
5 | import torch
6 | import torch.utils.checkpoint
7 | from torch import nn
8 |
9 | from transformers.modeling_outputs import SequenceClassifierOutputWithPast
10 | from transformers import LlamaModel, LlamaForCausalLM, LlamaPreTrainedModel, LlamaTokenizer
11 | from transformers import BertModel, BertPreTrainedModel
12 |
13 |
14 | class LlamaRewardModel(LlamaPreTrainedModel):
15 | def __init__(self, config):
16 | super().__init__(config)
17 | self.model = LlamaModel(config)
18 | self.reward_head = nn.Linear(config.hidden_size, 1, bias=False)
19 | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
20 | self.post_init()
21 |
22 | def get_input_embeddings(self):
23 | return self.model.embed_tokens
24 |
25 | def set_input_embeddings(self, value):
26 | self.model.embed_tokens = value
27 |
28 | def floating_point_ops(self, inputs):
29 | return 0
30 |
31 | def forward(
32 | self,
33 | input_ids: torch.LongTensor = None,
34 | attention_mask: Optional[torch.Tensor] = None,
35 | position_ids: Optional[torch.LongTensor] = None,
36 | past_key_values: Optional[List[torch.FloatTensor]] = None,
37 | inputs_embeds: Optional[torch.FloatTensor] = None,
38 | labels: Optional[torch.LongTensor] = None,
39 | pooling_type: str = "average",
40 | padding_side: str = "right",
41 | use_cache: Optional[bool] = None,
42 | output_attentions: Optional[bool] = None,
43 | output_hidden_states: Optional[bool] = None,
44 | return_dict: Optional[bool] = None,
45 | ):
46 | r"""
47 | labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
48 | Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
49 | config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
50 | `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
51 | """
52 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
53 |
54 | transformer_outputs = self.model(
55 | input_ids,
56 | attention_mask=attention_mask,
57 | position_ids=position_ids,
58 | past_key_values=past_key_values,
59 | inputs_embeds=inputs_embeds,
60 | use_cache=use_cache,
61 | output_attentions=output_attentions,
62 | output_hidden_states=output_hidden_states,
63 | return_dict=return_dict,
64 | )
65 | hidden_states = transformer_outputs[0]
66 |
67 | lm_logits = self.lm_head(hidden_states)
68 |
69 |
70 | if input_ids is not None:
71 | batch_size = input_ids.shape[0]
72 | else:
73 | batch_size = inputs_embeds.shape[0]
74 |
75 | if self.config.pad_token_id is None and batch_size != 1:
76 | raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
77 | if self.config.pad_token_id is None:
78 | sequence_lengths = -1
79 | else:
80 | if input_ids is not None:
81 | sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(dim=-1)).to(hidden_states.device)
82 | else:
83 | sequence_lengths = -1
84 |
85 | if attention_mask is None:
86 | attention_mask = torch.ne(input_ids, self.config.pad_token_id).float()
87 |
88 | # print("hidden_states shape {}".format(hidden_states.shape))
89 | # print("attention_mask shape {}".format(attention_mask.shape))
90 |
91 | attention_mask_ext = attention_mask.unsqueeze(-1)
92 | if pooling_type in ["last", "eos"]:
93 | offset = 1 if pooling_type == "eos" else 2
94 | if padding_side == "right":
95 | pooled_hidden_state = hidden_states[torch.arange(batch_size, device=hidden_states.device), sequence_lengths - offset]
96 | else:
97 | pooled_hidden_state = hidden_states[torch.arange(batch_size, device=hidden_states.device), - offset]
98 |
99 | elif pooling_type == "average":
100 | pooled_hidden_state = (hidden_states * attention_mask_ext).sum(dim=1) / attention_mask_ext.sum(dim=1)
101 | elif pooling_type == "max":
102 | pooled_hidden_state = (hidden_states * attention_mask_ext).max(dim=1)[0]
103 | else:
104 | raise ValueError("The pooling method {} is not implemented!!".format(pooling_type))
105 |
106 | pooled_logits = self.reward_head(pooled_hidden_state)
107 |
108 | #pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
109 |
110 | return {
111 | "lm_logits": lm_logits,
112 | "rm_logits": pooled_logits,
113 | "hidden_states": transformer_outputs[0],
114 | "rm_embeddings": pooled_hidden_state
115 | }
116 |
117 |
118 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | numpy
2 | scikit-learn
3 | fire
4 | transformers==4.28.1
5 | torch==2.0.0
6 | sentencepiece
7 | tokenizers>=0.13.3
8 | wandb
9 | datasets
10 | accelerate==0.20.3
11 | deepspeed==0.12.6
12 | pydantic==1.10.7
13 |
--------------------------------------------------------------------------------
/reward_datasets.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | from tqdm import tqdm
4 | import gzip
5 | import random
6 | from copy import deepcopy
7 |
8 | from utils import print_rank_0
9 | from pprint import pprint
10 | import numpy as np
11 |
12 | import torch
13 | from torch.utils.data import Dataset
14 |
15 | from transformers import LlamaTokenizer
16 |
17 | from datasets import load_dataset
18 | from utils import read_json_or_jsonl_data
19 | from utils import DEFAULT_PAD_TOKEN, DEFAULT_BOS_TOKEN, DEFAULT_EOS_TOKEN, DEFAULT_UNK_TOKEN
20 | from utils import QUERY_PROMPT, SEP_TOKEN, STRING_SEP
21 |
22 |
23 | class TextRewardDataset(Dataset):
24 | def __init__(self, data):
25 | self.data = data
26 |
27 | def __getitem__(self, index):
28 | return self.data[index]
29 |
30 | def __len__(self,):
31 | return len(self.data)
32 |
33 |
34 | def reward_data_collactor(args, batch, tokenizer):
35 | input_ids, attention_mask = [], []
36 | query_ids, text, scores, apo_data_mask = [], [], [], []
37 |
38 | max_response_num = max([len(item['scores']) for item in batch])
39 | if args.debug_mode:
40 | print_rank_0(">>> response padding number: {}".format(max_response_num))
41 |
42 | for item1 in batch:
43 | item = prepare_data_item(args, item1,
44 | tokenizer=tokenizer,
45 | padding=(not len(batch) == 1),
46 | max_response_num=max_response_num)
47 |
48 | scores.append(item['scores'])
49 | input_ids.append(item['tokens']['input_ids'])
50 | attention_mask.append(item['tokens']['attention_mask'])
51 | text.append(item['text'])
52 |
53 | if item.get("type", "hh") == 'apo':
54 | apo_data_mask.append(1)
55 | # coeffs.append(args.apo_loss_coeff / args.apo_sample_num)
56 | else:
57 | apo_data_mask.append(0)
58 | # coeffs.append(args.rm_kl_coeff)
59 |
60 | if "query_ids" in item:
61 | query_ids.append(item['query_ids'])
62 |
63 | if len(query_ids) > 0:
64 | assert len(query_ids) == len(scores), f"not all items have key:query_id, in {batch}"
65 |
66 |
67 | return {
68 | "scores": scores,
69 | "input_ids": input_ids,
70 | "attention_mask": attention_mask,
71 | "query_ids": query_ids,
72 | "text": text,
73 | "apo_data_mask": apo_data_mask
74 | # "coeffs": coeffs
75 | }
76 |
77 |
78 | def reward_tokenize(sentences, tokenizer, padding="longest", add_sep_token=False):
79 | if isinstance(sentences, str):
80 | sentences = [sentences]
81 |
82 | input_ids = []
83 | for sent in sentences:
84 | if add_sep_token:
85 | query, response = sent.split(SEP_TOKEN)
86 | query_ids = tokenizer.encode(query, add_special_tokens=False)
87 | response_ids = tokenizer.encode(response, add_special_tokens=False)
88 | input_ids.append(
89 | [tokenizer.bos_token_id] + query_ids + [tokenizer.sep_token_id] + response_ids + [tokenizer.eos_token_id]
90 | )
91 | else:
92 | if SEP_TOKEN in sent:
93 | query, response = sent.split(SEP_TOKEN)
94 | query_ids = tokenizer.encode(query, add_special_tokens=False)
95 | response_ids = tokenizer.encode(response, add_special_tokens=False)
96 | input_ids.append(
97 | [tokenizer.bos_token_id] + query_ids + response_ids + [tokenizer.eos_token_id]
98 | )
99 | else:
100 | input_ids.append(
101 | [tokenizer.bos_token_id] + tokenizer.encode(sent, add_special_tokens=False) + [tokenizer.eos_token_id]
102 | )
103 |
104 | return batch_padding(input_ids, tokenizer, padding=padding)
105 |
106 |
107 | def batch_padding(input_ids, tokenizer, padding='longest'):
108 | if padding == 'longest':
109 | max_input_length = max([len(inp_ids) for inp_ids in input_ids])
110 | max_length = min(tokenizer.model_max_length, max_input_length)
111 | else:
112 | max_length = tokenizer.model_max_length
113 |
114 | outputs = {"input_ids": [], "attention_mask": []}
115 | for inp_ids in input_ids:
116 | attn_mask = [1] * len(inp_ids)
117 | if len(inp_ids) >= max_length:
118 | if tokenizer.truncation_side == 'left':
119 | inp_ids = inp_ids[-max_length :]
120 | attn_mask = attn_mask[-max_length :]
121 | else:
122 | inp_ids = inp_ids[:max_length]
123 | attn_mask = attn_mask[:max_length]
124 | else:
125 | if tokenizer.padding_side == 'left':
126 | inp_ids = [tokenizer.pad_token_id] * (max_length - len(inp_ids)) + inp_ids
127 | attn_mask = [0] * (max_length - len(attn_mask)) + attn_mask
128 | else:
129 | inp_ids = inp_ids + [tokenizer.pad_token_id] * (max_length - len(inp_ids))
130 | attn_mask = attn_mask + [0] * (max_length - len(attn_mask))
131 |
132 | outputs['input_ids'].append(deepcopy(inp_ids))
133 | outputs['attention_mask'].append(deepcopy(attn_mask))
134 | return outputs
135 |
136 |
137 | def prepare_data_item(args, item, tokenizer=None, padding=False, max_response_num=1):
138 | new_item = deepcopy(item)
139 | if not len(new_item['scores']) == len(new_item['text']):
140 | ValueError("invalid data point {}".format(new_item))
141 | return None
142 |
143 |
144 | if "query_ids" in new_item and not len(new_item['scores']) == len(new_item['query_ids']):
145 | ValueError("invalid data point {}".format(new_item))
146 | return None
147 |
148 | # score_idx = np.argsort(new_item['scores'])
149 | max_score = max(new_item['scores']) + 1e-5
150 | min_score = min(new_item['scores']) - 1e-5
151 | new_item['scores'] = [(score - min_score) / (max_score -min_score) for score in new_item['scores']]
152 |
153 | if padding:
154 | new_item['text'] += ["\n\nHuman: ?\n\nAssistant: Some"] * (max_response_num - len(new_item['text']))
155 | new_item['scores'] += [-1.] * (max_response_num - len(new_item['scores']))
156 | if "query_ids" in new_item:
157 | new_item['query_ids'] += [ "unk" + STRING_SEP + "pad" + STRING_SEP + "unk"] * (max_response_num - len(new_item['query_ids']))
158 |
159 |
160 | if tokenizer is not None:
161 | try:
162 | new_item['tokens'] = reward_tokenize(
163 | sentences=new_item['text'],
164 | tokenizer=tokenizer,
165 | padding="max_length" if padding else "longest",
166 | add_sep_token=args.add_sep_token
167 | )
168 | except:
169 | raise ValueError(f"get tokenization error with {new_item}")
170 |
171 | return new_item
172 |
173 |
174 |
175 | def load_rejection_samples(data_path):
176 | data_list = read_json_or_jsonl_data(data_path)
177 | outputs = []
178 | for item in data_list:
179 | # print_rank_0(item)
180 | if 'query' in item:
181 | query = str(item['query'])
182 | else:
183 | query = str(item['instruction'])
184 |
185 | query_id = str(item['query_id'])
186 |
187 | for key in item:
188 | #if "hh_best" in key or "gpt4" in key:
189 | if "sample_" in key or "gpt4" in key or 'ans_' in key:
190 | outputs.append({
191 | "text": [ query + SEP_TOKEN + str(item[key])],
192 | "query_ids": [ data_path + STRING_SEP + query_id + STRING_SEP + key],
193 | "scores": [-1]
194 | })
195 | print(f">>> totally get {len(outputs)} rejection samples.")
196 | print(outputs[0])
197 | return outputs
198 |
199 |
200 | def load_text_score_dataset(args, data_path):
201 | print_rank_0("loading text-scores dataset from: \n {}".format(data_path))
202 |
203 | if args.data_type == "reject_sample":
204 | data_list = load_rejection_samples(data_path)
205 | else:
206 | data_list = read_json_or_jsonl_data(data_path)
207 | for item in data_list:
208 | item['query_ids'] = [os.path.split(data_path)[1]] * len(item['text'])
209 |
210 |
211 |
212 | print_rank_0("finished loading with {} data.".format(len(data_list)))
213 | return data_list
214 |
215 |
216 |
--------------------------------------------------------------------------------
/tools/apo_data_converter.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | import json
4 | import argparse
5 |
6 | from pprint import pprint
7 | from tqdm import tqdm
8 |
9 | def preprocess_response(response):
10 | while "\nHuman" in response:
11 | # remove the additional generation of LLM after the current turn responses.
12 | response = response.split("\nHuman")[0].strip()
13 |
14 | return response
15 |
16 |
17 | def convert_item(item, sampling=False):
18 | sample_names = ['sample_0', 'sample_1', 'sample_2', 'sample_3']
19 | if "\nHuman:" in item['golden']:
20 | print(item)
21 | gpt_response = preprocess_response(item['golden'])
22 |
23 | if sampling:
24 | sample_names = [random.choice(sample_names)]
25 |
26 | outputs = []
27 | for sample_name in sample_names:
28 | query = item['query']
29 | query_id = str(item['query_id'])
30 | res_response = preprocess_response(item[sample_name])
31 | data_point = {
32 | "text": [query+''+gpt_response, query+''+res_response],
33 | "scores": [1., 0.],
34 | "type": "apo"
35 | }
36 | outputs.append(data_point)
37 |
38 | return outputs
39 |
40 |
41 | if __name__ == '__main__':
42 |
43 | parser = argparse.ArgumentParser(description ='parser for preference data processing.')
44 | parser.add_argument("--golden_data_path", type=str, default="", help="the path to golden annotation data.")
45 | parser.add_argument("--sample_data_path", type=str, default="", help="the path to llm sample data.")
46 | parser.add_argument("--output_dir", type=str, default="", help="the path to output converted data.")
47 | parser.add_argument("--apo_data_name", type=str, default="", help="the path to output converted data.")
48 | parser.add_argument("--sampling", action="store_true", help="whether random select one of the llm sample for each query")
49 | args = parser.parse_args()
50 |
51 |
52 | with open(args.sample_data_path, 'r') as f:
53 | sft_samples = json.load(f)
54 | print(f'finished loadding {len(sft_samples)} samples')
55 |
56 | with open(args.golden_data_path, 'r') as f:
57 | golden_samples = json.load(f)
58 |
59 | print(f'finished loadding {len(golden_samples)} samples')
60 |
61 | merged_data = {}
62 |
63 | for item in tqdm(sft_samples):
64 | query_id = str(item['query_id'])
65 | merged_data[query_id] = item
66 |
67 | for item in tqdm(golden_samples):
68 | query_id = str(item['query_id'])
69 | merged_data[query_id]['golden'] = item['golden']
70 |
71 | score_dict = None
72 | outputs = []
73 | for query_id, item in merged_data.items():
74 | new_results = convert_item(item, sampling=args.sampling)
75 | outputs.extend(new_results)
76 | # except:
77 | # pprint(item1)
78 | # error_count += 1
79 |
80 | # print(f"get {error_count} error items")
81 |
82 | if not os.path.exists(args.output_dir):
83 | os.mkdir(args.output_dir)
84 |
85 | if args.sampling:
86 | output_path = f"{args.output_dir}/{args.apo_data_name}_sampled_text_scores.json"
87 | else:
88 | output_path = f"{args.output_dir}/{args.apo_data_name}_text_scores.json"
89 |
90 | print(f'finished processing {len(outputs)} data at {output_path}')
91 | with open(output_path, 'w') as f:
92 | json.dump(outputs, f, ensure_ascii=False, indent=2)
93 |
94 |
95 |
--------------------------------------------------------------------------------
/tools/convert_apo_data.sh:
--------------------------------------------------------------------------------
1 |
2 | REPO_DIR=path/to/APO/repo
3 | DATA_DIR="${REPO_DIR}/data/hh-split"
4 |
5 | python3 ${REPO_DIR}/tools/apo_data_converter.py \
6 | --golden_data_path ${DATA_DIR}/rm_data/hh_split_rm.golden.json \
7 | --sample_data_path ${DATA_DIR}/rm_data/hh_split_rm_alpaca_v0.sample.json \
8 | --output_dir ${DATA_DIR}/apo_data \
9 | --apo_data_name "rm_apo_data_v0"
10 |
--------------------------------------------------------------------------------
/tools/inference_llm.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 | from copy import deepcopy
4 | import json
5 | import glob
6 | from dataclasses import dataclass
7 | from typing import Dict, Sequence
8 | from tqdm import tqdm
9 |
10 |
11 | import torch
12 | import torch.distributed as dist
13 | from torch.nn.parallel import DistributedDataParallel as DDP
14 | from torch.utils.data import Dataset, DataLoader
15 |
16 | import transformers
17 | from transformers import GenerationConfig, AutoModelForCausalLM, AutoTokenizer, LlamaForCausalLM
18 | from datasets import load_dataset
19 | from arguments import CustomTrainingArguments
20 |
21 | from utils import print_rank_0, read_json_or_jsonl_data, SEP_TOKEN
22 | from utils import DEFAULT_PAD_TOKEN, DEFAULT_EOS_TOKEN, DEFAULT_BOS_TOKEN, DEFAULT_UNK_TOKEN
23 |
24 | from reward_datasets import TextRewardDataset, batch_padding
25 |
26 | IGNORE_INDEX = -100
27 | PROMPT_DICT = {
28 | "prompt_input": (
29 | "\n\nHuman: {instruction}\n{input}\n\nAssistant: "
30 | ),
31 | "prompt_no_input": (
32 | "\n\nHuman: {instruction}\n\nAssistant: "
33 | ),
34 | }
35 |
36 | B_INST, E_INST = "[INST]", "[/INST]" # for llama-2-chat
37 |
38 | def load_query_data_for_generation(args, data_path):
39 | all_data = read_json_or_jsonl_data(data_path)
40 | outputs = []
41 | for idx, item in enumerate(all_data):
42 | if args.data_type == "comparison_pair":
43 | query = item['text'][0].split(SEP_TOKEN)[0]
44 | outputs.append({
45 | "query": query,
46 | "query_id": item.get("query_id", str(idx))
47 | })
48 | else:
49 | outputs.append({
50 | 'query': item['query'],
51 | "query_id": item.get("query_id", str(idx))
52 | })
53 | return TextRewardDataset(outputs)
54 |
55 |
56 | def query_data_collactor(args, batch, tokenizer):
57 | input_ids, attention_mask, labels = [], [], []
58 | text = [item['query'] for item in batch]
59 | query_ids = [item['query_id'] for item in batch]
60 |
61 | for sent in text:
62 | if args.model_prefix == "llama-2-chat":
63 | # check details at https://huggingface.co/meta-llama/Llama-2-7b-chat
64 | sent = sent.replace("\nAssistant", f" {E_INST} ").replace("\nHuman", f" {tokenizer.eos_token} {tokenizer.bos_token} {B_INST} ")
65 | sent = sent.strip().strip(tokenizer.eos_token)
66 | input_query_ids = tokenizer.encode(sent, add_special_tokens=False)
67 |
68 | else:
69 | input_query_ids = tokenizer.encode(sent)
70 |
71 | input_ids.append(input_query_ids)
72 |
73 | outputs = batch_padding(input_ids, tokenizer)
74 | outputs['query_ids'] = query_ids
75 | outputs['text'] = text
76 | return outputs
77 |
78 |
79 | def main():
80 | parser = transformers.HfArgumentParser(CustomTrainingArguments)
81 | args = parser.parse_args_into_dataclasses()[0]
82 |
83 | # setup model
84 | #---------------------------------------------------------------------------------
85 | device = torch.cuda.current_device()
86 | print_rank_0(f"start loading model from {args.model_name_or_path}")
87 | model = LlamaForCausalLM.from_pretrained(
88 | args.model_name_or_path,
89 | # torch_dtype=torch.float16,
90 | )
91 | print_rank_0(model)
92 |
93 | tokenizer = AutoTokenizer.from_pretrained(
94 | args.model_name_or_path,
95 | padding_side="left", # for batch decode
96 | truncation_side='left',
97 | model_max_length=args.max_length,
98 | )
99 |
100 | if tokenizer.pad_token is None:
101 | tokenizer.pad_token = tokenizer.eos_token
102 | tokenizer.pad_token_id = 0
103 | # tokenizer.pad_token = DEFAULT_PAD_TOKEN
104 | # smart_tokenizer_and_embedding_resize(
105 | # special_tokens_dict=dict(pad_token=DEFAULT_PAD_TOKEN),
106 | # tokenizer=tokenizer,
107 | # model=model,
108 | # )
109 |
110 | eval_dataset = load_query_data_for_generation(args, args.data_path)
111 |
112 | sampler = torch.utils.data.distributed.DistributedSampler(eval_dataset, shuffle=False)
113 | dataloader = DataLoader(
114 | eval_dataset,
115 | shuffle=False,
116 | collate_fn=lambda x: query_data_collactor(args, x, tokenizer),
117 | batch_size=args.per_device_eval_batch_size,
118 | sampler=sampler,
119 | )
120 |
121 | if args.task_type == "testing":
122 | generation_config = GenerationConfig(
123 | temperature=0.3,
124 | do_sample=True,
125 | max_new_tokens=512,
126 | top_k=5,
127 | top_p=0.85,
128 | bos_token_id=tokenizer.bos_token_id,
129 | eos_token_id=tokenizer.eos_token_id,
130 | pad_token_id=0,
131 | repetition_penalty=1.05,
132 | num_return_sequences=1,
133 | )
134 | elif args.task_type == "sampling":
135 | if args.model_prefix == "llama-2-chat":
136 | temperature = 0.6
137 | top_p=0.9
138 | else:
139 | temperature = 1.2
140 | top_p=1.
141 |
142 | generation_config = GenerationConfig(
143 | temperature=temperature, # default=0.8
144 | do_sample=True,
145 | min_length=1,
146 | max_new_tokens=256,
147 | top_p=top_p,
148 | bos_token_id=tokenizer.bos_token_id,
149 | eos_token_id=tokenizer.eos_token_id,
150 | pad_token_id=0,
151 | num_return_sequences=4,
152 | )
153 |
154 |
155 | model.to(device)
156 | model.eval()
157 |
158 | all_outputs = []
159 | progress_bar = tqdm(range(len(dataloader)), disable=(dist.get_rank() != 0))
160 | for step, batch in enumerate(dataloader):
161 | progress_bar.update(1)
162 | input_ids = torch.Tensor(batch['input_ids']).long().to(model.device)
163 | attention_mask = torch.Tensor(batch['attention_mask']).float().to(model.device)
164 | query_ids = batch['query_ids']
165 | text = batch['text']
166 |
167 | batch_size = input_ids.shape[0]
168 |
169 | with torch.no_grad():
170 | generation_output = model.generate(
171 | input_ids=input_ids,
172 | attention_mask=attention_mask,
173 | generation_config=generation_config,
174 | return_dict_in_generate=True,
175 | )
176 | output_seq = generation_output.sequences.reshape(batch_size, generation_config.num_return_sequences, -1)
177 |
178 | inputs_string = tokenizer.batch_decode(input_ids.reshape(batch_size, -1), skip_special_tokens=True)
179 |
180 | for idx in range(len(inputs_string)):
181 | new_item = {"query_id": query_ids[idx], "query": text[idx]}
182 | output_responses = tokenizer.batch_decode(output_seq[idx], skip_special_tokens=True)
183 | for res_idx, output_res in enumerate(output_responses):
184 | response_sample = output_res.replace(inputs_string[idx], '')
185 | if args.model_prefix == "llama-2-chat":
186 | #sent = sent.replace("\nAssistant", f" {E_INST} ").replace("\nHuman", f" {tokenizer.eos_token} {tokenizer.bos_token} {B_INST} ")
187 | response_sample = response_sample.replace(E_INST, "\nAssistant").replace(B_INST, "\nHuman")
188 | #response_sample = response_sample.replace(E_INST, "\n\nAssistant:").replace(B_INST, "\n\nHuman:")
189 |
190 | new_item[f"sample_{res_idx}"] = response_sample
191 |
192 | all_outputs.append(new_item)
193 |
194 | if dist.get_rank() == 0 and (step % 10 == 0):
195 | print_rank_0(f"finished {step} of {len(dataloader)}")
196 | print_rank_0(all_outputs[-1])
197 |
198 |
199 | output_file_prefix = f"{args.output_dir}/{args.model_prefix}_{args.task_type}_{args.data_suffix}"
200 | with open(f"{output_file_prefix}_rank{dist.get_rank()}.json", 'w') as f:
201 | json.dump(all_outputs, f, ensure_ascii=False, indent=2)
202 | print(f"rank {dist.get_rank()} finishs inference.")
203 |
204 | del model
205 | torch.cuda.empty_cache()
206 | dist.barrier()
207 | if dist.get_rank() == 0:
208 | result_paths = glob.glob(f"{output_file_prefix}_rank*.json")
209 | all_results = []
210 | for res_path in result_paths:
211 | new_results = read_json_or_jsonl_data(res_path)
212 | all_results.extend(new_results)
213 |
214 | print(f"totally loaded {len(all_results)} results")
215 | with open(f"{output_file_prefix}_results.json", 'w') as f:
216 | json.dump(all_results, f, ensure_ascii=False, indent=2)
217 | print(f"finished inference results merge.")
218 |
219 | if __name__ == "__main__":
220 | main()
221 |
--------------------------------------------------------------------------------
/tools/llm_response_gen.sh:
--------------------------------------------------------------------------------
1 | #! /bin/bash
2 |
3 | REPO_DIR=
4 | export PYTHONPATH=${REPO_DIR}
5 |
6 |
7 | MODEL_DIR="chavinlo/alpaca-native"
8 | MODEL_NAME="alpaca"
9 |
10 | #TASK_TYPE="testing"
11 | TASK_TYPE="sampling"
12 |
13 | DATA_DIR=${REPO_DIR}/data/hh-split
14 | if [[ "${TASK_TYPE}" == "testing" ]]; then
15 | DATA_PATH=${DATA_DIR}/eval_data/hh_cleaned_origin.test.json
16 | DATA_NAME="hh_test"
17 | DATA_TYPE="comparison_pair"
18 | else
19 | DATA_DIR=${REPO_DIR}/data/hh-split
20 | DATA_PATH=${DATA_DIR}/llm_data/hh_split_llm.train.json
21 | DATA_NAME="hh_llm_train"
22 | DATA_TYPE="comparison_pair"
23 | fi
24 |
25 | OUTPUT_DIR=${DATA_DIR}/sample_data
26 | mkdir -p $OUTPUT_DIR
27 |
28 |
29 | EVAL_MICRO_BATCH_SIZE=1
30 | MAX_INPUT_LENGTH=512
31 |
32 | torchrun --nproc_per_node 8 --master_port 6000 ${REPO_DIR}/tools/inference_llm.py \
33 | --model_name_or_path $MODEL_DIR \
34 | --model_prefix ${MODEL_NAME} \
35 | --data_path $DATA_PATH \
36 | --output_dir $OUTPUT_DIR \
37 | --per_device_eval_batch_size $EVAL_MICRO_BATCH_SIZE \
38 | --task_type ${TASK_TYPE} \
39 | --data_suffix ${DATA_NAME} \
40 | --max_length ${MAX_INPUT_LENGTH} \
41 | --data_type ${DATA_TYPE}
42 |
--------------------------------------------------------------------------------
/tools/rejection_sampling.py:
--------------------------------------------------------------------------------
1 | import os
2 | import re
3 | import json
4 | import argparse
5 | import glob
6 | from copy import deepcopy
7 | from pprint import pprint
8 |
9 | def get_best_key(item, item_scores, filter_pattern=False):
10 | max_score = -100000000.
11 | result = None
12 | for key, value in item_scores.items():
13 | if len(item_scores) > 1 and key == "hh_best":
14 | continue
15 | if value > max_score:
16 | if item[key].strip() == "":
17 | continue
18 | else:
19 | result = deepcopy(key)
20 | max_score = value
21 | return result
22 |
23 | def get_scores_from_list(score_list):
24 | score_dict = {}
25 | for item in score_list:
26 | for key, value in item.items():
27 | query_id, ans_id = key.split(':')
28 | if query_id in score_dict:
29 | if ans_id in score_dict[query_id] and value != score_dict[query_id][ans_id]:
30 | print(f">>>>> warning!")
31 | print(f">>>>> replacing {query_id}: {ans_id} value {score_dict[query_id][ans_id]} with {value}")
32 |
33 | score_dict[query_id][ans_id] = value
34 | else:
35 | score_dict[query_id] = {ans_id: value}
36 | return score_dict
37 |
38 | def get_scores(data_path, rm_scorer):
39 | file_names = glob.glob(f"{data_path}_*pred_{rm_scorer}*rank*.jsonl")
40 | score_dict = {}
41 | for file_name in file_names:
42 | with open(file_name, 'r') as f:
43 | lines = f.readlines()
44 | scores = [json.loads(l.strip()) for l in lines]
45 | for item in scores:
46 | for key, value in item.items():
47 | query_id, ans_id = key.split(':')
48 | if query_id in score_dict:
49 | if ans_id in score_dict[query_id] and value != score_dict[query_id][ans_id]:
50 | print(f">>>>> warning!")
51 | print(f">>>>> replacing {query_id}: {ans_id} value {score_dict[query_id][ans_id]} with {value}")
52 |
53 | score_dict[query_id][ans_id] = value
54 | else:
55 | score_dict[query_id] = {ans_id: value}
56 | return score_dict
57 |
58 | def rejection_sample(data_path, score_path=None, rm_scorer=None):
59 | with open(data_path, 'r') as f:
60 | data_list = json.load(f)
61 |
62 | print(f"totally load {len(data_list)} samples for rejection sampling")
63 |
64 | if score_path is not None:
65 | with open(score_path, 'r') as f:
66 | score_list = json.load(f)
67 | data_scores = get_scores_from_list(score_list)
68 | elif rm_scorer is not None:
69 | data_scores = get_scores(data_path, rm_scorer)
70 | else:
71 | raise ValueError('cannot found score data')
72 |
73 | hh_best_counter = 0
74 | outputs = []
75 | for item in data_list:#[:10]:
76 | query_id = str(item['query_id'])
77 | item_scores = data_scores[query_id]
78 |
79 |
80 | #best_res_key = max(item_scores, key=item_scores.get)
81 | best_res_key = get_best_key(item, item_scores, filter_pattern=True)
82 | if best_res_key is None:
83 | best_res_key = get_best_key(item, item_scores, filter_pattern=False)
84 | if best_res_key is None:
85 | print(item)
86 | continue
87 |
88 | item['target'] = item[best_res_key]
89 | item['scores'] = item_scores
90 |
91 | if best_res_key == "hh_best":
92 | hh_best_counter += 1
93 | outputs.append(deepcopy(item))
94 | print(f"get {hh_best_counter} data with hh_best selected")
95 | return outputs
96 |
97 |
98 | if __name__ == "__main__":
99 | parser = argparse.ArgumentParser(description ='parser for preference data processing.')
100 | parser.add_argument("--data_path", type=str, default="", help="the path to input data.")
101 | parser.add_argument("--output_dir", type=str, default="", help="the path to output data.")
102 | parser.add_argument("--output_file_name", type=str, default="", help="the path to output data.")
103 | parser.add_argument("--score_path", type=str, default="", help="the rm model name to get score")
104 | parser.add_argument("--rm_scorer", type=str, default="", help="the rm model name to get score")
105 |
106 | parser.add_argument("--domain", type=str, default="general", help="the domain of the preference data, selected from [general, normal, academy, business, entertainment, literature].")
107 |
108 | parser.add_argument("--convert", action='store_true', help="whether convert responses into the preference text-score format.")
109 | parser.add_argument("--to_pairs", action='store_true', help="whether convert responses into pair comparisons.")
110 |
111 | args = parser.parse_args()
112 | #outputs = rejection_sample(args.data_path, f"{args.data_path}_{args.rm_scorer}_prediction.json")
113 | outputs = rejection_sample(args.data_path, args.score_path, args.rm_scorer)
114 |
115 | if len(args.output_file_name) == 0:
116 |
117 | _, file_name = os.path.split(args.score_path)
118 | print(file_name)
119 | args.output_file_name = f"{args.output_dir}/{file_name}_sft.json"
120 |
121 | with open(f"{args.output_file_name}", 'w', encoding="utf-8") as f:
122 | json.dump(outputs, f, ensure_ascii=False, indent=2)
123 |
124 |
125 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import copy
3 | import logging
4 | from dataclasses import dataclass, field
5 | from typing import Dict, Optional, Sequence, List
6 | import json
7 | import random
8 |
9 | import torch
10 | import torch.distributed as dist
11 | import transformers
12 |
13 | from torch.utils.data import Dataset
14 | from transformers import Trainer, AutoConfig
15 | from transformers import EvalPrediction
16 |
17 |
18 | from model import LlamaRewardModel
19 |
20 | from reward_datasets import TextRewardDataset, reward_data_collactor
21 | from reward_datasets import load_text_score_dataset
22 | from arguments import CustomTrainingArguments
23 | from trainer import RewardModelTrainer, compute_metrics
24 |
25 | from utils import print_rank_0, set_reward_tokenizer, merge_json_or_jsonl_data
26 | from utils import DEFAULT_PAD_TOKEN, DEFAULT_BOS_TOKEN, DEFAULT_EOS_TOKEN, DEFAULT_UNK_TOKEN
27 | from utils import QUERY_PROMPT, SEP_TOKEN, STRING_SEP, INFER_TMP_FILE
28 |
29 |
30 |
31 | def get_eval_datasets(args):
32 | data_dict = {}
33 |
34 | for data_path in args.eval_data_path:
35 | eval_data_list = load_text_score_dataset(args=args, data_path=data_path)
36 |
37 | eval_dataset = TextRewardDataset(eval_data_list)
38 |
39 | data_name = os.path.split(data_path)[-1]
40 | data_dict[data_name] = eval_dataset
41 | print_rank_0(">> finished loading {} data with data size = {}".format(data_name, len(eval_dataset)))
42 |
43 | if args.debug_mode:
44 | print_rank_0(f">>> check loaded data:")
45 | print_rank_0(f">>> {eval_dataset[0]}")
46 |
47 | return data_dict
48 |
49 | def get_train_dataset(args):
50 | all_train_data = []
51 | for train_data_path in args.train_data_path:
52 | train_data = load_text_score_dataset(args=args, data_path=train_data_path)
53 | all_train_data.extend(train_data)
54 |
55 | if args.debug_mode:
56 | print_rank_0(f">>> check loaded data:")
57 | print_rank_0(f">>> {all_train_data[0]}")
58 |
59 | train_set = TextRewardDataset(all_train_data)
60 | return train_set
61 |
62 |
63 | def train():
64 | parser = transformers.HfArgumentParser(CustomTrainingArguments)
65 | args = parser.parse_args_into_dataclasses()[0]
66 | print_rank_0(args)
67 |
68 | # load data
69 | #---------------------------------------------------------------------------------
70 | if args.do_train:
71 | train_dataset = get_train_dataset(args)
72 | else:
73 | train_dataset = None
74 |
75 | eval_dataset_dict = get_eval_datasets(args)
76 |
77 | # setup model
78 | #---------------------------------------------------------------------------------
79 | print_rank_0(f"Begin loading model from {args.model_name_or_path}")
80 | if args.model_type == "reward":
81 | model = LlamaRewardModel.from_pretrained(args.model_name_or_path)
82 | elif args.model_type == "sft":
83 | model = LlamaForCausalLM.from_pretrained(args.model_name_or_path)
84 |
85 | print_rank_0(model)
86 | print_rank_0(f"Finished loading model from {args.model_name_or_path}")
87 |
88 | model.is_parallelizable = True
89 | model.model_parallel = True
90 |
91 | # setup tokenizer
92 | #---------------------------------------------------------------------------------
93 | tokenizer = transformers.AutoTokenizer.from_pretrained(
94 | args.model_name_or_path,
95 | model_max_length=args.max_length,
96 | padding_side=args.padding_side,
97 | truncation_side=args.truncation_side,
98 | use_fast=False,
99 | )
100 |
101 | if args.model_type == "reward":
102 | model, tokenizer = set_reward_tokenizer(model=model, tokenizer=tokenizer)
103 |
104 | # build trainer
105 | #---------------------------------------------------------------------------------
106 |
107 | trainer = RewardModelTrainer(
108 | model=model,
109 | tokenizer=tokenizer,
110 | args=args,
111 | compute_metrics=lambda x: compute_metrics(args, x),
112 | train_dataset=train_dataset,
113 | eval_dataset=eval_dataset_dict,
114 | data_collator=lambda x: reward_data_collactor(args, x, tokenizer)
115 | )
116 |
117 | if args.do_train:
118 | if args.eval_at_start:
119 | for eval_set_name, eval_dataset in eval_dataset_dict.items():
120 | eval_result = trainer.evaluate(eval_dataset=eval_dataset, metric_key_prefix="eval_"+eval_set_name)
121 | print_rank_0(eval_result)
122 |
123 | if args.resume_from_checkpoint:
124 | train_result = trainer.train(resume_from_checkpoint=args.resume_from_checkpoint)
125 | else:
126 | train_result = trainer.train()
127 |
128 | metrics = train_result.metrics
129 | trainer.log_metrics("train", metrics)
130 | trainer.save_metrics("train", metrics)
131 |
132 | trainer.save_state()
133 | trainer.save_model(output_dir=args.output_dir)
134 |
135 |
136 | final_eval_results ={}
137 | for eval_set_name, eval_dataset in eval_dataset_dict.items():
138 | args.current_eval_filename = os.path.split(eval_set_name)[-1]
139 | eval_result = trainer.evaluate(eval_dataset=eval_dataset, metric_key_prefix="eval_"+eval_set_name)
140 |
141 | print_rank_0(eval_result)
142 | final_eval_results[eval_set_name] = eval_result
143 |
144 | if args.task_type == "inference":
145 | torch.distributed.barrier()
146 | if dist.get_rank() == 0:
147 | print_rank_0(eval_set_name)
148 | data_path = eval_dataset[0]['query_ids'][0].split(STRING_SEP)[0]
149 |
150 | result_temp = INFER_TMP_FILE.format(data_path=data_path,
151 | data_suffix=args.data_suffix,
152 | rank="*")
153 | print_rank_0(f"begin merge temp file from {result_temp}")
154 | outputs = merge_json_or_jsonl_data(result_temp)
155 | with open(f"{data_path}_pred_{args.data_suffix}_results.json", 'w') as f:
156 | json.dump(outputs, f, ensure_ascii=False, indent=2)
157 |
158 |
159 |
160 | with open(f"{args.output_dir}/final_eval_results.json", 'w') as f:
161 | json.dump(final_eval_results, f, ensure_ascii=False)
162 |
163 |
164 |
165 | if __name__ == "__main__":
166 | train()
167 |
--------------------------------------------------------------------------------
/trainer.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import logging
3 | from dataclasses import dataclass, field
4 | from typing import Dict, Optional, Sequence, List
5 | import json
6 | import datetime
7 |
8 | import numpy as np
9 | import sklearn
10 |
11 | import torch
12 | import torch.distributed as dist
13 | import torch.nn.functional as F
14 | import transformers
15 |
16 |
17 | from transformers import Trainer, AutoConfig
18 | from transformers import EvalPrediction
19 |
20 | from utils import print_rank_0, calibration_error, numpy_sigmoid
21 | from utils import QUERY_PROMPT, SEP_TOKEN, STRING_SEP, INFER_TMP_FILE
22 |
23 |
24 |
25 | def rm_calibration_errors(args, labels, probs, masks, num_bins):
26 | label_list = labels.reshape(-1).tolist()
27 | prob_list = probs.reshape(-1).tolist()
28 | mask_list = masks.reshape(-1).tolist()
29 |
30 | y_true, y_prob = [], []
31 | for label, prob, mask in zip(label_list, prob_list, mask_list):
32 | if mask:
33 | y_true.append(label)
34 | y_prob.append(prob)
35 |
36 | if args.debug_mode:
37 | print_rank_0(f">>>>> check calibration inputs mask filtered...")
38 | print_rank_0(f">>>>>>>> y_true: {y_true[:10]}")
39 | print_rank_0(f">>>>>>>> y_prob: {y_prob[:10]}")
40 |
41 | return calibration_error(np.array(y_true), np.array(y_prob), n_bins=num_bins)
42 |
43 |
44 | def compute_metrics(args, prediction: EvalPrediction):
45 | logits = torch.from_numpy(prediction.predictions)
46 | scores = torch.from_numpy(prediction.label_ids)
47 |
48 | if args.debug_mode:
49 | print_rank_0(f">> check eval_prediction inputs...")
50 | print_rank_0(f">>> logits: {logits[:5]}")
51 | print_rank_0(f">>> scores: {scores[:5]}")
52 |
53 | logits_diff = logits.unsqueeze(1) - logits.unsqueeze(2) # [batch_size, num_sample, num_sample]
54 |
55 | score_mask_larger = (scores.unsqueeze(1) > scores.unsqueeze(2)) * 1.
56 | score_mask_smaller = (scores.unsqueeze(1) < scores.unsqueeze(2)) * 1.
57 | score_mask = score_mask_larger - score_mask_smaller
58 | pad_mask = (scores >= 0).unsqueeze(1) * 1. * (scores >= 0).unsqueeze(2)
59 |
60 |
61 | # calculate accuracy...
62 | pred_compare = (logits_diff.detach() * score_mask > 0.) * 1.
63 | total_mask = (score_mask_larger + score_mask_smaller) * pad_mask
64 | #correct_compare = (pred_compare == score_mask_larger) * total_mask
65 | correct_compare = pred_compare * total_mask
66 |
67 | all_acc = correct_compare.sum() / total_mask.sum() if total_mask.sum() > 0 else total_mask.sum()
68 | average_score = logits.mean().item()
69 |
70 | calibration_errors = {}
71 | if args.rm_calibration:
72 | for num_bins in args.calibration_bins:
73 | expected_error, average_error, max_error = rm_calibration_errors(
74 | args=args,
75 | labels=score_mask_larger,
76 | #probs=torch.sigmoid(logits_diff),
77 | probs=numpy_sigmoid(logits_diff.numpy()),
78 | masks=total_mask,
79 | num_bins=num_bins
80 | )
81 | # if args.save_calibration and args.task_type == "eval":
82 | # time = datetime.datetime.now()
83 | # time_stamp = time.strftime("%d-%H:%M:%S")
84 | # if dist.get_rank() == 0:
85 | # outputs = {"prob_true": prob_true.tolist(), "prob_pred": prob_pred.tolist()}
86 | # with open(f"{args.output_dir}/calibration_result_t{args.current_eval_filename}_bin{num_bins}.json", 'w') as f:
87 | # json.dump(outputs, f, ensure_ascii=False, indent=2)
88 |
89 | calibration_errors[f"calibration_ECE_bin{num_bins}"] = expected_error
90 | calibration_errors[f"calibration_ACE_bin{num_bins}"] = average_error
91 | calibration_errors[f"calibration_MCE_bin{num_bins}"] = max_error
92 |
93 | if args.debug_mode:
94 | print_rank_0(f">> check eval_prediction outputs...")
95 | print_rank_0(f">>> correct_compare: {correct_compare}")
96 | print_rank_0(f">>> total_mask: {total_mask}")
97 | print_rank_0(f">>> all_acc: {all_acc}")
98 | print_rank_0(f">>> calibration error: {calibration_errors}")
99 |
100 | return {"Preference Acc": all_acc.item(), "Avg Score": average_score, **calibration_errors}
101 |
102 |
103 | def reward_model_loss(logits, scores, coeffs=None, loss_type="ranking"): # `logits`, `scores` with shape [bs, r], `coeffs` with shape [bs]
104 | logits_diff = logits.unsqueeze(1) - logits.unsqueeze(2) # shape [bs, r, r]
105 |
106 | score_mask_larger = (scores.unsqueeze(1) > scores.unsqueeze(2)) * 1.
107 | score_mask_smaller = (scores.unsqueeze(1) < scores.unsqueeze(2)) * 1.
108 | score_mask = score_mask_larger - score_mask_smaller
109 | pad_mask = (scores >= 0).unsqueeze(1) * 1. * (scores >= 0).unsqueeze(2)
110 |
111 | total_mask = (score_mask_larger + score_mask_smaller) * pad_mask
112 |
113 | if loss_type == "diff":
114 | log_prob = logits_diff * score_mask * pad_mask # shape [bs, r, r]
115 | else:
116 | log_prob = torch.nn.functional.logsigmoid(logits_diff * score_mask * pad_mask) # shape [bs, r, r]
117 |
118 | if coeffs is not None:
119 | log_prob = log_prob * coeffs.unsqueeze(-1).unsqueeze(-1)
120 |
121 | total_loss = - (log_prob * total_mask).sum()
122 | total_pairs = total_mask.sum()
123 |
124 | return total_loss / total_pairs if total_pairs > 0 else total_loss
125 | #return - log_prob.mean()
126 |
127 |
128 | class RewardModelTrainer(Trainer):
129 | def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys: Optional[List[str]] = None):
130 | device = model.device
131 | labels = torch.Tensor(inputs['scores']).float().to(device)
132 |
133 | with torch.no_grad():
134 | loss, logits = self.compute_loss(model, inputs, return_outputs=True)
135 | loss = loss.mean().detach()
136 | # logits = outputs.logits
137 |
138 | if prediction_loss_only:
139 | return (loss, None, None)
140 |
141 | return (loss, logits, labels)
142 |
143 |
144 | def compute_loss(self, model, inputs, return_outputs=False):
145 | device = model.device
146 | scores = torch.Tensor(inputs['scores']).float().to(device) # shape [batch_size, response_num]
147 | input_ids = torch.Tensor(inputs['input_ids']).long().to(device) # shape [batch_size, response_num, seq_length]
148 | attention_mask = torch.Tensor(inputs['attention_mask']).float().to(device)
149 | # coeffs = torch.Tensor(inputs['coeffs']).float().to(device)
150 | apo_data_mask = torch.Tensor(inputs['apo_data_mask']).float().to(device) # shape [batch_size] value 1 if apo data
151 |
152 | batch_size, response_num, seq_length = input_ids.shape
153 |
154 | if self.args.debug_mode:
155 | print(f">>> input_ids shape {input_ids.shape}")
156 |
157 | outputs = model(
158 | input_ids=input_ids.view(-1, seq_length),
159 | attention_mask=attention_mask.view(-1, seq_length),
160 | padding_side=self.args.padding_side,
161 | pooling_type=self.args.pooling_type
162 | )
163 |
164 | batch_logits = outputs['rm_logits'].view(batch_size, response_num) # shape [bs, r]
165 |
166 | if self.args.task_type == "apo":
167 | rm_kl_loss = reward_model_loss(batch_logits, scores, coeffs=(1. - apo_data_mask), loss_type="ranking")
168 | apo_loss = reward_model_loss(batch_logits, scores, coeffs=apo_data_mask, loss_type=self.args.apo_loss_type)
169 | total_loss = self.args.rm_kl_coeff * rm_kl_loss + self.args.apo_loss_coeff / self.args.apo_sample_num * apo_loss
170 | else:
171 | total_loss = reward_model_loss(batch_logits, scores, coeffs=None, loss_type="ranking")
172 |
173 | if self.args.debug_mode:
174 | print_rank_0(f">>> debug")
175 | print_rank_0(f">>> input_ids shape {input_ids.shape}")
176 | print_rank_0(f">>> Batch rm logits {batch_logits}")
177 |
178 | if self.args.task_type == "inference":
179 | query_ids = inputs['query_ids']
180 | new_results = []
181 |
182 | for i_bs in range(batch_size):
183 | for j_sample in range(response_num):
184 | data_path, query_id, ans_id = query_ids[i_bs][j_sample].split(STRING_SEP)
185 | new_results.append(
186 | json.dumps({f"{query_id}:{ans_id}": batch_logits[i_bs][j_sample].item()}, ensure_ascii=False)
187 | )
188 |
189 | output_file_path = INFER_TMP_FILE.format(data_path=data_path,
190 | data_suffix=self.args.data_suffix,
191 | rank=dist.get_rank())
192 | with open(output_file_path, 'a') as f:
193 | f.write("\n".join(new_results)+"\n")
194 |
195 | return (total_loss, batch_logits) if return_outputs else total_loss
196 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import dataclasses
2 | import logging
3 | import math
4 | import os
5 | import io
6 | import sys
7 | import time
8 | import json
9 | import glob
10 | from typing import Optional, Sequence, Union, List, Dict
11 |
12 | import openai
13 | import tqdm
14 | from openai import openai_object
15 | import copy
16 |
17 | import numpy as np
18 | import torch
19 |
20 | SEP_TOKEN=""
21 | STRING_SEP="<:>"
22 |
23 | DEFAULT_PAD_TOKEN = "[PAD]"
24 | DEFAULT_EOS_TOKEN = ""
25 | DEFAULT_BOS_TOKEN = ""
26 | DEFAULT_UNK_TOKEN = ""
27 |
28 | QUERY_PROMPT="## Human:\n{request}\n\n## Assistant:\n{response}"
29 |
30 | INFER_TMP_FILE="{data_path}_pred_{data_suffix}_results_rank_{rank}.jsonl"
31 |
32 | def numpy_sigmoid(x):
33 | # r_x = x - x.max()
34 | return 1. / (1. + np.exp(-x))
35 |
36 |
37 | def read_json_or_jsonl_data(data_path):
38 | if data_path[-5:] == ".json":
39 | with open(data_path, 'r') as f:
40 | data_list = json.load(f)
41 | else:
42 | with open(data_path, 'r') as f:
43 | lines = f.read().strip().split('\n')
44 | data_list = [json.loads(l) for l in lines]
45 |
46 | print_rank_0(f">>> totally load {len(data_list)} data from {data_path}")
47 | return data_list
48 |
49 | def merge_json_or_jsonl_data(data_path_pattern):
50 | file_names = glob.glob(data_path_pattern)
51 | print_rank_0(f"load {len(file_names)} files from {data_path_pattern}.")
52 | outputs = []
53 | for file_name in file_names:
54 | new_data = read_json_or_jsonl_data(file_name)
55 | if isinstance(new_data, list):
56 | outputs.extend(new_data)
57 | elif isinstance(new_data, dict):
58 | outputs.append(new_data)
59 | return outputs
60 |
61 |
62 | def print_rank_0(message):
63 | if torch.distributed.is_initialized():
64 | if torch.distributed.get_rank() == 0:
65 | print(message, flush=True)
66 | else:
67 | print(message, flush=True)
68 |
69 |
70 | def set_reward_tokenizer(model, tokenizer):
71 |
72 | tokenizer.pad_token_id = 3
73 | tokenizer.bos_token_id = 1
74 | tokenizer.eos_token_id = 2
75 | tokenizer.unk_token_id = 0
76 | tokenizer.sep_token_id = 4
77 |
78 | model.config.pad_token_id = tokenizer.pad_token_id
79 | model.config.bos_token_id = tokenizer.bos_token_id
80 | model.config.eos_token_id = tokenizer.eos_token_id
81 |
82 | print_rank_0(tokenizer)
83 | return model, tokenizer
84 |
85 |
86 |
87 |
88 | def calibration_error(
89 | y_true,
90 | y_prob,
91 | n_bins=5,
92 | strategy="uniform",
93 | ):
94 | if len(y_true) == 0:
95 | return 0., 0., 0.
96 |
97 | if strategy == "quantile": # Determine bin edges by distribution of data
98 | quantiles = np.linspace(0, 1, n_bins + 1)
99 | bins = np.percentile(y_prob, quantiles * 100)
100 | elif strategy == "uniform":
101 | bins = np.linspace(0.0, 1.0, n_bins + 1)
102 | else:
103 | raise ValueError(
104 | "Invalid entry to 'strategy' input. Strategy "
105 | "must be either 'quantile' or 'uniform'."
106 | )
107 |
108 | binids = np.searchsorted(bins[1:-1], y_prob)
109 |
110 | bin_sums = np.bincount(binids, weights=y_prob, minlength=len(bins))
111 | bin_true = np.bincount(binids, weights=y_true, minlength=len(bins))
112 | bin_total = np.bincount(binids, minlength=len(bins))
113 |
114 | nonzero = bin_total != 0
115 | # prob_true = bin_true[nonzero] / bin_total[nonzero]
116 | # prob_pred = bin_sums[nonzero] / bin_total[nonzero]
117 |
118 | # return prob_true, prob_pred, bin_total[nonzero]
119 | try:
120 | expected_error = np.abs(bin_sums - bin_true).sum() / len(y_prob)
121 | average_error = (np.abs(bin_sums[nonzero] - bin_true[nonzero]) / bin_total[nonzero]).mean()
122 | max_error = (np.abs(bin_sums[nonzero] - bin_true[nonzero]) / bin_total[nonzero]).max()
123 | except Exception as e:
124 | print_rank_0(">>>> WARNING: Encounter error in calibration calculation")
125 | print_rank_0(e)
126 | expected_error, average_error, max_error = 0., 0., 0.
127 |
128 | return expected_error, average_error, max_error
129 |
--------------------------------------------------------------------------------