├── CoachLM150 Test Set.json ├── Expert Revision Dataset.json ├── LICENSE ├── README.md ├── Technical Report.pdf ├── asset ├── .gitkeep ├── coachLM.png ├── dataset_score.png ├── illustration.png └── win_rates.png ├── data ├── .ipynb_checkpoints │ ├── dataset_info-checkpoint.json │ └── example-checkpoint.json ├── dataset_info.json └── example.json ├── requirements.txt └── src ├── glmtuner ├── __init__.py ├── api │ ├── __init__.py │ ├── app.py │ └── protocol.py ├── chat │ ├── __init__.py │ └── stream_chat.py ├── dsets │ ├── __init__.py │ ├── collator.py │ ├── loader.py │ ├── preprocess.py │ └── utils.py ├── extras │ ├── __init__.py │ ├── callbacks.py │ ├── constants.py │ ├── logging.py │ ├── misc.py │ ├── ploting.py │ └── save_and_load.py ├── hparams │ ├── __init__.py │ ├── data_args.py │ ├── finetuning_args.py │ ├── general_args.py │ ├── generating_args.py │ └── model_args.py ├── tuner │ ├── __init__.py │ ├── core │ │ ├── __init__.py │ │ ├── adapter.py │ │ ├── loader.py │ │ ├── parser.py │ │ └── trainer.py │ ├── ppo │ │ ├── __init__.py │ │ ├── trainer.py │ │ ├── utils.py │ │ └── workflow.py │ ├── rm │ │ ├── __init__.py │ │ ├── collator.py │ │ ├── metric.py │ │ ├── trainer.py │ │ └── workflow.py │ └── sft │ │ ├── __init__.py │ │ ├── metric.py │ │ ├── trainer.py │ │ └── workflow.py └── webui │ ├── __init__.py │ ├── chat.py │ ├── common.py │ ├── components │ ├── __init__.py │ ├── chatbot.py │ ├── data.py │ ├── eval.py │ ├── export.py │ ├── infer.py │ ├── sft.py │ └── top.py │ ├── css.py │ ├── interface.py │ ├── locales.py │ ├── manager.py │ ├── runner.py │ └── utils.py ├── train_bash.py └── web_demo.py /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CoachLM 2 | 3 | This repo contains codes and the human-created training dataset for CoachLM, an automatic instruction revision approach for LLM instruction tuning. 4 | 5 | Paper: https://arxiv.org/abs/2311.13246 6 | 7 | See more cases analysis for the human revision in Technical Report.pdf 8 | ## 📰 News 9 | CoachLM has been accepted by IEEE International Conference on Data Engineering (ICDE) 2024🎉🎉🎉🎉🎉🎉. 10 | ## 📣 Introduction 11 |

12 | 13 |

14 | 15 | Instruction tuning is crucial for enabling Language Learning Models (LLMs) in responding to human instructions. The quality of instruction pairs used for tuning greatly affects the performance of LLMs. However, the manual creation of high-quality instruction datasets is costly, leading to the adoption of automatic generation of instruction pairs by LLMs as a popular alternative. To ensure the high quality of LLM-generated instruction datasets, several approaches have been proposed. Nevertheless, existing methods either compromise dataset integrity by filtering a large proportion of samples, or are unsuitable for industrial applications. Instead of discarding low-quality samples, we propose CoachLM, a novel approach to enhance the quality of instruction datasets through automatic revisions on samples in the dataset. CoachLM is trained from the samples revised by human experts and significantly increases the proportion of high-quality samples in the dataset from 17.7% to 78.9%. The effectiveness of CoachLM is further assessed on various real-world instruction test sets. The results show that CoachLM improves the instruction-following capabilities of the instruction-tuned LLM by an average of 29.9%, which even surpasses larger LLMs with nearly twice the number of parameters. 16 | 17 |

18 | 19 |

20 | 21 |

22 | 23 |

24 | 25 | ## 🔰 Installation 26 | ``` 27 | $ pip install requirements.txt 28 | ``` 29 | ## ✨ Expert Revision Dataset 30 | We created a dataset of 2301 samples containing the raw instruction pairs from the Alpaca52k dataset and the human-revised results in order to improve the quality of these LLM-generated instruction pairs. 31 | ``` 32 | Expert Revision Dataset.json 33 | { 34 | "Raw Instruction": "", 35 | "Raw Input": "", 36 | "Raw Response": "", 37 | "Revised Instruction": "", 38 | "Revised Input": "", 39 | "Revised Response": "", 40 | "Distance": "" 41 | } 42 | ``` 43 | As inllustrated above, the dataset contain raw instructions and the revised versions. it also records the edit distances between the raw instructions and the revised samples. 44 | ## 📝 Training CoachLM 45 | CoachLM in this repository is implemented by fine-tuning ChatGLM2 on our curated expert-revision dataset. Thanks [ChatGLM-Efficient-Tuning](https://github.com/hiyouga/ChatGLM-Efficient-Tuning) for implementing an efficient tool to fine-tune ChatGLM2. The training steps of CoachLM are as follows: 46 | 47 | 48 | (1) Determine the subset used for training 49 | 50 | 51 | Training CoachLM using the whole dataset may not lead to optimal results. We recommend using around the first 30% samples in the dataset sorted by edit distance from largest to smallest. Other approaches of selecting the dataset are also possible. 52 | 53 | 54 | (2) Format the training dataset 55 | 56 | 57 | Format the training dataset according to data/example.json, and register the dataset in data/dataset_info.json. 58 | ``` 59 | example.json 60 | { 61 | "instruction": "Improve the following instruction, input and response pair to be more specific, detailed with more logical steps and grammarly corrected.", 62 | "input": "Instruction: Name three natural elements. Response: Water, air, and fire.", 63 | "output": "Instruction: Name three natural elements. Response: Some examples of natural elements are:\n\n- Oxygen: This is the most abundant element in the Earth's crust and the second most abundant element in the atmosphere. Oxygen is essential for life, as it is involved in cellular respiration and other metabolic processes. \n\n- Iron: This is the most abundant element in the Earth's core and the fourth most abundant element in the Earth's crust. Iron is a metal that has many physical and chemical properties, such as strength, magnetism, and conductivity. Iron is also important for life, as it is a component of hemoglobin, which transports oxygen in the blood. \n\n- Gold: This is one of the rarest and most valuable elements on Earth. Gold is a metal that has a shiny yellow color and a high density. Gold is resistant to corrosion and oxidation, which makes it suitable for jewelry and coins." 64 | } 65 | ``` 66 | ``` 67 | dataset_info.json 68 | { 69 | "example": { 70 | "file_name": "example.json", 71 | "file_sha1": "", 72 | "columns": { 73 | "prompt": "instruction", 74 | "query": "input", 75 | "response": "output" 76 | } 77 | } 78 | } 79 | ``` 80 | 81 | (3) Start training 82 | ``` 83 | python src/train_bash.py \ 84 | --stage sft \ 85 | --model_name_or_path path_to_your_chatglm2_model \ 86 | --do_train \ 87 | --dataset name_of_your_dataset_in_dataset_info_json \ 88 | --finetuning_type lora \ 89 | --output_dir path_to_CoachLM_checkpoint \ 90 | --per_device_train_batch_size 32 \ 91 | --gradient_accumulation_steps 1 \ 92 | --lr_scheduler_type cosine \ 93 | --logging_steps 100 \ 94 | --save_strategy "epoch" \ 95 | --learning_rate 2e-4 \ 96 | --num_train_epochs 7 \ 97 | --fp16 true \ 98 | --lora_rank 64 \ 99 | --lora_alpha 32 \ 100 | --lora_target "query_key_value,dense,dense_h_to_4h,dense_4h_to_h" \ 101 | --use_v2 true \ 102 | ``` 103 | 104 | (4) Inference 105 | 106 | 107 | The inference dataset should be formatted the same as example.json, with output field empty. 108 | ``` 109 | python src/train_bash.py \ 110 | --stage sft \ 111 | --do_predict \ 112 | --finetuning_type lora \ 113 | --dataset dataset_for_inference \ 114 | --model_name_or_path path_to_your_chatglm2_model \ 115 | --checkpoint_dir path_to_CoachLM_checkpoint \ 116 | --output_dir path_to_inference_result \ 117 | --per_device_eval_batch_size 32 \ 118 | --predict_with_generate \ 119 | --eval_num_beams 1 \ 120 | --lora_rank 64 \ 121 | --lora_alpha 32 \ 122 | --lora_target "query_key_value,dense,dense_h_to_4h,dense_4h_to_h" \ 123 | --use_v2 true \ 124 | ``` 125 | For more information, please refer to [ChatGLM-Efficient-Tuning](https://github.com/hiyouga/ChatGLM-Efficient-Tuning). 126 | 127 | ## 🧪 CoachLM150 Test Set 128 | We created an instruction-following test suite for LLMs, containing 150 questions covering topics from information extraction, scientific inference, dialogue completion, brainstorming, in-domain question answering, and more. For each question, a reference response is provided by human. 129 | ``` 130 | CoachLM150 Test Set.json 131 | { 132 | "instruction": "", 133 | "input": "", 134 | "reference response": "" 135 | } 136 | ``` 137 | 138 | ## ⚠️ Limitations 139 | The current 7B version of CoachLM may still generate undesirable content occasionally, including hallucinated content, repeated text and meaningless phrases. To mitigate the disorderness and increase the reliability, a rule-based post-processing performed on the output of CoachLM is recommended, removing the non-English characters, repeated strings and excessively long answers. Another solution is to train CoachLM on larger foundation models such as 13B or 60B. 140 | -------------------------------------------------------------------------------- /Technical Report.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lunyiliu/CoachLM/37617299916c3ae550b2f6da713835e21bea6e60/Technical Report.pdf -------------------------------------------------------------------------------- /asset/.gitkeep: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /asset/coachLM.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lunyiliu/CoachLM/37617299916c3ae550b2f6da713835e21bea6e60/asset/coachLM.png -------------------------------------------------------------------------------- /asset/dataset_score.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lunyiliu/CoachLM/37617299916c3ae550b2f6da713835e21bea6e60/asset/dataset_score.png -------------------------------------------------------------------------------- /asset/illustration.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lunyiliu/CoachLM/37617299916c3ae550b2f6da713835e21bea6e60/asset/illustration.png -------------------------------------------------------------------------------- /asset/win_rates.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lunyiliu/CoachLM/37617299916c3ae550b2f6da713835e21bea6e60/asset/win_rates.png -------------------------------------------------------------------------------- /data/.ipynb_checkpoints/dataset_info-checkpoint.json: -------------------------------------------------------------------------------- 1 | { 2 | "alpaca_en": { 3 | "file_name": "alpaca_data_en_52k.json", 4 | "file_sha1": "607f94a7f581341e59685aef32f531095232cf23" 5 | }, 6 | "alpaca_zh": { 7 | "file_name": "alpaca_data_zh_51k.json", 8 | "file_sha1": "e655af3db557a4197f7b0cf92e1986b08fae6311" 9 | }, 10 | "alpaca_gpt4_en": { 11 | "file_name": "alpaca_gpt4_data_en.json", 12 | "file_sha1": "647f4ad447bd993e4b6b6223d1be15208bab694a" 13 | }, 14 | "alpaca_gpt4_zh": { 15 | "file_name": "alpaca_gpt4_data_zh.json", 16 | "file_sha1": "3eaa3bda364ccdd59925d7448a698256c31ef845" 17 | }, 18 | "self_cognition": { 19 | "file_name": "self_cognition.json", 20 | "file_sha1": "6287a730ada924fc5d9eadc6d8f865e01b7a6f67" 21 | }, 22 | "oaast_sft": { 23 | "file_name": "oaast_sft.json", 24 | "file_sha1": "7baf5d43e67a91f9bbdf4e400dbe033b87e9757e", 25 | "columns": { 26 | "prompt": "instruction", 27 | "query": "input", 28 | "response": "output", 29 | "history": "history" 30 | } 31 | }, 32 | "oaast_sft_zh": { 33 | "file_name": "oaast_sft_zh.json", 34 | "file_sha1": "a6a91f18f80f37b10ded9cf633fb50c033bf7b9f", 35 | "columns": { 36 | "prompt": "instruction", 37 | "query": "input", 38 | "response": "output", 39 | "history": "history" 40 | } 41 | }, 42 | "sharegpt_zh": { 43 | "file_name": "sharegpt_zh_27k.json", 44 | "file_sha1": "baf766bcf3d61f1b783728c14ce695af57a86e6e", 45 | "columns": { 46 | "prompt": "instruction", 47 | "query": "input", 48 | "response": "output", 49 | "history": "history" 50 | } 51 | }, 52 | "refgpt_zh_p1": { 53 | "file_name": "refgpt_zh_50k_p1.json", 54 | "file_sha1": "b40f4f4d0ffacd16da7c275b056d5b6670021752", 55 | "columns": { 56 | "prompt": "instruction", 57 | "query": "input", 58 | "response": "output", 59 | "history": "history" 60 | } 61 | }, 62 | "refgpt_zh_p2": { 63 | "file_name": "refgpt_zh_50k_p2.json", 64 | "file_sha1": "181f32b2c60264a29f81f59d3c76095793eae1b0", 65 | "columns": { 66 | "prompt": "instruction", 67 | "query": "input", 68 | "response": "output", 69 | "history": "history" 70 | } 71 | }, 72 | "lima": { 73 | "file_name": "lima.json", 74 | "file_sha1": "9db59f6b7007dc4b17529fc63379b9cd61640f37", 75 | "columns": { 76 | "prompt": "instruction", 77 | "query": "input", 78 | "response": "output", 79 | "history": "history" 80 | } 81 | }, 82 | "example": { 83 | "script_url": "example_dataset", 84 | "columns": { 85 | "prompt": "instruction", 86 | "query": "input", 87 | "response": "output", 88 | "history": "history" 89 | } 90 | }, 91 | "guanaco": { 92 | "hf_hub_url": "JosephusCheung/GuanacoDataset" 93 | }, 94 | "belle_0.5m": { 95 | "hf_hub_url": "BelleGroup/train_0.5M_CN" 96 | }, 97 | "belle_1m": { 98 | "hf_hub_url": "BelleGroup/train_1M_CN" 99 | }, 100 | "belle_2m": { 101 | "hf_hub_url": "BelleGroup/train_2M_CN" 102 | }, 103 | "belle_dialog": { 104 | "hf_hub_url": "BelleGroup/generated_chat_0.4M" 105 | }, 106 | "belle_math": { 107 | "hf_hub_url": "BelleGroup/school_math_0.25M" 108 | }, 109 | "belle_multiturn": { 110 | "script_url": "belle_multiturn", 111 | "columns": { 112 | "prompt": "instruction", 113 | "query": "", 114 | "response": "output", 115 | "history": "history" 116 | } 117 | }, 118 | "firefly": { 119 | "hf_hub_url": "YeungNLP/firefly-train-1.1M", 120 | "columns": { 121 | "prompt": "input", 122 | "query": "", 123 | "response": "target", 124 | "history": "" 125 | } 126 | }, 127 | "codealpaca": { 128 | "hf_hub_url": "sahil2801/CodeAlpaca-20k" 129 | }, 130 | "alpaca_cot": { 131 | "hf_hub_url": "QingyiSi/Alpaca-CoT" 132 | }, 133 | "webqa": { 134 | "hf_hub_url": "suolyer/webqa", 135 | "columns": { 136 | "prompt": "input", 137 | "query": "", 138 | "response": "output", 139 | "history": "" 140 | } 141 | }, 142 | "ultra_chat": { 143 | "script_url": "ultra_chat", 144 | "columns": { 145 | "prompt": "instruction", 146 | "query": "", 147 | "response": "output", 148 | "history": "history" 149 | } 150 | }, 151 | "novel_tokens512_50k": { 152 | "hf_hub_url": "zxbsmk/webnovel_cn" 153 | }, 154 | "comparison_gpt4_en": { 155 | "file_name": "comparison_gpt4_data_en.json", 156 | "file_sha1": "96fa18313544e22444fe20eead7754b17da452ae" 157 | }, 158 | "comparison_gpt4_zh": { 159 | "file_name": "comparison_gpt4_data_zh.json", 160 | "file_sha1": "515b18ed497199131ddcc1af950345c11dc5c7fd" 161 | }, 162 | "hh_rlhf_en": { 163 | "script_url": "hh_rlhf_en", 164 | "columns": { 165 | "prompt": "instruction", 166 | "query": "", 167 | "response": "output", 168 | "history": "history" 169 | } 170 | }, 171 | "oaast_rm": { 172 | "file_name": "oaast_rm.json", 173 | "file_sha1": "622d420e9b70003b210618253bd3d9d2891d86cb", 174 | "columns": { 175 | "prompt": "instruction", 176 | "query": "input", 177 | "response": "output", 178 | "history": "history" 179 | } 180 | }, 181 | "oaast_rm_zh": { 182 | "file_name": "oaast_rm_zh.json", 183 | "file_sha1": "1065af1f3784dd61be5e79713a35f427b713a232", 184 | "columns": { 185 | "prompt": "instruction", 186 | "query": "input", 187 | "response": "output", 188 | "history": "history" 189 | } 190 | } 191 | } 192 | -------------------------------------------------------------------------------- /data/.ipynb_checkpoints/example-checkpoint.json: -------------------------------------------------------------------------------- 1 | { 2 | "instruction": "Improve the following instruction, input and response pair to be more specific, detailed with more logical steps and grammarly corrected.", 3 | "input": "Instruction: Name three natural elements. Response: Water, air, and fire.", 4 | "output": "Instruction: Name three natural elements. Response: Some examples of natural elements are:\n\n- Oxygen: This is the most abundant element in the Earth's crust and the second most abundant element in the atmosphere. Oxygen is essential for life, as it is involved in cellular respiration and other metabolic processes. \n\n- Iron: This is the most abundant element in the Earth's core and the fourth most abundant element in the Earth's crust. Iron is a metal that has many physical and chemical properties, such as strength, magnetism, and conductivity. Iron is also important for life, as it is a component of hemoglobin, which transports oxygen in the blood. \n\n- Gold: This is one of the rarest and most valuable elements on Earth. Gold is a metal that has a shiny yellow color and a high density. Gold is resistant to corrosion and oxidation, which makes it suitable for jewelry and coins." 5 | } -------------------------------------------------------------------------------- /data/dataset_info.json: -------------------------------------------------------------------------------- 1 | { 2 | "example": { 3 | "file_name": "example.json", 4 | "file_sha1": "", 5 | "columns": { 6 | "prompt": "instruction", 7 | "query": "input", 8 | "response": "output" 9 | } 10 | } 11 | } 12 | -------------------------------------------------------------------------------- /data/example.json: -------------------------------------------------------------------------------- 1 | { 2 | "instruction": "Improve the following instruction, input and response pair to be more specific, detailed with more logical steps and grammarly corrected.", 3 | "input": "Instruction: Name three natural elements. Response: Water, air, and fire.", 4 | "output": "Instruction: Name three natural elements. Response: Some examples of natural elements are:\n\n- Oxygen: This is the most abundant element in the Earth's crust and the second most abundant element in the atmosphere. Oxygen is essential for life, as it is involved in cellular respiration and other metabolic processes. \n\n- Iron: This is the most abundant element in the Earth's core and the fourth most abundant element in the Earth's crust. Iron is a metal that has many physical and chemical properties, such as strength, magnetism, and conductivity. Iron is also important for life, as it is a component of hemoglobin, which transports oxygen in the blood. \n\n- Gold: This is one of the rarest and most valuable elements on Earth. Gold is a metal that has a shiny yellow color and a high density. Gold is resistant to corrosion and oxidation, which makes it suitable for jewelry and coins." 5 | } -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.13.1 2 | transformers>=4.29.1 3 | datasets>=2.12.0 4 | accelerate>=0.21.0 5 | peft>=0.4.0 6 | trl>=0.4.7 7 | sentencepiece 8 | jieba 9 | rouge-chinese 10 | nltk 11 | gradio>=3.36.0 12 | uvicorn 13 | pydantic==1.10.11 14 | fastapi==0.95.1 15 | sse-starlette 16 | matplotlib 17 | protobuf 18 | cpm-kernels 19 | -------------------------------------------------------------------------------- /src/glmtuner/__init__.py: -------------------------------------------------------------------------------- 1 | from glmtuner.chat import ChatModel 2 | 3 | 4 | __version__ = "0.1.5" 5 | -------------------------------------------------------------------------------- /src/glmtuner/api/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lunyiliu/CoachLM/37617299916c3ae550b2f6da713835e21bea6e60/src/glmtuner/api/__init__.py -------------------------------------------------------------------------------- /src/glmtuner/api/app.py: -------------------------------------------------------------------------------- 1 | import uvicorn 2 | from fastapi import FastAPI, HTTPException 3 | from fastapi.middleware.cors import CORSMiddleware 4 | from contextlib import asynccontextmanager 5 | from sse_starlette import EventSourceResponse 6 | from typing import List, Tuple 7 | 8 | from glmtuner.tuner import get_infer_args 9 | from glmtuner.extras.misc import torch_gc 10 | from glmtuner.chat.stream_chat import ChatModel 11 | from glmtuner.api.protocol import ( 12 | Role, 13 | Finish, 14 | ModelCard, 15 | ModelList, 16 | ChatMessage, 17 | DeltaMessage, 18 | ChatCompletionRequest, 19 | ChatCompletionResponse, 20 | ChatCompletionStreamResponse, 21 | ChatCompletionResponseChoice, 22 | ChatCompletionResponseStreamChoice, 23 | ChatCompletionResponseUsage 24 | ) 25 | 26 | 27 | @asynccontextmanager 28 | async def lifespan(app: FastAPI): # collects GPU memory 29 | yield 30 | torch_gc() 31 | 32 | 33 | def create_app(chat_model: ChatModel) -> FastAPI: 34 | app = FastAPI(lifespan=lifespan) 35 | 36 | app.add_middleware( 37 | CORSMiddleware, 38 | allow_origins=["*"], 39 | allow_credentials=True, 40 | allow_methods=["*"], 41 | allow_headers=["*"], 42 | ) 43 | 44 | @app.get("/v1/models", response_model=ModelList) 45 | async def list_models(): 46 | model_card = ModelCard(id="gpt-3.5-turbo") 47 | return ModelList(data=[model_card]) 48 | 49 | @app.post("/v1/chat/completions", response_model=ChatCompletionResponse) 50 | async def create_chat_completion(request: ChatCompletionRequest): 51 | if request.messages[-1].role != Role.USER: 52 | raise HTTPException(status_code=400, detail="Invalid request") 53 | query = request.messages[-1].content 54 | 55 | prev_messages = request.messages[:-1] 56 | if len(prev_messages) > 0 and prev_messages[0].role == Role.SYSTEM: 57 | prefix = prev_messages.pop(0).content 58 | else: 59 | prefix = None 60 | 61 | history = [] 62 | if len(prev_messages) % 2 == 0: 63 | for i in range(0, len(prev_messages), 2): 64 | if prev_messages[i].role == Role.USER and prev_messages[i+1].role == Role.ASSISTANT: 65 | history.append([prev_messages[i].content, prev_messages[i+1].content]) 66 | 67 | if request.stream: 68 | generate = predict(query, history, prefix, request) 69 | return EventSourceResponse(generate, media_type="text/event-stream") 70 | 71 | response, (prompt_length, response_length) = chat_model.chat( 72 | query, history, prefix, temperature=request.temperature, top_p=request.top_p, max_new_tokens=request.max_tokens 73 | ) 74 | 75 | usage = ChatCompletionResponseUsage( 76 | prompt_tokens=prompt_length, 77 | completion_tokens=response_length, 78 | total_tokens=prompt_length+response_length 79 | ) 80 | 81 | choice_data = ChatCompletionResponseChoice( 82 | index=0, 83 | message=ChatMessage(role=Role.ASSISTANT, content=response), 84 | finish_reason=Finish.STOP 85 | ) 86 | 87 | return ChatCompletionResponse(model=request.model, choices=[choice_data], usage=usage) 88 | 89 | async def predict(query: str, history: List[Tuple[str, str]], prefix: str, request: ChatCompletionRequest): 90 | choice_data = ChatCompletionResponseStreamChoice( 91 | index=0, 92 | delta=DeltaMessage(role=Role.ASSISTANT), 93 | finish_reason=None 94 | ) 95 | chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data]) 96 | yield chunk.json(exclude_unset=True, ensure_ascii=False) 97 | 98 | for new_text in chat_model.stream_chat( 99 | query, history, prefix, temperature=request.temperature, top_p=request.top_p, max_new_tokens=request.max_tokens 100 | ): 101 | if len(new_text) == 0: 102 | continue 103 | 104 | choice_data = ChatCompletionResponseStreamChoice( 105 | index=0, 106 | delta=DeltaMessage(content=new_text), 107 | finish_reason=None 108 | ) 109 | chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data]) 110 | yield chunk.json(exclude_unset=True, ensure_ascii=False) 111 | 112 | choice_data = ChatCompletionResponseStreamChoice( 113 | index=0, 114 | delta=DeltaMessage(), 115 | finish_reason=Finish.STOP 116 | ) 117 | chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data]) 118 | yield chunk.json(exclude_unset=True, ensure_ascii=False) 119 | yield "[DONE]" 120 | 121 | return app 122 | 123 | 124 | if __name__ == "__main__": 125 | chat_model = ChatModel(*get_infer_args()) 126 | app = create_app(chat_model) 127 | uvicorn.run(app, host="0.0.0.0", port=8000, workers=1) 128 | -------------------------------------------------------------------------------- /src/glmtuner/api/protocol.py: -------------------------------------------------------------------------------- 1 | import time 2 | from enum import Enum 3 | from pydantic import BaseModel, Field 4 | from typing import List, Optional 5 | 6 | 7 | class Role(str, Enum): 8 | USER = "user" 9 | ASSISTANT = "assistant" 10 | SYSTEM = "system" 11 | 12 | 13 | class Finish(str, Enum): 14 | STOP = "stop" 15 | LENGTH = "length" 16 | 17 | 18 | class ModelCard(BaseModel): 19 | id: str 20 | object: Optional[str] = "model" 21 | created: Optional[int] = Field(default_factory=lambda: int(time.time())) 22 | owned_by: Optional[str] = "owner" 23 | root: Optional[str] = None 24 | parent: Optional[str] = None 25 | permission: Optional[list] = [] 26 | 27 | 28 | class ModelList(BaseModel): 29 | object: Optional[str] = "list" 30 | data: Optional[List[ModelCard]] = [] 31 | 32 | 33 | class ChatMessage(BaseModel): 34 | role: Role 35 | content: str 36 | 37 | 38 | class DeltaMessage(BaseModel): 39 | role: Optional[Role] = None 40 | content: Optional[str] = None 41 | 42 | 43 | class ChatCompletionRequest(BaseModel): 44 | model: str 45 | messages: List[ChatMessage] 46 | temperature: Optional[float] = None 47 | top_p: Optional[float] = None 48 | n: Optional[int] = 1 49 | max_tokens: Optional[int] = None 50 | stream: Optional[bool] = False 51 | 52 | 53 | class ChatCompletionResponseChoice(BaseModel): 54 | index: int 55 | message: ChatMessage 56 | finish_reason: Finish 57 | 58 | 59 | class ChatCompletionResponseStreamChoice(BaseModel): 60 | index: int 61 | delta: DeltaMessage 62 | finish_reason: Optional[Finish] = None 63 | 64 | 65 | class ChatCompletionResponseUsage(BaseModel): 66 | prompt_tokens: int 67 | completion_tokens: int 68 | total_tokens: int 69 | 70 | 71 | class ChatCompletionResponse(BaseModel): 72 | id: Optional[str] = "chatcmpl-default" 73 | object: Optional[str] = "chat.completion" 74 | created: Optional[int] = Field(default_factory=lambda: int(time.time())) 75 | model: str 76 | choices: List[ChatCompletionResponseChoice] 77 | usage: ChatCompletionResponseUsage 78 | 79 | 80 | class ChatCompletionStreamResponse(BaseModel): 81 | id: Optional[str] = "chatcmpl-default" 82 | object: Optional[str] = "chat.completion.chunk" 83 | created: Optional[int] = Field(default_factory=lambda: int(time.time())) 84 | model: str 85 | choices: List[ChatCompletionResponseStreamChoice] 86 | -------------------------------------------------------------------------------- /src/glmtuner/chat/__init__.py: -------------------------------------------------------------------------------- 1 | from glmtuner.chat.stream_chat import ChatModel 2 | -------------------------------------------------------------------------------- /src/glmtuner/chat/stream_chat.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Any, Dict, Generator, List, Optional, Tuple 3 | from threading import Thread 4 | from transformers import TextIteratorStreamer 5 | 6 | from glmtuner.extras.misc import dispatch_model, get_logits_processor 7 | from glmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments 8 | from glmtuner.tuner import load_model_and_tokenizer 9 | 10 | 11 | class ChatModel: 12 | 13 | def __init__( 14 | self, 15 | model_args: ModelArguments, 16 | data_args: DataArguments, 17 | finetuning_args: FinetuningArguments, 18 | generating_args: GeneratingArguments 19 | ) -> None: 20 | self.model, self.tokenizer = load_model_and_tokenizer(model_args, finetuning_args) 21 | self.model = dispatch_model(self.model, use_v2=(self.tokenizer.eos_token_id==2)) 22 | self.source_prefix = data_args.source_prefix 23 | self.generating_args = generating_args 24 | 25 | def get_prompt( 26 | self, query: str, history: Optional[List[Tuple[str, str]]] = None, prefix: Optional[str] = None 27 | ) -> str: 28 | prefix = prefix + "\n" if prefix else "" # add separator for non-empty prefix 29 | history = history or [] 30 | prompt = "" 31 | for i, (old_query, response) in enumerate(history): 32 | prompt += "[Round {}]\n\n问:{}\n\n答:{}\n\n".format(i+1, old_query, response) 33 | prompt += "[Round {}]\n\n问:{}\n\n答:".format(len(history)+1, query) 34 | prompt = prefix + prompt 35 | return prompt 36 | 37 | def process_args( 38 | self, query: str, history: Optional[List[Tuple[str, str]]] = None, prefix: Optional[str] = None, **input_kwargs 39 | ) -> Tuple[Dict[str, Any], int]: 40 | prefix = prefix or self.source_prefix 41 | 42 | inputs = self.tokenizer([self.get_prompt(query, history, prefix)], return_tensors="pt") 43 | inputs = inputs.to(self.model.device) 44 | prompt_length = len(inputs["input_ids"][0]) 45 | 46 | do_sample = input_kwargs.pop("do_sample", None) 47 | temperature = input_kwargs.pop("temperature", None) 48 | top_p = input_kwargs.pop("top_p", None) 49 | top_k = input_kwargs.pop("top_k", None) 50 | repetition_penalty = input_kwargs.pop("repetition_penalty", None) 51 | max_length = input_kwargs.pop("max_length", None) 52 | max_new_tokens = input_kwargs.pop("max_new_tokens", None) 53 | 54 | gen_kwargs = self.generating_args.to_dict() 55 | gen_kwargs.update(dict( 56 | input_ids=inputs["input_ids"], 57 | do_sample=do_sample if do_sample is not None else gen_kwargs["do_sample"], 58 | temperature=temperature or gen_kwargs["temperature"], 59 | top_p=top_p or gen_kwargs["top_p"], 60 | top_k=top_k or gen_kwargs["top_k"], 61 | repetition_penalty=repetition_penalty or gen_kwargs["repetition_penalty"], 62 | logits_processor=get_logits_processor() 63 | )) 64 | 65 | if max_length: 66 | gen_kwargs.pop("max_new_tokens", None) 67 | gen_kwargs["max_length"] = max_length 68 | 69 | if max_new_tokens: 70 | gen_kwargs.pop("max_length", None) 71 | gen_kwargs["max_new_tokens"] = max_new_tokens 72 | 73 | return gen_kwargs, prompt_length 74 | 75 | @torch.inference_mode() 76 | def chat( 77 | self, query: str, history: Optional[List[Tuple[str, str]]] = None, prefix: Optional[str] = None, **input_kwargs 78 | ) -> Tuple[str, Tuple[int, int]]: 79 | gen_kwargs, prompt_length = self.process_args(query, history, prefix, **input_kwargs) 80 | generation_output = self.model.generate(**gen_kwargs) 81 | outputs = generation_output.tolist()[0][prompt_length:] 82 | response = self.tokenizer.decode(outputs, skip_special_tokens=True) 83 | response_length = len(outputs) 84 | return response, (prompt_length, response_length) 85 | 86 | @torch.inference_mode() 87 | def stream_chat( 88 | self, query: str, history: Optional[List[Tuple[str, str]]] = None, prefix: Optional[str] = None, **input_kwargs 89 | ) -> Generator[str, None, None]: 90 | gen_kwargs, _ = self.process_args(query, history, prefix, **input_kwargs) 91 | streamer = TextIteratorStreamer(self.tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True) 92 | gen_kwargs["streamer"] = streamer 93 | 94 | thread = Thread(target=self.model.generate, kwargs=gen_kwargs) 95 | thread.start() 96 | 97 | yield from streamer 98 | -------------------------------------------------------------------------------- /src/glmtuner/dsets/__init__.py: -------------------------------------------------------------------------------- 1 | from glmtuner.dsets.collator import DataCollatorForChatGLM 2 | from glmtuner.dsets.loader import get_dataset 3 | from glmtuner.dsets.preprocess import preprocess_dataset 4 | from glmtuner.dsets.utils import split_dataset 5 | -------------------------------------------------------------------------------- /src/glmtuner/dsets/collator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Any, Dict, Optional, Sequence 3 | 4 | from transformers import DataCollatorWithPadding, BatchEncoding 5 | from transformers.modeling_utils import PreTrainedModel 6 | from transformers.tokenization_utils import PreTrainedTokenizer 7 | 8 | from glmtuner.extras.constants import IGNORE_INDEX 9 | 10 | 11 | class DataCollatorForChatGLM(DataCollatorWithPadding): 12 | r""" 13 | Data collator for ChatGLM. It is capable of dynamically padding for batched data. 14 | """ 15 | def __init__( 16 | self, 17 | tokenizer: PreTrainedTokenizer, 18 | model: PreTrainedModel, 19 | ignore_pad_token_for_loss: Optional[bool] = False 20 | ): 21 | super().__init__(tokenizer, padding=True) 22 | self.model = model 23 | self.label_pad_token_id = IGNORE_INDEX if ignore_pad_token_for_loss else tokenizer.pad_token_id 24 | if tokenizer.eos_token_id == 130005: 25 | self.get_attention_masks = self.get_attention_masks_v1 26 | self.get_position_ids = self.get_position_ids_v1 27 | else: 28 | self.get_attention_masks = self.get_attention_masks_v2 29 | self.get_position_ids = self.get_position_ids_v2 30 | 31 | def get_attention_masks_v1(self, input_ids: torch.Tensor, device: torch.device) -> torch.Tensor: 32 | r""" 33 | Generates attention masks for left-padded sequences. 34 | 35 | Note that ChatGLM assigns False on token to be attended in attention mask. In general settings, it should be True. 36 | 37 | According to: https://huggingface.co/THUDM/chatglm-6b/blob/v1.1.0/modeling_chatglm.py#L680 38 | """ 39 | batch_size, seq_length = input_ids.size() 40 | attention_mask = torch.ones((batch_size, seq_length, seq_length), device=device) 41 | attention_mask.tril_() 42 | 43 | for i, seq in enumerate(input_ids): 44 | attention_mask[i, :, :(seq == self.tokenizer.bos_token_id).nonzero()[0].item()] = 1 # context 45 | attention_mask[i, :, :(seq != self.tokenizer.pad_token_id).nonzero()[0].item()] = 0 # padding 46 | 47 | attention_mask.unsqueeze_(1) 48 | attention_mask = (attention_mask < 0.5).bool() 49 | return attention_mask 50 | 51 | def get_position_ids_v1(self, input_ids: torch.Tensor, device: torch.device) -> torch.Tensor: 52 | r""" 53 | Generates position ids for left-padded sequenes. 54 | 55 | According to: https://huggingface.co/THUDM/chatglm-6b/blob/v1.1.0/modeling_chatglm.py#L692 56 | """ 57 | batch_size, seq_length = input_ids.size() 58 | mask: int = self.model.config.mask_token_id 59 | gmask: int = self.model.config.gmask_token_id 60 | position_ids = torch.zeros((batch_size, seq_length), dtype=torch.long, device=device) 61 | block_position_ids = torch.zeros((batch_size, seq_length), dtype=torch.long, device=device) 62 | 63 | for i, seq in enumerate(input_ids): 64 | mask_token = gmask if gmask in seq else mask 65 | context_length = (seq == self.tokenizer.bos_token_id).nonzero()[0].item() 66 | padding_length = (seq != self.tokenizer.pad_token_id).nonzero()[0].item() 67 | position_ids[i, padding_length:] = torch.arange( 68 | seq_length - padding_length, 69 | dtype=torch.long, 70 | device=device 71 | ) 72 | if self.model.position_encoding_2d or (mask_token != gmask): # 2d position encoding or not gMASK 73 | position_ids[i, context_length:] = (seq == mask_token).nonzero()[0].item() - padding_length # mask position 74 | block_position_ids[i, context_length:] = torch.arange( 75 | seq_length - context_length, 76 | dtype=torch.long, 77 | device=device 78 | ) + 1 79 | 80 | if self.model.position_encoding_2d: 81 | position_ids = torch.stack((position_ids, block_position_ids), dim=1) 82 | 83 | return position_ids 84 | 85 | def get_attention_masks_v2(self, input_ids: torch.Tensor, device: torch.device) -> torch.Tensor: 86 | r""" 87 | Generates attention masks for left-padded sequences. 88 | """ 89 | batch_size, seq_length = input_ids.size() 90 | attention_mask = torch.ones((batch_size, seq_length), device=device) 91 | 92 | for i, seq in enumerate(input_ids): 93 | attention_mask[i, :(seq != self.tokenizer.pad_token_id).nonzero()[0].item()] = 0 # padding 94 | 95 | return attention_mask 96 | 97 | def get_position_ids_v2(self, input_ids: torch.Tensor, device: torch.device) -> torch.Tensor: 98 | r""" 99 | Generates position ids for left-padded sequenes. 100 | """ 101 | batch_size, seq_length = input_ids.size() 102 | position_ids = torch.zeros((batch_size, seq_length), dtype=torch.long, device=device) 103 | 104 | for i, seq in enumerate(input_ids): 105 | padding_length = (seq != self.tokenizer.pad_token_id).nonzero()[0].item() 106 | position_ids[i, padding_length:] = torch.arange(seq_length - padding_length, dtype=torch.long, device=device) 107 | 108 | return position_ids 109 | 110 | def __call__(self, features: Sequence[Dict[str, Any]]) -> BatchEncoding: 111 | r""" 112 | Pads batched data to the longest sequence in the batch. 113 | 114 | We adopt left-padding in both training and evaluation. 115 | """ 116 | if isinstance(features[0]["input_ids"], torch.Tensor): 117 | input_ids = [feature["input_ids"].clone().detach().flip(0) for feature in features] 118 | else: 119 | input_ids = [torch.tensor(feature["input_ids"]).flip(0) for feature in features] 120 | 121 | if "labels" in features[0]: 122 | if isinstance(features[0]["labels"], torch.Tensor): 123 | labels = [feature["labels"].clone().detach().flip(0) for feature in features] 124 | else: 125 | labels = [torch.tensor(feature["labels"]).flip(0) for feature in features] 126 | input_ids += labels # pad them to the same length 127 | 128 | input_ids = torch.nn.utils.rnn.pad_sequence( 129 | input_ids, 130 | batch_first=True, 131 | padding_value=self.tokenizer.pad_token_id 132 | ).flip(-1) 133 | 134 | batch = {} 135 | 136 | if "labels" in features[0]: 137 | input_ids, labels = input_ids.split(len(features), dim=0) 138 | labels = torch.where(labels != self.tokenizer.pad_token_id, labels, self.label_pad_token_id) 139 | batch["labels"] = labels 140 | 141 | batch["input_ids"] = input_ids 142 | batch["attention_mask"] = self.get_attention_masks(input_ids, device=input_ids.device) 143 | batch["position_ids"] = self.get_position_ids(input_ids, device=input_ids.device) 144 | 145 | return BatchEncoding(batch) 146 | -------------------------------------------------------------------------------- /src/glmtuner/dsets/loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import hashlib 3 | from typing import List 4 | 5 | from datasets import Dataset, concatenate_datasets, load_dataset 6 | 7 | from glmtuner.extras.logging import get_logger 8 | from glmtuner.hparams import ModelArguments, DataArguments 9 | 10 | 11 | logger = get_logger(__name__) 12 | 13 | 14 | def get_dataset( 15 | model_args: ModelArguments, 16 | data_args: DataArguments 17 | ) -> Dataset: 18 | 19 | def checksum(file_path, hash): 20 | with open(file_path, "rb") as datafile: 21 | binary_data = datafile.read() 22 | sha1 = hashlib.sha1(binary_data).hexdigest() 23 | if sha1 != hash: 24 | logger.warning("Checksum failed for {}. It may vary depending on the platform.".format(file_path)) 25 | 26 | ext2type = { 27 | "csv": "csv", 28 | "json": "json", 29 | "jsonl": "json" 30 | } 31 | 32 | max_samples = data_args.max_samples 33 | all_datasets: List[Dataset] = [] # support multiple datasets 34 | 35 | for dataset_attr in data_args.dataset_list: 36 | 37 | logger.info("Loading dataset {}...".format(dataset_attr)) 38 | 39 | if dataset_attr.load_from == "hf_hub": 40 | data_path = dataset_attr.dataset_name 41 | data_files = None 42 | 43 | elif dataset_attr.load_from == "script": 44 | data_path = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name) 45 | data_files = None 46 | 47 | elif dataset_attr.load_from == "file": # support folder or file 48 | data_path = None 49 | data_files: List[str] = [] 50 | 51 | if os.path.isdir(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)): # folder 52 | for file_name in os.listdir(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)): 53 | data_files.append(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name, file_name)) 54 | 55 | if data_path is None: 56 | data_path = ext2type.get(data_files[0].split(".")[-1], None) 57 | else: 58 | assert ext2type.get(data_files[-1].split(".")[-1], None) == data_path, \ 59 | "more than one file formats found in a single folder" 60 | 61 | elif os.path.isfile(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)): # file 62 | data_files.append(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)) 63 | data_path = ext2type.get(data_files[0].split(".")[-1], None) 64 | 65 | else: 66 | raise ValueError("File not found.") 67 | 68 | assert data_path, "File extension must be csv, json or jsonl." 69 | 70 | if len(data_files) == 1 and dataset_attr.dataset_sha1 is not None: 71 | checksum(data_files[0], dataset_attr.dataset_sha1) 72 | else: 73 | logger.warning("Checksum failed: missing SHA-1 hash value in dataset_info.json or too many files.") 74 | 75 | else: 76 | raise NotImplementedError 77 | 78 | raw_datasets = load_dataset( 79 | data_path, 80 | data_files=data_files, 81 | cache_dir=model_args.cache_dir, 82 | use_auth_token=True if model_args.use_auth_token else None 83 | ) 84 | dataset = raw_datasets[data_args.split] 85 | 86 | if max_samples is not None: 87 | max_samples_temp = min(len(dataset), max_samples) 88 | dataset = dataset.select(range(max_samples_temp)) 89 | 90 | dummy_data = [None] * len(dataset) 91 | 92 | for column_name, target_name in [ 93 | ("prompt_column", "prompt"), 94 | ("query_column", "query"), 95 | ("response_column", "response"), 96 | ("history_column", "history") 97 | ]: # every dataset will have 4 columns same as each other 98 | if getattr(dataset_attr, column_name) != target_name: 99 | if getattr(dataset_attr, column_name): 100 | dataset = dataset.rename_column(getattr(dataset_attr, column_name), target_name) 101 | else: # None or empty string 102 | dataset = dataset.add_column(target_name, dummy_data) 103 | 104 | all_datasets.append(dataset) 105 | 106 | if len(data_args.dataset_list) == 1: 107 | all_datasets = all_datasets[0] 108 | else: 109 | all_datasets = concatenate_datasets(all_datasets) 110 | 111 | return all_datasets 112 | -------------------------------------------------------------------------------- /src/glmtuner/dsets/preprocess.py: -------------------------------------------------------------------------------- 1 | from typing import Literal 2 | from transformers import Seq2SeqTrainingArguments 3 | from transformers.tokenization_utils import PreTrainedTokenizer 4 | 5 | from datasets import Dataset 6 | 7 | from glmtuner.extras.constants import IGNORE_INDEX 8 | from glmtuner.hparams import DataArguments 9 | 10 | 11 | def preprocess_dataset( 12 | dataset: Dataset, 13 | tokenizer: PreTrainedTokenizer, 14 | data_args: DataArguments, 15 | training_args: Seq2SeqTrainingArguments, 16 | stage: Literal["sft", "rm", "ppo"] 17 | ) -> Dataset: 18 | 19 | column_names = list(dataset.column_names) 20 | prefix = data_args.source_prefix if data_args.source_prefix is not None else "" 21 | 22 | def format_example(examples): # support question with a single answer or multiple answers 23 | for i in range(len(examples["prompt"])): 24 | if examples["prompt"][i] and examples["response"][i]: 25 | query, answer = examples["prompt"][i], examples["response"][i] 26 | query = query + examples["query"][i] if examples["query"][i] else query 27 | history = examples["history"][i] if examples["history"][i] else [] 28 | prompt = "" 29 | for j, (old_query, response) in enumerate(history): 30 | prompt += "[Round {}]\n\n问:{}\n\n答:{}\n\n".format(j+1, old_query, response) 31 | prompt += "[Round {}]\n\n问:{}\n\n答:".format(len(history)+1, query) 32 | prompt = prefix + prompt 33 | yield prompt, answer 34 | 35 | def preprocess_supervised_dataset(examples): 36 | # v1: build inputs with format `X [gMASK] Y ` and labels with format `[IGNORE] ... [IGNORE] Y ` 37 | # v2: build inputs with format `[gMASK] sop X Y ` and labels with format `[IGNORE] ... [IGNORE] Y ` 38 | model_inputs = {"input_ids": [], "labels": []} 39 | for prompt, answer in format_example(examples): 40 | source_ids = tokenizer.encode(text=prompt, add_special_tokens=False) 41 | target_ids = tokenizer.encode(text=answer, add_special_tokens=False) 42 | 43 | if len(source_ids) > data_args.max_source_length - 2: # gmask and sop tokens 44 | source_ids = source_ids[:data_args.max_source_length - 2] 45 | if len(target_ids) > data_args.max_target_length - 1: # eos token 46 | target_ids = target_ids[:data_args.max_target_length - 1] 47 | 48 | context_length = len(source_ids) + 2 # gmask and sop tokens 49 | input_ids = tokenizer.build_inputs_with_special_tokens(source_ids, target_ids) 50 | labels = [IGNORE_INDEX] * context_length + input_ids[context_length:] 51 | 52 | model_inputs["input_ids"].append(input_ids) 53 | model_inputs["labels"].append(labels) 54 | return model_inputs 55 | 56 | def preprocess_evaluation_dataset(examples): 57 | # v1: build inputs with format `X [gMASK] ` and labels with format `Y [gMASK] ` 58 | # v2: build inputs with format `[gMASK] sop X` and labels with format `[gMASK] sop Y` 59 | model_inputs = {"input_ids": [], "labels": []} 60 | for prompt, answer in format_example(examples): 61 | source_ids = tokenizer.encode(text=prompt, add_special_tokens=False) 62 | target_ids = tokenizer.encode(text=answer, add_special_tokens=False) 63 | 64 | if len(source_ids) > data_args.max_source_length - 2: # gmask and sop tokens 65 | source_ids = source_ids[:data_args.max_source_length - 2] 66 | if len(target_ids) > data_args.max_target_length - 2: # gmask and sop tokens 67 | target_ids = target_ids[:data_args.max_target_length - 2] 68 | 69 | input_ids = tokenizer.build_inputs_with_special_tokens(source_ids) 70 | labels = tokenizer.build_inputs_with_special_tokens(target_ids) 71 | 72 | model_inputs["input_ids"].append(input_ids) 73 | model_inputs["labels"].append(labels) 74 | return model_inputs 75 | 76 | def preprocess_pairwise_dataset(examples): 77 | # v1: build input pairs with format `X [gMASK] Y1 ` and `X [gMASK] Y2 ` 78 | # v2: build input pairs with format `[gMASK] sop X Y1 ` and `[gMASK] sop X Y2 ` 79 | model_inputs = {"accept_ids": [], "reject_ids": []} 80 | for prompt, answer in format_example(examples): 81 | source_ids = tokenizer.encode(text=prompt, add_special_tokens=False) 82 | accept_ids = tokenizer.encode(text=answer[0], add_special_tokens=False) 83 | reject_ids = tokenizer.encode(text=answer[1], add_special_tokens=False) 84 | 85 | if len(source_ids) > data_args.max_source_length - 2: # gmask and sop tokens 86 | source_ids = source_ids[:data_args.max_source_length - 2] 87 | if len(accept_ids) > data_args.max_target_length - 1: # eos token 88 | accept_ids = accept_ids[:data_args.max_target_length - 1] 89 | if len(reject_ids) > data_args.max_target_length - 1: # eos token 90 | reject_ids = reject_ids[:data_args.max_target_length - 1] 91 | 92 | accept_ids = tokenizer.build_inputs_with_special_tokens(source_ids[:], accept_ids) # avoid copying error 93 | reject_ids = tokenizer.build_inputs_with_special_tokens(source_ids[:], reject_ids) 94 | 95 | model_inputs["accept_ids"].append(accept_ids) 96 | model_inputs["reject_ids"].append(reject_ids) 97 | return model_inputs 98 | 99 | def print_sft_dataset_example(example): 100 | print("input_ids:\n{}".format(example["input_ids"])) 101 | print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False))) 102 | print("label_ids:\n{}".format(example["labels"])) 103 | print("labels:\n{}".format( 104 | tokenizer.decode([d if d != IGNORE_INDEX else tokenizer.pad_token_id for d in example["labels"]], 105 | skip_special_tokens=False) 106 | )) 107 | 108 | def print_pairwise_dataset_example(example): 109 | print("accept_ids:\n{}".format(example["accept_ids"])) 110 | print("accepts:\n{}".format(tokenizer.decode(example["accept_ids"], skip_special_tokens=False))) 111 | print("reject_ids:\n{}".format(example["reject_ids"])) 112 | print("rejects:\n{}".format(tokenizer.decode(example["reject_ids"], skip_special_tokens=False))) 113 | 114 | def print_ppo_dataset_example(example): 115 | print("input_ids:\n{}".format(example["input_ids"])) 116 | print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False))) 117 | 118 | if stage == "sft": 119 | if not training_args.predict_with_generate: 120 | preprocess_function = preprocess_supervised_dataset 121 | else: 122 | preprocess_function = preprocess_evaluation_dataset 123 | elif stage == "rm": 124 | preprocess_function = preprocess_pairwise_dataset 125 | elif stage == "ppo": 126 | preprocess_function = preprocess_evaluation_dataset 127 | 128 | with training_args.main_process_first(desc="dataset map pre-processing"): 129 | dataset = dataset.map( 130 | preprocess_function, 131 | batched=True, 132 | num_proc=data_args.preprocessing_num_workers, 133 | remove_columns=column_names, 134 | load_from_cache_file=not data_args.overwrite_cache, 135 | desc="Running tokenizer on dataset" 136 | ) 137 | 138 | if stage == "sft": 139 | print_sft_dataset_example(dataset[0]) 140 | elif stage == "rm": 141 | print_pairwise_dataset_example(dataset[0]) 142 | elif stage == "ppo": 143 | print_ppo_dataset_example(dataset[0]) 144 | 145 | return dataset 146 | -------------------------------------------------------------------------------- /src/glmtuner/dsets/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | from datasets import Dataset 3 | 4 | 5 | def split_dataset( 6 | dataset: Dataset, dev_ratio: float, do_train: bool 7 | ) -> Dict[str, Dataset]: 8 | # Split the dataset 9 | if do_train: 10 | if dev_ratio > 1e-6: 11 | dataset = dataset.train_test_split(test_size=dev_ratio) 12 | return {"train_dataset": dataset["train"], "eval_dataset": dataset["test"]} 13 | else: 14 | return {"train_dataset": dataset} 15 | else: # do_eval or do_predict 16 | return {"eval_dataset": dataset} 17 | -------------------------------------------------------------------------------- /src/glmtuner/extras/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lunyiliu/CoachLM/37617299916c3ae550b2f6da713835e21bea6e60/src/glmtuner/extras/__init__.py -------------------------------------------------------------------------------- /src/glmtuner/extras/callbacks.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import time 4 | from datetime import timedelta 5 | 6 | from transformers import ( 7 | TrainerCallback, 8 | TrainerControl, 9 | TrainerState, 10 | TrainingArguments 11 | ) 12 | from transformers.trainer_callback import TrainerControl, TrainerState 13 | from transformers.training_args import TrainingArguments 14 | 15 | 16 | class LogCallback(TrainerCallback): 17 | 18 | def __init__(self, runner=None): 19 | self.runner = runner 20 | self.start_time = time.time() 21 | self.tracker = {} 22 | 23 | def on_train_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): 24 | r""" 25 | Event called at the beginning of training. 26 | """ 27 | self.start_time = time.time() 28 | 29 | def on_step_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): 30 | r""" 31 | Event called at the beginning of a training step. If using gradient accumulation, one training step 32 | might take several inputs. 33 | """ 34 | if self.runner is not None and self.runner.aborted: 35 | control.should_epoch_stop = True 36 | control.should_training_stop = True 37 | 38 | def on_substep_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): 39 | r""" 40 | Event called at the end of an substep during gradient accumulation. 41 | """ 42 | if self.runner is not None and self.runner.aborted: 43 | control.should_epoch_stop = True 44 | control.should_training_stop = True 45 | 46 | def on_log(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs) -> None: 47 | r""" 48 | Event called after logging the last logs. 49 | """ 50 | if not state.is_world_process_zero: 51 | return 52 | 53 | cur_time = time.time() 54 | cur_steps = state.log_history[-1].get("step") 55 | elapsed_time = cur_time - self.start_time 56 | avg_time_per_step = elapsed_time / cur_steps if cur_steps != 0 else 0 57 | remaining_steps = state.max_steps - cur_steps 58 | remaining_time = remaining_steps * avg_time_per_step 59 | self.tracker = { 60 | "current_steps": cur_steps, 61 | "total_steps": state.max_steps, 62 | "loss": state.log_history[-1].get("loss", None), 63 | "eval_loss": state.log_history[-1].get("eval_loss", None), 64 | "predict_loss": state.log_history[-1].get("predict_loss", None), 65 | "reward": state.log_history[-1].get("reward", None), 66 | "learning_rate": state.log_history[-1].get("learning_rate", None), 67 | "epoch": state.log_history[-1].get("epoch", None), 68 | "percentage": round(cur_steps / state.max_steps * 100, 2) if state.max_steps != 0 else 100, 69 | "elapsed_time": str(timedelta(seconds=int(elapsed_time))), 70 | "remaining_time": str(timedelta(seconds=int(remaining_time))) 71 | } 72 | os.makedirs(args.output_dir, exist_ok=True) 73 | with open(os.path.join(args.output_dir, "trainer_log.jsonl"), "a", encoding="utf-8") as f: 74 | f.write(json.dumps(self.tracker) + "\n") 75 | -------------------------------------------------------------------------------- /src/glmtuner/extras/constants.py: -------------------------------------------------------------------------------- 1 | IGNORE_INDEX = -100 2 | 3 | VALUE_HEAD_FILE_NAME = "value_head.bin" 4 | 5 | FINETUNING_ARGS_NAME = "finetuning_args.json" 6 | 7 | LAYERNORM_NAMES = ["layernorm"] 8 | 9 | METHODS = ["full", "freeze", "p_tuning", "lora"] 10 | 11 | SUPPORTED_MODELS = { 12 | "ChatGLM-6B": "THUDM/chatglm-6b", 13 | "ChatGLM2-6B": "THUDM/chatglm2-6b" 14 | } 15 | -------------------------------------------------------------------------------- /src/glmtuner/extras/logging.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import logging 3 | 4 | 5 | class LoggerHandler(logging.Handler): 6 | 7 | def __init__(self): 8 | super().__init__() 9 | self.log = "" 10 | 11 | def emit(self, record): 12 | if record.name == "httpx": 13 | return 14 | log_entry = self.format(record) 15 | self.log += log_entry 16 | self.log += "\n\n" 17 | 18 | 19 | def get_logger(name: str) -> logging.Logger: 20 | 21 | formatter = logging.Formatter( 22 | fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 23 | datefmt="%m/%d/%Y %H:%M:%S" 24 | ) 25 | handler = logging.StreamHandler(sys.stdout) 26 | handler.setFormatter(formatter) 27 | 28 | logger = logging.getLogger(name) 29 | logger.setLevel(logging.INFO) 30 | logger.addHandler(handler) 31 | 32 | return logger 33 | -------------------------------------------------------------------------------- /src/glmtuner/extras/misc.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Dict, List, Optional 3 | 4 | from transformers.modeling_utils import PreTrainedModel 5 | from transformers.generation.utils import LogitsProcessorList 6 | from transformers.generation.logits_process import LogitsProcessor 7 | 8 | from glmtuner.extras.constants import LAYERNORM_NAMES 9 | 10 | 11 | class AverageMeter: 12 | r""" 13 | Computes and stores the average and current value. 14 | """ 15 | def __init__(self): 16 | self.reset() 17 | 18 | def reset(self): 19 | self.val = 0 20 | self.avg = 0 21 | self.sum = 0 22 | self.count = 0 23 | 24 | def update(self, val, n=1): 25 | self.val = val 26 | self.sum += val * n 27 | self.count += n 28 | self.avg = self.sum / self.count 29 | 30 | 31 | # Avoid runtime error in model.generate(do_sample=True). 32 | # Borrowed from: https://huggingface.co/THUDM/chatglm-6b/blob/658202d88ac4bb782b99e99ac3adff58b4d0b813/modeling_chatglm.py#L54 33 | class InvalidScoreLogitsProcessor(LogitsProcessor): 34 | 35 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: 36 | if torch.isnan(scores).any() or torch.isinf(scores).any(): 37 | scores.zero_() 38 | scores[..., 5] = 5e4 39 | return scores 40 | 41 | 42 | def get_logits_processor() -> LogitsProcessorList: 43 | logits_processor = LogitsProcessorList() 44 | logits_processor.append(InvalidScoreLogitsProcessor()) 45 | return logits_processor 46 | 47 | 48 | def print_trainable_params(model: torch.nn.Module) -> None: 49 | trainable_params, all_param = 0, 0 50 | for param in model.parameters(): 51 | num_params = param.numel() 52 | # if using DS Zero 3 and the weights are initialized empty 53 | if num_params == 0 and hasattr(param, "ds_numel"): 54 | num_params = param.ds_numel 55 | all_param += num_params 56 | if param.requires_grad: 57 | trainable_params += num_params 58 | print("trainable params: {:d} || all params: {:d} || trainable%: {:.4f}".format( 59 | trainable_params, all_param, 100 * trainable_params / all_param)) 60 | 61 | 62 | # Includes: (1) cast the layernorm in fp32 (2) make output embedding layer require grads (3) upcast the lm_head to fp32 63 | # Inspired by: https://github.com/huggingface/peft/blob/c0209c35abbf88c63aa267800d98a8e212ed0a42/src/peft/utils/other.py#L35 64 | def prepare_model_for_training( 65 | model: PreTrainedModel, 66 | finetuning_type: str, 67 | output_embedding_base_layer: torch.nn.Module, 68 | output_embedding_layer_name: Optional[str] = "lm_head", 69 | use_gradient_checkpointing: Optional[bool] = True, 70 | layer_norm_names: Optional[List[str]] = LAYERNORM_NAMES 71 | ) -> PreTrainedModel: 72 | 73 | for name, param in model.named_parameters(): 74 | if param.ndim == 1 and any(layer_norm_name in name for layer_norm_name in layer_norm_names): 75 | param.data = param.data.to(torch.float32) 76 | 77 | if use_gradient_checkpointing: 78 | model.enable_input_require_grads() 79 | model.gradient_checkpointing_enable() 80 | model.config.use_cache = False # turn off when gradient checkpointing is enabled 81 | 82 | if finetuning_type != "full" and hasattr(output_embedding_base_layer, output_embedding_layer_name): 83 | output_embedding_layer = getattr(output_embedding_base_layer, output_embedding_layer_name) 84 | input_dtype = output_embedding_layer.weight.dtype 85 | 86 | class CastOutputToFloat(torch.nn.Sequential): 87 | 88 | def forward(self, x: torch.Tensor) -> torch.Tensor: 89 | return super().forward(x.to(input_dtype)).to(torch.float32) 90 | 91 | setattr(output_embedding_base_layer, output_embedding_layer_name, CastOutputToFloat(output_embedding_layer)) 92 | 93 | return model 94 | 95 | 96 | def torch_gc() -> None: 97 | r""" 98 | Collects GPU memory. 99 | """ 100 | if torch.cuda.is_available(): 101 | torch.cuda.empty_cache() 102 | torch.cuda.ipc_collect() 103 | 104 | 105 | def auto_configure_device_map(num_gpus: int, use_v2: bool) -> Dict[str, int]: 106 | r""" 107 | Configures device map for ChatGLM. 108 | 109 | Borrowed from: https://github.com/THUDM/ChatGLM-6B/blob/dev_multi_gpu/utils.py#L8 110 | """ 111 | num_layers = 28 112 | layers_per_gpu = 30 / num_gpus 113 | if use_v2: 114 | device_map = { 115 | "transformer.embedding.word_embeddings": 0, 116 | "transformer.encoder.final_layernorm": 0, 117 | "transformer.output_layer": 0, 118 | "transformer.rotary_pos_emb": 0, 119 | "transformer.prefix_encoder": 0, 120 | "lm_head": 0 121 | } 122 | else: 123 | device_map = { 124 | "transformer.word_embeddings": 0, 125 | "transformer.final_layernorm": 0, 126 | "transformer.prefix_encoder": 0, 127 | "lm_head": 0 128 | } 129 | 130 | added_layers = 2 131 | target_gpu = 0 132 | 133 | for i in range(num_layers): 134 | if added_layers >= layers_per_gpu: 135 | target_gpu += 1 136 | added_layers = 0 137 | assert target_gpu < num_gpus 138 | if use_v2: 139 | device_map[f"transformer.encoder.layers.{i}"] = target_gpu 140 | else: 141 | device_map[f"transformer.layers.{i}"] = target_gpu 142 | added_layers += 1 143 | 144 | return device_map 145 | 146 | 147 | def dispatch_model(model: PreTrainedModel, use_v2: bool) -> PreTrainedModel: 148 | r""" 149 | Dispatches a pre-trained model to GPUs with balanced memory. 150 | """ 151 | if torch.cuda.device_count() > 1: 152 | from accelerate import dispatch_model 153 | 154 | device_map = auto_configure_device_map(torch.cuda.device_count(), use_v2=use_v2) 155 | model.tie_weights() 156 | return dispatch_model(model, device_map) 157 | else: 158 | return model.cuda() 159 | -------------------------------------------------------------------------------- /src/glmtuner/extras/ploting.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import json 4 | import matplotlib.pyplot as plt 5 | from typing import List, Optional 6 | from transformers.trainer import TRAINER_STATE_NAME 7 | 8 | from glmtuner.extras.logging import get_logger 9 | 10 | 11 | logger = get_logger(__name__) 12 | 13 | 14 | def smooth(scalars: List[float]) -> List[float]: 15 | r""" 16 | EMA implementation according to TensorBoard. 17 | """ 18 | last = scalars[0] 19 | smoothed = list() 20 | weight = 1.8 * (1 / (1 + math.exp(-0.05 * len(scalars))) - 0.5) # a sigmoid function 21 | for next_val in scalars: 22 | smoothed_val = last * weight + (1 - weight) * next_val 23 | smoothed.append(smoothed_val) 24 | last = smoothed_val 25 | return smoothed 26 | 27 | 28 | def plot_loss(save_dictionary: os.PathLike, keys: Optional[List[str]] = ["loss"]) -> None: 29 | 30 | with open(os.path.join(save_dictionary, TRAINER_STATE_NAME), "r", encoding="utf-8") as f: 31 | data = json.load(f) 32 | 33 | for key in keys: 34 | steps, metrics = [], [] 35 | for i in range(len(data["log_history"])): 36 | if key in data["log_history"][i]: 37 | steps.append(data["log_history"][i]["step"]) 38 | metrics.append(data["log_history"][i][key]) 39 | 40 | if len(metrics) == 0: 41 | logger.warning(f"No metric {key} to plot.") 42 | continue 43 | 44 | plt.figure() 45 | plt.plot(steps, metrics, alpha=0.4, label="original") 46 | plt.plot(steps, smooth(metrics), label="smoothed") 47 | plt.title("training {} of {}".format(key, save_dictionary)) 48 | plt.xlabel("step") 49 | plt.ylabel(key) 50 | plt.legend() 51 | plt.savefig(os.path.join(save_dictionary, "training_{}.png".format(key)), format="png", dpi=100) 52 | print("Figure saved:", os.path.join(save_dictionary, "training_{}.png".format(key))) 53 | -------------------------------------------------------------------------------- /src/glmtuner/extras/save_and_load.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from typing import Dict, Optional 4 | 5 | from transformers.trainer import WEIGHTS_NAME, WEIGHTS_INDEX_NAME 6 | from transformers.modeling_utils import load_sharded_checkpoint 7 | 8 | from glmtuner.extras.constants import VALUE_HEAD_FILE_NAME 9 | from glmtuner.extras.logging import get_logger 10 | 11 | 12 | logger = get_logger(__name__) 13 | 14 | 15 | def get_state_dict(model: torch.nn.Module, trainable_only: Optional[bool] = True) -> Dict[str, torch.Tensor]: 16 | state_dict = model.state_dict() 17 | filtered_state_dict = {} 18 | 19 | for k, v in model.named_parameters(): 20 | if (not trainable_only) or v.requires_grad: 21 | filtered_state_dict[k] = state_dict[k].cpu().clone().detach() 22 | 23 | return filtered_state_dict 24 | 25 | 26 | def load_trainable_params(model: torch.nn.Module, checkpoint_dir: os.PathLike) -> bool: 27 | weights_file = os.path.join(checkpoint_dir, WEIGHTS_NAME) 28 | if os.path.exists(weights_file): 29 | model_state_dict = torch.load(weights_file, map_location="cpu") 30 | model.load_state_dict(model_state_dict, strict=False) # skip missing keys 31 | elif os.path.exists(os.path.join(checkpoint_dir, WEIGHTS_INDEX_NAME)): 32 | load_sharded_checkpoint(model, checkpoint_dir, strict=False) 33 | else: 34 | logger.warning("Provided path ({}) does not contain pre-trained weights.".format(checkpoint_dir)) 35 | return False 36 | return True 37 | 38 | 39 | def load_valuehead_params(model: torch.nn.Module, checkpoint_dir: os.PathLike) -> bool: 40 | valuehead_file = os.path.join(checkpoint_dir, VALUE_HEAD_FILE_NAME) 41 | if not os.path.exists(valuehead_file): 42 | logger.warning("Provided path ({}) does not contain valuehead weights.".format(checkpoint_dir)) 43 | return False 44 | valuehead_state_dict = torch.load(valuehead_file, map_location="cpu") 45 | model.register_buffer("reward_head_weight", valuehead_state_dict["summary.weight"]) 46 | model.register_buffer("reward_head_bias", valuehead_state_dict["summary.bias"]) 47 | model.register_buffer("default_head_weight", torch.zeros_like(valuehead_state_dict["summary.weight"])) 48 | model.register_buffer("default_head_bias", torch.zeros_like(valuehead_state_dict["summary.bias"])) 49 | return True 50 | -------------------------------------------------------------------------------- /src/glmtuner/hparams/__init__.py: -------------------------------------------------------------------------------- 1 | from glmtuner.hparams.data_args import DataArguments 2 | from glmtuner.hparams.finetuning_args import FinetuningArguments 3 | from glmtuner.hparams.general_args import GeneralArguments 4 | from glmtuner.hparams.generating_args import GeneratingArguments 5 | from glmtuner.hparams.model_args import ModelArguments 6 | -------------------------------------------------------------------------------- /src/glmtuner/hparams/data_args.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | from typing import List, Optional 4 | from dataclasses import dataclass, field 5 | 6 | 7 | @dataclass 8 | class DatasetAttr: 9 | 10 | load_from: str 11 | dataset_name: str 12 | dataset_sha1: Optional[str] = None 13 | 14 | def __repr__(self) -> str: 15 | return self.dataset_name 16 | 17 | def __post_init__(self): 18 | self.prompt_column = "instruction" 19 | self.query_column = "input" 20 | self.response_column = "output" 21 | self.history_column = None 22 | 23 | 24 | @dataclass 25 | class DataArguments: 26 | """ 27 | Arguments pertaining to what data we are going to input our model for training and evaluation. 28 | """ 29 | dataset: Optional[str] = field( 30 | default="alpaca_zh", 31 | metadata={"help": "The name of provided dataset(s) to use. Use commas to separate multiple datasets."} 32 | ) 33 | dataset_dir: Optional[str] = field( 34 | default="data", 35 | metadata={"help": "The name of the folder containing datasets."} 36 | ) 37 | split: Optional[str] = field( 38 | default="train", 39 | metadata={"help": "Which dataset split to use for training and evaluation."} 40 | ) 41 | overwrite_cache: Optional[bool] = field( 42 | default=False, 43 | metadata={"help": "Overwrite the cached training and evaluation sets."} 44 | ) 45 | preprocessing_num_workers: Optional[int] = field( 46 | default=None, 47 | metadata={"help": "The number of processes to use for the preprocessing."} 48 | ) 49 | max_source_length: Optional[int] = field( 50 | default=512, 51 | metadata={"help": "The maximum total input sequence length after tokenization."} 52 | ) 53 | max_target_length: Optional[int] = field( 54 | default=512, 55 | metadata={"help": "The maximum total output sequence length after tokenization."} 56 | ) 57 | max_samples: Optional[int] = field( 58 | default=None, 59 | metadata={"help": "For debugging purposes, truncate the number of examples for each dataset."} 60 | ) 61 | eval_num_beams: Optional[int] = field( 62 | default=None, 63 | metadata={"help": "Number of beams to use for evaluation. This argument will be passed to `model.generate`"} 64 | ) 65 | ignore_pad_token_for_loss: Optional[bool] = field( 66 | default=True, 67 | metadata={"help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not."} 68 | ) 69 | source_prefix: Optional[str] = field( 70 | default=None, 71 | metadata={"help": "A prefix to add before every source text (useful for T5 models)."} 72 | ) 73 | dev_ratio: Optional[float] = field( 74 | default=0, 75 | metadata={"help": "Proportion of the dataset to include in the development set, should be between 0.0 and 1.0."} 76 | ) 77 | 78 | def init_for_training(self): # support mixing multiple datasets 79 | dataset_names = [ds.strip() for ds in self.dataset.split(",")] 80 | with open(os.path.join(self.dataset_dir, "dataset_info.json"), "r", encoding="utf-8") as f: 81 | dataset_info = json.load(f) 82 | 83 | self.dataset_list: List[DatasetAttr] = [] 84 | for name in dataset_names: 85 | if name not in dataset_info: 86 | raise ValueError("Undefined dataset {} in dataset_info.json.".format(name)) 87 | 88 | if "hf_hub_url" in dataset_info[name]: 89 | dataset_attr = DatasetAttr("hf_hub", dataset_name=dataset_info[name]["hf_hub_url"]) 90 | elif "script_url" in dataset_info[name]: 91 | dataset_attr = DatasetAttr("script", dataset_name=dataset_info[name]["script_url"]) 92 | else: 93 | dataset_attr = DatasetAttr( 94 | "file", 95 | dataset_name=dataset_info[name]["file_name"], 96 | dataset_sha1=dataset_info[name].get("file_sha1", None) 97 | ) 98 | 99 | if "columns" in dataset_info[name]: 100 | dataset_attr.prompt_column = dataset_info[name]["columns"].get("prompt", None) 101 | dataset_attr.query_column = dataset_info[name]["columns"].get("query", None) 102 | dataset_attr.response_column = dataset_info[name]["columns"].get("response", None) 103 | dataset_attr.history_column = dataset_info[name]["columns"].get("history", None) 104 | 105 | self.dataset_list.append(dataset_attr) 106 | -------------------------------------------------------------------------------- /src/glmtuner/hparams/finetuning_args.py: -------------------------------------------------------------------------------- 1 | import json 2 | from typing import Literal, Optional 3 | from dataclasses import asdict, dataclass, field 4 | 5 | 6 | @dataclass 7 | class FinetuningArguments: 8 | """ 9 | Arguments pertaining to which techniques we are going to fine-tuning with. 10 | """ 11 | finetuning_type: Optional[Literal["none", "freeze", "p_tuning", "lora", "full"]] = field( 12 | default="lora", 13 | metadata={"help": "Which fine-tuning method to use."} 14 | ) 15 | num_layer_trainable: Optional[int] = field( 16 | default=3, 17 | metadata={"help": "Number of trainable layers for Freeze fine-tuning."} 18 | ) 19 | name_module_trainable: Optional[Literal["mlp", "qkv"]] = field( 20 | default="mlp", 21 | metadata={"help": "Name of trainable modules for Freeze fine-tuning."} 22 | ) 23 | pre_seq_len: Optional[int] = field( 24 | default=64, 25 | metadata={"help": "Number of prefix tokens to use for P-tuning V2."} 26 | ) 27 | prefix_projection: Optional[bool] = field( 28 | default=False, 29 | metadata={"help": "Whether to add a project layer for the prefix in P-tuning V2 or not."} 30 | ) 31 | lora_rank: Optional[int] = field( 32 | default=8, 33 | metadata={"help": "The intrinsic dimension for LoRA fine-tuning."} 34 | ) 35 | lora_alpha: Optional[float] = field( 36 | default=32.0, 37 | metadata={"help": "The scale factor for LoRA fine-tuning. (similar with the learning rate)"} 38 | ) 39 | lora_dropout: Optional[float] = field( 40 | default=0.1, 41 | metadata={"help": "Dropout rate for the LoRA fine-tuning."} 42 | ) 43 | lora_target: Optional[str] = field( 44 | default="query_key_value", 45 | metadata={"help": "Name(s) of target modules to apply LoRA. Use commas to separate multiple modules."} 46 | ) 47 | 48 | def __post_init__(self): 49 | if isinstance(self.lora_target, str): 50 | self.lora_target = [target.strip() for target in self.lora_target.split(",")] # support custom target modules of LoRA 51 | 52 | if self.num_layer_trainable > 0: # fine-tuning the last n layers if num_layer_trainable > 0 53 | trainable_layer_ids = [27 - k for k in range(self.num_layer_trainable)] 54 | else: # fine-tuning the first n layers if num_layer_trainable < 0 55 | trainable_layer_ids = [k for k in range(-self.num_layer_trainable)] 56 | 57 | if self.name_module_trainable == "mlp": 58 | self.trainable_layers = ["{:d}.mlp".format(idx) for idx in trainable_layer_ids] 59 | elif self.name_module_trainable == "qkv": 60 | self.trainable_layers = ["{:d}.attention.query_key_value".format(idx) for idx in trainable_layer_ids] 61 | 62 | assert self.finetuning_type in ["none", "freeze", "p_tuning", "lora", "full"], "Invalid fine-tuning method." 63 | 64 | def save_to_json(self, json_path: str) -> None: 65 | """Saves the content of this instance in JSON format inside `json_path`.""" 66 | json_string = json.dumps(asdict(self), indent=2, sort_keys=True) + "\n" 67 | with open(json_path, "w", encoding="utf-8") as f: 68 | f.write(json_string) 69 | 70 | @classmethod 71 | def load_from_json(cls, json_path: str): 72 | """Creates an instance from the content of `json_path`.""" 73 | with open(json_path, "r", encoding="utf-8") as f: 74 | text = f.read() 75 | return cls(**json.loads(text)) 76 | -------------------------------------------------------------------------------- /src/glmtuner/hparams/general_args.py: -------------------------------------------------------------------------------- 1 | from typing import Literal, Optional 2 | from dataclasses import dataclass, field 3 | 4 | 5 | @dataclass 6 | class GeneralArguments: 7 | """ 8 | Arguments pertaining to which techniques we are going to fine-tuning with. 9 | """ 10 | stage: Optional[Literal["sft", "rm", "ppo"]] = field( 11 | default="sft", 12 | metadata={"help": "Which stage will be performed in training."} 13 | ) 14 | -------------------------------------------------------------------------------- /src/glmtuner/hparams/generating_args.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, Optional 2 | from dataclasses import asdict, dataclass, field 3 | 4 | 5 | @dataclass 6 | class GeneratingArguments: 7 | """ 8 | Arguments pertaining to specify the decoding parameters. 9 | """ 10 | do_sample: Optional[bool] = field( 11 | default=True, 12 | metadata={"help": "Whether or not to use sampling, use greedy decoding otherwise."} 13 | ) 14 | temperature: Optional[float] = field( 15 | default=0.95, 16 | metadata={"help": "The value used to modulate the next token probabilities."} 17 | ) 18 | top_p: Optional[float] = field( 19 | default=0.7, 20 | metadata={"help": "The smallest set of most probable tokens with probabilities that add up to top_p or higher are kept."} 21 | ) 22 | top_k: Optional[int] = field( 23 | default=50, 24 | metadata={"help": "The number of highest probability vocabulary tokens to keep for top-k filtering."} 25 | ) 26 | num_beams: Optional[int] = field( 27 | default=1, 28 | metadata={"help": "Number of beams for beam search. 1 means no beam search."} 29 | ) 30 | max_length: Optional[int] = field( 31 | default=2048, 32 | metadata={"help": "The maximum length the generated tokens can have. It can be overridden by max_new_tokens."} 33 | ) 34 | max_new_tokens: Optional[int] = field( 35 | default=None, 36 | metadata={"help": "The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt."} 37 | ) 38 | repetition_penalty: Optional[float] = field( 39 | default=1.0, 40 | metadata={"help": "The parameter for repetition penalty. 1.0 means no penalty."} 41 | ) 42 | 43 | def to_dict(self) -> Dict[str, Any]: 44 | return asdict(self) 45 | -------------------------------------------------------------------------------- /src/glmtuner/hparams/model_args.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Literal, Optional 3 | from dataclasses import dataclass, field 4 | 5 | 6 | @dataclass 7 | class ModelArguments: 8 | """ 9 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune. 10 | """ 11 | model_name_or_path: Optional[str] = field( 12 | default="THUDM/chatglm-6b", 13 | metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models."} 14 | ) 15 | config_name: Optional[str] = field( 16 | default=None, 17 | metadata={"help": "Pretrained config name or path if not the same as model_name."} 18 | ) 19 | tokenizer_name: Optional[str] = field( 20 | default=None, 21 | metadata={"help": "Pretrained tokenizer name or path if not the same as model_name."} 22 | ) 23 | cache_dir: Optional[str] = field( 24 | default=None, 25 | metadata={"help": "Where to store the pretrained models downloaded from huggingface.co."} 26 | ) 27 | use_fast_tokenizer: Optional[bool] = field( 28 | default=True, 29 | metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."} 30 | ) 31 | model_revision: Optional[str] = field( 32 | default="main", 33 | metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."} 34 | ) 35 | use_auth_token: Optional[bool] = field( 36 | default=False, 37 | metadata={"help": "Will use the token generated when running `huggingface-cli login`."} 38 | ) 39 | quantization_bit: Optional[int] = field( 40 | default=None, 41 | metadata={"help": "The number of bits to quantize the model."} 42 | ) 43 | quantization_type: Optional[Literal["fp4", "nf4"]] = field( 44 | default="nf4", 45 | metadata={"help": "Quantization data type to use in int4 training."} 46 | ) 47 | double_quantization: Optional[bool] = field( 48 | default=True, 49 | metadata={"help": "Whether to use double quantization in int4 training or not."} 50 | ) 51 | compute_dtype: Optional[torch.dtype] = field( 52 | default=None, 53 | metadata={"help": "Used in quantization configs. Do not specify this argument manually."} 54 | ) 55 | checkpoint_dir: Optional[str] = field( 56 | default=None, 57 | metadata={"help": "Path to the directory containing the model checkpoints as well as the configurations."} 58 | ) 59 | reward_model: Optional[str] = field( 60 | default=None, 61 | metadata={"help": "Path to the directory containing the checkpoints of the reward model."} 62 | ) 63 | resume_lora_training: Optional[bool] = field( 64 | default=True, 65 | metadata={"help": "Whether to resume training from the last LoRA weights or create new weights after merging them."} 66 | ) 67 | plot_loss: Optional[bool] = field( 68 | default=False, 69 | metadata={"help": "Whether to plot the training loss after fine-tuning or not."} 70 | ) 71 | 72 | def __post_init__(self): 73 | if not self.checkpoint_dir: 74 | self.checkpoint_dir = None 75 | 76 | if self.checkpoint_dir is not None: # support merging lora weights 77 | self.checkpoint_dir = [cd.strip() for cd in self.checkpoint_dir.split(",")] 78 | 79 | if self.quantization_bit is not None: 80 | assert self.quantization_bit in [4, 8], "We only accept 4-bit or 8-bit quantization." 81 | -------------------------------------------------------------------------------- /src/glmtuner/tuner/__init__.py: -------------------------------------------------------------------------------- 1 | from glmtuner.tuner.core import get_train_args, get_infer_args, load_model_and_tokenizer 2 | from glmtuner.tuner.sft import run_sft 3 | from glmtuner.tuner.rm import run_rm 4 | from glmtuner.tuner.ppo import run_ppo 5 | -------------------------------------------------------------------------------- /src/glmtuner/tuner/core/__init__.py: -------------------------------------------------------------------------------- 1 | from glmtuner.tuner.core.parser import get_train_args, get_infer_args 2 | from glmtuner.tuner.core.loader import load_model_and_tokenizer 3 | -------------------------------------------------------------------------------- /src/glmtuner/tuner/core/adapter.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | from transformers.modeling_utils import PreTrainedModel 5 | from peft import ( 6 | PeftModel, 7 | TaskType, 8 | LoraConfig, 9 | get_peft_model 10 | ) 11 | from peft.utils import CONFIG_NAME, WEIGHTS_NAME 12 | 13 | from glmtuner.extras.logging import get_logger 14 | from glmtuner.extras.save_and_load import load_trainable_params 15 | from glmtuner.hparams import ModelArguments, FinetuningArguments 16 | 17 | 18 | logger = get_logger(__name__) 19 | 20 | 21 | def init_adapter( 22 | model: PreTrainedModel, 23 | model_args: ModelArguments, 24 | finetuning_args: FinetuningArguments, 25 | is_trainable: bool 26 | ) -> PreTrainedModel: 27 | r""" 28 | Initializes the adapters. 29 | 30 | Support full-parameter, freeze, P-Tuning v2 and LoRA training. 31 | 32 | Note that the trainable parameters must be cast to float32. 33 | """ 34 | 35 | if finetuning_args.finetuning_type == "none" and is_trainable: 36 | raise ValueError("You cannot use finetuning_type=none while training.") 37 | 38 | if finetuning_args.finetuning_type == "full": 39 | logger.info("Fine-tuning method: Full") 40 | model = model.float() 41 | 42 | if finetuning_args.finetuning_type == "freeze": 43 | logger.info("Fine-tuning method: Freeze") 44 | 45 | for name, param in model.named_parameters(): 46 | if not any(trainable_layer in name for trainable_layer in finetuning_args.trainable_layers): 47 | param.requires_grad_(False) 48 | else: 49 | param.data = param.data.to(torch.float32) 50 | 51 | if model_args.checkpoint_dir is not None: 52 | assert load_trainable_params(model, model_args.checkpoint_dir[0]), "Model checkpoint is not correctly loaded." 53 | 54 | if finetuning_args.finetuning_type == "p_tuning": 55 | logger.info("Fine-tuning method: P-Tuning v2") 56 | 57 | if model_args.checkpoint_dir is not None: 58 | assert load_trainable_params(model, model_args.checkpoint_dir[0]), "Model checkpoint is not correctly loaded." 59 | 60 | if finetuning_args.finetuning_type == "lora": 61 | logger.info("Fine-tuning method: LoRA") 62 | latest_checkpoint = None 63 | 64 | if model_args.checkpoint_dir is not None: 65 | assert os.path.exists(os.path.join(model_args.checkpoint_dir[0], WEIGHTS_NAME)), \ 66 | "Provided path ({}) does not contain a LoRA weight.".format(model_args.checkpoint_dir[0]) 67 | assert os.path.exists(os.path.join(model_args.checkpoint_dir[0], CONFIG_NAME)), \ 68 | "The given checkpoint may be not a LoRA checkpoint, please specify `--finetuning_type full/p_tuning/freeze` instead." 69 | 70 | if is_trainable and model_args.resume_lora_training: # continually train on the lora weights 71 | checkpoints_to_merge, latest_checkpoint = model_args.checkpoint_dir[:-1], model_args.checkpoint_dir[-1] 72 | else: 73 | checkpoints_to_merge = model_args.checkpoint_dir 74 | 75 | for checkpoint in checkpoints_to_merge: 76 | model = PeftModel.from_pretrained(model, checkpoint) 77 | model = model.merge_and_unload() 78 | 79 | if len(checkpoints_to_merge) > 0: 80 | logger.info("Merged {} model checkpoint(s).".format(len(checkpoints_to_merge))) 81 | 82 | if latest_checkpoint is not None: # resume lora training 83 | model = PeftModel.from_pretrained(model, latest_checkpoint, is_trainable=True) 84 | 85 | if is_trainable and latest_checkpoint is None: # create new lora weights while training 86 | lora_config = LoraConfig( 87 | task_type=TaskType.CAUSAL_LM, # we should regard ChatGLM as a causal LM 88 | inference_mode=False, 89 | r=finetuning_args.lora_rank, 90 | lora_alpha=finetuning_args.lora_alpha, 91 | lora_dropout=finetuning_args.lora_dropout, 92 | target_modules=finetuning_args.lora_target 93 | ) 94 | model = get_peft_model(model, lora_config) 95 | 96 | if model_args.checkpoint_dir is not None: 97 | logger.info("Loaded fine-tuned model from checkpoint(s): {}".format(",".join(model_args.checkpoint_dir))) 98 | 99 | return model 100 | -------------------------------------------------------------------------------- /src/glmtuner/tuner/core/loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from typing import Literal, Optional, Tuple 4 | 5 | from transformers import ( 6 | AutoConfig, 7 | AutoModel, 8 | AutoTokenizer, 9 | BitsAndBytesConfig 10 | ) 11 | from transformers.utils import check_min_version 12 | from transformers.utils.versions import require_version 13 | from transformers.modeling_utils import PretrainedConfig, PreTrainedModel 14 | from transformers.tokenization_utils import PreTrainedTokenizerBase 15 | from trl import AutoModelForCausalLMWithValueHead 16 | 17 | from glmtuner.extras.logging import get_logger 18 | from glmtuner.extras.misc import prepare_model_for_training, print_trainable_params 19 | from glmtuner.extras.save_and_load import load_valuehead_params 20 | from glmtuner.hparams import ModelArguments, FinetuningArguments 21 | from glmtuner.tuner.core.adapter import init_adapter 22 | 23 | 24 | logger = get_logger(__name__) 25 | 26 | 27 | check_min_version("4.29.1") 28 | require_version("datasets>=2.12.0", "To fix: pip install datasets>=2.12.0") 29 | require_version("accelerate>=0.21.0", "To fix: pip install accelerate>=0.21.0") 30 | require_version("peft>=0.4.0", "To fix: pip install peft>=0.4.0") 31 | require_version("trl>=0.4.7", "To fix: pip install trl>=0.4.7") 32 | 33 | 34 | def load_model_and_tokenizer( 35 | model_args: ModelArguments, 36 | finetuning_args: FinetuningArguments, 37 | is_trainable: Optional[bool] = False, 38 | stage: Optional[Literal["sft", "rm", "ppo"]] = "sft" 39 | ) -> Tuple[PreTrainedModel, PreTrainedTokenizerBase]: 40 | r""" 41 | Loads pretrained model and tokenizer. 42 | 43 | Support both training and inference. 44 | """ 45 | 46 | if (not is_trainable) and model_args.checkpoint_dir is None: 47 | logger.warning("Checkpoint is not found at evaluation, load the original model.") 48 | finetuning_args = FinetuningArguments(finetuning_type="none") 49 | 50 | assert stage == "sft" or finetuning_args.finetuning_type == "lora", \ 51 | "RM and PPO training can only be performed with LoRA method." 52 | 53 | if model_args.quantization_bit is not None: 54 | if is_trainable and finetuning_args.finetuning_type == "lora": 55 | quantization = "bnb" # use bnb's quantization 56 | else: 57 | quantization = "cpm" # use cpm's quantization 58 | else: 59 | quantization = None 60 | 61 | config_kwargs = { 62 | "trust_remote_code": True, 63 | "cache_dir": model_args.cache_dir, 64 | "revision": model_args.model_revision, 65 | "use_auth_token": True if model_args.use_auth_token else None, 66 | } 67 | 68 | tokenizer = AutoTokenizer.from_pretrained( 69 | model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, 70 | use_fast=model_args.use_fast_tokenizer, 71 | padding_side="left", 72 | **config_kwargs 73 | ) 74 | 75 | config = AutoConfig.from_pretrained( 76 | model_args.config_name if model_args.config_name else model_args.model_name_or_path, 77 | **config_kwargs 78 | ) 79 | 80 | # P-Tuning v2 configurations. Use the built-in p-tuning method of ChatGLM. 81 | if finetuning_args.finetuning_type == "p_tuning": 82 | config.pre_seq_len = finetuning_args.pre_seq_len # enable this will fix other parameters automatically 83 | config.prefix_projection = finetuning_args.prefix_projection 84 | 85 | # Quantization configurations for Full, Freeze and LoRA in training (using bitsandbytes library). 86 | if quantization == "bnb": 87 | if model_args.quantization_bit == 8: 88 | require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0") 89 | config_kwargs["load_in_8bit"] = True 90 | config_kwargs["quantization_config"] = BitsAndBytesConfig( 91 | load_in_8bit=True, 92 | llm_int8_threshold=6.0 93 | ) 94 | elif model_args.quantization_bit == 4: 95 | require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0") 96 | config_kwargs["load_in_4bit"] = True 97 | config_kwargs["quantization_config"] = BitsAndBytesConfig( 98 | load_in_4bit=True, 99 | bnb_4bit_compute_dtype=model_args.compute_dtype, 100 | bnb_4bit_use_double_quant=model_args.double_quantization, 101 | bnb_4bit_quant_type=model_args.quantization_type 102 | ) 103 | config_kwargs["device_map"] = {"": int(os.environ.get("LOCAL_RANK") or 0)} 104 | 105 | if model_args.checkpoint_dir is not None and finetuning_args.finetuning_type == "full": 106 | model_to_load = model_args.checkpoint_dir[0] 107 | else: 108 | model_to_load = model_args.model_name_or_path 109 | 110 | # Load and prepare pretrained models (without valuehead). 111 | model = AutoModel.from_pretrained(model_to_load, config=config, **config_kwargs) 112 | 113 | # Register auto class to save the custom code files. 114 | if isinstance(config, PretrainedConfig): 115 | config.__class__.register_for_auto_class() 116 | if isinstance(tokenizer, PreTrainedTokenizerBase): 117 | tokenizer.__class__.register_for_auto_class() 118 | if isinstance(model, PreTrainedModel): 119 | model.__class__.register_for_auto_class() 120 | 121 | if tokenizer.eos_token_id == 130005: # ChatGLM-6B 122 | output_embedding_base_layer = model 123 | output_embedding_layer_name = "lm_head" 124 | elif tokenizer.eos_token_id == 2: # ChatGLM2-6B 125 | assert hasattr(model, "transformer"), "Please update the model files of ChatGLM-6B." 126 | model.lm_head = model.transformer.output_layer 127 | output_embedding_base_layer = model.transformer 128 | output_embedding_layer_name = "output_layer" 129 | else: 130 | raise ValueError("Please update the model files of ChatGLM2-6B.") 131 | 132 | # Initialize adapters 133 | model = prepare_model_for_training( 134 | model, 135 | finetuning_args.finetuning_type, 136 | output_embedding_base_layer, 137 | output_embedding_layer_name 138 | ) if is_trainable else model 139 | model = init_adapter(model, model_args, finetuning_args, is_trainable) 140 | 141 | if not is_trainable: 142 | model.requires_grad_(False) # fix all model params 143 | model = model.half() # cast all params to float16 for inference 144 | 145 | # Quantization with the built-in method for P-Tuning v2 training or evaluation. 146 | # Model parameters should be cast to float16 in quantized P-Tuning setting. 147 | if quantization == "cpm": 148 | if is_trainable: # convert all params into half precision except prefix_encoder in training 149 | for name, param in model.named_parameters(): 150 | if "prefix_encoder" not in name: 151 | param.data = param.data.to(torch.float16) 152 | 153 | model.quantize(model_args.quantization_bit) # built-in method in ChatGLM-6B, also an in-place operation 154 | 155 | if quantization is not None: 156 | logger.info("Quantized model to {} bit.".format(model_args.quantization_bit)) 157 | 158 | if stage == "rm" or stage == "ppo": # add value head 159 | model = AutoModelForCausalLMWithValueHead.from_pretrained(model) 160 | 161 | if stage == "rm" and model_args.checkpoint_dir is not None: # load valuehead weights to evaluate reward model 162 | logger.warning("Only the last checkpoint containing valuehead will be loaded as the valuehead.") 163 | if load_valuehead_params(model, model_args.checkpoint_dir[-1]): 164 | model.v_head.load_state_dict({ 165 | "summary.weight": getattr(model, "reward_head_weight"), 166 | "summary.bias": getattr(model, "reward_head_bias") 167 | }) 168 | 169 | if stage == "ppo": # load reward model 170 | assert is_trainable, "PPO stage cannot be performed at evaluation." 171 | assert model_args.reward_model is not None, "Reward model is necessary for PPO training." 172 | logger.info("Load reward model from {}".format(model_args.reward_model)) 173 | model.pretrained_model.load_adapter(model_args.reward_model, "reward", is_trainable=False) 174 | assert load_valuehead_params(model, model_args.reward_model), "Reward model is not correctly loaded." 175 | 176 | print_trainable_params(model) 177 | 178 | return model, tokenizer 179 | -------------------------------------------------------------------------------- /src/glmtuner/tuner/core/parser.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import torch 4 | import datasets 5 | import transformers 6 | from typing import Any, Dict, Optional, Tuple 7 | from transformers import HfArgumentParser, Seq2SeqTrainingArguments 8 | 9 | from glmtuner.extras.logging import get_logger 10 | from glmtuner.hparams import ( 11 | ModelArguments, 12 | DataArguments, 13 | FinetuningArguments, 14 | GeneratingArguments, 15 | GeneralArguments 16 | ) 17 | 18 | 19 | logger = get_logger(__name__) 20 | 21 | 22 | def get_train_args( 23 | args: Optional[Dict[str, Any]] = None 24 | ) -> Tuple[ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneralArguments]: 25 | 26 | parser = HfArgumentParser((ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneralArguments)) 27 | 28 | if args is not None: 29 | model_args, data_args, training_args, finetuning_args, general_args = parser.parse_dict(args) 30 | elif len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"): 31 | model_args, data_args, training_args, finetuning_args, general_args = parser.parse_yaml_file(os.path.abspath(sys.argv[1])) 32 | elif len(sys.argv) == 2 and sys.argv[1].endswith(".json"): 33 | model_args, data_args, training_args, finetuning_args, general_args = parser.parse_json_file(os.path.abspath(sys.argv[1])) 34 | else: 35 | model_args, data_args, training_args, finetuning_args, general_args = parser.parse_args_into_dataclasses() 36 | 37 | # Setup logging 38 | if training_args.should_log: 39 | # The default of training_args.log_level is passive, so we set log level at info here to have that default. 40 | transformers.utils.logging.set_verbosity_info() 41 | 42 | log_level = training_args.get_process_log_level() 43 | datasets.utils.logging.set_verbosity(log_level) 44 | transformers.utils.logging.set_verbosity(log_level) 45 | transformers.utils.logging.enable_default_handler() 46 | transformers.utils.logging.enable_explicit_format() 47 | 48 | # Check arguments (do not check finetuning_args since it may be loaded from checkpoints) 49 | data_args.init_for_training() 50 | 51 | assert general_args.stage == "sft" or (not training_args.predict_with_generate), \ 52 | "`predict_with_generate` cannot be set as True at PT, RM and PPO stages." 53 | 54 | assert not (training_args.do_train and training_args.predict_with_generate), \ 55 | "`predict_with_generate` cannot be set as True while training." 56 | 57 | assert general_args.stage != "sft" or (not training_args.do_predict) or training_args.predict_with_generate, \ 58 | "Please enable `predict_with_generate` to save model predictions." 59 | 60 | if model_args.quantization_bit is not None: 61 | assert finetuning_args.finetuning_type != "full" and finetuning_args.finetuning_type != "freeze", \ 62 | "Quantization is incompatible with the full-parameter and freeze tuning." 63 | 64 | assert not (finetuning_args.finetuning_type == "p_tuning" and training_args.fp16), \ 65 | "FP16 training conflicts with quantized P-Tuning." 66 | 67 | if not training_args.do_train: 68 | logger.warning("Evaluating model in 4/8-bit mode may cause lower scores.") 69 | 70 | assert model_args.checkpoint_dir is None or finetuning_args.finetuning_type == "lora" \ 71 | or len(model_args.checkpoint_dir) == 1, "Only LoRA tuning accepts multiple checkpoints." 72 | 73 | if training_args.do_train and (not training_args.fp16): 74 | logger.warning("We recommend enable fp16 mixed precision training for ChatGLM-6B.") 75 | 76 | if training_args.local_rank != -1 and training_args.ddp_find_unused_parameters is None: 77 | logger.warning("`ddp_find_unused_parameters` needs to be set as False in DDP training.") 78 | training_args.ddp_find_unused_parameters = False 79 | 80 | training_args.optim = "adamw_torch" if training_args.optim == "adamw_hf" else training_args.optim # suppress warning 81 | 82 | if model_args.quantization_bit is not None: 83 | if training_args.fp16: 84 | model_args.compute_dtype = torch.float16 85 | elif training_args.bf16: 86 | model_args.compute_dtype = torch.bfloat16 87 | else: 88 | model_args.compute_dtype = torch.float32 89 | 90 | # Log on each process the small summary: 91 | logger.info( 92 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}\n" 93 | + f" distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" 94 | ) 95 | logger.info(f"Training/evaluation parameters {training_args}") 96 | 97 | # Set seed before initializing model. 98 | transformers.set_seed(training_args.seed) 99 | 100 | return model_args, data_args, training_args, finetuning_args, general_args 101 | 102 | 103 | def get_infer_args( 104 | args: Optional[Dict[str, Any]] = None 105 | ) -> Tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]: 106 | 107 | parser = HfArgumentParser((ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments)) 108 | 109 | if args is not None: 110 | model_args, data_args, finetuning_args, generating_args = parser.parse_dict(args) 111 | elif len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"): 112 | model_args, data_args, finetuning_args, generating_args = parser.parse_yaml_file(os.path.abspath(sys.argv[1])) 113 | elif len(sys.argv) == 2 and sys.argv[1].endswith(".json"): 114 | model_args, data_args, finetuning_args, generating_args = parser.parse_json_file(os.path.abspath(sys.argv[1])) 115 | else: 116 | model_args, data_args, finetuning_args, generating_args = parser.parse_args_into_dataclasses() 117 | 118 | assert model_args.checkpoint_dir is None or finetuning_args.finetuning_type == "lora" \ 119 | or len(model_args.checkpoint_dir) == 1, "Only LoRA tuning accepts multiple checkpoints." 120 | 121 | return model_args, data_args, finetuning_args, generating_args 122 | -------------------------------------------------------------------------------- /src/glmtuner/tuner/core/trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from typing import Dict, Optional 4 | 5 | from transformers import Seq2SeqTrainer 6 | from transformers.trainer import TRAINING_ARGS_NAME 7 | from transformers.modeling_utils import PreTrainedModel, unwrap_model 8 | from peft import PeftModel 9 | 10 | from glmtuner.extras.constants import FINETUNING_ARGS_NAME, VALUE_HEAD_FILE_NAME 11 | from glmtuner.extras.logging import get_logger 12 | from glmtuner.extras.save_and_load import get_state_dict, load_trainable_params, load_valuehead_params 13 | from glmtuner.hparams import FinetuningArguments 14 | 15 | 16 | logger = get_logger(__name__) 17 | 18 | 19 | class PeftTrainer(Seq2SeqTrainer): 20 | r""" 21 | Inherits Seq2SeqTrainer to support parameter-efficient checkpoints. 22 | """ 23 | 24 | def __init__(self, finetuning_args: FinetuningArguments, **kwargs): 25 | super().__init__(**kwargs) 26 | self.finetuning_args = finetuning_args 27 | self._remove_log() 28 | 29 | def _remove_log(self): 30 | if self.is_world_process_zero() and os.path.exists(os.path.join(self.args.output_dir, "trainer_log.jsonl")): 31 | logger.warning("Previous log file in this folder will be deleted.") 32 | os.remove(os.path.join(self.args.output_dir, "trainer_log.jsonl")) 33 | 34 | def _save(self, output_dir: Optional[str] = None, state_dict: Optional[Dict[str, torch.Tensor]] = None) -> None: 35 | r""" 36 | Saves trainable parameters as model checkpoint. 37 | 38 | This function will only be executed at the process zero. 39 | 40 | Subclass and override to inject custom behavior. It should not be directly used by external scripts. 41 | """ 42 | output_dir = output_dir if output_dir is not None else self.args.output_dir 43 | os.makedirs(output_dir, exist_ok=True) 44 | logger.info(f"Saving model checkpoint to {output_dir}") 45 | model = unwrap_model(self.model) 46 | 47 | if hasattr(model, "pretrained_model"): # for models with valuehead (currently using LoRA only) 48 | backbone_model = getattr(model, "pretrained_model") 49 | torch.save(get_state_dict(getattr(model, "v_head")), os.path.join(output_dir, VALUE_HEAD_FILE_NAME)) 50 | else: 51 | backbone_model = model 52 | 53 | if isinstance(backbone_model, PeftModel): # LoRA tuning 54 | backbone_model.save_pretrained(output_dir, state_dict=get_state_dict(backbone_model)) 55 | elif isinstance(backbone_model, PreTrainedModel): # freeze/full-tuning or p_tuning 56 | backbone_model.config.use_cache = True 57 | backbone_model.save_pretrained( 58 | output_dir, 59 | state_dict=get_state_dict(backbone_model, trainable_only=(self.finetuning_args.finetuning_type != "full")), 60 | safe_serialization=self.args.save_safetensors 61 | ) 62 | backbone_model.config.use_cache = False 63 | if self.tokenizer is not None: 64 | self.tokenizer.save_pretrained(output_dir) 65 | else: 66 | logger.warning("No model to save.") 67 | 68 | with open(os.path.join(output_dir, TRAINING_ARGS_NAME), "w", encoding="utf-8") as f: 69 | f.write(self.args.to_json_string() + "\n") 70 | self.finetuning_args.save_to_json(os.path.join(output_dir, FINETUNING_ARGS_NAME)) 71 | 72 | def _load_best_model(self): 73 | r""" 74 | Loads trainable parameters from model checkpoint. 75 | 76 | Subclass and override to inject custom behavior. It should not be directly used by external scripts. 77 | """ 78 | logger.info(f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric}).") 79 | 80 | model = unwrap_model(self.model) 81 | backbone_model = getattr(model, "pretrained_model") if hasattr(model, "pretrained_model") else model 82 | 83 | if isinstance(backbone_model, PeftModel): 84 | backbone_model.load_adapter(self.state.best_model_checkpoint, backbone_model.active_adapter) 85 | if hasattr(model, "v_head") and load_valuehead_params(model, self.state.best_model_checkpoint): 86 | model.v_head.load_state_dict({ 87 | "summary.weight": getattr(model, "reward_head_weight"), 88 | "summary.bias": getattr(model, "reward_head_bias") 89 | }) 90 | else: # freeze/full-tuning or p_tuning 91 | load_trainable_params(backbone_model, self.state.best_model_checkpoint) 92 | -------------------------------------------------------------------------------- /src/glmtuner/tuner/ppo/__init__.py: -------------------------------------------------------------------------------- 1 | from glmtuner.tuner.ppo.workflow import run_ppo 2 | -------------------------------------------------------------------------------- /src/glmtuner/tuner/ppo/trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import torch 4 | from tqdm import tqdm 5 | from typing import Callable, Dict, List, Optional 6 | 7 | from transformers import Seq2SeqTrainingArguments, TrainerState, TrainerControl 8 | from transformers.modeling_utils import PreTrainedModel 9 | 10 | from trl import PPOTrainer, AutoModelForCausalLMWithValueHead 11 | from trl.core import LengthSampler 12 | from trl.trainer.ppo_trainer import PPODecorators, logprobs_from_logits 13 | 14 | from glmtuner.extras.callbacks import LogCallback 15 | from glmtuner.extras.logging import get_logger 16 | from glmtuner.extras.misc import AverageMeter, get_logits_processor 17 | from glmtuner.hparams import FinetuningArguments 18 | from glmtuner.tuner.core.trainer import PeftTrainer 19 | from glmtuner.tuner.ppo.utils import cast_layernorm_dtype, replace_model 20 | 21 | 22 | logger = get_logger(__name__) 23 | 24 | 25 | class PPOTrainerForChatGLM(PPOTrainer, PeftTrainer): 26 | r""" 27 | Inherits PPOTrainer. 28 | """ 29 | def __init__( 30 | self, 31 | training_args: Seq2SeqTrainingArguments, 32 | finetuning_args: FinetuningArguments, 33 | callbacks: List[LogCallback], 34 | **kwargs 35 | ): 36 | PPOTrainer.__init__(self, **kwargs) 37 | self.args = training_args 38 | self.finetuning_args = finetuning_args 39 | self.log_callback = callbacks[0] 40 | self.state = TrainerState() 41 | self.control = TrainerControl() 42 | self.data_collator = self.accelerator.prepare(kwargs["data_collator"]) # override the data collator of PPOTrainer 43 | self._remove_log() 44 | 45 | def ppo_train(self, max_target_length: int) -> None: 46 | r""" 47 | Implements training loop for the PPO stage, like _inner_training_loop() in Huggingface's Trainer. 48 | """ 49 | total_train_batch_size = ( 50 | self.args.per_device_train_batch_size * self.args.gradient_accumulation_steps * self.args.world_size 51 | ) 52 | len_dataloader = len(self.dataloader) 53 | num_examples = len(self.dataset) 54 | num_train_epochs = self.args.num_train_epochs 55 | max_steps = math.ceil(num_train_epochs * len_dataloader) 56 | 57 | self.state.max_steps = max_steps 58 | self.state.num_train_epochs = num_train_epochs 59 | self.state.is_local_process_zero = self.is_local_process_zero() 60 | self.state.is_world_process_zero = self.is_world_process_zero() 61 | 62 | if self.is_world_process_zero(): 63 | logger.info("***** Running training *****") 64 | logger.info(f" Num examples = {num_examples}") 65 | logger.info(f" Num Epochs = {num_train_epochs}") 66 | logger.info(f" Instantaneous batch size per device = {self.args.per_device_train_batch_size}") 67 | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size}") 68 | logger.info(f" Gradient Accumulation steps = {self.args.gradient_accumulation_steps}") 69 | logger.info(f" Total optimization steps = {max_steps}") 70 | logger.info(f" Number of trainable parameters = {sum(p.numel() for p in self.model.parameters() if p.requires_grad)}") 71 | 72 | # Keyword arguments for `model.generate` 73 | gen_kwargs = { 74 | "top_k": 0.0, 75 | "top_p": 1.0, 76 | "do_sample": True, 77 | "pad_token_id": self.tokenizer.pad_token_id, 78 | "eos_token_id": self.tokenizer.eos_token_id, 79 | "logits_processor": get_logits_processor() 80 | } 81 | length_sampler = LengthSampler(max_target_length // 2, max_target_length) 82 | unwrapped_model: PreTrainedModel = self.accelerator.unwrap_model(self.model) 83 | 84 | dataiter = iter(self.dataloader) 85 | steps_trained = 0 86 | loss_meter = AverageMeter() 87 | reward_meter = AverageMeter() 88 | self.log_callback.on_train_begin(self.args, self.state, self.control) 89 | 90 | for step in tqdm(range(max_steps), disable=not self.is_world_process_zero(), leave=False): 91 | batch = next(dataiter) 92 | steps_trained += 1 93 | 94 | unwrapped_model.gradient_checkpointing_disable() 95 | unwrapped_model.config.use_cache = True 96 | 97 | # Get responses 98 | query_tensors = batch["input_ids"] 99 | response_tensors = self.generate(batch, length_sampler, return_prompt=False, **gen_kwargs) 100 | 101 | queries, responses = [], [] 102 | for i in range(len(query_tensors)): 103 | query_length = (query_tensors[i] != self.tokenizer.pad_token_id).nonzero()[0] 104 | response_length = (response_tensors[i] != self.tokenizer.pad_token_id).nonzero()[-1] + 1 105 | queries.append(query_tensors[i, query_length:]) # remove padding from left 106 | responses.append(response_tensors[i, :response_length]) # remove padding from right 107 | 108 | # Compute rewards 109 | replace_model(unwrapped_model, target="reward") 110 | with torch.no_grad(): 111 | _, _, values = self.model( 112 | **self.prepare_model_inputs(queries, responses), 113 | output_hidden_states=True, 114 | return_dict=True 115 | ) 116 | rewards = [reward for reward in values[-1].to(torch.float32)] # use float32 type 117 | replace_model(unwrapped_model, target="default") 118 | 119 | # Run PPO step 120 | unwrapped_model.gradient_checkpointing_enable() 121 | unwrapped_model.config.use_cache = False 122 | stats = self.step(queries, responses, rewards) 123 | 124 | loss_meter.update(stats["ppo/loss/total"], n=len(rewards)) 125 | reward_meter.update(torch.stack(rewards).mean().item(), n=len(rewards)) 126 | 127 | if self.is_world_process_zero() and (step+1) % self.args.logging_steps == 0: 128 | logs = dict( 129 | loss=round(loss_meter.avg, 4), 130 | reward=round(reward_meter.avg, 4), 131 | learning_rate=stats["ppo/learning_rate"], 132 | epoch=round(step / len_dataloader, 2) 133 | ) 134 | print(logs) 135 | logs["step"] = step 136 | self.state.log_history.append(logs) 137 | self.log_callback.on_log(self.args, self.state, self.control) 138 | loss_meter.reset() 139 | reward_meter.reset() 140 | 141 | if (step+1) % self.args.save_steps == 0: # save checkpoint 142 | self.save_model(os.path.join(self.args.output_dir, f"checkpoint-{step+1}")) 143 | 144 | if self.control.should_epoch_stop or self.control.should_training_stop: 145 | break 146 | 147 | if steps_trained == len_dataloader: 148 | dataiter = iter(self.dataloader) 149 | steps_trained = 0 150 | 151 | @torch.no_grad() 152 | def generate( 153 | self, 154 | inputs: Dict[str, torch.Tensor], 155 | length_sampler: Optional[Callable] = None, 156 | return_prompt: Optional[bool] = True, 157 | **generation_kwargs 158 | ) -> torch.Tensor: 159 | r""" 160 | Generates model's responses given queries. 161 | 162 | Subclass and override to inject custom behavior. 163 | """ 164 | self.model, layer_norm_params = cast_layernorm_dtype(self.model) 165 | 166 | if length_sampler is not None: 167 | generation_kwargs["max_new_tokens"] = length_sampler() 168 | 169 | unwrapped_model = self.accelerator.unwrap_model(self.model) 170 | 171 | response = unwrapped_model.generate(**inputs, **generation_kwargs) 172 | 173 | # Temporary hack to ensure the generation config is not initialized for each iteration of the evaluation loop 174 | # Inspired by: https://github.com/huggingface/transformers/blob/v4.28.1/src/transformers/trainer_seq2seq.py#L273 175 | if unwrapped_model.pretrained_model.generation_config._from_model_config: 176 | unwrapped_model.pretrained_model.generation_config._from_model_config = False 177 | 178 | self.model, _ = cast_layernorm_dtype(self.model, layer_norm_params) 179 | 180 | if not return_prompt and not self.is_encoder_decoder: 181 | return response[:, inputs["input_ids"].size(1):] 182 | return response 183 | 184 | @PPODecorators.empty_cuda_cache() 185 | def batched_forward_pass( 186 | self, 187 | model: AutoModelForCausalLMWithValueHead, 188 | queries: torch.Tensor, 189 | responses: torch.Tensor, 190 | model_inputs: dict, 191 | return_logits: bool = False 192 | ): 193 | r""" 194 | Calculates model outputs in multiple batches. 195 | 196 | Subclass and override to inject custom behavior. 197 | """ 198 | bs = len(queries) 199 | fbs = self.config.mini_batch_size 200 | all_logprobs = [] 201 | all_logits = [] 202 | all_masks = [] 203 | all_values = [] 204 | 205 | for i in range(int(bs / fbs)): 206 | input_kwargs = {key: value[i * fbs : (i + 1) * fbs] for key, value in model_inputs.items()} 207 | query_batch = queries[i * fbs : (i + 1) * fbs] 208 | response_batch = responses[i * fbs : (i + 1) * fbs] 209 | input_ids = input_kwargs["input_ids"] # left-padded sequences 210 | 211 | if self.is_distributed: # re-generate them to adapt padded inputs 212 | input_kwargs["attention_mask"] = self.data_collator.get_attention_masks(input_ids, device=self.current_device) 213 | input_kwargs["position_ids"] = self.data_collator.get_position_ids(input_ids, device=self.current_device) 214 | 215 | logits, _, values = model(**input_kwargs, output_hidden_states=True, return_dict=True) 216 | logprobs = logprobs_from_logits(logits[:, :-1, :], input_ids[:, 1:]) 217 | values = values.transpose(0, 1) 218 | masks = torch.zeros_like(input_ids) 219 | 220 | for j in range(fbs): 221 | start = len(query_batch[j]) - 1 222 | start += (input_ids[j] != self.tokenizer.pad_token_id).nonzero()[0].item() 223 | end = start + len(response_batch[j]) 224 | masks[j][start:end] = 1 225 | 226 | if return_logits: 227 | all_logits.append(logits) 228 | else: 229 | del logits 230 | all_values.append(values) 231 | all_logprobs.append(logprobs) 232 | all_masks.append(masks) 233 | 234 | return ( 235 | torch.cat(all_logprobs), 236 | torch.cat(all_logits)[:, :-1] if return_logits else None, 237 | torch.cat(all_values)[:, :-1], 238 | torch.cat(all_masks)[:, :-1], 239 | ) 240 | 241 | def save_model(self, output_dir: Optional[str] = None) -> None: 242 | r""" 243 | Saves model checkpoint. 244 | 245 | Subclass and override to inject custom behavior. 246 | """ 247 | 248 | if self.args.should_save: 249 | self._save(output_dir) 250 | -------------------------------------------------------------------------------- /src/glmtuner/tuner/ppo/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Dict, List, Literal, Optional, Tuple 3 | from trl import AutoModelForCausalLMWithValueHead 4 | 5 | from glmtuner.extras.constants import LAYERNORM_NAMES 6 | 7 | 8 | def replace_model(model: AutoModelForCausalLMWithValueHead, target: Literal["default", "reward"]) -> None: 9 | if target == "reward": # save default head temporarily 10 | valuehead_state_dict = model.v_head.state_dict() 11 | setattr(model, "default_head_weight", valuehead_state_dict["summary.weight"]) 12 | setattr(model, "default_head_bias", valuehead_state_dict["summary.bias"]) 13 | 14 | model.pretrained_model.set_adapter(target) # set the LoRA adapter to be active 15 | model.v_head.load_state_dict({ 16 | "summary.weight": getattr(model, "{}_head_weight".format(target)), 17 | "summary.bias": getattr(model, "{}_head_bias".format(target)) 18 | }) 19 | 20 | 21 | def cast_layernorm_dtype( 22 | model: AutoModelForCausalLMWithValueHead, 23 | layer_norm_names: List[str] = LAYERNORM_NAMES, 24 | layer_norm_params: Optional[Dict[str, torch.Tensor]] = None 25 | ) -> Tuple[AutoModelForCausalLMWithValueHead, Dict[str, torch.Tensor]]: 26 | 27 | layer_norm_state_dict = {} 28 | 29 | for name, param in model.named_parameters(): 30 | if param.ndim == 1 and any(layer_norm_name in name for layer_norm_name in layer_norm_names): 31 | if layer_norm_params is not None: 32 | param.data = layer_norm_params[name] # restore float32 weights 33 | else: 34 | layer_norm_state_dict[name] = param.data.detach().clone() # store float32 weights for stability 35 | param.data = param.data.to(torch.float16) 36 | 37 | return model, layer_norm_state_dict 38 | -------------------------------------------------------------------------------- /src/glmtuner/tuner/ppo/workflow.py: -------------------------------------------------------------------------------- 1 | # Inspired by: 2 | # https://github.com/lvwerra/trl/blob/main/examples/sentiment/scripts/gpt-neox-20b_peft/gpt-neo-20b_sentiment_peft.py 3 | 4 | import math 5 | from trl import PPOConfig 6 | from torch.optim import AdamW 7 | from typing import Optional, List 8 | from transformers import Seq2SeqTrainingArguments, TrainerCallback 9 | from transformers.optimization import get_scheduler 10 | 11 | from glmtuner.dsets import DataCollatorForChatGLM, get_dataset, preprocess_dataset 12 | from glmtuner.extras.callbacks import LogCallback 13 | from glmtuner.extras.ploting import plot_loss 14 | from glmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments 15 | from glmtuner.tuner.core import load_model_and_tokenizer 16 | from glmtuner.tuner.ppo.trainer import PPOTrainerForChatGLM 17 | 18 | 19 | def run_ppo( 20 | model_args: ModelArguments, 21 | data_args: DataArguments, 22 | training_args: Seq2SeqTrainingArguments, 23 | finetuning_args: FinetuningArguments, 24 | callbacks: Optional[List[TrainerCallback]] = [LogCallback()] 25 | ): 26 | dataset = get_dataset(model_args, data_args) 27 | model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="ppo") 28 | dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="ppo") 29 | data_collator = DataCollatorForChatGLM(tokenizer, model.pretrained_model) 30 | 31 | ppo_config = PPOConfig( 32 | model_name=model_args.model_name_or_path, 33 | learning_rate=training_args.learning_rate, 34 | mini_batch_size=training_args.per_device_train_batch_size, 35 | batch_size=training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps, 36 | gradient_accumulation_steps=training_args.gradient_accumulation_steps, 37 | ppo_epochs=1, 38 | max_grad_norm=training_args.max_grad_norm 39 | ) 40 | 41 | optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=ppo_config.learning_rate) 42 | total_train_batch_size = \ 43 | training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps * training_args.world_size 44 | lr_scheduler = get_scheduler( 45 | training_args.lr_scheduler_type, 46 | optimizer=optimizer, 47 | num_warmup_steps=training_args.warmup_steps, 48 | num_training_steps=(training_args.num_train_epochs * math.ceil(len(dataset) / total_train_batch_size)) 49 | ) 50 | 51 | # Initialize our Trainer 52 | ppo_trainer = PPOTrainerForChatGLM( 53 | training_args=training_args, 54 | finetuning_args=finetuning_args, 55 | callbacks=callbacks, 56 | config=ppo_config, 57 | model=model, 58 | ref_model=None, 59 | tokenizer=tokenizer, 60 | dataset=dataset, 61 | data_collator=data_collator, 62 | optimizer=optimizer, 63 | lr_scheduler=lr_scheduler 64 | ) 65 | 66 | ppo_trainer.ppo_train(max_target_length=data_args.max_target_length) 67 | ppo_trainer.save_model() 68 | ppo_trainer.save_state() # must be after save_model 69 | if ppo_trainer.is_world_process_zero() and model_args.plot_loss: 70 | plot_loss(training_args.output_dir, keys=["loss", "reward"]) 71 | -------------------------------------------------------------------------------- /src/glmtuner/tuner/rm/__init__.py: -------------------------------------------------------------------------------- 1 | from glmtuner.tuner.rm.workflow import run_rm 2 | -------------------------------------------------------------------------------- /src/glmtuner/tuner/rm/collator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Any, Dict, Sequence 3 | 4 | from glmtuner.dsets import DataCollatorForChatGLM 5 | 6 | 7 | class PairwiseDataCollatorForChatGLM(DataCollatorForChatGLM): 8 | r""" 9 | Data collator for pairwise data. 10 | """ 11 | 12 | def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, torch.Tensor]: 13 | r""" 14 | Pads batched data to the longest sequence in the batch. 15 | 16 | We generate 2 * n examples where the first n examples represent chosen examples and 17 | the last n examples represent rejected examples. 18 | """ 19 | 20 | features = [{"input_ids": feature[key]} for key in ("accept_ids", "reject_ids") for feature in features] 21 | return super().__call__(features) 22 | -------------------------------------------------------------------------------- /src/glmtuner/tuner/rm/metric.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from typing import Dict, Sequence, Tuple, Union 3 | 4 | 5 | def compute_accuracy(eval_preds: Sequence[Union[np.ndarray, Tuple[np.ndarray]]]) -> Dict[str, float]: 6 | preds, _ = eval_preds 7 | return {"accuracy": (preds[0] > preds[1]).sum() / len(preds[0])} 8 | -------------------------------------------------------------------------------- /src/glmtuner/tuner/rm/trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import torch 4 | from typing import Dict, List, Optional, Tuple, Union 5 | from transformers.trainer import PredictionOutput 6 | from transformers.modeling_utils import PreTrainedModel 7 | 8 | from glmtuner.extras.logging import get_logger 9 | from glmtuner.tuner.core.trainer import PeftTrainer 10 | 11 | 12 | logger = get_logger(__name__) 13 | 14 | 15 | class PairwiseTrainerForChatGLM(PeftTrainer): 16 | r""" 17 | Inherits PeftTrainer to compute pairwise loss. 18 | """ 19 | 20 | def __init__(self, *args, **kwargs): 21 | super().__init__(*args, **kwargs) 22 | self.can_return_loss = True # override property to return eval_loss 23 | 24 | def compute_loss( 25 | self, 26 | model: PreTrainedModel, 27 | inputs: Dict[str, torch.Tensor], 28 | return_outputs: Optional[bool] = False 29 | ) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]: 30 | r""" 31 | Computes pairwise loss. The first n examples are chosen and the last n examples are rejected. 32 | 33 | We use score on the EOS token to represent reward of the whole sentence. 34 | 35 | Subclass and override to inject custom behavior. It should not be directly used by external scripts. 36 | 37 | Note that the first element will be removed from the output tuple. 38 | 39 | See: https://github.com/huggingface/transformers/blob/v4.30.2/src/transformers/trainer.py#L3509 40 | """ 41 | batch_size = inputs["input_ids"].size(0) // 2 42 | _, _, values = model(**inputs, output_hidden_states=True, return_dict=True) 43 | r_accept, r_reject = values[-1].split(batch_size, dim=0) 44 | loss = -torch.log(torch.sigmoid(r_accept - r_reject)).mean() 45 | return (loss, [loss, r_accept, r_reject]) if return_outputs else loss 46 | 47 | def save_predictions( 48 | self, 49 | predict_results: PredictionOutput 50 | ) -> None: 51 | r""" 52 | Saves model predictions to `output_dir`. 53 | A custom behavior that not contained in Seq2SeqTrainer. 54 | """ 55 | if not self.is_world_process_zero(): 56 | return 57 | 58 | output_prediction_file = os.path.join(self.args.output_dir, "generated_predictions.jsonl") 59 | logger.info(f"Saving prediction results to {output_prediction_file}") 60 | 61 | acc_scores, rej_scores = predict_results.predictions 62 | 63 | with open(output_prediction_file, "w", encoding="utf-8") as writer: 64 | res: List[str] = [] 65 | for acc_score, rej_score in zip(acc_scores, rej_scores): 66 | res.append(json.dumps({"accept": round(float(acc_score), 2), "reject": round(float(rej_score), 2)})) 67 | writer.write("\n".join(res)) 68 | -------------------------------------------------------------------------------- /src/glmtuner/tuner/rm/workflow.py: -------------------------------------------------------------------------------- 1 | # Inspired by: 2 | # https://github.com/lvwerra/trl/blob/main/examples/summarization/scripts/reward_summarization.py 3 | # https://github.com/CarperAI/trlx/blob/main/examples/summarize_rlhf/reward_model/train_reward_model_gptj.py 4 | 5 | from typing import Optional, List 6 | from transformers import Seq2SeqTrainingArguments, TrainerCallback 7 | 8 | from glmtuner.dsets import get_dataset, preprocess_dataset, split_dataset 9 | from glmtuner.extras.callbacks import LogCallback 10 | from glmtuner.extras.ploting import plot_loss 11 | from glmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments 12 | from glmtuner.tuner.core import load_model_and_tokenizer 13 | from glmtuner.tuner.rm.metric import compute_accuracy 14 | from glmtuner.tuner.rm.collator import PairwiseDataCollatorForChatGLM 15 | from glmtuner.tuner.rm.trainer import PairwiseTrainerForChatGLM 16 | 17 | 18 | def run_rm( 19 | model_args: ModelArguments, 20 | data_args: DataArguments, 21 | training_args: Seq2SeqTrainingArguments, 22 | finetuning_args: FinetuningArguments, 23 | callbacks: Optional[List[TrainerCallback]] = [LogCallback()] 24 | ): 25 | dataset = get_dataset(model_args, data_args) 26 | model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="rm") 27 | dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="rm") 28 | data_collator = PairwiseDataCollatorForChatGLM(tokenizer, model.pretrained_model) 29 | 30 | training_args.remove_unused_columns = False # Important for pairwise dataset 31 | 32 | # Initialize our Trainer 33 | trainer = PairwiseTrainerForChatGLM( 34 | finetuning_args=finetuning_args, 35 | model=model, 36 | args=training_args, 37 | tokenizer=tokenizer, 38 | data_collator=data_collator, 39 | callbacks=callbacks, 40 | compute_metrics=compute_accuracy, 41 | **split_dataset(dataset, data_args.dev_ratio, training_args.do_train) 42 | ) 43 | 44 | # Training 45 | if training_args.do_train: 46 | train_result = trainer.train() 47 | trainer.log_metrics("train", train_result.metrics) 48 | trainer.save_metrics("train", train_result.metrics) 49 | trainer.save_state() 50 | trainer.save_model() 51 | if trainer.is_world_process_zero() and model_args.plot_loss: 52 | plot_loss(training_args.output_dir, keys=["loss", "eval_loss"]) 53 | 54 | # Evaluation 55 | if training_args.do_eval: 56 | metrics = trainer.evaluate(metric_key_prefix="eval") 57 | trainer.log_metrics("eval", metrics) 58 | trainer.save_metrics("eval", metrics) 59 | 60 | # Predict 61 | if training_args.do_predict: 62 | predict_results = trainer.predict(dataset, metric_key_prefix="predict") 63 | trainer.log_metrics("predict", predict_results.metrics) 64 | trainer.save_metrics("predict", predict_results.metrics) 65 | trainer.save_predictions(predict_results) 66 | -------------------------------------------------------------------------------- /src/glmtuner/tuner/sft/__init__.py: -------------------------------------------------------------------------------- 1 | from glmtuner.tuner.sft.workflow import run_sft 2 | -------------------------------------------------------------------------------- /src/glmtuner/tuner/sft/metric.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from dataclasses import dataclass 3 | from typing import Dict, Sequence, Tuple, Union 4 | from transformers.tokenization_utils import PreTrainedTokenizer 5 | 6 | import jieba 7 | from rouge_chinese import Rouge 8 | from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction 9 | 10 | from glmtuner.extras.constants import IGNORE_INDEX 11 | 12 | 13 | @dataclass 14 | class ComputeMetrics: 15 | r""" 16 | Wraps the tokenizer into metric functions, used in Seq2SeqTrainerForChatGLM. 17 | """ 18 | 19 | tokenizer: PreTrainedTokenizer 20 | 21 | def __call__(self, eval_preds: Sequence[Union[np.ndarray, Tuple[np.ndarray]]]) -> Dict[str, float]: 22 | r""" 23 | Uses the model predictions to compute metrics. 24 | """ 25 | preds, labels = eval_preds 26 | score_dict = {"accuracy": [], "rouge-1": [], "rouge-2": [], "rouge-l": [], "bleu-4": []} 27 | 28 | preds = np.where(preds != IGNORE_INDEX, preds, self.tokenizer.pad_token_id) 29 | labels = np.where(labels != IGNORE_INDEX, labels, self.tokenizer.pad_token_id) 30 | 31 | decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True) 32 | decoded_labels = self.tokenizer.batch_decode(labels, skip_special_tokens=True) 33 | 34 | for pred, label in zip(decoded_preds, decoded_labels): 35 | hypothesis = list(jieba.cut(pred)) 36 | reference = list(jieba.cut(label)) 37 | 38 | if len(" ".join(hypothesis).split()) == 0 or len(" ".join(reference).split()) == 0: 39 | result = {"rouge-1": {"f": 0.0}, "rouge-2": {"f": 0.0}, "rouge-l": {"f": 0.0}} 40 | else: 41 | rouge = Rouge() 42 | scores = rouge.get_scores(" ".join(hypothesis), " ".join(reference)) 43 | result = scores[0] 44 | 45 | for k, v in result.items(): 46 | score_dict[k].append(round(v["f"] * 100, 4)) 47 | 48 | bleu_score = sentence_bleu([list(label)], list(pred), smoothing_function=SmoothingFunction().method3) 49 | score_dict["bleu-4"].append(round(bleu_score * 100, 4)) 50 | score_dict["accuracy"].append(float(len(label) != 0 and pred[:len(label)] == label)) 51 | 52 | return {k: float(np.mean(v)) for k, v in score_dict.items()} 53 | -------------------------------------------------------------------------------- /src/glmtuner/tuner/sft/trainer.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import torch 4 | import numpy as np 5 | import torch.nn as nn 6 | from typing import Any, Dict, List, Optional, Tuple, Union 7 | from transformers.trainer import PredictionOutput 8 | 9 | from glmtuner.extras.constants import IGNORE_INDEX 10 | from glmtuner.extras.logging import get_logger 11 | from glmtuner.tuner.core.trainer import PeftTrainer 12 | 13 | 14 | logger = get_logger(__name__) 15 | 16 | 17 | class Seq2SeqTrainerForChatGLM(PeftTrainer): 18 | r""" 19 | Inherits PeftTrainer to compute generative metrics such as BLEU and ROUGE. 20 | """ 21 | 22 | def prediction_step( 23 | self, 24 | model: nn.Module, 25 | inputs: Dict[str, Union[torch.Tensor, Any]], 26 | prediction_loss_only: bool, 27 | ignore_keys: Optional[List[str]] = None, 28 | ) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: 29 | r""" 30 | Removes the prompt part in the generated tokens. 31 | 32 | Subclass and override to inject custom behavior. 33 | """ 34 | input_ids = inputs["input_ids"] 35 | loss, generated_tokens, labels = super().prediction_step( 36 | model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys 37 | ) 38 | generated_tokens = generated_tokens[:, input_ids.size(-1):] if generated_tokens is not None else None 39 | return (loss, generated_tokens, labels) 40 | 41 | def save_predictions( 42 | self, 43 | predict_results: PredictionOutput 44 | ) -> None: 45 | r""" 46 | Saves model predictions to `output_dir`. 47 | 48 | A custom behavior that not contained in Seq2SeqTrainer. 49 | """ 50 | if not self.is_world_process_zero(): 51 | return 52 | 53 | output_prediction_file = os.path.join(self.args.output_dir, "generated_predictions.jsonl") 54 | logger.info(f"Saving prediction results to {output_prediction_file}") 55 | 56 | preds = np.where(predict_results.predictions != IGNORE_INDEX, predict_results.predictions, self.tokenizer.pad_token_id) 57 | labels = np.where(predict_results.label_ids != IGNORE_INDEX, predict_results.label_ids, self.tokenizer.pad_token_id) 58 | 59 | decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True) 60 | decoded_labels = self.tokenizer.batch_decode(labels, skip_special_tokens=True) 61 | 62 | with open(output_prediction_file, "w", encoding="utf-8") as writer: 63 | res: List[str] = [] 64 | for pred, label in zip(decoded_preds, decoded_labels): 65 | res.append(json.dumps({"label": label, "predict": pred}, ensure_ascii=False)) 66 | writer.write("\n".join(res)) 67 | -------------------------------------------------------------------------------- /src/glmtuner/tuner/sft/workflow.py: -------------------------------------------------------------------------------- 1 | # Inspired by: https://github.com/THUDM/ChatGLM-6B/blob/main/ptuning/main.py 2 | 3 | from typing import Optional, List 4 | from transformers import Seq2SeqTrainingArguments, TrainerCallback 5 | 6 | from glmtuner.dsets import DataCollatorForChatGLM, get_dataset, preprocess_dataset, split_dataset 7 | from glmtuner.extras.callbacks import LogCallback 8 | from glmtuner.extras.misc import get_logits_processor 9 | from glmtuner.extras.ploting import plot_loss 10 | from glmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments 11 | from glmtuner.tuner.core import load_model_and_tokenizer 12 | from glmtuner.tuner.sft.metric import ComputeMetrics 13 | from glmtuner.tuner.sft.trainer import Seq2SeqTrainerForChatGLM 14 | 15 | 16 | def run_sft( 17 | model_args: ModelArguments, 18 | data_args: DataArguments, 19 | training_args: Seq2SeqTrainingArguments, 20 | finetuning_args: FinetuningArguments, 21 | callbacks: Optional[List[TrainerCallback]] = [LogCallback()] 22 | ): 23 | dataset = get_dataset(model_args, data_args) 24 | model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="sft") 25 | dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="sft") 26 | data_collator = DataCollatorForChatGLM( 27 | tokenizer=tokenizer, 28 | model=model, 29 | ignore_pad_token_for_loss=(data_args.ignore_pad_token_for_loss and not training_args.predict_with_generate) 30 | ) 31 | 32 | # Override the decoding parameters of Seq2SeqTrainer 33 | training_args.generation_max_length = training_args.generation_max_length if \ 34 | training_args.generation_max_length is not None else data_args.max_target_length 35 | training_args.generation_num_beams = data_args.eval_num_beams if \ 36 | data_args.eval_num_beams is not None else training_args.generation_num_beams 37 | 38 | # Initialize our Trainer 39 | trainer = Seq2SeqTrainerForChatGLM( 40 | finetuning_args=finetuning_args, 41 | model=model, 42 | args=training_args, 43 | tokenizer=tokenizer, 44 | data_collator=data_collator, 45 | callbacks=callbacks, 46 | compute_metrics=ComputeMetrics(tokenizer) if training_args.predict_with_generate else None, 47 | **split_dataset(dataset, data_args.dev_ratio, training_args.do_train) 48 | ) 49 | 50 | # Keyword arguments for `model.generate` 51 | gen_kwargs = { 52 | "do_sample": True, 53 | "top_p": 0.7, 54 | "max_new_tokens": data_args.max_target_length + 1, 55 | "temperature": 0.95, 56 | "logits_processor": get_logits_processor() 57 | } 58 | 59 | # Training 60 | if training_args.do_train: 61 | train_result = trainer.train() 62 | trainer.log_metrics("train", train_result.metrics) 63 | trainer.save_metrics("train", train_result.metrics) 64 | trainer.save_state() 65 | trainer.save_model() 66 | if trainer.is_world_process_zero() and model_args.plot_loss: 67 | plot_loss(training_args.output_dir, keys=["loss", "eval_loss"]) 68 | 69 | # Evaluation 70 | if training_args.do_eval: 71 | metrics = trainer.evaluate(metric_key_prefix="eval", **gen_kwargs) 72 | if training_args.predict_with_generate: # eval_loss will be wrong if predict_with_generate is enabled 73 | metrics.pop("eval_loss", None) 74 | trainer.log_metrics("eval", metrics) 75 | trainer.save_metrics("eval", metrics) 76 | 77 | # Predict 78 | if training_args.do_predict: 79 | predict_results = trainer.predict(dataset, metric_key_prefix="predict", **gen_kwargs) 80 | if training_args.predict_with_generate: # predict_loss will be wrong if predict_with_generate is enabled 81 | predict_results.metrics.pop("predict_loss", None) 82 | trainer.log_metrics("predict", predict_results.metrics) 83 | trainer.save_metrics("predict", predict_results.metrics) 84 | trainer.save_predictions(predict_results) 85 | -------------------------------------------------------------------------------- /src/glmtuner/webui/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lunyiliu/CoachLM/37617299916c3ae550b2f6da713835e21bea6e60/src/glmtuner/webui/__init__.py -------------------------------------------------------------------------------- /src/glmtuner/webui/chat.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import List, Tuple 3 | 4 | from glmtuner.chat.stream_chat import ChatModel 5 | from glmtuner.extras.misc import torch_gc 6 | from glmtuner.hparams import GeneratingArguments 7 | from glmtuner.tuner import get_infer_args 8 | from glmtuner.webui.common import get_model_path, get_save_dir 9 | from glmtuner.webui.locales import ALERTS 10 | 11 | 12 | class WebChatModel(ChatModel): 13 | 14 | def __init__(self, *args): 15 | self.model = None 16 | self.tokenizer = None 17 | self.generating_args = GeneratingArguments() 18 | if len(args) != 0: 19 | super().__init__(*args) 20 | 21 | def load_model( 22 | self, 23 | lang: str, 24 | model_name: str, 25 | checkpoints: List[str], 26 | finetuning_type: str, 27 | quantization_bit: str, 28 | source_prefix: str 29 | ): 30 | if self.model is not None: 31 | yield ALERTS["err_exists"][lang] 32 | return 33 | 34 | if not model_name: 35 | yield ALERTS["err_no_model"][lang] 36 | return 37 | 38 | model_name_or_path = get_model_path(model_name) 39 | if not model_name_or_path: 40 | yield ALERTS["err_no_path"][lang] 41 | return 42 | 43 | if checkpoints: 44 | checkpoint_dir = ",".join( 45 | [os.path.join(get_save_dir(model_name), finetuning_type, checkpoint) for checkpoint in checkpoints] 46 | ) 47 | else: 48 | checkpoint_dir = None 49 | 50 | yield ALERTS["info_loading"][lang] 51 | args = dict( 52 | model_name_or_path=model_name_or_path, 53 | checkpoint_dir=checkpoint_dir, 54 | finetuning_type=finetuning_type, 55 | quantization_bit=int(quantization_bit) if quantization_bit else None, 56 | source_prefix=source_prefix 57 | ) 58 | super().__init__(*get_infer_args(args)) 59 | yield ALERTS["info_loaded"][lang] 60 | 61 | def unload_model(self, lang: str): 62 | yield ALERTS["info_unloading"][lang] 63 | self.model = None 64 | self.tokenizer = None 65 | torch_gc() 66 | yield ALERTS["info_unloaded"][lang] 67 | 68 | def predict( 69 | self, 70 | chatbot: List[Tuple[str, str]], 71 | query: str, 72 | history: List[Tuple[str, str]], 73 | prefix: str, 74 | max_length: int, 75 | top_p: float, 76 | temperature: float 77 | ): 78 | chatbot.append([query, ""]) 79 | response = "" 80 | for new_text in self.stream_chat( 81 | query, history, prefix, max_length=max_length, top_p=top_p, temperature=temperature 82 | ): 83 | response += new_text 84 | response = self.postprocess(response) 85 | new_history = history + [(query, response)] 86 | chatbot[-1] = [query, response] 87 | yield chatbot, new_history 88 | 89 | def postprocess(self, response: str) -> str: 90 | blocks = response.split("```") 91 | for i, block in enumerate(blocks): 92 | if i % 2 == 0: 93 | blocks[i] = block.replace("<", "<").replace(">", ">") 94 | return "```".join(blocks) 95 | -------------------------------------------------------------------------------- /src/glmtuner/webui/common.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from typing import Any, Dict, Optional 4 | 5 | import gradio as gr 6 | from peft.utils import WEIGHTS_NAME as PEFT_WEIGHTS_NAME 7 | from transformers.trainer import WEIGHTS_NAME, WEIGHTS_INDEX_NAME 8 | 9 | from glmtuner.extras.constants import SUPPORTED_MODELS 10 | 11 | 12 | DEFAULT_CACHE_DIR = "cache" 13 | DEFAULT_DATA_DIR = "data" 14 | DEFAULT_SAVE_DIR = "saves" 15 | USER_CONFIG = "user.config" 16 | DATA_CONFIG = "dataset_info.json" 17 | 18 | 19 | def get_save_dir(model_name: str) -> str: 20 | return os.path.join(DEFAULT_SAVE_DIR, os.path.split(model_name)[-1]) 21 | 22 | 23 | def get_config_path() -> os.PathLike: 24 | return os.path.join(DEFAULT_CACHE_DIR, USER_CONFIG) 25 | 26 | 27 | def load_config() -> Dict[str, Any]: 28 | try: 29 | with open(get_config_path(), "r", encoding="utf-8") as f: 30 | return json.load(f) 31 | except: 32 | return {"last_model": "", "path_dict": {}} 33 | 34 | 35 | def save_config(model_name: str, model_path: str) -> None: 36 | os.makedirs(DEFAULT_CACHE_DIR, exist_ok=True) 37 | user_config = load_config() 38 | user_config["last_model"] = model_name 39 | user_config["path_dict"][model_name] = model_path 40 | with open(get_config_path(), "w", encoding="utf-8") as f: 41 | json.dump(user_config, f, indent=2, ensure_ascii=False) 42 | 43 | 44 | def get_model_path(model_name: str) -> str: 45 | user_config = load_config() 46 | return user_config["path_dict"].get(model_name, SUPPORTED_MODELS.get(model_name, "")) 47 | 48 | 49 | def list_checkpoint(model_name: str, finetuning_type: str) -> Dict[str, Any]: 50 | checkpoints = [] 51 | save_dir = os.path.join(get_save_dir(model_name), finetuning_type) 52 | if save_dir and os.path.isdir(save_dir): 53 | for checkpoint in os.listdir(save_dir): 54 | if ( 55 | os.path.isdir(os.path.join(save_dir, checkpoint)) 56 | and any([ 57 | os.path.isfile(os.path.join(save_dir, checkpoint, name)) 58 | for name in (WEIGHTS_NAME, WEIGHTS_INDEX_NAME, PEFT_WEIGHTS_NAME) 59 | ]) 60 | ): 61 | checkpoints.append(checkpoint) 62 | return gr.update(value=[], choices=checkpoints) 63 | 64 | 65 | def load_dataset_info(dataset_dir: str) -> Dict[str, Any]: 66 | try: 67 | with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f: 68 | return json.load(f) 69 | except: 70 | return {} 71 | 72 | 73 | def list_dataset(dataset_dir: Optional[str] = None) -> Dict[str, Any]: 74 | dataset_info = load_dataset_info(dataset_dir if dataset_dir is not None else DEFAULT_DATA_DIR) 75 | return gr.update(value=[], choices=list(dataset_info.keys())) 76 | -------------------------------------------------------------------------------- /src/glmtuner/webui/components/__init__.py: -------------------------------------------------------------------------------- 1 | from glmtuner.webui.components.top import create_top 2 | from glmtuner.webui.components.sft import create_sft_tab 3 | from glmtuner.webui.components.eval import create_eval_tab 4 | from glmtuner.webui.components.infer import create_infer_tab 5 | from glmtuner.webui.components.export import create_export_tab 6 | -------------------------------------------------------------------------------- /src/glmtuner/webui/components/chatbot.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, Optional, Tuple 2 | 3 | import gradio as gr 4 | from gradio.blocks import Block 5 | from gradio.components import Component 6 | 7 | from glmtuner.webui.chat import WebChatModel 8 | 9 | 10 | def create_chat_box( 11 | chat_model: WebChatModel, 12 | visible: Optional[bool] = False 13 | ) -> Tuple[Block, Component, Component, Dict[str, Component]]: 14 | with gr.Box(visible=visible) as chat_box: 15 | chatbot = gr.Chatbot() 16 | 17 | with gr.Row(): 18 | with gr.Column(scale=4): 19 | prefix = gr.Dropdown(show_label=False) 20 | query = gr.Textbox(show_label=False, lines=8) 21 | submit_btn = gr.Button(variant="primary") 22 | 23 | with gr.Column(scale=1): 24 | clear_btn = gr.Button() 25 | max_length = gr.Slider(10, 2048, value=chat_model.generating_args.max_length, step=1) 26 | top_p = gr.Slider(0.01, 1, value=chat_model.generating_args.top_p, step=0.01) 27 | temperature = gr.Slider(0.01, 1.5, value=chat_model.generating_args.temperature, step=0.01) 28 | 29 | history = gr.State([]) 30 | 31 | submit_btn.click( 32 | chat_model.predict, 33 | [chatbot, query, history, prefix, max_length, top_p, temperature], 34 | [chatbot, history], 35 | show_progress=True 36 | ).then( 37 | lambda: gr.update(value=""), outputs=[query] 38 | ) 39 | 40 | clear_btn.click(lambda: ([], []), outputs=[chatbot, history], show_progress=True) 41 | 42 | return chat_box, chatbot, history, dict( 43 | prefix=prefix, 44 | query=query, 45 | submit_btn=submit_btn, 46 | clear_btn=clear_btn, 47 | max_length=max_length, 48 | top_p=top_p, 49 | temperature=temperature 50 | ) 51 | -------------------------------------------------------------------------------- /src/glmtuner/webui/components/data.py: -------------------------------------------------------------------------------- 1 | import gradio as gr 2 | from gradio.blocks import Block 3 | from gradio.components import Component 4 | from typing import Tuple 5 | 6 | 7 | def create_preview_box() -> Tuple[Block, Component, Component, Component]: 8 | with gr.Box(visible=False, elem_classes="modal-box") as preview_box: 9 | with gr.Row(): 10 | preview_count = gr.Number(interactive=False) 11 | 12 | with gr.Row(): 13 | preview_samples = gr.JSON(interactive=False) 14 | 15 | close_btn = gr.Button() 16 | 17 | close_btn.click(lambda: gr.update(visible=False), outputs=[preview_box]) 18 | 19 | return preview_box, preview_count, preview_samples, close_btn 20 | -------------------------------------------------------------------------------- /src/glmtuner/webui/components/eval.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | import gradio as gr 3 | from gradio.components import Component 4 | 5 | from glmtuner.webui.common import list_dataset, DEFAULT_DATA_DIR 6 | from glmtuner.webui.components.data import create_preview_box 7 | from glmtuner.webui.runner import Runner 8 | from glmtuner.webui.utils import can_preview, get_preview 9 | 10 | 11 | def create_eval_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str, Component]: 12 | with gr.Row(): 13 | dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=2) 14 | dataset = gr.Dropdown(multiselect=True, scale=4) 15 | preview_btn = gr.Button(interactive=False, scale=1) 16 | 17 | preview_box, preview_count, preview_samples, close_btn = create_preview_box() 18 | 19 | dataset_dir.change(list_dataset, [dataset_dir], [dataset]) 20 | dataset.change(can_preview, [dataset_dir, dataset], [preview_btn]) 21 | preview_btn.click(get_preview, [dataset_dir, dataset], [preview_count, preview_samples, preview_box]) 22 | 23 | with gr.Row(): 24 | max_source_length = gr.Slider(value=512, minimum=4, maximum=4096, step=1) 25 | max_target_length = gr.Slider(value=512, minimum=4, maximum=4096, step=1) 26 | max_samples = gr.Textbox(value="100000") 27 | batch_size = gr.Slider(value=8, minimum=1, maximum=512, step=1) 28 | predict = gr.Checkbox(value=True) 29 | 30 | with gr.Row(): 31 | start_btn = gr.Button() 32 | stop_btn = gr.Button() 33 | 34 | with gr.Box(): 35 | output_box = gr.Markdown() 36 | 37 | start_btn.click( 38 | runner.run_eval, 39 | [ 40 | top_elems["lang"], 41 | top_elems["model_name"], 42 | top_elems["checkpoints"], 43 | top_elems["finetuning_type"], 44 | top_elems["quantization_bit"], 45 | top_elems["source_prefix"], 46 | dataset_dir, 47 | dataset, 48 | max_source_length, 49 | max_target_length, 50 | max_samples, 51 | batch_size, 52 | predict 53 | ], 54 | [output_box] 55 | ) 56 | stop_btn.click(runner.set_abort, queue=False) 57 | 58 | return dict( 59 | dataset_dir=dataset_dir, 60 | dataset=dataset, 61 | preview_btn=preview_btn, 62 | preview_count=preview_count, 63 | preview_samples=preview_samples, 64 | close_btn=close_btn, 65 | max_source_length=max_source_length, 66 | max_target_length=max_target_length, 67 | max_samples=max_samples, 68 | batch_size=batch_size, 69 | predict=predict, 70 | start_btn=start_btn, 71 | stop_btn=stop_btn, 72 | output_box=output_box 73 | ) 74 | -------------------------------------------------------------------------------- /src/glmtuner/webui/components/export.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | import gradio as gr 3 | from gradio.components import Component 4 | 5 | from glmtuner.webui.utils import export_model 6 | 7 | 8 | def create_export_tab(top_elems: Dict[str, Component]) -> Dict[str, Component]: 9 | with gr.Row(): 10 | save_dir = gr.Textbox() 11 | max_shard_size = gr.Slider(value=10, minimum=1, maximum=100) 12 | 13 | export_btn = gr.Button() 14 | info_box = gr.Textbox(show_label=False, interactive=False) 15 | 16 | export_btn.click( 17 | export_model, 18 | [ 19 | top_elems["lang"], 20 | top_elems["model_name"], 21 | top_elems["checkpoints"], 22 | top_elems["finetuning_type"], 23 | max_shard_size, 24 | save_dir 25 | ], 26 | [info_box] 27 | ) 28 | 29 | return dict( 30 | save_dir=save_dir, 31 | max_shard_size=max_shard_size, 32 | export_btn=export_btn, 33 | info_box=info_box 34 | ) 35 | -------------------------------------------------------------------------------- /src/glmtuner/webui/components/infer.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | import gradio as gr 4 | from gradio.components import Component 5 | 6 | from glmtuner.webui.chat import WebChatModel 7 | from glmtuner.webui.components.chatbot import create_chat_box 8 | 9 | 10 | def create_infer_tab(top_elems: Dict[str, Component]) -> Dict[str, Component]: 11 | with gr.Row(): 12 | load_btn = gr.Button() 13 | unload_btn = gr.Button() 14 | 15 | info_box = gr.Textbox(show_label=False, interactive=False) 16 | 17 | chat_model = WebChatModel() 18 | chat_box, chatbot, history, chat_elems = create_chat_box(chat_model) 19 | 20 | load_btn.click( 21 | chat_model.load_model, 22 | [ 23 | top_elems["lang"], 24 | top_elems["model_name"], 25 | top_elems["checkpoints"], 26 | top_elems["finetuning_type"], 27 | top_elems["quantization_bit"], 28 | top_elems["source_prefix"] 29 | ], 30 | [info_box] 31 | ).then( 32 | lambda: gr.update(visible=(chat_model.model is not None)), outputs=[chat_box] 33 | ) 34 | 35 | unload_btn.click( 36 | chat_model.unload_model, [top_elems["lang"]], [info_box] 37 | ).then( 38 | lambda: ([], []), outputs=[chatbot, history] 39 | ).then( 40 | lambda: gr.update(visible=(chat_model.model is not None)), outputs=[chat_box] 41 | ) 42 | 43 | return dict( 44 | info_box=info_box, 45 | load_btn=load_btn, 46 | unload_btn=unload_btn, 47 | **chat_elems 48 | ) 49 | -------------------------------------------------------------------------------- /src/glmtuner/webui/components/sft.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | from transformers.trainer_utils import SchedulerType 3 | 4 | import gradio as gr 5 | from gradio.components import Component 6 | 7 | from glmtuner.webui.common import list_dataset, DEFAULT_DATA_DIR 8 | from glmtuner.webui.components.data import create_preview_box 9 | from glmtuner.webui.runner import Runner 10 | from glmtuner.webui.utils import can_preview, get_preview, gen_plot 11 | 12 | 13 | def create_sft_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str, Component]: 14 | with gr.Row(): 15 | dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=2) 16 | dataset = gr.Dropdown(multiselect=True, scale=4) 17 | preview_btn = gr.Button(interactive=False, scale=1) 18 | 19 | preview_box, preview_count, preview_samples, close_btn = create_preview_box() 20 | 21 | dataset_dir.change(list_dataset, [dataset_dir], [dataset]) 22 | dataset.change(can_preview, [dataset_dir, dataset], [preview_btn]) 23 | preview_btn.click(get_preview, [dataset_dir, dataset], [preview_count, preview_samples, preview_box]) 24 | 25 | with gr.Row(): 26 | max_source_length = gr.Slider(value=512, minimum=4, maximum=4096, step=1) 27 | max_target_length = gr.Slider(value=512, minimum=4, maximum=4096, step=1) 28 | learning_rate = gr.Textbox(value="5e-5") 29 | num_train_epochs = gr.Textbox(value="3.0") 30 | max_samples = gr.Textbox(value="100000") 31 | 32 | with gr.Row(): 33 | batch_size = gr.Slider(value=4, minimum=1, maximum=512, step=1) 34 | gradient_accumulation_steps = gr.Slider(value=4, minimum=1, maximum=512, step=1) 35 | lr_scheduler_type = gr.Dropdown( 36 | value="cosine", choices=[scheduler.value for scheduler in SchedulerType] 37 | ) 38 | max_grad_norm = gr.Textbox(value="1.0") 39 | dev_ratio = gr.Slider(value=0, minimum=0, maximum=1, step=0.001) 40 | 41 | with gr.Accordion(label="Advanced config", open=False) as advanced_tab: 42 | with gr.Row(): 43 | logging_steps = gr.Slider(value=5, minimum=5, maximum=1000, step=5) 44 | save_steps = gr.Slider(value=100, minimum=10, maximum=5000, step=10) 45 | warmup_steps = gr.Slider(value=0, minimum=0, maximum=5000, step=1) 46 | compute_type = gr.Radio(choices=["fp16", "bf16", "fp32"], value="fp16") 47 | 48 | with gr.Accordion(label="LoRA config", open=False) as lora_tab: 49 | with gr.Row(): 50 | lora_rank = gr.Slider(value=8, minimum=1, maximum=1024, step=1, scale=1) 51 | lora_dropout = gr.Slider(value=0, minimum=0, maximum=1, step=0.01, scale=1) 52 | lora_target = gr.Textbox(scale=2) 53 | 54 | with gr.Row(): 55 | start_btn = gr.Button() 56 | stop_btn = gr.Button() 57 | 58 | with gr.Row(): 59 | with gr.Column(scale=4): 60 | output_dir = gr.Textbox() 61 | 62 | with gr.Box(): 63 | output_box = gr.Markdown() 64 | 65 | with gr.Column(scale=1): 66 | loss_viewer = gr.Plot() 67 | 68 | start_btn.click( 69 | runner.run_train, 70 | [ 71 | top_elems["lang"], 72 | top_elems["model_name"], 73 | top_elems["checkpoints"], 74 | top_elems["finetuning_type"], 75 | top_elems["quantization_bit"], 76 | top_elems["source_prefix"], 77 | dataset_dir, 78 | dataset, 79 | max_source_length, 80 | max_target_length, 81 | learning_rate, 82 | num_train_epochs, 83 | max_samples, 84 | batch_size, 85 | gradient_accumulation_steps, 86 | lr_scheduler_type, 87 | max_grad_norm, 88 | dev_ratio, 89 | logging_steps, 90 | save_steps, 91 | warmup_steps, 92 | compute_type, 93 | lora_rank, 94 | lora_dropout, 95 | lora_target, 96 | output_dir 97 | ], 98 | [output_box] 99 | ) 100 | stop_btn.click(runner.set_abort, queue=False) 101 | 102 | output_box.change( 103 | gen_plot, [top_elems["model_name"], top_elems["finetuning_type"], output_dir], loss_viewer, queue=False 104 | ) 105 | 106 | return dict( 107 | dataset_dir=dataset_dir, 108 | dataset=dataset, 109 | preview_btn=preview_btn, 110 | preview_count=preview_count, 111 | preview_samples=preview_samples, 112 | close_btn=close_btn, 113 | max_source_length=max_source_length, 114 | max_target_length=max_target_length, 115 | learning_rate=learning_rate, 116 | num_train_epochs=num_train_epochs, 117 | max_samples=max_samples, 118 | batch_size=batch_size, 119 | gradient_accumulation_steps=gradient_accumulation_steps, 120 | lr_scheduler_type=lr_scheduler_type, 121 | max_grad_norm=max_grad_norm, 122 | dev_ratio=dev_ratio, 123 | advanced_tab=advanced_tab, 124 | logging_steps=logging_steps, 125 | save_steps=save_steps, 126 | warmup_steps=warmup_steps, 127 | compute_type=compute_type, 128 | lora_tab=lora_tab, 129 | lora_rank=lora_rank, 130 | lora_dropout=lora_dropout, 131 | lora_target=lora_target, 132 | start_btn=start_btn, 133 | stop_btn=stop_btn, 134 | output_dir=output_dir, 135 | output_box=output_box, 136 | loss_viewer=loss_viewer 137 | ) 138 | -------------------------------------------------------------------------------- /src/glmtuner/webui/components/top.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | import gradio as gr 4 | from gradio.components import Component 5 | 6 | from glmtuner.extras.constants import METHODS, SUPPORTED_MODELS 7 | from glmtuner.webui.common import list_checkpoint, get_model_path, save_config 8 | from glmtuner.webui.utils import can_quantize 9 | 10 | 11 | def create_top() -> Dict[str, Component]: 12 | available_models = list(SUPPORTED_MODELS.keys()) + ["Custom"] 13 | 14 | with gr.Row(): 15 | lang = gr.Dropdown(choices=["en", "zh"], value="en", scale=1) 16 | model_name = gr.Dropdown(choices=available_models, scale=3) 17 | model_path = gr.Textbox(scale=3) 18 | 19 | with gr.Row(): 20 | finetuning_type = gr.Dropdown(value="lora", choices=METHODS, scale=1) 21 | checkpoints = gr.Dropdown(multiselect=True, scale=5) 22 | refresh_btn = gr.Button(scale=1) 23 | 24 | with gr.Accordion(label="Advanced config", open=False) as advanced_tab: 25 | with gr.Row(): 26 | quantization_bit = gr.Dropdown([8, 4], scale=1) 27 | source_prefix = gr.Textbox(scale=4) 28 | 29 | model_name.change( 30 | get_model_path, [model_name], [model_path] 31 | ).then( 32 | list_checkpoint, [model_name, finetuning_type], [checkpoints] 33 | ) # do not save config since the below line will save 34 | model_path.change(save_config, [model_name, model_path]) 35 | 36 | finetuning_type.change( 37 | list_checkpoint, [model_name, finetuning_type], [checkpoints] 38 | ).then( 39 | can_quantize, [finetuning_type], [quantization_bit] 40 | ) 41 | 42 | refresh_btn.click(list_checkpoint, [model_name, finetuning_type], [checkpoints]) 43 | 44 | return dict( 45 | lang=lang, 46 | model_name=model_name, 47 | model_path=model_path, 48 | finetuning_type=finetuning_type, 49 | checkpoints=checkpoints, 50 | refresh_btn=refresh_btn, 51 | advanced_tab=advanced_tab, 52 | quantization_bit=quantization_bit, 53 | source_prefix=source_prefix 54 | ) 55 | -------------------------------------------------------------------------------- /src/glmtuner/webui/css.py: -------------------------------------------------------------------------------- 1 | CSS = r""" 2 | .modal-box { 3 | position: fixed !important; 4 | top: 50%; 5 | left: 50%; 6 | transform: translate(-50%, -50%); /* center horizontally */ 7 | max-width: 1000px; 8 | max-height: 750px; 9 | overflow-y: scroll !important; 10 | background-color: var(--input-background-fill); 11 | border: 2px solid black !important; 12 | z-index: 1000; 13 | } 14 | 15 | .dark .modal-box { 16 | border: 2px solid white !important; 17 | } 18 | """ 19 | -------------------------------------------------------------------------------- /src/glmtuner/webui/interface.py: -------------------------------------------------------------------------------- 1 | import gradio as gr 2 | from transformers.utils.versions import require_version 3 | 4 | from glmtuner.webui.components import ( 5 | create_top, 6 | create_sft_tab, 7 | create_eval_tab, 8 | create_infer_tab, 9 | create_export_tab 10 | ) 11 | from glmtuner.webui.css import CSS 12 | from glmtuner.webui.manager import Manager 13 | from glmtuner.webui.runner import Runner 14 | 15 | 16 | require_version("gradio>=3.36.0", "To fix: pip install gradio>=3.36.0") 17 | 18 | 19 | def create_ui() -> gr.Blocks: 20 | runner = Runner() 21 | 22 | with gr.Blocks(title="Web Tuner", css=CSS) as demo: 23 | top_elems = create_top() 24 | 25 | with gr.Tab("SFT"): 26 | sft_elems = create_sft_tab(top_elems, runner) 27 | 28 | with gr.Tab("Evaluate"): 29 | eval_elems = create_eval_tab(top_elems, runner) 30 | 31 | with gr.Tab("Chat"): 32 | infer_elems = create_infer_tab(top_elems) 33 | 34 | with gr.Tab("Export"): 35 | export_elems = create_export_tab(top_elems) 36 | 37 | elem_list = [top_elems, sft_elems, eval_elems, infer_elems, export_elems] 38 | manager = Manager(elem_list) 39 | 40 | demo.load( 41 | manager.gen_label, 42 | [top_elems["lang"]], 43 | [elem for elems in elem_list for elem in elems.values()], 44 | ) 45 | 46 | top_elems["lang"].change( 47 | manager.gen_label, 48 | [top_elems["lang"]], 49 | [elem for elems in elem_list for elem in elems.values()], 50 | ) 51 | 52 | return demo 53 | 54 | 55 | if __name__ == "__main__": 56 | demo = create_ui() 57 | demo.queue() 58 | demo.launch(server_name="0.0.0.0", server_port=7860, share=False, inbrowser=True) 59 | -------------------------------------------------------------------------------- /src/glmtuner/webui/locales.py: -------------------------------------------------------------------------------- 1 | LOCALES = { 2 | "lang": { 3 | "en": { 4 | "label": "Lang" 5 | }, 6 | "zh": { 7 | "label": "语言" 8 | } 9 | }, 10 | "model_name": { 11 | "en": { 12 | "label": "Model name" 13 | }, 14 | "zh": { 15 | "label": "模型名称" 16 | } 17 | }, 18 | "model_path": { 19 | "en": { 20 | "label": "Model path", 21 | "info": "Path to pretrained model or model identifier from Hugging Face." 22 | }, 23 | "zh": { 24 | "label": "模型路径", 25 | "info": "本地模型的文件路径或 Hugging Face 的模型标识符。" 26 | } 27 | }, 28 | "finetuning_type": { 29 | "en": { 30 | "label": "Finetuning method" 31 | }, 32 | "zh": { 33 | "label": "微调方法" 34 | } 35 | }, 36 | "checkpoints": { 37 | "en": { 38 | "label": "Checkpoints" 39 | }, 40 | "zh": { 41 | "label": "模型断点" 42 | } 43 | }, 44 | "refresh_btn": { 45 | "en": { 46 | "value": "Refresh checkpoints" 47 | }, 48 | "zh": { 49 | "value": "刷新断点" 50 | } 51 | }, 52 | "advanced_tab": { 53 | "en": { 54 | "label": "Advanced configurations" 55 | }, 56 | "zh": { 57 | "label": "高级设置" 58 | } 59 | }, 60 | "quantization_bit": { 61 | "en": { 62 | "label": "Quantization bit (optional)", 63 | "info": "Enable 4/8-bit model quantization." 64 | }, 65 | "zh": { 66 | "label": "量化等级(非必填)", 67 | "info": "启用 4/8 比特模型量化。" 68 | } 69 | }, 70 | "template": { 71 | "en": { 72 | "label": "Prompt template", 73 | "info": "The template used in constructing prompts." 74 | }, 75 | "zh": { 76 | "label": "提示模板", 77 | "info": "构建提示词时使用的模板" 78 | } 79 | }, 80 | "source_prefix": { 81 | "en": { 82 | "label": "System prompt (optional)", 83 | "info": "A sequence used as the default system prompt." 84 | }, 85 | "zh": { 86 | "label": "系统提示词(非必填)", 87 | "info": "默认使用的系统提示词" 88 | } 89 | }, 90 | "dataset_dir": { 91 | "en": { 92 | "label": "Data dir", 93 | "info": "Path of the data directory." 94 | }, 95 | "zh": { 96 | "label": "数据路径", 97 | "info": "数据文件夹的路径。" 98 | } 99 | }, 100 | "dataset": { 101 | "en": { 102 | "label": "Dataset" 103 | }, 104 | "zh": { 105 | "label": "数据集" 106 | } 107 | }, 108 | "preview_btn": { 109 | "en": { 110 | "value": "Preview" 111 | }, 112 | "zh": { 113 | "value": "预览" 114 | } 115 | }, 116 | "preview_count": { 117 | "en": { 118 | "label": "Count" 119 | }, 120 | "zh": { 121 | "label": "数量" 122 | } 123 | }, 124 | "preview_samples": { 125 | "en": { 126 | "label": "Samples" 127 | }, 128 | "zh": { 129 | "label": "样例" 130 | } 131 | }, 132 | "close_btn": { 133 | "en": { 134 | "value": "Close" 135 | }, 136 | "zh": { 137 | "value": "关闭" 138 | } 139 | }, 140 | "max_source_length": { 141 | "en": { 142 | "label": "Max source length", 143 | "info": "Max tokens in source sequence." 144 | }, 145 | "zh": { 146 | "label": "输入序列最大长度", 147 | "info": "输入序列分词后的最大长度。" 148 | } 149 | }, 150 | "max_target_length": { 151 | "en": { 152 | "label": "Max target length", 153 | "info": "Max tokens in target sequence." 154 | }, 155 | "zh": { 156 | "label": "输出序列最大长度", 157 | "info": "输出序列分词后的最大长度。" 158 | } 159 | }, 160 | "learning_rate": { 161 | "en": { 162 | "label": "Learning rate", 163 | "info": "Initial learning rate for AdamW." 164 | }, 165 | "zh": { 166 | "label": "学习率", 167 | "info": "AdamW 优化器的初始学习率。" 168 | } 169 | }, 170 | "num_train_epochs": { 171 | "en": { 172 | "label": "Epochs", 173 | "info": "Total number of training epochs to perform." 174 | }, 175 | "zh": { 176 | "label": "训练轮数", 177 | "info": "需要执行的训练总轮数。" 178 | } 179 | }, 180 | "max_samples": { 181 | "en": { 182 | "label": "Max samples", 183 | "info": "Maximum samples per dataset." 184 | }, 185 | "zh": { 186 | "label": "最大样本数", 187 | "info": "每个数据集最多使用的样本数。" 188 | } 189 | }, 190 | "batch_size": { 191 | "en": { 192 | "label": "Batch size", 193 | "info": "Number of samples to process per GPU." 194 | }, 195 | "zh":{ 196 | "label": "批处理大小", 197 | "info": "每块 GPU 上处理的样本数量。" 198 | } 199 | }, 200 | "gradient_accumulation_steps": { 201 | "en": { 202 | "label": "Gradient accumulation", 203 | "info": "Number of gradient accumulation steps." 204 | }, 205 | "zh": { 206 | "label": "梯度累积", 207 | "info": "梯度累积的步数。" 208 | } 209 | }, 210 | "lr_scheduler_type": { 211 | "en": { 212 | "label": "LR Scheduler", 213 | "info": "Name of learning rate scheduler.", 214 | }, 215 | "zh": { 216 | "label": "学习率调节器", 217 | "info": "采用的学习率调节器名称。" 218 | } 219 | }, 220 | "max_grad_norm": { 221 | "en": { 222 | "label": "Maximum gradient norm", 223 | "info": "Norm for gradient clipping.." 224 | }, 225 | "zh": { 226 | "label": "最大梯度范数", 227 | "info": "用于梯度裁剪的范数。" 228 | } 229 | }, 230 | "dev_ratio": { 231 | "en": { 232 | "label": "Dev ratio", 233 | "info": "Proportion of data in the dev set." 234 | }, 235 | "zh": { 236 | "label": "验证集比例", 237 | "info": "验证集占全部样本的百分比。" 238 | } 239 | }, 240 | "logging_steps": { 241 | "en": { 242 | "label": "Logging steps", 243 | "info": "Number of steps between two logs." 244 | }, 245 | "zh": { 246 | "label": "日志间隔", 247 | "info": "每两次日志输出间的更新步数。" 248 | } 249 | }, 250 | "save_steps": { 251 | "en": { 252 | "label": "Save steps", 253 | "info": "Number of steps between two checkpoints." 254 | }, 255 | "zh": { 256 | "label": "保存间隔", 257 | "info": "每两次断点保存间的更新步数。" 258 | } 259 | }, 260 | "warmup_steps": { 261 | "en": { 262 | "label": "Warmup steps", 263 | "info": "Number of steps used for warmup." 264 | }, 265 | "zh": { 266 | "label": "预热步数", 267 | "info": "学习率预热采用的步数。" 268 | } 269 | }, 270 | "compute_type": { 271 | "en": { 272 | "label": "Compute type", 273 | "info": "Whether to use fp16 or bf16 mixed precision training." 274 | }, 275 | "zh": { 276 | "label": "计算类型", 277 | "info": "是否启用 FP16 或 BF16 混合精度训练。" 278 | } 279 | }, 280 | "lora_tab": { 281 | "en": { 282 | "label": "LoRA configurations" 283 | }, 284 | "zh": { 285 | "label": "LoRA 参数设置" 286 | } 287 | }, 288 | "lora_rank": { 289 | "en": { 290 | "label": "LoRA rank", 291 | "info": "The rank of LoRA matrices." 292 | }, 293 | "zh": { 294 | "label": "LoRA 秩", 295 | "info": "LoRA 矩阵的秩。" 296 | } 297 | }, 298 | "lora_dropout": { 299 | "en": { 300 | "label": "LoRA Dropout", 301 | "info": "Dropout ratio of LoRA weights." 302 | }, 303 | "zh": { 304 | "label": "LoRA 随机丢弃", 305 | "info": "LoRA 权重随机丢弃的概率。" 306 | } 307 | }, 308 | "lora_target": { 309 | "en": { 310 | "label": "LoRA modules (optional)", 311 | "info": "The name(s) of target modules to apply LoRA. Use commas to separate multiple modules." 312 | }, 313 | "zh": { 314 | "label": "LoRA 作用层(非必填)", 315 | "info": "应用 LoRA 的线性层名称。使用英文逗号分隔多个名称。" 316 | } 317 | }, 318 | "start_btn": { 319 | "en": { 320 | "value": "Start" 321 | }, 322 | "zh": { 323 | "value": "开始" 324 | } 325 | }, 326 | "stop_btn": { 327 | "en": { 328 | "value": "Abort" 329 | }, 330 | "zh": { 331 | "value": "中断" 332 | } 333 | }, 334 | "output_dir": { 335 | "en": { 336 | "label": "Checkpoint name", 337 | "info": "Directory to save checkpoint." 338 | }, 339 | "zh": { 340 | "label": "断点名称", 341 | "info": "保存模型断点的文件夹名称。" 342 | } 343 | }, 344 | "output_box": { 345 | "en": { 346 | "value": "Ready." 347 | }, 348 | "zh": { 349 | "value": "准备就绪。" 350 | } 351 | }, 352 | "loss_viewer": { 353 | "en": { 354 | "label": "Loss" 355 | }, 356 | "zh": { 357 | "label": "损失" 358 | } 359 | }, 360 | "predict": { 361 | "en": { 362 | "label": "Save predictions" 363 | }, 364 | "zh": { 365 | "label": "保存预测结果" 366 | } 367 | }, 368 | "load_btn": { 369 | "en": { 370 | "value": "Load model" 371 | }, 372 | "zh": { 373 | "value": "加载模型" 374 | } 375 | }, 376 | "unload_btn": { 377 | "en": { 378 | "value": "Unload model" 379 | }, 380 | "zh": { 381 | "value": "卸载模型" 382 | } 383 | }, 384 | "info_box": { 385 | "en": { 386 | "value": "Model unloaded, please load a model first." 387 | }, 388 | "zh": { 389 | "value": "模型未加载,请先加载模型。" 390 | } 391 | }, 392 | "prefix": { 393 | "en": { 394 | "placeholder": "System prompt (optional)" 395 | }, 396 | "zh": { 397 | "placeholder": "系统提示词(非必填)" 398 | } 399 | }, 400 | "query": { 401 | "en": { 402 | "placeholder": "Input..." 403 | }, 404 | "zh": { 405 | "placeholder": "输入..." 406 | } 407 | }, 408 | "submit_btn": { 409 | "en": { 410 | "value": "Submit" 411 | }, 412 | "zh": { 413 | "value": "提交" 414 | } 415 | }, 416 | "clear_btn": { 417 | "en": { 418 | "value": "Clear history" 419 | }, 420 | "zh": { 421 | "value": "清空历史" 422 | } 423 | }, 424 | "max_length": { 425 | "en": { 426 | "label": "Maximum length" 427 | }, 428 | "zh": { 429 | "label": "最大长度" 430 | } 431 | }, 432 | "max_new_tokens": { 433 | "en": { 434 | "label": "Maximum new tokens" 435 | }, 436 | "zh": { 437 | "label": "最大生成长度" 438 | } 439 | }, 440 | "top_p": { 441 | "en": { 442 | "label": "Top-p" 443 | }, 444 | "zh": { 445 | "label": "Top-p 采样值" 446 | } 447 | }, 448 | "temperature": { 449 | "en": { 450 | "label": "Temperature" 451 | }, 452 | "zh": { 453 | "label": "温度系数" 454 | } 455 | }, 456 | "save_dir": { 457 | "en": { 458 | "label": "Export dir", 459 | "info": "Directory to save exported model." 460 | }, 461 | "zh": { 462 | "label": "导出目录", 463 | "info": "保存导出模型的文件夹路径。" 464 | } 465 | }, 466 | "max_shard_size": { 467 | "en": { 468 | "label": "Max shard size (GB)", 469 | "info": "The maximum size for a model file." 470 | }, 471 | "zh": { 472 | "label": "最大分块大小(GB)", 473 | "info": "模型文件的最大大小。" 474 | } 475 | }, 476 | "export_btn": { 477 | "en": { 478 | "value": "Export" 479 | }, 480 | "zh": { 481 | "value": "开始导出" 482 | } 483 | } 484 | } 485 | 486 | 487 | ALERTS = { 488 | "err_conflict": { 489 | "en": "A process is in running, please abort it firstly.", 490 | "zh": "任务已存在,请先中断训练。" 491 | }, 492 | "err_exists": { 493 | "en": "You have loaded a model, please unload it first.", 494 | "zh": "模型已存在,请先卸载模型。" 495 | }, 496 | "err_no_model": { 497 | "en": "Please select a model.", 498 | "zh": "请选择模型。" 499 | }, 500 | "err_no_path": { 501 | "en": "Model not found.", 502 | "zh": "模型未找到。" 503 | }, 504 | "err_no_dataset": { 505 | "en": "Please choose a dataset.", 506 | "zh": "请选择数据集。" 507 | }, 508 | "err_no_checkpoint": { 509 | "en": "Please select a checkpoint.", 510 | "zh": "请选择断点。" 511 | }, 512 | "err_no_save_dir": { 513 | "en": "Please provide export dir.", 514 | "zh": "请填写导出目录" 515 | }, 516 | "info_aborting": { 517 | "en": "Aborted, wait for terminating...", 518 | "zh": "训练中断,正在等待线程结束……" 519 | }, 520 | "info_aborted": { 521 | "en": "Ready.", 522 | "zh": "准备就绪。" 523 | }, 524 | "info_finished": { 525 | "en": "Finished.", 526 | "zh": "训练完毕。" 527 | }, 528 | "info_loading": { 529 | "en": "Loading model...", 530 | "zh": "加载中……" 531 | }, 532 | "info_unloading": { 533 | "en": "Unloading model...", 534 | "zh": "卸载中……" 535 | }, 536 | "info_loaded": { 537 | "en": "Model loaded, now you can chat with your model!", 538 | "zh": "模型已加载,可以开始聊天了!" 539 | }, 540 | "info_unloaded": { 541 | "en": "Model unloaded.", 542 | "zh": "模型已卸载。" 543 | }, 544 | "info_exporting": { 545 | "en": "Exporting model...", 546 | "zh": "正在导出模型……" 547 | }, 548 | "info_exported": { 549 | "en": "Model exported.", 550 | "zh": "模型导出完成。" 551 | } 552 | } 553 | -------------------------------------------------------------------------------- /src/glmtuner/webui/manager.py: -------------------------------------------------------------------------------- 1 | import gradio as gr 2 | from typing import Any, Dict, List 3 | from gradio.components import Component 4 | 5 | from glmtuner.webui.common import get_model_path, list_dataset, load_config 6 | from glmtuner.webui.locales import LOCALES 7 | from glmtuner.webui.utils import get_time 8 | 9 | 10 | class Manager: 11 | 12 | def __init__(self, elem_list: List[Dict[str, Component]]): 13 | self.elem_list = elem_list 14 | 15 | def gen_refresh(self) -> Dict[str, Any]: 16 | refresh_dict = { 17 | "dataset": {"choices": list_dataset()["choices"]}, 18 | "output_dir": {"value": get_time()} 19 | } 20 | user_config = load_config() 21 | if user_config["last_model"]: 22 | refresh_dict["model_name"] = {"value": user_config["last_model"]} 23 | refresh_dict["model_path"] = {"value": get_model_path(user_config["last_model"])} 24 | 25 | return refresh_dict 26 | 27 | def gen_label(self, lang: str) -> Dict[Component, dict]: 28 | update_dict = {} 29 | refresh_dict = self.gen_refresh() 30 | 31 | for elems in self.elem_list: 32 | for name, component in elems.items(): 33 | update_dict[component] = gr.update(**LOCALES[name][lang], **refresh_dict.get(name, {})) 34 | 35 | return update_dict 36 | -------------------------------------------------------------------------------- /src/glmtuner/webui/runner.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import threading 4 | import time 5 | import transformers 6 | from typing import Generator, List, Optional, Tuple 7 | 8 | from glmtuner.extras.callbacks import LogCallback 9 | from glmtuner.extras.logging import LoggerHandler 10 | from glmtuner.extras.misc import torch_gc 11 | from glmtuner.tuner import get_train_args, run_sft 12 | from glmtuner.webui.common import get_model_path, get_save_dir 13 | from glmtuner.webui.locales import ALERTS 14 | from glmtuner.webui.utils import format_info, get_eval_results 15 | 16 | 17 | class Runner: 18 | 19 | def __init__(self): 20 | self.aborted = False 21 | self.running = False 22 | 23 | def set_abort(self): 24 | self.aborted = True 25 | self.running = False 26 | 27 | def initialize( 28 | self, lang: str, model_name: str, dataset: List[str] 29 | ) -> Tuple[str, str, LoggerHandler, LogCallback]: 30 | if self.running: 31 | return None, ALERTS["err_conflict"][lang], None, None 32 | 33 | if not model_name: 34 | return None, ALERTS["err_no_model"][lang], None, None 35 | 36 | model_name_or_path = get_model_path(model_name) 37 | if not model_name_or_path: 38 | return None, ALERTS["err_no_path"][lang], None, None 39 | 40 | if len(dataset) == 0: 41 | return None, ALERTS["err_no_dataset"][lang], None, None 42 | 43 | self.aborted = False 44 | self.running = True 45 | 46 | logger_handler = LoggerHandler() 47 | logger_handler.setLevel(logging.INFO) 48 | logging.root.addHandler(logger_handler) 49 | transformers.logging.add_handler(logger_handler) 50 | trainer_callback = LogCallback(self) 51 | 52 | return model_name_or_path, "", logger_handler, trainer_callback 53 | 54 | def finalize( 55 | self, lang: str, finish_info: Optional[str] = None 56 | ) -> str: 57 | self.running = False 58 | torch_gc() 59 | if self.aborted: 60 | return ALERTS["info_aborted"][lang] 61 | else: 62 | return finish_info if finish_info is not None else ALERTS["info_finished"][lang] 63 | 64 | def run_train( 65 | self, 66 | lang: str, 67 | model_name: str, 68 | checkpoints: List[str], 69 | finetuning_type: str, 70 | quantization_bit: str, 71 | source_prefix: str, 72 | dataset_dir: str, 73 | dataset: List[str], 74 | max_source_length: int, 75 | max_target_length: int, 76 | learning_rate: str, 77 | num_train_epochs: str, 78 | max_samples: str, 79 | batch_size: int, 80 | gradient_accumulation_steps: int, 81 | lr_scheduler_type: str, 82 | max_grad_norm: str, 83 | dev_ratio: float, 84 | logging_steps: int, 85 | save_steps: int, 86 | warmup_steps: int, 87 | compute_type: str, 88 | lora_rank: int, 89 | lora_dropout: float, 90 | lora_target: str, 91 | output_dir: str 92 | ) -> Generator[str, None, None]: 93 | model_name_or_path, error, logger_handler, trainer_callback = self.initialize(lang, model_name, dataset) 94 | if error: 95 | yield error 96 | return 97 | 98 | if checkpoints: 99 | checkpoint_dir = ",".join( 100 | [os.path.join(get_save_dir(model_name), finetuning_type, checkpoint) for checkpoint in checkpoints] 101 | ) 102 | else: 103 | checkpoint_dir = None 104 | 105 | args = dict( 106 | model_name_or_path=model_name_or_path, 107 | do_train=True, 108 | overwrite_cache=True, 109 | checkpoint_dir=checkpoint_dir, 110 | finetuning_type=finetuning_type, 111 | quantization_bit=int(quantization_bit) if quantization_bit else None, 112 | source_prefix=source_prefix, 113 | dataset_dir=dataset_dir, 114 | dataset=",".join(dataset), 115 | max_source_length=max_source_length, 116 | max_target_length=max_target_length, 117 | learning_rate=float(learning_rate), 118 | num_train_epochs=float(num_train_epochs), 119 | max_samples=int(max_samples), 120 | per_device_train_batch_size=batch_size, 121 | gradient_accumulation_steps=gradient_accumulation_steps, 122 | lr_scheduler_type=lr_scheduler_type, 123 | max_grad_norm=float(max_grad_norm), 124 | logging_steps=logging_steps, 125 | save_steps=save_steps, 126 | warmup_steps=warmup_steps, 127 | fp16=(compute_type == "fp16"), 128 | bf16=(compute_type == "bf16"), 129 | lora_rank=lora_rank, 130 | lora_dropout=lora_dropout, 131 | lora_target=lora_target or "query_key_value", 132 | output_dir=os.path.join(get_save_dir(model_name), finetuning_type, output_dir) 133 | ) 134 | 135 | if dev_ratio > 1e-6: 136 | args["dev_ratio"] = dev_ratio 137 | args["evaluation_strategy"] = "steps" 138 | args["eval_steps"] = save_steps 139 | args["load_best_model_at_end"] = True 140 | 141 | model_args, data_args, training_args, finetuning_args, _ = get_train_args(args) 142 | 143 | run_args = dict( 144 | model_args=model_args, 145 | data_args=data_args, 146 | training_args=training_args, 147 | finetuning_args=finetuning_args, 148 | callbacks=[trainer_callback] 149 | ) 150 | thread = threading.Thread(target=run_sft, kwargs=run_args) 151 | thread.start() 152 | 153 | while thread.is_alive(): 154 | time.sleep(1) 155 | if self.aborted: 156 | yield ALERTS["info_aborting"][lang] 157 | else: 158 | yield format_info(logger_handler.log, trainer_callback.tracker) 159 | 160 | yield self.finalize(lang) 161 | 162 | def run_eval( 163 | self, 164 | lang: str, 165 | model_name: str, 166 | checkpoints: List[str], 167 | finetuning_type: str, 168 | quantization_bit: str, 169 | source_prefix: str, 170 | dataset_dir: str, 171 | dataset: List[str], 172 | max_source_length: int, 173 | max_target_length: int, 174 | max_samples: str, 175 | batch_size: int, 176 | predict: bool 177 | ) -> Generator[str, None, None]: 178 | model_name_or_path, error, logger_handler, trainer_callback = self.initialize(lang, model_name, dataset) 179 | if error: 180 | yield error 181 | return 182 | 183 | if checkpoints: 184 | checkpoint_dir = ",".join( 185 | [os.path.join(get_save_dir(model_name), finetuning_type, checkpoint) for checkpoint in checkpoints] 186 | ) 187 | output_dir = os.path.join(get_save_dir(model_name), finetuning_type, "eval_" + "_".join(checkpoints)) 188 | else: 189 | checkpoint_dir = None 190 | output_dir = os.path.join(get_save_dir(model_name), finetuning_type, "eval_base") 191 | 192 | args = dict( 193 | model_name_or_path=model_name_or_path, 194 | do_eval=True, 195 | overwrite_cache=True, 196 | predict_with_generate=True, 197 | checkpoint_dir=checkpoint_dir, 198 | finetuning_type=finetuning_type, 199 | quantization_bit=int(quantization_bit) if quantization_bit else None, 200 | source_prefix=source_prefix, 201 | dataset_dir=dataset_dir, 202 | dataset=",".join(dataset), 203 | max_source_length=max_source_length, 204 | max_target_length=max_target_length, 205 | max_samples=int(max_samples), 206 | per_device_eval_batch_size=batch_size, 207 | output_dir=output_dir 208 | ) 209 | 210 | if predict: 211 | args.pop("do_eval", None) 212 | args["do_predict"] = True 213 | 214 | model_args, data_args, training_args, finetuning_args, _ = get_train_args(args) 215 | 216 | run_args = dict( 217 | model_args=model_args, 218 | data_args=data_args, 219 | training_args=training_args, 220 | finetuning_args=finetuning_args, 221 | callbacks=[trainer_callback] 222 | ) 223 | thread = threading.Thread(target=run_sft, kwargs=run_args) 224 | thread.start() 225 | 226 | while thread.is_alive(): 227 | time.sleep(1) 228 | if self.aborted: 229 | yield ALERTS["info_aborting"][lang] 230 | else: 231 | yield format_info(logger_handler.log, trainer_callback.tracker) 232 | 233 | yield self.finalize(lang, get_eval_results(os.path.join(output_dir, "all_results.json"))) 234 | -------------------------------------------------------------------------------- /src/glmtuner/webui/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import gradio as gr 4 | import matplotlib.figure 5 | import matplotlib.pyplot as plt 6 | from typing import Any, Dict, Generator, List, Tuple 7 | from datetime import datetime 8 | 9 | from glmtuner.extras.ploting import smooth 10 | from glmtuner.tuner import get_infer_args, load_model_and_tokenizer 11 | from glmtuner.webui.common import get_model_path, get_save_dir, DATA_CONFIG 12 | from glmtuner.webui.locales import ALERTS 13 | 14 | 15 | def format_info(log: str, tracker: dict) -> str: 16 | info = log 17 | if "current_steps" in tracker: 18 | info += "Running **{:d}/{:d}**: {} < {}\n".format( 19 | tracker["current_steps"], tracker["total_steps"], tracker["elapsed_time"], tracker["remaining_time"] 20 | ) 21 | return info 22 | 23 | 24 | def get_time() -> str: 25 | return datetime.now().strftime('%Y-%m-%d-%H-%M-%S') 26 | 27 | 28 | def can_preview(dataset_dir: str, dataset: list) -> Dict[str, Any]: 29 | with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f: 30 | dataset_info = json.load(f) 31 | if ( 32 | len(dataset) > 0 33 | and "file_name" in dataset_info[dataset[0]] 34 | and os.path.isfile(os.path.join(dataset_dir, dataset_info[dataset[0]]["file_name"])) 35 | ): 36 | return gr.update(interactive=True) 37 | else: 38 | return gr.update(interactive=False) 39 | 40 | 41 | def get_preview(dataset_dir: str, dataset: list) -> Tuple[int, list, Dict[str, Any]]: 42 | with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f: 43 | dataset_info = json.load(f) 44 | data_file = dataset_info[dataset[0]]["file_name"] 45 | with open(os.path.join(dataset_dir, data_file), "r", encoding="utf-8") as f: 46 | data = json.load(f) 47 | return len(data), data[:2], gr.update(visible=True) 48 | 49 | 50 | def can_quantize(finetuning_type: str) -> Dict[str, Any]: 51 | if finetuning_type not in ["p_tuning", "lora"]: 52 | return gr.update(value="", interactive=False) 53 | else: 54 | return gr.update(interactive=True) 55 | 56 | 57 | def get_eval_results(path: os.PathLike) -> str: 58 | with open(path, "r", encoding="utf-8") as f: 59 | result = json.dumps(json.load(f), indent=4) 60 | return "```json\n{}\n```\n".format(result) 61 | 62 | 63 | def gen_plot(base_model: str, finetuning_type: str, output_dir: str) -> matplotlib.figure.Figure: 64 | log_file = os.path.join(get_save_dir(base_model), finetuning_type, output_dir, "trainer_log.jsonl") 65 | if not os.path.isfile(log_file): 66 | return None 67 | 68 | plt.close("all") 69 | fig = plt.figure() 70 | ax = fig.add_subplot(111) 71 | steps, losses = [], [] 72 | with open(log_file, "r", encoding="utf-8") as f: 73 | for line in f: 74 | log_info = json.loads(line) 75 | if log_info.get("loss", None): 76 | steps.append(log_info["current_steps"]) 77 | losses.append(log_info["loss"]) 78 | 79 | if len(losses) == 0: 80 | return None 81 | 82 | ax.plot(steps, losses, alpha=0.4, label="original") 83 | ax.plot(steps, smooth(losses), label="smoothed") 84 | ax.legend() 85 | ax.set_xlabel("step") 86 | ax.set_ylabel("loss") 87 | return fig 88 | 89 | 90 | def export_model( 91 | lang: str, model_name: str, checkpoints: List[str], finetuning_type: str, max_shard_size: int, save_dir: str 92 | ) -> Generator[str, None, None]: 93 | if not model_name: 94 | yield ALERTS["err_no_model"][lang] 95 | return 96 | 97 | model_name_or_path = get_model_path(model_name) 98 | if not model_name_or_path: 99 | yield ALERTS["err_no_path"][lang] 100 | return 101 | 102 | if not checkpoints: 103 | yield ALERTS["err_no_checkpoint"][lang] 104 | return 105 | 106 | checkpoint_dir = ",".join( 107 | [os.path.join(get_save_dir(model_name), finetuning_type, checkpoint) for checkpoint in checkpoints] 108 | ) 109 | 110 | if not save_dir: 111 | yield ALERTS["err_no_save_dir"][lang] 112 | return 113 | 114 | args = dict( 115 | model_name_or_path=model_name_or_path, 116 | checkpoint_dir=checkpoint_dir, 117 | finetuning_type=finetuning_type 118 | ) 119 | 120 | yield ALERTS["info_exporting"][lang] 121 | model_args, _, finetuning_args, _ = get_infer_args(args) 122 | model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args) 123 | model.save_pretrained(save_dir, max_shard_size=str(max_shard_size)+"GB") 124 | tokenizer.save_pretrained(save_dir) 125 | yield ALERTS["info_exported"][lang] 126 | -------------------------------------------------------------------------------- /src/train_bash.py: -------------------------------------------------------------------------------- 1 | from glmtuner.tuner import get_train_args, run_sft, run_rm, run_ppo 2 | 3 | 4 | def main(): 5 | model_args, data_args, training_args, finetuning_args, general_args = get_train_args() 6 | 7 | if general_args.stage == "sft": 8 | run_sft(model_args, data_args, training_args, finetuning_args) 9 | elif general_args.stage == "rm": 10 | run_rm(model_args, data_args, training_args, finetuning_args) 11 | elif general_args.stage == "ppo": 12 | run_ppo(model_args, data_args, training_args, finetuning_args) 13 | 14 | 15 | def _mp_fn(index): 16 | # For xla_spawn (TPUs) 17 | main() 18 | 19 | 20 | if __name__ == "__main__": 21 | main() 22 | -------------------------------------------------------------------------------- /src/web_demo.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Implements user interface in browser for ChatGLM fine-tuned with PEFT. 3 | # Usage: python web_demo.py --checkpoint_dir path_to_checkpoint [--quantization_bit 4] 4 | 5 | import gradio as gr 6 | from transformers.utils.versions import require_version 7 | 8 | from glmtuner.tuner import get_infer_args 9 | from glmtuner.webui.chat import WebChatModel 10 | from glmtuner.webui.components.chatbot import create_chat_box 11 | from glmtuner.webui.manager import Manager 12 | 13 | 14 | require_version("gradio>=3.36.0", "To fix: pip install gradio>=3.36.0") 15 | 16 | 17 | def main(): 18 | chat_model = WebChatModel(*get_infer_args()) 19 | 20 | with gr.Blocks(title="Web Demo") as demo: 21 | lang = gr.Dropdown(choices=["en", "zh"], value="en") 22 | 23 | _, _, _, chat_elems = create_chat_box(chat_model, visible=True) 24 | 25 | manager = Manager([{"lang": lang}, chat_elems]) 26 | 27 | demo.load(manager.gen_label, [lang], [lang] + list(chat_elems.values())) 28 | 29 | lang.change(manager.gen_label, [lang], [lang] + list(chat_elems.values())) 30 | 31 | demo.queue() 32 | demo.launch(server_name="0.0.0.0", server_port=7860, share=False, inbrowser=True) 33 | 34 | 35 | if __name__ == "__main__": 36 | main() 37 | --------------------------------------------------------------------------------