├── LICENSE.txt
├── README.md
├── asset
├── .DS_Store
├── XMAiNframe.png
├── cobol_diagram-Page-1.drawio.png
├── cobol_diagram-Page-8.drawio.pdf
├── cobol_diagram-Page-8.drawio.png
├── sample_1.png
├── sample_2.png
└── sample_3.png
├── recipes
├── accelerate_configs
│ ├── deepspeed_zero3.yaml
│ ├── deepspeed_zero3_lora.yaml
│ ├── fsdp.yaml
│ ├── fsdp_qlora.yaml
│ └── multi_gpu.yaml
└── deepseek
│ ├── full.yaml
│ ├── full_instruct.yaml
│ ├── lora_instruct.yaml
│ └── lora_sft.yaml
├── requirements.txt
├── scripts
├── ft.sh
└── instruct.sh
├── src
├── alignment
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-310.pyc
│ │ ├── __init__.cpython-38.pyc
│ │ ├── configs.cpython-310.pyc
│ │ ├── configs.cpython-38.pyc
│ │ ├── data.cpython-310.pyc
│ │ ├── data.cpython-38.pyc
│ │ ├── decontaminate.cpython-310.pyc
│ │ ├── decontaminate.cpython-38.pyc
│ │ ├── model_utils.cpython-310.pyc
│ │ └── model_utils.cpython-38.pyc
│ ├── configs.py
│ ├── data.py
│ ├── decontaminate.py
│ ├── model_utils.py
│ └── release.py
├── data
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-310.pyc
│ │ ├── data.cpython-310.pyc
│ │ └── utils.cpython-310.pyc
│ ├── data.py
│ └── utils.py
└── model
│ ├── __init__.py
│ ├── __pycache__
│ ├── __init__.cpython-310.pyc
│ └── tokenizer.cpython-310.pyc
│ └── tokenizer.py
├── train_instruct.py
├── train_raw.py
└── utils.py
/LICENSE.txt:
--------------------------------------------------------------------------------
1 | Apache License Version 2.0
2 |
3 | Copyright (c) 2024 FPT Software, Inc.
4 | All rights reserved.
5 |
6 | Apache License
7 | Version 2.0, January 2004
8 | http://www.apache.org/licenses/
9 |
10 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
11 |
12 | 1. Definitions.
13 |
14 | "License" shall mean the terms and conditions for use, reproduction,
15 | and distribution as defined by Sections 1 through 9 of this document.
16 |
17 | "Licensor" shall mean the copyright owner or entity authorized by
18 | the copyright owner that is granting the License.
19 |
20 | "Legal Entity" shall mean the union of the acting entity and all
21 | other entities that control, are controlled by, or are under common
22 | control with that entity. For the purposes of this definition,
23 | "control" means (i) the power, direct or indirect, to cause the
24 | direction or management of such entity, whether by contract or
25 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
26 | outstanding shares, or (iii) beneficial ownership of such entity.
27 |
28 | "You" (or "Your") shall mean an individual or Legal Entity
29 | exercising permissions granted by this License.
30 |
31 | "Source" form shall mean the preferred form for making modifications,
32 | including but not limited to software source code, documentation
33 | source, and configuration files.
34 |
35 | "Object" form shall mean any form resulting from mechanical
36 | transformation or translation of a Source form, including but
37 | not limited to compiled object code, generated documentation,
38 | and conversions to other media types.
39 |
40 | "Work" shall mean the work of authorship, whether in Source or
41 | Object form, made available under the License, as indicated by a
42 | copyright notice that is included in or attached to the work
43 | (an example is provided in the Appendix below).
44 |
45 | "Derivative Works" shall mean any work, whether in Source or Object
46 | form, that is based on (or derived from) the Work and for which the
47 | editorial revisions, annotations, elaborations, or other modifications
48 | represent, as a whole, an original work of authorship. For the purposes
49 | of this License, Derivative Works shall not include works that remain
50 | separable from, or merely link (or bind by name) to the interfaces of,
51 | the Work and Derivative Works thereof.
52 |
53 | "Contribution" shall mean any work of authorship, including
54 | the original version of the Work and any modifications or additions
55 | to that Work or Derivative Works thereof, that is intentionally
56 | submitted to Licensor for inclusion in the Work by the copyright owner
57 | or by an individual or Legal Entity authorized to submit on behalf of
58 | the copyright owner. For the purposes of this definition, "submitted"
59 | means any form of electronic, verbal, or written communication sent
60 | to the Licensor or its representatives, including but not limited to
61 | communication on electronic mailing lists, source code control systems,
62 | and issue tracking systems that are managed by, or on behalf of, the
63 | Licensor for the purpose of discussing and improving the Work, but
64 | excluding communication that is conspicuously marked or otherwise
65 | designated in writing by the copyright owner as "Not a Contribution."
66 |
67 | "Contributor" shall mean Licensor and any individual or Legal Entity
68 | on behalf of whom a Contribution has been received by Licensor and
69 | subsequently incorporated within the Work.
70 |
71 | 2. Grant of Copyright License. Subject to the terms and conditions of
72 | this License, each Contributor hereby grants to You a perpetual,
73 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
74 | copyright license to reproduce, prepare Derivative Works of,
75 | publicly display, publicly perform, sublicense, and distribute the
76 | Work and such Derivative Works in Source or Object form.
77 |
78 | 3. Grant of Patent License. Subject to the terms and conditions of
79 | this License, each Contributor hereby grants to You a perpetual,
80 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
81 | (except as stated in this section) patent license to make, have made,
82 | use, offer to sell, sell, import, and otherwise transfer the Work,
83 | where such license applies only to those patent claims licensable
84 | by such Contributor that are necessarily infringed by their
85 | Contribution(s) alone or by combination of their Contribution(s)
86 | with the Work to which such Contribution(s) was submitted. If You
87 | institute patent litigation against any entity (including a
88 | cross-claim or counterclaim in a lawsuit) alleging that the Work
89 | or a Contribution incorporated within the Work constitutes direct
90 | or contributory patent infringement, then any patent licenses
91 | granted to You under this License for that Work shall terminate
92 | as of the date such litigation is filed.
93 |
94 | 4. Redistribution. You may reproduce and distribute copies of the
95 | Work or Derivative Works thereof in any medium, with or without
96 | modifications, and in Source or Object form, provided that You
97 | meet the following conditions:
98 |
99 | (a) You must give any other recipients of the Work or
100 | Derivative Works a copy of this License; and
101 |
102 | (b) You must cause any modified files to carry prominent notices
103 | stating that You changed the files; and
104 |
105 | (c) You must retain, in the Source form of any Derivative Works
106 | that You distribute, all copyright, patent, trademark, and
107 | attribution notices from the Source form of the Work,
108 | excluding those notices that do not pertain to any part of
109 | the Derivative Works; and
110 |
111 | (d) If the Work includes a "NOTICE" text file as part of its
112 | distribution, then any Derivative Works that You distribute must
113 | include a readable copy of the attribution notices contained
114 | within such NOTICE file, excluding those notices that do not
115 | pertain to any part of the Derivative Works, in at least one
116 | of the following places: within a NOTICE text file distributed
117 | as part of the Derivative Works; within the Source form or
118 | documentation, if provided along with the Derivative Works; or,
119 | within a display generated by the Derivative Works, if and
120 | wherever such third-party notices normally appear. The contents
121 | of the NOTICE file are for informational purposes only and
122 | do not modify the License. You may add Your own attribution
123 | notices within Derivative Works that You distribute, alongside
124 | or as an addendum to the NOTICE text from the Work, provided
125 | that such additional attribution notices cannot be construed
126 | as modifying the License.
127 |
128 | You may add Your own copyright statement to Your modifications and
129 | may provide additional or different license terms and conditions
130 | for use, reproduction, or distribution of Your modifications, or
131 | for any such Derivative Works as a whole, provided Your use,
132 | reproduction, and distribution of the Work otherwise complies with
133 | the conditions stated in this License.
134 |
135 | 5. Submission of Contributions. Unless You explicitly state otherwise,
136 | any Contribution intentionally submitted for inclusion in the Work
137 | by You to the Licensor shall be under the terms and conditions of
138 | this License, without any additional terms or conditions.
139 | Notwithstanding the above, nothing herein shall supersede or modify
140 | the terms of any separate license agreement you may have executed
141 | with Licensor regarding such Contributions.
142 |
143 | 6. Trademarks. This License does not grant permission to use the trade
144 | names, trademarks, service marks, or product names of the Licensor,
145 | except as required for reasonable and customary use in describing the
146 | origin of the Work and reproducing the content of the NOTICE file.
147 |
148 | 7. Disclaimer of Warranty. Unless required by applicable law or
149 | agreed to in writing, Licensor provides the Work (and each
150 | Contributor provides its Contributions) on an "AS IS" BASIS,
151 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
152 | implied, including, without limitation, any warranties or conditions
153 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
154 | PARTICULAR PURPOSE. You are solely responsible for determining the
155 | appropriateness of using or redistributing the Work and assume any
156 | risks associated with Your exercise of permissions under this License.
157 |
158 | 8. Limitation of Liability. In no event and under no legal theory,
159 | whether in tort (including negligence), contract, or otherwise,
160 | unless required by applicable law (such as deliberate and grossly
161 | negligent acts) or agreed to in writing, shall any Contributor be
162 | liable to You for damages, including any direct, indirect, special,
163 | incidental, or consequential damages of any character arising as a
164 | result of this License or out of the use or inability to use the
165 | Work (including but not limited to damages for loss of goodwill,
166 | work stoppage, computer failure or malfunction, or any and all
167 | other commercial damages or losses), even if such Contributor
168 | has been advised of the possibility of such damages.
169 |
170 | 9. Accepting Warranty or Additional Liability. While redistributing
171 | the Work or Derivative Works thereof, You may choose to offer,
172 | and charge a fee for, acceptance of support, warranty, indemnity,
173 | or other liability obligations and/or rights consistent with this
174 | License. However, in accepting such obligations, You may act only
175 | on Your own behalf and on Your sole responsibility, not on behalf
176 | of any other Contributor, and only if You agree to indemnify,
177 | defend, and hold each Contributor harmless for any liability
178 | incurred by, or claims asserted against, such Contributor by reason
179 | of your accepting any such warranty or additional liability.
180 |
181 | END OF TERMS AND CONDITIONS
182 |
183 | APPENDIX: How to apply the Apache License to your work.
184 |
185 | To apply the Apache License to your work, attach the following
186 | boilerplate notice, with the fields enclosed by brackets "{}"
187 | replaced with your own identifying information. (Don't include
188 | the brackets!) The text should be enclosed in the appropriate
189 | comment syntax for the file format. We also recommend that a
190 | file or class name and description of purpose be included on the
191 | same "printed page" as the copyright notice for easier
192 | identification within third-party archives.
193 |
194 | Copyright {yyyy} {name of copyright owner}
195 |
196 | Licensed under the Apache License, Version 2.0 (the "License");
197 | you may not use this file except in compliance with the License.
198 | You may obtain a copy of the License at
199 |
200 | http://www.apache.org/licenses/LICENSE-2.0
201 |
202 | Unless required by applicable law or agreed to in writing, software
203 | distributed under the License is distributed on an "AS IS" BASIS,
204 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
205 | See the License for the specific language governing permissions and
206 | limitations under the License.
207 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | # XMAiNframe: A Large Language Model for Mainframe Modernization
4 | [](https://opensource.org/licenses/MIT)
5 | [](link)
6 | [](https://huggingface.co/collections/Fsoft-AIC/xmainframe-66aca02d5b552e62033dc2bc)
7 | [](https://www.python.org/downloads/release/python-3100/)
8 |
9 |
10 |
11 |
12 | ## Table of Contents
13 | - [Introduction](#dataset-summary)
14 | - [Demonstration](#demo)
15 | - [Procedure of Data Construction](#procedure-of-data-construction)
16 | - [Mainframe-Training](#mainframe-training)
17 | - [Mainframe-Instruct](#mainframe-instruct)
18 | - [Model Download](#model-download)
19 | - [Evaluation Results](#evaluation-results)
20 | - [Usage](#usage)
21 | - [Fine-tune XMAiNframe](#how-to-fine-tune-xmainframe)
22 | - [Inference](#inference)
23 | - [License](##licensing-information)
24 | - [Acknowledgements](#acknowledgements)
25 | - [Contact Us](#contact-us)
26 | - [Citation Information](#citation-information)
27 |
28 |
29 |
30 | # Introduction
31 |
32 | We are introducing **XMAiNframe**, a state-of-the-art large language model (LLM) specifically designed with knowledge of mainframe legacy systems and COBOL codebases. XMAiNframe is built on top of DeepSeek-Coder 7B and is available with 7B and 10.5B parameters.
33 | Additionally, we present [MainframeBench](https://huggingface.co/datasets/Fsoft-AIC/MainframeBench), a comprehensive benchmark for assessing mainframe knowledge, including multiple-choice questions, question answering, and COBOL code summarization. Our empirical evaluations demonstrate that XMAiNframe consistently outperforms existing state-of-the-art LLMs across these tasks. Specifically, XMAiNframe achieves 30% higher accuracy than DeepSeek-Coder on multiple-choice questions, doubles the BLEU score of Mixtral-Instruct 8x7B on question answering, and scores six times higher than GPT-3.5 on COBOL summarization. Our work highlights the potential of XMAiNframe to drive significant advancements in managing and modernizing legacy systems, thereby enhancing productivity and saving time for software developers.
34 |
35 | # Demonstration
36 |
37 | In this section, we demonstrate the capabilities of XMAiNframe by comparing it with the leading language model, DeepSeek-Coder-7B. We evaluate the performance of each model by showcasing their responses to a series of realistic questions related to mainframe knowledge. The images below illustrate how each model handles identical prompts. As shown, the responses generated by XMAiNframe are not only accurate but also more detailed and comprehensive compared to those from the base model, DeepSeek-Coder-7B. This makes XMAiNframe particularly valuable for developers seeking a reliable and thorough AI assistant in the mainframe environment.
38 |
39 |
40 |
41 |

42 |

43 |

44 |
45 |
46 |
47 |
48 |
49 | # Procedure of Data Construction
50 | ## Mainframe-Training
51 |
52 | We utilized two different sources: using the GitHub API to collect COBOL projects hosted on GitHub and gathering online document data relevant to mainframes. In total, Mainframe-Training Dataset consists of 236 million tokens from documents about the mainframe technology and COBOL constructs. In the pre-training process, we combined our Mainframe-Training Dataset with [SlimOrca-Dedup](https://huggingface.co/datasets/Open-Orca/SlimOrca-Dedup) to enrich the model’s mainframe knowledge while retaining its general capabilities.
53 |
54 | ## Mainframe-Instruct
55 |
56 |
57 |
58 |

59 |
60 |
61 |
62 | Mainframe-Instruct is a high-quality synthetic dataset created through 5 steps:
63 |
64 | - Step 1: 300 seed data instances about Mainframe and COBOL are gathered and annotated by our domain experts.
65 |
66 | - Step 2: Using popular LLMs to enrich Mainframe-Instruct from seed data.
67 |
68 | - Step 3: Utilizing GPT-4 as an evaluator to judge model responses, scoring the outputs and ranking responses in a pairwise manner.
69 |
70 | - Step 4: Filtering and manually checking.
71 |
72 | - Step 5: Dividing Mainframe-Instruct into three tasks: Multiple Choice Questions, Question Answering, and COBOL summarization.
73 |
74 | Below are the statistics of Mainframe-Instruct Dataset:
75 |
76 |
77 |
78 | |
79 | Training Samples |
80 | Validating Samples |
81 | Testing Samples |
82 |
83 |
84 | Multiple Choice Questions |
85 | 13.894 |
86 | 1.544 |
87 | 1.931 |
88 |
89 |
90 | Question Answering |
91 | 18.692 |
92 | 2.078 |
93 | 2.598 |
94 |
95 |
96 | COBOL Summarization |
97 | 9.081 |
98 | 1.010 |
99 | 2.523 |
100 |
101 |
102 |
103 |
104 | [MainframeBench](https://huggingface.co/datasets/Fsoft-AIC/MainframeBench), our benchmark for mainframe knowledge, is the testing set in Mainframe-Instruct Dataset. This benchmark is used to evaluate our LLMs with others which is now available at Huggingface datasets.
105 |
106 | ```python
107 | from datasets import load_dataset
108 |
109 | # Load each sub-set in MainframeBench
110 | QA_set = load_dataset("Fsoft-AIC/MainframeBench", 'question_answering')
111 | MC_set = load_dataset("Fsoft-AIC/MainframeBench", 'multiple_choice_question')
112 | Summarization_set = load_dataset("Fsoft-AIC/MainframeBench", 'COBOL_code_summarization')
113 | ```
114 |
115 | # Model Download
116 | We release XMAiNframe with 7B and 10.5B parameters, including base and instruct models, to the public. XMAiNframe 10.5B is expanded from DeepSeek-Coder 7B by the depth up-scaling method without introducing additional modules or dynamic expert selection methods.
117 |
118 |
119 |
120 | | **Model** | **Download** |
121 | | :-----------------------------: | :----------------------------------------------------------: |
122 | | XMAiNframe-base-7b | [🤗 HuggingFace](https://https://huggingface.co/Fsoft-AIC/XMAiNframe-base-7b/) |
123 | | XMAiNframe-instruct-7b | [🤗 HuggingFace](https://huggingface.co/Fsoft-AIC/XMAiNframe-instruct-7b) |
124 | | XMAiNframe-base-10.5b | [🤗 HuggingFace](https://huggingface.co/Fsoft-AIC/XMAiNframe-base-10.5b) |
125 | | XMAiNframe-instruct-10.5b | [🤗 HuggingFace](https://huggingface.co/Fsoft-AIC/XMAiNframe-instruct-10.5b) |
126 |
127 |
128 |
129 |
130 | # Evaluation Results
131 | ## Multiple Choice Question Task
132 |
133 |
134 |
135 | Model |
136 | Accuracy (%) |
137 |
138 |
139 | GPT-4 |
140 | 73.90 |
141 |
142 |
143 | GPT-3.5 |
144 | 74.56 |
145 |
146 |
147 | Mixtral-Instruct 8x7B |
148 | 68.12 |
149 |
150 |
151 | Mistral-Instruct 7B |
152 | 69.29 |
153 |
154 |
155 | Neural-Chat |
156 | 66.35 |
157 |
158 |
159 | DeepSeek-Coder-Instruct 6.7B |
160 | 47.49 |
161 |
162 |
163 | DeepSeek-Coder-Instruct 33B |
164 | 53.29 |
165 |
166 |
167 | XMAiNframe-Instruct 7B |
168 | 68.57 |
169 |
170 |
171 | XMAiNframe-Instruct 10.5B |
172 | 77.89 |
173 |
174 |
175 |
176 | ## Question Answering Task
177 |
178 |
179 |
180 | Models |
181 | MAP |
182 | F1-Score |
183 | BERTScore |
184 | RougeL |
185 | Meteor |
186 | BLEU-4 |
187 |
188 |
189 | GPT 4 |
190 | 0.12 |
191 | 0.19 |
192 | 0.88 |
193 | 0.18 |
194 | 0.34 |
195 | 5.71 |
196 |
197 |
198 | GPT 3.5 |
199 | 0.14 |
200 | 0.22 |
201 | 0.89 |
202 | 0.21 |
203 | 0.38 |
204 | 7.36 |
205 |
206 |
207 | Mixtral-Instruct 8x7B |
208 | 0.27 |
209 | 0.31 |
210 | 0.9 |
211 | 0.29 |
212 | 0.38 |
213 | 11.39 |
214 |
215 |
216 | Mistral-Instruct 7B |
217 | 0.12 |
218 | 0.19 |
219 | 0.87 |
220 | 0.18 |
221 | 0.34 |
222 | 5.74 |
223 |
224 |
225 | Neural-Chat |
226 | 0.13 |
227 | 0.21 |
228 | 0.88 |
229 | 0.2 |
230 | 0.36 |
231 | 6.45 |
232 |
233 |
234 | DeepSeek-Coder-Instruct 6.7B |
235 | 0.09 |
236 | 0.15 |
237 | 0.86 |
238 | 0.14 |
239 | 0.30 |
240 | 4.09 |
241 |
242 |
243 | DeepSeek-Coder-Instruct 33B |
244 | 0.09 |
245 | 0.15 |
246 | 0.86 |
247 | 0.15 |
248 | 0.31 |
249 | 4.41 |
250 |
251 |
252 | XMAiNframe-Instruct 7B |
253 | 0.45 |
254 | 0.42 |
255 | 0.92 |
256 | 0.4 |
257 | 0.42 |
258 | 20.43 |
259 |
260 |
261 | XMAiNframe-Instruct 10.5B |
262 | 0.43 |
263 | 0.42 |
264 | 0.92 |
265 | 0.4 |
266 | 0.42 |
267 | 20.93 |
268 |
269 |
270 |
271 | ## COBOL Code Summarization
272 |
273 |
274 | Models |
275 | MAP |
276 | F1-Score |
277 | BERTScore |
278 | RougeL |
279 | Meteor |
280 | BLEU-4 |
281 |
282 |
283 | GPT 4 |
284 | 0.12 |
285 | 0.19 |
286 | 0.88 |
287 | 0.18 |
288 | 0.34 |
289 | 5.71 |
290 |
291 |
292 | GPT 3.5 |
293 | 0.14 |
294 | 0.22 |
295 | 0.89 |
296 | 0.21 |
297 | 0.38 |
298 | 7.36 |
299 |
300 |
301 | Mixtral-Instruct 8x7B |
302 | 0.27 |
303 | 0.31 |
304 | 0.9 |
305 | 0.29 |
306 | 0.38 |
307 | 11.39 |
308 |
309 |
310 | Mistral-Instruct 7B |
311 | 0.12 |
312 | 0.19 |
313 | 0.87 |
314 | 0.18 |
315 | 0.34 |
316 | 5.74 |
317 |
318 |
319 | Neural-Chat |
320 | 0.13 |
321 | 0.21 |
322 | 0.88 |
323 | 0.2 |
324 | 0.36 |
325 | 6.45 |
326 |
327 |
328 | DeepSeek-Coder-Instruct 6.7B |
329 | 0.09 |
330 | 0.15 |
331 | 0.86 |
332 | 0.14 |
333 | 0.30 |
334 | 4.09 |
335 |
336 |
337 | DeepSeek-Coder-Instruct 33B |
338 | 0.09 |
339 | 0.15 |
340 | 0.86 |
341 | 0.15 |
342 | 0.31 |
343 | 4.41 |
344 |
345 |
346 | XMAiNframe-Instruct 7B |
347 | 0.45 |
348 | 0.42 |
349 | 0.92 |
350 | 0.4 |
351 | 0.42 |
352 | 20.43 |
353 |
354 |
355 | XMAiNframe-Instruct 10.5B |
356 | 0.43 |
357 | 0.42 |
358 | 0.92 |
359 | 0.4 |
360 | 0.42 |
361 | 20.93 |
362 |
363 |
364 |
365 | For more evaluation details and settings, please check our paper.
366 |
367 |
368 | # Usage
369 | ## Fine-tune XMAiNframe
370 | To run the code in this project, first, create a Python virtual environment using e.g. Conda:
371 |
372 | ```shell
373 | conda create -n xmainframe python=3.10 && conda activate xmainframe
374 | ```
375 |
376 | You can then install the remaining package dependencies as follows:
377 |
378 | ```shell
379 | git clone https://github.com/FSoft-AI4Code/XMainframe.git
380 | cd XMainframe
381 | pip install -r requirements.txt
382 | ```
383 | You can now check out the `scripts` and `recipes` directories for instructions on how to fine-tune our model 🪁!
384 |
385 |
386 | ## Inference
387 |
388 | Here is a code snippet with `apply_chat_template` to show you how to load the tokenizer and model and how to generate content.
389 |
390 |
391 | ```python
392 | from transformers import AutoTokenizer, AutoModelForCausalLM
393 | tokenizer = AutoTokenizer.from_pretrained("Fsoft-AIC/XMAiNframe-instruct-7b")
394 | model = AutoModelForCausalLM.from_pretrained("Fsoft-AIC/XMAiNframe-instruct-7b")
395 | messages=[
396 | {'from':'system', 'value': "You are a helpful assistant"},
397 | {'from': 'human', 'value': 'What is the future of Mainframe?'}
398 | ]
399 | inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt").to(model.device)
400 |
401 | outputs = model.generate(inputs, max_new_tokens=512, do_sample=False, top_k=50, top_p=0.95, num_return_sequences=1, eos_token_id=tokenizer.eos_token_id)
402 | print(tokenizer.decode(outputs[0][len(inputs[0]):], skip_special_tokens=True))
403 | ```
404 |
405 |
406 |
407 |
408 |
409 |
410 | # License
411 | This code repository is licensed under [the MIT License](LICENSE)
412 |
413 | # Acknowledgements
414 | This codebase is adapted from:
415 | - [alignment-handbook](https://github.com/huggingface/alignment-handbook)
416 |
417 | # Contact us
418 | If you have any questions, comments or suggestions, please do not hesitate to contact us.
419 | - Website: [fpt-aicenter](https://www.fpt-aicenter.com/ai-residency/)
420 | - Email: support.ailab@fpt.com
421 |
422 | # Citation Information
423 | More details can be found in our [technical report](https://github.com/FSoft-AI4Code/).
424 |
425 | If you're using XMAiNframe, please cite using this BibTeX:
426 | ```bibtex
427 | @article{dau2024xmainframe,
428 | title={XMainframe: A Large Language Model for Mainframe Modernization},
429 | author={Dau, Anh TV and Dao, Hieu Trung and Nguyen, Anh Tuan and Tran, Hieu Trung and Nguyen, Phong X and Bui, Nghi DQ},
430 | journal={arXiv preprint arXiv:2408.04660},
431 | year={2024}
432 | }
433 | ```
434 |
--------------------------------------------------------------------------------
/asset/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FSoft-AI4Code/XMainframe/801033385e0457667aff301dc07df1c1b8ca4b04/asset/.DS_Store
--------------------------------------------------------------------------------
/asset/XMAiNframe.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FSoft-AI4Code/XMainframe/801033385e0457667aff301dc07df1c1b8ca4b04/asset/XMAiNframe.png
--------------------------------------------------------------------------------
/asset/cobol_diagram-Page-1.drawio.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FSoft-AI4Code/XMainframe/801033385e0457667aff301dc07df1c1b8ca4b04/asset/cobol_diagram-Page-1.drawio.png
--------------------------------------------------------------------------------
/asset/cobol_diagram-Page-8.drawio.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FSoft-AI4Code/XMainframe/801033385e0457667aff301dc07df1c1b8ca4b04/asset/cobol_diagram-Page-8.drawio.pdf
--------------------------------------------------------------------------------
/asset/cobol_diagram-Page-8.drawio.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FSoft-AI4Code/XMainframe/801033385e0457667aff301dc07df1c1b8ca4b04/asset/cobol_diagram-Page-8.drawio.png
--------------------------------------------------------------------------------
/asset/sample_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FSoft-AI4Code/XMainframe/801033385e0457667aff301dc07df1c1b8ca4b04/asset/sample_1.png
--------------------------------------------------------------------------------
/asset/sample_2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FSoft-AI4Code/XMainframe/801033385e0457667aff301dc07df1c1b8ca4b04/asset/sample_2.png
--------------------------------------------------------------------------------
/asset/sample_3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FSoft-AI4Code/XMainframe/801033385e0457667aff301dc07df1c1b8ca4b04/asset/sample_3.png
--------------------------------------------------------------------------------
/recipes/accelerate_configs/deepspeed_zero3.yaml:
--------------------------------------------------------------------------------
1 | compute_environment: LOCAL_MACHINE
2 | debug: false
3 | deepspeed_config:
4 | deepspeed_multinode_launcher: standard
5 | offload_optimizer_device: none
6 | offload_param_device: none
7 | zero3_init_flag: true
8 | zero3_save_16bit_model: true
9 | zero_stage: 3
10 | distributed_type: DEEPSPEED
11 | downcast_bf16: 'no'
12 | machine_rank: 0
13 | main_training_function: main
14 | mixed_precision: bf16
15 | num_machines: 1
16 | num_processes: 8
17 | rdzv_backend: static
18 | same_network: true
19 | tpu_env: []
20 | tpu_use_cluster: false
21 | tpu_use_sudo: false
22 | use_cpu: false
23 |
--------------------------------------------------------------------------------
/recipes/accelerate_configs/deepspeed_zero3_lora.yaml:
--------------------------------------------------------------------------------
1 | compute_environment: LOCAL_MACHINE
2 | debug: false
3 | deepspeed_config:
4 | deepspeed_multinode_launcher: standard
5 | offload_optimizer_device: cpu #none
6 | offload_param_device: cpu #none
7 | zero3_init_flag: true
8 | zero3_save_16bit_model: false
9 | zero_stage: 3
10 | distributed_type: DEEPSPEED
11 | downcast_bf16: 'no'
12 | machine_rank: 0
13 | main_training_function: main
14 | mixed_precision: bf16
15 | num_machines: 1
16 | num_processes: 4
17 | rdzv_backend: static
18 | same_network: true
19 | tpu_env: []
20 | tpu_use_cluster: false
21 | tpu_use_sudo: false
22 | use_cpu: false
--------------------------------------------------------------------------------
/recipes/accelerate_configs/fsdp.yaml:
--------------------------------------------------------------------------------
1 | compute_environment: LOCAL_MACHINE
2 | debug: false
3 | distributed_type: FSDP
4 | downcast_bf16: 'no'
5 | enable_cpu_affinity: false
6 | fsdp_config:
7 | fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
8 | fsdp_backward_prefetch: BACKWARD_PRE
9 | fsdp_cpu_ram_efficient_loading: true
10 | fsdp_forward_prefetch: true
11 | fsdp_offload_params: false
12 | fsdp_sharding_strategy: FULL_SHARD
13 | fsdp_state_dict_type: SHARDED_STATE_DICT
14 | fsdp_sync_module_states: true
15 | fsdp_use_orig_params: true
16 | machine_rank: 0
17 | main_training_function: main
18 | mixed_precision: bf16
19 | num_machines: 1
20 | num_processes: 4
21 | rdzv_backend: static
22 | same_network: true
23 | tpu_env: []
24 | tpu_use_cluster: false
25 | tpu_use_sudo: false
26 | use_cpu: false
27 |
--------------------------------------------------------------------------------
/recipes/accelerate_configs/fsdp_qlora.yaml:
--------------------------------------------------------------------------------
1 | compute_environment: LOCAL_MACHINE
2 | debug: false
3 | distributed_type: FSDP
4 | downcast_bf16: 'no'
5 | # cpu_ram_efficient_loading: "true"
6 | fsdp_config:
7 | fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
8 | fsdp_backward_prefetch: BACKWARD_PRE
9 | fsdp_cpu_ram_efficient_loading: true
10 | fsdp_forward_prefetch: false
11 | fsdp_offload_params: true
12 | fsdp_sharding_strategy: FULL_SHARD
13 | fsdp_state_dict_type: SHARDED_STATE_DICT
14 | fsdp_sync_module_states: true
15 | fsdp_use_orig_params: false
16 | machine_rank: 0
17 | main_training_function: main
18 | mixed_precision: 'no'
19 | num_machines: 1
20 | num_processes: 2
21 | rdzv_backend: static
22 | same_network: true
23 | tpu_env: []
24 | tpu_use_cluster: false
25 | tpu_use_sudo: false
26 | use_cpu: false
--------------------------------------------------------------------------------
/recipes/accelerate_configs/multi_gpu.yaml:
--------------------------------------------------------------------------------
1 | compute_environment: LOCAL_MACHINE
2 | debug: false
3 | distributed_type: MULTI_GPU
4 | downcast_bf16: 'no'
5 | gpu_ids: all
6 | machine_rank: 0
7 | main_training_function: main
8 | mixed_precision: bf16
9 | num_machines: 1
10 | num_processes: 8
11 | rdzv_backend: static
12 | same_network: true
13 | tpu_env: []
14 | tpu_use_cluster: false
15 | tpu_use_sudo: false
16 | use_cpu: false
17 |
--------------------------------------------------------------------------------
/recipes/deepseek/full.yaml:
--------------------------------------------------------------------------------
1 | # Model arguments
2 | # model_name_or_path: deepseek-ai/deepseek-coder-7b-base-v1.5
3 | model_name_or_path: /cm/shared/anhdtv7/mainframe_gpt/data/deepseek-7b-ft-full
4 | model_revision: main
5 | torch_dtype: bfloat16
6 | use_flash_attention_2: false
7 | trust_remote_code: true
8 |
9 | # Data training arguments
10 | dataset_mixer:
11 | # HuggingFaceH4/ultrachat_200k: 1.0
12 | /cm/archive/hieudt47/workspace/data/mainframe_df_v1.1_chunks_8096.feather: 0.1
13 | /cm/archive/hieudt47/workspace/data/textbook_quality_programming.feather: 0.1
14 | dataset_splits:
15 | - train
16 | - test
17 | preprocessing_num_workers: 12
18 |
19 | # SFT trainer config
20 | bf16: true
21 | do_eval: false
22 | evaluation_strategy: steps
23 | gradient_accumulation_steps: 2
24 | gradient_checkpointing: true
25 | hub_model_id: deepseek-7b-ft-full_longcontext
26 | hub_strategy: every_save
27 | learning_rate: 2.0e-05
28 | log_level: info
29 | logging_steps: 5
30 | logging_strategy: steps
31 | lr_scheduler_type: cosine
32 | max_seq_length: 16000
33 | max_steps: -1
34 | num_train_epochs: 1
35 | output_dir: data/deepseek-7b-ft-full_longcontext
36 | overwrite_output_dir: true
37 | per_device_eval_batch_size: 4
38 | per_device_train_batch_size: 1
39 | push_to_hub: false
40 | remove_unused_columns: true
41 | report_to:
42 | - tensorboard
43 | # save_strategy: "epoch"
44 | save_strategy: "steps"
45 | save_steps: 100
46 | save_total_limit: null
47 | seed: 42
48 | tf32: true
49 |
--------------------------------------------------------------------------------
/recipes/deepseek/full_instruct.yaml:
--------------------------------------------------------------------------------
1 | # Model arguments
2 | model_name_or_path: /cm/shared/anhdtv7/mainframe_gpt/data/deepseek-7b-ft-full/checkpoint-5425
3 | # model_name_or_path: mistralai/Mistral-7B-v0.1
4 | model_revision: main
5 | torch_dtype: bfloat16
6 | use_flash_attention_2: true
7 |
8 | # Data training arguments
9 | chat_template: "{{ bos_token }}{% for message in messages %}\n{% if message['from'] == 'human' %}\n{{ '<|user|>\n' + message['value']}}\n{% elif message['from'] == 'system' %}\n{{ '<|system|>\n' + message['value']}}\n{% elif message['from'] == 'gpt' %}\n{{ '<|assistant|>\n' + message['value'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}"
10 | dataset_mixer:
11 | # HuggingFaceH4/ultrachat_200k: 1.0
12 | /cm/archive/hieudt47/workspace/data/mainframe_df_v1.1_chunks_8096.feather: 0.1
13 | /cm/archive/hieudt47/workspace/data/textbook_quality_programming.feather: 0.1
14 | dataset_splits:
15 | - train
16 | - test
17 | preprocessing_num_workers: 12
18 |
19 | # SFT trainer config
20 | bf16: true
21 | do_eval: false
22 | evaluation_strategy: epoch
23 | gradient_accumulation_steps: 2
24 | gradient_checkpointing: true
25 | gradient_checkpointing_kwargs:
26 | use_reentrant: False
27 | hub_model_id: deepseek_instruct_full_data_pretrained
28 | hub_strategy: every_save
29 | learning_rate: 2.0e-05
30 | log_level: info
31 | logging_steps: 3
32 | logging_strategy: steps
33 | lr_scheduler_type: cosine
34 | max_seq_length: 4096
35 | max_steps: -1
36 | num_train_epochs: 5
37 | output_dir: data/deepseek_instruct_full_data_pretrained
38 | overwrite_output_dir: true
39 | per_device_eval_batch_size: 4
40 | per_device_train_batch_size: 2
41 | push_to_hub: false
42 | remove_unused_columns: true
43 | report_to:
44 | - tensorboard
45 | save_strategy: "epoch"
46 | save_total_limit: null
47 | seed: 45
48 | warmup_ratio: 0.1
--------------------------------------------------------------------------------
/recipes/deepseek/lora_instruct.yaml:
--------------------------------------------------------------------------------
1 | # Model arguments
2 | model_name_or_path: /cm/shared/anhdtv7/mainframe_gpt/data/deepseek_16b_lora
3 | model_revision: main
4 | torch_dtype: float16
5 | use_flash_attention_2: true
6 | trust_remote_code: true
7 |
8 |
9 |
10 | # LoRA arguments
11 | load_in_4bit: true
12 | use_peft: true
13 | lora_r: 16
14 | lora_alpha: 16
15 | lora_dropout: 0.1
16 | lora_target_modules:
17 | - q_proj
18 | - k_proj
19 | - v_proj
20 | - o_proj
21 | - gate_proj
22 | - lm_head
23 | - up_proj
24 | - down_proj
25 |
26 | # Data training arguments
27 | chat_template: "{% for message in messages %}\n{% if message['from'] == 'human' %}\n{{ '<|user|>\n' + message['value'] + eos_token }}\n{% elif message['from'] == 'system' %}\n{{ '<|system|>\n' + message['value'] + eos_token }}\n{% elif message['from'] == 'gpt' %}\n{{ '<|assistant|>\n' + message['value'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}"
28 | dataset_mixer:
29 | # HuggingFaceH4/ultrachat_200k: 1.0
30 | /cm/archive/hieudt47/workspace/data/mainframe_df_v1.1.feather: 1
31 | /cm/archive/hieudt47/workspace/data/textbook_quality_programming.feather: 1
32 | dataset_splits:
33 | - train
34 | - test
35 | preprocessing_num_workers: 12
36 |
37 | # SFT trainer config
38 | bf16: true
39 | do_eval: true
40 | evaluation_strategy: epoch
41 | gradient_accumulation_steps: 4
42 | gradient_checkpointing: true
43 | gradient_checkpointing_kwargs:
44 | use_reentrant: false
45 | hub_model_id: deepseek_instruct_lora_16b_multigpu
46 | hub_strategy: every_save
47 | learning_rate: 2.0e-04
48 | log_level: info
49 | logging_steps: 5
50 | logging_strategy: steps
51 | lr_scheduler_type: cosine
52 | max_seq_length: 4096
53 | max_steps: -1
54 | num_train_epochs: 3
55 | output_dir: data/deepseek_instruct_lora_16b_multigpu
56 | overwrite_output_dir: true
57 | per_device_eval_batch_size: 4
58 | per_device_train_batch_size: 3
59 | push_to_hub: false
60 | # dataset_num_proc: 4
61 | report_to:
62 | - tensorboard
63 | save_strategy: "epoch"
64 | save_total_limit: null
65 | seed: 42
66 | warmup_ratio: 0.1
--------------------------------------------------------------------------------
/recipes/deepseek/lora_sft.yaml:
--------------------------------------------------------------------------------
1 | # Model arguments
2 | model_name_or_path: /cm/shared/anhdtv7/mainframe_gpt/data/deepseek-7b-ft-full
3 | model_revision: main
4 | torch_dtype: 'float16'
5 | # trust_remote_code: true
6 | use_flash_attention_2: true
7 |
8 |
9 | # LoRA arguments
10 | # load_in_8bit: true
11 | load_in_4bit: true
12 | use_peft: true
13 | lora_r: 8
14 | lora_alpha: 16
15 | lora_dropout: 0.1
16 | lora_target_modules:
17 | - q_proj
18 | - k_proj
19 | - v_proj
20 | - o_proj
21 | - gate_proj
22 | - lm_head
23 | - up_proj
24 | - down_proj
25 |
26 | # Data training arguments
27 | chat_template: "{% for message in messages %}\n{% if message['from'] == 'human' %}\n{{ '<|user|>\n' + message['value'] + eos_token }}\n{% elif message['from'] == 'system' %}\n{{ '<|system|>\n' + message['value'] + eos_token }}\n{% elif message['from'] == 'gpt' %}\n{{ '<|assistant|>\n' + message['value'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}"
28 | dataset_mixer:
29 | # HuggingFaceH4/ultrachat_200k: 1.0
30 | /cm/archive/hieudt47/workspace/data/mainframe_df_v1.1.feather: 1
31 | /cm/archive/hieudt47/workspace/data/textbook_quality_programming.feather: 1
32 | dataset_splits:
33 | - train
34 | - test
35 | preprocessing_num_workers: 12
36 |
37 | # SFT trainer config
38 | bf16: true
39 | do_eval: true
40 | evaluation_strategy: epoch
41 | gradient_accumulation_steps: 2
42 | gradient_checkpointing: true
43 | gradient_checkpointing_kwargs:
44 | use_reentrant: false
45 | hub_model_id: deepseek-7b-ft-lora_longcontext
46 | hub_strategy: every_save
47 | # optimizer: paged_adamw_8bit
48 | learning_rate: 1.0e-05
49 | log_level: info
50 | logging_steps: 5
51 | logging_strategy: steps
52 | lr_scheduler_type: cosine
53 | max_seq_length: 16000
54 | max_steps: -1
55 | num_train_epochs: 5
56 | output_dir: data/deepseek-7b-ft-lora_longcontext
57 | overwrite_output_dir: true
58 | per_device_eval_batch_size: 2
59 | per_device_train_batch_size: 1
60 | # save_strategy: "steps"
61 | # save_steps: 500
62 | push_to_hub: false
63 | report_to:
64 | - tensorboard
65 | save_strategy: "epoch"
66 | save_total_limit: null
67 | seed: 42
68 | warmup_ratio: 0.1
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | absl-py==2.1.0
2 | accelerate==0.33.0
3 | aiohttp==3.9.5
4 | aiosignal==1.3.1
5 | alignment-handbook
6 | annotated-types==0.7.0
7 | async-timeout==4.0.3
8 | attrs==23.2.0
9 | bitsandbytes==0.43.2
10 | cachetools==5.4.0
11 | certifi==2024.7.4
12 | charset-normalizer==3.3.2
13 | datasets==2.20.0
14 | deepspeed==0.12.2
15 | dill==0.3.8
16 | docstring_parser==0.16
17 | einops==0.8.0
18 | evaluate==0.4.0
19 | filelock==3.15.4
20 | flash_attn==2.6.3
21 | frozenlist==1.4.1
22 | fsspec==2024.5.0
23 | grpcio==1.65.1
24 | hf_transfer==0.1.8
25 | hjson==3.1.0
26 | huggingface-hub==0.24.2
27 | idna==3.7
28 | Jinja2==3.1.4
29 | Markdown==3.6
30 | markdown-it-py==3.0.0
31 | MarkupSafe==2.1.5
32 | mdurl==0.1.2
33 | mpmath==1.3.0
34 | multidict==6.0.5
35 | multiprocess==0.70.16
36 | networkx==3.3
37 | ninja==1.11.1.1
38 | numpy==1.26.4
39 | nvidia-cublas-cu12==12.1.3.1
40 | nvidia-cuda-cupti-cu12==12.1.105
41 | nvidia-cuda-nvrtc-cu12==12.1.105
42 | nvidia-cuda-runtime-cu12==12.1.105
43 | nvidia-cudnn-cu12==9.1.0.70
44 | nvidia-cufft-cu12==11.0.2.54
45 | nvidia-curand-cu12==10.3.2.106
46 | nvidia-cusolver-cu12==11.4.5.107
47 | nvidia-cusparse-cu12==12.1.0.106
48 | nvidia-ml-py==12.535.161
49 | nvidia-nccl-cu12==2.20.5
50 | nvidia-nvjitlink-cu12==12.5.82
51 | nvidia-nvtx-cu12==12.1.105
52 | nvitop==1.3.2
53 | packaging==24.1
54 | pandas==2.2.2
55 | peft==0.12.0
56 | protobuf==3.20.2
57 | psutil==6.0.0
58 | py-cpuinfo==9.0.0
59 | pyarrow==17.0.0
60 | pyarrow-hotfix==0.6
61 | pydantic==2.8.2
62 | pydantic_core==2.20.1
63 | Pygments==2.18.0
64 | pynvml==11.5.3
65 | python-dateutil==2.9.0.post0
66 | pytz==2024.1
67 | PyYAML==6.0.1
68 | regex==2024.7.24
69 | requests==2.32.3
70 | responses==0.18.0
71 | rich==13.7.1
72 | safetensors==0.4.3
73 | scipy==1.14.0
74 | sentencepiece==0.2.0
75 | shtab==1.7.1
76 | six==1.16.0
77 | sympy==1.13.1
78 | tensorboard==2.17.0
79 | tensorboard-data-server==0.7.2
80 | termcolor==2.4.0
81 | tokenizers==0.19.1
82 | torch==2.4.0
83 | tqdm==4.66.4
84 | transformers==4.40.2
85 | triton==3.0.0
86 | trl==0.9.6
87 | typing_extensions==4.12.2
88 | tyro==0.8.5
89 | tzdata==2024.1
90 | urllib3==2.2.2
91 | Werkzeug==3.0.3
92 | xxhash==3.4.1
93 | yarl==1.9.4
94 |
95 |
96 |
--------------------------------------------------------------------------------
/scripts/ft.sh:
--------------------------------------------------------------------------------
1 |
2 | ACCELERATE_LOG_LEVEL=info accelerate launch \
3 | --config_file recipes/accelerate_configs/fsdp_qlora.yaml \
4 | --num_processes=4 \
5 | --main_process_port 9506 \
6 | train_raw.py recipes/deepseek/lora_sft.yaml \
7 | --torch_dtype=bfloat16 --bnb_4bit_quant_storage=bfloat16 \
8 | --use_4bit_quantization=True \
9 | --use_nested_quant=True \
10 | --bnb_4bit_quant_type="nf4" \
11 | --bnb_4bit_compute_dtype=bfloat16 \
12 | --bnb_4bit_quant_storage_dtype=bfloat16
13 | # --load_in_4bit=true
14 |
--------------------------------------------------------------------------------
/scripts/instruct.sh:
--------------------------------------------------------------------------------
1 |
2 | ACCELERATE_LOG_LEVEL=info accelerate launch \
3 | --config_file recipes/accelerate_configs/multi_gpu.yaml \
4 | --num_processes=4 \
5 | --main_process_port 9505 \
6 | train_instruct.py recipes/deepseek/lora_instruct.yaml \
7 | # --torch_dtype=bfloat16 --bnb_4bit_quant_storage=bfloat16 \
8 | # --use_4bit_quantization=True \
9 | # --use_nested_quant=True \
10 | # --bnb_4bit_quant_type="nf4" \
11 | # --bnb_4bit_compute_dtype=bfloat16 \
12 | # --bnb_4bit_quant_storage_dtype=bfloat16
13 |
--------------------------------------------------------------------------------
/src/alignment/__init__.py:
--------------------------------------------------------------------------------
1 | __version__ = "0.3.0.dev0"
2 |
3 | from .configs import DataArguments, DPOConfig, H4ArgumentParser, ModelArguments, SFTConfig
4 | from .data import apply_chat_template, get_datasets
5 | from .decontaminate import decontaminate_humaneval
6 | from .model_utils import (
7 | get_checkpoint,
8 | get_kbit_device_map,
9 | get_peft_config,
10 | get_quantization_config,
11 | get_tokenizer,
12 | is_adapter_model,
13 | )
14 |
--------------------------------------------------------------------------------
/src/alignment/__pycache__/__init__.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FSoft-AI4Code/XMainframe/801033385e0457667aff301dc07df1c1b8ca4b04/src/alignment/__pycache__/__init__.cpython-310.pyc
--------------------------------------------------------------------------------
/src/alignment/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FSoft-AI4Code/XMainframe/801033385e0457667aff301dc07df1c1b8ca4b04/src/alignment/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/src/alignment/__pycache__/configs.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FSoft-AI4Code/XMainframe/801033385e0457667aff301dc07df1c1b8ca4b04/src/alignment/__pycache__/configs.cpython-310.pyc
--------------------------------------------------------------------------------
/src/alignment/__pycache__/configs.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FSoft-AI4Code/XMainframe/801033385e0457667aff301dc07df1c1b8ca4b04/src/alignment/__pycache__/configs.cpython-38.pyc
--------------------------------------------------------------------------------
/src/alignment/__pycache__/data.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FSoft-AI4Code/XMainframe/801033385e0457667aff301dc07df1c1b8ca4b04/src/alignment/__pycache__/data.cpython-310.pyc
--------------------------------------------------------------------------------
/src/alignment/__pycache__/data.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FSoft-AI4Code/XMainframe/801033385e0457667aff301dc07df1c1b8ca4b04/src/alignment/__pycache__/data.cpython-38.pyc
--------------------------------------------------------------------------------
/src/alignment/__pycache__/decontaminate.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FSoft-AI4Code/XMainframe/801033385e0457667aff301dc07df1c1b8ca4b04/src/alignment/__pycache__/decontaminate.cpython-310.pyc
--------------------------------------------------------------------------------
/src/alignment/__pycache__/decontaminate.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FSoft-AI4Code/XMainframe/801033385e0457667aff301dc07df1c1b8ca4b04/src/alignment/__pycache__/decontaminate.cpython-38.pyc
--------------------------------------------------------------------------------
/src/alignment/__pycache__/model_utils.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FSoft-AI4Code/XMainframe/801033385e0457667aff301dc07df1c1b8ca4b04/src/alignment/__pycache__/model_utils.cpython-310.pyc
--------------------------------------------------------------------------------
/src/alignment/__pycache__/model_utils.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FSoft-AI4Code/XMainframe/801033385e0457667aff301dc07df1c1b8ca4b04/src/alignment/__pycache__/model_utils.cpython-38.pyc
--------------------------------------------------------------------------------
/src/alignment/configs.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023 The HuggingFace Team. All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | import dataclasses
16 | import os
17 | import sys
18 | from dataclasses import dataclass, field
19 | from typing import Any, Dict, List, NewType, Optional, Tuple
20 |
21 | import transformers
22 | from transformers import MODEL_FOR_CAUSAL_LM_MAPPING, HfArgumentParser
23 |
24 |
25 | MODEL_CONFIG_CLASSES = list(MODEL_FOR_CAUSAL_LM_MAPPING.keys())
26 | MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)
27 |
28 |
29 | DataClassType = NewType("DataClassType", Any)
30 |
31 |
32 | class H4ArgumentParser(HfArgumentParser):
33 | def parse_yaml_and_args(self, yaml_arg: str, other_args: Optional[List[str]] = None) -> List[dataclass]:
34 | """
35 | Parse a YAML file and overwrite the default/loaded values with the values provided to the command line.
36 |
37 | Args:
38 | yaml_arg (`str`):
39 | The path to the config file used
40 | other_args (`List[str]`, *optional`):
41 | A list of strings to parse as command line arguments, e.g. ['--arg=val', '--arg2=val2'].
42 |
43 | Returns:
44 | [`List[dataclass]`]: a list of dataclasses with the values from the YAML file and the command line
45 | """
46 | arg_list = self.parse_yaml_file(os.path.abspath(yaml_arg))
47 |
48 | outputs = []
49 | # strip other args list into dict of key-value pairs
50 | other_args = {arg.split("=")[0].strip("-"): arg.split("=")[1] for arg in other_args}
51 | used_args = {}
52 |
53 | # overwrite the default/loaded value with the value provided to the command line
54 | # adapted from https://github.com/huggingface/transformers/blob/d0b5002378daabf62769159add3e7d66d3f83c3b/src/transformers/hf_argparser.py#L327
55 | for data_yaml, data_class in zip(arg_list, self.dataclass_types):
56 | keys = {f.name for f in dataclasses.fields(data_yaml) if f.init}
57 | inputs = {k: v for k, v in vars(data_yaml).items() if k in keys}
58 | for arg, val in other_args.items():
59 | # add only if in keys
60 |
61 | if arg in keys:
62 | base_type = data_yaml.__dataclass_fields__[arg].type
63 | inputs[arg] = val
64 |
65 | # cast type for ints, floats (default to strings)
66 | if base_type in [int, float]:
67 | inputs[arg] = base_type(val)
68 |
69 | if base_type == List[str]:
70 | inputs[arg] = [str(v) for v in val.split(",")]
71 |
72 | # bool of a non-empty string is True, so we manually check for bools
73 | if base_type == bool:
74 | if val in ["true", "True"]:
75 | inputs[arg] = True
76 | else:
77 | inputs[arg] = False
78 |
79 | # add to used-args so we can check if double add
80 | if arg not in used_args:
81 | used_args[arg] = val
82 | else:
83 | raise ValueError(f"Duplicate argument provided: {arg}, may cause unexpected behavior")
84 |
85 | obj = data_class(**inputs)
86 | outputs.append(obj)
87 |
88 | return outputs
89 |
90 | def parse(self): #-> DataClassType | Tuple[DataClassType]:
91 | if len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"):
92 | # If we pass only one argument to the script and it's the path to a YAML file,
93 | # let's parse it to get our arguments.
94 | output = self.parse_yaml_file(os.path.abspath(sys.argv[1]))
95 | # parse command line args and yaml file
96 | elif len(sys.argv) > 2 and sys.argv[1].endswith(".yaml"):
97 | output = self.parse_yaml_and_args(os.path.abspath(sys.argv[1]), sys.argv[2:])
98 | # parse command line args only
99 | else:
100 | output = self.parse_args_into_dataclasses()
101 |
102 | if len(output) == 1:
103 | output = output[0]
104 | return output
105 |
106 |
107 | @dataclass
108 | class ModelArguments:
109 | """
110 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune.
111 | """
112 |
113 | base_model_revision: Optional[str] = field(
114 | default=None,
115 | metadata={"help": ("The base model checkpoint for weights initialization with PEFT adapters.")},
116 | )
117 | model_name_or_path: Optional[str] = field(
118 | default=None,
119 | metadata={
120 | "help": (
121 | "The model checkpoint for weights initialization. Don't set if you want to train a model from scratch."
122 | )
123 | },
124 | )
125 | model_revision: str = field(
126 | default="main",
127 | metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
128 | )
129 | model_code_revision: str = field(default=None, metadata={"help": "The branch of the IFT model"})
130 | torch_dtype: Optional[str] = field(
131 | default=None,
132 | metadata={
133 | "help": (
134 | "Override the default `torch.dtype` and load the model under this dtype. If `auto` is passed, the "
135 | "dtype will be automatically derived from the model's weights."
136 | ),
137 | "choices": ["auto", "bfloat16", "float16", "float32"],
138 | },
139 | )
140 | tokenizer_name_or_path: Optional[str] = field(
141 | default=None,
142 | metadata={
143 | "help": (
144 | "The path to the tokenizer. Useful if you want to use a different tokenizer to the one stored in `model_name_or_path`."
145 | )
146 | },
147 | )
148 | trust_remote_code: bool = field(default=False, metadata={"help": "Trust remote code when loading a model."})
149 | use_flash_attention_2: bool = field(
150 | default=False,
151 | metadata={
152 | "help": (
153 | "Whether to use flash attention 2. You must install this manually by running `pip install flash-attn --no-build-isolation`"
154 | )
155 | },
156 | )
157 | use_peft: bool = field(
158 | default=False,
159 | metadata={"help": ("Whether to use PEFT or not for training.")},
160 | )
161 | lora_r: Optional[int] = field(
162 | default=16,
163 | metadata={"help": ("LoRA R value.")},
164 | )
165 | lora_alpha: Optional[int] = field(
166 | default=32,
167 | metadata={"help": ("LoRA alpha.")},
168 | )
169 | lora_dropout: Optional[float] = field(
170 | default=0.05,
171 | metadata={"help": ("LoRA dropout.")},
172 | )
173 | lora_target_modules: Optional[List[str]] = field(
174 | default=None,
175 | metadata={"help": ("LoRA target modules.")},
176 | )
177 | lora_modules_to_save: Optional[List[str]] = field(
178 | default=None,
179 | metadata={"help": ("Model layers to unfreeze & train")},
180 | )
181 | load_in_8bit: bool = field(default=False, metadata={"help": "use 8 bit precision"})
182 | load_in_4bit: bool = field(default=False, metadata={"help": "use 4 bit precision"})
183 |
184 | bnb_4bit_quant_type: Optional[str] = field(
185 | default="nf4", metadata={"help": "precise the quantization type (fp4 or nf4)"}
186 | )
187 | use_bnb_nested_quant: bool = field(default=False, metadata={"help": "use nested quantization"})
188 | bnb_4bit_quant_storage: Optional[str] = field(
189 | default="uint8", metadata={"help": "storage type to pack the quanitzed 4-bit prarams."}
190 | )
191 |
192 | def __post_init__(self):
193 | if self.load_in_8bit and self.load_in_4bit:
194 | raise ValueError("You can't use 8 bit and 4 bit precision at the same time")
195 |
196 |
197 | @dataclass
198 | class DataArguments:
199 | """
200 | Arguments pertaining to what data we are going to input our model for training and eval.
201 | """
202 |
203 | chat_template: Optional[str] = field(default=None, metadata={"help": "The chat template to use."})
204 | dataset_mixer: Optional[Dict[str, float]] = field(
205 | default=None,
206 | metadata={"help": ("Datasets and their proportions to be used for training ift/rl.")},
207 | )
208 | text_column: Optional[str] = field(
209 | default="text",
210 | metadata={"help": "The column name to use for the text in the dataset (only used for continued pretraining)."},
211 | )
212 | dataset_splits: Optional[List[str]] = field(
213 | default_factory=lambda: ["train", "test"],
214 | metadata={"help": ("List of train test splits to use in the dataset")},
215 | )
216 | dataset_configs: Optional[List[str]] = field(
217 | default=None,
218 | metadata={"help": "List of dataset config names. If given must be the same length as 'dataset_mixer' keys."},
219 | )
220 | preprocessing_num_workers: Optional[int] = field(
221 | default=None,
222 | metadata={"help": "The number of processes to use for the preprocessing."},
223 | )
224 | truncation_side: Optional[str] = field(
225 | default=None, metadata={"help": "Truncation side to use for the tokenizer."}
226 | )
227 | auto_insert_empty_system_msg: bool = field(
228 | default=True,
229 | metadata={
230 | "help": (
231 | "Whether to automatically insert an empty system message as the first message if `system` is mentioned in the chat template."
232 | )
233 | },
234 | )
235 |
236 |
237 | @dataclass
238 | class SFTConfig(transformers.TrainingArguments):
239 | """
240 | Arguments related to the training process itself. For all parameters, see: https://huggingface.co/docs/transformers/v4.26.1/en/main_classes/trainer#transformers.TrainingArguments
241 | Also used for the continued pretraining task.
242 | """
243 |
244 | dataset_kwargs: Optional[Dict[str, Any]] = field(
245 | default=None, metadata={"help": "Dataset kwargs for the SFTTrainer"}
246 | )
247 | max_seq_length: Optional[int] = field(
248 | default=None,
249 | metadata={"help": ("Used by TRL for reward model training, which tries to read this parameter in init.")},
250 | )
251 | logging_first_step: bool = field(
252 | default=True,
253 | metadata={"help": ("Whether to log and evaluate the first global_step or not.")},
254 | )
255 | optim: Optional[str] = field(default="adamw_torch")
256 |
257 |
258 | @dataclass
259 | class DPOConfig(transformers.TrainingArguments):
260 | """
261 | Arguments related to the DPO training process itself. For all parameters, see: https://huggingface.co/docs/transformers/v4.26.1/en/main_classes/trainer#transformers.TrainingArguments
262 | """
263 |
264 | beta: Optional[float] = field(
265 | default=0.1,
266 | metadata={"help": "The beta factor in DPO loss. Higher beta means less divergence from the initial policy."},
267 | )
268 | hub_model_revision: Optional[str] = field(
269 | default="main",
270 | metadata={"help": ("The Hub model branch to push the model to.")},
271 | )
272 | logging_first_step: bool = field(
273 | default=True,
274 | metadata={"help": ("Whether to log and evaluate the first global_step or not.")},
275 | )
276 | max_prompt_length: Optional[int] = field(
277 | default=None,
278 | metadata={"help": ("For DPO, the maximum length of the prompt to use for conditioning the model.")},
279 | )
280 | max_length: Optional[int] = field(
281 | default=None,
282 | metadata={"help": ("Used by TRL for reward model training, which tries to read this parameter in init.")},
283 | )
284 | optim: Optional[str] = field(default="rmsprop")
285 | remove_unused_columns: bool = field(default=False)
286 | loss_type: Optional[str] = field(default="sigmoid", metadata={"help": ("The loss type for DPO.")})
287 |
288 |
289 | @dataclass
290 | class ORPOConfig(transformers.TrainingArguments):
291 | max_length: Optional[int] = field(
292 | default=None,
293 | metadata={"help": "The maximum length of the sequences in the batch."},
294 | )
295 | max_prompt_length: Optional[int] = field(
296 | default=None,
297 | metadata={"help": "The maximum length of the prompt."},
298 | )
299 | max_completion_length: Optional[int] = field(
300 | default=None,
301 | metadata={"help": "The maximum length of the completions."},
302 | )
303 |
304 | beta: float = field(
305 | default=0.1,
306 | metadata={
307 | "help": "The beta factor in ORPO loss (lambda/alpha in paper/code) that is the weight of the relative loss ratio in the SFT loss."
308 | },
309 | )
310 | disable_dropout: bool = field(
311 | default=True,
312 | metadata={"help": "Whether or not to disable dropouts in `model`."},
313 | )
314 |
315 | label_pad_token_id: int = field(
316 | default=-100,
317 | metadata={"help": "The label pad token id."},
318 | )
319 | padding_value: Optional[int] = field(
320 | default=None,
321 | metadata={"help": "The padding value if it is different to the tokenizer's pad_token_id."},
322 | )
323 | truncation_mode: str = field(
324 | default="keep_end",
325 | metadata={"help": "The truncation mode to use, either `keep_end` or `keep_start`."},
326 | )
327 |
328 | generate_during_eval: bool = field(
329 | default=False,
330 | metadata={"help": "Whether to sample and log generations during evaluation step."},
331 | )
332 | is_encoder_decoder: Optional[bool] = field(
333 | default=None,
334 | metadata={"help": ("If no model is provided, we need to know if the model_init returns an encoder-decoder.")},
335 | )
336 |
337 | model_init_kwargs: Optional[Dict] = field(
338 | default=None,
339 | metadata={"help": ("Dict of Optional kwargs to pass when instantiating the model from a string")},
340 | )
341 |
342 | dataset_num_proc: Optional[int] = field(
343 | default=None,
344 | metadata={"help": ("The number of workers to use to tokenize the data.")},
345 | )
346 |
--------------------------------------------------------------------------------
/src/alignment/data.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023 The HuggingFace Team. All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import os
17 | from typing import Any, List, Literal, Optional
18 |
19 | from datasets import DatasetDict, concatenate_datasets, load_dataset, load_from_disk
20 | from datasets.builder import DatasetGenerationError
21 |
22 | from .configs import DataArguments
23 |
24 |
25 | DEFAULT_CHAT_TEMPLATE = "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}"
26 |
27 |
28 | def maybe_insert_system_message(messages, tokenizer):
29 | if messages[0]["role"] == "system":
30 | return
31 |
32 | # chat template can be one of two attributes, we check in order
33 | chat_template = tokenizer.chat_template
34 | if chat_template is None:
35 | chat_template = tokenizer.default_chat_template
36 |
37 | # confirm the jinja template refers to a system message before inserting
38 | if "system" in chat_template or "<|im_start|>" in chat_template:
39 | messages.insert(0, {"role": "system", "content": ""})
40 |
41 |
42 | def apply_chat_template(
43 | example,
44 | tokenizer,
45 | task: Literal["sft", "generation", "rm", "dpo"],
46 | auto_insert_empty_system_msg: bool = True,
47 | ):
48 | if task in ["sft", "generation"]:
49 | messages = example["messages"]
50 | # We add an empty system message if there is none
51 | if auto_insert_empty_system_msg:
52 | maybe_insert_system_message(messages, tokenizer)
53 | example["text"] = tokenizer.apply_chat_template(
54 | messages,
55 | tokenize=False,
56 | add_generation_prompt=True if task == "generation" else False,
57 | )
58 | elif task == "rm":
59 | if all(k in example.keys() for k in ("chosen", "rejected")):
60 | chosen_messages = example["chosen"]
61 | rejected_messages = example["rejected"]
62 | # We add an empty system message if there is none
63 | if auto_insert_empty_system_msg:
64 | maybe_insert_system_message(chosen_messages, tokenizer)
65 | maybe_insert_system_message(rejected_messages, tokenizer)
66 |
67 | example["text_chosen"] = tokenizer.apply_chat_template(chosen_messages, tokenize=False)
68 | example["text_rejected"] = tokenizer.apply_chat_template(rejected_messages, tokenize=False)
69 | else:
70 | raise ValueError(
71 | f"Could not format example as dialogue for `rm` task! Require `[chosen, rejected]` keys but found {list(example.keys())}"
72 | )
73 | elif task in ["dpo", "orpo"]:
74 | if all(k in example.keys() for k in ("chosen", "rejected")):
75 | if not is_openai_format(example["chosen"]) or not is_openai_format(example["rejected"]):
76 | raise ValueError(
77 | f"Could not format example as dialogue for `{task}` task! Require OpenAI format for all messages"
78 | )
79 |
80 | # For DPO/ORPO, the inputs are triples of (prompt, chosen, rejected), where `chosen` and `rejected` are the final turn of a dialogue
81 | # We therefore need to extract the N-1 turns to form the prompt
82 | if "prompt" in example and is_openai_format(example["prompt"]):
83 | prompt_messages = example["prompt"]
84 | chosen_messages = example["chosen"]
85 | rejected_messages = example["rejected"]
86 | else:
87 | prompt_messages = example["chosen"][:-1]
88 | # Now we extract the final turn to define chosen/rejected responses
89 | chosen_messages = example["chosen"][-1:]
90 | rejected_messages = example["rejected"][-1:]
91 |
92 | # Prepend a system message if the first message is not a system message
93 | if auto_insert_empty_system_msg:
94 | maybe_insert_system_message(prompt_messages, tokenizer)
95 |
96 | example["text_prompt"] = tokenizer.apply_chat_template(prompt_messages, tokenize=False)
97 | example["text_chosen"] = tokenizer.apply_chat_template(chosen_messages, tokenize=False)
98 | example["text_rejected"] = tokenizer.apply_chat_template(rejected_messages, tokenize=False)
99 | else:
100 | raise ValueError(
101 | f"Could not format example as dialogue for `{task}` task! Require either the "
102 | f"`[chosen, rejected]` or `[prompt, chosen, rejected]` keys but found {list(example.keys())}"
103 | )
104 | else:
105 | raise ValueError(
106 | f"Task {task} not supported, please ensure that the provided task is one of ['sft', 'generation', 'rm', 'dpo', 'orpo']"
107 | )
108 | return example
109 |
110 |
111 | def is_openai_format(messages: Any) -> bool:
112 | """
113 | Check if the input messages are in OpenAI format.
114 | Args:
115 | messages (`Any`):
116 | Messages to check.
117 | Returns:
118 | `bool`: Whether the messages are in OpenAI format.
119 | """
120 | if isinstance(messages, list) and all(isinstance(message, dict) for message in messages):
121 | return all("role" in message and "content" in message for message in messages)
122 | return False
123 |
124 |
125 | def get_datasets(
126 | data_config,#: DataArguments | dict,
127 | splits: Optional[List[str]] = None,
128 | configs: Optional[List[str]] = None,
129 | columns_to_keep: Optional[List[str]] = None,
130 | shuffle: bool = True,
131 | ):# -> DatasetDict:
132 | """
133 | Loads one or more datasets with varying training set proportions.
134 |
135 | Args:
136 | data_config (`DataArguments` or `dict`):
137 | Dataset configuration and split proportions.
138 | splits (`List[str]`, *optional*, defaults to `['train', 'test']`):
139 | Dataset splits to load and mix. Assumes the splits exist in all datasets and have a `train_` or `test_` prefix.
140 | configs (Optional[List[str]], *optional*, defaults to `None`):
141 | List of dataset config names. If given must be the same length as 'data_config' keys.
142 | columns_to_keep (Optional[List[str]], *optional*, defaults to `None`):
143 | Column names to keep in the dataset. Useful in the datamixer to avoid schema conflicts,
144 | and for cpt this should be (at least) the text column.
145 | shuffle (`bool`, *optional*, defaults to `True`):
146 | Whether to shuffle the training and testing/validation data.
147 |
148 | Returns
149 | [`DatasetDict`]: The dataset dictionary containing the loaded datasets.
150 | """
151 | if type(data_config) is DataArguments:
152 | # Structure of the config to read the datasets and their mix
153 | # datasets_mixer:
154 | # - 'dataset1': 0.5
155 | # - 'dataset2': 0.3
156 | # - 'dataset3': 0.2
157 | dataset_mixer = data_config.dataset_mixer
158 | elif isinstance(data_config, dict):
159 | # Structure of the input is:
160 | # dataset_mixer = {
161 | # "dataset1": 0.5,
162 | # "dataset1": 0.3,
163 | # "dataset1": 0.2,
164 | # }
165 | dataset_mixer = data_config
166 | else:
167 | raise ValueError(f"Data config {data_config} not recognized.")
168 |
169 | raw_datasets = mix_datasets(
170 | dataset_mixer,
171 | splits=splits,
172 | configs=configs,
173 | columns_to_keep=columns_to_keep,
174 | shuffle=shuffle,
175 | )
176 | return raw_datasets
177 |
178 |
179 | def mix_datasets(
180 | dataset_mixer: dict,
181 | splits: Optional[List[str]] = None,
182 | configs: Optional[List[str]] = None,
183 | columns_to_keep: Optional[List[str]] = None,
184 | shuffle=True,
185 | ) -> DatasetDict:
186 | """
187 | Loads and mixes datasets according to proportions specified in `dataset_mixer`.
188 |
189 | Args:
190 | dataset_mixer (`dict`):
191 | Dictionary containing the dataset names and their training proportions. By default, all test proportions are 1.
192 | splits (Optional[List[str]], *optional*, defaults to `None`):
193 | Dataset splits to load and mix. Assumes the splits exist in all datasets and have a `train_` or `test_` prefix.
194 | configs (Optional[List[str]], *optional*, defaults to `None`):
195 | List of dataset config names. If given must be the same length as 'dataset_mixer' keys.
196 | columns_to_keep (Optional[List[str]], *optional*, defaults to `None`):
197 | Column names to keep in the dataset. Useful in the datamixer to avoid schema conflicts,
198 | and for cpt this should be (at least) the text column.
199 | shuffle (`bool`, *optional*, defaults to `True`):
200 | Whether to shuffle the training and testing/validation data.
201 | """
202 | splits = ["train", "test"] if splits is None else splits
203 | configs = [None] * len(dataset_mixer) if not configs else configs
204 | columns_to_keep = [] if columns_to_keep is None else columns_to_keep
205 |
206 | if configs is not None and len(configs) != len(dataset_mixer):
207 | raise ValueError("The number of given dataset config names must be the same as the given number of datasets.")
208 |
209 | raw_datasets = DatasetDict()
210 | raw_train_datasets = []
211 | raw_val_datasets = []
212 | fracs = []
213 | for (ds, frac), ds_config in zip(dataset_mixer.items(), configs):
214 | fracs.append(frac)
215 | for split in splits:
216 | try:
217 | # Try first if dataset on a Hub repo
218 | dataset = load_dataset(ds, ds_config, split=split)
219 | except DatasetGenerationError:
220 | # If not, check local dataset
221 | dataset = load_from_disk(os.path.join(ds, split))
222 |
223 | # Remove redundant columns to avoid schema conflicts on load
224 | dataset = dataset.remove_columns([col for col in dataset.column_names if col not in columns_to_keep])
225 | if "train" in split:
226 | raw_train_datasets.append(dataset)
227 | elif "test" in split:
228 | raw_val_datasets.append(dataset)
229 | else:
230 | raise ValueError(f"Split type {split} not recognized as one of test or train.")
231 |
232 | if any(frac < 0 for frac in fracs):
233 | raise ValueError("Dataset fractions cannot be negative.")
234 |
235 | if len(raw_train_datasets) > 0:
236 | train_subsets = []
237 | for dataset, frac in zip(raw_train_datasets, fracs):
238 | train_subset = dataset.select(range(int(frac * len(dataset))))
239 | train_subsets.append(train_subset)
240 | if shuffle:
241 | raw_datasets["train"] = concatenate_datasets(train_subsets).shuffle(seed=42)
242 | else:
243 | raw_datasets["train"] = concatenate_datasets(train_subsets)
244 | # No subsampling for test datasets to enable fair comparison across models
245 | if len(raw_val_datasets) > 0:
246 | if shuffle:
247 | raw_datasets["test"] = concatenate_datasets(raw_val_datasets).shuffle(seed=42)
248 | else:
249 | raw_datasets["test"] = concatenate_datasets(raw_val_datasets)
250 |
251 | if len(raw_datasets) == 0:
252 | raise ValueError(
253 | f"Dataset {dataset_mixer} not recognized with splits {splits}. Check the dataset has been correctly formatted."
254 | )
255 |
256 | return raw_datasets
257 |
--------------------------------------------------------------------------------
/src/alignment/decontaminate.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | from typing import Any, Dict, List
17 |
18 | from datasets import load_dataset
19 |
20 |
21 | # HumanEval solutions that are considered simple/generic enough to be kept in the training dataset
22 | HUMAN_EVAL_STRINGS_OK = ["return x + y", "return len(string)", "return n**2", "return " ".join(strings)"]
23 |
24 |
25 | def extract_docstring(prompt: str) -> str:
26 | if '"""' in prompt:
27 | if prompt.count('"""') == 2:
28 | return prompt.split('"""')[1].strip()
29 | elif prompt.count('"""') == 4:
30 | return prompt.split('"""')[3].strip()
31 | else:
32 | raise ValueError()
33 | elif "'''" in prompt:
34 | assert prompt.count("'''") == 2
35 | return prompt.split("'''")[1].strip()
36 | else:
37 | raise ValueError()
38 |
39 |
40 | def human_eval_docstrings() -> List[str]:
41 | # ds = load_dataset("openai_humaneval", split="test")
42 | # docstrings = [extract_docstring(v["prompt"]) for v in ds]
43 | # return docstrings
44 | return []
45 |
46 |
47 | def load_dataset_column(dataset: str, column: str, split: str, name=None) -> List[str]:
48 | # ds = load_dataset(dataset, split=split, name=name)
49 | # res = [sample[column].strip() for sample in ds]
50 | # # Only return non-empty strings
51 | # return [sample for sample in res if len(sample) > 0]
52 | return []
53 |
54 |
55 | FILTER_OUT = {
56 | "human_eval_docstrings": human_eval_docstrings(),
57 | "human_eval_solutions": [
58 | s
59 | for s in load_dataset_column("openai_humaneval", "canonical_solution", "test")
60 | if s not in HUMAN_EVAL_STRINGS_OK
61 | ],
62 | }
63 |
64 |
65 | def normalize_whitespace(text: str) -> str:
66 | return " ".join(text.split())
67 |
68 |
69 | def decontaminate_humaneval(
70 | samples: List[Dict[str, Any]], text_column: str = "text", filter_out: Dict[str, List[str]] = FILTER_OUT
71 | ) -> List[Dict[str, Any]]:
72 | """
73 | filter_out: Dict[str, List[str]] mapping from benchmark name to list of strings that need to be
74 | filtered-out.
75 | Return a list where each element is True if the corresponding file should be included in the dataset.
76 | Otherwise, the element is False.
77 | """
78 | output = []
79 |
80 | for content in samples[text_column]:
81 | content = normalize_whitespace(content.lower())
82 | matched = False
83 | for _, substrings in filter_out.items():
84 | for substring in substrings:
85 | if normalize_whitespace(substring.lower()) in content:
86 | matched = True
87 | break
88 | if matched:
89 | break
90 | # we keep files that are not matched
91 | output.append(not matched)
92 |
93 | return output
94 |
--------------------------------------------------------------------------------
/src/alignment/model_utils.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023 The HuggingFace Team. All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | import os
16 | from pathlib import Path
17 | from typing import Dict
18 |
19 | import torch
20 | from transformers import AutoTokenizer, BitsAndBytesConfig, PreTrainedTokenizer
21 | from transformers.trainer_utils import get_last_checkpoint
22 |
23 | from accelerate import Accelerator
24 | from huggingface_hub import list_repo_files
25 | from huggingface_hub.utils._errors import RepositoryNotFoundError
26 | from huggingface_hub.utils._validators import HFValidationError
27 | from peft import LoraConfig, PeftConfig
28 |
29 | from .configs import DataArguments, DPOConfig, ModelArguments, SFTConfig
30 | from .data import DEFAULT_CHAT_TEMPLATE
31 |
32 |
33 | def get_current_device():# -> int:
34 | """Get the current device. For GPU we return the local process index to enable multiple GPU training."""
35 | return Accelerator().local_process_index if torch.cuda.is_available() else "cpu"
36 |
37 |
38 | def get_kbit_device_map():# -> Dict[str, int] | None:
39 | """Useful for running inference with quantized models by setting `device_map=get_peft_device_map()`"""
40 | return {"": get_current_device()} if torch.cuda.is_available() else None
41 |
42 |
43 | def get_quantization_config(model_args: ModelArguments):# -> BitsAndBytesConfig | None:
44 | if model_args.load_in_4bit:
45 | compute_dtype = torch.float16
46 | if model_args.torch_dtype not in {"auto", None}:
47 | compute_dtype = getattr(torch, model_args.torch_dtype)
48 |
49 | quantization_config = BitsAndBytesConfig(
50 | load_in_4bit=True,
51 | bnb_4bit_compute_dtype=compute_dtype,
52 | bnb_4bit_quant_type=model_args.bnb_4bit_quant_type,
53 | bnb_4bit_use_double_quant=model_args.use_bnb_nested_quant,
54 | bnb_4bit_quant_storage=model_args.bnb_4bit_quant_storage,
55 | )
56 | elif model_args.load_in_8bit:
57 | quantization_config = BitsAndBytesConfig(
58 | load_in_8bit=True,
59 | )
60 | else:
61 | quantization_config = None
62 |
63 | return quantization_config
64 |
65 |
66 | def get_tokenizer(
67 | model_args: ModelArguments, data_args: DataArguments, auto_set_chat_template: bool = True
68 | ):# -> PreTrainedTokenizer:
69 | """Get the tokenizer for the model."""
70 | tokenizer = AutoTokenizer.from_pretrained(
71 | model_args.model_name_or_path
72 | if model_args.tokenizer_name_or_path is None
73 | else model_args.tokenizer_name_or_path,
74 | revision=model_args.model_revision,
75 | trust_remote_code=model_args.trust_remote_code,
76 | )
77 | if tokenizer.pad_token_id is None:
78 | tokenizer.pad_token_id = tokenizer.eos_token_id
79 |
80 | if data_args.truncation_side is not None:
81 | tokenizer.truncation_side = data_args.truncation_side
82 |
83 | # Set reasonable default for models without max length
84 | if tokenizer.model_max_length > 100_000:
85 | tokenizer.model_max_length = 2048
86 |
87 | if data_args.chat_template is not None:
88 | tokenizer.chat_template = data_args.chat_template
89 | elif auto_set_chat_template and tokenizer.chat_template is None and tokenizer.default_chat_template is None:
90 | tokenizer.chat_template = DEFAULT_CHAT_TEMPLATE
91 |
92 | return tokenizer
93 |
94 |
95 | def get_peft_config(model_args: ModelArguments) :# -> PeftConfig | None:
96 | if model_args.use_peft is False:
97 | return None
98 |
99 | peft_config = LoraConfig(
100 | r=model_args.lora_r,
101 | lora_alpha=model_args.lora_alpha,
102 | lora_dropout=model_args.lora_dropout,
103 | bias="none",
104 | task_type="CAUSAL_LM",
105 | target_modules=model_args.lora_target_modules,
106 | modules_to_save=model_args.lora_modules_to_save,
107 | )
108 |
109 | return peft_config
110 |
111 |
112 | def is_adapter_model(model_name_or_path: str, revision: str = "main") :# -> bool:
113 | try:
114 | # Try first if model on a Hub repo
115 | repo_files = list_repo_files(model_name_or_path, revision=revision)
116 | except (HFValidationError, RepositoryNotFoundError):
117 | # If not, check local repo
118 | repo_files = os.listdir(model_name_or_path)
119 | return "adapter_model.safetensors" in repo_files or "adapter_model.bin" in repo_files
120 |
121 |
122 | def get_checkpoint(training_args) :# -> Path | None:
123 | last_checkpoint = None
124 | if os.path.isdir(training_args.output_dir):
125 | last_checkpoint = get_last_checkpoint(training_args.output_dir)
126 | return last_checkpoint
127 |
--------------------------------------------------------------------------------
/src/alignment/release.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023 The HuggingFace Team. All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import argparse
17 | import re
18 |
19 | import packaging.version
20 |
21 |
22 | REPLACE_PATTERNS = {
23 | "init": (re.compile(r'^__version__\s+=\s+"([^"]+)"\s*$', re.MULTILINE), '__version__ = "VERSION"\n'),
24 | "setup": (re.compile(r'^(\s*)version\s*=\s*"[^"]+",', re.MULTILINE), r'\1version="VERSION",'),
25 | }
26 | REPLACE_FILES = {
27 | "init": "src/alignment/__init__.py",
28 | "setup": "setup.py",
29 | }
30 | README_FILE = "README.md"
31 |
32 |
33 | def update_version_in_file(fname, version, pattern):
34 | """Update the version in one file using a specific pattern."""
35 | with open(fname, "r", encoding="utf-8", newline="\n") as f:
36 | code = f.read()
37 | re_pattern, replace = REPLACE_PATTERNS[pattern]
38 | replace = replace.replace("VERSION", version)
39 | code = re_pattern.sub(replace, code)
40 | with open(fname, "w", encoding="utf-8", newline="\n") as f:
41 | f.write(code)
42 |
43 |
44 | def global_version_update(version, patch=False):
45 | """Update the version in all needed files."""
46 | for pattern, fname in REPLACE_FILES.items():
47 | update_version_in_file(fname, version, pattern)
48 |
49 |
50 | def get_version():
51 | """Reads the current version in the __init__."""
52 | with open(REPLACE_FILES["init"], "r") as f:
53 | code = f.read()
54 | default_version = REPLACE_PATTERNS["init"][0].search(code).groups()[0]
55 | return packaging.version.parse(default_version)
56 |
57 |
58 | def pre_release_work(patch=False):
59 | """Do all the necessary pre-release steps."""
60 | # First let's get the default version: base version if we are in dev, bump minor otherwise.
61 | default_version = get_version()
62 | if patch and default_version.is_devrelease:
63 | raise ValueError("Can't create a patch version from the dev branch, checkout a released version!")
64 | if default_version.is_devrelease:
65 | default_version = default_version.base_version
66 | elif patch:
67 | default_version = f"{default_version.major}.{default_version.minor}.{default_version.micro + 1}"
68 | else:
69 | default_version = f"{default_version.major}.{default_version.minor + 1}.0"
70 |
71 | # Now let's ask nicely if that's the right one.
72 | version = input(f"Which version are you releasing? [{default_version}]")
73 | if len(version) == 0:
74 | version = default_version
75 |
76 | print(f"Updating version to {version}.")
77 | global_version_update(version, patch=patch)
78 |
79 |
80 | def post_release_work():
81 | """Do all the necessary post-release steps."""
82 | # First let's get the current version
83 | current_version = get_version()
84 | dev_version = f"{current_version.major}.{current_version.minor + 1}.0.dev0"
85 | current_version = current_version.base_version
86 |
87 | # Check with the user we got that right.
88 | version = input(f"Which version are we developing now? [{dev_version}]")
89 | if len(version) == 0:
90 | version = dev_version
91 |
92 | print(f"Updating version to {version}.")
93 | global_version_update(version)
94 |
95 |
96 | if __name__ == "__main__":
97 | parser = argparse.ArgumentParser()
98 | parser.add_argument("--post_release", action="store_true", help="Whether this is pre or post release.")
99 | parser.add_argument("--patch", action="store_true", help="Whether or not this is a patch release.")
100 | args = parser.parse_args()
101 | if not args.post_release:
102 | pre_release_work(patch=args.patch)
103 | elif args.patch:
104 | print("Nothing to do after a patch :-)")
105 | else:
106 | post_release_work()
107 |
--------------------------------------------------------------------------------
/src/data/__init__.py:
--------------------------------------------------------------------------------
1 | from src.data.data import *
--------------------------------------------------------------------------------
/src/data/__pycache__/__init__.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FSoft-AI4Code/XMainframe/801033385e0457667aff301dc07df1c1b8ca4b04/src/data/__pycache__/__init__.cpython-310.pyc
--------------------------------------------------------------------------------
/src/data/__pycache__/data.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FSoft-AI4Code/XMainframe/801033385e0457667aff301dc07df1c1b8ca4b04/src/data/__pycache__/data.cpython-310.pyc
--------------------------------------------------------------------------------
/src/data/__pycache__/utils.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FSoft-AI4Code/XMainframe/801033385e0457667aff301dc07df1c1b8ca4b04/src/data/__pycache__/utils.cpython-310.pyc
--------------------------------------------------------------------------------
/src/data/data.py:
--------------------------------------------------------------------------------
1 | import os
2 | import re
3 | from typing import List, Literal, Optional
4 |
5 | import pandas as pd
6 | from datasets import DatasetDict, concatenate_datasets, load_dataset, load_from_disk, Dataset
7 | from datasets.builder import DatasetGenerationError
8 | from src.alignment import DataArguments
9 |
10 | from src.data.utils import markdown_to_text
11 |
12 |
13 | def apply_chat_template(
14 | example, tokenizer, task: Literal["sft", "generation", "rm", "dpo"] = "sft", assistant_prefix="<|assistant|>\n"
15 | ):
16 | def _strip_prefix(s, pattern):
17 | # Use re.escape to escape any special characters in the pattern
18 | return re.sub(f"^{re.escape(pattern)}", "", s)
19 |
20 | # if task in ["sft", "generation"]:
21 | if task == "sft":
22 | messages = example["conversations"]
23 | # We add an empty system message if there is none
24 | if messages[0]["from"] != "system":
25 | messages.insert(0, {"from": "system", "value": ""})
26 | example["text"] = tokenizer.apply_chat_template(
27 | messages, tokenize=False, add_generation_prompt=True if task == "generation" else False
28 | )
29 | elif task == "generation":
30 | pass
31 | # if "markdown" in example:
32 | # example["text"] = markdown_to_text(example["markdown"])
33 | # else:
34 | # pass
35 | elif task == "rm":
36 | if all(k in example.keys() for k in ("chosen", "rejected")):
37 | chosen_messages = example["chosen"]
38 | rejected_messages = example["rejected"]
39 | # We add an empty system message if there is none
40 | if chosen_messages[0]["role"] != "system":
41 | chosen_messages.insert(0, {"role": "system", "content": ""})
42 | if rejected_messages[0]["role"] != "system":
43 | rejected_messages.insert(0, {"role": "system", "content": ""})
44 | example["text_chosen"] = tokenizer.apply_chat_template(chosen_messages, tokenize=False)
45 | example["text_rejected"] = tokenizer.apply_chat_template(rejected_messages, tokenize=False)
46 | else:
47 | raise ValueError(
48 | f"Could not format example as dialogue for `rm` task! Require `[chosen, rejected]` keys but found {list(example.keys())}"
49 | )
50 | elif task == "dpo":
51 | if all(k in example.keys() for k in ("chosen", "rejected")):
52 | # Compared to reward modeling, we filter out the prompt, so the text is everything after the last assistant token
53 | prompt_messages = [[msg for msg in example["chosen"] if msg["role"] == "user"][0]]
54 | # Insert system message
55 | if example["chosen"][0]["role"] != "system":
56 | prompt_messages.insert(0, {"role": "system", "content": ""})
57 | else:
58 | prompt_messages.insert(0, example["chosen"][0])
59 | # TODO: handle case where chosen/rejected also have system messages
60 | chosen_messages = example["chosen"][1:]
61 | rejected_messages = example["rejected"][1:]
62 | example["text_chosen"] = tokenizer.apply_chat_template(chosen_messages, tokenize=False)
63 | example["text_rejected"] = tokenizer.apply_chat_template(rejected_messages, tokenize=False)
64 | example["text_prompt"] = tokenizer.apply_chat_template(
65 | prompt_messages, tokenize=False, add_generation_prompt=True
66 | )
67 | example["text_chosen"] = _strip_prefix(example["text_chosen"], assistant_prefix)
68 | example["text_rejected"] = _strip_prefix(example["text_rejected"], assistant_prefix)
69 | else:
70 | raise ValueError(
71 | f"Could not format example as dialogue for `dpo` task! Require `[chosen, rejected]` keys but found {list(example.keys())}"
72 | )
73 | else:
74 | raise ValueError(
75 | f"Task {task} not supported, please ensure that the provided task is one of {['sft', 'generation', 'rm', 'dpo']}"
76 | )
77 | return example
78 |
79 |
80 | def get_datasets(
81 | data_config: DataArguments | dict,
82 | splits: List[str] = ["train", "test"],
83 | shuffle: bool = True,
84 | ) -> DatasetDict:
85 | """
86 | Loads one or more datasets with varying training set proportions.
87 |
88 | Args:
89 | data_config (`DataArguments` or `dict`):
90 | Dataset configuration and split proportions.
91 | splits (`List[str]`, *optional*, defaults to `['train', 'test']`):
92 | Dataset splits to load and mix. Assumes the splits exist in all datasets and have a `train_` or `test_` prefix.
93 | shuffle (`bool`, *optional*, defaults to `True`):
94 | Whether to shuffle the training and testing/validation data.
95 |
96 | Returns
97 | [`DatasetDict`]: The dataset dictionary containing the loaded datasets.
98 | """
99 |
100 | if type(data_config) is DataArguments:
101 | # Structure of the config to read the datasets and their mix
102 | # datasets_mixer:
103 | # - 'dataset1': 0.5
104 | # - 'dataset2': 0.3
105 | # - 'dataset3': 0.2
106 | dataset_mixer = data_config.dataset_mixer
107 | elif type(data_config) is dict:
108 | # Structure of the input is:
109 | # dataset_mixer = {
110 | # "dataset1": 0.5,
111 | # "dataset1": 0.3,
112 | # "dataset1": 0.2,
113 | # }
114 | dataset_mixer = data_config
115 | else:
116 | raise ValueError(f"Data config {data_config} not recognized.")
117 |
118 | raw_datasets = mix_datasets(dataset_mixer, splits=splits, shuffle=shuffle)
119 | return raw_datasets
120 |
121 |
122 | def mix_datasets(dataset_mixer: dict, splits: Optional[List[str]] = None, shuffle=True) -> DatasetDict:
123 | """
124 | Loads and mixes datasets according to proportions specified in `dataset_mixer`.
125 |
126 | Args:
127 | dataset_mixer (`dict`):
128 | Dictionary containing the dataset names and their training proportions. By default, all test proportions are 1.
129 | splits (Optional[List[str]], *optional*, defaults to `None`):
130 | Dataset splits to load and mix. Assumes the splits exist in all datasets and have a `train_` or `test_` prefix.
131 | shuffle (`bool`, *optional*, defaults to `True`):
132 | Whether to shuffle the training and testing/validation data.
133 | """
134 | raw_datasets = DatasetDict()
135 | raw_train_datasets = []
136 | raw_val_datasets = []
137 | fracs = []
138 | for ds, frac in dataset_mixer.items():
139 | fracs.append(frac)
140 | for split in splits:
141 | try:
142 | # Try first if dataset on a Hub repo
143 | dataset = load_dataset(ds, split=split)
144 | # except DatasetGenerationError:
145 | # # If not, check local dataset
146 | # # dataset = load_from_disk(os.path.join(ds, split))
147 | except Exception as e:
148 | dataset = Dataset.from_pandas(pd.read_feather(ds))
149 |
150 | if "train" in split:
151 | raw_train_datasets.append(dataset)
152 | elif "test" in split:
153 | raw_val_datasets.append(dataset)
154 | else:
155 | raise ValueError(f"Split type {split} not recognized as one of test or train.")
156 |
157 | if any(frac < 0 for frac in fracs):
158 | raise ValueError("Dataset fractions cannot be negative.")
159 |
160 | if len(raw_train_datasets) > 0:
161 | train_subsets = []
162 | for dataset, frac in zip(raw_train_datasets, fracs):
163 | train_subset = dataset.select(range(int(frac * len(dataset))))
164 | train_subsets.append(train_subset)
165 | if shuffle:
166 | raw_datasets["train"] = concatenate_datasets(train_subsets).shuffle(seed=42)
167 | else:
168 | raw_datasets["train"] = concatenate_datasets(train_subsets)
169 | # No subsampling for test datasets to enable fair comparison across models
170 | if len(raw_val_datasets) > 0:
171 | if shuffle:
172 | raw_datasets["test"] = concatenate_datasets(raw_val_datasets).shuffle(seed=42)
173 | else:
174 | raw_datasets["test"] = concatenate_datasets(raw_val_datasets)
175 |
176 | if len(raw_datasets) == 0:
177 | raise ValueError(
178 | f"Dataset {dataset_mixer} not recognized with split {split}. Check the dataset has been correctly formatted."
179 | )
180 |
181 | return raw_datasets
--------------------------------------------------------------------------------
/src/data/utils.py:
--------------------------------------------------------------------------------
1 | from bs4 import BeautifulSoup
2 | from markdown import markdown
3 | import re
4 |
5 | def markdown_to_text(markdown_string):
6 | """ Converts a markdown string to plaintext """
7 |
8 | # md -> html -> text since BeautifulSoup can extract text cleanly
9 | html = markdown(markdown_string)
10 |
11 | # remove code snippets
12 | html = re.sub(r'(.*?)
', ' ', html)
13 | html = re.sub(r'(.*?)
', ' ', html)
14 |
15 | # extract text
16 | soup = BeautifulSoup(html, "html.parser")
17 | text = ''.join(soup.findAll(text=True))
18 |
19 | return text
--------------------------------------------------------------------------------
/src/model/__init__.py:
--------------------------------------------------------------------------------
1 | from src.model.tokenizer import *
--------------------------------------------------------------------------------
/src/model/__pycache__/__init__.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FSoft-AI4Code/XMainframe/801033385e0457667aff301dc07df1c1b8ca4b04/src/model/__pycache__/__init__.cpython-310.pyc
--------------------------------------------------------------------------------
/src/model/__pycache__/tokenizer.cpython-310.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FSoft-AI4Code/XMainframe/801033385e0457667aff301dc07df1c1b8ca4b04/src/model/__pycache__/tokenizer.cpython-310.pyc
--------------------------------------------------------------------------------
/src/model/tokenizer.py:
--------------------------------------------------------------------------------
1 | from transformers import AutoTokenizer, PreTrainedTokenizer
2 | from src.alignment.configs import DataArguments, ModelArguments
3 |
4 |
5 | DEFAULT_CHAT_TEMPLATE = "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}"
6 |
7 | def get_tokenizer(model_args: ModelArguments, data_args: DataArguments) -> PreTrainedTokenizer:
8 | """Get the tokenizer for the model."""
9 | tokenizer = AutoTokenizer.from_pretrained(
10 | model_args.model_name_or_path,
11 | revision=model_args.model_revision,
12 | trust_remote_code=True
13 | )
14 | if tokenizer.pad_token_id is None:
15 | tokenizer.pad_token_id = tokenizer.eos_token_id
16 |
17 | if data_args.truncation_side is not None:
18 | tokenizer.truncation_side = data_args.truncation_side
19 |
20 | # Set reasonable default for models without max length
21 | # if tokenizer.model_max_length > 100_000:
22 | # tokenizer.model_max_length = 4096
23 | tokenizer.model_max_length = 4096
24 |
25 | if data_args.chat_template is not None:
26 | tokenizer.chat_template = data_args.chat_template
27 | elif tokenizer.chat_template is None:
28 | tokenizer.chat_template = DEFAULT_CHAT_TEMPLATE
29 |
30 | return tokenizer
31 |
32 |
33 | def get_tokenizer_phi2(model_args: ModelArguments, data_args: DataArguments) -> PreTrainedTokenizer:
34 | """Get the tokenizer for the model."""
35 | tokenizer = AutoTokenizer.from_pretrained(
36 | model_args.model_name_or_path,
37 | revision=model_args.model_revision,
38 | use_fast=False
39 | )
40 | tokenizer.add_tokens(["<|im_start|>", ""])
41 | tokenizer.pad_token = ""
42 | tokenizer.add_special_tokens(dict(eos_token="<|im_end|>"))
43 |
44 | tokenizer.model_max_length = 2048
45 |
46 |
47 | return tokenizer
48 |
49 | def get_tokenizer_qwen15(model_args: ModelArguments, data_args: DataArguments) -> PreTrainedTokenizer:
50 | """Get the tokenizer for the model."""
51 | tokenizer = AutoTokenizer.from_pretrained(
52 | model_args.model_name_or_path,
53 | revision=model_args.model_revision,
54 | model_max_length=8192,
55 | padding_side="right",
56 | use_fast=False,
57 | trust_remote_code=True
58 | )
59 | # tokenizer.pad_token_id = tokenizer.eod_id
60 |
61 | # tokenizer.model_max_length = 8192
62 |
63 | return tokenizer
64 |
65 |
66 | def get_tokenizer_code_llama(model_args: ModelArguments, data_args: DataArguments) -> PreTrainedTokenizer:
67 | """Get the tokenizer for the model."""
68 | tokenizer = AutoTokenizer.from_pretrained(
69 | model_args.model_name_or_path,
70 | revision=model_args.model_revision
71 | )
72 | tokenizer.add_eos_token = True
73 | tokenizer.pad_token_id = 0
74 | tokenizer.padding_side = "left"
75 |
76 | return tokenizer
--------------------------------------------------------------------------------
/train_instruct.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding=utf-8
3 | # Copyright 2023 The HuggingFace Inc. team. All rights reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | """
17 | Supervised fine-tuning script for decoder language models.
18 | """
19 |
20 | import logging
21 | import random
22 | import sys
23 | import datetime
24 |
25 | import pandas as pd
26 | import datasets
27 | import torch
28 | import transformers
29 | from transformers import set_seed, DataCollatorForLanguageModeling,AutoModelForCausalLM
30 | from datasets import load_dataset, Dataset, concatenate_datasets
31 |
32 | from accelerate import Accelerator
33 | from src.alignment import (
34 | DataArguments,
35 | H4ArgumentParser,
36 | ModelArguments,
37 | SFTConfig,
38 | # apply_chat_template,
39 | # get_datasets,
40 | get_checkpoint,
41 | get_kbit_device_map,
42 | get_peft_config,
43 | get_quantization_config,
44 | # get_tokenizer,
45 | )
46 | # from trl import SFTConfig
47 | from trl import SFTTrainer, setup_chat_format
48 | from src.data import get_datasets, apply_chat_template
49 | from src.model import get_tokenizer
50 |
51 |
52 | logger = logging.getLogger(__name__)
53 |
54 |
55 | def main():
56 | parser = H4ArgumentParser((ModelArguments, DataArguments, SFTConfig))
57 | model_args, data_args, training_args = parser.parse()
58 |
59 | # Set seed for reproducibility
60 | set_seed(training_args.seed)
61 |
62 | # accelerator = Accelerator()
63 |
64 | ###############
65 | # Setup logging
66 | ###############
67 | logging.basicConfig(
68 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
69 | datefmt="%Y-%m-%d %H:%M:%S",
70 | handlers=[logging.StreamHandler(sys.stdout)],
71 | )
72 | log_level = training_args.get_process_log_level()
73 | logger.setLevel(log_level)
74 | datasets.utils.logging.set_verbosity(log_level)
75 | transformers.utils.logging.set_verbosity(log_level)
76 | transformers.utils.logging.enable_default_handler()
77 | transformers.utils.logging.enable_explicit_format()
78 |
79 | # Log on each process a small summary
80 | logger.warning(
81 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
82 | + f" distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
83 | )
84 | logger.info(f"Model parameters {model_args}")
85 | logger.info(f"Data parameters {data_args}")
86 | logger.info(f"Training/evaluation parameters {training_args}")
87 |
88 | # Check for last checkpoint
89 | last_checkpoint = get_checkpoint(training_args)
90 | if last_checkpoint is not None and training_args.resume_from_checkpoint is None:
91 | logger.info(f"Checkpoint detected, resuming training at {last_checkpoint=}.")
92 |
93 | ###############
94 | # Load datasets
95 | ###############
96 |
97 | mainframe_instruct_ds = Dataset.from_pandas(
98 | pd.read_feather(
99 | ".data/mainframegpt_instructions_vananh_20240209.feather"
100 | )
101 | )
102 | mainframe_oss_instruct_ds = Dataset.from_pandas(
103 | pd.read_feather(
104 | "./data/mainframegpt_oss_instruct.feather"
105 | )
106 | )
107 | mainframe_self_instruct_ds = Dataset.from_pandas(
108 | pd.read_feather(
109 | "./data/mainframegpt_self_instruct_2151.feather"
110 | )
111 | )
112 | slim_orca_ds = load_dataset("Open-Orca/SlimOrca-Dedup", cache_dir="./hf_cache/datasets")
113 |
114 | raw_datasets = concatenate_datasets([mainframe_instruct_ds, slim_orca_ds["train"]]).shuffle(seed=42)
115 | raw_datasets = raw_datasets.train_test_split(test_size=0.1)
116 | logger.info(
117 | f"Training on the following datasets and their proportions: {[split + ' : ' + str(dset.num_rows) for split, dset in raw_datasets.items()]}"
118 | )
119 | column_names = list(raw_datasets["train"].features)
120 |
121 | ################
122 | # Load tokenizer
123 | ################
124 | tokenizer = get_tokenizer(model_args, data_args)
125 |
126 | #####################
127 | # Apply chat template
128 | #####################
129 | raw_datasets = raw_datasets.map(
130 | apply_chat_template,
131 | fn_kwargs={"tokenizer": tokenizer, "task": "sft"},
132 | num_proc=data_args.preprocessing_num_workers,
133 | remove_columns=column_names,
134 | desc="Applying chat template",
135 | )
136 | train_dataset = raw_datasets["train"]
137 | eval_dataset = raw_datasets["test"]
138 |
139 | data_collator = DataCollatorForLanguageModeling(
140 | tokenizer=tokenizer,
141 | mlm=False,
142 | return_tensors='pt'
143 | )
144 |
145 | #######################
146 | # Load pretrained model
147 | #######################
148 | logger.info("*** Load pretrained model ***")
149 | torch_dtype =(
150 | model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype)
151 | )
152 | quantization_config = get_quantization_config(model_args)
153 |
154 | model_kwargs = dict(
155 | revision=model_args.model_revision,
156 | trust_remote_code=model_args.trust_remote_code,
157 | use_flash_attention_2=model_args.use_flash_attention_2,
158 | torch_dtype=torch_dtype,
159 | use_cache=False if training_args.gradient_checkpointing else True,
160 | device_map=get_kbit_device_map() if quantization_config is not None else None,
161 | quantization_config=quantization_config,
162 | )
163 | print(model_kwargs)
164 | print(type(torch_dtype))
165 |
166 | model = model_args.model_name_or_path
167 | # For ChatML we need to add special tokens and resize the embedding layer
168 | if "<|im_start|>" in tokenizer.chat_template and "gemma-tokenizer-chatml" not in tokenizer.name_or_path:
169 | model = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path, **model_kwargs)
170 | model, tokenizer = setup_chat_format(model, tokenizer)
171 | model_kwargs = None
172 |
173 | logger.info("*** Model loaded! ***")
174 |
175 | ########################
176 | # Initialize the Trainer
177 | ########################
178 | trainer = SFTTrainer(
179 | model=model,
180 | model_init_kwargs=model_kwargs,
181 | args=training_args,
182 | train_dataset=train_dataset,
183 | eval_dataset=eval_dataset,
184 | dataset_text_field="text",
185 | max_seq_length=training_args.max_seq_length,
186 | tokenizer=tokenizer,
187 | packing=True,
188 | peft_config=get_peft_config(model_args),
189 | neftune_noise_alpha=5,
190 | data_collator=data_collator,
191 | )
192 |
193 | ###############
194 | # Training loop
195 | ###############
196 | logger.info("*** Train ***")
197 | checkpoint = None
198 | if training_args.resume_from_checkpoint is not None:
199 | checkpoint = training_args.resume_from_checkpoint
200 | elif last_checkpoint is not None:
201 | checkpoint = last_checkpoint
202 | train_result = trainer.train(resume_from_checkpoint=checkpoint)
203 | metrics = train_result.metrics
204 | metrics["train_samples"] = len(train_dataset)
205 | trainer.log_metrics("train", metrics)
206 | trainer.save_metrics("train", metrics)
207 | trainer.save_state()
208 |
209 | ##########
210 | # Evaluate
211 | ##########
212 | if training_args.do_eval:
213 | logger.info("*** Evaluate ***")
214 | metrics = trainer.evaluate()
215 | metrics["eval_samples"] = len(eval_dataset)
216 | trainer.log_metrics("eval", metrics)
217 | trainer.save_metrics("eval", metrics)
218 |
219 | ##################################
220 | # Save model and create model card
221 | ##################################
222 | logger.info("*** Save model ***")
223 | trainer.save_model(training_args.output_dir)
224 | logger.info(f"Model saved to {training_args.output_dir}")
225 |
226 | kwargs = {
227 | "finetuned_from": model_args.model_name_or_path,
228 | "dataset": list(data_args.dataset_mixer.keys()),
229 | "dataset_tags": list(data_args.dataset_mixer.keys()),
230 | "tags": ["alignment-handbook"],
231 | }
232 | if trainer.accelerator.is_main_process:
233 | trainer.create_model_card(**kwargs)
234 | # Restore k,v cache for fast inference
235 | trainer.model.config.use_cache = True
236 | trainer.model.config.save_pretrained(training_args.output_dir)
237 |
238 | if training_args.push_to_hub is True:
239 | logger.info("Pushing to hub...")
240 | trainer.push_to_hub(**kwargs)
241 |
242 | logger.info("*** Training complete ***")
243 |
244 | if __name__ == "__main__":
245 | main()
246 |
--------------------------------------------------------------------------------
/train_raw.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding=utf-8
3 | # Copyright 2023 The HuggingFace Inc. team. All rights reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | """
17 | Supervised fine-tuning script for decoder language models.
18 | """
19 |
20 | import logging
21 | import random
22 | import sys
23 | import datetime
24 |
25 | import pandas as pd
26 | import datasets
27 | import torch
28 | import transformers
29 | from transformers import set_seed, DataCollatorForLanguageModeling,AutoModelForCausalLM
30 | from datasets import load_dataset, Dataset, concatenate_datasets
31 |
32 | # import sys
33 | # sys.path.append('./alignment-handbook/src')
34 | from accelerate import Accelerator
35 | from src.alignment import (
36 | DataArguments,
37 | H4ArgumentParser,
38 | ModelArguments,
39 | SFTConfig,
40 | # apply_chat_template,
41 | # get_datasets,
42 | get_checkpoint,
43 | get_kbit_device_map,
44 | get_peft_config,
45 | get_quantization_config,
46 | # get_tokenizer,
47 | )
48 | from trl import SFTTrainer,setup_chat_format
49 | from src.data import get_datasets, apply_chat_template
50 | from src.model.tokenizer import get_tokenizer, get_tokenizer_phi2, get_tokenizer_qwen15, get_tokenizer_code_llama
51 |
52 | # import deepspeed
53 | # from deepspeed.ops.op_builder import builder
54 |
55 | logger = logging.getLogger(__name__)
56 |
57 |
58 | def main():
59 | parser = H4ArgumentParser((ModelArguments, DataArguments, SFTConfig))
60 | model_args, data_args, training_args = parser.parse()
61 |
62 | # Set seed for reproducibility
63 | set_seed(training_args.seed)
64 |
65 | # accelerator = Accelerator()
66 |
67 | ###############
68 | # Setup logging
69 | ###############
70 | logging.basicConfig(
71 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
72 | datefmt="%Y-%m-%d %H:%M:%S",
73 | handlers=[logging.StreamHandler(sys.stdout)],
74 | )
75 | log_level = training_args.get_process_log_level()
76 | logger.setLevel(log_level)
77 | datasets.utils.logging.set_verbosity(log_level)
78 | transformers.utils.logging.set_verbosity(log_level)
79 | transformers.utils.logging.enable_default_handler()
80 | transformers.utils.logging.enable_explicit_format()
81 |
82 | # Log on each process a small summary
83 | logger.warning(
84 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
85 | + f" distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
86 | )
87 | logger.info(f"Model parameters {model_args}")
88 | logger.info(f"Data parameters {data_args}")
89 | logger.info(f"Training/evaluation parameters {training_args}")
90 |
91 | # Check for last checkpoint
92 | last_checkpoint = get_checkpoint(training_args)
93 | if last_checkpoint is not None and training_args.resume_from_checkpoint is None:
94 | logger.info(f"Checkpoint detected, resuming training at {last_checkpoint=}.")
95 |
96 | ###############
97 | # Load datasets
98 | ###############
99 | # raw_datasets = get_datasets(data_args, splits=data_args.dataset_splits)
100 | mainframe_ds = Dataset.from_pandas(pd.read_feather("./data/mainframe_df_v1.1_chunks_4096.feather"))
101 | textbook_ds = Dataset.from_pandas(pd.read_feather("./data/textbook_quality_programming.feather"))
102 | longcontext = Dataset.from_pandas(pd.read_feather(".data/long_context_data/long_data.feather"))
103 | raw_datasets = longcontext.train_test_split(test_size=0.1)
104 | logger.info(
105 | f"Training on the following datasets and their proportions: {[split + ' : ' + str(dset.num_rows) for split, dset in raw_datasets.items()]}"
106 | )
107 | column_names = list(raw_datasets["train"].features)
108 |
109 | ################
110 | # Load tokenizer
111 | ################
112 | tokenizer = get_tokenizer_code_llama(model_args, data_args)
113 |
114 |
115 | train_dataset = raw_datasets["train"]
116 | eval_dataset = raw_datasets["test"]
117 |
118 | with training_args.main_process_first(desc="Log a few random samples from the processed training set"):
119 | for index in random.sample(range(len(raw_datasets["train"])), 3):
120 | logger.info(f"Sample {index} of the processed training set:\n\n{raw_datasets['train'][index]['text']}")
121 |
122 | data_collator = DataCollatorForLanguageModeling(
123 | tokenizer=tokenizer,
124 | mlm=False,
125 | return_tensors='pt'
126 | )
127 |
128 |
129 | #######################
130 | # Load pretrained model
131 | #######################
132 | logger.info("*** Load pretrained model ***")
133 | torch_dtype = (
134 | model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype)
135 | )
136 | # torch_dtype = 'float16'
137 | quantization_config = get_quantization_config(model_args)
138 |
139 | model_kwargs = dict(
140 | revision=model_args.model_revision,
141 | trust_remote_code=model_args.trust_remote_code,
142 | use_flash_attention_2=model_args.use_flash_attention_2,
143 | torch_dtype=torch_dtype,
144 | use_cache=False if training_args.gradient_checkpointing else True,
145 | # device_map=get_kbit_device_map() if quantization_config is not None else None,
146 | quantization_config=quantization_config,
147 | )
148 |
149 | logger.info("*** Model loaded! ***")
150 |
151 | ########################
152 | # Initialize the Trainer
153 | ########################
154 | trainer = SFTTrainer(
155 | model=model_args.model_name_or_path,
156 | model_init_kwargs=model_kwargs,
157 | args=training_args,
158 | train_dataset=train_dataset,
159 | eval_dataset=eval_dataset,
160 | dataset_text_field="text",
161 | max_seq_length=training_args.max_seq_length,
162 | tokenizer=tokenizer,
163 | packing=True,
164 | peft_config=get_peft_config(model_args),
165 | neftune_noise_alpha=5,
166 | data_collator=data_collator,
167 | )
168 |
169 | ###############
170 | # Training loop
171 | ###############
172 | logger.info("*** Train ***")
173 | # train_result = trainer.train()
174 | checkpoint = None
175 | if training_args.resume_from_checkpoint is not None:
176 | checkpoint = training_args.resume_from_checkpoint
177 | elif last_checkpoint is not None:
178 | checkpoint = last_checkpoint
179 |
180 | train_result = trainer.train(resume_from_checkpoint=checkpoint)
181 | metrics = train_result.metrics
182 | metrics["train_samples"] = len(train_dataset)
183 | trainer.log_metrics("train", metrics)
184 | trainer.save_metrics("train", metrics)
185 | trainer.save_state()
186 |
187 | ##########
188 | # Evaluate
189 | ##########
190 | if training_args.do_eval:
191 | logger.info("*** Evaluate ***")
192 | metrics = trainer.evaluate()
193 | metrics["eval_samples"] = len(eval_dataset)
194 | trainer.log_metrics("eval", metrics)
195 | trainer.save_metrics("eval", metrics)
196 |
197 | ##################################
198 | # Save model and create model card
199 | ##################################
200 | logger.info("*** Save model ***")
201 | trainer.save_model(training_args.output_dir)
202 | logger.info(f"Model saved to {training_args.output_dir}")
203 |
204 | kwargs = {
205 | "finetuned_from": model_args.model_name_or_path,
206 | "dataset": list(data_args.dataset_mixer.keys()),
207 | "dataset_tags": list(data_args.dataset_mixer.keys()),
208 | "tags": ["alignment-handbook"],
209 | }
210 | # if trainer.accelerator.is_main_process:
211 | trainer.create_model_card(**kwargs)
212 | # Restore k,v cache for fast inference
213 | trainer.model.config.use_cache = True
214 | trainer.model.config.save_pretrained(training_args.output_dir)
215 |
216 | if training_args.push_to_hub is True:
217 | logger.info("Pushing to hub...")
218 | trainer.push_to_hub(**kwargs)
219 |
220 | logger.info("*** Training complete ***")
221 |
222 | if __name__ == "__main__":
223 | main()
224 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | import torch
4 | import multiprocessing
5 | import torch.distributed as dist
6 |
7 |
8 | def setup_for_distributed(is_master):
9 | """
10 | This function disables printing when not in master process
11 | """
12 | import builtins as __builtin__
13 | builtin_print = __builtin__.print
14 |
15 | def print(*args, **kwargs):
16 | force = kwargs.pop('force', False)
17 | if is_master or force:
18 | builtin_print(*args, **kwargs)
19 |
20 | __builtin__.print = print
21 |
22 |
23 | def get_world_size():
24 | if not is_dist_avail_and_initialized():
25 | return 1
26 | return dist.get_world_size()
27 |
28 | def is_dist_avail_and_initialized():
29 | if not dist.is_available():
30 | return False
31 | if not dist.is_initialized():
32 | return False
33 | return True
34 |
35 | def get_rank():
36 | if not is_dist_avail_and_initialized():
37 | return 0
38 | return dist.get_rank()
39 |
40 | def init_distributed_mode(args):
41 | cpu_cont = multiprocessing.cpu_count()
42 | # if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
43 | # args.rank = int(os.environ["RANK"])
44 | # args.world_size = int(os.environ['WORLD_SIZE'])
45 | # args.gpu = int(os.environ['LOCAL_RANK'])
46 | # elif 'SLURM_PROCID' in os.environ:
47 | # args.rank = int(os.environ['SLURM_PROCID'])
48 | # args.gpu = args.rank % torch.cuda.device_count()
49 | # else:
50 | # print('Not using distributed mode')
51 | # args.distributed = False
52 | # return
53 |
54 | args.distributed = True
55 | args.rank = get_rank()
56 | # args.world_size = get_world_size()
57 | args.gpu = args.rank % torch.cuda.device_count()
58 |
59 | torch.cuda.set_device(args.gpu)
60 | args.dist_backend = 'nccl'
61 | print('| distributed init (rank {}, word {}): {}'.format(
62 | args.rank, args.world_size, args.dist_url), flush=True)
63 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
64 | world_size=args.world_size, rank=args.rank)
65 | torch.distributed.barrier()
66 | device = torch.device("cuda", args.gpu)
67 | args.n_gpu = torch.cuda.device_count()
68 | args.device = device
69 | args.cpu_cont = cpu_cont
70 | setup_for_distributed(args.rank == 0)
71 |
72 |
--------------------------------------------------------------------------------