├── LICENSE.txt ├── README.md ├── asset ├── .DS_Store ├── XMAiNframe.png ├── cobol_diagram-Page-1.drawio.png ├── cobol_diagram-Page-8.drawio.pdf ├── cobol_diagram-Page-8.drawio.png ├── sample_1.png ├── sample_2.png └── sample_3.png ├── recipes ├── accelerate_configs │ ├── deepspeed_zero3.yaml │ ├── deepspeed_zero3_lora.yaml │ ├── fsdp.yaml │ ├── fsdp_qlora.yaml │ └── multi_gpu.yaml └── deepseek │ ├── full.yaml │ ├── full_instruct.yaml │ ├── lora_instruct.yaml │ └── lora_sft.yaml ├── requirements.txt ├── scripts ├── ft.sh └── instruct.sh ├── src ├── alignment │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-310.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── configs.cpython-310.pyc │ │ ├── configs.cpython-38.pyc │ │ ├── data.cpython-310.pyc │ │ ├── data.cpython-38.pyc │ │ ├── decontaminate.cpython-310.pyc │ │ ├── decontaminate.cpython-38.pyc │ │ ├── model_utils.cpython-310.pyc │ │ └── model_utils.cpython-38.pyc │ ├── configs.py │ ├── data.py │ ├── decontaminate.py │ ├── model_utils.py │ └── release.py ├── data │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-310.pyc │ │ ├── data.cpython-310.pyc │ │ └── utils.cpython-310.pyc │ ├── data.py │ └── utils.py └── model │ ├── __init__.py │ ├── __pycache__ │ ├── __init__.cpython-310.pyc │ └── tokenizer.cpython-310.pyc │ └── tokenizer.py ├── train_instruct.py ├── train_raw.py └── utils.py /LICENSE.txt: -------------------------------------------------------------------------------- 1 | Apache License Version 2.0 2 | 3 | Copyright (c) 2024 FPT Software, Inc. 4 | All rights reserved. 5 | 6 | Apache License 7 | Version 2.0, January 2004 8 | http://www.apache.org/licenses/ 9 | 10 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 11 | 12 | 1. Definitions. 13 | 14 | "License" shall mean the terms and conditions for use, reproduction, 15 | and distribution as defined by Sections 1 through 9 of this document. 16 | 17 | "Licensor" shall mean the copyright owner or entity authorized by 18 | the copyright owner that is granting the License. 19 | 20 | "Legal Entity" shall mean the union of the acting entity and all 21 | other entities that control, are controlled by, or are under common 22 | control with that entity. For the purposes of this definition, 23 | "control" means (i) the power, direct or indirect, to cause the 24 | direction or management of such entity, whether by contract or 25 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 26 | outstanding shares, or (iii) beneficial ownership of such entity. 27 | 28 | "You" (or "Your") shall mean an individual or Legal Entity 29 | exercising permissions granted by this License. 30 | 31 | "Source" form shall mean the preferred form for making modifications, 32 | including but not limited to software source code, documentation 33 | source, and configuration files. 34 | 35 | "Object" form shall mean any form resulting from mechanical 36 | transformation or translation of a Source form, including but 37 | not limited to compiled object code, generated documentation, 38 | and conversions to other media types. 39 | 40 | "Work" shall mean the work of authorship, whether in Source or 41 | Object form, made available under the License, as indicated by a 42 | copyright notice that is included in or attached to the work 43 | (an example is provided in the Appendix below). 44 | 45 | "Derivative Works" shall mean any work, whether in Source or Object 46 | form, that is based on (or derived from) the Work and for which the 47 | editorial revisions, annotations, elaborations, or other modifications 48 | represent, as a whole, an original work of authorship. For the purposes 49 | of this License, Derivative Works shall not include works that remain 50 | separable from, or merely link (or bind by name) to the interfaces of, 51 | the Work and Derivative Works thereof. 52 | 53 | "Contribution" shall mean any work of authorship, including 54 | the original version of the Work and any modifications or additions 55 | to that Work or Derivative Works thereof, that is intentionally 56 | submitted to Licensor for inclusion in the Work by the copyright owner 57 | or by an individual or Legal Entity authorized to submit on behalf of 58 | the copyright owner. For the purposes of this definition, "submitted" 59 | means any form of electronic, verbal, or written communication sent 60 | to the Licensor or its representatives, including but not limited to 61 | communication on electronic mailing lists, source code control systems, 62 | and issue tracking systems that are managed by, or on behalf of, the 63 | Licensor for the purpose of discussing and improving the Work, but 64 | excluding communication that is conspicuously marked or otherwise 65 | designated in writing by the copyright owner as "Not a Contribution." 66 | 67 | "Contributor" shall mean Licensor and any individual or Legal Entity 68 | on behalf of whom a Contribution has been received by Licensor and 69 | subsequently incorporated within the Work. 70 | 71 | 2. Grant of Copyright License. Subject to the terms and conditions of 72 | this License, each Contributor hereby grants to You a perpetual, 73 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 74 | copyright license to reproduce, prepare Derivative Works of, 75 | publicly display, publicly perform, sublicense, and distribute the 76 | Work and such Derivative Works in Source or Object form. 77 | 78 | 3. Grant of Patent License. Subject to the terms and conditions of 79 | this License, each Contributor hereby grants to You a perpetual, 80 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 81 | (except as stated in this section) patent license to make, have made, 82 | use, offer to sell, sell, import, and otherwise transfer the Work, 83 | where such license applies only to those patent claims licensable 84 | by such Contributor that are necessarily infringed by their 85 | Contribution(s) alone or by combination of their Contribution(s) 86 | with the Work to which such Contribution(s) was submitted. If You 87 | institute patent litigation against any entity (including a 88 | cross-claim or counterclaim in a lawsuit) alleging that the Work 89 | or a Contribution incorporated within the Work constitutes direct 90 | or contributory patent infringement, then any patent licenses 91 | granted to You under this License for that Work shall terminate 92 | as of the date such litigation is filed. 93 | 94 | 4. Redistribution. You may reproduce and distribute copies of the 95 | Work or Derivative Works thereof in any medium, with or without 96 | modifications, and in Source or Object form, provided that You 97 | meet the following conditions: 98 | 99 | (a) You must give any other recipients of the Work or 100 | Derivative Works a copy of this License; and 101 | 102 | (b) You must cause any modified files to carry prominent notices 103 | stating that You changed the files; and 104 | 105 | (c) You must retain, in the Source form of any Derivative Works 106 | that You distribute, all copyright, patent, trademark, and 107 | attribution notices from the Source form of the Work, 108 | excluding those notices that do not pertain to any part of 109 | the Derivative Works; and 110 | 111 | (d) If the Work includes a "NOTICE" text file as part of its 112 | distribution, then any Derivative Works that You distribute must 113 | include a readable copy of the attribution notices contained 114 | within such NOTICE file, excluding those notices that do not 115 | pertain to any part of the Derivative Works, in at least one 116 | of the following places: within a NOTICE text file distributed 117 | as part of the Derivative Works; within the Source form or 118 | documentation, if provided along with the Derivative Works; or, 119 | within a display generated by the Derivative Works, if and 120 | wherever such third-party notices normally appear. The contents 121 | of the NOTICE file are for informational purposes only and 122 | do not modify the License. You may add Your own attribution 123 | notices within Derivative Works that You distribute, alongside 124 | or as an addendum to the NOTICE text from the Work, provided 125 | that such additional attribution notices cannot be construed 126 | as modifying the License. 127 | 128 | You may add Your own copyright statement to Your modifications and 129 | may provide additional or different license terms and conditions 130 | for use, reproduction, or distribution of Your modifications, or 131 | for any such Derivative Works as a whole, provided Your use, 132 | reproduction, and distribution of the Work otherwise complies with 133 | the conditions stated in this License. 134 | 135 | 5. Submission of Contributions. Unless You explicitly state otherwise, 136 | any Contribution intentionally submitted for inclusion in the Work 137 | by You to the Licensor shall be under the terms and conditions of 138 | this License, without any additional terms or conditions. 139 | Notwithstanding the above, nothing herein shall supersede or modify 140 | the terms of any separate license agreement you may have executed 141 | with Licensor regarding such Contributions. 142 | 143 | 6. Trademarks. This License does not grant permission to use the trade 144 | names, trademarks, service marks, or product names of the Licensor, 145 | except as required for reasonable and customary use in describing the 146 | origin of the Work and reproducing the content of the NOTICE file. 147 | 148 | 7. Disclaimer of Warranty. Unless required by applicable law or 149 | agreed to in writing, Licensor provides the Work (and each 150 | Contributor provides its Contributions) on an "AS IS" BASIS, 151 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 152 | implied, including, without limitation, any warranties or conditions 153 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 154 | PARTICULAR PURPOSE. You are solely responsible for determining the 155 | appropriateness of using or redistributing the Work and assume any 156 | risks associated with Your exercise of permissions under this License. 157 | 158 | 8. Limitation of Liability. In no event and under no legal theory, 159 | whether in tort (including negligence), contract, or otherwise, 160 | unless required by applicable law (such as deliberate and grossly 161 | negligent acts) or agreed to in writing, shall any Contributor be 162 | liable to You for damages, including any direct, indirect, special, 163 | incidental, or consequential damages of any character arising as a 164 | result of this License or out of the use or inability to use the 165 | Work (including but not limited to damages for loss of goodwill, 166 | work stoppage, computer failure or malfunction, or any and all 167 | other commercial damages or losses), even if such Contributor 168 | has been advised of the possibility of such damages. 169 | 170 | 9. Accepting Warranty or Additional Liability. While redistributing 171 | the Work or Derivative Works thereof, You may choose to offer, 172 | and charge a fee for, acceptance of support, warranty, indemnity, 173 | or other liability obligations and/or rights consistent with this 174 | License. However, in accepting such obligations, You may act only 175 | on Your own behalf and on Your sole responsibility, not on behalf 176 | of any other Contributor, and only if You agree to indemnify, 177 | defend, and hold each Contributor harmless for any liability 178 | incurred by, or claims asserted against, such Contributor by reason 179 | of your accepting any such warranty or additional liability. 180 | 181 | END OF TERMS AND CONDITIONS 182 | 183 | APPENDIX: How to apply the Apache License to your work. 184 | 185 | To apply the Apache License to your work, attach the following 186 | boilerplate notice, with the fields enclosed by brackets "{}" 187 | replaced with your own identifying information. (Don't include 188 | the brackets!) The text should be enclosed in the appropriate 189 | comment syntax for the file format. We also recommend that a 190 | file or class name and description of purpose be included on the 191 | same "printed page" as the copyright notice for easier 192 | identification within third-party archives. 193 | 194 | Copyright {yyyy} {name of copyright owner} 195 | 196 | Licensed under the Apache License, Version 2.0 (the "License"); 197 | you may not use this file except in compliance with the License. 198 | You may obtain a copy of the License at 199 | 200 | http://www.apache.org/licenses/LICENSE-2.0 201 | 202 | Unless required by applicable law or agreed to in writing, software 203 | distributed under the License is distributed on an "AS IS" BASIS, 204 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 205 | See the License for the specific language governing permissions and 206 | limitations under the License. 207 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | # XMAiNframe: A Large Language Model for Mainframe Modernization 4 | [![License: MIT](https://img.shields.io/badge/License-MIT-green.svg)](https://opensource.org/licenses/MIT) 5 | [![arXiv](https://img.shields.io/badge/2406.11927-red?style=flat&label=arXiv)](link) 6 | [![XMAiNframe on Huggingface](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-XMAiNframe-ffc107?color=ffc107&logoColor=white%22%20style=%22display:%20inline-block;%20vertical-align:%20middle;)](https://huggingface.co/collections/Fsoft-AIC/xmainframe-66aca02d5b552e62033dc2bc) 7 | [![Python 3.10+](https://img.shields.io/badge/python-3.10+-blue.svg)](https://www.python.org/downloads/release/python-3100/) 8 | 9 |
10 | 11 | 12 | ## Table of Contents 13 | - [Introduction](#dataset-summary) 14 | - [Demonstration](#demo) 15 | - [Procedure of Data Construction](#procedure-of-data-construction) 16 | - [Mainframe-Training](#mainframe-training) 17 | - [Mainframe-Instruct](#mainframe-instruct) 18 | - [Model Download](#model-download) 19 | - [Evaluation Results](#evaluation-results) 20 | - [Usage](#usage) 21 | - [Fine-tune XMAiNframe](#how-to-fine-tune-xmainframe) 22 | - [Inference](#inference) 23 | - [License](##licensing-information) 24 | - [Acknowledgements](#acknowledgements) 25 | - [Contact Us](#contact-us) 26 | - [Citation Information](#citation-information) 27 | 28 | 29 | 30 | # Introduction 31 | 32 | We are introducing **XMAiNframe**, a state-of-the-art large language model (LLM) specifically designed with knowledge of mainframe legacy systems and COBOL codebases. XMAiNframe is built on top of DeepSeek-Coder 7B and is available with 7B and 10.5B parameters. 33 | Additionally, we present [MainframeBench](https://huggingface.co/datasets/Fsoft-AIC/MainframeBench), a comprehensive benchmark for assessing mainframe knowledge, including multiple-choice questions, question answering, and COBOL code summarization. Our empirical evaluations demonstrate that XMAiNframe consistently outperforms existing state-of-the-art LLMs across these tasks. Specifically, XMAiNframe achieves 30% higher accuracy than DeepSeek-Coder on multiple-choice questions, doubles the BLEU score of Mixtral-Instruct 8x7B on question answering, and scores six times higher than GPT-3.5 on COBOL summarization. Our work highlights the potential of XMAiNframe to drive significant advancements in managing and modernizing legacy systems, thereby enhancing productivity and saving time for software developers. 34 | 35 | # Demonstration 36 | 37 | In this section, we demonstrate the capabilities of XMAiNframe by comparing it with the leading language model, DeepSeek-Coder-7B. We evaluate the performance of each model by showcasing their responses to a series of realistic questions related to mainframe knowledge. The images below illustrate how each model handles identical prompts. As shown, the responses generated by XMAiNframe are not only accurate but also more detailed and comprehensive compared to those from the base model, DeepSeek-Coder-7B. This makes XMAiNframe particularly valuable for developers seeking a reliable and thorough AI assistant in the mainframe environment. 38 | 39 |
40 | 41 | 42 | 43 | 44 | 45 | 46 |
47 | 48 | 49 | # Procedure of Data Construction 50 | ## Mainframe-Training 51 | 52 | We utilized two different sources: using the GitHub API to collect COBOL projects hosted on GitHub and gathering online document data relevant to mainframes. In total, Mainframe-Training Dataset consists of 236 million tokens from documents about the mainframe technology and COBOL constructs. In the pre-training process, we combined our Mainframe-Training Dataset with [SlimOrca-Dedup](https://huggingface.co/datasets/Open-Orca/SlimOrca-Dedup) to enrich the model’s mainframe knowledge while retaining its general capabilities. 53 | 54 | ## Mainframe-Instruct 55 | 56 |
57 | 58 | 59 | 60 |
61 | 62 | Mainframe-Instruct is a high-quality synthetic dataset created through 5 steps: 63 | 64 | - Step 1: 300 seed data instances about Mainframe and COBOL are gathered and annotated by our domain experts. 65 | 66 | - Step 2: Using popular LLMs to enrich Mainframe-Instruct from seed data. 67 | 68 | - Step 3: Utilizing GPT-4 as an evaluator to judge model responses, scoring the outputs and ranking responses in a pairwise manner. 69 | 70 | - Step 4: Filtering and manually checking. 71 | 72 | - Step 5: Dividing Mainframe-Instruct into three tasks: Multiple Choice Questions, Question Answering, and COBOL summarization. 73 | 74 | Below are the statistics of Mainframe-Instruct Dataset: 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 |
Training SamplesValidating SamplesTesting Samples
Multiple Choice Questions13.8941.5441.931
Question Answering18.6922.0782.598
COBOL Summarization9.0811.0102.523
102 | 103 | 104 | [MainframeBench](https://huggingface.co/datasets/Fsoft-AIC/MainframeBench), our benchmark for mainframe knowledge, is the testing set in Mainframe-Instruct Dataset. This benchmark is used to evaluate our LLMs with others which is now available at Huggingface datasets. 105 | 106 | ```python 107 | from datasets import load_dataset 108 | 109 | # Load each sub-set in MainframeBench 110 | QA_set = load_dataset("Fsoft-AIC/MainframeBench", 'question_answering') 111 | MC_set = load_dataset("Fsoft-AIC/MainframeBench", 'multiple_choice_question') 112 | Summarization_set = load_dataset("Fsoft-AIC/MainframeBench", 'COBOL_code_summarization') 113 | ``` 114 | 115 | # Model Download 116 | We release XMAiNframe with 7B and 10.5B parameters, including base and instruct models, to the public. XMAiNframe 10.5B is expanded from DeepSeek-Coder 7B by the depth up-scaling method without introducing additional modules or dynamic expert selection methods. 117 | 118 |
119 | 120 | | **Model** | **Download** | 121 | | :-----------------------------: | :----------------------------------------------------------: | 122 | | XMAiNframe-base-7b | [🤗 HuggingFace](https://https://huggingface.co/Fsoft-AIC/XMAiNframe-base-7b/) | 123 | | XMAiNframe-instruct-7b | [🤗 HuggingFace](https://huggingface.co/Fsoft-AIC/XMAiNframe-instruct-7b) | 124 | | XMAiNframe-base-10.5b | [🤗 HuggingFace](https://huggingface.co/Fsoft-AIC/XMAiNframe-base-10.5b) | 125 | | XMAiNframe-instruct-10.5b | [🤗 HuggingFace](https://huggingface.co/Fsoft-AIC/XMAiNframe-instruct-10.5b) | 126 | 127 |
128 | 129 | 130 | # Evaluation Results 131 | ## Multiple Choice Question Task 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 | 141 | 142 | 143 | 144 | 145 | 146 | 147 | 148 | 149 | 150 | 151 | 152 | 153 | 154 | 155 | 156 | 157 | 158 | 159 | 160 | 161 | 162 | 163 | 164 | 165 | 166 | 167 | 168 | 169 | 170 | 171 | 172 | 173 | 174 |
ModelAccuracy (%)
GPT-473.90
GPT-3.574.56
Mixtral-Instruct 8x7B68.12
Mistral-Instruct 7B69.29
Neural-Chat66.35
DeepSeek-Coder-Instruct 6.7B47.49
DeepSeek-Coder-Instruct 33B53.29
XMAiNframe-Instruct 7B68.57
XMAiNframe-Instruct 10.5B77.89
175 | 176 | ## Question Answering Task 177 | 178 | 179 | 180 | 181 | 182 | 183 | 184 | 185 | 186 | 187 | 188 | 189 | 190 | 191 | 192 | 193 | 194 | 195 | 196 | 197 | 198 | 199 | 200 | 201 | 202 | 203 | 204 | 205 | 206 | 207 | 208 | 209 | 210 | 211 | 212 | 213 | 214 | 215 | 216 | 217 | 218 | 219 | 220 | 221 | 222 | 223 | 224 | 225 | 226 | 227 | 228 | 229 | 230 | 231 | 232 | 233 | 234 | 235 | 236 | 237 | 238 | 239 | 240 | 241 | 242 | 243 | 244 | 245 | 246 | 247 | 248 | 249 | 250 | 251 | 252 | 253 | 254 | 255 | 256 | 257 | 258 | 259 | 260 | 261 | 262 | 263 | 264 | 265 | 266 | 267 | 268 | 269 |
ModelsMAPF1-ScoreBERTScoreRougeLMeteorBLEU-4
GPT 40.120.190.880.180.345.71
GPT 3.50.140.220.890.210.387.36
Mixtral-Instruct 8x7B0.270.310.90.290.3811.39
Mistral-Instruct 7B0.120.190.870.180.345.74
Neural-Chat0.130.210.880.20.366.45
DeepSeek-Coder-Instruct 6.7B0.090.150.860.140.304.09
DeepSeek-Coder-Instruct 33B0.090.150.860.150.314.41
XMAiNframe-Instruct 7B0.450.420.920.40.4220.43
XMAiNframe-Instruct 10.5B0.430.42 0.92 0.4 0.42 20.93
270 | 271 | ## COBOL Code Summarization 272 | 273 | 274 | 275 | 276 | 277 | 278 | 279 | 280 | 281 | 282 | 283 | 284 | 285 | 286 | 287 | 288 | 289 | 290 | 291 | 292 | 293 | 294 | 295 | 296 | 297 | 298 | 299 | 300 | 301 | 302 | 303 | 304 | 305 | 306 | 307 | 308 | 309 | 310 | 311 | 312 | 313 | 314 | 315 | 316 | 317 | 318 | 319 | 320 | 321 | 322 | 323 | 324 | 325 | 326 | 327 | 328 | 329 | 330 | 331 | 332 | 333 | 334 | 335 | 336 | 337 | 338 | 339 | 340 | 341 | 342 | 343 | 344 | 345 | 346 | 347 | 348 | 349 | 350 | 351 | 352 | 353 | 354 | 355 | 356 | 357 | 358 | 359 | 360 | 361 | 362 | 363 |
ModelsMAPF1-ScoreBERTScoreRougeLMeteorBLEU-4
GPT 40.120.190.880.180.345.71
GPT 3.50.140.220.890.210.387.36
Mixtral-Instruct 8x7B0.270.310.90.290.3811.39
Mistral-Instruct 7B0.120.190.870.180.345.74
Neural-Chat0.130.210.880.20.366.45
DeepSeek-Coder-Instruct 6.7B0.090.150.860.140.304.09
DeepSeek-Coder-Instruct 33B0.090.150.860.150.314.41
XMAiNframe-Instruct 7B0.450.420.920.40.4220.43
XMAiNframe-Instruct 10.5B0.430.42 0.92 0.4 0.42 20.93
364 | 365 | For more evaluation details and settings, please check our paper. 366 | 367 | 368 | # Usage 369 | ## Fine-tune XMAiNframe 370 | To run the code in this project, first, create a Python virtual environment using e.g. Conda: 371 | 372 | ```shell 373 | conda create -n xmainframe python=3.10 && conda activate xmainframe 374 | ``` 375 | 376 | You can then install the remaining package dependencies as follows: 377 | 378 | ```shell 379 | git clone https://github.com/FSoft-AI4Code/XMainframe.git 380 | cd XMainframe 381 | pip install -r requirements.txt 382 | ``` 383 | You can now check out the `scripts` and `recipes` directories for instructions on how to fine-tune our model 🪁! 384 | 385 | 386 | ## Inference 387 | 388 | Here is a code snippet with `apply_chat_template` to show you how to load the tokenizer and model and how to generate content. 389 | 390 | 391 | ```python 392 | from transformers import AutoTokenizer, AutoModelForCausalLM 393 | tokenizer = AutoTokenizer.from_pretrained("Fsoft-AIC/XMAiNframe-instruct-7b") 394 | model = AutoModelForCausalLM.from_pretrained("Fsoft-AIC/XMAiNframe-instruct-7b") 395 | messages=[ 396 | {'from':'system', 'value': "You are a helpful assistant"}, 397 | {'from': 'human', 'value': 'What is the future of Mainframe?'} 398 | ] 399 | inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt").to(model.device) 400 | 401 | outputs = model.generate(inputs, max_new_tokens=512, do_sample=False, top_k=50, top_p=0.95, num_return_sequences=1, eos_token_id=tokenizer.eos_token_id) 402 | print(tokenizer.decode(outputs[0][len(inputs[0]):], skip_special_tokens=True)) 403 | ``` 404 | 405 | 406 | 407 | 408 | 409 | 410 | # License 411 | This code repository is licensed under [the MIT License](LICENSE) 412 | 413 | # Acknowledgements 414 | This codebase is adapted from: 415 | - [alignment-handbook](https://github.com/huggingface/alignment-handbook) 416 | 417 | # Contact us 418 | If you have any questions, comments or suggestions, please do not hesitate to contact us. 419 | - Website: [fpt-aicenter](https://www.fpt-aicenter.com/ai-residency/) 420 | - Email: support.ailab@fpt.com 421 | 422 | # Citation Information 423 | More details can be found in our [technical report](https://github.com/FSoft-AI4Code/). 424 | 425 | If you're using XMAiNframe, please cite using this BibTeX: 426 | ```bibtex 427 | @article{dau2024xmainframe, 428 | title={XMainframe: A Large Language Model for Mainframe Modernization}, 429 | author={Dau, Anh TV and Dao, Hieu Trung and Nguyen, Anh Tuan and Tran, Hieu Trung and Nguyen, Phong X and Bui, Nghi DQ}, 430 | journal={arXiv preprint arXiv:2408.04660}, 431 | year={2024} 432 | } 433 | ``` 434 | -------------------------------------------------------------------------------- /asset/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FSoft-AI4Code/XMainframe/801033385e0457667aff301dc07df1c1b8ca4b04/asset/.DS_Store -------------------------------------------------------------------------------- /asset/XMAiNframe.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FSoft-AI4Code/XMainframe/801033385e0457667aff301dc07df1c1b8ca4b04/asset/XMAiNframe.png -------------------------------------------------------------------------------- /asset/cobol_diagram-Page-1.drawio.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FSoft-AI4Code/XMainframe/801033385e0457667aff301dc07df1c1b8ca4b04/asset/cobol_diagram-Page-1.drawio.png -------------------------------------------------------------------------------- /asset/cobol_diagram-Page-8.drawio.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FSoft-AI4Code/XMainframe/801033385e0457667aff301dc07df1c1b8ca4b04/asset/cobol_diagram-Page-8.drawio.pdf -------------------------------------------------------------------------------- /asset/cobol_diagram-Page-8.drawio.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FSoft-AI4Code/XMainframe/801033385e0457667aff301dc07df1c1b8ca4b04/asset/cobol_diagram-Page-8.drawio.png -------------------------------------------------------------------------------- /asset/sample_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FSoft-AI4Code/XMainframe/801033385e0457667aff301dc07df1c1b8ca4b04/asset/sample_1.png -------------------------------------------------------------------------------- /asset/sample_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FSoft-AI4Code/XMainframe/801033385e0457667aff301dc07df1c1b8ca4b04/asset/sample_2.png -------------------------------------------------------------------------------- /asset/sample_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FSoft-AI4Code/XMainframe/801033385e0457667aff301dc07df1c1b8ca4b04/asset/sample_3.png -------------------------------------------------------------------------------- /recipes/accelerate_configs/deepspeed_zero3.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: false 3 | deepspeed_config: 4 | deepspeed_multinode_launcher: standard 5 | offload_optimizer_device: none 6 | offload_param_device: none 7 | zero3_init_flag: true 8 | zero3_save_16bit_model: true 9 | zero_stage: 3 10 | distributed_type: DEEPSPEED 11 | downcast_bf16: 'no' 12 | machine_rank: 0 13 | main_training_function: main 14 | mixed_precision: bf16 15 | num_machines: 1 16 | num_processes: 8 17 | rdzv_backend: static 18 | same_network: true 19 | tpu_env: [] 20 | tpu_use_cluster: false 21 | tpu_use_sudo: false 22 | use_cpu: false 23 | -------------------------------------------------------------------------------- /recipes/accelerate_configs/deepspeed_zero3_lora.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: false 3 | deepspeed_config: 4 | deepspeed_multinode_launcher: standard 5 | offload_optimizer_device: cpu #none 6 | offload_param_device: cpu #none 7 | zero3_init_flag: true 8 | zero3_save_16bit_model: false 9 | zero_stage: 3 10 | distributed_type: DEEPSPEED 11 | downcast_bf16: 'no' 12 | machine_rank: 0 13 | main_training_function: main 14 | mixed_precision: bf16 15 | num_machines: 1 16 | num_processes: 4 17 | rdzv_backend: static 18 | same_network: true 19 | tpu_env: [] 20 | tpu_use_cluster: false 21 | tpu_use_sudo: false 22 | use_cpu: false -------------------------------------------------------------------------------- /recipes/accelerate_configs/fsdp.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: false 3 | distributed_type: FSDP 4 | downcast_bf16: 'no' 5 | enable_cpu_affinity: false 6 | fsdp_config: 7 | fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP 8 | fsdp_backward_prefetch: BACKWARD_PRE 9 | fsdp_cpu_ram_efficient_loading: true 10 | fsdp_forward_prefetch: true 11 | fsdp_offload_params: false 12 | fsdp_sharding_strategy: FULL_SHARD 13 | fsdp_state_dict_type: SHARDED_STATE_DICT 14 | fsdp_sync_module_states: true 15 | fsdp_use_orig_params: true 16 | machine_rank: 0 17 | main_training_function: main 18 | mixed_precision: bf16 19 | num_machines: 1 20 | num_processes: 4 21 | rdzv_backend: static 22 | same_network: true 23 | tpu_env: [] 24 | tpu_use_cluster: false 25 | tpu_use_sudo: false 26 | use_cpu: false 27 | -------------------------------------------------------------------------------- /recipes/accelerate_configs/fsdp_qlora.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: false 3 | distributed_type: FSDP 4 | downcast_bf16: 'no' 5 | # cpu_ram_efficient_loading: "true" 6 | fsdp_config: 7 | fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP 8 | fsdp_backward_prefetch: BACKWARD_PRE 9 | fsdp_cpu_ram_efficient_loading: true 10 | fsdp_forward_prefetch: false 11 | fsdp_offload_params: true 12 | fsdp_sharding_strategy: FULL_SHARD 13 | fsdp_state_dict_type: SHARDED_STATE_DICT 14 | fsdp_sync_module_states: true 15 | fsdp_use_orig_params: false 16 | machine_rank: 0 17 | main_training_function: main 18 | mixed_precision: 'no' 19 | num_machines: 1 20 | num_processes: 2 21 | rdzv_backend: static 22 | same_network: true 23 | tpu_env: [] 24 | tpu_use_cluster: false 25 | tpu_use_sudo: false 26 | use_cpu: false -------------------------------------------------------------------------------- /recipes/accelerate_configs/multi_gpu.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: false 3 | distributed_type: MULTI_GPU 4 | downcast_bf16: 'no' 5 | gpu_ids: all 6 | machine_rank: 0 7 | main_training_function: main 8 | mixed_precision: bf16 9 | num_machines: 1 10 | num_processes: 8 11 | rdzv_backend: static 12 | same_network: true 13 | tpu_env: [] 14 | tpu_use_cluster: false 15 | tpu_use_sudo: false 16 | use_cpu: false 17 | -------------------------------------------------------------------------------- /recipes/deepseek/full.yaml: -------------------------------------------------------------------------------- 1 | # Model arguments 2 | # model_name_or_path: deepseek-ai/deepseek-coder-7b-base-v1.5 3 | model_name_or_path: /cm/shared/anhdtv7/mainframe_gpt/data/deepseek-7b-ft-full 4 | model_revision: main 5 | torch_dtype: bfloat16 6 | use_flash_attention_2: false 7 | trust_remote_code: true 8 | 9 | # Data training arguments 10 | dataset_mixer: 11 | # HuggingFaceH4/ultrachat_200k: 1.0 12 | /cm/archive/hieudt47/workspace/data/mainframe_df_v1.1_chunks_8096.feather: 0.1 13 | /cm/archive/hieudt47/workspace/data/textbook_quality_programming.feather: 0.1 14 | dataset_splits: 15 | - train 16 | - test 17 | preprocessing_num_workers: 12 18 | 19 | # SFT trainer config 20 | bf16: true 21 | do_eval: false 22 | evaluation_strategy: steps 23 | gradient_accumulation_steps: 2 24 | gradient_checkpointing: true 25 | hub_model_id: deepseek-7b-ft-full_longcontext 26 | hub_strategy: every_save 27 | learning_rate: 2.0e-05 28 | log_level: info 29 | logging_steps: 5 30 | logging_strategy: steps 31 | lr_scheduler_type: cosine 32 | max_seq_length: 16000 33 | max_steps: -1 34 | num_train_epochs: 1 35 | output_dir: data/deepseek-7b-ft-full_longcontext 36 | overwrite_output_dir: true 37 | per_device_eval_batch_size: 4 38 | per_device_train_batch_size: 1 39 | push_to_hub: false 40 | remove_unused_columns: true 41 | report_to: 42 | - tensorboard 43 | # save_strategy: "epoch" 44 | save_strategy: "steps" 45 | save_steps: 100 46 | save_total_limit: null 47 | seed: 42 48 | tf32: true 49 | -------------------------------------------------------------------------------- /recipes/deepseek/full_instruct.yaml: -------------------------------------------------------------------------------- 1 | # Model arguments 2 | model_name_or_path: /cm/shared/anhdtv7/mainframe_gpt/data/deepseek-7b-ft-full/checkpoint-5425 3 | # model_name_or_path: mistralai/Mistral-7B-v0.1 4 | model_revision: main 5 | torch_dtype: bfloat16 6 | use_flash_attention_2: true 7 | 8 | # Data training arguments 9 | chat_template: "{{ bos_token }}{% for message in messages %}\n{% if message['from'] == 'human' %}\n{{ '<|user|>\n' + message['value']}}\n{% elif message['from'] == 'system' %}\n{{ '<|system|>\n' + message['value']}}\n{% elif message['from'] == 'gpt' %}\n{{ '<|assistant|>\n' + message['value'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}" 10 | dataset_mixer: 11 | # HuggingFaceH4/ultrachat_200k: 1.0 12 | /cm/archive/hieudt47/workspace/data/mainframe_df_v1.1_chunks_8096.feather: 0.1 13 | /cm/archive/hieudt47/workspace/data/textbook_quality_programming.feather: 0.1 14 | dataset_splits: 15 | - train 16 | - test 17 | preprocessing_num_workers: 12 18 | 19 | # SFT trainer config 20 | bf16: true 21 | do_eval: false 22 | evaluation_strategy: epoch 23 | gradient_accumulation_steps: 2 24 | gradient_checkpointing: true 25 | gradient_checkpointing_kwargs: 26 | use_reentrant: False 27 | hub_model_id: deepseek_instruct_full_data_pretrained 28 | hub_strategy: every_save 29 | learning_rate: 2.0e-05 30 | log_level: info 31 | logging_steps: 3 32 | logging_strategy: steps 33 | lr_scheduler_type: cosine 34 | max_seq_length: 4096 35 | max_steps: -1 36 | num_train_epochs: 5 37 | output_dir: data/deepseek_instruct_full_data_pretrained 38 | overwrite_output_dir: true 39 | per_device_eval_batch_size: 4 40 | per_device_train_batch_size: 2 41 | push_to_hub: false 42 | remove_unused_columns: true 43 | report_to: 44 | - tensorboard 45 | save_strategy: "epoch" 46 | save_total_limit: null 47 | seed: 45 48 | warmup_ratio: 0.1 -------------------------------------------------------------------------------- /recipes/deepseek/lora_instruct.yaml: -------------------------------------------------------------------------------- 1 | # Model arguments 2 | model_name_or_path: /cm/shared/anhdtv7/mainframe_gpt/data/deepseek_16b_lora 3 | model_revision: main 4 | torch_dtype: float16 5 | use_flash_attention_2: true 6 | trust_remote_code: true 7 | 8 | 9 | 10 | # LoRA arguments 11 | load_in_4bit: true 12 | use_peft: true 13 | lora_r: 16 14 | lora_alpha: 16 15 | lora_dropout: 0.1 16 | lora_target_modules: 17 | - q_proj 18 | - k_proj 19 | - v_proj 20 | - o_proj 21 | - gate_proj 22 | - lm_head 23 | - up_proj 24 | - down_proj 25 | 26 | # Data training arguments 27 | chat_template: "{% for message in messages %}\n{% if message['from'] == 'human' %}\n{{ '<|user|>\n' + message['value'] + eos_token }}\n{% elif message['from'] == 'system' %}\n{{ '<|system|>\n' + message['value'] + eos_token }}\n{% elif message['from'] == 'gpt' %}\n{{ '<|assistant|>\n' + message['value'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}" 28 | dataset_mixer: 29 | # HuggingFaceH4/ultrachat_200k: 1.0 30 | /cm/archive/hieudt47/workspace/data/mainframe_df_v1.1.feather: 1 31 | /cm/archive/hieudt47/workspace/data/textbook_quality_programming.feather: 1 32 | dataset_splits: 33 | - train 34 | - test 35 | preprocessing_num_workers: 12 36 | 37 | # SFT trainer config 38 | bf16: true 39 | do_eval: true 40 | evaluation_strategy: epoch 41 | gradient_accumulation_steps: 4 42 | gradient_checkpointing: true 43 | gradient_checkpointing_kwargs: 44 | use_reentrant: false 45 | hub_model_id: deepseek_instruct_lora_16b_multigpu 46 | hub_strategy: every_save 47 | learning_rate: 2.0e-04 48 | log_level: info 49 | logging_steps: 5 50 | logging_strategy: steps 51 | lr_scheduler_type: cosine 52 | max_seq_length: 4096 53 | max_steps: -1 54 | num_train_epochs: 3 55 | output_dir: data/deepseek_instruct_lora_16b_multigpu 56 | overwrite_output_dir: true 57 | per_device_eval_batch_size: 4 58 | per_device_train_batch_size: 3 59 | push_to_hub: false 60 | # dataset_num_proc: 4 61 | report_to: 62 | - tensorboard 63 | save_strategy: "epoch" 64 | save_total_limit: null 65 | seed: 42 66 | warmup_ratio: 0.1 -------------------------------------------------------------------------------- /recipes/deepseek/lora_sft.yaml: -------------------------------------------------------------------------------- 1 | # Model arguments 2 | model_name_or_path: /cm/shared/anhdtv7/mainframe_gpt/data/deepseek-7b-ft-full 3 | model_revision: main 4 | torch_dtype: 'float16' 5 | # trust_remote_code: true 6 | use_flash_attention_2: true 7 | 8 | 9 | # LoRA arguments 10 | # load_in_8bit: true 11 | load_in_4bit: true 12 | use_peft: true 13 | lora_r: 8 14 | lora_alpha: 16 15 | lora_dropout: 0.1 16 | lora_target_modules: 17 | - q_proj 18 | - k_proj 19 | - v_proj 20 | - o_proj 21 | - gate_proj 22 | - lm_head 23 | - up_proj 24 | - down_proj 25 | 26 | # Data training arguments 27 | chat_template: "{% for message in messages %}\n{% if message['from'] == 'human' %}\n{{ '<|user|>\n' + message['value'] + eos_token }}\n{% elif message['from'] == 'system' %}\n{{ '<|system|>\n' + message['value'] + eos_token }}\n{% elif message['from'] == 'gpt' %}\n{{ '<|assistant|>\n' + message['value'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}" 28 | dataset_mixer: 29 | # HuggingFaceH4/ultrachat_200k: 1.0 30 | /cm/archive/hieudt47/workspace/data/mainframe_df_v1.1.feather: 1 31 | /cm/archive/hieudt47/workspace/data/textbook_quality_programming.feather: 1 32 | dataset_splits: 33 | - train 34 | - test 35 | preprocessing_num_workers: 12 36 | 37 | # SFT trainer config 38 | bf16: true 39 | do_eval: true 40 | evaluation_strategy: epoch 41 | gradient_accumulation_steps: 2 42 | gradient_checkpointing: true 43 | gradient_checkpointing_kwargs: 44 | use_reentrant: false 45 | hub_model_id: deepseek-7b-ft-lora_longcontext 46 | hub_strategy: every_save 47 | # optimizer: paged_adamw_8bit 48 | learning_rate: 1.0e-05 49 | log_level: info 50 | logging_steps: 5 51 | logging_strategy: steps 52 | lr_scheduler_type: cosine 53 | max_seq_length: 16000 54 | max_steps: -1 55 | num_train_epochs: 5 56 | output_dir: data/deepseek-7b-ft-lora_longcontext 57 | overwrite_output_dir: true 58 | per_device_eval_batch_size: 2 59 | per_device_train_batch_size: 1 60 | # save_strategy: "steps" 61 | # save_steps: 500 62 | push_to_hub: false 63 | report_to: 64 | - tensorboard 65 | save_strategy: "epoch" 66 | save_total_limit: null 67 | seed: 42 68 | warmup_ratio: 0.1 -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==2.1.0 2 | accelerate==0.33.0 3 | aiohttp==3.9.5 4 | aiosignal==1.3.1 5 | alignment-handbook 6 | annotated-types==0.7.0 7 | async-timeout==4.0.3 8 | attrs==23.2.0 9 | bitsandbytes==0.43.2 10 | cachetools==5.4.0 11 | certifi==2024.7.4 12 | charset-normalizer==3.3.2 13 | datasets==2.20.0 14 | deepspeed==0.12.2 15 | dill==0.3.8 16 | docstring_parser==0.16 17 | einops==0.8.0 18 | evaluate==0.4.0 19 | filelock==3.15.4 20 | flash_attn==2.6.3 21 | frozenlist==1.4.1 22 | fsspec==2024.5.0 23 | grpcio==1.65.1 24 | hf_transfer==0.1.8 25 | hjson==3.1.0 26 | huggingface-hub==0.24.2 27 | idna==3.7 28 | Jinja2==3.1.4 29 | Markdown==3.6 30 | markdown-it-py==3.0.0 31 | MarkupSafe==2.1.5 32 | mdurl==0.1.2 33 | mpmath==1.3.0 34 | multidict==6.0.5 35 | multiprocess==0.70.16 36 | networkx==3.3 37 | ninja==1.11.1.1 38 | numpy==1.26.4 39 | nvidia-cublas-cu12==12.1.3.1 40 | nvidia-cuda-cupti-cu12==12.1.105 41 | nvidia-cuda-nvrtc-cu12==12.1.105 42 | nvidia-cuda-runtime-cu12==12.1.105 43 | nvidia-cudnn-cu12==9.1.0.70 44 | nvidia-cufft-cu12==11.0.2.54 45 | nvidia-curand-cu12==10.3.2.106 46 | nvidia-cusolver-cu12==11.4.5.107 47 | nvidia-cusparse-cu12==12.1.0.106 48 | nvidia-ml-py==12.535.161 49 | nvidia-nccl-cu12==2.20.5 50 | nvidia-nvjitlink-cu12==12.5.82 51 | nvidia-nvtx-cu12==12.1.105 52 | nvitop==1.3.2 53 | packaging==24.1 54 | pandas==2.2.2 55 | peft==0.12.0 56 | protobuf==3.20.2 57 | psutil==6.0.0 58 | py-cpuinfo==9.0.0 59 | pyarrow==17.0.0 60 | pyarrow-hotfix==0.6 61 | pydantic==2.8.2 62 | pydantic_core==2.20.1 63 | Pygments==2.18.0 64 | pynvml==11.5.3 65 | python-dateutil==2.9.0.post0 66 | pytz==2024.1 67 | PyYAML==6.0.1 68 | regex==2024.7.24 69 | requests==2.32.3 70 | responses==0.18.0 71 | rich==13.7.1 72 | safetensors==0.4.3 73 | scipy==1.14.0 74 | sentencepiece==0.2.0 75 | shtab==1.7.1 76 | six==1.16.0 77 | sympy==1.13.1 78 | tensorboard==2.17.0 79 | tensorboard-data-server==0.7.2 80 | termcolor==2.4.0 81 | tokenizers==0.19.1 82 | torch==2.4.0 83 | tqdm==4.66.4 84 | transformers==4.40.2 85 | triton==3.0.0 86 | trl==0.9.6 87 | typing_extensions==4.12.2 88 | tyro==0.8.5 89 | tzdata==2024.1 90 | urllib3==2.2.2 91 | Werkzeug==3.0.3 92 | xxhash==3.4.1 93 | yarl==1.9.4 94 | 95 | 96 | -------------------------------------------------------------------------------- /scripts/ft.sh: -------------------------------------------------------------------------------- 1 | 2 | ACCELERATE_LOG_LEVEL=info accelerate launch \ 3 | --config_file recipes/accelerate_configs/fsdp_qlora.yaml \ 4 | --num_processes=4 \ 5 | --main_process_port 9506 \ 6 | train_raw.py recipes/deepseek/lora_sft.yaml \ 7 | --torch_dtype=bfloat16 --bnb_4bit_quant_storage=bfloat16 \ 8 | --use_4bit_quantization=True \ 9 | --use_nested_quant=True \ 10 | --bnb_4bit_quant_type="nf4" \ 11 | --bnb_4bit_compute_dtype=bfloat16 \ 12 | --bnb_4bit_quant_storage_dtype=bfloat16 13 | # --load_in_4bit=true 14 | -------------------------------------------------------------------------------- /scripts/instruct.sh: -------------------------------------------------------------------------------- 1 | 2 | ACCELERATE_LOG_LEVEL=info accelerate launch \ 3 | --config_file recipes/accelerate_configs/multi_gpu.yaml \ 4 | --num_processes=4 \ 5 | --main_process_port 9505 \ 6 | train_instruct.py recipes/deepseek/lora_instruct.yaml \ 7 | # --torch_dtype=bfloat16 --bnb_4bit_quant_storage=bfloat16 \ 8 | # --use_4bit_quantization=True \ 9 | # --use_nested_quant=True \ 10 | # --bnb_4bit_quant_type="nf4" \ 11 | # --bnb_4bit_compute_dtype=bfloat16 \ 12 | # --bnb_4bit_quant_storage_dtype=bfloat16 13 | -------------------------------------------------------------------------------- /src/alignment/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.3.0.dev0" 2 | 3 | from .configs import DataArguments, DPOConfig, H4ArgumentParser, ModelArguments, SFTConfig 4 | from .data import apply_chat_template, get_datasets 5 | from .decontaminate import decontaminate_humaneval 6 | from .model_utils import ( 7 | get_checkpoint, 8 | get_kbit_device_map, 9 | get_peft_config, 10 | get_quantization_config, 11 | get_tokenizer, 12 | is_adapter_model, 13 | ) 14 | -------------------------------------------------------------------------------- /src/alignment/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FSoft-AI4Code/XMainframe/801033385e0457667aff301dc07df1c1b8ca4b04/src/alignment/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /src/alignment/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FSoft-AI4Code/XMainframe/801033385e0457667aff301dc07df1c1b8ca4b04/src/alignment/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /src/alignment/__pycache__/configs.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FSoft-AI4Code/XMainframe/801033385e0457667aff301dc07df1c1b8ca4b04/src/alignment/__pycache__/configs.cpython-310.pyc -------------------------------------------------------------------------------- /src/alignment/__pycache__/configs.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FSoft-AI4Code/XMainframe/801033385e0457667aff301dc07df1c1b8ca4b04/src/alignment/__pycache__/configs.cpython-38.pyc -------------------------------------------------------------------------------- /src/alignment/__pycache__/data.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FSoft-AI4Code/XMainframe/801033385e0457667aff301dc07df1c1b8ca4b04/src/alignment/__pycache__/data.cpython-310.pyc -------------------------------------------------------------------------------- /src/alignment/__pycache__/data.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FSoft-AI4Code/XMainframe/801033385e0457667aff301dc07df1c1b8ca4b04/src/alignment/__pycache__/data.cpython-38.pyc -------------------------------------------------------------------------------- /src/alignment/__pycache__/decontaminate.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FSoft-AI4Code/XMainframe/801033385e0457667aff301dc07df1c1b8ca4b04/src/alignment/__pycache__/decontaminate.cpython-310.pyc -------------------------------------------------------------------------------- /src/alignment/__pycache__/decontaminate.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FSoft-AI4Code/XMainframe/801033385e0457667aff301dc07df1c1b8ca4b04/src/alignment/__pycache__/decontaminate.cpython-38.pyc -------------------------------------------------------------------------------- /src/alignment/__pycache__/model_utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FSoft-AI4Code/XMainframe/801033385e0457667aff301dc07df1c1b8ca4b04/src/alignment/__pycache__/model_utils.cpython-310.pyc -------------------------------------------------------------------------------- /src/alignment/__pycache__/model_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FSoft-AI4Code/XMainframe/801033385e0457667aff301dc07df1c1b8ca4b04/src/alignment/__pycache__/model_utils.cpython-38.pyc -------------------------------------------------------------------------------- /src/alignment/configs.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 The HuggingFace Team. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | import dataclasses 16 | import os 17 | import sys 18 | from dataclasses import dataclass, field 19 | from typing import Any, Dict, List, NewType, Optional, Tuple 20 | 21 | import transformers 22 | from transformers import MODEL_FOR_CAUSAL_LM_MAPPING, HfArgumentParser 23 | 24 | 25 | MODEL_CONFIG_CLASSES = list(MODEL_FOR_CAUSAL_LM_MAPPING.keys()) 26 | MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES) 27 | 28 | 29 | DataClassType = NewType("DataClassType", Any) 30 | 31 | 32 | class H4ArgumentParser(HfArgumentParser): 33 | def parse_yaml_and_args(self, yaml_arg: str, other_args: Optional[List[str]] = None) -> List[dataclass]: 34 | """ 35 | Parse a YAML file and overwrite the default/loaded values with the values provided to the command line. 36 | 37 | Args: 38 | yaml_arg (`str`): 39 | The path to the config file used 40 | other_args (`List[str]`, *optional`): 41 | A list of strings to parse as command line arguments, e.g. ['--arg=val', '--arg2=val2']. 42 | 43 | Returns: 44 | [`List[dataclass]`]: a list of dataclasses with the values from the YAML file and the command line 45 | """ 46 | arg_list = self.parse_yaml_file(os.path.abspath(yaml_arg)) 47 | 48 | outputs = [] 49 | # strip other args list into dict of key-value pairs 50 | other_args = {arg.split("=")[0].strip("-"): arg.split("=")[1] for arg in other_args} 51 | used_args = {} 52 | 53 | # overwrite the default/loaded value with the value provided to the command line 54 | # adapted from https://github.com/huggingface/transformers/blob/d0b5002378daabf62769159add3e7d66d3f83c3b/src/transformers/hf_argparser.py#L327 55 | for data_yaml, data_class in zip(arg_list, self.dataclass_types): 56 | keys = {f.name for f in dataclasses.fields(data_yaml) if f.init} 57 | inputs = {k: v for k, v in vars(data_yaml).items() if k in keys} 58 | for arg, val in other_args.items(): 59 | # add only if in keys 60 | 61 | if arg in keys: 62 | base_type = data_yaml.__dataclass_fields__[arg].type 63 | inputs[arg] = val 64 | 65 | # cast type for ints, floats (default to strings) 66 | if base_type in [int, float]: 67 | inputs[arg] = base_type(val) 68 | 69 | if base_type == List[str]: 70 | inputs[arg] = [str(v) for v in val.split(",")] 71 | 72 | # bool of a non-empty string is True, so we manually check for bools 73 | if base_type == bool: 74 | if val in ["true", "True"]: 75 | inputs[arg] = True 76 | else: 77 | inputs[arg] = False 78 | 79 | # add to used-args so we can check if double add 80 | if arg not in used_args: 81 | used_args[arg] = val 82 | else: 83 | raise ValueError(f"Duplicate argument provided: {arg}, may cause unexpected behavior") 84 | 85 | obj = data_class(**inputs) 86 | outputs.append(obj) 87 | 88 | return outputs 89 | 90 | def parse(self): #-> DataClassType | Tuple[DataClassType]: 91 | if len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"): 92 | # If we pass only one argument to the script and it's the path to a YAML file, 93 | # let's parse it to get our arguments. 94 | output = self.parse_yaml_file(os.path.abspath(sys.argv[1])) 95 | # parse command line args and yaml file 96 | elif len(sys.argv) > 2 and sys.argv[1].endswith(".yaml"): 97 | output = self.parse_yaml_and_args(os.path.abspath(sys.argv[1]), sys.argv[2:]) 98 | # parse command line args only 99 | else: 100 | output = self.parse_args_into_dataclasses() 101 | 102 | if len(output) == 1: 103 | output = output[0] 104 | return output 105 | 106 | 107 | @dataclass 108 | class ModelArguments: 109 | """ 110 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune. 111 | """ 112 | 113 | base_model_revision: Optional[str] = field( 114 | default=None, 115 | metadata={"help": ("The base model checkpoint for weights initialization with PEFT adapters.")}, 116 | ) 117 | model_name_or_path: Optional[str] = field( 118 | default=None, 119 | metadata={ 120 | "help": ( 121 | "The model checkpoint for weights initialization. Don't set if you want to train a model from scratch." 122 | ) 123 | }, 124 | ) 125 | model_revision: str = field( 126 | default="main", 127 | metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, 128 | ) 129 | model_code_revision: str = field(default=None, metadata={"help": "The branch of the IFT model"}) 130 | torch_dtype: Optional[str] = field( 131 | default=None, 132 | metadata={ 133 | "help": ( 134 | "Override the default `torch.dtype` and load the model under this dtype. If `auto` is passed, the " 135 | "dtype will be automatically derived from the model's weights." 136 | ), 137 | "choices": ["auto", "bfloat16", "float16", "float32"], 138 | }, 139 | ) 140 | tokenizer_name_or_path: Optional[str] = field( 141 | default=None, 142 | metadata={ 143 | "help": ( 144 | "The path to the tokenizer. Useful if you want to use a different tokenizer to the one stored in `model_name_or_path`." 145 | ) 146 | }, 147 | ) 148 | trust_remote_code: bool = field(default=False, metadata={"help": "Trust remote code when loading a model."}) 149 | use_flash_attention_2: bool = field( 150 | default=False, 151 | metadata={ 152 | "help": ( 153 | "Whether to use flash attention 2. You must install this manually by running `pip install flash-attn --no-build-isolation`" 154 | ) 155 | }, 156 | ) 157 | use_peft: bool = field( 158 | default=False, 159 | metadata={"help": ("Whether to use PEFT or not for training.")}, 160 | ) 161 | lora_r: Optional[int] = field( 162 | default=16, 163 | metadata={"help": ("LoRA R value.")}, 164 | ) 165 | lora_alpha: Optional[int] = field( 166 | default=32, 167 | metadata={"help": ("LoRA alpha.")}, 168 | ) 169 | lora_dropout: Optional[float] = field( 170 | default=0.05, 171 | metadata={"help": ("LoRA dropout.")}, 172 | ) 173 | lora_target_modules: Optional[List[str]] = field( 174 | default=None, 175 | metadata={"help": ("LoRA target modules.")}, 176 | ) 177 | lora_modules_to_save: Optional[List[str]] = field( 178 | default=None, 179 | metadata={"help": ("Model layers to unfreeze & train")}, 180 | ) 181 | load_in_8bit: bool = field(default=False, metadata={"help": "use 8 bit precision"}) 182 | load_in_4bit: bool = field(default=False, metadata={"help": "use 4 bit precision"}) 183 | 184 | bnb_4bit_quant_type: Optional[str] = field( 185 | default="nf4", metadata={"help": "precise the quantization type (fp4 or nf4)"} 186 | ) 187 | use_bnb_nested_quant: bool = field(default=False, metadata={"help": "use nested quantization"}) 188 | bnb_4bit_quant_storage: Optional[str] = field( 189 | default="uint8", metadata={"help": "storage type to pack the quanitzed 4-bit prarams."} 190 | ) 191 | 192 | def __post_init__(self): 193 | if self.load_in_8bit and self.load_in_4bit: 194 | raise ValueError("You can't use 8 bit and 4 bit precision at the same time") 195 | 196 | 197 | @dataclass 198 | class DataArguments: 199 | """ 200 | Arguments pertaining to what data we are going to input our model for training and eval. 201 | """ 202 | 203 | chat_template: Optional[str] = field(default=None, metadata={"help": "The chat template to use."}) 204 | dataset_mixer: Optional[Dict[str, float]] = field( 205 | default=None, 206 | metadata={"help": ("Datasets and their proportions to be used for training ift/rl.")}, 207 | ) 208 | text_column: Optional[str] = field( 209 | default="text", 210 | metadata={"help": "The column name to use for the text in the dataset (only used for continued pretraining)."}, 211 | ) 212 | dataset_splits: Optional[List[str]] = field( 213 | default_factory=lambda: ["train", "test"], 214 | metadata={"help": ("List of train test splits to use in the dataset")}, 215 | ) 216 | dataset_configs: Optional[List[str]] = field( 217 | default=None, 218 | metadata={"help": "List of dataset config names. If given must be the same length as 'dataset_mixer' keys."}, 219 | ) 220 | preprocessing_num_workers: Optional[int] = field( 221 | default=None, 222 | metadata={"help": "The number of processes to use for the preprocessing."}, 223 | ) 224 | truncation_side: Optional[str] = field( 225 | default=None, metadata={"help": "Truncation side to use for the tokenizer."} 226 | ) 227 | auto_insert_empty_system_msg: bool = field( 228 | default=True, 229 | metadata={ 230 | "help": ( 231 | "Whether to automatically insert an empty system message as the first message if `system` is mentioned in the chat template." 232 | ) 233 | }, 234 | ) 235 | 236 | 237 | @dataclass 238 | class SFTConfig(transformers.TrainingArguments): 239 | """ 240 | Arguments related to the training process itself. For all parameters, see: https://huggingface.co/docs/transformers/v4.26.1/en/main_classes/trainer#transformers.TrainingArguments 241 | Also used for the continued pretraining task. 242 | """ 243 | 244 | dataset_kwargs: Optional[Dict[str, Any]] = field( 245 | default=None, metadata={"help": "Dataset kwargs for the SFTTrainer"} 246 | ) 247 | max_seq_length: Optional[int] = field( 248 | default=None, 249 | metadata={"help": ("Used by TRL for reward model training, which tries to read this parameter in init.")}, 250 | ) 251 | logging_first_step: bool = field( 252 | default=True, 253 | metadata={"help": ("Whether to log and evaluate the first global_step or not.")}, 254 | ) 255 | optim: Optional[str] = field(default="adamw_torch") 256 | 257 | 258 | @dataclass 259 | class DPOConfig(transformers.TrainingArguments): 260 | """ 261 | Arguments related to the DPO training process itself. For all parameters, see: https://huggingface.co/docs/transformers/v4.26.1/en/main_classes/trainer#transformers.TrainingArguments 262 | """ 263 | 264 | beta: Optional[float] = field( 265 | default=0.1, 266 | metadata={"help": "The beta factor in DPO loss. Higher beta means less divergence from the initial policy."}, 267 | ) 268 | hub_model_revision: Optional[str] = field( 269 | default="main", 270 | metadata={"help": ("The Hub model branch to push the model to.")}, 271 | ) 272 | logging_first_step: bool = field( 273 | default=True, 274 | metadata={"help": ("Whether to log and evaluate the first global_step or not.")}, 275 | ) 276 | max_prompt_length: Optional[int] = field( 277 | default=None, 278 | metadata={"help": ("For DPO, the maximum length of the prompt to use for conditioning the model.")}, 279 | ) 280 | max_length: Optional[int] = field( 281 | default=None, 282 | metadata={"help": ("Used by TRL for reward model training, which tries to read this parameter in init.")}, 283 | ) 284 | optim: Optional[str] = field(default="rmsprop") 285 | remove_unused_columns: bool = field(default=False) 286 | loss_type: Optional[str] = field(default="sigmoid", metadata={"help": ("The loss type for DPO.")}) 287 | 288 | 289 | @dataclass 290 | class ORPOConfig(transformers.TrainingArguments): 291 | max_length: Optional[int] = field( 292 | default=None, 293 | metadata={"help": "The maximum length of the sequences in the batch."}, 294 | ) 295 | max_prompt_length: Optional[int] = field( 296 | default=None, 297 | metadata={"help": "The maximum length of the prompt."}, 298 | ) 299 | max_completion_length: Optional[int] = field( 300 | default=None, 301 | metadata={"help": "The maximum length of the completions."}, 302 | ) 303 | 304 | beta: float = field( 305 | default=0.1, 306 | metadata={ 307 | "help": "The beta factor in ORPO loss (lambda/alpha in paper/code) that is the weight of the relative loss ratio in the SFT loss." 308 | }, 309 | ) 310 | disable_dropout: bool = field( 311 | default=True, 312 | metadata={"help": "Whether or not to disable dropouts in `model`."}, 313 | ) 314 | 315 | label_pad_token_id: int = field( 316 | default=-100, 317 | metadata={"help": "The label pad token id."}, 318 | ) 319 | padding_value: Optional[int] = field( 320 | default=None, 321 | metadata={"help": "The padding value if it is different to the tokenizer's pad_token_id."}, 322 | ) 323 | truncation_mode: str = field( 324 | default="keep_end", 325 | metadata={"help": "The truncation mode to use, either `keep_end` or `keep_start`."}, 326 | ) 327 | 328 | generate_during_eval: bool = field( 329 | default=False, 330 | metadata={"help": "Whether to sample and log generations during evaluation step."}, 331 | ) 332 | is_encoder_decoder: Optional[bool] = field( 333 | default=None, 334 | metadata={"help": ("If no model is provided, we need to know if the model_init returns an encoder-decoder.")}, 335 | ) 336 | 337 | model_init_kwargs: Optional[Dict] = field( 338 | default=None, 339 | metadata={"help": ("Dict of Optional kwargs to pass when instantiating the model from a string")}, 340 | ) 341 | 342 | dataset_num_proc: Optional[int] = field( 343 | default=None, 344 | metadata={"help": ("The number of workers to use to tokenize the data.")}, 345 | ) 346 | -------------------------------------------------------------------------------- /src/alignment/data.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 The HuggingFace Team. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import os 17 | from typing import Any, List, Literal, Optional 18 | 19 | from datasets import DatasetDict, concatenate_datasets, load_dataset, load_from_disk 20 | from datasets.builder import DatasetGenerationError 21 | 22 | from .configs import DataArguments 23 | 24 | 25 | DEFAULT_CHAT_TEMPLATE = "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}" 26 | 27 | 28 | def maybe_insert_system_message(messages, tokenizer): 29 | if messages[0]["role"] == "system": 30 | return 31 | 32 | # chat template can be one of two attributes, we check in order 33 | chat_template = tokenizer.chat_template 34 | if chat_template is None: 35 | chat_template = tokenizer.default_chat_template 36 | 37 | # confirm the jinja template refers to a system message before inserting 38 | if "system" in chat_template or "<|im_start|>" in chat_template: 39 | messages.insert(0, {"role": "system", "content": ""}) 40 | 41 | 42 | def apply_chat_template( 43 | example, 44 | tokenizer, 45 | task: Literal["sft", "generation", "rm", "dpo"], 46 | auto_insert_empty_system_msg: bool = True, 47 | ): 48 | if task in ["sft", "generation"]: 49 | messages = example["messages"] 50 | # We add an empty system message if there is none 51 | if auto_insert_empty_system_msg: 52 | maybe_insert_system_message(messages, tokenizer) 53 | example["text"] = tokenizer.apply_chat_template( 54 | messages, 55 | tokenize=False, 56 | add_generation_prompt=True if task == "generation" else False, 57 | ) 58 | elif task == "rm": 59 | if all(k in example.keys() for k in ("chosen", "rejected")): 60 | chosen_messages = example["chosen"] 61 | rejected_messages = example["rejected"] 62 | # We add an empty system message if there is none 63 | if auto_insert_empty_system_msg: 64 | maybe_insert_system_message(chosen_messages, tokenizer) 65 | maybe_insert_system_message(rejected_messages, tokenizer) 66 | 67 | example["text_chosen"] = tokenizer.apply_chat_template(chosen_messages, tokenize=False) 68 | example["text_rejected"] = tokenizer.apply_chat_template(rejected_messages, tokenize=False) 69 | else: 70 | raise ValueError( 71 | f"Could not format example as dialogue for `rm` task! Require `[chosen, rejected]` keys but found {list(example.keys())}" 72 | ) 73 | elif task in ["dpo", "orpo"]: 74 | if all(k in example.keys() for k in ("chosen", "rejected")): 75 | if not is_openai_format(example["chosen"]) or not is_openai_format(example["rejected"]): 76 | raise ValueError( 77 | f"Could not format example as dialogue for `{task}` task! Require OpenAI format for all messages" 78 | ) 79 | 80 | # For DPO/ORPO, the inputs are triples of (prompt, chosen, rejected), where `chosen` and `rejected` are the final turn of a dialogue 81 | # We therefore need to extract the N-1 turns to form the prompt 82 | if "prompt" in example and is_openai_format(example["prompt"]): 83 | prompt_messages = example["prompt"] 84 | chosen_messages = example["chosen"] 85 | rejected_messages = example["rejected"] 86 | else: 87 | prompt_messages = example["chosen"][:-1] 88 | # Now we extract the final turn to define chosen/rejected responses 89 | chosen_messages = example["chosen"][-1:] 90 | rejected_messages = example["rejected"][-1:] 91 | 92 | # Prepend a system message if the first message is not a system message 93 | if auto_insert_empty_system_msg: 94 | maybe_insert_system_message(prompt_messages, tokenizer) 95 | 96 | example["text_prompt"] = tokenizer.apply_chat_template(prompt_messages, tokenize=False) 97 | example["text_chosen"] = tokenizer.apply_chat_template(chosen_messages, tokenize=False) 98 | example["text_rejected"] = tokenizer.apply_chat_template(rejected_messages, tokenize=False) 99 | else: 100 | raise ValueError( 101 | f"Could not format example as dialogue for `{task}` task! Require either the " 102 | f"`[chosen, rejected]` or `[prompt, chosen, rejected]` keys but found {list(example.keys())}" 103 | ) 104 | else: 105 | raise ValueError( 106 | f"Task {task} not supported, please ensure that the provided task is one of ['sft', 'generation', 'rm', 'dpo', 'orpo']" 107 | ) 108 | return example 109 | 110 | 111 | def is_openai_format(messages: Any) -> bool: 112 | """ 113 | Check if the input messages are in OpenAI format. 114 | Args: 115 | messages (`Any`): 116 | Messages to check. 117 | Returns: 118 | `bool`: Whether the messages are in OpenAI format. 119 | """ 120 | if isinstance(messages, list) and all(isinstance(message, dict) for message in messages): 121 | return all("role" in message and "content" in message for message in messages) 122 | return False 123 | 124 | 125 | def get_datasets( 126 | data_config,#: DataArguments | dict, 127 | splits: Optional[List[str]] = None, 128 | configs: Optional[List[str]] = None, 129 | columns_to_keep: Optional[List[str]] = None, 130 | shuffle: bool = True, 131 | ):# -> DatasetDict: 132 | """ 133 | Loads one or more datasets with varying training set proportions. 134 | 135 | Args: 136 | data_config (`DataArguments` or `dict`): 137 | Dataset configuration and split proportions. 138 | splits (`List[str]`, *optional*, defaults to `['train', 'test']`): 139 | Dataset splits to load and mix. Assumes the splits exist in all datasets and have a `train_` or `test_` prefix. 140 | configs (Optional[List[str]], *optional*, defaults to `None`): 141 | List of dataset config names. If given must be the same length as 'data_config' keys. 142 | columns_to_keep (Optional[List[str]], *optional*, defaults to `None`): 143 | Column names to keep in the dataset. Useful in the datamixer to avoid schema conflicts, 144 | and for cpt this should be (at least) the text column. 145 | shuffle (`bool`, *optional*, defaults to `True`): 146 | Whether to shuffle the training and testing/validation data. 147 | 148 | Returns 149 | [`DatasetDict`]: The dataset dictionary containing the loaded datasets. 150 | """ 151 | if type(data_config) is DataArguments: 152 | # Structure of the config to read the datasets and their mix 153 | # datasets_mixer: 154 | # - 'dataset1': 0.5 155 | # - 'dataset2': 0.3 156 | # - 'dataset3': 0.2 157 | dataset_mixer = data_config.dataset_mixer 158 | elif isinstance(data_config, dict): 159 | # Structure of the input is: 160 | # dataset_mixer = { 161 | # "dataset1": 0.5, 162 | # "dataset1": 0.3, 163 | # "dataset1": 0.2, 164 | # } 165 | dataset_mixer = data_config 166 | else: 167 | raise ValueError(f"Data config {data_config} not recognized.") 168 | 169 | raw_datasets = mix_datasets( 170 | dataset_mixer, 171 | splits=splits, 172 | configs=configs, 173 | columns_to_keep=columns_to_keep, 174 | shuffle=shuffle, 175 | ) 176 | return raw_datasets 177 | 178 | 179 | def mix_datasets( 180 | dataset_mixer: dict, 181 | splits: Optional[List[str]] = None, 182 | configs: Optional[List[str]] = None, 183 | columns_to_keep: Optional[List[str]] = None, 184 | shuffle=True, 185 | ) -> DatasetDict: 186 | """ 187 | Loads and mixes datasets according to proportions specified in `dataset_mixer`. 188 | 189 | Args: 190 | dataset_mixer (`dict`): 191 | Dictionary containing the dataset names and their training proportions. By default, all test proportions are 1. 192 | splits (Optional[List[str]], *optional*, defaults to `None`): 193 | Dataset splits to load and mix. Assumes the splits exist in all datasets and have a `train_` or `test_` prefix. 194 | configs (Optional[List[str]], *optional*, defaults to `None`): 195 | List of dataset config names. If given must be the same length as 'dataset_mixer' keys. 196 | columns_to_keep (Optional[List[str]], *optional*, defaults to `None`): 197 | Column names to keep in the dataset. Useful in the datamixer to avoid schema conflicts, 198 | and for cpt this should be (at least) the text column. 199 | shuffle (`bool`, *optional*, defaults to `True`): 200 | Whether to shuffle the training and testing/validation data. 201 | """ 202 | splits = ["train", "test"] if splits is None else splits 203 | configs = [None] * len(dataset_mixer) if not configs else configs 204 | columns_to_keep = [] if columns_to_keep is None else columns_to_keep 205 | 206 | if configs is not None and len(configs) != len(dataset_mixer): 207 | raise ValueError("The number of given dataset config names must be the same as the given number of datasets.") 208 | 209 | raw_datasets = DatasetDict() 210 | raw_train_datasets = [] 211 | raw_val_datasets = [] 212 | fracs = [] 213 | for (ds, frac), ds_config in zip(dataset_mixer.items(), configs): 214 | fracs.append(frac) 215 | for split in splits: 216 | try: 217 | # Try first if dataset on a Hub repo 218 | dataset = load_dataset(ds, ds_config, split=split) 219 | except DatasetGenerationError: 220 | # If not, check local dataset 221 | dataset = load_from_disk(os.path.join(ds, split)) 222 | 223 | # Remove redundant columns to avoid schema conflicts on load 224 | dataset = dataset.remove_columns([col for col in dataset.column_names if col not in columns_to_keep]) 225 | if "train" in split: 226 | raw_train_datasets.append(dataset) 227 | elif "test" in split: 228 | raw_val_datasets.append(dataset) 229 | else: 230 | raise ValueError(f"Split type {split} not recognized as one of test or train.") 231 | 232 | if any(frac < 0 for frac in fracs): 233 | raise ValueError("Dataset fractions cannot be negative.") 234 | 235 | if len(raw_train_datasets) > 0: 236 | train_subsets = [] 237 | for dataset, frac in zip(raw_train_datasets, fracs): 238 | train_subset = dataset.select(range(int(frac * len(dataset)))) 239 | train_subsets.append(train_subset) 240 | if shuffle: 241 | raw_datasets["train"] = concatenate_datasets(train_subsets).shuffle(seed=42) 242 | else: 243 | raw_datasets["train"] = concatenate_datasets(train_subsets) 244 | # No subsampling for test datasets to enable fair comparison across models 245 | if len(raw_val_datasets) > 0: 246 | if shuffle: 247 | raw_datasets["test"] = concatenate_datasets(raw_val_datasets).shuffle(seed=42) 248 | else: 249 | raw_datasets["test"] = concatenate_datasets(raw_val_datasets) 250 | 251 | if len(raw_datasets) == 0: 252 | raise ValueError( 253 | f"Dataset {dataset_mixer} not recognized with splits {splits}. Check the dataset has been correctly formatted." 254 | ) 255 | 256 | return raw_datasets 257 | -------------------------------------------------------------------------------- /src/alignment/decontaminate.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The HuggingFace Inc. team. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from typing import Any, Dict, List 17 | 18 | from datasets import load_dataset 19 | 20 | 21 | # HumanEval solutions that are considered simple/generic enough to be kept in the training dataset 22 | HUMAN_EVAL_STRINGS_OK = ["return x + y", "return len(string)", "return n**2", "return " ".join(strings)"] 23 | 24 | 25 | def extract_docstring(prompt: str) -> str: 26 | if '"""' in prompt: 27 | if prompt.count('"""') == 2: 28 | return prompt.split('"""')[1].strip() 29 | elif prompt.count('"""') == 4: 30 | return prompt.split('"""')[3].strip() 31 | else: 32 | raise ValueError() 33 | elif "'''" in prompt: 34 | assert prompt.count("'''") == 2 35 | return prompt.split("'''")[1].strip() 36 | else: 37 | raise ValueError() 38 | 39 | 40 | def human_eval_docstrings() -> List[str]: 41 | # ds = load_dataset("openai_humaneval", split="test") 42 | # docstrings = [extract_docstring(v["prompt"]) for v in ds] 43 | # return docstrings 44 | return [] 45 | 46 | 47 | def load_dataset_column(dataset: str, column: str, split: str, name=None) -> List[str]: 48 | # ds = load_dataset(dataset, split=split, name=name) 49 | # res = [sample[column].strip() for sample in ds] 50 | # # Only return non-empty strings 51 | # return [sample for sample in res if len(sample) > 0] 52 | return [] 53 | 54 | 55 | FILTER_OUT = { 56 | "human_eval_docstrings": human_eval_docstrings(), 57 | "human_eval_solutions": [ 58 | s 59 | for s in load_dataset_column("openai_humaneval", "canonical_solution", "test") 60 | if s not in HUMAN_EVAL_STRINGS_OK 61 | ], 62 | } 63 | 64 | 65 | def normalize_whitespace(text: str) -> str: 66 | return " ".join(text.split()) 67 | 68 | 69 | def decontaminate_humaneval( 70 | samples: List[Dict[str, Any]], text_column: str = "text", filter_out: Dict[str, List[str]] = FILTER_OUT 71 | ) -> List[Dict[str, Any]]: 72 | """ 73 | filter_out: Dict[str, List[str]] mapping from benchmark name to list of strings that need to be 74 | filtered-out. 75 | Return a list where each element is True if the corresponding file should be included in the dataset. 76 | Otherwise, the element is False. 77 | """ 78 | output = [] 79 | 80 | for content in samples[text_column]: 81 | content = normalize_whitespace(content.lower()) 82 | matched = False 83 | for _, substrings in filter_out.items(): 84 | for substring in substrings: 85 | if normalize_whitespace(substring.lower()) in content: 86 | matched = True 87 | break 88 | if matched: 89 | break 90 | # we keep files that are not matched 91 | output.append(not matched) 92 | 93 | return output 94 | -------------------------------------------------------------------------------- /src/alignment/model_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 The HuggingFace Team. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | import os 16 | from pathlib import Path 17 | from typing import Dict 18 | 19 | import torch 20 | from transformers import AutoTokenizer, BitsAndBytesConfig, PreTrainedTokenizer 21 | from transformers.trainer_utils import get_last_checkpoint 22 | 23 | from accelerate import Accelerator 24 | from huggingface_hub import list_repo_files 25 | from huggingface_hub.utils._errors import RepositoryNotFoundError 26 | from huggingface_hub.utils._validators import HFValidationError 27 | from peft import LoraConfig, PeftConfig 28 | 29 | from .configs import DataArguments, DPOConfig, ModelArguments, SFTConfig 30 | from .data import DEFAULT_CHAT_TEMPLATE 31 | 32 | 33 | def get_current_device():# -> int: 34 | """Get the current device. For GPU we return the local process index to enable multiple GPU training.""" 35 | return Accelerator().local_process_index if torch.cuda.is_available() else "cpu" 36 | 37 | 38 | def get_kbit_device_map():# -> Dict[str, int] | None: 39 | """Useful for running inference with quantized models by setting `device_map=get_peft_device_map()`""" 40 | return {"": get_current_device()} if torch.cuda.is_available() else None 41 | 42 | 43 | def get_quantization_config(model_args: ModelArguments):# -> BitsAndBytesConfig | None: 44 | if model_args.load_in_4bit: 45 | compute_dtype = torch.float16 46 | if model_args.torch_dtype not in {"auto", None}: 47 | compute_dtype = getattr(torch, model_args.torch_dtype) 48 | 49 | quantization_config = BitsAndBytesConfig( 50 | load_in_4bit=True, 51 | bnb_4bit_compute_dtype=compute_dtype, 52 | bnb_4bit_quant_type=model_args.bnb_4bit_quant_type, 53 | bnb_4bit_use_double_quant=model_args.use_bnb_nested_quant, 54 | bnb_4bit_quant_storage=model_args.bnb_4bit_quant_storage, 55 | ) 56 | elif model_args.load_in_8bit: 57 | quantization_config = BitsAndBytesConfig( 58 | load_in_8bit=True, 59 | ) 60 | else: 61 | quantization_config = None 62 | 63 | return quantization_config 64 | 65 | 66 | def get_tokenizer( 67 | model_args: ModelArguments, data_args: DataArguments, auto_set_chat_template: bool = True 68 | ):# -> PreTrainedTokenizer: 69 | """Get the tokenizer for the model.""" 70 | tokenizer = AutoTokenizer.from_pretrained( 71 | model_args.model_name_or_path 72 | if model_args.tokenizer_name_or_path is None 73 | else model_args.tokenizer_name_or_path, 74 | revision=model_args.model_revision, 75 | trust_remote_code=model_args.trust_remote_code, 76 | ) 77 | if tokenizer.pad_token_id is None: 78 | tokenizer.pad_token_id = tokenizer.eos_token_id 79 | 80 | if data_args.truncation_side is not None: 81 | tokenizer.truncation_side = data_args.truncation_side 82 | 83 | # Set reasonable default for models without max length 84 | if tokenizer.model_max_length > 100_000: 85 | tokenizer.model_max_length = 2048 86 | 87 | if data_args.chat_template is not None: 88 | tokenizer.chat_template = data_args.chat_template 89 | elif auto_set_chat_template and tokenizer.chat_template is None and tokenizer.default_chat_template is None: 90 | tokenizer.chat_template = DEFAULT_CHAT_TEMPLATE 91 | 92 | return tokenizer 93 | 94 | 95 | def get_peft_config(model_args: ModelArguments) :# -> PeftConfig | None: 96 | if model_args.use_peft is False: 97 | return None 98 | 99 | peft_config = LoraConfig( 100 | r=model_args.lora_r, 101 | lora_alpha=model_args.lora_alpha, 102 | lora_dropout=model_args.lora_dropout, 103 | bias="none", 104 | task_type="CAUSAL_LM", 105 | target_modules=model_args.lora_target_modules, 106 | modules_to_save=model_args.lora_modules_to_save, 107 | ) 108 | 109 | return peft_config 110 | 111 | 112 | def is_adapter_model(model_name_or_path: str, revision: str = "main") :# -> bool: 113 | try: 114 | # Try first if model on a Hub repo 115 | repo_files = list_repo_files(model_name_or_path, revision=revision) 116 | except (HFValidationError, RepositoryNotFoundError): 117 | # If not, check local repo 118 | repo_files = os.listdir(model_name_or_path) 119 | return "adapter_model.safetensors" in repo_files or "adapter_model.bin" in repo_files 120 | 121 | 122 | def get_checkpoint(training_args) :# -> Path | None: 123 | last_checkpoint = None 124 | if os.path.isdir(training_args.output_dir): 125 | last_checkpoint = get_last_checkpoint(training_args.output_dir) 126 | return last_checkpoint 127 | -------------------------------------------------------------------------------- /src/alignment/release.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 The HuggingFace Team. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import argparse 17 | import re 18 | 19 | import packaging.version 20 | 21 | 22 | REPLACE_PATTERNS = { 23 | "init": (re.compile(r'^__version__\s+=\s+"([^"]+)"\s*$', re.MULTILINE), '__version__ = "VERSION"\n'), 24 | "setup": (re.compile(r'^(\s*)version\s*=\s*"[^"]+",', re.MULTILINE), r'\1version="VERSION",'), 25 | } 26 | REPLACE_FILES = { 27 | "init": "src/alignment/__init__.py", 28 | "setup": "setup.py", 29 | } 30 | README_FILE = "README.md" 31 | 32 | 33 | def update_version_in_file(fname, version, pattern): 34 | """Update the version in one file using a specific pattern.""" 35 | with open(fname, "r", encoding="utf-8", newline="\n") as f: 36 | code = f.read() 37 | re_pattern, replace = REPLACE_PATTERNS[pattern] 38 | replace = replace.replace("VERSION", version) 39 | code = re_pattern.sub(replace, code) 40 | with open(fname, "w", encoding="utf-8", newline="\n") as f: 41 | f.write(code) 42 | 43 | 44 | def global_version_update(version, patch=False): 45 | """Update the version in all needed files.""" 46 | for pattern, fname in REPLACE_FILES.items(): 47 | update_version_in_file(fname, version, pattern) 48 | 49 | 50 | def get_version(): 51 | """Reads the current version in the __init__.""" 52 | with open(REPLACE_FILES["init"], "r") as f: 53 | code = f.read() 54 | default_version = REPLACE_PATTERNS["init"][0].search(code).groups()[0] 55 | return packaging.version.parse(default_version) 56 | 57 | 58 | def pre_release_work(patch=False): 59 | """Do all the necessary pre-release steps.""" 60 | # First let's get the default version: base version if we are in dev, bump minor otherwise. 61 | default_version = get_version() 62 | if patch and default_version.is_devrelease: 63 | raise ValueError("Can't create a patch version from the dev branch, checkout a released version!") 64 | if default_version.is_devrelease: 65 | default_version = default_version.base_version 66 | elif patch: 67 | default_version = f"{default_version.major}.{default_version.minor}.{default_version.micro + 1}" 68 | else: 69 | default_version = f"{default_version.major}.{default_version.minor + 1}.0" 70 | 71 | # Now let's ask nicely if that's the right one. 72 | version = input(f"Which version are you releasing? [{default_version}]") 73 | if len(version) == 0: 74 | version = default_version 75 | 76 | print(f"Updating version to {version}.") 77 | global_version_update(version, patch=patch) 78 | 79 | 80 | def post_release_work(): 81 | """Do all the necessary post-release steps.""" 82 | # First let's get the current version 83 | current_version = get_version() 84 | dev_version = f"{current_version.major}.{current_version.minor + 1}.0.dev0" 85 | current_version = current_version.base_version 86 | 87 | # Check with the user we got that right. 88 | version = input(f"Which version are we developing now? [{dev_version}]") 89 | if len(version) == 0: 90 | version = dev_version 91 | 92 | print(f"Updating version to {version}.") 93 | global_version_update(version) 94 | 95 | 96 | if __name__ == "__main__": 97 | parser = argparse.ArgumentParser() 98 | parser.add_argument("--post_release", action="store_true", help="Whether this is pre or post release.") 99 | parser.add_argument("--patch", action="store_true", help="Whether or not this is a patch release.") 100 | args = parser.parse_args() 101 | if not args.post_release: 102 | pre_release_work(patch=args.patch) 103 | elif args.patch: 104 | print("Nothing to do after a patch :-)") 105 | else: 106 | post_release_work() 107 | -------------------------------------------------------------------------------- /src/data/__init__.py: -------------------------------------------------------------------------------- 1 | from src.data.data import * -------------------------------------------------------------------------------- /src/data/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FSoft-AI4Code/XMainframe/801033385e0457667aff301dc07df1c1b8ca4b04/src/data/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /src/data/__pycache__/data.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FSoft-AI4Code/XMainframe/801033385e0457667aff301dc07df1c1b8ca4b04/src/data/__pycache__/data.cpython-310.pyc -------------------------------------------------------------------------------- /src/data/__pycache__/utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FSoft-AI4Code/XMainframe/801033385e0457667aff301dc07df1c1b8ca4b04/src/data/__pycache__/utils.cpython-310.pyc -------------------------------------------------------------------------------- /src/data/data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | from typing import List, Literal, Optional 4 | 5 | import pandas as pd 6 | from datasets import DatasetDict, concatenate_datasets, load_dataset, load_from_disk, Dataset 7 | from datasets.builder import DatasetGenerationError 8 | from src.alignment import DataArguments 9 | 10 | from src.data.utils import markdown_to_text 11 | 12 | 13 | def apply_chat_template( 14 | example, tokenizer, task: Literal["sft", "generation", "rm", "dpo"] = "sft", assistant_prefix="<|assistant|>\n" 15 | ): 16 | def _strip_prefix(s, pattern): 17 | # Use re.escape to escape any special characters in the pattern 18 | return re.sub(f"^{re.escape(pattern)}", "", s) 19 | 20 | # if task in ["sft", "generation"]: 21 | if task == "sft": 22 | messages = example["conversations"] 23 | # We add an empty system message if there is none 24 | if messages[0]["from"] != "system": 25 | messages.insert(0, {"from": "system", "value": ""}) 26 | example["text"] = tokenizer.apply_chat_template( 27 | messages, tokenize=False, add_generation_prompt=True if task == "generation" else False 28 | ) 29 | elif task == "generation": 30 | pass 31 | # if "markdown" in example: 32 | # example["text"] = markdown_to_text(example["markdown"]) 33 | # else: 34 | # pass 35 | elif task == "rm": 36 | if all(k in example.keys() for k in ("chosen", "rejected")): 37 | chosen_messages = example["chosen"] 38 | rejected_messages = example["rejected"] 39 | # We add an empty system message if there is none 40 | if chosen_messages[0]["role"] != "system": 41 | chosen_messages.insert(0, {"role": "system", "content": ""}) 42 | if rejected_messages[0]["role"] != "system": 43 | rejected_messages.insert(0, {"role": "system", "content": ""}) 44 | example["text_chosen"] = tokenizer.apply_chat_template(chosen_messages, tokenize=False) 45 | example["text_rejected"] = tokenizer.apply_chat_template(rejected_messages, tokenize=False) 46 | else: 47 | raise ValueError( 48 | f"Could not format example as dialogue for `rm` task! Require `[chosen, rejected]` keys but found {list(example.keys())}" 49 | ) 50 | elif task == "dpo": 51 | if all(k in example.keys() for k in ("chosen", "rejected")): 52 | # Compared to reward modeling, we filter out the prompt, so the text is everything after the last assistant token 53 | prompt_messages = [[msg for msg in example["chosen"] if msg["role"] == "user"][0]] 54 | # Insert system message 55 | if example["chosen"][0]["role"] != "system": 56 | prompt_messages.insert(0, {"role": "system", "content": ""}) 57 | else: 58 | prompt_messages.insert(0, example["chosen"][0]) 59 | # TODO: handle case where chosen/rejected also have system messages 60 | chosen_messages = example["chosen"][1:] 61 | rejected_messages = example["rejected"][1:] 62 | example["text_chosen"] = tokenizer.apply_chat_template(chosen_messages, tokenize=False) 63 | example["text_rejected"] = tokenizer.apply_chat_template(rejected_messages, tokenize=False) 64 | example["text_prompt"] = tokenizer.apply_chat_template( 65 | prompt_messages, tokenize=False, add_generation_prompt=True 66 | ) 67 | example["text_chosen"] = _strip_prefix(example["text_chosen"], assistant_prefix) 68 | example["text_rejected"] = _strip_prefix(example["text_rejected"], assistant_prefix) 69 | else: 70 | raise ValueError( 71 | f"Could not format example as dialogue for `dpo` task! Require `[chosen, rejected]` keys but found {list(example.keys())}" 72 | ) 73 | else: 74 | raise ValueError( 75 | f"Task {task} not supported, please ensure that the provided task is one of {['sft', 'generation', 'rm', 'dpo']}" 76 | ) 77 | return example 78 | 79 | 80 | def get_datasets( 81 | data_config: DataArguments | dict, 82 | splits: List[str] = ["train", "test"], 83 | shuffle: bool = True, 84 | ) -> DatasetDict: 85 | """ 86 | Loads one or more datasets with varying training set proportions. 87 | 88 | Args: 89 | data_config (`DataArguments` or `dict`): 90 | Dataset configuration and split proportions. 91 | splits (`List[str]`, *optional*, defaults to `['train', 'test']`): 92 | Dataset splits to load and mix. Assumes the splits exist in all datasets and have a `train_` or `test_` prefix. 93 | shuffle (`bool`, *optional*, defaults to `True`): 94 | Whether to shuffle the training and testing/validation data. 95 | 96 | Returns 97 | [`DatasetDict`]: The dataset dictionary containing the loaded datasets. 98 | """ 99 | 100 | if type(data_config) is DataArguments: 101 | # Structure of the config to read the datasets and their mix 102 | # datasets_mixer: 103 | # - 'dataset1': 0.5 104 | # - 'dataset2': 0.3 105 | # - 'dataset3': 0.2 106 | dataset_mixer = data_config.dataset_mixer 107 | elif type(data_config) is dict: 108 | # Structure of the input is: 109 | # dataset_mixer = { 110 | # "dataset1": 0.5, 111 | # "dataset1": 0.3, 112 | # "dataset1": 0.2, 113 | # } 114 | dataset_mixer = data_config 115 | else: 116 | raise ValueError(f"Data config {data_config} not recognized.") 117 | 118 | raw_datasets = mix_datasets(dataset_mixer, splits=splits, shuffle=shuffle) 119 | return raw_datasets 120 | 121 | 122 | def mix_datasets(dataset_mixer: dict, splits: Optional[List[str]] = None, shuffle=True) -> DatasetDict: 123 | """ 124 | Loads and mixes datasets according to proportions specified in `dataset_mixer`. 125 | 126 | Args: 127 | dataset_mixer (`dict`): 128 | Dictionary containing the dataset names and their training proportions. By default, all test proportions are 1. 129 | splits (Optional[List[str]], *optional*, defaults to `None`): 130 | Dataset splits to load and mix. Assumes the splits exist in all datasets and have a `train_` or `test_` prefix. 131 | shuffle (`bool`, *optional*, defaults to `True`): 132 | Whether to shuffle the training and testing/validation data. 133 | """ 134 | raw_datasets = DatasetDict() 135 | raw_train_datasets = [] 136 | raw_val_datasets = [] 137 | fracs = [] 138 | for ds, frac in dataset_mixer.items(): 139 | fracs.append(frac) 140 | for split in splits: 141 | try: 142 | # Try first if dataset on a Hub repo 143 | dataset = load_dataset(ds, split=split) 144 | # except DatasetGenerationError: 145 | # # If not, check local dataset 146 | # # dataset = load_from_disk(os.path.join(ds, split)) 147 | except Exception as e: 148 | dataset = Dataset.from_pandas(pd.read_feather(ds)) 149 | 150 | if "train" in split: 151 | raw_train_datasets.append(dataset) 152 | elif "test" in split: 153 | raw_val_datasets.append(dataset) 154 | else: 155 | raise ValueError(f"Split type {split} not recognized as one of test or train.") 156 | 157 | if any(frac < 0 for frac in fracs): 158 | raise ValueError("Dataset fractions cannot be negative.") 159 | 160 | if len(raw_train_datasets) > 0: 161 | train_subsets = [] 162 | for dataset, frac in zip(raw_train_datasets, fracs): 163 | train_subset = dataset.select(range(int(frac * len(dataset)))) 164 | train_subsets.append(train_subset) 165 | if shuffle: 166 | raw_datasets["train"] = concatenate_datasets(train_subsets).shuffle(seed=42) 167 | else: 168 | raw_datasets["train"] = concatenate_datasets(train_subsets) 169 | # No subsampling for test datasets to enable fair comparison across models 170 | if len(raw_val_datasets) > 0: 171 | if shuffle: 172 | raw_datasets["test"] = concatenate_datasets(raw_val_datasets).shuffle(seed=42) 173 | else: 174 | raw_datasets["test"] = concatenate_datasets(raw_val_datasets) 175 | 176 | if len(raw_datasets) == 0: 177 | raise ValueError( 178 | f"Dataset {dataset_mixer} not recognized with split {split}. Check the dataset has been correctly formatted." 179 | ) 180 | 181 | return raw_datasets -------------------------------------------------------------------------------- /src/data/utils.py: -------------------------------------------------------------------------------- 1 | from bs4 import BeautifulSoup 2 | from markdown import markdown 3 | import re 4 | 5 | def markdown_to_text(markdown_string): 6 | """ Converts a markdown string to plaintext """ 7 | 8 | # md -> html -> text since BeautifulSoup can extract text cleanly 9 | html = markdown(markdown_string) 10 | 11 | # remove code snippets 12 | html = re.sub(r'
(.*?)
', ' ', html) 13 | html = re.sub(r'(.*?)', ' ', html) 14 | 15 | # extract text 16 | soup = BeautifulSoup(html, "html.parser") 17 | text = ''.join(soup.findAll(text=True)) 18 | 19 | return text -------------------------------------------------------------------------------- /src/model/__init__.py: -------------------------------------------------------------------------------- 1 | from src.model.tokenizer import * -------------------------------------------------------------------------------- /src/model/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FSoft-AI4Code/XMainframe/801033385e0457667aff301dc07df1c1b8ca4b04/src/model/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /src/model/__pycache__/tokenizer.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FSoft-AI4Code/XMainframe/801033385e0457667aff301dc07df1c1b8ca4b04/src/model/__pycache__/tokenizer.cpython-310.pyc -------------------------------------------------------------------------------- /src/model/tokenizer.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoTokenizer, PreTrainedTokenizer 2 | from src.alignment.configs import DataArguments, ModelArguments 3 | 4 | 5 | DEFAULT_CHAT_TEMPLATE = "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}" 6 | 7 | def get_tokenizer(model_args: ModelArguments, data_args: DataArguments) -> PreTrainedTokenizer: 8 | """Get the tokenizer for the model.""" 9 | tokenizer = AutoTokenizer.from_pretrained( 10 | model_args.model_name_or_path, 11 | revision=model_args.model_revision, 12 | trust_remote_code=True 13 | ) 14 | if tokenizer.pad_token_id is None: 15 | tokenizer.pad_token_id = tokenizer.eos_token_id 16 | 17 | if data_args.truncation_side is not None: 18 | tokenizer.truncation_side = data_args.truncation_side 19 | 20 | # Set reasonable default for models without max length 21 | # if tokenizer.model_max_length > 100_000: 22 | # tokenizer.model_max_length = 4096 23 | tokenizer.model_max_length = 4096 24 | 25 | if data_args.chat_template is not None: 26 | tokenizer.chat_template = data_args.chat_template 27 | elif tokenizer.chat_template is None: 28 | tokenizer.chat_template = DEFAULT_CHAT_TEMPLATE 29 | 30 | return tokenizer 31 | 32 | 33 | def get_tokenizer_phi2(model_args: ModelArguments, data_args: DataArguments) -> PreTrainedTokenizer: 34 | """Get the tokenizer for the model.""" 35 | tokenizer = AutoTokenizer.from_pretrained( 36 | model_args.model_name_or_path, 37 | revision=model_args.model_revision, 38 | use_fast=False 39 | ) 40 | tokenizer.add_tokens(["<|im_start|>", ""]) 41 | tokenizer.pad_token = "" 42 | tokenizer.add_special_tokens(dict(eos_token="<|im_end|>")) 43 | 44 | tokenizer.model_max_length = 2048 45 | 46 | 47 | return tokenizer 48 | 49 | def get_tokenizer_qwen15(model_args: ModelArguments, data_args: DataArguments) -> PreTrainedTokenizer: 50 | """Get the tokenizer for the model.""" 51 | tokenizer = AutoTokenizer.from_pretrained( 52 | model_args.model_name_or_path, 53 | revision=model_args.model_revision, 54 | model_max_length=8192, 55 | padding_side="right", 56 | use_fast=False, 57 | trust_remote_code=True 58 | ) 59 | # tokenizer.pad_token_id = tokenizer.eod_id 60 | 61 | # tokenizer.model_max_length = 8192 62 | 63 | return tokenizer 64 | 65 | 66 | def get_tokenizer_code_llama(model_args: ModelArguments, data_args: DataArguments) -> PreTrainedTokenizer: 67 | """Get the tokenizer for the model.""" 68 | tokenizer = AutoTokenizer.from_pretrained( 69 | model_args.model_name_or_path, 70 | revision=model_args.model_revision 71 | ) 72 | tokenizer.add_eos_token = True 73 | tokenizer.pad_token_id = 0 74 | tokenizer.padding_side = "left" 75 | 76 | return tokenizer -------------------------------------------------------------------------------- /train_instruct.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | # Copyright 2023 The HuggingFace Inc. team. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ 17 | Supervised fine-tuning script for decoder language models. 18 | """ 19 | 20 | import logging 21 | import random 22 | import sys 23 | import datetime 24 | 25 | import pandas as pd 26 | import datasets 27 | import torch 28 | import transformers 29 | from transformers import set_seed, DataCollatorForLanguageModeling,AutoModelForCausalLM 30 | from datasets import load_dataset, Dataset, concatenate_datasets 31 | 32 | from accelerate import Accelerator 33 | from src.alignment import ( 34 | DataArguments, 35 | H4ArgumentParser, 36 | ModelArguments, 37 | SFTConfig, 38 | # apply_chat_template, 39 | # get_datasets, 40 | get_checkpoint, 41 | get_kbit_device_map, 42 | get_peft_config, 43 | get_quantization_config, 44 | # get_tokenizer, 45 | ) 46 | # from trl import SFTConfig 47 | from trl import SFTTrainer, setup_chat_format 48 | from src.data import get_datasets, apply_chat_template 49 | from src.model import get_tokenizer 50 | 51 | 52 | logger = logging.getLogger(__name__) 53 | 54 | 55 | def main(): 56 | parser = H4ArgumentParser((ModelArguments, DataArguments, SFTConfig)) 57 | model_args, data_args, training_args = parser.parse() 58 | 59 | # Set seed for reproducibility 60 | set_seed(training_args.seed) 61 | 62 | # accelerator = Accelerator() 63 | 64 | ############### 65 | # Setup logging 66 | ############### 67 | logging.basicConfig( 68 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 69 | datefmt="%Y-%m-%d %H:%M:%S", 70 | handlers=[logging.StreamHandler(sys.stdout)], 71 | ) 72 | log_level = training_args.get_process_log_level() 73 | logger.setLevel(log_level) 74 | datasets.utils.logging.set_verbosity(log_level) 75 | transformers.utils.logging.set_verbosity(log_level) 76 | transformers.utils.logging.enable_default_handler() 77 | transformers.utils.logging.enable_explicit_format() 78 | 79 | # Log on each process a small summary 80 | logger.warning( 81 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" 82 | + f" distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" 83 | ) 84 | logger.info(f"Model parameters {model_args}") 85 | logger.info(f"Data parameters {data_args}") 86 | logger.info(f"Training/evaluation parameters {training_args}") 87 | 88 | # Check for last checkpoint 89 | last_checkpoint = get_checkpoint(training_args) 90 | if last_checkpoint is not None and training_args.resume_from_checkpoint is None: 91 | logger.info(f"Checkpoint detected, resuming training at {last_checkpoint=}.") 92 | 93 | ############### 94 | # Load datasets 95 | ############### 96 | 97 | mainframe_instruct_ds = Dataset.from_pandas( 98 | pd.read_feather( 99 | ".data/mainframegpt_instructions_vananh_20240209.feather" 100 | ) 101 | ) 102 | mainframe_oss_instruct_ds = Dataset.from_pandas( 103 | pd.read_feather( 104 | "./data/mainframegpt_oss_instruct.feather" 105 | ) 106 | ) 107 | mainframe_self_instruct_ds = Dataset.from_pandas( 108 | pd.read_feather( 109 | "./data/mainframegpt_self_instruct_2151.feather" 110 | ) 111 | ) 112 | slim_orca_ds = load_dataset("Open-Orca/SlimOrca-Dedup", cache_dir="./hf_cache/datasets") 113 | 114 | raw_datasets = concatenate_datasets([mainframe_instruct_ds, slim_orca_ds["train"]]).shuffle(seed=42) 115 | raw_datasets = raw_datasets.train_test_split(test_size=0.1) 116 | logger.info( 117 | f"Training on the following datasets and their proportions: {[split + ' : ' + str(dset.num_rows) for split, dset in raw_datasets.items()]}" 118 | ) 119 | column_names = list(raw_datasets["train"].features) 120 | 121 | ################ 122 | # Load tokenizer 123 | ################ 124 | tokenizer = get_tokenizer(model_args, data_args) 125 | 126 | ##################### 127 | # Apply chat template 128 | ##################### 129 | raw_datasets = raw_datasets.map( 130 | apply_chat_template, 131 | fn_kwargs={"tokenizer": tokenizer, "task": "sft"}, 132 | num_proc=data_args.preprocessing_num_workers, 133 | remove_columns=column_names, 134 | desc="Applying chat template", 135 | ) 136 | train_dataset = raw_datasets["train"] 137 | eval_dataset = raw_datasets["test"] 138 | 139 | data_collator = DataCollatorForLanguageModeling( 140 | tokenizer=tokenizer, 141 | mlm=False, 142 | return_tensors='pt' 143 | ) 144 | 145 | ####################### 146 | # Load pretrained model 147 | ####################### 148 | logger.info("*** Load pretrained model ***") 149 | torch_dtype =( 150 | model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype) 151 | ) 152 | quantization_config = get_quantization_config(model_args) 153 | 154 | model_kwargs = dict( 155 | revision=model_args.model_revision, 156 | trust_remote_code=model_args.trust_remote_code, 157 | use_flash_attention_2=model_args.use_flash_attention_2, 158 | torch_dtype=torch_dtype, 159 | use_cache=False if training_args.gradient_checkpointing else True, 160 | device_map=get_kbit_device_map() if quantization_config is not None else None, 161 | quantization_config=quantization_config, 162 | ) 163 | print(model_kwargs) 164 | print(type(torch_dtype)) 165 | 166 | model = model_args.model_name_or_path 167 | # For ChatML we need to add special tokens and resize the embedding layer 168 | if "<|im_start|>" in tokenizer.chat_template and "gemma-tokenizer-chatml" not in tokenizer.name_or_path: 169 | model = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path, **model_kwargs) 170 | model, tokenizer = setup_chat_format(model, tokenizer) 171 | model_kwargs = None 172 | 173 | logger.info("*** Model loaded! ***") 174 | 175 | ######################## 176 | # Initialize the Trainer 177 | ######################## 178 | trainer = SFTTrainer( 179 | model=model, 180 | model_init_kwargs=model_kwargs, 181 | args=training_args, 182 | train_dataset=train_dataset, 183 | eval_dataset=eval_dataset, 184 | dataset_text_field="text", 185 | max_seq_length=training_args.max_seq_length, 186 | tokenizer=tokenizer, 187 | packing=True, 188 | peft_config=get_peft_config(model_args), 189 | neftune_noise_alpha=5, 190 | data_collator=data_collator, 191 | ) 192 | 193 | ############### 194 | # Training loop 195 | ############### 196 | logger.info("*** Train ***") 197 | checkpoint = None 198 | if training_args.resume_from_checkpoint is not None: 199 | checkpoint = training_args.resume_from_checkpoint 200 | elif last_checkpoint is not None: 201 | checkpoint = last_checkpoint 202 | train_result = trainer.train(resume_from_checkpoint=checkpoint) 203 | metrics = train_result.metrics 204 | metrics["train_samples"] = len(train_dataset) 205 | trainer.log_metrics("train", metrics) 206 | trainer.save_metrics("train", metrics) 207 | trainer.save_state() 208 | 209 | ########## 210 | # Evaluate 211 | ########## 212 | if training_args.do_eval: 213 | logger.info("*** Evaluate ***") 214 | metrics = trainer.evaluate() 215 | metrics["eval_samples"] = len(eval_dataset) 216 | trainer.log_metrics("eval", metrics) 217 | trainer.save_metrics("eval", metrics) 218 | 219 | ################################## 220 | # Save model and create model card 221 | ################################## 222 | logger.info("*** Save model ***") 223 | trainer.save_model(training_args.output_dir) 224 | logger.info(f"Model saved to {training_args.output_dir}") 225 | 226 | kwargs = { 227 | "finetuned_from": model_args.model_name_or_path, 228 | "dataset": list(data_args.dataset_mixer.keys()), 229 | "dataset_tags": list(data_args.dataset_mixer.keys()), 230 | "tags": ["alignment-handbook"], 231 | } 232 | if trainer.accelerator.is_main_process: 233 | trainer.create_model_card(**kwargs) 234 | # Restore k,v cache for fast inference 235 | trainer.model.config.use_cache = True 236 | trainer.model.config.save_pretrained(training_args.output_dir) 237 | 238 | if training_args.push_to_hub is True: 239 | logger.info("Pushing to hub...") 240 | trainer.push_to_hub(**kwargs) 241 | 242 | logger.info("*** Training complete ***") 243 | 244 | if __name__ == "__main__": 245 | main() 246 | -------------------------------------------------------------------------------- /train_raw.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | # Copyright 2023 The HuggingFace Inc. team. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ 17 | Supervised fine-tuning script for decoder language models. 18 | """ 19 | 20 | import logging 21 | import random 22 | import sys 23 | import datetime 24 | 25 | import pandas as pd 26 | import datasets 27 | import torch 28 | import transformers 29 | from transformers import set_seed, DataCollatorForLanguageModeling,AutoModelForCausalLM 30 | from datasets import load_dataset, Dataset, concatenate_datasets 31 | 32 | # import sys 33 | # sys.path.append('./alignment-handbook/src') 34 | from accelerate import Accelerator 35 | from src.alignment import ( 36 | DataArguments, 37 | H4ArgumentParser, 38 | ModelArguments, 39 | SFTConfig, 40 | # apply_chat_template, 41 | # get_datasets, 42 | get_checkpoint, 43 | get_kbit_device_map, 44 | get_peft_config, 45 | get_quantization_config, 46 | # get_tokenizer, 47 | ) 48 | from trl import SFTTrainer,setup_chat_format 49 | from src.data import get_datasets, apply_chat_template 50 | from src.model.tokenizer import get_tokenizer, get_tokenizer_phi2, get_tokenizer_qwen15, get_tokenizer_code_llama 51 | 52 | # import deepspeed 53 | # from deepspeed.ops.op_builder import builder 54 | 55 | logger = logging.getLogger(__name__) 56 | 57 | 58 | def main(): 59 | parser = H4ArgumentParser((ModelArguments, DataArguments, SFTConfig)) 60 | model_args, data_args, training_args = parser.parse() 61 | 62 | # Set seed for reproducibility 63 | set_seed(training_args.seed) 64 | 65 | # accelerator = Accelerator() 66 | 67 | ############### 68 | # Setup logging 69 | ############### 70 | logging.basicConfig( 71 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 72 | datefmt="%Y-%m-%d %H:%M:%S", 73 | handlers=[logging.StreamHandler(sys.stdout)], 74 | ) 75 | log_level = training_args.get_process_log_level() 76 | logger.setLevel(log_level) 77 | datasets.utils.logging.set_verbosity(log_level) 78 | transformers.utils.logging.set_verbosity(log_level) 79 | transformers.utils.logging.enable_default_handler() 80 | transformers.utils.logging.enable_explicit_format() 81 | 82 | # Log on each process a small summary 83 | logger.warning( 84 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" 85 | + f" distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" 86 | ) 87 | logger.info(f"Model parameters {model_args}") 88 | logger.info(f"Data parameters {data_args}") 89 | logger.info(f"Training/evaluation parameters {training_args}") 90 | 91 | # Check for last checkpoint 92 | last_checkpoint = get_checkpoint(training_args) 93 | if last_checkpoint is not None and training_args.resume_from_checkpoint is None: 94 | logger.info(f"Checkpoint detected, resuming training at {last_checkpoint=}.") 95 | 96 | ############### 97 | # Load datasets 98 | ############### 99 | # raw_datasets = get_datasets(data_args, splits=data_args.dataset_splits) 100 | mainframe_ds = Dataset.from_pandas(pd.read_feather("./data/mainframe_df_v1.1_chunks_4096.feather")) 101 | textbook_ds = Dataset.from_pandas(pd.read_feather("./data/textbook_quality_programming.feather")) 102 | longcontext = Dataset.from_pandas(pd.read_feather(".data/long_context_data/long_data.feather")) 103 | raw_datasets = longcontext.train_test_split(test_size=0.1) 104 | logger.info( 105 | f"Training on the following datasets and their proportions: {[split + ' : ' + str(dset.num_rows) for split, dset in raw_datasets.items()]}" 106 | ) 107 | column_names = list(raw_datasets["train"].features) 108 | 109 | ################ 110 | # Load tokenizer 111 | ################ 112 | tokenizer = get_tokenizer_code_llama(model_args, data_args) 113 | 114 | 115 | train_dataset = raw_datasets["train"] 116 | eval_dataset = raw_datasets["test"] 117 | 118 | with training_args.main_process_first(desc="Log a few random samples from the processed training set"): 119 | for index in random.sample(range(len(raw_datasets["train"])), 3): 120 | logger.info(f"Sample {index} of the processed training set:\n\n{raw_datasets['train'][index]['text']}") 121 | 122 | data_collator = DataCollatorForLanguageModeling( 123 | tokenizer=tokenizer, 124 | mlm=False, 125 | return_tensors='pt' 126 | ) 127 | 128 | 129 | ####################### 130 | # Load pretrained model 131 | ####################### 132 | logger.info("*** Load pretrained model ***") 133 | torch_dtype = ( 134 | model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype) 135 | ) 136 | # torch_dtype = 'float16' 137 | quantization_config = get_quantization_config(model_args) 138 | 139 | model_kwargs = dict( 140 | revision=model_args.model_revision, 141 | trust_remote_code=model_args.trust_remote_code, 142 | use_flash_attention_2=model_args.use_flash_attention_2, 143 | torch_dtype=torch_dtype, 144 | use_cache=False if training_args.gradient_checkpointing else True, 145 | # device_map=get_kbit_device_map() if quantization_config is not None else None, 146 | quantization_config=quantization_config, 147 | ) 148 | 149 | logger.info("*** Model loaded! ***") 150 | 151 | ######################## 152 | # Initialize the Trainer 153 | ######################## 154 | trainer = SFTTrainer( 155 | model=model_args.model_name_or_path, 156 | model_init_kwargs=model_kwargs, 157 | args=training_args, 158 | train_dataset=train_dataset, 159 | eval_dataset=eval_dataset, 160 | dataset_text_field="text", 161 | max_seq_length=training_args.max_seq_length, 162 | tokenizer=tokenizer, 163 | packing=True, 164 | peft_config=get_peft_config(model_args), 165 | neftune_noise_alpha=5, 166 | data_collator=data_collator, 167 | ) 168 | 169 | ############### 170 | # Training loop 171 | ############### 172 | logger.info("*** Train ***") 173 | # train_result = trainer.train() 174 | checkpoint = None 175 | if training_args.resume_from_checkpoint is not None: 176 | checkpoint = training_args.resume_from_checkpoint 177 | elif last_checkpoint is not None: 178 | checkpoint = last_checkpoint 179 | 180 | train_result = trainer.train(resume_from_checkpoint=checkpoint) 181 | metrics = train_result.metrics 182 | metrics["train_samples"] = len(train_dataset) 183 | trainer.log_metrics("train", metrics) 184 | trainer.save_metrics("train", metrics) 185 | trainer.save_state() 186 | 187 | ########## 188 | # Evaluate 189 | ########## 190 | if training_args.do_eval: 191 | logger.info("*** Evaluate ***") 192 | metrics = trainer.evaluate() 193 | metrics["eval_samples"] = len(eval_dataset) 194 | trainer.log_metrics("eval", metrics) 195 | trainer.save_metrics("eval", metrics) 196 | 197 | ################################## 198 | # Save model and create model card 199 | ################################## 200 | logger.info("*** Save model ***") 201 | trainer.save_model(training_args.output_dir) 202 | logger.info(f"Model saved to {training_args.output_dir}") 203 | 204 | kwargs = { 205 | "finetuned_from": model_args.model_name_or_path, 206 | "dataset": list(data_args.dataset_mixer.keys()), 207 | "dataset_tags": list(data_args.dataset_mixer.keys()), 208 | "tags": ["alignment-handbook"], 209 | } 210 | # if trainer.accelerator.is_main_process: 211 | trainer.create_model_card(**kwargs) 212 | # Restore k,v cache for fast inference 213 | trainer.model.config.use_cache = True 214 | trainer.model.config.save_pretrained(training_args.output_dir) 215 | 216 | if training_args.push_to_hub is True: 217 | logger.info("Pushing to hub...") 218 | trainer.push_to_hub(**kwargs) 219 | 220 | logger.info("*** Training complete ***") 221 | 222 | if __name__ == "__main__": 223 | main() 224 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import torch 4 | import multiprocessing 5 | import torch.distributed as dist 6 | 7 | 8 | def setup_for_distributed(is_master): 9 | """ 10 | This function disables printing when not in master process 11 | """ 12 | import builtins as __builtin__ 13 | builtin_print = __builtin__.print 14 | 15 | def print(*args, **kwargs): 16 | force = kwargs.pop('force', False) 17 | if is_master or force: 18 | builtin_print(*args, **kwargs) 19 | 20 | __builtin__.print = print 21 | 22 | 23 | def get_world_size(): 24 | if not is_dist_avail_and_initialized(): 25 | return 1 26 | return dist.get_world_size() 27 | 28 | def is_dist_avail_and_initialized(): 29 | if not dist.is_available(): 30 | return False 31 | if not dist.is_initialized(): 32 | return False 33 | return True 34 | 35 | def get_rank(): 36 | if not is_dist_avail_and_initialized(): 37 | return 0 38 | return dist.get_rank() 39 | 40 | def init_distributed_mode(args): 41 | cpu_cont = multiprocessing.cpu_count() 42 | # if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 43 | # args.rank = int(os.environ["RANK"]) 44 | # args.world_size = int(os.environ['WORLD_SIZE']) 45 | # args.gpu = int(os.environ['LOCAL_RANK']) 46 | # elif 'SLURM_PROCID' in os.environ: 47 | # args.rank = int(os.environ['SLURM_PROCID']) 48 | # args.gpu = args.rank % torch.cuda.device_count() 49 | # else: 50 | # print('Not using distributed mode') 51 | # args.distributed = False 52 | # return 53 | 54 | args.distributed = True 55 | args.rank = get_rank() 56 | # args.world_size = get_world_size() 57 | args.gpu = args.rank % torch.cuda.device_count() 58 | 59 | torch.cuda.set_device(args.gpu) 60 | args.dist_backend = 'nccl' 61 | print('| distributed init (rank {}, word {}): {}'.format( 62 | args.rank, args.world_size, args.dist_url), flush=True) 63 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 64 | world_size=args.world_size, rank=args.rank) 65 | torch.distributed.barrier() 66 | device = torch.device("cuda", args.gpu) 67 | args.n_gpu = torch.cuda.device_count() 68 | args.device = device 69 | args.cpu_cont = cpu_cont 70 | setup_for_distributed(args.rank == 0) 71 | 72 | --------------------------------------------------------------------------------