├── CoachLM150 Test Set.json
├── Expert Revision Dataset.json
├── LICENSE
├── README.md
├── Technical Report.pdf
├── asset
├── .gitkeep
├── coachLM.png
├── dataset_score.png
├── illustration.png
└── win_rates.png
├── data
├── .ipynb_checkpoints
│ ├── dataset_info-checkpoint.json
│ └── example-checkpoint.json
├── dataset_info.json
└── example.json
├── requirements.txt
└── src
├── glmtuner
├── __init__.py
├── api
│ ├── __init__.py
│ ├── app.py
│ └── protocol.py
├── chat
│ ├── __init__.py
│ └── stream_chat.py
├── dsets
│ ├── __init__.py
│ ├── collator.py
│ ├── loader.py
│ ├── preprocess.py
│ └── utils.py
├── extras
│ ├── __init__.py
│ ├── callbacks.py
│ ├── constants.py
│ ├── logging.py
│ ├── misc.py
│ ├── ploting.py
│ └── save_and_load.py
├── hparams
│ ├── __init__.py
│ ├── data_args.py
│ ├── finetuning_args.py
│ ├── general_args.py
│ ├── generating_args.py
│ └── model_args.py
├── tuner
│ ├── __init__.py
│ ├── core
│ │ ├── __init__.py
│ │ ├── adapter.py
│ │ ├── loader.py
│ │ ├── parser.py
│ │ └── trainer.py
│ ├── ppo
│ │ ├── __init__.py
│ │ ├── trainer.py
│ │ ├── utils.py
│ │ └── workflow.py
│ ├── rm
│ │ ├── __init__.py
│ │ ├── collator.py
│ │ ├── metric.py
│ │ ├── trainer.py
│ │ └── workflow.py
│ └── sft
│ │ ├── __init__.py
│ │ ├── metric.py
│ │ ├── trainer.py
│ │ └── workflow.py
└── webui
│ ├── __init__.py
│ ├── chat.py
│ ├── common.py
│ ├── components
│ ├── __init__.py
│ ├── chatbot.py
│ ├── data.py
│ ├── eval.py
│ ├── export.py
│ ├── infer.py
│ ├── sft.py
│ └── top.py
│ ├── css.py
│ ├── interface.py
│ ├── locales.py
│ ├── manager.py
│ ├── runner.py
│ └── utils.py
├── train_bash.py
└── web_demo.py
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # CoachLM
2 |
3 | This repo contains codes and the human-created training dataset for CoachLM, an automatic instruction revision approach for LLM instruction tuning.
4 |
5 | Paper: https://arxiv.org/abs/2311.13246
6 |
7 | See more cases analysis for the human revision in Technical Report.pdf
8 | ## 📰 News
9 | CoachLM has been accepted by IEEE International Conference on Data Engineering (ICDE) 2024🎉🎉🎉🎉🎉🎉.
10 | ## 📣 Introduction
11 |
12 |
13 |
14 |
15 | Instruction tuning is crucial for enabling Language Learning Models (LLMs) in responding to human instructions. The quality of instruction pairs used for tuning greatly affects the performance of LLMs. However, the manual creation of high-quality instruction datasets is costly, leading to the adoption of automatic generation of instruction pairs by LLMs as a popular alternative. To ensure the high quality of LLM-generated instruction datasets, several approaches have been proposed. Nevertheless, existing methods either compromise dataset integrity by filtering a large proportion of samples, or are unsuitable for industrial applications. Instead of discarding low-quality samples, we propose CoachLM, a novel approach to enhance the quality of instruction datasets through automatic revisions on samples in the dataset. CoachLM is trained from the samples revised by human experts and significantly increases the proportion of high-quality samples in the dataset from 17.7% to 78.9%. The effectiveness of CoachLM is further assessed on various real-world instruction test sets. The results show that CoachLM improves the instruction-following capabilities of the instruction-tuned LLM by an average of 29.9%, which even surpasses larger LLMs with nearly twice the number of parameters.
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 | ## 🔰 Installation
26 | ```
27 | $ pip install requirements.txt
28 | ```
29 | ## ✨ Expert Revision Dataset
30 | We created a dataset of 2301 samples containing the raw instruction pairs from the Alpaca52k dataset and the human-revised results in order to improve the quality of these LLM-generated instruction pairs.
31 | ```
32 | Expert Revision Dataset.json
33 | {
34 | "Raw Instruction": "",
35 | "Raw Input": "",
36 | "Raw Response": "",
37 | "Revised Instruction": "",
38 | "Revised Input": "",
39 | "Revised Response": "",
40 | "Distance": ""
41 | }
42 | ```
43 | As inllustrated above, the dataset contain raw instructions and the revised versions. it also records the edit distances between the raw instructions and the revised samples.
44 | ## 📝 Training CoachLM
45 | CoachLM in this repository is implemented by fine-tuning ChatGLM2 on our curated expert-revision dataset. Thanks [ChatGLM-Efficient-Tuning](https://github.com/hiyouga/ChatGLM-Efficient-Tuning) for implementing an efficient tool to fine-tune ChatGLM2. The training steps of CoachLM are as follows:
46 |
47 |
48 | (1) Determine the subset used for training
49 |
50 |
51 | Training CoachLM using the whole dataset may not lead to optimal results. We recommend using around the first 30% samples in the dataset sorted by edit distance from largest to smallest. Other approaches of selecting the dataset are also possible.
52 |
53 |
54 | (2) Format the training dataset
55 |
56 |
57 | Format the training dataset according to data/example.json, and register the dataset in data/dataset_info.json.
58 | ```
59 | example.json
60 | {
61 | "instruction": "Improve the following instruction, input and response pair to be more specific, detailed with more logical steps and grammarly corrected.",
62 | "input": "Instruction: Name three natural elements. Response: Water, air, and fire.",
63 | "output": "Instruction: Name three natural elements. Response: Some examples of natural elements are:\n\n- Oxygen: This is the most abundant element in the Earth's crust and the second most abundant element in the atmosphere. Oxygen is essential for life, as it is involved in cellular respiration and other metabolic processes. \n\n- Iron: This is the most abundant element in the Earth's core and the fourth most abundant element in the Earth's crust. Iron is a metal that has many physical and chemical properties, such as strength, magnetism, and conductivity. Iron is also important for life, as it is a component of hemoglobin, which transports oxygen in the blood. \n\n- Gold: This is one of the rarest and most valuable elements on Earth. Gold is a metal that has a shiny yellow color and a high density. Gold is resistant to corrosion and oxidation, which makes it suitable for jewelry and coins."
64 | }
65 | ```
66 | ```
67 | dataset_info.json
68 | {
69 | "example": {
70 | "file_name": "example.json",
71 | "file_sha1": "",
72 | "columns": {
73 | "prompt": "instruction",
74 | "query": "input",
75 | "response": "output"
76 | }
77 | }
78 | }
79 | ```
80 |
81 | (3) Start training
82 | ```
83 | python src/train_bash.py \
84 | --stage sft \
85 | --model_name_or_path path_to_your_chatglm2_model \
86 | --do_train \
87 | --dataset name_of_your_dataset_in_dataset_info_json \
88 | --finetuning_type lora \
89 | --output_dir path_to_CoachLM_checkpoint \
90 | --per_device_train_batch_size 32 \
91 | --gradient_accumulation_steps 1 \
92 | --lr_scheduler_type cosine \
93 | --logging_steps 100 \
94 | --save_strategy "epoch" \
95 | --learning_rate 2e-4 \
96 | --num_train_epochs 7 \
97 | --fp16 true \
98 | --lora_rank 64 \
99 | --lora_alpha 32 \
100 | --lora_target "query_key_value,dense,dense_h_to_4h,dense_4h_to_h" \
101 | --use_v2 true \
102 | ```
103 |
104 | (4) Inference
105 |
106 |
107 | The inference dataset should be formatted the same as example.json, with output field empty.
108 | ```
109 | python src/train_bash.py \
110 | --stage sft \
111 | --do_predict \
112 | --finetuning_type lora \
113 | --dataset dataset_for_inference \
114 | --model_name_or_path path_to_your_chatglm2_model \
115 | --checkpoint_dir path_to_CoachLM_checkpoint \
116 | --output_dir path_to_inference_result \
117 | --per_device_eval_batch_size 32 \
118 | --predict_with_generate \
119 | --eval_num_beams 1 \
120 | --lora_rank 64 \
121 | --lora_alpha 32 \
122 | --lora_target "query_key_value,dense,dense_h_to_4h,dense_4h_to_h" \
123 | --use_v2 true \
124 | ```
125 | For more information, please refer to [ChatGLM-Efficient-Tuning](https://github.com/hiyouga/ChatGLM-Efficient-Tuning).
126 |
127 | ## 🧪 CoachLM150 Test Set
128 | We created an instruction-following test suite for LLMs, containing 150 questions covering topics from information extraction, scientific inference, dialogue completion, brainstorming, in-domain question answering, and more. For each question, a reference response is provided by human.
129 | ```
130 | CoachLM150 Test Set.json
131 | {
132 | "instruction": "",
133 | "input": "",
134 | "reference response": ""
135 | }
136 | ```
137 |
138 | ## ⚠️ Limitations
139 | The current 7B version of CoachLM may still generate undesirable content occasionally, including hallucinated content, repeated text and meaningless phrases. To mitigate the disorderness and increase the reliability, a rule-based post-processing performed on the output of CoachLM is recommended, removing the non-English characters, repeated strings and excessively long answers. Another solution is to train CoachLM on larger foundation models such as 13B or 60B.
140 |
--------------------------------------------------------------------------------
/Technical Report.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lunyiliu/CoachLM/37617299916c3ae550b2f6da713835e21bea6e60/Technical Report.pdf
--------------------------------------------------------------------------------
/asset/.gitkeep:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/asset/coachLM.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lunyiliu/CoachLM/37617299916c3ae550b2f6da713835e21bea6e60/asset/coachLM.png
--------------------------------------------------------------------------------
/asset/dataset_score.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lunyiliu/CoachLM/37617299916c3ae550b2f6da713835e21bea6e60/asset/dataset_score.png
--------------------------------------------------------------------------------
/asset/illustration.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lunyiliu/CoachLM/37617299916c3ae550b2f6da713835e21bea6e60/asset/illustration.png
--------------------------------------------------------------------------------
/asset/win_rates.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lunyiliu/CoachLM/37617299916c3ae550b2f6da713835e21bea6e60/asset/win_rates.png
--------------------------------------------------------------------------------
/data/.ipynb_checkpoints/dataset_info-checkpoint.json:
--------------------------------------------------------------------------------
1 | {
2 | "alpaca_en": {
3 | "file_name": "alpaca_data_en_52k.json",
4 | "file_sha1": "607f94a7f581341e59685aef32f531095232cf23"
5 | },
6 | "alpaca_zh": {
7 | "file_name": "alpaca_data_zh_51k.json",
8 | "file_sha1": "e655af3db557a4197f7b0cf92e1986b08fae6311"
9 | },
10 | "alpaca_gpt4_en": {
11 | "file_name": "alpaca_gpt4_data_en.json",
12 | "file_sha1": "647f4ad447bd993e4b6b6223d1be15208bab694a"
13 | },
14 | "alpaca_gpt4_zh": {
15 | "file_name": "alpaca_gpt4_data_zh.json",
16 | "file_sha1": "3eaa3bda364ccdd59925d7448a698256c31ef845"
17 | },
18 | "self_cognition": {
19 | "file_name": "self_cognition.json",
20 | "file_sha1": "6287a730ada924fc5d9eadc6d8f865e01b7a6f67"
21 | },
22 | "oaast_sft": {
23 | "file_name": "oaast_sft.json",
24 | "file_sha1": "7baf5d43e67a91f9bbdf4e400dbe033b87e9757e",
25 | "columns": {
26 | "prompt": "instruction",
27 | "query": "input",
28 | "response": "output",
29 | "history": "history"
30 | }
31 | },
32 | "oaast_sft_zh": {
33 | "file_name": "oaast_sft_zh.json",
34 | "file_sha1": "a6a91f18f80f37b10ded9cf633fb50c033bf7b9f",
35 | "columns": {
36 | "prompt": "instruction",
37 | "query": "input",
38 | "response": "output",
39 | "history": "history"
40 | }
41 | },
42 | "sharegpt_zh": {
43 | "file_name": "sharegpt_zh_27k.json",
44 | "file_sha1": "baf766bcf3d61f1b783728c14ce695af57a86e6e",
45 | "columns": {
46 | "prompt": "instruction",
47 | "query": "input",
48 | "response": "output",
49 | "history": "history"
50 | }
51 | },
52 | "refgpt_zh_p1": {
53 | "file_name": "refgpt_zh_50k_p1.json",
54 | "file_sha1": "b40f4f4d0ffacd16da7c275b056d5b6670021752",
55 | "columns": {
56 | "prompt": "instruction",
57 | "query": "input",
58 | "response": "output",
59 | "history": "history"
60 | }
61 | },
62 | "refgpt_zh_p2": {
63 | "file_name": "refgpt_zh_50k_p2.json",
64 | "file_sha1": "181f32b2c60264a29f81f59d3c76095793eae1b0",
65 | "columns": {
66 | "prompt": "instruction",
67 | "query": "input",
68 | "response": "output",
69 | "history": "history"
70 | }
71 | },
72 | "lima": {
73 | "file_name": "lima.json",
74 | "file_sha1": "9db59f6b7007dc4b17529fc63379b9cd61640f37",
75 | "columns": {
76 | "prompt": "instruction",
77 | "query": "input",
78 | "response": "output",
79 | "history": "history"
80 | }
81 | },
82 | "example": {
83 | "script_url": "example_dataset",
84 | "columns": {
85 | "prompt": "instruction",
86 | "query": "input",
87 | "response": "output",
88 | "history": "history"
89 | }
90 | },
91 | "guanaco": {
92 | "hf_hub_url": "JosephusCheung/GuanacoDataset"
93 | },
94 | "belle_0.5m": {
95 | "hf_hub_url": "BelleGroup/train_0.5M_CN"
96 | },
97 | "belle_1m": {
98 | "hf_hub_url": "BelleGroup/train_1M_CN"
99 | },
100 | "belle_2m": {
101 | "hf_hub_url": "BelleGroup/train_2M_CN"
102 | },
103 | "belle_dialog": {
104 | "hf_hub_url": "BelleGroup/generated_chat_0.4M"
105 | },
106 | "belle_math": {
107 | "hf_hub_url": "BelleGroup/school_math_0.25M"
108 | },
109 | "belle_multiturn": {
110 | "script_url": "belle_multiturn",
111 | "columns": {
112 | "prompt": "instruction",
113 | "query": "",
114 | "response": "output",
115 | "history": "history"
116 | }
117 | },
118 | "firefly": {
119 | "hf_hub_url": "YeungNLP/firefly-train-1.1M",
120 | "columns": {
121 | "prompt": "input",
122 | "query": "",
123 | "response": "target",
124 | "history": ""
125 | }
126 | },
127 | "codealpaca": {
128 | "hf_hub_url": "sahil2801/CodeAlpaca-20k"
129 | },
130 | "alpaca_cot": {
131 | "hf_hub_url": "QingyiSi/Alpaca-CoT"
132 | },
133 | "webqa": {
134 | "hf_hub_url": "suolyer/webqa",
135 | "columns": {
136 | "prompt": "input",
137 | "query": "",
138 | "response": "output",
139 | "history": ""
140 | }
141 | },
142 | "ultra_chat": {
143 | "script_url": "ultra_chat",
144 | "columns": {
145 | "prompt": "instruction",
146 | "query": "",
147 | "response": "output",
148 | "history": "history"
149 | }
150 | },
151 | "novel_tokens512_50k": {
152 | "hf_hub_url": "zxbsmk/webnovel_cn"
153 | },
154 | "comparison_gpt4_en": {
155 | "file_name": "comparison_gpt4_data_en.json",
156 | "file_sha1": "96fa18313544e22444fe20eead7754b17da452ae"
157 | },
158 | "comparison_gpt4_zh": {
159 | "file_name": "comparison_gpt4_data_zh.json",
160 | "file_sha1": "515b18ed497199131ddcc1af950345c11dc5c7fd"
161 | },
162 | "hh_rlhf_en": {
163 | "script_url": "hh_rlhf_en",
164 | "columns": {
165 | "prompt": "instruction",
166 | "query": "",
167 | "response": "output",
168 | "history": "history"
169 | }
170 | },
171 | "oaast_rm": {
172 | "file_name": "oaast_rm.json",
173 | "file_sha1": "622d420e9b70003b210618253bd3d9d2891d86cb",
174 | "columns": {
175 | "prompt": "instruction",
176 | "query": "input",
177 | "response": "output",
178 | "history": "history"
179 | }
180 | },
181 | "oaast_rm_zh": {
182 | "file_name": "oaast_rm_zh.json",
183 | "file_sha1": "1065af1f3784dd61be5e79713a35f427b713a232",
184 | "columns": {
185 | "prompt": "instruction",
186 | "query": "input",
187 | "response": "output",
188 | "history": "history"
189 | }
190 | }
191 | }
192 |
--------------------------------------------------------------------------------
/data/.ipynb_checkpoints/example-checkpoint.json:
--------------------------------------------------------------------------------
1 | {
2 | "instruction": "Improve the following instruction, input and response pair to be more specific, detailed with more logical steps and grammarly corrected.",
3 | "input": "Instruction: Name three natural elements. Response: Water, air, and fire.",
4 | "output": "Instruction: Name three natural elements. Response: Some examples of natural elements are:\n\n- Oxygen: This is the most abundant element in the Earth's crust and the second most abundant element in the atmosphere. Oxygen is essential for life, as it is involved in cellular respiration and other metabolic processes. \n\n- Iron: This is the most abundant element in the Earth's core and the fourth most abundant element in the Earth's crust. Iron is a metal that has many physical and chemical properties, such as strength, magnetism, and conductivity. Iron is also important for life, as it is a component of hemoglobin, which transports oxygen in the blood. \n\n- Gold: This is one of the rarest and most valuable elements on Earth. Gold is a metal that has a shiny yellow color and a high density. Gold is resistant to corrosion and oxidation, which makes it suitable for jewelry and coins."
5 | }
--------------------------------------------------------------------------------
/data/dataset_info.json:
--------------------------------------------------------------------------------
1 | {
2 | "example": {
3 | "file_name": "example.json",
4 | "file_sha1": "",
5 | "columns": {
6 | "prompt": "instruction",
7 | "query": "input",
8 | "response": "output"
9 | }
10 | }
11 | }
12 |
--------------------------------------------------------------------------------
/data/example.json:
--------------------------------------------------------------------------------
1 | {
2 | "instruction": "Improve the following instruction, input and response pair to be more specific, detailed with more logical steps and grammarly corrected.",
3 | "input": "Instruction: Name three natural elements. Response: Water, air, and fire.",
4 | "output": "Instruction: Name three natural elements. Response: Some examples of natural elements are:\n\n- Oxygen: This is the most abundant element in the Earth's crust and the second most abundant element in the atmosphere. Oxygen is essential for life, as it is involved in cellular respiration and other metabolic processes. \n\n- Iron: This is the most abundant element in the Earth's core and the fourth most abundant element in the Earth's crust. Iron is a metal that has many physical and chemical properties, such as strength, magnetism, and conductivity. Iron is also important for life, as it is a component of hemoglobin, which transports oxygen in the blood. \n\n- Gold: This is one of the rarest and most valuable elements on Earth. Gold is a metal that has a shiny yellow color and a high density. Gold is resistant to corrosion and oxidation, which makes it suitable for jewelry and coins."
5 | }
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch>=1.13.1
2 | transformers>=4.29.1
3 | datasets>=2.12.0
4 | accelerate>=0.21.0
5 | peft>=0.4.0
6 | trl>=0.4.7
7 | sentencepiece
8 | jieba
9 | rouge-chinese
10 | nltk
11 | gradio>=3.36.0
12 | uvicorn
13 | pydantic==1.10.11
14 | fastapi==0.95.1
15 | sse-starlette
16 | matplotlib
17 | protobuf
18 | cpm-kernels
19 |
--------------------------------------------------------------------------------
/src/glmtuner/__init__.py:
--------------------------------------------------------------------------------
1 | from glmtuner.chat import ChatModel
2 |
3 |
4 | __version__ = "0.1.5"
5 |
--------------------------------------------------------------------------------
/src/glmtuner/api/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lunyiliu/CoachLM/37617299916c3ae550b2f6da713835e21bea6e60/src/glmtuner/api/__init__.py
--------------------------------------------------------------------------------
/src/glmtuner/api/app.py:
--------------------------------------------------------------------------------
1 | import uvicorn
2 | from fastapi import FastAPI, HTTPException
3 | from fastapi.middleware.cors import CORSMiddleware
4 | from contextlib import asynccontextmanager
5 | from sse_starlette import EventSourceResponse
6 | from typing import List, Tuple
7 |
8 | from glmtuner.tuner import get_infer_args
9 | from glmtuner.extras.misc import torch_gc
10 | from glmtuner.chat.stream_chat import ChatModel
11 | from glmtuner.api.protocol import (
12 | Role,
13 | Finish,
14 | ModelCard,
15 | ModelList,
16 | ChatMessage,
17 | DeltaMessage,
18 | ChatCompletionRequest,
19 | ChatCompletionResponse,
20 | ChatCompletionStreamResponse,
21 | ChatCompletionResponseChoice,
22 | ChatCompletionResponseStreamChoice,
23 | ChatCompletionResponseUsage
24 | )
25 |
26 |
27 | @asynccontextmanager
28 | async def lifespan(app: FastAPI): # collects GPU memory
29 | yield
30 | torch_gc()
31 |
32 |
33 | def create_app(chat_model: ChatModel) -> FastAPI:
34 | app = FastAPI(lifespan=lifespan)
35 |
36 | app.add_middleware(
37 | CORSMiddleware,
38 | allow_origins=["*"],
39 | allow_credentials=True,
40 | allow_methods=["*"],
41 | allow_headers=["*"],
42 | )
43 |
44 | @app.get("/v1/models", response_model=ModelList)
45 | async def list_models():
46 | model_card = ModelCard(id="gpt-3.5-turbo")
47 | return ModelList(data=[model_card])
48 |
49 | @app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
50 | async def create_chat_completion(request: ChatCompletionRequest):
51 | if request.messages[-1].role != Role.USER:
52 | raise HTTPException(status_code=400, detail="Invalid request")
53 | query = request.messages[-1].content
54 |
55 | prev_messages = request.messages[:-1]
56 | if len(prev_messages) > 0 and prev_messages[0].role == Role.SYSTEM:
57 | prefix = prev_messages.pop(0).content
58 | else:
59 | prefix = None
60 |
61 | history = []
62 | if len(prev_messages) % 2 == 0:
63 | for i in range(0, len(prev_messages), 2):
64 | if prev_messages[i].role == Role.USER and prev_messages[i+1].role == Role.ASSISTANT:
65 | history.append([prev_messages[i].content, prev_messages[i+1].content])
66 |
67 | if request.stream:
68 | generate = predict(query, history, prefix, request)
69 | return EventSourceResponse(generate, media_type="text/event-stream")
70 |
71 | response, (prompt_length, response_length) = chat_model.chat(
72 | query, history, prefix, temperature=request.temperature, top_p=request.top_p, max_new_tokens=request.max_tokens
73 | )
74 |
75 | usage = ChatCompletionResponseUsage(
76 | prompt_tokens=prompt_length,
77 | completion_tokens=response_length,
78 | total_tokens=prompt_length+response_length
79 | )
80 |
81 | choice_data = ChatCompletionResponseChoice(
82 | index=0,
83 | message=ChatMessage(role=Role.ASSISTANT, content=response),
84 | finish_reason=Finish.STOP
85 | )
86 |
87 | return ChatCompletionResponse(model=request.model, choices=[choice_data], usage=usage)
88 |
89 | async def predict(query: str, history: List[Tuple[str, str]], prefix: str, request: ChatCompletionRequest):
90 | choice_data = ChatCompletionResponseStreamChoice(
91 | index=0,
92 | delta=DeltaMessage(role=Role.ASSISTANT),
93 | finish_reason=None
94 | )
95 | chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
96 | yield chunk.json(exclude_unset=True, ensure_ascii=False)
97 |
98 | for new_text in chat_model.stream_chat(
99 | query, history, prefix, temperature=request.temperature, top_p=request.top_p, max_new_tokens=request.max_tokens
100 | ):
101 | if len(new_text) == 0:
102 | continue
103 |
104 | choice_data = ChatCompletionResponseStreamChoice(
105 | index=0,
106 | delta=DeltaMessage(content=new_text),
107 | finish_reason=None
108 | )
109 | chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
110 | yield chunk.json(exclude_unset=True, ensure_ascii=False)
111 |
112 | choice_data = ChatCompletionResponseStreamChoice(
113 | index=0,
114 | delta=DeltaMessage(),
115 | finish_reason=Finish.STOP
116 | )
117 | chunk = ChatCompletionStreamResponse(model=request.model, choices=[choice_data])
118 | yield chunk.json(exclude_unset=True, ensure_ascii=False)
119 | yield "[DONE]"
120 |
121 | return app
122 |
123 |
124 | if __name__ == "__main__":
125 | chat_model = ChatModel(*get_infer_args())
126 | app = create_app(chat_model)
127 | uvicorn.run(app, host="0.0.0.0", port=8000, workers=1)
128 |
--------------------------------------------------------------------------------
/src/glmtuner/api/protocol.py:
--------------------------------------------------------------------------------
1 | import time
2 | from enum import Enum
3 | from pydantic import BaseModel, Field
4 | from typing import List, Optional
5 |
6 |
7 | class Role(str, Enum):
8 | USER = "user"
9 | ASSISTANT = "assistant"
10 | SYSTEM = "system"
11 |
12 |
13 | class Finish(str, Enum):
14 | STOP = "stop"
15 | LENGTH = "length"
16 |
17 |
18 | class ModelCard(BaseModel):
19 | id: str
20 | object: Optional[str] = "model"
21 | created: Optional[int] = Field(default_factory=lambda: int(time.time()))
22 | owned_by: Optional[str] = "owner"
23 | root: Optional[str] = None
24 | parent: Optional[str] = None
25 | permission: Optional[list] = []
26 |
27 |
28 | class ModelList(BaseModel):
29 | object: Optional[str] = "list"
30 | data: Optional[List[ModelCard]] = []
31 |
32 |
33 | class ChatMessage(BaseModel):
34 | role: Role
35 | content: str
36 |
37 |
38 | class DeltaMessage(BaseModel):
39 | role: Optional[Role] = None
40 | content: Optional[str] = None
41 |
42 |
43 | class ChatCompletionRequest(BaseModel):
44 | model: str
45 | messages: List[ChatMessage]
46 | temperature: Optional[float] = None
47 | top_p: Optional[float] = None
48 | n: Optional[int] = 1
49 | max_tokens: Optional[int] = None
50 | stream: Optional[bool] = False
51 |
52 |
53 | class ChatCompletionResponseChoice(BaseModel):
54 | index: int
55 | message: ChatMessage
56 | finish_reason: Finish
57 |
58 |
59 | class ChatCompletionResponseStreamChoice(BaseModel):
60 | index: int
61 | delta: DeltaMessage
62 | finish_reason: Optional[Finish] = None
63 |
64 |
65 | class ChatCompletionResponseUsage(BaseModel):
66 | prompt_tokens: int
67 | completion_tokens: int
68 | total_tokens: int
69 |
70 |
71 | class ChatCompletionResponse(BaseModel):
72 | id: Optional[str] = "chatcmpl-default"
73 | object: Optional[str] = "chat.completion"
74 | created: Optional[int] = Field(default_factory=lambda: int(time.time()))
75 | model: str
76 | choices: List[ChatCompletionResponseChoice]
77 | usage: ChatCompletionResponseUsage
78 |
79 |
80 | class ChatCompletionStreamResponse(BaseModel):
81 | id: Optional[str] = "chatcmpl-default"
82 | object: Optional[str] = "chat.completion.chunk"
83 | created: Optional[int] = Field(default_factory=lambda: int(time.time()))
84 | model: str
85 | choices: List[ChatCompletionResponseStreamChoice]
86 |
--------------------------------------------------------------------------------
/src/glmtuner/chat/__init__.py:
--------------------------------------------------------------------------------
1 | from glmtuner.chat.stream_chat import ChatModel
2 |
--------------------------------------------------------------------------------
/src/glmtuner/chat/stream_chat.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from typing import Any, Dict, Generator, List, Optional, Tuple
3 | from threading import Thread
4 | from transformers import TextIteratorStreamer
5 |
6 | from glmtuner.extras.misc import dispatch_model, get_logits_processor
7 | from glmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
8 | from glmtuner.tuner import load_model_and_tokenizer
9 |
10 |
11 | class ChatModel:
12 |
13 | def __init__(
14 | self,
15 | model_args: ModelArguments,
16 | data_args: DataArguments,
17 | finetuning_args: FinetuningArguments,
18 | generating_args: GeneratingArguments
19 | ) -> None:
20 | self.model, self.tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
21 | self.model = dispatch_model(self.model, use_v2=(self.tokenizer.eos_token_id==2))
22 | self.source_prefix = data_args.source_prefix
23 | self.generating_args = generating_args
24 |
25 | def get_prompt(
26 | self, query: str, history: Optional[List[Tuple[str, str]]] = None, prefix: Optional[str] = None
27 | ) -> str:
28 | prefix = prefix + "\n" if prefix else "" # add separator for non-empty prefix
29 | history = history or []
30 | prompt = ""
31 | for i, (old_query, response) in enumerate(history):
32 | prompt += "[Round {}]\n\n问:{}\n\n答:{}\n\n".format(i+1, old_query, response)
33 | prompt += "[Round {}]\n\n问:{}\n\n答:".format(len(history)+1, query)
34 | prompt = prefix + prompt
35 | return prompt
36 |
37 | def process_args(
38 | self, query: str, history: Optional[List[Tuple[str, str]]] = None, prefix: Optional[str] = None, **input_kwargs
39 | ) -> Tuple[Dict[str, Any], int]:
40 | prefix = prefix or self.source_prefix
41 |
42 | inputs = self.tokenizer([self.get_prompt(query, history, prefix)], return_tensors="pt")
43 | inputs = inputs.to(self.model.device)
44 | prompt_length = len(inputs["input_ids"][0])
45 |
46 | do_sample = input_kwargs.pop("do_sample", None)
47 | temperature = input_kwargs.pop("temperature", None)
48 | top_p = input_kwargs.pop("top_p", None)
49 | top_k = input_kwargs.pop("top_k", None)
50 | repetition_penalty = input_kwargs.pop("repetition_penalty", None)
51 | max_length = input_kwargs.pop("max_length", None)
52 | max_new_tokens = input_kwargs.pop("max_new_tokens", None)
53 |
54 | gen_kwargs = self.generating_args.to_dict()
55 | gen_kwargs.update(dict(
56 | input_ids=inputs["input_ids"],
57 | do_sample=do_sample if do_sample is not None else gen_kwargs["do_sample"],
58 | temperature=temperature or gen_kwargs["temperature"],
59 | top_p=top_p or gen_kwargs["top_p"],
60 | top_k=top_k or gen_kwargs["top_k"],
61 | repetition_penalty=repetition_penalty or gen_kwargs["repetition_penalty"],
62 | logits_processor=get_logits_processor()
63 | ))
64 |
65 | if max_length:
66 | gen_kwargs.pop("max_new_tokens", None)
67 | gen_kwargs["max_length"] = max_length
68 |
69 | if max_new_tokens:
70 | gen_kwargs.pop("max_length", None)
71 | gen_kwargs["max_new_tokens"] = max_new_tokens
72 |
73 | return gen_kwargs, prompt_length
74 |
75 | @torch.inference_mode()
76 | def chat(
77 | self, query: str, history: Optional[List[Tuple[str, str]]] = None, prefix: Optional[str] = None, **input_kwargs
78 | ) -> Tuple[str, Tuple[int, int]]:
79 | gen_kwargs, prompt_length = self.process_args(query, history, prefix, **input_kwargs)
80 | generation_output = self.model.generate(**gen_kwargs)
81 | outputs = generation_output.tolist()[0][prompt_length:]
82 | response = self.tokenizer.decode(outputs, skip_special_tokens=True)
83 | response_length = len(outputs)
84 | return response, (prompt_length, response_length)
85 |
86 | @torch.inference_mode()
87 | def stream_chat(
88 | self, query: str, history: Optional[List[Tuple[str, str]]] = None, prefix: Optional[str] = None, **input_kwargs
89 | ) -> Generator[str, None, None]:
90 | gen_kwargs, _ = self.process_args(query, history, prefix, **input_kwargs)
91 | streamer = TextIteratorStreamer(self.tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
92 | gen_kwargs["streamer"] = streamer
93 |
94 | thread = Thread(target=self.model.generate, kwargs=gen_kwargs)
95 | thread.start()
96 |
97 | yield from streamer
98 |
--------------------------------------------------------------------------------
/src/glmtuner/dsets/__init__.py:
--------------------------------------------------------------------------------
1 | from glmtuner.dsets.collator import DataCollatorForChatGLM
2 | from glmtuner.dsets.loader import get_dataset
3 | from glmtuner.dsets.preprocess import preprocess_dataset
4 | from glmtuner.dsets.utils import split_dataset
5 |
--------------------------------------------------------------------------------
/src/glmtuner/dsets/collator.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from typing import Any, Dict, Optional, Sequence
3 |
4 | from transformers import DataCollatorWithPadding, BatchEncoding
5 | from transformers.modeling_utils import PreTrainedModel
6 | from transformers.tokenization_utils import PreTrainedTokenizer
7 |
8 | from glmtuner.extras.constants import IGNORE_INDEX
9 |
10 |
11 | class DataCollatorForChatGLM(DataCollatorWithPadding):
12 | r"""
13 | Data collator for ChatGLM. It is capable of dynamically padding for batched data.
14 | """
15 | def __init__(
16 | self,
17 | tokenizer: PreTrainedTokenizer,
18 | model: PreTrainedModel,
19 | ignore_pad_token_for_loss: Optional[bool] = False
20 | ):
21 | super().__init__(tokenizer, padding=True)
22 | self.model = model
23 | self.label_pad_token_id = IGNORE_INDEX if ignore_pad_token_for_loss else tokenizer.pad_token_id
24 | if tokenizer.eos_token_id == 130005:
25 | self.get_attention_masks = self.get_attention_masks_v1
26 | self.get_position_ids = self.get_position_ids_v1
27 | else:
28 | self.get_attention_masks = self.get_attention_masks_v2
29 | self.get_position_ids = self.get_position_ids_v2
30 |
31 | def get_attention_masks_v1(self, input_ids: torch.Tensor, device: torch.device) -> torch.Tensor:
32 | r"""
33 | Generates attention masks for left-padded sequences.
34 |
35 | Note that ChatGLM assigns False on token to be attended in attention mask. In general settings, it should be True.
36 |
37 | According to: https://huggingface.co/THUDM/chatglm-6b/blob/v1.1.0/modeling_chatglm.py#L680
38 | """
39 | batch_size, seq_length = input_ids.size()
40 | attention_mask = torch.ones((batch_size, seq_length, seq_length), device=device)
41 | attention_mask.tril_()
42 |
43 | for i, seq in enumerate(input_ids):
44 | attention_mask[i, :, :(seq == self.tokenizer.bos_token_id).nonzero()[0].item()] = 1 # context
45 | attention_mask[i, :, :(seq != self.tokenizer.pad_token_id).nonzero()[0].item()] = 0 # padding
46 |
47 | attention_mask.unsqueeze_(1)
48 | attention_mask = (attention_mask < 0.5).bool()
49 | return attention_mask
50 |
51 | def get_position_ids_v1(self, input_ids: torch.Tensor, device: torch.device) -> torch.Tensor:
52 | r"""
53 | Generates position ids for left-padded sequenes.
54 |
55 | According to: https://huggingface.co/THUDM/chatglm-6b/blob/v1.1.0/modeling_chatglm.py#L692
56 | """
57 | batch_size, seq_length = input_ids.size()
58 | mask: int = self.model.config.mask_token_id
59 | gmask: int = self.model.config.gmask_token_id
60 | position_ids = torch.zeros((batch_size, seq_length), dtype=torch.long, device=device)
61 | block_position_ids = torch.zeros((batch_size, seq_length), dtype=torch.long, device=device)
62 |
63 | for i, seq in enumerate(input_ids):
64 | mask_token = gmask if gmask in seq else mask
65 | context_length = (seq == self.tokenizer.bos_token_id).nonzero()[0].item()
66 | padding_length = (seq != self.tokenizer.pad_token_id).nonzero()[0].item()
67 | position_ids[i, padding_length:] = torch.arange(
68 | seq_length - padding_length,
69 | dtype=torch.long,
70 | device=device
71 | )
72 | if self.model.position_encoding_2d or (mask_token != gmask): # 2d position encoding or not gMASK
73 | position_ids[i, context_length:] = (seq == mask_token).nonzero()[0].item() - padding_length # mask position
74 | block_position_ids[i, context_length:] = torch.arange(
75 | seq_length - context_length,
76 | dtype=torch.long,
77 | device=device
78 | ) + 1
79 |
80 | if self.model.position_encoding_2d:
81 | position_ids = torch.stack((position_ids, block_position_ids), dim=1)
82 |
83 | return position_ids
84 |
85 | def get_attention_masks_v2(self, input_ids: torch.Tensor, device: torch.device) -> torch.Tensor:
86 | r"""
87 | Generates attention masks for left-padded sequences.
88 | """
89 | batch_size, seq_length = input_ids.size()
90 | attention_mask = torch.ones((batch_size, seq_length), device=device)
91 |
92 | for i, seq in enumerate(input_ids):
93 | attention_mask[i, :(seq != self.tokenizer.pad_token_id).nonzero()[0].item()] = 0 # padding
94 |
95 | return attention_mask
96 |
97 | def get_position_ids_v2(self, input_ids: torch.Tensor, device: torch.device) -> torch.Tensor:
98 | r"""
99 | Generates position ids for left-padded sequenes.
100 | """
101 | batch_size, seq_length = input_ids.size()
102 | position_ids = torch.zeros((batch_size, seq_length), dtype=torch.long, device=device)
103 |
104 | for i, seq in enumerate(input_ids):
105 | padding_length = (seq != self.tokenizer.pad_token_id).nonzero()[0].item()
106 | position_ids[i, padding_length:] = torch.arange(seq_length - padding_length, dtype=torch.long, device=device)
107 |
108 | return position_ids
109 |
110 | def __call__(self, features: Sequence[Dict[str, Any]]) -> BatchEncoding:
111 | r"""
112 | Pads batched data to the longest sequence in the batch.
113 |
114 | We adopt left-padding in both training and evaluation.
115 | """
116 | if isinstance(features[0]["input_ids"], torch.Tensor):
117 | input_ids = [feature["input_ids"].clone().detach().flip(0) for feature in features]
118 | else:
119 | input_ids = [torch.tensor(feature["input_ids"]).flip(0) for feature in features]
120 |
121 | if "labels" in features[0]:
122 | if isinstance(features[0]["labels"], torch.Tensor):
123 | labels = [feature["labels"].clone().detach().flip(0) for feature in features]
124 | else:
125 | labels = [torch.tensor(feature["labels"]).flip(0) for feature in features]
126 | input_ids += labels # pad them to the same length
127 |
128 | input_ids = torch.nn.utils.rnn.pad_sequence(
129 | input_ids,
130 | batch_first=True,
131 | padding_value=self.tokenizer.pad_token_id
132 | ).flip(-1)
133 |
134 | batch = {}
135 |
136 | if "labels" in features[0]:
137 | input_ids, labels = input_ids.split(len(features), dim=0)
138 | labels = torch.where(labels != self.tokenizer.pad_token_id, labels, self.label_pad_token_id)
139 | batch["labels"] = labels
140 |
141 | batch["input_ids"] = input_ids
142 | batch["attention_mask"] = self.get_attention_masks(input_ids, device=input_ids.device)
143 | batch["position_ids"] = self.get_position_ids(input_ids, device=input_ids.device)
144 |
145 | return BatchEncoding(batch)
146 |
--------------------------------------------------------------------------------
/src/glmtuner/dsets/loader.py:
--------------------------------------------------------------------------------
1 | import os
2 | import hashlib
3 | from typing import List
4 |
5 | from datasets import Dataset, concatenate_datasets, load_dataset
6 |
7 | from glmtuner.extras.logging import get_logger
8 | from glmtuner.hparams import ModelArguments, DataArguments
9 |
10 |
11 | logger = get_logger(__name__)
12 |
13 |
14 | def get_dataset(
15 | model_args: ModelArguments,
16 | data_args: DataArguments
17 | ) -> Dataset:
18 |
19 | def checksum(file_path, hash):
20 | with open(file_path, "rb") as datafile:
21 | binary_data = datafile.read()
22 | sha1 = hashlib.sha1(binary_data).hexdigest()
23 | if sha1 != hash:
24 | logger.warning("Checksum failed for {}. It may vary depending on the platform.".format(file_path))
25 |
26 | ext2type = {
27 | "csv": "csv",
28 | "json": "json",
29 | "jsonl": "json"
30 | }
31 |
32 | max_samples = data_args.max_samples
33 | all_datasets: List[Dataset] = [] # support multiple datasets
34 |
35 | for dataset_attr in data_args.dataset_list:
36 |
37 | logger.info("Loading dataset {}...".format(dataset_attr))
38 |
39 | if dataset_attr.load_from == "hf_hub":
40 | data_path = dataset_attr.dataset_name
41 | data_files = None
42 |
43 | elif dataset_attr.load_from == "script":
44 | data_path = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)
45 | data_files = None
46 |
47 | elif dataset_attr.load_from == "file": # support folder or file
48 | data_path = None
49 | data_files: List[str] = []
50 |
51 | if os.path.isdir(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)): # folder
52 | for file_name in os.listdir(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)):
53 | data_files.append(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name, file_name))
54 |
55 | if data_path is None:
56 | data_path = ext2type.get(data_files[0].split(".")[-1], None)
57 | else:
58 | assert ext2type.get(data_files[-1].split(".")[-1], None) == data_path, \
59 | "more than one file formats found in a single folder"
60 |
61 | elif os.path.isfile(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)): # file
62 | data_files.append(os.path.join(data_args.dataset_dir, dataset_attr.dataset_name))
63 | data_path = ext2type.get(data_files[0].split(".")[-1], None)
64 |
65 | else:
66 | raise ValueError("File not found.")
67 |
68 | assert data_path, "File extension must be csv, json or jsonl."
69 |
70 | if len(data_files) == 1 and dataset_attr.dataset_sha1 is not None:
71 | checksum(data_files[0], dataset_attr.dataset_sha1)
72 | else:
73 | logger.warning("Checksum failed: missing SHA-1 hash value in dataset_info.json or too many files.")
74 |
75 | else:
76 | raise NotImplementedError
77 |
78 | raw_datasets = load_dataset(
79 | data_path,
80 | data_files=data_files,
81 | cache_dir=model_args.cache_dir,
82 | use_auth_token=True if model_args.use_auth_token else None
83 | )
84 | dataset = raw_datasets[data_args.split]
85 |
86 | if max_samples is not None:
87 | max_samples_temp = min(len(dataset), max_samples)
88 | dataset = dataset.select(range(max_samples_temp))
89 |
90 | dummy_data = [None] * len(dataset)
91 |
92 | for column_name, target_name in [
93 | ("prompt_column", "prompt"),
94 | ("query_column", "query"),
95 | ("response_column", "response"),
96 | ("history_column", "history")
97 | ]: # every dataset will have 4 columns same as each other
98 | if getattr(dataset_attr, column_name) != target_name:
99 | if getattr(dataset_attr, column_name):
100 | dataset = dataset.rename_column(getattr(dataset_attr, column_name), target_name)
101 | else: # None or empty string
102 | dataset = dataset.add_column(target_name, dummy_data)
103 |
104 | all_datasets.append(dataset)
105 |
106 | if len(data_args.dataset_list) == 1:
107 | all_datasets = all_datasets[0]
108 | else:
109 | all_datasets = concatenate_datasets(all_datasets)
110 |
111 | return all_datasets
112 |
--------------------------------------------------------------------------------
/src/glmtuner/dsets/preprocess.py:
--------------------------------------------------------------------------------
1 | from typing import Literal
2 | from transformers import Seq2SeqTrainingArguments
3 | from transformers.tokenization_utils import PreTrainedTokenizer
4 |
5 | from datasets import Dataset
6 |
7 | from glmtuner.extras.constants import IGNORE_INDEX
8 | from glmtuner.hparams import DataArguments
9 |
10 |
11 | def preprocess_dataset(
12 | dataset: Dataset,
13 | tokenizer: PreTrainedTokenizer,
14 | data_args: DataArguments,
15 | training_args: Seq2SeqTrainingArguments,
16 | stage: Literal["sft", "rm", "ppo"]
17 | ) -> Dataset:
18 |
19 | column_names = list(dataset.column_names)
20 | prefix = data_args.source_prefix if data_args.source_prefix is not None else ""
21 |
22 | def format_example(examples): # support question with a single answer or multiple answers
23 | for i in range(len(examples["prompt"])):
24 | if examples["prompt"][i] and examples["response"][i]:
25 | query, answer = examples["prompt"][i], examples["response"][i]
26 | query = query + examples["query"][i] if examples["query"][i] else query
27 | history = examples["history"][i] if examples["history"][i] else []
28 | prompt = ""
29 | for j, (old_query, response) in enumerate(history):
30 | prompt += "[Round {}]\n\n问:{}\n\n答:{}\n\n".format(j+1, old_query, response)
31 | prompt += "[Round {}]\n\n问:{}\n\n答:".format(len(history)+1, query)
32 | prompt = prefix + prompt
33 | yield prompt, answer
34 |
35 | def preprocess_supervised_dataset(examples):
36 | # v1: build inputs with format `X [gMASK] Y ` and labels with format `[IGNORE] ... [IGNORE] Y `
37 | # v2: build inputs with format `[gMASK] sop X Y ` and labels with format `[IGNORE] ... [IGNORE] Y `
38 | model_inputs = {"input_ids": [], "labels": []}
39 | for prompt, answer in format_example(examples):
40 | source_ids = tokenizer.encode(text=prompt, add_special_tokens=False)
41 | target_ids = tokenizer.encode(text=answer, add_special_tokens=False)
42 |
43 | if len(source_ids) > data_args.max_source_length - 2: # gmask and sop tokens
44 | source_ids = source_ids[:data_args.max_source_length - 2]
45 | if len(target_ids) > data_args.max_target_length - 1: # eos token
46 | target_ids = target_ids[:data_args.max_target_length - 1]
47 |
48 | context_length = len(source_ids) + 2 # gmask and sop tokens
49 | input_ids = tokenizer.build_inputs_with_special_tokens(source_ids, target_ids)
50 | labels = [IGNORE_INDEX] * context_length + input_ids[context_length:]
51 |
52 | model_inputs["input_ids"].append(input_ids)
53 | model_inputs["labels"].append(labels)
54 | return model_inputs
55 |
56 | def preprocess_evaluation_dataset(examples):
57 | # v1: build inputs with format `X [gMASK] ` and labels with format `Y [gMASK] `
58 | # v2: build inputs with format `[gMASK] sop X` and labels with format `[gMASK] sop Y`
59 | model_inputs = {"input_ids": [], "labels": []}
60 | for prompt, answer in format_example(examples):
61 | source_ids = tokenizer.encode(text=prompt, add_special_tokens=False)
62 | target_ids = tokenizer.encode(text=answer, add_special_tokens=False)
63 |
64 | if len(source_ids) > data_args.max_source_length - 2: # gmask and sop tokens
65 | source_ids = source_ids[:data_args.max_source_length - 2]
66 | if len(target_ids) > data_args.max_target_length - 2: # gmask and sop tokens
67 | target_ids = target_ids[:data_args.max_target_length - 2]
68 |
69 | input_ids = tokenizer.build_inputs_with_special_tokens(source_ids)
70 | labels = tokenizer.build_inputs_with_special_tokens(target_ids)
71 |
72 | model_inputs["input_ids"].append(input_ids)
73 | model_inputs["labels"].append(labels)
74 | return model_inputs
75 |
76 | def preprocess_pairwise_dataset(examples):
77 | # v1: build input pairs with format `X [gMASK] Y1 ` and `X [gMASK] Y2 `
78 | # v2: build input pairs with format `[gMASK] sop X Y1 ` and `[gMASK] sop X Y2 `
79 | model_inputs = {"accept_ids": [], "reject_ids": []}
80 | for prompt, answer in format_example(examples):
81 | source_ids = tokenizer.encode(text=prompt, add_special_tokens=False)
82 | accept_ids = tokenizer.encode(text=answer[0], add_special_tokens=False)
83 | reject_ids = tokenizer.encode(text=answer[1], add_special_tokens=False)
84 |
85 | if len(source_ids) > data_args.max_source_length - 2: # gmask and sop tokens
86 | source_ids = source_ids[:data_args.max_source_length - 2]
87 | if len(accept_ids) > data_args.max_target_length - 1: # eos token
88 | accept_ids = accept_ids[:data_args.max_target_length - 1]
89 | if len(reject_ids) > data_args.max_target_length - 1: # eos token
90 | reject_ids = reject_ids[:data_args.max_target_length - 1]
91 |
92 | accept_ids = tokenizer.build_inputs_with_special_tokens(source_ids[:], accept_ids) # avoid copying error
93 | reject_ids = tokenizer.build_inputs_with_special_tokens(source_ids[:], reject_ids)
94 |
95 | model_inputs["accept_ids"].append(accept_ids)
96 | model_inputs["reject_ids"].append(reject_ids)
97 | return model_inputs
98 |
99 | def print_sft_dataset_example(example):
100 | print("input_ids:\n{}".format(example["input_ids"]))
101 | print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
102 | print("label_ids:\n{}".format(example["labels"]))
103 | print("labels:\n{}".format(
104 | tokenizer.decode([d if d != IGNORE_INDEX else tokenizer.pad_token_id for d in example["labels"]],
105 | skip_special_tokens=False)
106 | ))
107 |
108 | def print_pairwise_dataset_example(example):
109 | print("accept_ids:\n{}".format(example["accept_ids"]))
110 | print("accepts:\n{}".format(tokenizer.decode(example["accept_ids"], skip_special_tokens=False)))
111 | print("reject_ids:\n{}".format(example["reject_ids"]))
112 | print("rejects:\n{}".format(tokenizer.decode(example["reject_ids"], skip_special_tokens=False)))
113 |
114 | def print_ppo_dataset_example(example):
115 | print("input_ids:\n{}".format(example["input_ids"]))
116 | print("inputs:\n{}".format(tokenizer.decode(example["input_ids"], skip_special_tokens=False)))
117 |
118 | if stage == "sft":
119 | if not training_args.predict_with_generate:
120 | preprocess_function = preprocess_supervised_dataset
121 | else:
122 | preprocess_function = preprocess_evaluation_dataset
123 | elif stage == "rm":
124 | preprocess_function = preprocess_pairwise_dataset
125 | elif stage == "ppo":
126 | preprocess_function = preprocess_evaluation_dataset
127 |
128 | with training_args.main_process_first(desc="dataset map pre-processing"):
129 | dataset = dataset.map(
130 | preprocess_function,
131 | batched=True,
132 | num_proc=data_args.preprocessing_num_workers,
133 | remove_columns=column_names,
134 | load_from_cache_file=not data_args.overwrite_cache,
135 | desc="Running tokenizer on dataset"
136 | )
137 |
138 | if stage == "sft":
139 | print_sft_dataset_example(dataset[0])
140 | elif stage == "rm":
141 | print_pairwise_dataset_example(dataset[0])
142 | elif stage == "ppo":
143 | print_ppo_dataset_example(dataset[0])
144 |
145 | return dataset
146 |
--------------------------------------------------------------------------------
/src/glmtuner/dsets/utils.py:
--------------------------------------------------------------------------------
1 | from typing import Dict
2 | from datasets import Dataset
3 |
4 |
5 | def split_dataset(
6 | dataset: Dataset, dev_ratio: float, do_train: bool
7 | ) -> Dict[str, Dataset]:
8 | # Split the dataset
9 | if do_train:
10 | if dev_ratio > 1e-6:
11 | dataset = dataset.train_test_split(test_size=dev_ratio)
12 | return {"train_dataset": dataset["train"], "eval_dataset": dataset["test"]}
13 | else:
14 | return {"train_dataset": dataset}
15 | else: # do_eval or do_predict
16 | return {"eval_dataset": dataset}
17 |
--------------------------------------------------------------------------------
/src/glmtuner/extras/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lunyiliu/CoachLM/37617299916c3ae550b2f6da713835e21bea6e60/src/glmtuner/extras/__init__.py
--------------------------------------------------------------------------------
/src/glmtuner/extras/callbacks.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import time
4 | from datetime import timedelta
5 |
6 | from transformers import (
7 | TrainerCallback,
8 | TrainerControl,
9 | TrainerState,
10 | TrainingArguments
11 | )
12 | from transformers.trainer_callback import TrainerControl, TrainerState
13 | from transformers.training_args import TrainingArguments
14 |
15 |
16 | class LogCallback(TrainerCallback):
17 |
18 | def __init__(self, runner=None):
19 | self.runner = runner
20 | self.start_time = time.time()
21 | self.tracker = {}
22 |
23 | def on_train_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
24 | r"""
25 | Event called at the beginning of training.
26 | """
27 | self.start_time = time.time()
28 |
29 | def on_step_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
30 | r"""
31 | Event called at the beginning of a training step. If using gradient accumulation, one training step
32 | might take several inputs.
33 | """
34 | if self.runner is not None and self.runner.aborted:
35 | control.should_epoch_stop = True
36 | control.should_training_stop = True
37 |
38 | def on_substep_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
39 | r"""
40 | Event called at the end of an substep during gradient accumulation.
41 | """
42 | if self.runner is not None and self.runner.aborted:
43 | control.should_epoch_stop = True
44 | control.should_training_stop = True
45 |
46 | def on_log(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs) -> None:
47 | r"""
48 | Event called after logging the last logs.
49 | """
50 | if not state.is_world_process_zero:
51 | return
52 |
53 | cur_time = time.time()
54 | cur_steps = state.log_history[-1].get("step")
55 | elapsed_time = cur_time - self.start_time
56 | avg_time_per_step = elapsed_time / cur_steps if cur_steps != 0 else 0
57 | remaining_steps = state.max_steps - cur_steps
58 | remaining_time = remaining_steps * avg_time_per_step
59 | self.tracker = {
60 | "current_steps": cur_steps,
61 | "total_steps": state.max_steps,
62 | "loss": state.log_history[-1].get("loss", None),
63 | "eval_loss": state.log_history[-1].get("eval_loss", None),
64 | "predict_loss": state.log_history[-1].get("predict_loss", None),
65 | "reward": state.log_history[-1].get("reward", None),
66 | "learning_rate": state.log_history[-1].get("learning_rate", None),
67 | "epoch": state.log_history[-1].get("epoch", None),
68 | "percentage": round(cur_steps / state.max_steps * 100, 2) if state.max_steps != 0 else 100,
69 | "elapsed_time": str(timedelta(seconds=int(elapsed_time))),
70 | "remaining_time": str(timedelta(seconds=int(remaining_time)))
71 | }
72 | os.makedirs(args.output_dir, exist_ok=True)
73 | with open(os.path.join(args.output_dir, "trainer_log.jsonl"), "a", encoding="utf-8") as f:
74 | f.write(json.dumps(self.tracker) + "\n")
75 |
--------------------------------------------------------------------------------
/src/glmtuner/extras/constants.py:
--------------------------------------------------------------------------------
1 | IGNORE_INDEX = -100
2 |
3 | VALUE_HEAD_FILE_NAME = "value_head.bin"
4 |
5 | FINETUNING_ARGS_NAME = "finetuning_args.json"
6 |
7 | LAYERNORM_NAMES = ["layernorm"]
8 |
9 | METHODS = ["full", "freeze", "p_tuning", "lora"]
10 |
11 | SUPPORTED_MODELS = {
12 | "ChatGLM-6B": "THUDM/chatglm-6b",
13 | "ChatGLM2-6B": "THUDM/chatglm2-6b"
14 | }
15 |
--------------------------------------------------------------------------------
/src/glmtuner/extras/logging.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import logging
3 |
4 |
5 | class LoggerHandler(logging.Handler):
6 |
7 | def __init__(self):
8 | super().__init__()
9 | self.log = ""
10 |
11 | def emit(self, record):
12 | if record.name == "httpx":
13 | return
14 | log_entry = self.format(record)
15 | self.log += log_entry
16 | self.log += "\n\n"
17 |
18 |
19 | def get_logger(name: str) -> logging.Logger:
20 |
21 | formatter = logging.Formatter(
22 | fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
23 | datefmt="%m/%d/%Y %H:%M:%S"
24 | )
25 | handler = logging.StreamHandler(sys.stdout)
26 | handler.setFormatter(formatter)
27 |
28 | logger = logging.getLogger(name)
29 | logger.setLevel(logging.INFO)
30 | logger.addHandler(handler)
31 |
32 | return logger
33 |
--------------------------------------------------------------------------------
/src/glmtuner/extras/misc.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from typing import Dict, List, Optional
3 |
4 | from transformers.modeling_utils import PreTrainedModel
5 | from transformers.generation.utils import LogitsProcessorList
6 | from transformers.generation.logits_process import LogitsProcessor
7 |
8 | from glmtuner.extras.constants import LAYERNORM_NAMES
9 |
10 |
11 | class AverageMeter:
12 | r"""
13 | Computes and stores the average and current value.
14 | """
15 | def __init__(self):
16 | self.reset()
17 |
18 | def reset(self):
19 | self.val = 0
20 | self.avg = 0
21 | self.sum = 0
22 | self.count = 0
23 |
24 | def update(self, val, n=1):
25 | self.val = val
26 | self.sum += val * n
27 | self.count += n
28 | self.avg = self.sum / self.count
29 |
30 |
31 | # Avoid runtime error in model.generate(do_sample=True).
32 | # Borrowed from: https://huggingface.co/THUDM/chatglm-6b/blob/658202d88ac4bb782b99e99ac3adff58b4d0b813/modeling_chatglm.py#L54
33 | class InvalidScoreLogitsProcessor(LogitsProcessor):
34 |
35 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
36 | if torch.isnan(scores).any() or torch.isinf(scores).any():
37 | scores.zero_()
38 | scores[..., 5] = 5e4
39 | return scores
40 |
41 |
42 | def get_logits_processor() -> LogitsProcessorList:
43 | logits_processor = LogitsProcessorList()
44 | logits_processor.append(InvalidScoreLogitsProcessor())
45 | return logits_processor
46 |
47 |
48 | def print_trainable_params(model: torch.nn.Module) -> None:
49 | trainable_params, all_param = 0, 0
50 | for param in model.parameters():
51 | num_params = param.numel()
52 | # if using DS Zero 3 and the weights are initialized empty
53 | if num_params == 0 and hasattr(param, "ds_numel"):
54 | num_params = param.ds_numel
55 | all_param += num_params
56 | if param.requires_grad:
57 | trainable_params += num_params
58 | print("trainable params: {:d} || all params: {:d} || trainable%: {:.4f}".format(
59 | trainable_params, all_param, 100 * trainable_params / all_param))
60 |
61 |
62 | # Includes: (1) cast the layernorm in fp32 (2) make output embedding layer require grads (3) upcast the lm_head to fp32
63 | # Inspired by: https://github.com/huggingface/peft/blob/c0209c35abbf88c63aa267800d98a8e212ed0a42/src/peft/utils/other.py#L35
64 | def prepare_model_for_training(
65 | model: PreTrainedModel,
66 | finetuning_type: str,
67 | output_embedding_base_layer: torch.nn.Module,
68 | output_embedding_layer_name: Optional[str] = "lm_head",
69 | use_gradient_checkpointing: Optional[bool] = True,
70 | layer_norm_names: Optional[List[str]] = LAYERNORM_NAMES
71 | ) -> PreTrainedModel:
72 |
73 | for name, param in model.named_parameters():
74 | if param.ndim == 1 and any(layer_norm_name in name for layer_norm_name in layer_norm_names):
75 | param.data = param.data.to(torch.float32)
76 |
77 | if use_gradient_checkpointing:
78 | model.enable_input_require_grads()
79 | model.gradient_checkpointing_enable()
80 | model.config.use_cache = False # turn off when gradient checkpointing is enabled
81 |
82 | if finetuning_type != "full" and hasattr(output_embedding_base_layer, output_embedding_layer_name):
83 | output_embedding_layer = getattr(output_embedding_base_layer, output_embedding_layer_name)
84 | input_dtype = output_embedding_layer.weight.dtype
85 |
86 | class CastOutputToFloat(torch.nn.Sequential):
87 |
88 | def forward(self, x: torch.Tensor) -> torch.Tensor:
89 | return super().forward(x.to(input_dtype)).to(torch.float32)
90 |
91 | setattr(output_embedding_base_layer, output_embedding_layer_name, CastOutputToFloat(output_embedding_layer))
92 |
93 | return model
94 |
95 |
96 | def torch_gc() -> None:
97 | r"""
98 | Collects GPU memory.
99 | """
100 | if torch.cuda.is_available():
101 | torch.cuda.empty_cache()
102 | torch.cuda.ipc_collect()
103 |
104 |
105 | def auto_configure_device_map(num_gpus: int, use_v2: bool) -> Dict[str, int]:
106 | r"""
107 | Configures device map for ChatGLM.
108 |
109 | Borrowed from: https://github.com/THUDM/ChatGLM-6B/blob/dev_multi_gpu/utils.py#L8
110 | """
111 | num_layers = 28
112 | layers_per_gpu = 30 / num_gpus
113 | if use_v2:
114 | device_map = {
115 | "transformer.embedding.word_embeddings": 0,
116 | "transformer.encoder.final_layernorm": 0,
117 | "transformer.output_layer": 0,
118 | "transformer.rotary_pos_emb": 0,
119 | "transformer.prefix_encoder": 0,
120 | "lm_head": 0
121 | }
122 | else:
123 | device_map = {
124 | "transformer.word_embeddings": 0,
125 | "transformer.final_layernorm": 0,
126 | "transformer.prefix_encoder": 0,
127 | "lm_head": 0
128 | }
129 |
130 | added_layers = 2
131 | target_gpu = 0
132 |
133 | for i in range(num_layers):
134 | if added_layers >= layers_per_gpu:
135 | target_gpu += 1
136 | added_layers = 0
137 | assert target_gpu < num_gpus
138 | if use_v2:
139 | device_map[f"transformer.encoder.layers.{i}"] = target_gpu
140 | else:
141 | device_map[f"transformer.layers.{i}"] = target_gpu
142 | added_layers += 1
143 |
144 | return device_map
145 |
146 |
147 | def dispatch_model(model: PreTrainedModel, use_v2: bool) -> PreTrainedModel:
148 | r"""
149 | Dispatches a pre-trained model to GPUs with balanced memory.
150 | """
151 | if torch.cuda.device_count() > 1:
152 | from accelerate import dispatch_model
153 |
154 | device_map = auto_configure_device_map(torch.cuda.device_count(), use_v2=use_v2)
155 | model.tie_weights()
156 | return dispatch_model(model, device_map)
157 | else:
158 | return model.cuda()
159 |
--------------------------------------------------------------------------------
/src/glmtuner/extras/ploting.py:
--------------------------------------------------------------------------------
1 | import os
2 | import math
3 | import json
4 | import matplotlib.pyplot as plt
5 | from typing import List, Optional
6 | from transformers.trainer import TRAINER_STATE_NAME
7 |
8 | from glmtuner.extras.logging import get_logger
9 |
10 |
11 | logger = get_logger(__name__)
12 |
13 |
14 | def smooth(scalars: List[float]) -> List[float]:
15 | r"""
16 | EMA implementation according to TensorBoard.
17 | """
18 | last = scalars[0]
19 | smoothed = list()
20 | weight = 1.8 * (1 / (1 + math.exp(-0.05 * len(scalars))) - 0.5) # a sigmoid function
21 | for next_val in scalars:
22 | smoothed_val = last * weight + (1 - weight) * next_val
23 | smoothed.append(smoothed_val)
24 | last = smoothed_val
25 | return smoothed
26 |
27 |
28 | def plot_loss(save_dictionary: os.PathLike, keys: Optional[List[str]] = ["loss"]) -> None:
29 |
30 | with open(os.path.join(save_dictionary, TRAINER_STATE_NAME), "r", encoding="utf-8") as f:
31 | data = json.load(f)
32 |
33 | for key in keys:
34 | steps, metrics = [], []
35 | for i in range(len(data["log_history"])):
36 | if key in data["log_history"][i]:
37 | steps.append(data["log_history"][i]["step"])
38 | metrics.append(data["log_history"][i][key])
39 |
40 | if len(metrics) == 0:
41 | logger.warning(f"No metric {key} to plot.")
42 | continue
43 |
44 | plt.figure()
45 | plt.plot(steps, metrics, alpha=0.4, label="original")
46 | plt.plot(steps, smooth(metrics), label="smoothed")
47 | plt.title("training {} of {}".format(key, save_dictionary))
48 | plt.xlabel("step")
49 | plt.ylabel(key)
50 | plt.legend()
51 | plt.savefig(os.path.join(save_dictionary, "training_{}.png".format(key)), format="png", dpi=100)
52 | print("Figure saved:", os.path.join(save_dictionary, "training_{}.png".format(key)))
53 |
--------------------------------------------------------------------------------
/src/glmtuner/extras/save_and_load.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from typing import Dict, Optional
4 |
5 | from transformers.trainer import WEIGHTS_NAME, WEIGHTS_INDEX_NAME
6 | from transformers.modeling_utils import load_sharded_checkpoint
7 |
8 | from glmtuner.extras.constants import VALUE_HEAD_FILE_NAME
9 | from glmtuner.extras.logging import get_logger
10 |
11 |
12 | logger = get_logger(__name__)
13 |
14 |
15 | def get_state_dict(model: torch.nn.Module, trainable_only: Optional[bool] = True) -> Dict[str, torch.Tensor]:
16 | state_dict = model.state_dict()
17 | filtered_state_dict = {}
18 |
19 | for k, v in model.named_parameters():
20 | if (not trainable_only) or v.requires_grad:
21 | filtered_state_dict[k] = state_dict[k].cpu().clone().detach()
22 |
23 | return filtered_state_dict
24 |
25 |
26 | def load_trainable_params(model: torch.nn.Module, checkpoint_dir: os.PathLike) -> bool:
27 | weights_file = os.path.join(checkpoint_dir, WEIGHTS_NAME)
28 | if os.path.exists(weights_file):
29 | model_state_dict = torch.load(weights_file, map_location="cpu")
30 | model.load_state_dict(model_state_dict, strict=False) # skip missing keys
31 | elif os.path.exists(os.path.join(checkpoint_dir, WEIGHTS_INDEX_NAME)):
32 | load_sharded_checkpoint(model, checkpoint_dir, strict=False)
33 | else:
34 | logger.warning("Provided path ({}) does not contain pre-trained weights.".format(checkpoint_dir))
35 | return False
36 | return True
37 |
38 |
39 | def load_valuehead_params(model: torch.nn.Module, checkpoint_dir: os.PathLike) -> bool:
40 | valuehead_file = os.path.join(checkpoint_dir, VALUE_HEAD_FILE_NAME)
41 | if not os.path.exists(valuehead_file):
42 | logger.warning("Provided path ({}) does not contain valuehead weights.".format(checkpoint_dir))
43 | return False
44 | valuehead_state_dict = torch.load(valuehead_file, map_location="cpu")
45 | model.register_buffer("reward_head_weight", valuehead_state_dict["summary.weight"])
46 | model.register_buffer("reward_head_bias", valuehead_state_dict["summary.bias"])
47 | model.register_buffer("default_head_weight", torch.zeros_like(valuehead_state_dict["summary.weight"]))
48 | model.register_buffer("default_head_bias", torch.zeros_like(valuehead_state_dict["summary.bias"]))
49 | return True
50 |
--------------------------------------------------------------------------------
/src/glmtuner/hparams/__init__.py:
--------------------------------------------------------------------------------
1 | from glmtuner.hparams.data_args import DataArguments
2 | from glmtuner.hparams.finetuning_args import FinetuningArguments
3 | from glmtuner.hparams.general_args import GeneralArguments
4 | from glmtuner.hparams.generating_args import GeneratingArguments
5 | from glmtuner.hparams.model_args import ModelArguments
6 |
--------------------------------------------------------------------------------
/src/glmtuner/hparams/data_args.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | from typing import List, Optional
4 | from dataclasses import dataclass, field
5 |
6 |
7 | @dataclass
8 | class DatasetAttr:
9 |
10 | load_from: str
11 | dataset_name: str
12 | dataset_sha1: Optional[str] = None
13 |
14 | def __repr__(self) -> str:
15 | return self.dataset_name
16 |
17 | def __post_init__(self):
18 | self.prompt_column = "instruction"
19 | self.query_column = "input"
20 | self.response_column = "output"
21 | self.history_column = None
22 |
23 |
24 | @dataclass
25 | class DataArguments:
26 | """
27 | Arguments pertaining to what data we are going to input our model for training and evaluation.
28 | """
29 | dataset: Optional[str] = field(
30 | default="alpaca_zh",
31 | metadata={"help": "The name of provided dataset(s) to use. Use commas to separate multiple datasets."}
32 | )
33 | dataset_dir: Optional[str] = field(
34 | default="data",
35 | metadata={"help": "The name of the folder containing datasets."}
36 | )
37 | split: Optional[str] = field(
38 | default="train",
39 | metadata={"help": "Which dataset split to use for training and evaluation."}
40 | )
41 | overwrite_cache: Optional[bool] = field(
42 | default=False,
43 | metadata={"help": "Overwrite the cached training and evaluation sets."}
44 | )
45 | preprocessing_num_workers: Optional[int] = field(
46 | default=None,
47 | metadata={"help": "The number of processes to use for the preprocessing."}
48 | )
49 | max_source_length: Optional[int] = field(
50 | default=512,
51 | metadata={"help": "The maximum total input sequence length after tokenization."}
52 | )
53 | max_target_length: Optional[int] = field(
54 | default=512,
55 | metadata={"help": "The maximum total output sequence length after tokenization."}
56 | )
57 | max_samples: Optional[int] = field(
58 | default=None,
59 | metadata={"help": "For debugging purposes, truncate the number of examples for each dataset."}
60 | )
61 | eval_num_beams: Optional[int] = field(
62 | default=None,
63 | metadata={"help": "Number of beams to use for evaluation. This argument will be passed to `model.generate`"}
64 | )
65 | ignore_pad_token_for_loss: Optional[bool] = field(
66 | default=True,
67 | metadata={"help": "Whether to ignore the tokens corresponding to padded labels in the loss computation or not."}
68 | )
69 | source_prefix: Optional[str] = field(
70 | default=None,
71 | metadata={"help": "A prefix to add before every source text (useful for T5 models)."}
72 | )
73 | dev_ratio: Optional[float] = field(
74 | default=0,
75 | metadata={"help": "Proportion of the dataset to include in the development set, should be between 0.0 and 1.0."}
76 | )
77 |
78 | def init_for_training(self): # support mixing multiple datasets
79 | dataset_names = [ds.strip() for ds in self.dataset.split(",")]
80 | with open(os.path.join(self.dataset_dir, "dataset_info.json"), "r", encoding="utf-8") as f:
81 | dataset_info = json.load(f)
82 |
83 | self.dataset_list: List[DatasetAttr] = []
84 | for name in dataset_names:
85 | if name not in dataset_info:
86 | raise ValueError("Undefined dataset {} in dataset_info.json.".format(name))
87 |
88 | if "hf_hub_url" in dataset_info[name]:
89 | dataset_attr = DatasetAttr("hf_hub", dataset_name=dataset_info[name]["hf_hub_url"])
90 | elif "script_url" in dataset_info[name]:
91 | dataset_attr = DatasetAttr("script", dataset_name=dataset_info[name]["script_url"])
92 | else:
93 | dataset_attr = DatasetAttr(
94 | "file",
95 | dataset_name=dataset_info[name]["file_name"],
96 | dataset_sha1=dataset_info[name].get("file_sha1", None)
97 | )
98 |
99 | if "columns" in dataset_info[name]:
100 | dataset_attr.prompt_column = dataset_info[name]["columns"].get("prompt", None)
101 | dataset_attr.query_column = dataset_info[name]["columns"].get("query", None)
102 | dataset_attr.response_column = dataset_info[name]["columns"].get("response", None)
103 | dataset_attr.history_column = dataset_info[name]["columns"].get("history", None)
104 |
105 | self.dataset_list.append(dataset_attr)
106 |
--------------------------------------------------------------------------------
/src/glmtuner/hparams/finetuning_args.py:
--------------------------------------------------------------------------------
1 | import json
2 | from typing import Literal, Optional
3 | from dataclasses import asdict, dataclass, field
4 |
5 |
6 | @dataclass
7 | class FinetuningArguments:
8 | """
9 | Arguments pertaining to which techniques we are going to fine-tuning with.
10 | """
11 | finetuning_type: Optional[Literal["none", "freeze", "p_tuning", "lora", "full"]] = field(
12 | default="lora",
13 | metadata={"help": "Which fine-tuning method to use."}
14 | )
15 | num_layer_trainable: Optional[int] = field(
16 | default=3,
17 | metadata={"help": "Number of trainable layers for Freeze fine-tuning."}
18 | )
19 | name_module_trainable: Optional[Literal["mlp", "qkv"]] = field(
20 | default="mlp",
21 | metadata={"help": "Name of trainable modules for Freeze fine-tuning."}
22 | )
23 | pre_seq_len: Optional[int] = field(
24 | default=64,
25 | metadata={"help": "Number of prefix tokens to use for P-tuning V2."}
26 | )
27 | prefix_projection: Optional[bool] = field(
28 | default=False,
29 | metadata={"help": "Whether to add a project layer for the prefix in P-tuning V2 or not."}
30 | )
31 | lora_rank: Optional[int] = field(
32 | default=8,
33 | metadata={"help": "The intrinsic dimension for LoRA fine-tuning."}
34 | )
35 | lora_alpha: Optional[float] = field(
36 | default=32.0,
37 | metadata={"help": "The scale factor for LoRA fine-tuning. (similar with the learning rate)"}
38 | )
39 | lora_dropout: Optional[float] = field(
40 | default=0.1,
41 | metadata={"help": "Dropout rate for the LoRA fine-tuning."}
42 | )
43 | lora_target: Optional[str] = field(
44 | default="query_key_value",
45 | metadata={"help": "Name(s) of target modules to apply LoRA. Use commas to separate multiple modules."}
46 | )
47 |
48 | def __post_init__(self):
49 | if isinstance(self.lora_target, str):
50 | self.lora_target = [target.strip() for target in self.lora_target.split(",")] # support custom target modules of LoRA
51 |
52 | if self.num_layer_trainable > 0: # fine-tuning the last n layers if num_layer_trainable > 0
53 | trainable_layer_ids = [27 - k for k in range(self.num_layer_trainable)]
54 | else: # fine-tuning the first n layers if num_layer_trainable < 0
55 | trainable_layer_ids = [k for k in range(-self.num_layer_trainable)]
56 |
57 | if self.name_module_trainable == "mlp":
58 | self.trainable_layers = ["{:d}.mlp".format(idx) for idx in trainable_layer_ids]
59 | elif self.name_module_trainable == "qkv":
60 | self.trainable_layers = ["{:d}.attention.query_key_value".format(idx) for idx in trainable_layer_ids]
61 |
62 | assert self.finetuning_type in ["none", "freeze", "p_tuning", "lora", "full"], "Invalid fine-tuning method."
63 |
64 | def save_to_json(self, json_path: str) -> None:
65 | """Saves the content of this instance in JSON format inside `json_path`."""
66 | json_string = json.dumps(asdict(self), indent=2, sort_keys=True) + "\n"
67 | with open(json_path, "w", encoding="utf-8") as f:
68 | f.write(json_string)
69 |
70 | @classmethod
71 | def load_from_json(cls, json_path: str):
72 | """Creates an instance from the content of `json_path`."""
73 | with open(json_path, "r", encoding="utf-8") as f:
74 | text = f.read()
75 | return cls(**json.loads(text))
76 |
--------------------------------------------------------------------------------
/src/glmtuner/hparams/general_args.py:
--------------------------------------------------------------------------------
1 | from typing import Literal, Optional
2 | from dataclasses import dataclass, field
3 |
4 |
5 | @dataclass
6 | class GeneralArguments:
7 | """
8 | Arguments pertaining to which techniques we are going to fine-tuning with.
9 | """
10 | stage: Optional[Literal["sft", "rm", "ppo"]] = field(
11 | default="sft",
12 | metadata={"help": "Which stage will be performed in training."}
13 | )
14 |
--------------------------------------------------------------------------------
/src/glmtuner/hparams/generating_args.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Dict, Optional
2 | from dataclasses import asdict, dataclass, field
3 |
4 |
5 | @dataclass
6 | class GeneratingArguments:
7 | """
8 | Arguments pertaining to specify the decoding parameters.
9 | """
10 | do_sample: Optional[bool] = field(
11 | default=True,
12 | metadata={"help": "Whether or not to use sampling, use greedy decoding otherwise."}
13 | )
14 | temperature: Optional[float] = field(
15 | default=0.95,
16 | metadata={"help": "The value used to modulate the next token probabilities."}
17 | )
18 | top_p: Optional[float] = field(
19 | default=0.7,
20 | metadata={"help": "The smallest set of most probable tokens with probabilities that add up to top_p or higher are kept."}
21 | )
22 | top_k: Optional[int] = field(
23 | default=50,
24 | metadata={"help": "The number of highest probability vocabulary tokens to keep for top-k filtering."}
25 | )
26 | num_beams: Optional[int] = field(
27 | default=1,
28 | metadata={"help": "Number of beams for beam search. 1 means no beam search."}
29 | )
30 | max_length: Optional[int] = field(
31 | default=2048,
32 | metadata={"help": "The maximum length the generated tokens can have. It can be overridden by max_new_tokens."}
33 | )
34 | max_new_tokens: Optional[int] = field(
35 | default=None,
36 | metadata={"help": "The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt."}
37 | )
38 | repetition_penalty: Optional[float] = field(
39 | default=1.0,
40 | metadata={"help": "The parameter for repetition penalty. 1.0 means no penalty."}
41 | )
42 |
43 | def to_dict(self) -> Dict[str, Any]:
44 | return asdict(self)
45 |
--------------------------------------------------------------------------------
/src/glmtuner/hparams/model_args.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from typing import Literal, Optional
3 | from dataclasses import dataclass, field
4 |
5 |
6 | @dataclass
7 | class ModelArguments:
8 | """
9 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune.
10 | """
11 | model_name_or_path: Optional[str] = field(
12 | default="THUDM/chatglm-6b",
13 | metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models."}
14 | )
15 | config_name: Optional[str] = field(
16 | default=None,
17 | metadata={"help": "Pretrained config name or path if not the same as model_name."}
18 | )
19 | tokenizer_name: Optional[str] = field(
20 | default=None,
21 | metadata={"help": "Pretrained tokenizer name or path if not the same as model_name."}
22 | )
23 | cache_dir: Optional[str] = field(
24 | default=None,
25 | metadata={"help": "Where to store the pretrained models downloaded from huggingface.co."}
26 | )
27 | use_fast_tokenizer: Optional[bool] = field(
28 | default=True,
29 | metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}
30 | )
31 | model_revision: Optional[str] = field(
32 | default="main",
33 | metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}
34 | )
35 | use_auth_token: Optional[bool] = field(
36 | default=False,
37 | metadata={"help": "Will use the token generated when running `huggingface-cli login`."}
38 | )
39 | quantization_bit: Optional[int] = field(
40 | default=None,
41 | metadata={"help": "The number of bits to quantize the model."}
42 | )
43 | quantization_type: Optional[Literal["fp4", "nf4"]] = field(
44 | default="nf4",
45 | metadata={"help": "Quantization data type to use in int4 training."}
46 | )
47 | double_quantization: Optional[bool] = field(
48 | default=True,
49 | metadata={"help": "Whether to use double quantization in int4 training or not."}
50 | )
51 | compute_dtype: Optional[torch.dtype] = field(
52 | default=None,
53 | metadata={"help": "Used in quantization configs. Do not specify this argument manually."}
54 | )
55 | checkpoint_dir: Optional[str] = field(
56 | default=None,
57 | metadata={"help": "Path to the directory containing the model checkpoints as well as the configurations."}
58 | )
59 | reward_model: Optional[str] = field(
60 | default=None,
61 | metadata={"help": "Path to the directory containing the checkpoints of the reward model."}
62 | )
63 | resume_lora_training: Optional[bool] = field(
64 | default=True,
65 | metadata={"help": "Whether to resume training from the last LoRA weights or create new weights after merging them."}
66 | )
67 | plot_loss: Optional[bool] = field(
68 | default=False,
69 | metadata={"help": "Whether to plot the training loss after fine-tuning or not."}
70 | )
71 |
72 | def __post_init__(self):
73 | if not self.checkpoint_dir:
74 | self.checkpoint_dir = None
75 |
76 | if self.checkpoint_dir is not None: # support merging lora weights
77 | self.checkpoint_dir = [cd.strip() for cd in self.checkpoint_dir.split(",")]
78 |
79 | if self.quantization_bit is not None:
80 | assert self.quantization_bit in [4, 8], "We only accept 4-bit or 8-bit quantization."
81 |
--------------------------------------------------------------------------------
/src/glmtuner/tuner/__init__.py:
--------------------------------------------------------------------------------
1 | from glmtuner.tuner.core import get_train_args, get_infer_args, load_model_and_tokenizer
2 | from glmtuner.tuner.sft import run_sft
3 | from glmtuner.tuner.rm import run_rm
4 | from glmtuner.tuner.ppo import run_ppo
5 |
--------------------------------------------------------------------------------
/src/glmtuner/tuner/core/__init__.py:
--------------------------------------------------------------------------------
1 | from glmtuner.tuner.core.parser import get_train_args, get_infer_args
2 | from glmtuner.tuner.core.loader import load_model_and_tokenizer
3 |
--------------------------------------------------------------------------------
/src/glmtuner/tuner/core/adapter.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 |
4 | from transformers.modeling_utils import PreTrainedModel
5 | from peft import (
6 | PeftModel,
7 | TaskType,
8 | LoraConfig,
9 | get_peft_model
10 | )
11 | from peft.utils import CONFIG_NAME, WEIGHTS_NAME
12 |
13 | from glmtuner.extras.logging import get_logger
14 | from glmtuner.extras.save_and_load import load_trainable_params
15 | from glmtuner.hparams import ModelArguments, FinetuningArguments
16 |
17 |
18 | logger = get_logger(__name__)
19 |
20 |
21 | def init_adapter(
22 | model: PreTrainedModel,
23 | model_args: ModelArguments,
24 | finetuning_args: FinetuningArguments,
25 | is_trainable: bool
26 | ) -> PreTrainedModel:
27 | r"""
28 | Initializes the adapters.
29 |
30 | Support full-parameter, freeze, P-Tuning v2 and LoRA training.
31 |
32 | Note that the trainable parameters must be cast to float32.
33 | """
34 |
35 | if finetuning_args.finetuning_type == "none" and is_trainable:
36 | raise ValueError("You cannot use finetuning_type=none while training.")
37 |
38 | if finetuning_args.finetuning_type == "full":
39 | logger.info("Fine-tuning method: Full")
40 | model = model.float()
41 |
42 | if finetuning_args.finetuning_type == "freeze":
43 | logger.info("Fine-tuning method: Freeze")
44 |
45 | for name, param in model.named_parameters():
46 | if not any(trainable_layer in name for trainable_layer in finetuning_args.trainable_layers):
47 | param.requires_grad_(False)
48 | else:
49 | param.data = param.data.to(torch.float32)
50 |
51 | if model_args.checkpoint_dir is not None:
52 | assert load_trainable_params(model, model_args.checkpoint_dir[0]), "Model checkpoint is not correctly loaded."
53 |
54 | if finetuning_args.finetuning_type == "p_tuning":
55 | logger.info("Fine-tuning method: P-Tuning v2")
56 |
57 | if model_args.checkpoint_dir is not None:
58 | assert load_trainable_params(model, model_args.checkpoint_dir[0]), "Model checkpoint is not correctly loaded."
59 |
60 | if finetuning_args.finetuning_type == "lora":
61 | logger.info("Fine-tuning method: LoRA")
62 | latest_checkpoint = None
63 |
64 | if model_args.checkpoint_dir is not None:
65 | assert os.path.exists(os.path.join(model_args.checkpoint_dir[0], WEIGHTS_NAME)), \
66 | "Provided path ({}) does not contain a LoRA weight.".format(model_args.checkpoint_dir[0])
67 | assert os.path.exists(os.path.join(model_args.checkpoint_dir[0], CONFIG_NAME)), \
68 | "The given checkpoint may be not a LoRA checkpoint, please specify `--finetuning_type full/p_tuning/freeze` instead."
69 |
70 | if is_trainable and model_args.resume_lora_training: # continually train on the lora weights
71 | checkpoints_to_merge, latest_checkpoint = model_args.checkpoint_dir[:-1], model_args.checkpoint_dir[-1]
72 | else:
73 | checkpoints_to_merge = model_args.checkpoint_dir
74 |
75 | for checkpoint in checkpoints_to_merge:
76 | model = PeftModel.from_pretrained(model, checkpoint)
77 | model = model.merge_and_unload()
78 |
79 | if len(checkpoints_to_merge) > 0:
80 | logger.info("Merged {} model checkpoint(s).".format(len(checkpoints_to_merge)))
81 |
82 | if latest_checkpoint is not None: # resume lora training
83 | model = PeftModel.from_pretrained(model, latest_checkpoint, is_trainable=True)
84 |
85 | if is_trainable and latest_checkpoint is None: # create new lora weights while training
86 | lora_config = LoraConfig(
87 | task_type=TaskType.CAUSAL_LM, # we should regard ChatGLM as a causal LM
88 | inference_mode=False,
89 | r=finetuning_args.lora_rank,
90 | lora_alpha=finetuning_args.lora_alpha,
91 | lora_dropout=finetuning_args.lora_dropout,
92 | target_modules=finetuning_args.lora_target
93 | )
94 | model = get_peft_model(model, lora_config)
95 |
96 | if model_args.checkpoint_dir is not None:
97 | logger.info("Loaded fine-tuned model from checkpoint(s): {}".format(",".join(model_args.checkpoint_dir)))
98 |
99 | return model
100 |
--------------------------------------------------------------------------------
/src/glmtuner/tuner/core/loader.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from typing import Literal, Optional, Tuple
4 |
5 | from transformers import (
6 | AutoConfig,
7 | AutoModel,
8 | AutoTokenizer,
9 | BitsAndBytesConfig
10 | )
11 | from transformers.utils import check_min_version
12 | from transformers.utils.versions import require_version
13 | from transformers.modeling_utils import PretrainedConfig, PreTrainedModel
14 | from transformers.tokenization_utils import PreTrainedTokenizerBase
15 | from trl import AutoModelForCausalLMWithValueHead
16 |
17 | from glmtuner.extras.logging import get_logger
18 | from glmtuner.extras.misc import prepare_model_for_training, print_trainable_params
19 | from glmtuner.extras.save_and_load import load_valuehead_params
20 | from glmtuner.hparams import ModelArguments, FinetuningArguments
21 | from glmtuner.tuner.core.adapter import init_adapter
22 |
23 |
24 | logger = get_logger(__name__)
25 |
26 |
27 | check_min_version("4.29.1")
28 | require_version("datasets>=2.12.0", "To fix: pip install datasets>=2.12.0")
29 | require_version("accelerate>=0.21.0", "To fix: pip install accelerate>=0.21.0")
30 | require_version("peft>=0.4.0", "To fix: pip install peft>=0.4.0")
31 | require_version("trl>=0.4.7", "To fix: pip install trl>=0.4.7")
32 |
33 |
34 | def load_model_and_tokenizer(
35 | model_args: ModelArguments,
36 | finetuning_args: FinetuningArguments,
37 | is_trainable: Optional[bool] = False,
38 | stage: Optional[Literal["sft", "rm", "ppo"]] = "sft"
39 | ) -> Tuple[PreTrainedModel, PreTrainedTokenizerBase]:
40 | r"""
41 | Loads pretrained model and tokenizer.
42 |
43 | Support both training and inference.
44 | """
45 |
46 | if (not is_trainable) and model_args.checkpoint_dir is None:
47 | logger.warning("Checkpoint is not found at evaluation, load the original model.")
48 | finetuning_args = FinetuningArguments(finetuning_type="none")
49 |
50 | assert stage == "sft" or finetuning_args.finetuning_type == "lora", \
51 | "RM and PPO training can only be performed with LoRA method."
52 |
53 | if model_args.quantization_bit is not None:
54 | if is_trainable and finetuning_args.finetuning_type == "lora":
55 | quantization = "bnb" # use bnb's quantization
56 | else:
57 | quantization = "cpm" # use cpm's quantization
58 | else:
59 | quantization = None
60 |
61 | config_kwargs = {
62 | "trust_remote_code": True,
63 | "cache_dir": model_args.cache_dir,
64 | "revision": model_args.model_revision,
65 | "use_auth_token": True if model_args.use_auth_token else None,
66 | }
67 |
68 | tokenizer = AutoTokenizer.from_pretrained(
69 | model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
70 | use_fast=model_args.use_fast_tokenizer,
71 | padding_side="left",
72 | **config_kwargs
73 | )
74 |
75 | config = AutoConfig.from_pretrained(
76 | model_args.config_name if model_args.config_name else model_args.model_name_or_path,
77 | **config_kwargs
78 | )
79 |
80 | # P-Tuning v2 configurations. Use the built-in p-tuning method of ChatGLM.
81 | if finetuning_args.finetuning_type == "p_tuning":
82 | config.pre_seq_len = finetuning_args.pre_seq_len # enable this will fix other parameters automatically
83 | config.prefix_projection = finetuning_args.prefix_projection
84 |
85 | # Quantization configurations for Full, Freeze and LoRA in training (using bitsandbytes library).
86 | if quantization == "bnb":
87 | if model_args.quantization_bit == 8:
88 | require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0")
89 | config_kwargs["load_in_8bit"] = True
90 | config_kwargs["quantization_config"] = BitsAndBytesConfig(
91 | load_in_8bit=True,
92 | llm_int8_threshold=6.0
93 | )
94 | elif model_args.quantization_bit == 4:
95 | require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0")
96 | config_kwargs["load_in_4bit"] = True
97 | config_kwargs["quantization_config"] = BitsAndBytesConfig(
98 | load_in_4bit=True,
99 | bnb_4bit_compute_dtype=model_args.compute_dtype,
100 | bnb_4bit_use_double_quant=model_args.double_quantization,
101 | bnb_4bit_quant_type=model_args.quantization_type
102 | )
103 | config_kwargs["device_map"] = {"": int(os.environ.get("LOCAL_RANK") or 0)}
104 |
105 | if model_args.checkpoint_dir is not None and finetuning_args.finetuning_type == "full":
106 | model_to_load = model_args.checkpoint_dir[0]
107 | else:
108 | model_to_load = model_args.model_name_or_path
109 |
110 | # Load and prepare pretrained models (without valuehead).
111 | model = AutoModel.from_pretrained(model_to_load, config=config, **config_kwargs)
112 |
113 | # Register auto class to save the custom code files.
114 | if isinstance(config, PretrainedConfig):
115 | config.__class__.register_for_auto_class()
116 | if isinstance(tokenizer, PreTrainedTokenizerBase):
117 | tokenizer.__class__.register_for_auto_class()
118 | if isinstance(model, PreTrainedModel):
119 | model.__class__.register_for_auto_class()
120 |
121 | if tokenizer.eos_token_id == 130005: # ChatGLM-6B
122 | output_embedding_base_layer = model
123 | output_embedding_layer_name = "lm_head"
124 | elif tokenizer.eos_token_id == 2: # ChatGLM2-6B
125 | assert hasattr(model, "transformer"), "Please update the model files of ChatGLM-6B."
126 | model.lm_head = model.transformer.output_layer
127 | output_embedding_base_layer = model.transformer
128 | output_embedding_layer_name = "output_layer"
129 | else:
130 | raise ValueError("Please update the model files of ChatGLM2-6B.")
131 |
132 | # Initialize adapters
133 | model = prepare_model_for_training(
134 | model,
135 | finetuning_args.finetuning_type,
136 | output_embedding_base_layer,
137 | output_embedding_layer_name
138 | ) if is_trainable else model
139 | model = init_adapter(model, model_args, finetuning_args, is_trainable)
140 |
141 | if not is_trainable:
142 | model.requires_grad_(False) # fix all model params
143 | model = model.half() # cast all params to float16 for inference
144 |
145 | # Quantization with the built-in method for P-Tuning v2 training or evaluation.
146 | # Model parameters should be cast to float16 in quantized P-Tuning setting.
147 | if quantization == "cpm":
148 | if is_trainable: # convert all params into half precision except prefix_encoder in training
149 | for name, param in model.named_parameters():
150 | if "prefix_encoder" not in name:
151 | param.data = param.data.to(torch.float16)
152 |
153 | model.quantize(model_args.quantization_bit) # built-in method in ChatGLM-6B, also an in-place operation
154 |
155 | if quantization is not None:
156 | logger.info("Quantized model to {} bit.".format(model_args.quantization_bit))
157 |
158 | if stage == "rm" or stage == "ppo": # add value head
159 | model = AutoModelForCausalLMWithValueHead.from_pretrained(model)
160 |
161 | if stage == "rm" and model_args.checkpoint_dir is not None: # load valuehead weights to evaluate reward model
162 | logger.warning("Only the last checkpoint containing valuehead will be loaded as the valuehead.")
163 | if load_valuehead_params(model, model_args.checkpoint_dir[-1]):
164 | model.v_head.load_state_dict({
165 | "summary.weight": getattr(model, "reward_head_weight"),
166 | "summary.bias": getattr(model, "reward_head_bias")
167 | })
168 |
169 | if stage == "ppo": # load reward model
170 | assert is_trainable, "PPO stage cannot be performed at evaluation."
171 | assert model_args.reward_model is not None, "Reward model is necessary for PPO training."
172 | logger.info("Load reward model from {}".format(model_args.reward_model))
173 | model.pretrained_model.load_adapter(model_args.reward_model, "reward", is_trainable=False)
174 | assert load_valuehead_params(model, model_args.reward_model), "Reward model is not correctly loaded."
175 |
176 | print_trainable_params(model)
177 |
178 | return model, tokenizer
179 |
--------------------------------------------------------------------------------
/src/glmtuner/tuner/core/parser.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import torch
4 | import datasets
5 | import transformers
6 | from typing import Any, Dict, Optional, Tuple
7 | from transformers import HfArgumentParser, Seq2SeqTrainingArguments
8 |
9 | from glmtuner.extras.logging import get_logger
10 | from glmtuner.hparams import (
11 | ModelArguments,
12 | DataArguments,
13 | FinetuningArguments,
14 | GeneratingArguments,
15 | GeneralArguments
16 | )
17 |
18 |
19 | logger = get_logger(__name__)
20 |
21 |
22 | def get_train_args(
23 | args: Optional[Dict[str, Any]] = None
24 | ) -> Tuple[ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneralArguments]:
25 |
26 | parser = HfArgumentParser((ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneralArguments))
27 |
28 | if args is not None:
29 | model_args, data_args, training_args, finetuning_args, general_args = parser.parse_dict(args)
30 | elif len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"):
31 | model_args, data_args, training_args, finetuning_args, general_args = parser.parse_yaml_file(os.path.abspath(sys.argv[1]))
32 | elif len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
33 | model_args, data_args, training_args, finetuning_args, general_args = parser.parse_json_file(os.path.abspath(sys.argv[1]))
34 | else:
35 | model_args, data_args, training_args, finetuning_args, general_args = parser.parse_args_into_dataclasses()
36 |
37 | # Setup logging
38 | if training_args.should_log:
39 | # The default of training_args.log_level is passive, so we set log level at info here to have that default.
40 | transformers.utils.logging.set_verbosity_info()
41 |
42 | log_level = training_args.get_process_log_level()
43 | datasets.utils.logging.set_verbosity(log_level)
44 | transformers.utils.logging.set_verbosity(log_level)
45 | transformers.utils.logging.enable_default_handler()
46 | transformers.utils.logging.enable_explicit_format()
47 |
48 | # Check arguments (do not check finetuning_args since it may be loaded from checkpoints)
49 | data_args.init_for_training()
50 |
51 | assert general_args.stage == "sft" or (not training_args.predict_with_generate), \
52 | "`predict_with_generate` cannot be set as True at PT, RM and PPO stages."
53 |
54 | assert not (training_args.do_train and training_args.predict_with_generate), \
55 | "`predict_with_generate` cannot be set as True while training."
56 |
57 | assert general_args.stage != "sft" or (not training_args.do_predict) or training_args.predict_with_generate, \
58 | "Please enable `predict_with_generate` to save model predictions."
59 |
60 | if model_args.quantization_bit is not None:
61 | assert finetuning_args.finetuning_type != "full" and finetuning_args.finetuning_type != "freeze", \
62 | "Quantization is incompatible with the full-parameter and freeze tuning."
63 |
64 | assert not (finetuning_args.finetuning_type == "p_tuning" and training_args.fp16), \
65 | "FP16 training conflicts with quantized P-Tuning."
66 |
67 | if not training_args.do_train:
68 | logger.warning("Evaluating model in 4/8-bit mode may cause lower scores.")
69 |
70 | assert model_args.checkpoint_dir is None or finetuning_args.finetuning_type == "lora" \
71 | or len(model_args.checkpoint_dir) == 1, "Only LoRA tuning accepts multiple checkpoints."
72 |
73 | if training_args.do_train and (not training_args.fp16):
74 | logger.warning("We recommend enable fp16 mixed precision training for ChatGLM-6B.")
75 |
76 | if training_args.local_rank != -1 and training_args.ddp_find_unused_parameters is None:
77 | logger.warning("`ddp_find_unused_parameters` needs to be set as False in DDP training.")
78 | training_args.ddp_find_unused_parameters = False
79 |
80 | training_args.optim = "adamw_torch" if training_args.optim == "adamw_hf" else training_args.optim # suppress warning
81 |
82 | if model_args.quantization_bit is not None:
83 | if training_args.fp16:
84 | model_args.compute_dtype = torch.float16
85 | elif training_args.bf16:
86 | model_args.compute_dtype = torch.bfloat16
87 | else:
88 | model_args.compute_dtype = torch.float32
89 |
90 | # Log on each process the small summary:
91 | logger.info(
92 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}\n"
93 | + f" distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
94 | )
95 | logger.info(f"Training/evaluation parameters {training_args}")
96 |
97 | # Set seed before initializing model.
98 | transformers.set_seed(training_args.seed)
99 |
100 | return model_args, data_args, training_args, finetuning_args, general_args
101 |
102 |
103 | def get_infer_args(
104 | args: Optional[Dict[str, Any]] = None
105 | ) -> Tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]:
106 |
107 | parser = HfArgumentParser((ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments))
108 |
109 | if args is not None:
110 | model_args, data_args, finetuning_args, generating_args = parser.parse_dict(args)
111 | elif len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"):
112 | model_args, data_args, finetuning_args, generating_args = parser.parse_yaml_file(os.path.abspath(sys.argv[1]))
113 | elif len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
114 | model_args, data_args, finetuning_args, generating_args = parser.parse_json_file(os.path.abspath(sys.argv[1]))
115 | else:
116 | model_args, data_args, finetuning_args, generating_args = parser.parse_args_into_dataclasses()
117 |
118 | assert model_args.checkpoint_dir is None or finetuning_args.finetuning_type == "lora" \
119 | or len(model_args.checkpoint_dir) == 1, "Only LoRA tuning accepts multiple checkpoints."
120 |
121 | return model_args, data_args, finetuning_args, generating_args
122 |
--------------------------------------------------------------------------------
/src/glmtuner/tuner/core/trainer.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from typing import Dict, Optional
4 |
5 | from transformers import Seq2SeqTrainer
6 | from transformers.trainer import TRAINING_ARGS_NAME
7 | from transformers.modeling_utils import PreTrainedModel, unwrap_model
8 | from peft import PeftModel
9 |
10 | from glmtuner.extras.constants import FINETUNING_ARGS_NAME, VALUE_HEAD_FILE_NAME
11 | from glmtuner.extras.logging import get_logger
12 | from glmtuner.extras.save_and_load import get_state_dict, load_trainable_params, load_valuehead_params
13 | from glmtuner.hparams import FinetuningArguments
14 |
15 |
16 | logger = get_logger(__name__)
17 |
18 |
19 | class PeftTrainer(Seq2SeqTrainer):
20 | r"""
21 | Inherits Seq2SeqTrainer to support parameter-efficient checkpoints.
22 | """
23 |
24 | def __init__(self, finetuning_args: FinetuningArguments, **kwargs):
25 | super().__init__(**kwargs)
26 | self.finetuning_args = finetuning_args
27 | self._remove_log()
28 |
29 | def _remove_log(self):
30 | if self.is_world_process_zero() and os.path.exists(os.path.join(self.args.output_dir, "trainer_log.jsonl")):
31 | logger.warning("Previous log file in this folder will be deleted.")
32 | os.remove(os.path.join(self.args.output_dir, "trainer_log.jsonl"))
33 |
34 | def _save(self, output_dir: Optional[str] = None, state_dict: Optional[Dict[str, torch.Tensor]] = None) -> None:
35 | r"""
36 | Saves trainable parameters as model checkpoint.
37 |
38 | This function will only be executed at the process zero.
39 |
40 | Subclass and override to inject custom behavior. It should not be directly used by external scripts.
41 | """
42 | output_dir = output_dir if output_dir is not None else self.args.output_dir
43 | os.makedirs(output_dir, exist_ok=True)
44 | logger.info(f"Saving model checkpoint to {output_dir}")
45 | model = unwrap_model(self.model)
46 |
47 | if hasattr(model, "pretrained_model"): # for models with valuehead (currently using LoRA only)
48 | backbone_model = getattr(model, "pretrained_model")
49 | torch.save(get_state_dict(getattr(model, "v_head")), os.path.join(output_dir, VALUE_HEAD_FILE_NAME))
50 | else:
51 | backbone_model = model
52 |
53 | if isinstance(backbone_model, PeftModel): # LoRA tuning
54 | backbone_model.save_pretrained(output_dir, state_dict=get_state_dict(backbone_model))
55 | elif isinstance(backbone_model, PreTrainedModel): # freeze/full-tuning or p_tuning
56 | backbone_model.config.use_cache = True
57 | backbone_model.save_pretrained(
58 | output_dir,
59 | state_dict=get_state_dict(backbone_model, trainable_only=(self.finetuning_args.finetuning_type != "full")),
60 | safe_serialization=self.args.save_safetensors
61 | )
62 | backbone_model.config.use_cache = False
63 | if self.tokenizer is not None:
64 | self.tokenizer.save_pretrained(output_dir)
65 | else:
66 | logger.warning("No model to save.")
67 |
68 | with open(os.path.join(output_dir, TRAINING_ARGS_NAME), "w", encoding="utf-8") as f:
69 | f.write(self.args.to_json_string() + "\n")
70 | self.finetuning_args.save_to_json(os.path.join(output_dir, FINETUNING_ARGS_NAME))
71 |
72 | def _load_best_model(self):
73 | r"""
74 | Loads trainable parameters from model checkpoint.
75 |
76 | Subclass and override to inject custom behavior. It should not be directly used by external scripts.
77 | """
78 | logger.info(f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric}).")
79 |
80 | model = unwrap_model(self.model)
81 | backbone_model = getattr(model, "pretrained_model") if hasattr(model, "pretrained_model") else model
82 |
83 | if isinstance(backbone_model, PeftModel):
84 | backbone_model.load_adapter(self.state.best_model_checkpoint, backbone_model.active_adapter)
85 | if hasattr(model, "v_head") and load_valuehead_params(model, self.state.best_model_checkpoint):
86 | model.v_head.load_state_dict({
87 | "summary.weight": getattr(model, "reward_head_weight"),
88 | "summary.bias": getattr(model, "reward_head_bias")
89 | })
90 | else: # freeze/full-tuning or p_tuning
91 | load_trainable_params(backbone_model, self.state.best_model_checkpoint)
92 |
--------------------------------------------------------------------------------
/src/glmtuner/tuner/ppo/__init__.py:
--------------------------------------------------------------------------------
1 | from glmtuner.tuner.ppo.workflow import run_ppo
2 |
--------------------------------------------------------------------------------
/src/glmtuner/tuner/ppo/trainer.py:
--------------------------------------------------------------------------------
1 | import os
2 | import math
3 | import torch
4 | from tqdm import tqdm
5 | from typing import Callable, Dict, List, Optional
6 |
7 | from transformers import Seq2SeqTrainingArguments, TrainerState, TrainerControl
8 | from transformers.modeling_utils import PreTrainedModel
9 |
10 | from trl import PPOTrainer, AutoModelForCausalLMWithValueHead
11 | from trl.core import LengthSampler
12 | from trl.trainer.ppo_trainer import PPODecorators, logprobs_from_logits
13 |
14 | from glmtuner.extras.callbacks import LogCallback
15 | from glmtuner.extras.logging import get_logger
16 | from glmtuner.extras.misc import AverageMeter, get_logits_processor
17 | from glmtuner.hparams import FinetuningArguments
18 | from glmtuner.tuner.core.trainer import PeftTrainer
19 | from glmtuner.tuner.ppo.utils import cast_layernorm_dtype, replace_model
20 |
21 |
22 | logger = get_logger(__name__)
23 |
24 |
25 | class PPOTrainerForChatGLM(PPOTrainer, PeftTrainer):
26 | r"""
27 | Inherits PPOTrainer.
28 | """
29 | def __init__(
30 | self,
31 | training_args: Seq2SeqTrainingArguments,
32 | finetuning_args: FinetuningArguments,
33 | callbacks: List[LogCallback],
34 | **kwargs
35 | ):
36 | PPOTrainer.__init__(self, **kwargs)
37 | self.args = training_args
38 | self.finetuning_args = finetuning_args
39 | self.log_callback = callbacks[0]
40 | self.state = TrainerState()
41 | self.control = TrainerControl()
42 | self.data_collator = self.accelerator.prepare(kwargs["data_collator"]) # override the data collator of PPOTrainer
43 | self._remove_log()
44 |
45 | def ppo_train(self, max_target_length: int) -> None:
46 | r"""
47 | Implements training loop for the PPO stage, like _inner_training_loop() in Huggingface's Trainer.
48 | """
49 | total_train_batch_size = (
50 | self.args.per_device_train_batch_size * self.args.gradient_accumulation_steps * self.args.world_size
51 | )
52 | len_dataloader = len(self.dataloader)
53 | num_examples = len(self.dataset)
54 | num_train_epochs = self.args.num_train_epochs
55 | max_steps = math.ceil(num_train_epochs * len_dataloader)
56 |
57 | self.state.max_steps = max_steps
58 | self.state.num_train_epochs = num_train_epochs
59 | self.state.is_local_process_zero = self.is_local_process_zero()
60 | self.state.is_world_process_zero = self.is_world_process_zero()
61 |
62 | if self.is_world_process_zero():
63 | logger.info("***** Running training *****")
64 | logger.info(f" Num examples = {num_examples}")
65 | logger.info(f" Num Epochs = {num_train_epochs}")
66 | logger.info(f" Instantaneous batch size per device = {self.args.per_device_train_batch_size}")
67 | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size}")
68 | logger.info(f" Gradient Accumulation steps = {self.args.gradient_accumulation_steps}")
69 | logger.info(f" Total optimization steps = {max_steps}")
70 | logger.info(f" Number of trainable parameters = {sum(p.numel() for p in self.model.parameters() if p.requires_grad)}")
71 |
72 | # Keyword arguments for `model.generate`
73 | gen_kwargs = {
74 | "top_k": 0.0,
75 | "top_p": 1.0,
76 | "do_sample": True,
77 | "pad_token_id": self.tokenizer.pad_token_id,
78 | "eos_token_id": self.tokenizer.eos_token_id,
79 | "logits_processor": get_logits_processor()
80 | }
81 | length_sampler = LengthSampler(max_target_length // 2, max_target_length)
82 | unwrapped_model: PreTrainedModel = self.accelerator.unwrap_model(self.model)
83 |
84 | dataiter = iter(self.dataloader)
85 | steps_trained = 0
86 | loss_meter = AverageMeter()
87 | reward_meter = AverageMeter()
88 | self.log_callback.on_train_begin(self.args, self.state, self.control)
89 |
90 | for step in tqdm(range(max_steps), disable=not self.is_world_process_zero(), leave=False):
91 | batch = next(dataiter)
92 | steps_trained += 1
93 |
94 | unwrapped_model.gradient_checkpointing_disable()
95 | unwrapped_model.config.use_cache = True
96 |
97 | # Get responses
98 | query_tensors = batch["input_ids"]
99 | response_tensors = self.generate(batch, length_sampler, return_prompt=False, **gen_kwargs)
100 |
101 | queries, responses = [], []
102 | for i in range(len(query_tensors)):
103 | query_length = (query_tensors[i] != self.tokenizer.pad_token_id).nonzero()[0]
104 | response_length = (response_tensors[i] != self.tokenizer.pad_token_id).nonzero()[-1] + 1
105 | queries.append(query_tensors[i, query_length:]) # remove padding from left
106 | responses.append(response_tensors[i, :response_length]) # remove padding from right
107 |
108 | # Compute rewards
109 | replace_model(unwrapped_model, target="reward")
110 | with torch.no_grad():
111 | _, _, values = self.model(
112 | **self.prepare_model_inputs(queries, responses),
113 | output_hidden_states=True,
114 | return_dict=True
115 | )
116 | rewards = [reward for reward in values[-1].to(torch.float32)] # use float32 type
117 | replace_model(unwrapped_model, target="default")
118 |
119 | # Run PPO step
120 | unwrapped_model.gradient_checkpointing_enable()
121 | unwrapped_model.config.use_cache = False
122 | stats = self.step(queries, responses, rewards)
123 |
124 | loss_meter.update(stats["ppo/loss/total"], n=len(rewards))
125 | reward_meter.update(torch.stack(rewards).mean().item(), n=len(rewards))
126 |
127 | if self.is_world_process_zero() and (step+1) % self.args.logging_steps == 0:
128 | logs = dict(
129 | loss=round(loss_meter.avg, 4),
130 | reward=round(reward_meter.avg, 4),
131 | learning_rate=stats["ppo/learning_rate"],
132 | epoch=round(step / len_dataloader, 2)
133 | )
134 | print(logs)
135 | logs["step"] = step
136 | self.state.log_history.append(logs)
137 | self.log_callback.on_log(self.args, self.state, self.control)
138 | loss_meter.reset()
139 | reward_meter.reset()
140 |
141 | if (step+1) % self.args.save_steps == 0: # save checkpoint
142 | self.save_model(os.path.join(self.args.output_dir, f"checkpoint-{step+1}"))
143 |
144 | if self.control.should_epoch_stop or self.control.should_training_stop:
145 | break
146 |
147 | if steps_trained == len_dataloader:
148 | dataiter = iter(self.dataloader)
149 | steps_trained = 0
150 |
151 | @torch.no_grad()
152 | def generate(
153 | self,
154 | inputs: Dict[str, torch.Tensor],
155 | length_sampler: Optional[Callable] = None,
156 | return_prompt: Optional[bool] = True,
157 | **generation_kwargs
158 | ) -> torch.Tensor:
159 | r"""
160 | Generates model's responses given queries.
161 |
162 | Subclass and override to inject custom behavior.
163 | """
164 | self.model, layer_norm_params = cast_layernorm_dtype(self.model)
165 |
166 | if length_sampler is not None:
167 | generation_kwargs["max_new_tokens"] = length_sampler()
168 |
169 | unwrapped_model = self.accelerator.unwrap_model(self.model)
170 |
171 | response = unwrapped_model.generate(**inputs, **generation_kwargs)
172 |
173 | # Temporary hack to ensure the generation config is not initialized for each iteration of the evaluation loop
174 | # Inspired by: https://github.com/huggingface/transformers/blob/v4.28.1/src/transformers/trainer_seq2seq.py#L273
175 | if unwrapped_model.pretrained_model.generation_config._from_model_config:
176 | unwrapped_model.pretrained_model.generation_config._from_model_config = False
177 |
178 | self.model, _ = cast_layernorm_dtype(self.model, layer_norm_params)
179 |
180 | if not return_prompt and not self.is_encoder_decoder:
181 | return response[:, inputs["input_ids"].size(1):]
182 | return response
183 |
184 | @PPODecorators.empty_cuda_cache()
185 | def batched_forward_pass(
186 | self,
187 | model: AutoModelForCausalLMWithValueHead,
188 | queries: torch.Tensor,
189 | responses: torch.Tensor,
190 | model_inputs: dict,
191 | return_logits: bool = False
192 | ):
193 | r"""
194 | Calculates model outputs in multiple batches.
195 |
196 | Subclass and override to inject custom behavior.
197 | """
198 | bs = len(queries)
199 | fbs = self.config.mini_batch_size
200 | all_logprobs = []
201 | all_logits = []
202 | all_masks = []
203 | all_values = []
204 |
205 | for i in range(int(bs / fbs)):
206 | input_kwargs = {key: value[i * fbs : (i + 1) * fbs] for key, value in model_inputs.items()}
207 | query_batch = queries[i * fbs : (i + 1) * fbs]
208 | response_batch = responses[i * fbs : (i + 1) * fbs]
209 | input_ids = input_kwargs["input_ids"] # left-padded sequences
210 |
211 | if self.is_distributed: # re-generate them to adapt padded inputs
212 | input_kwargs["attention_mask"] = self.data_collator.get_attention_masks(input_ids, device=self.current_device)
213 | input_kwargs["position_ids"] = self.data_collator.get_position_ids(input_ids, device=self.current_device)
214 |
215 | logits, _, values = model(**input_kwargs, output_hidden_states=True, return_dict=True)
216 | logprobs = logprobs_from_logits(logits[:, :-1, :], input_ids[:, 1:])
217 | values = values.transpose(0, 1)
218 | masks = torch.zeros_like(input_ids)
219 |
220 | for j in range(fbs):
221 | start = len(query_batch[j]) - 1
222 | start += (input_ids[j] != self.tokenizer.pad_token_id).nonzero()[0].item()
223 | end = start + len(response_batch[j])
224 | masks[j][start:end] = 1
225 |
226 | if return_logits:
227 | all_logits.append(logits)
228 | else:
229 | del logits
230 | all_values.append(values)
231 | all_logprobs.append(logprobs)
232 | all_masks.append(masks)
233 |
234 | return (
235 | torch.cat(all_logprobs),
236 | torch.cat(all_logits)[:, :-1] if return_logits else None,
237 | torch.cat(all_values)[:, :-1],
238 | torch.cat(all_masks)[:, :-1],
239 | )
240 |
241 | def save_model(self, output_dir: Optional[str] = None) -> None:
242 | r"""
243 | Saves model checkpoint.
244 |
245 | Subclass and override to inject custom behavior.
246 | """
247 |
248 | if self.args.should_save:
249 | self._save(output_dir)
250 |
--------------------------------------------------------------------------------
/src/glmtuner/tuner/ppo/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from typing import Dict, List, Literal, Optional, Tuple
3 | from trl import AutoModelForCausalLMWithValueHead
4 |
5 | from glmtuner.extras.constants import LAYERNORM_NAMES
6 |
7 |
8 | def replace_model(model: AutoModelForCausalLMWithValueHead, target: Literal["default", "reward"]) -> None:
9 | if target == "reward": # save default head temporarily
10 | valuehead_state_dict = model.v_head.state_dict()
11 | setattr(model, "default_head_weight", valuehead_state_dict["summary.weight"])
12 | setattr(model, "default_head_bias", valuehead_state_dict["summary.bias"])
13 |
14 | model.pretrained_model.set_adapter(target) # set the LoRA adapter to be active
15 | model.v_head.load_state_dict({
16 | "summary.weight": getattr(model, "{}_head_weight".format(target)),
17 | "summary.bias": getattr(model, "{}_head_bias".format(target))
18 | })
19 |
20 |
21 | def cast_layernorm_dtype(
22 | model: AutoModelForCausalLMWithValueHead,
23 | layer_norm_names: List[str] = LAYERNORM_NAMES,
24 | layer_norm_params: Optional[Dict[str, torch.Tensor]] = None
25 | ) -> Tuple[AutoModelForCausalLMWithValueHead, Dict[str, torch.Tensor]]:
26 |
27 | layer_norm_state_dict = {}
28 |
29 | for name, param in model.named_parameters():
30 | if param.ndim == 1 and any(layer_norm_name in name for layer_norm_name in layer_norm_names):
31 | if layer_norm_params is not None:
32 | param.data = layer_norm_params[name] # restore float32 weights
33 | else:
34 | layer_norm_state_dict[name] = param.data.detach().clone() # store float32 weights for stability
35 | param.data = param.data.to(torch.float16)
36 |
37 | return model, layer_norm_state_dict
38 |
--------------------------------------------------------------------------------
/src/glmtuner/tuner/ppo/workflow.py:
--------------------------------------------------------------------------------
1 | # Inspired by:
2 | # https://github.com/lvwerra/trl/blob/main/examples/sentiment/scripts/gpt-neox-20b_peft/gpt-neo-20b_sentiment_peft.py
3 |
4 | import math
5 | from trl import PPOConfig
6 | from torch.optim import AdamW
7 | from typing import Optional, List
8 | from transformers import Seq2SeqTrainingArguments, TrainerCallback
9 | from transformers.optimization import get_scheduler
10 |
11 | from glmtuner.dsets import DataCollatorForChatGLM, get_dataset, preprocess_dataset
12 | from glmtuner.extras.callbacks import LogCallback
13 | from glmtuner.extras.ploting import plot_loss
14 | from glmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
15 | from glmtuner.tuner.core import load_model_and_tokenizer
16 | from glmtuner.tuner.ppo.trainer import PPOTrainerForChatGLM
17 |
18 |
19 | def run_ppo(
20 | model_args: ModelArguments,
21 | data_args: DataArguments,
22 | training_args: Seq2SeqTrainingArguments,
23 | finetuning_args: FinetuningArguments,
24 | callbacks: Optional[List[TrainerCallback]] = [LogCallback()]
25 | ):
26 | dataset = get_dataset(model_args, data_args)
27 | model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="ppo")
28 | dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="ppo")
29 | data_collator = DataCollatorForChatGLM(tokenizer, model.pretrained_model)
30 |
31 | ppo_config = PPOConfig(
32 | model_name=model_args.model_name_or_path,
33 | learning_rate=training_args.learning_rate,
34 | mini_batch_size=training_args.per_device_train_batch_size,
35 | batch_size=training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps,
36 | gradient_accumulation_steps=training_args.gradient_accumulation_steps,
37 | ppo_epochs=1,
38 | max_grad_norm=training_args.max_grad_norm
39 | )
40 |
41 | optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=ppo_config.learning_rate)
42 | total_train_batch_size = \
43 | training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps * training_args.world_size
44 | lr_scheduler = get_scheduler(
45 | training_args.lr_scheduler_type,
46 | optimizer=optimizer,
47 | num_warmup_steps=training_args.warmup_steps,
48 | num_training_steps=(training_args.num_train_epochs * math.ceil(len(dataset) / total_train_batch_size))
49 | )
50 |
51 | # Initialize our Trainer
52 | ppo_trainer = PPOTrainerForChatGLM(
53 | training_args=training_args,
54 | finetuning_args=finetuning_args,
55 | callbacks=callbacks,
56 | config=ppo_config,
57 | model=model,
58 | ref_model=None,
59 | tokenizer=tokenizer,
60 | dataset=dataset,
61 | data_collator=data_collator,
62 | optimizer=optimizer,
63 | lr_scheduler=lr_scheduler
64 | )
65 |
66 | ppo_trainer.ppo_train(max_target_length=data_args.max_target_length)
67 | ppo_trainer.save_model()
68 | ppo_trainer.save_state() # must be after save_model
69 | if ppo_trainer.is_world_process_zero() and model_args.plot_loss:
70 | plot_loss(training_args.output_dir, keys=["loss", "reward"])
71 |
--------------------------------------------------------------------------------
/src/glmtuner/tuner/rm/__init__.py:
--------------------------------------------------------------------------------
1 | from glmtuner.tuner.rm.workflow import run_rm
2 |
--------------------------------------------------------------------------------
/src/glmtuner/tuner/rm/collator.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from typing import Any, Dict, Sequence
3 |
4 | from glmtuner.dsets import DataCollatorForChatGLM
5 |
6 |
7 | class PairwiseDataCollatorForChatGLM(DataCollatorForChatGLM):
8 | r"""
9 | Data collator for pairwise data.
10 | """
11 |
12 | def __call__(self, features: Sequence[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
13 | r"""
14 | Pads batched data to the longest sequence in the batch.
15 |
16 | We generate 2 * n examples where the first n examples represent chosen examples and
17 | the last n examples represent rejected examples.
18 | """
19 |
20 | features = [{"input_ids": feature[key]} for key in ("accept_ids", "reject_ids") for feature in features]
21 | return super().__call__(features)
22 |
--------------------------------------------------------------------------------
/src/glmtuner/tuner/rm/metric.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from typing import Dict, Sequence, Tuple, Union
3 |
4 |
5 | def compute_accuracy(eval_preds: Sequence[Union[np.ndarray, Tuple[np.ndarray]]]) -> Dict[str, float]:
6 | preds, _ = eval_preds
7 | return {"accuracy": (preds[0] > preds[1]).sum() / len(preds[0])}
8 |
--------------------------------------------------------------------------------
/src/glmtuner/tuner/rm/trainer.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import torch
4 | from typing import Dict, List, Optional, Tuple, Union
5 | from transformers.trainer import PredictionOutput
6 | from transformers.modeling_utils import PreTrainedModel
7 |
8 | from glmtuner.extras.logging import get_logger
9 | from glmtuner.tuner.core.trainer import PeftTrainer
10 |
11 |
12 | logger = get_logger(__name__)
13 |
14 |
15 | class PairwiseTrainerForChatGLM(PeftTrainer):
16 | r"""
17 | Inherits PeftTrainer to compute pairwise loss.
18 | """
19 |
20 | def __init__(self, *args, **kwargs):
21 | super().__init__(*args, **kwargs)
22 | self.can_return_loss = True # override property to return eval_loss
23 |
24 | def compute_loss(
25 | self,
26 | model: PreTrainedModel,
27 | inputs: Dict[str, torch.Tensor],
28 | return_outputs: Optional[bool] = False
29 | ) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]:
30 | r"""
31 | Computes pairwise loss. The first n examples are chosen and the last n examples are rejected.
32 |
33 | We use score on the EOS token to represent reward of the whole sentence.
34 |
35 | Subclass and override to inject custom behavior. It should not be directly used by external scripts.
36 |
37 | Note that the first element will be removed from the output tuple.
38 |
39 | See: https://github.com/huggingface/transformers/blob/v4.30.2/src/transformers/trainer.py#L3509
40 | """
41 | batch_size = inputs["input_ids"].size(0) // 2
42 | _, _, values = model(**inputs, output_hidden_states=True, return_dict=True)
43 | r_accept, r_reject = values[-1].split(batch_size, dim=0)
44 | loss = -torch.log(torch.sigmoid(r_accept - r_reject)).mean()
45 | return (loss, [loss, r_accept, r_reject]) if return_outputs else loss
46 |
47 | def save_predictions(
48 | self,
49 | predict_results: PredictionOutput
50 | ) -> None:
51 | r"""
52 | Saves model predictions to `output_dir`.
53 | A custom behavior that not contained in Seq2SeqTrainer.
54 | """
55 | if not self.is_world_process_zero():
56 | return
57 |
58 | output_prediction_file = os.path.join(self.args.output_dir, "generated_predictions.jsonl")
59 | logger.info(f"Saving prediction results to {output_prediction_file}")
60 |
61 | acc_scores, rej_scores = predict_results.predictions
62 |
63 | with open(output_prediction_file, "w", encoding="utf-8") as writer:
64 | res: List[str] = []
65 | for acc_score, rej_score in zip(acc_scores, rej_scores):
66 | res.append(json.dumps({"accept": round(float(acc_score), 2), "reject": round(float(rej_score), 2)}))
67 | writer.write("\n".join(res))
68 |
--------------------------------------------------------------------------------
/src/glmtuner/tuner/rm/workflow.py:
--------------------------------------------------------------------------------
1 | # Inspired by:
2 | # https://github.com/lvwerra/trl/blob/main/examples/summarization/scripts/reward_summarization.py
3 | # https://github.com/CarperAI/trlx/blob/main/examples/summarize_rlhf/reward_model/train_reward_model_gptj.py
4 |
5 | from typing import Optional, List
6 | from transformers import Seq2SeqTrainingArguments, TrainerCallback
7 |
8 | from glmtuner.dsets import get_dataset, preprocess_dataset, split_dataset
9 | from glmtuner.extras.callbacks import LogCallback
10 | from glmtuner.extras.ploting import plot_loss
11 | from glmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
12 | from glmtuner.tuner.core import load_model_and_tokenizer
13 | from glmtuner.tuner.rm.metric import compute_accuracy
14 | from glmtuner.tuner.rm.collator import PairwiseDataCollatorForChatGLM
15 | from glmtuner.tuner.rm.trainer import PairwiseTrainerForChatGLM
16 |
17 |
18 | def run_rm(
19 | model_args: ModelArguments,
20 | data_args: DataArguments,
21 | training_args: Seq2SeqTrainingArguments,
22 | finetuning_args: FinetuningArguments,
23 | callbacks: Optional[List[TrainerCallback]] = [LogCallback()]
24 | ):
25 | dataset = get_dataset(model_args, data_args)
26 | model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="rm")
27 | dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="rm")
28 | data_collator = PairwiseDataCollatorForChatGLM(tokenizer, model.pretrained_model)
29 |
30 | training_args.remove_unused_columns = False # Important for pairwise dataset
31 |
32 | # Initialize our Trainer
33 | trainer = PairwiseTrainerForChatGLM(
34 | finetuning_args=finetuning_args,
35 | model=model,
36 | args=training_args,
37 | tokenizer=tokenizer,
38 | data_collator=data_collator,
39 | callbacks=callbacks,
40 | compute_metrics=compute_accuracy,
41 | **split_dataset(dataset, data_args.dev_ratio, training_args.do_train)
42 | )
43 |
44 | # Training
45 | if training_args.do_train:
46 | train_result = trainer.train()
47 | trainer.log_metrics("train", train_result.metrics)
48 | trainer.save_metrics("train", train_result.metrics)
49 | trainer.save_state()
50 | trainer.save_model()
51 | if trainer.is_world_process_zero() and model_args.plot_loss:
52 | plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])
53 |
54 | # Evaluation
55 | if training_args.do_eval:
56 | metrics = trainer.evaluate(metric_key_prefix="eval")
57 | trainer.log_metrics("eval", metrics)
58 | trainer.save_metrics("eval", metrics)
59 |
60 | # Predict
61 | if training_args.do_predict:
62 | predict_results = trainer.predict(dataset, metric_key_prefix="predict")
63 | trainer.log_metrics("predict", predict_results.metrics)
64 | trainer.save_metrics("predict", predict_results.metrics)
65 | trainer.save_predictions(predict_results)
66 |
--------------------------------------------------------------------------------
/src/glmtuner/tuner/sft/__init__.py:
--------------------------------------------------------------------------------
1 | from glmtuner.tuner.sft.workflow import run_sft
2 |
--------------------------------------------------------------------------------
/src/glmtuner/tuner/sft/metric.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from dataclasses import dataclass
3 | from typing import Dict, Sequence, Tuple, Union
4 | from transformers.tokenization_utils import PreTrainedTokenizer
5 |
6 | import jieba
7 | from rouge_chinese import Rouge
8 | from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
9 |
10 | from glmtuner.extras.constants import IGNORE_INDEX
11 |
12 |
13 | @dataclass
14 | class ComputeMetrics:
15 | r"""
16 | Wraps the tokenizer into metric functions, used in Seq2SeqTrainerForChatGLM.
17 | """
18 |
19 | tokenizer: PreTrainedTokenizer
20 |
21 | def __call__(self, eval_preds: Sequence[Union[np.ndarray, Tuple[np.ndarray]]]) -> Dict[str, float]:
22 | r"""
23 | Uses the model predictions to compute metrics.
24 | """
25 | preds, labels = eval_preds
26 | score_dict = {"accuracy": [], "rouge-1": [], "rouge-2": [], "rouge-l": [], "bleu-4": []}
27 |
28 | preds = np.where(preds != IGNORE_INDEX, preds, self.tokenizer.pad_token_id)
29 | labels = np.where(labels != IGNORE_INDEX, labels, self.tokenizer.pad_token_id)
30 |
31 | decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True)
32 | decoded_labels = self.tokenizer.batch_decode(labels, skip_special_tokens=True)
33 |
34 | for pred, label in zip(decoded_preds, decoded_labels):
35 | hypothesis = list(jieba.cut(pred))
36 | reference = list(jieba.cut(label))
37 |
38 | if len(" ".join(hypothesis).split()) == 0 or len(" ".join(reference).split()) == 0:
39 | result = {"rouge-1": {"f": 0.0}, "rouge-2": {"f": 0.0}, "rouge-l": {"f": 0.0}}
40 | else:
41 | rouge = Rouge()
42 | scores = rouge.get_scores(" ".join(hypothesis), " ".join(reference))
43 | result = scores[0]
44 |
45 | for k, v in result.items():
46 | score_dict[k].append(round(v["f"] * 100, 4))
47 |
48 | bleu_score = sentence_bleu([list(label)], list(pred), smoothing_function=SmoothingFunction().method3)
49 | score_dict["bleu-4"].append(round(bleu_score * 100, 4))
50 | score_dict["accuracy"].append(float(len(label) != 0 and pred[:len(label)] == label))
51 |
52 | return {k: float(np.mean(v)) for k, v in score_dict.items()}
53 |
--------------------------------------------------------------------------------
/src/glmtuner/tuner/sft/trainer.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import torch
4 | import numpy as np
5 | import torch.nn as nn
6 | from typing import Any, Dict, List, Optional, Tuple, Union
7 | from transformers.trainer import PredictionOutput
8 |
9 | from glmtuner.extras.constants import IGNORE_INDEX
10 | from glmtuner.extras.logging import get_logger
11 | from glmtuner.tuner.core.trainer import PeftTrainer
12 |
13 |
14 | logger = get_logger(__name__)
15 |
16 |
17 | class Seq2SeqTrainerForChatGLM(PeftTrainer):
18 | r"""
19 | Inherits PeftTrainer to compute generative metrics such as BLEU and ROUGE.
20 | """
21 |
22 | def prediction_step(
23 | self,
24 | model: nn.Module,
25 | inputs: Dict[str, Union[torch.Tensor, Any]],
26 | prediction_loss_only: bool,
27 | ignore_keys: Optional[List[str]] = None,
28 | ) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
29 | r"""
30 | Removes the prompt part in the generated tokens.
31 |
32 | Subclass and override to inject custom behavior.
33 | """
34 | input_ids = inputs["input_ids"]
35 | loss, generated_tokens, labels = super().prediction_step(
36 | model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys
37 | )
38 | generated_tokens = generated_tokens[:, input_ids.size(-1):] if generated_tokens is not None else None
39 | return (loss, generated_tokens, labels)
40 |
41 | def save_predictions(
42 | self,
43 | predict_results: PredictionOutput
44 | ) -> None:
45 | r"""
46 | Saves model predictions to `output_dir`.
47 |
48 | A custom behavior that not contained in Seq2SeqTrainer.
49 | """
50 | if not self.is_world_process_zero():
51 | return
52 |
53 | output_prediction_file = os.path.join(self.args.output_dir, "generated_predictions.jsonl")
54 | logger.info(f"Saving prediction results to {output_prediction_file}")
55 |
56 | preds = np.where(predict_results.predictions != IGNORE_INDEX, predict_results.predictions, self.tokenizer.pad_token_id)
57 | labels = np.where(predict_results.label_ids != IGNORE_INDEX, predict_results.label_ids, self.tokenizer.pad_token_id)
58 |
59 | decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True)
60 | decoded_labels = self.tokenizer.batch_decode(labels, skip_special_tokens=True)
61 |
62 | with open(output_prediction_file, "w", encoding="utf-8") as writer:
63 | res: List[str] = []
64 | for pred, label in zip(decoded_preds, decoded_labels):
65 | res.append(json.dumps({"label": label, "predict": pred}, ensure_ascii=False))
66 | writer.write("\n".join(res))
67 |
--------------------------------------------------------------------------------
/src/glmtuner/tuner/sft/workflow.py:
--------------------------------------------------------------------------------
1 | # Inspired by: https://github.com/THUDM/ChatGLM-6B/blob/main/ptuning/main.py
2 |
3 | from typing import Optional, List
4 | from transformers import Seq2SeqTrainingArguments, TrainerCallback
5 |
6 | from glmtuner.dsets import DataCollatorForChatGLM, get_dataset, preprocess_dataset, split_dataset
7 | from glmtuner.extras.callbacks import LogCallback
8 | from glmtuner.extras.misc import get_logits_processor
9 | from glmtuner.extras.ploting import plot_loss
10 | from glmtuner.hparams import ModelArguments, DataArguments, FinetuningArguments
11 | from glmtuner.tuner.core import load_model_and_tokenizer
12 | from glmtuner.tuner.sft.metric import ComputeMetrics
13 | from glmtuner.tuner.sft.trainer import Seq2SeqTrainerForChatGLM
14 |
15 |
16 | def run_sft(
17 | model_args: ModelArguments,
18 | data_args: DataArguments,
19 | training_args: Seq2SeqTrainingArguments,
20 | finetuning_args: FinetuningArguments,
21 | callbacks: Optional[List[TrainerCallback]] = [LogCallback()]
22 | ):
23 | dataset = get_dataset(model_args, data_args)
24 | model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="sft")
25 | dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="sft")
26 | data_collator = DataCollatorForChatGLM(
27 | tokenizer=tokenizer,
28 | model=model,
29 | ignore_pad_token_for_loss=(data_args.ignore_pad_token_for_loss and not training_args.predict_with_generate)
30 | )
31 |
32 | # Override the decoding parameters of Seq2SeqTrainer
33 | training_args.generation_max_length = training_args.generation_max_length if \
34 | training_args.generation_max_length is not None else data_args.max_target_length
35 | training_args.generation_num_beams = data_args.eval_num_beams if \
36 | data_args.eval_num_beams is not None else training_args.generation_num_beams
37 |
38 | # Initialize our Trainer
39 | trainer = Seq2SeqTrainerForChatGLM(
40 | finetuning_args=finetuning_args,
41 | model=model,
42 | args=training_args,
43 | tokenizer=tokenizer,
44 | data_collator=data_collator,
45 | callbacks=callbacks,
46 | compute_metrics=ComputeMetrics(tokenizer) if training_args.predict_with_generate else None,
47 | **split_dataset(dataset, data_args.dev_ratio, training_args.do_train)
48 | )
49 |
50 | # Keyword arguments for `model.generate`
51 | gen_kwargs = {
52 | "do_sample": True,
53 | "top_p": 0.7,
54 | "max_new_tokens": data_args.max_target_length + 1,
55 | "temperature": 0.95,
56 | "logits_processor": get_logits_processor()
57 | }
58 |
59 | # Training
60 | if training_args.do_train:
61 | train_result = trainer.train()
62 | trainer.log_metrics("train", train_result.metrics)
63 | trainer.save_metrics("train", train_result.metrics)
64 | trainer.save_state()
65 | trainer.save_model()
66 | if trainer.is_world_process_zero() and model_args.plot_loss:
67 | plot_loss(training_args.output_dir, keys=["loss", "eval_loss"])
68 |
69 | # Evaluation
70 | if training_args.do_eval:
71 | metrics = trainer.evaluate(metric_key_prefix="eval", **gen_kwargs)
72 | if training_args.predict_with_generate: # eval_loss will be wrong if predict_with_generate is enabled
73 | metrics.pop("eval_loss", None)
74 | trainer.log_metrics("eval", metrics)
75 | trainer.save_metrics("eval", metrics)
76 |
77 | # Predict
78 | if training_args.do_predict:
79 | predict_results = trainer.predict(dataset, metric_key_prefix="predict", **gen_kwargs)
80 | if training_args.predict_with_generate: # predict_loss will be wrong if predict_with_generate is enabled
81 | predict_results.metrics.pop("predict_loss", None)
82 | trainer.log_metrics("predict", predict_results.metrics)
83 | trainer.save_metrics("predict", predict_results.metrics)
84 | trainer.save_predictions(predict_results)
85 |
--------------------------------------------------------------------------------
/src/glmtuner/webui/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lunyiliu/CoachLM/37617299916c3ae550b2f6da713835e21bea6e60/src/glmtuner/webui/__init__.py
--------------------------------------------------------------------------------
/src/glmtuner/webui/chat.py:
--------------------------------------------------------------------------------
1 | import os
2 | from typing import List, Tuple
3 |
4 | from glmtuner.chat.stream_chat import ChatModel
5 | from glmtuner.extras.misc import torch_gc
6 | from glmtuner.hparams import GeneratingArguments
7 | from glmtuner.tuner import get_infer_args
8 | from glmtuner.webui.common import get_model_path, get_save_dir
9 | from glmtuner.webui.locales import ALERTS
10 |
11 |
12 | class WebChatModel(ChatModel):
13 |
14 | def __init__(self, *args):
15 | self.model = None
16 | self.tokenizer = None
17 | self.generating_args = GeneratingArguments()
18 | if len(args) != 0:
19 | super().__init__(*args)
20 |
21 | def load_model(
22 | self,
23 | lang: str,
24 | model_name: str,
25 | checkpoints: List[str],
26 | finetuning_type: str,
27 | quantization_bit: str,
28 | source_prefix: str
29 | ):
30 | if self.model is not None:
31 | yield ALERTS["err_exists"][lang]
32 | return
33 |
34 | if not model_name:
35 | yield ALERTS["err_no_model"][lang]
36 | return
37 |
38 | model_name_or_path = get_model_path(model_name)
39 | if not model_name_or_path:
40 | yield ALERTS["err_no_path"][lang]
41 | return
42 |
43 | if checkpoints:
44 | checkpoint_dir = ",".join(
45 | [os.path.join(get_save_dir(model_name), finetuning_type, checkpoint) for checkpoint in checkpoints]
46 | )
47 | else:
48 | checkpoint_dir = None
49 |
50 | yield ALERTS["info_loading"][lang]
51 | args = dict(
52 | model_name_or_path=model_name_or_path,
53 | checkpoint_dir=checkpoint_dir,
54 | finetuning_type=finetuning_type,
55 | quantization_bit=int(quantization_bit) if quantization_bit else None,
56 | source_prefix=source_prefix
57 | )
58 | super().__init__(*get_infer_args(args))
59 | yield ALERTS["info_loaded"][lang]
60 |
61 | def unload_model(self, lang: str):
62 | yield ALERTS["info_unloading"][lang]
63 | self.model = None
64 | self.tokenizer = None
65 | torch_gc()
66 | yield ALERTS["info_unloaded"][lang]
67 |
68 | def predict(
69 | self,
70 | chatbot: List[Tuple[str, str]],
71 | query: str,
72 | history: List[Tuple[str, str]],
73 | prefix: str,
74 | max_length: int,
75 | top_p: float,
76 | temperature: float
77 | ):
78 | chatbot.append([query, ""])
79 | response = ""
80 | for new_text in self.stream_chat(
81 | query, history, prefix, max_length=max_length, top_p=top_p, temperature=temperature
82 | ):
83 | response += new_text
84 | response = self.postprocess(response)
85 | new_history = history + [(query, response)]
86 | chatbot[-1] = [query, response]
87 | yield chatbot, new_history
88 |
89 | def postprocess(self, response: str) -> str:
90 | blocks = response.split("```")
91 | for i, block in enumerate(blocks):
92 | if i % 2 == 0:
93 | blocks[i] = block.replace("<", "<").replace(">", ">")
94 | return "```".join(blocks)
95 |
--------------------------------------------------------------------------------
/src/glmtuner/webui/common.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | from typing import Any, Dict, Optional
4 |
5 | import gradio as gr
6 | from peft.utils import WEIGHTS_NAME as PEFT_WEIGHTS_NAME
7 | from transformers.trainer import WEIGHTS_NAME, WEIGHTS_INDEX_NAME
8 |
9 | from glmtuner.extras.constants import SUPPORTED_MODELS
10 |
11 |
12 | DEFAULT_CACHE_DIR = "cache"
13 | DEFAULT_DATA_DIR = "data"
14 | DEFAULT_SAVE_DIR = "saves"
15 | USER_CONFIG = "user.config"
16 | DATA_CONFIG = "dataset_info.json"
17 |
18 |
19 | def get_save_dir(model_name: str) -> str:
20 | return os.path.join(DEFAULT_SAVE_DIR, os.path.split(model_name)[-1])
21 |
22 |
23 | def get_config_path() -> os.PathLike:
24 | return os.path.join(DEFAULT_CACHE_DIR, USER_CONFIG)
25 |
26 |
27 | def load_config() -> Dict[str, Any]:
28 | try:
29 | with open(get_config_path(), "r", encoding="utf-8") as f:
30 | return json.load(f)
31 | except:
32 | return {"last_model": "", "path_dict": {}}
33 |
34 |
35 | def save_config(model_name: str, model_path: str) -> None:
36 | os.makedirs(DEFAULT_CACHE_DIR, exist_ok=True)
37 | user_config = load_config()
38 | user_config["last_model"] = model_name
39 | user_config["path_dict"][model_name] = model_path
40 | with open(get_config_path(), "w", encoding="utf-8") as f:
41 | json.dump(user_config, f, indent=2, ensure_ascii=False)
42 |
43 |
44 | def get_model_path(model_name: str) -> str:
45 | user_config = load_config()
46 | return user_config["path_dict"].get(model_name, SUPPORTED_MODELS.get(model_name, ""))
47 |
48 |
49 | def list_checkpoint(model_name: str, finetuning_type: str) -> Dict[str, Any]:
50 | checkpoints = []
51 | save_dir = os.path.join(get_save_dir(model_name), finetuning_type)
52 | if save_dir and os.path.isdir(save_dir):
53 | for checkpoint in os.listdir(save_dir):
54 | if (
55 | os.path.isdir(os.path.join(save_dir, checkpoint))
56 | and any([
57 | os.path.isfile(os.path.join(save_dir, checkpoint, name))
58 | for name in (WEIGHTS_NAME, WEIGHTS_INDEX_NAME, PEFT_WEIGHTS_NAME)
59 | ])
60 | ):
61 | checkpoints.append(checkpoint)
62 | return gr.update(value=[], choices=checkpoints)
63 |
64 |
65 | def load_dataset_info(dataset_dir: str) -> Dict[str, Any]:
66 | try:
67 | with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f:
68 | return json.load(f)
69 | except:
70 | return {}
71 |
72 |
73 | def list_dataset(dataset_dir: Optional[str] = None) -> Dict[str, Any]:
74 | dataset_info = load_dataset_info(dataset_dir if dataset_dir is not None else DEFAULT_DATA_DIR)
75 | return gr.update(value=[], choices=list(dataset_info.keys()))
76 |
--------------------------------------------------------------------------------
/src/glmtuner/webui/components/__init__.py:
--------------------------------------------------------------------------------
1 | from glmtuner.webui.components.top import create_top
2 | from glmtuner.webui.components.sft import create_sft_tab
3 | from glmtuner.webui.components.eval import create_eval_tab
4 | from glmtuner.webui.components.infer import create_infer_tab
5 | from glmtuner.webui.components.export import create_export_tab
6 |
--------------------------------------------------------------------------------
/src/glmtuner/webui/components/chatbot.py:
--------------------------------------------------------------------------------
1 | from typing import Dict, Optional, Tuple
2 |
3 | import gradio as gr
4 | from gradio.blocks import Block
5 | from gradio.components import Component
6 |
7 | from glmtuner.webui.chat import WebChatModel
8 |
9 |
10 | def create_chat_box(
11 | chat_model: WebChatModel,
12 | visible: Optional[bool] = False
13 | ) -> Tuple[Block, Component, Component, Dict[str, Component]]:
14 | with gr.Box(visible=visible) as chat_box:
15 | chatbot = gr.Chatbot()
16 |
17 | with gr.Row():
18 | with gr.Column(scale=4):
19 | prefix = gr.Dropdown(show_label=False)
20 | query = gr.Textbox(show_label=False, lines=8)
21 | submit_btn = gr.Button(variant="primary")
22 |
23 | with gr.Column(scale=1):
24 | clear_btn = gr.Button()
25 | max_length = gr.Slider(10, 2048, value=chat_model.generating_args.max_length, step=1)
26 | top_p = gr.Slider(0.01, 1, value=chat_model.generating_args.top_p, step=0.01)
27 | temperature = gr.Slider(0.01, 1.5, value=chat_model.generating_args.temperature, step=0.01)
28 |
29 | history = gr.State([])
30 |
31 | submit_btn.click(
32 | chat_model.predict,
33 | [chatbot, query, history, prefix, max_length, top_p, temperature],
34 | [chatbot, history],
35 | show_progress=True
36 | ).then(
37 | lambda: gr.update(value=""), outputs=[query]
38 | )
39 |
40 | clear_btn.click(lambda: ([], []), outputs=[chatbot, history], show_progress=True)
41 |
42 | return chat_box, chatbot, history, dict(
43 | prefix=prefix,
44 | query=query,
45 | submit_btn=submit_btn,
46 | clear_btn=clear_btn,
47 | max_length=max_length,
48 | top_p=top_p,
49 | temperature=temperature
50 | )
51 |
--------------------------------------------------------------------------------
/src/glmtuner/webui/components/data.py:
--------------------------------------------------------------------------------
1 | import gradio as gr
2 | from gradio.blocks import Block
3 | from gradio.components import Component
4 | from typing import Tuple
5 |
6 |
7 | def create_preview_box() -> Tuple[Block, Component, Component, Component]:
8 | with gr.Box(visible=False, elem_classes="modal-box") as preview_box:
9 | with gr.Row():
10 | preview_count = gr.Number(interactive=False)
11 |
12 | with gr.Row():
13 | preview_samples = gr.JSON(interactive=False)
14 |
15 | close_btn = gr.Button()
16 |
17 | close_btn.click(lambda: gr.update(visible=False), outputs=[preview_box])
18 |
19 | return preview_box, preview_count, preview_samples, close_btn
20 |
--------------------------------------------------------------------------------
/src/glmtuner/webui/components/eval.py:
--------------------------------------------------------------------------------
1 | from typing import Dict
2 | import gradio as gr
3 | from gradio.components import Component
4 |
5 | from glmtuner.webui.common import list_dataset, DEFAULT_DATA_DIR
6 | from glmtuner.webui.components.data import create_preview_box
7 | from glmtuner.webui.runner import Runner
8 | from glmtuner.webui.utils import can_preview, get_preview
9 |
10 |
11 | def create_eval_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str, Component]:
12 | with gr.Row():
13 | dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=2)
14 | dataset = gr.Dropdown(multiselect=True, scale=4)
15 | preview_btn = gr.Button(interactive=False, scale=1)
16 |
17 | preview_box, preview_count, preview_samples, close_btn = create_preview_box()
18 |
19 | dataset_dir.change(list_dataset, [dataset_dir], [dataset])
20 | dataset.change(can_preview, [dataset_dir, dataset], [preview_btn])
21 | preview_btn.click(get_preview, [dataset_dir, dataset], [preview_count, preview_samples, preview_box])
22 |
23 | with gr.Row():
24 | max_source_length = gr.Slider(value=512, minimum=4, maximum=4096, step=1)
25 | max_target_length = gr.Slider(value=512, minimum=4, maximum=4096, step=1)
26 | max_samples = gr.Textbox(value="100000")
27 | batch_size = gr.Slider(value=8, minimum=1, maximum=512, step=1)
28 | predict = gr.Checkbox(value=True)
29 |
30 | with gr.Row():
31 | start_btn = gr.Button()
32 | stop_btn = gr.Button()
33 |
34 | with gr.Box():
35 | output_box = gr.Markdown()
36 |
37 | start_btn.click(
38 | runner.run_eval,
39 | [
40 | top_elems["lang"],
41 | top_elems["model_name"],
42 | top_elems["checkpoints"],
43 | top_elems["finetuning_type"],
44 | top_elems["quantization_bit"],
45 | top_elems["source_prefix"],
46 | dataset_dir,
47 | dataset,
48 | max_source_length,
49 | max_target_length,
50 | max_samples,
51 | batch_size,
52 | predict
53 | ],
54 | [output_box]
55 | )
56 | stop_btn.click(runner.set_abort, queue=False)
57 |
58 | return dict(
59 | dataset_dir=dataset_dir,
60 | dataset=dataset,
61 | preview_btn=preview_btn,
62 | preview_count=preview_count,
63 | preview_samples=preview_samples,
64 | close_btn=close_btn,
65 | max_source_length=max_source_length,
66 | max_target_length=max_target_length,
67 | max_samples=max_samples,
68 | batch_size=batch_size,
69 | predict=predict,
70 | start_btn=start_btn,
71 | stop_btn=stop_btn,
72 | output_box=output_box
73 | )
74 |
--------------------------------------------------------------------------------
/src/glmtuner/webui/components/export.py:
--------------------------------------------------------------------------------
1 | from typing import Dict
2 | import gradio as gr
3 | from gradio.components import Component
4 |
5 | from glmtuner.webui.utils import export_model
6 |
7 |
8 | def create_export_tab(top_elems: Dict[str, Component]) -> Dict[str, Component]:
9 | with gr.Row():
10 | save_dir = gr.Textbox()
11 | max_shard_size = gr.Slider(value=10, minimum=1, maximum=100)
12 |
13 | export_btn = gr.Button()
14 | info_box = gr.Textbox(show_label=False, interactive=False)
15 |
16 | export_btn.click(
17 | export_model,
18 | [
19 | top_elems["lang"],
20 | top_elems["model_name"],
21 | top_elems["checkpoints"],
22 | top_elems["finetuning_type"],
23 | max_shard_size,
24 | save_dir
25 | ],
26 | [info_box]
27 | )
28 |
29 | return dict(
30 | save_dir=save_dir,
31 | max_shard_size=max_shard_size,
32 | export_btn=export_btn,
33 | info_box=info_box
34 | )
35 |
--------------------------------------------------------------------------------
/src/glmtuner/webui/components/infer.py:
--------------------------------------------------------------------------------
1 | from typing import Dict
2 |
3 | import gradio as gr
4 | from gradio.components import Component
5 |
6 | from glmtuner.webui.chat import WebChatModel
7 | from glmtuner.webui.components.chatbot import create_chat_box
8 |
9 |
10 | def create_infer_tab(top_elems: Dict[str, Component]) -> Dict[str, Component]:
11 | with gr.Row():
12 | load_btn = gr.Button()
13 | unload_btn = gr.Button()
14 |
15 | info_box = gr.Textbox(show_label=False, interactive=False)
16 |
17 | chat_model = WebChatModel()
18 | chat_box, chatbot, history, chat_elems = create_chat_box(chat_model)
19 |
20 | load_btn.click(
21 | chat_model.load_model,
22 | [
23 | top_elems["lang"],
24 | top_elems["model_name"],
25 | top_elems["checkpoints"],
26 | top_elems["finetuning_type"],
27 | top_elems["quantization_bit"],
28 | top_elems["source_prefix"]
29 | ],
30 | [info_box]
31 | ).then(
32 | lambda: gr.update(visible=(chat_model.model is not None)), outputs=[chat_box]
33 | )
34 |
35 | unload_btn.click(
36 | chat_model.unload_model, [top_elems["lang"]], [info_box]
37 | ).then(
38 | lambda: ([], []), outputs=[chatbot, history]
39 | ).then(
40 | lambda: gr.update(visible=(chat_model.model is not None)), outputs=[chat_box]
41 | )
42 |
43 | return dict(
44 | info_box=info_box,
45 | load_btn=load_btn,
46 | unload_btn=unload_btn,
47 | **chat_elems
48 | )
49 |
--------------------------------------------------------------------------------
/src/glmtuner/webui/components/sft.py:
--------------------------------------------------------------------------------
1 | from typing import Dict
2 | from transformers.trainer_utils import SchedulerType
3 |
4 | import gradio as gr
5 | from gradio.components import Component
6 |
7 | from glmtuner.webui.common import list_dataset, DEFAULT_DATA_DIR
8 | from glmtuner.webui.components.data import create_preview_box
9 | from glmtuner.webui.runner import Runner
10 | from glmtuner.webui.utils import can_preview, get_preview, gen_plot
11 |
12 |
13 | def create_sft_tab(top_elems: Dict[str, Component], runner: Runner) -> Dict[str, Component]:
14 | with gr.Row():
15 | dataset_dir = gr.Textbox(value=DEFAULT_DATA_DIR, scale=2)
16 | dataset = gr.Dropdown(multiselect=True, scale=4)
17 | preview_btn = gr.Button(interactive=False, scale=1)
18 |
19 | preview_box, preview_count, preview_samples, close_btn = create_preview_box()
20 |
21 | dataset_dir.change(list_dataset, [dataset_dir], [dataset])
22 | dataset.change(can_preview, [dataset_dir, dataset], [preview_btn])
23 | preview_btn.click(get_preview, [dataset_dir, dataset], [preview_count, preview_samples, preview_box])
24 |
25 | with gr.Row():
26 | max_source_length = gr.Slider(value=512, minimum=4, maximum=4096, step=1)
27 | max_target_length = gr.Slider(value=512, minimum=4, maximum=4096, step=1)
28 | learning_rate = gr.Textbox(value="5e-5")
29 | num_train_epochs = gr.Textbox(value="3.0")
30 | max_samples = gr.Textbox(value="100000")
31 |
32 | with gr.Row():
33 | batch_size = gr.Slider(value=4, minimum=1, maximum=512, step=1)
34 | gradient_accumulation_steps = gr.Slider(value=4, minimum=1, maximum=512, step=1)
35 | lr_scheduler_type = gr.Dropdown(
36 | value="cosine", choices=[scheduler.value for scheduler in SchedulerType]
37 | )
38 | max_grad_norm = gr.Textbox(value="1.0")
39 | dev_ratio = gr.Slider(value=0, minimum=0, maximum=1, step=0.001)
40 |
41 | with gr.Accordion(label="Advanced config", open=False) as advanced_tab:
42 | with gr.Row():
43 | logging_steps = gr.Slider(value=5, minimum=5, maximum=1000, step=5)
44 | save_steps = gr.Slider(value=100, minimum=10, maximum=5000, step=10)
45 | warmup_steps = gr.Slider(value=0, minimum=0, maximum=5000, step=1)
46 | compute_type = gr.Radio(choices=["fp16", "bf16", "fp32"], value="fp16")
47 |
48 | with gr.Accordion(label="LoRA config", open=False) as lora_tab:
49 | with gr.Row():
50 | lora_rank = gr.Slider(value=8, minimum=1, maximum=1024, step=1, scale=1)
51 | lora_dropout = gr.Slider(value=0, minimum=0, maximum=1, step=0.01, scale=1)
52 | lora_target = gr.Textbox(scale=2)
53 |
54 | with gr.Row():
55 | start_btn = gr.Button()
56 | stop_btn = gr.Button()
57 |
58 | with gr.Row():
59 | with gr.Column(scale=4):
60 | output_dir = gr.Textbox()
61 |
62 | with gr.Box():
63 | output_box = gr.Markdown()
64 |
65 | with gr.Column(scale=1):
66 | loss_viewer = gr.Plot()
67 |
68 | start_btn.click(
69 | runner.run_train,
70 | [
71 | top_elems["lang"],
72 | top_elems["model_name"],
73 | top_elems["checkpoints"],
74 | top_elems["finetuning_type"],
75 | top_elems["quantization_bit"],
76 | top_elems["source_prefix"],
77 | dataset_dir,
78 | dataset,
79 | max_source_length,
80 | max_target_length,
81 | learning_rate,
82 | num_train_epochs,
83 | max_samples,
84 | batch_size,
85 | gradient_accumulation_steps,
86 | lr_scheduler_type,
87 | max_grad_norm,
88 | dev_ratio,
89 | logging_steps,
90 | save_steps,
91 | warmup_steps,
92 | compute_type,
93 | lora_rank,
94 | lora_dropout,
95 | lora_target,
96 | output_dir
97 | ],
98 | [output_box]
99 | )
100 | stop_btn.click(runner.set_abort, queue=False)
101 |
102 | output_box.change(
103 | gen_plot, [top_elems["model_name"], top_elems["finetuning_type"], output_dir], loss_viewer, queue=False
104 | )
105 |
106 | return dict(
107 | dataset_dir=dataset_dir,
108 | dataset=dataset,
109 | preview_btn=preview_btn,
110 | preview_count=preview_count,
111 | preview_samples=preview_samples,
112 | close_btn=close_btn,
113 | max_source_length=max_source_length,
114 | max_target_length=max_target_length,
115 | learning_rate=learning_rate,
116 | num_train_epochs=num_train_epochs,
117 | max_samples=max_samples,
118 | batch_size=batch_size,
119 | gradient_accumulation_steps=gradient_accumulation_steps,
120 | lr_scheduler_type=lr_scheduler_type,
121 | max_grad_norm=max_grad_norm,
122 | dev_ratio=dev_ratio,
123 | advanced_tab=advanced_tab,
124 | logging_steps=logging_steps,
125 | save_steps=save_steps,
126 | warmup_steps=warmup_steps,
127 | compute_type=compute_type,
128 | lora_tab=lora_tab,
129 | lora_rank=lora_rank,
130 | lora_dropout=lora_dropout,
131 | lora_target=lora_target,
132 | start_btn=start_btn,
133 | stop_btn=stop_btn,
134 | output_dir=output_dir,
135 | output_box=output_box,
136 | loss_viewer=loss_viewer
137 | )
138 |
--------------------------------------------------------------------------------
/src/glmtuner/webui/components/top.py:
--------------------------------------------------------------------------------
1 | from typing import Dict
2 |
3 | import gradio as gr
4 | from gradio.components import Component
5 |
6 | from glmtuner.extras.constants import METHODS, SUPPORTED_MODELS
7 | from glmtuner.webui.common import list_checkpoint, get_model_path, save_config
8 | from glmtuner.webui.utils import can_quantize
9 |
10 |
11 | def create_top() -> Dict[str, Component]:
12 | available_models = list(SUPPORTED_MODELS.keys()) + ["Custom"]
13 |
14 | with gr.Row():
15 | lang = gr.Dropdown(choices=["en", "zh"], value="en", scale=1)
16 | model_name = gr.Dropdown(choices=available_models, scale=3)
17 | model_path = gr.Textbox(scale=3)
18 |
19 | with gr.Row():
20 | finetuning_type = gr.Dropdown(value="lora", choices=METHODS, scale=1)
21 | checkpoints = gr.Dropdown(multiselect=True, scale=5)
22 | refresh_btn = gr.Button(scale=1)
23 |
24 | with gr.Accordion(label="Advanced config", open=False) as advanced_tab:
25 | with gr.Row():
26 | quantization_bit = gr.Dropdown([8, 4], scale=1)
27 | source_prefix = gr.Textbox(scale=4)
28 |
29 | model_name.change(
30 | get_model_path, [model_name], [model_path]
31 | ).then(
32 | list_checkpoint, [model_name, finetuning_type], [checkpoints]
33 | ) # do not save config since the below line will save
34 | model_path.change(save_config, [model_name, model_path])
35 |
36 | finetuning_type.change(
37 | list_checkpoint, [model_name, finetuning_type], [checkpoints]
38 | ).then(
39 | can_quantize, [finetuning_type], [quantization_bit]
40 | )
41 |
42 | refresh_btn.click(list_checkpoint, [model_name, finetuning_type], [checkpoints])
43 |
44 | return dict(
45 | lang=lang,
46 | model_name=model_name,
47 | model_path=model_path,
48 | finetuning_type=finetuning_type,
49 | checkpoints=checkpoints,
50 | refresh_btn=refresh_btn,
51 | advanced_tab=advanced_tab,
52 | quantization_bit=quantization_bit,
53 | source_prefix=source_prefix
54 | )
55 |
--------------------------------------------------------------------------------
/src/glmtuner/webui/css.py:
--------------------------------------------------------------------------------
1 | CSS = r"""
2 | .modal-box {
3 | position: fixed !important;
4 | top: 50%;
5 | left: 50%;
6 | transform: translate(-50%, -50%); /* center horizontally */
7 | max-width: 1000px;
8 | max-height: 750px;
9 | overflow-y: scroll !important;
10 | background-color: var(--input-background-fill);
11 | border: 2px solid black !important;
12 | z-index: 1000;
13 | }
14 |
15 | .dark .modal-box {
16 | border: 2px solid white !important;
17 | }
18 | """
19 |
--------------------------------------------------------------------------------
/src/glmtuner/webui/interface.py:
--------------------------------------------------------------------------------
1 | import gradio as gr
2 | from transformers.utils.versions import require_version
3 |
4 | from glmtuner.webui.components import (
5 | create_top,
6 | create_sft_tab,
7 | create_eval_tab,
8 | create_infer_tab,
9 | create_export_tab
10 | )
11 | from glmtuner.webui.css import CSS
12 | from glmtuner.webui.manager import Manager
13 | from glmtuner.webui.runner import Runner
14 |
15 |
16 | require_version("gradio>=3.36.0", "To fix: pip install gradio>=3.36.0")
17 |
18 |
19 | def create_ui() -> gr.Blocks:
20 | runner = Runner()
21 |
22 | with gr.Blocks(title="Web Tuner", css=CSS) as demo:
23 | top_elems = create_top()
24 |
25 | with gr.Tab("SFT"):
26 | sft_elems = create_sft_tab(top_elems, runner)
27 |
28 | with gr.Tab("Evaluate"):
29 | eval_elems = create_eval_tab(top_elems, runner)
30 |
31 | with gr.Tab("Chat"):
32 | infer_elems = create_infer_tab(top_elems)
33 |
34 | with gr.Tab("Export"):
35 | export_elems = create_export_tab(top_elems)
36 |
37 | elem_list = [top_elems, sft_elems, eval_elems, infer_elems, export_elems]
38 | manager = Manager(elem_list)
39 |
40 | demo.load(
41 | manager.gen_label,
42 | [top_elems["lang"]],
43 | [elem for elems in elem_list for elem in elems.values()],
44 | )
45 |
46 | top_elems["lang"].change(
47 | manager.gen_label,
48 | [top_elems["lang"]],
49 | [elem for elems in elem_list for elem in elems.values()],
50 | )
51 |
52 | return demo
53 |
54 |
55 | if __name__ == "__main__":
56 | demo = create_ui()
57 | demo.queue()
58 | demo.launch(server_name="0.0.0.0", server_port=7860, share=False, inbrowser=True)
59 |
--------------------------------------------------------------------------------
/src/glmtuner/webui/locales.py:
--------------------------------------------------------------------------------
1 | LOCALES = {
2 | "lang": {
3 | "en": {
4 | "label": "Lang"
5 | },
6 | "zh": {
7 | "label": "语言"
8 | }
9 | },
10 | "model_name": {
11 | "en": {
12 | "label": "Model name"
13 | },
14 | "zh": {
15 | "label": "模型名称"
16 | }
17 | },
18 | "model_path": {
19 | "en": {
20 | "label": "Model path",
21 | "info": "Path to pretrained model or model identifier from Hugging Face."
22 | },
23 | "zh": {
24 | "label": "模型路径",
25 | "info": "本地模型的文件路径或 Hugging Face 的模型标识符。"
26 | }
27 | },
28 | "finetuning_type": {
29 | "en": {
30 | "label": "Finetuning method"
31 | },
32 | "zh": {
33 | "label": "微调方法"
34 | }
35 | },
36 | "checkpoints": {
37 | "en": {
38 | "label": "Checkpoints"
39 | },
40 | "zh": {
41 | "label": "模型断点"
42 | }
43 | },
44 | "refresh_btn": {
45 | "en": {
46 | "value": "Refresh checkpoints"
47 | },
48 | "zh": {
49 | "value": "刷新断点"
50 | }
51 | },
52 | "advanced_tab": {
53 | "en": {
54 | "label": "Advanced configurations"
55 | },
56 | "zh": {
57 | "label": "高级设置"
58 | }
59 | },
60 | "quantization_bit": {
61 | "en": {
62 | "label": "Quantization bit (optional)",
63 | "info": "Enable 4/8-bit model quantization."
64 | },
65 | "zh": {
66 | "label": "量化等级(非必填)",
67 | "info": "启用 4/8 比特模型量化。"
68 | }
69 | },
70 | "template": {
71 | "en": {
72 | "label": "Prompt template",
73 | "info": "The template used in constructing prompts."
74 | },
75 | "zh": {
76 | "label": "提示模板",
77 | "info": "构建提示词时使用的模板"
78 | }
79 | },
80 | "source_prefix": {
81 | "en": {
82 | "label": "System prompt (optional)",
83 | "info": "A sequence used as the default system prompt."
84 | },
85 | "zh": {
86 | "label": "系统提示词(非必填)",
87 | "info": "默认使用的系统提示词"
88 | }
89 | },
90 | "dataset_dir": {
91 | "en": {
92 | "label": "Data dir",
93 | "info": "Path of the data directory."
94 | },
95 | "zh": {
96 | "label": "数据路径",
97 | "info": "数据文件夹的路径。"
98 | }
99 | },
100 | "dataset": {
101 | "en": {
102 | "label": "Dataset"
103 | },
104 | "zh": {
105 | "label": "数据集"
106 | }
107 | },
108 | "preview_btn": {
109 | "en": {
110 | "value": "Preview"
111 | },
112 | "zh": {
113 | "value": "预览"
114 | }
115 | },
116 | "preview_count": {
117 | "en": {
118 | "label": "Count"
119 | },
120 | "zh": {
121 | "label": "数量"
122 | }
123 | },
124 | "preview_samples": {
125 | "en": {
126 | "label": "Samples"
127 | },
128 | "zh": {
129 | "label": "样例"
130 | }
131 | },
132 | "close_btn": {
133 | "en": {
134 | "value": "Close"
135 | },
136 | "zh": {
137 | "value": "关闭"
138 | }
139 | },
140 | "max_source_length": {
141 | "en": {
142 | "label": "Max source length",
143 | "info": "Max tokens in source sequence."
144 | },
145 | "zh": {
146 | "label": "输入序列最大长度",
147 | "info": "输入序列分词后的最大长度。"
148 | }
149 | },
150 | "max_target_length": {
151 | "en": {
152 | "label": "Max target length",
153 | "info": "Max tokens in target sequence."
154 | },
155 | "zh": {
156 | "label": "输出序列最大长度",
157 | "info": "输出序列分词后的最大长度。"
158 | }
159 | },
160 | "learning_rate": {
161 | "en": {
162 | "label": "Learning rate",
163 | "info": "Initial learning rate for AdamW."
164 | },
165 | "zh": {
166 | "label": "学习率",
167 | "info": "AdamW 优化器的初始学习率。"
168 | }
169 | },
170 | "num_train_epochs": {
171 | "en": {
172 | "label": "Epochs",
173 | "info": "Total number of training epochs to perform."
174 | },
175 | "zh": {
176 | "label": "训练轮数",
177 | "info": "需要执行的训练总轮数。"
178 | }
179 | },
180 | "max_samples": {
181 | "en": {
182 | "label": "Max samples",
183 | "info": "Maximum samples per dataset."
184 | },
185 | "zh": {
186 | "label": "最大样本数",
187 | "info": "每个数据集最多使用的样本数。"
188 | }
189 | },
190 | "batch_size": {
191 | "en": {
192 | "label": "Batch size",
193 | "info": "Number of samples to process per GPU."
194 | },
195 | "zh":{
196 | "label": "批处理大小",
197 | "info": "每块 GPU 上处理的样本数量。"
198 | }
199 | },
200 | "gradient_accumulation_steps": {
201 | "en": {
202 | "label": "Gradient accumulation",
203 | "info": "Number of gradient accumulation steps."
204 | },
205 | "zh": {
206 | "label": "梯度累积",
207 | "info": "梯度累积的步数。"
208 | }
209 | },
210 | "lr_scheduler_type": {
211 | "en": {
212 | "label": "LR Scheduler",
213 | "info": "Name of learning rate scheduler.",
214 | },
215 | "zh": {
216 | "label": "学习率调节器",
217 | "info": "采用的学习率调节器名称。"
218 | }
219 | },
220 | "max_grad_norm": {
221 | "en": {
222 | "label": "Maximum gradient norm",
223 | "info": "Norm for gradient clipping.."
224 | },
225 | "zh": {
226 | "label": "最大梯度范数",
227 | "info": "用于梯度裁剪的范数。"
228 | }
229 | },
230 | "dev_ratio": {
231 | "en": {
232 | "label": "Dev ratio",
233 | "info": "Proportion of data in the dev set."
234 | },
235 | "zh": {
236 | "label": "验证集比例",
237 | "info": "验证集占全部样本的百分比。"
238 | }
239 | },
240 | "logging_steps": {
241 | "en": {
242 | "label": "Logging steps",
243 | "info": "Number of steps between two logs."
244 | },
245 | "zh": {
246 | "label": "日志间隔",
247 | "info": "每两次日志输出间的更新步数。"
248 | }
249 | },
250 | "save_steps": {
251 | "en": {
252 | "label": "Save steps",
253 | "info": "Number of steps between two checkpoints."
254 | },
255 | "zh": {
256 | "label": "保存间隔",
257 | "info": "每两次断点保存间的更新步数。"
258 | }
259 | },
260 | "warmup_steps": {
261 | "en": {
262 | "label": "Warmup steps",
263 | "info": "Number of steps used for warmup."
264 | },
265 | "zh": {
266 | "label": "预热步数",
267 | "info": "学习率预热采用的步数。"
268 | }
269 | },
270 | "compute_type": {
271 | "en": {
272 | "label": "Compute type",
273 | "info": "Whether to use fp16 or bf16 mixed precision training."
274 | },
275 | "zh": {
276 | "label": "计算类型",
277 | "info": "是否启用 FP16 或 BF16 混合精度训练。"
278 | }
279 | },
280 | "lora_tab": {
281 | "en": {
282 | "label": "LoRA configurations"
283 | },
284 | "zh": {
285 | "label": "LoRA 参数设置"
286 | }
287 | },
288 | "lora_rank": {
289 | "en": {
290 | "label": "LoRA rank",
291 | "info": "The rank of LoRA matrices."
292 | },
293 | "zh": {
294 | "label": "LoRA 秩",
295 | "info": "LoRA 矩阵的秩。"
296 | }
297 | },
298 | "lora_dropout": {
299 | "en": {
300 | "label": "LoRA Dropout",
301 | "info": "Dropout ratio of LoRA weights."
302 | },
303 | "zh": {
304 | "label": "LoRA 随机丢弃",
305 | "info": "LoRA 权重随机丢弃的概率。"
306 | }
307 | },
308 | "lora_target": {
309 | "en": {
310 | "label": "LoRA modules (optional)",
311 | "info": "The name(s) of target modules to apply LoRA. Use commas to separate multiple modules."
312 | },
313 | "zh": {
314 | "label": "LoRA 作用层(非必填)",
315 | "info": "应用 LoRA 的线性层名称。使用英文逗号分隔多个名称。"
316 | }
317 | },
318 | "start_btn": {
319 | "en": {
320 | "value": "Start"
321 | },
322 | "zh": {
323 | "value": "开始"
324 | }
325 | },
326 | "stop_btn": {
327 | "en": {
328 | "value": "Abort"
329 | },
330 | "zh": {
331 | "value": "中断"
332 | }
333 | },
334 | "output_dir": {
335 | "en": {
336 | "label": "Checkpoint name",
337 | "info": "Directory to save checkpoint."
338 | },
339 | "zh": {
340 | "label": "断点名称",
341 | "info": "保存模型断点的文件夹名称。"
342 | }
343 | },
344 | "output_box": {
345 | "en": {
346 | "value": "Ready."
347 | },
348 | "zh": {
349 | "value": "准备就绪。"
350 | }
351 | },
352 | "loss_viewer": {
353 | "en": {
354 | "label": "Loss"
355 | },
356 | "zh": {
357 | "label": "损失"
358 | }
359 | },
360 | "predict": {
361 | "en": {
362 | "label": "Save predictions"
363 | },
364 | "zh": {
365 | "label": "保存预测结果"
366 | }
367 | },
368 | "load_btn": {
369 | "en": {
370 | "value": "Load model"
371 | },
372 | "zh": {
373 | "value": "加载模型"
374 | }
375 | },
376 | "unload_btn": {
377 | "en": {
378 | "value": "Unload model"
379 | },
380 | "zh": {
381 | "value": "卸载模型"
382 | }
383 | },
384 | "info_box": {
385 | "en": {
386 | "value": "Model unloaded, please load a model first."
387 | },
388 | "zh": {
389 | "value": "模型未加载,请先加载模型。"
390 | }
391 | },
392 | "prefix": {
393 | "en": {
394 | "placeholder": "System prompt (optional)"
395 | },
396 | "zh": {
397 | "placeholder": "系统提示词(非必填)"
398 | }
399 | },
400 | "query": {
401 | "en": {
402 | "placeholder": "Input..."
403 | },
404 | "zh": {
405 | "placeholder": "输入..."
406 | }
407 | },
408 | "submit_btn": {
409 | "en": {
410 | "value": "Submit"
411 | },
412 | "zh": {
413 | "value": "提交"
414 | }
415 | },
416 | "clear_btn": {
417 | "en": {
418 | "value": "Clear history"
419 | },
420 | "zh": {
421 | "value": "清空历史"
422 | }
423 | },
424 | "max_length": {
425 | "en": {
426 | "label": "Maximum length"
427 | },
428 | "zh": {
429 | "label": "最大长度"
430 | }
431 | },
432 | "max_new_tokens": {
433 | "en": {
434 | "label": "Maximum new tokens"
435 | },
436 | "zh": {
437 | "label": "最大生成长度"
438 | }
439 | },
440 | "top_p": {
441 | "en": {
442 | "label": "Top-p"
443 | },
444 | "zh": {
445 | "label": "Top-p 采样值"
446 | }
447 | },
448 | "temperature": {
449 | "en": {
450 | "label": "Temperature"
451 | },
452 | "zh": {
453 | "label": "温度系数"
454 | }
455 | },
456 | "save_dir": {
457 | "en": {
458 | "label": "Export dir",
459 | "info": "Directory to save exported model."
460 | },
461 | "zh": {
462 | "label": "导出目录",
463 | "info": "保存导出模型的文件夹路径。"
464 | }
465 | },
466 | "max_shard_size": {
467 | "en": {
468 | "label": "Max shard size (GB)",
469 | "info": "The maximum size for a model file."
470 | },
471 | "zh": {
472 | "label": "最大分块大小(GB)",
473 | "info": "模型文件的最大大小。"
474 | }
475 | },
476 | "export_btn": {
477 | "en": {
478 | "value": "Export"
479 | },
480 | "zh": {
481 | "value": "开始导出"
482 | }
483 | }
484 | }
485 |
486 |
487 | ALERTS = {
488 | "err_conflict": {
489 | "en": "A process is in running, please abort it firstly.",
490 | "zh": "任务已存在,请先中断训练。"
491 | },
492 | "err_exists": {
493 | "en": "You have loaded a model, please unload it first.",
494 | "zh": "模型已存在,请先卸载模型。"
495 | },
496 | "err_no_model": {
497 | "en": "Please select a model.",
498 | "zh": "请选择模型。"
499 | },
500 | "err_no_path": {
501 | "en": "Model not found.",
502 | "zh": "模型未找到。"
503 | },
504 | "err_no_dataset": {
505 | "en": "Please choose a dataset.",
506 | "zh": "请选择数据集。"
507 | },
508 | "err_no_checkpoint": {
509 | "en": "Please select a checkpoint.",
510 | "zh": "请选择断点。"
511 | },
512 | "err_no_save_dir": {
513 | "en": "Please provide export dir.",
514 | "zh": "请填写导出目录"
515 | },
516 | "info_aborting": {
517 | "en": "Aborted, wait for terminating...",
518 | "zh": "训练中断,正在等待线程结束……"
519 | },
520 | "info_aborted": {
521 | "en": "Ready.",
522 | "zh": "准备就绪。"
523 | },
524 | "info_finished": {
525 | "en": "Finished.",
526 | "zh": "训练完毕。"
527 | },
528 | "info_loading": {
529 | "en": "Loading model...",
530 | "zh": "加载中……"
531 | },
532 | "info_unloading": {
533 | "en": "Unloading model...",
534 | "zh": "卸载中……"
535 | },
536 | "info_loaded": {
537 | "en": "Model loaded, now you can chat with your model!",
538 | "zh": "模型已加载,可以开始聊天了!"
539 | },
540 | "info_unloaded": {
541 | "en": "Model unloaded.",
542 | "zh": "模型已卸载。"
543 | },
544 | "info_exporting": {
545 | "en": "Exporting model...",
546 | "zh": "正在导出模型……"
547 | },
548 | "info_exported": {
549 | "en": "Model exported.",
550 | "zh": "模型导出完成。"
551 | }
552 | }
553 |
--------------------------------------------------------------------------------
/src/glmtuner/webui/manager.py:
--------------------------------------------------------------------------------
1 | import gradio as gr
2 | from typing import Any, Dict, List
3 | from gradio.components import Component
4 |
5 | from glmtuner.webui.common import get_model_path, list_dataset, load_config
6 | from glmtuner.webui.locales import LOCALES
7 | from glmtuner.webui.utils import get_time
8 |
9 |
10 | class Manager:
11 |
12 | def __init__(self, elem_list: List[Dict[str, Component]]):
13 | self.elem_list = elem_list
14 |
15 | def gen_refresh(self) -> Dict[str, Any]:
16 | refresh_dict = {
17 | "dataset": {"choices": list_dataset()["choices"]},
18 | "output_dir": {"value": get_time()}
19 | }
20 | user_config = load_config()
21 | if user_config["last_model"]:
22 | refresh_dict["model_name"] = {"value": user_config["last_model"]}
23 | refresh_dict["model_path"] = {"value": get_model_path(user_config["last_model"])}
24 |
25 | return refresh_dict
26 |
27 | def gen_label(self, lang: str) -> Dict[Component, dict]:
28 | update_dict = {}
29 | refresh_dict = self.gen_refresh()
30 |
31 | for elems in self.elem_list:
32 | for name, component in elems.items():
33 | update_dict[component] = gr.update(**LOCALES[name][lang], **refresh_dict.get(name, {}))
34 |
35 | return update_dict
36 |
--------------------------------------------------------------------------------
/src/glmtuner/webui/runner.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | import threading
4 | import time
5 | import transformers
6 | from typing import Generator, List, Optional, Tuple
7 |
8 | from glmtuner.extras.callbacks import LogCallback
9 | from glmtuner.extras.logging import LoggerHandler
10 | from glmtuner.extras.misc import torch_gc
11 | from glmtuner.tuner import get_train_args, run_sft
12 | from glmtuner.webui.common import get_model_path, get_save_dir
13 | from glmtuner.webui.locales import ALERTS
14 | from glmtuner.webui.utils import format_info, get_eval_results
15 |
16 |
17 | class Runner:
18 |
19 | def __init__(self):
20 | self.aborted = False
21 | self.running = False
22 |
23 | def set_abort(self):
24 | self.aborted = True
25 | self.running = False
26 |
27 | def initialize(
28 | self, lang: str, model_name: str, dataset: List[str]
29 | ) -> Tuple[str, str, LoggerHandler, LogCallback]:
30 | if self.running:
31 | return None, ALERTS["err_conflict"][lang], None, None
32 |
33 | if not model_name:
34 | return None, ALERTS["err_no_model"][lang], None, None
35 |
36 | model_name_or_path = get_model_path(model_name)
37 | if not model_name_or_path:
38 | return None, ALERTS["err_no_path"][lang], None, None
39 |
40 | if len(dataset) == 0:
41 | return None, ALERTS["err_no_dataset"][lang], None, None
42 |
43 | self.aborted = False
44 | self.running = True
45 |
46 | logger_handler = LoggerHandler()
47 | logger_handler.setLevel(logging.INFO)
48 | logging.root.addHandler(logger_handler)
49 | transformers.logging.add_handler(logger_handler)
50 | trainer_callback = LogCallback(self)
51 |
52 | return model_name_or_path, "", logger_handler, trainer_callback
53 |
54 | def finalize(
55 | self, lang: str, finish_info: Optional[str] = None
56 | ) -> str:
57 | self.running = False
58 | torch_gc()
59 | if self.aborted:
60 | return ALERTS["info_aborted"][lang]
61 | else:
62 | return finish_info if finish_info is not None else ALERTS["info_finished"][lang]
63 |
64 | def run_train(
65 | self,
66 | lang: str,
67 | model_name: str,
68 | checkpoints: List[str],
69 | finetuning_type: str,
70 | quantization_bit: str,
71 | source_prefix: str,
72 | dataset_dir: str,
73 | dataset: List[str],
74 | max_source_length: int,
75 | max_target_length: int,
76 | learning_rate: str,
77 | num_train_epochs: str,
78 | max_samples: str,
79 | batch_size: int,
80 | gradient_accumulation_steps: int,
81 | lr_scheduler_type: str,
82 | max_grad_norm: str,
83 | dev_ratio: float,
84 | logging_steps: int,
85 | save_steps: int,
86 | warmup_steps: int,
87 | compute_type: str,
88 | lora_rank: int,
89 | lora_dropout: float,
90 | lora_target: str,
91 | output_dir: str
92 | ) -> Generator[str, None, None]:
93 | model_name_or_path, error, logger_handler, trainer_callback = self.initialize(lang, model_name, dataset)
94 | if error:
95 | yield error
96 | return
97 |
98 | if checkpoints:
99 | checkpoint_dir = ",".join(
100 | [os.path.join(get_save_dir(model_name), finetuning_type, checkpoint) for checkpoint in checkpoints]
101 | )
102 | else:
103 | checkpoint_dir = None
104 |
105 | args = dict(
106 | model_name_or_path=model_name_or_path,
107 | do_train=True,
108 | overwrite_cache=True,
109 | checkpoint_dir=checkpoint_dir,
110 | finetuning_type=finetuning_type,
111 | quantization_bit=int(quantization_bit) if quantization_bit else None,
112 | source_prefix=source_prefix,
113 | dataset_dir=dataset_dir,
114 | dataset=",".join(dataset),
115 | max_source_length=max_source_length,
116 | max_target_length=max_target_length,
117 | learning_rate=float(learning_rate),
118 | num_train_epochs=float(num_train_epochs),
119 | max_samples=int(max_samples),
120 | per_device_train_batch_size=batch_size,
121 | gradient_accumulation_steps=gradient_accumulation_steps,
122 | lr_scheduler_type=lr_scheduler_type,
123 | max_grad_norm=float(max_grad_norm),
124 | logging_steps=logging_steps,
125 | save_steps=save_steps,
126 | warmup_steps=warmup_steps,
127 | fp16=(compute_type == "fp16"),
128 | bf16=(compute_type == "bf16"),
129 | lora_rank=lora_rank,
130 | lora_dropout=lora_dropout,
131 | lora_target=lora_target or "query_key_value",
132 | output_dir=os.path.join(get_save_dir(model_name), finetuning_type, output_dir)
133 | )
134 |
135 | if dev_ratio > 1e-6:
136 | args["dev_ratio"] = dev_ratio
137 | args["evaluation_strategy"] = "steps"
138 | args["eval_steps"] = save_steps
139 | args["load_best_model_at_end"] = True
140 |
141 | model_args, data_args, training_args, finetuning_args, _ = get_train_args(args)
142 |
143 | run_args = dict(
144 | model_args=model_args,
145 | data_args=data_args,
146 | training_args=training_args,
147 | finetuning_args=finetuning_args,
148 | callbacks=[trainer_callback]
149 | )
150 | thread = threading.Thread(target=run_sft, kwargs=run_args)
151 | thread.start()
152 |
153 | while thread.is_alive():
154 | time.sleep(1)
155 | if self.aborted:
156 | yield ALERTS["info_aborting"][lang]
157 | else:
158 | yield format_info(logger_handler.log, trainer_callback.tracker)
159 |
160 | yield self.finalize(lang)
161 |
162 | def run_eval(
163 | self,
164 | lang: str,
165 | model_name: str,
166 | checkpoints: List[str],
167 | finetuning_type: str,
168 | quantization_bit: str,
169 | source_prefix: str,
170 | dataset_dir: str,
171 | dataset: List[str],
172 | max_source_length: int,
173 | max_target_length: int,
174 | max_samples: str,
175 | batch_size: int,
176 | predict: bool
177 | ) -> Generator[str, None, None]:
178 | model_name_or_path, error, logger_handler, trainer_callback = self.initialize(lang, model_name, dataset)
179 | if error:
180 | yield error
181 | return
182 |
183 | if checkpoints:
184 | checkpoint_dir = ",".join(
185 | [os.path.join(get_save_dir(model_name), finetuning_type, checkpoint) for checkpoint in checkpoints]
186 | )
187 | output_dir = os.path.join(get_save_dir(model_name), finetuning_type, "eval_" + "_".join(checkpoints))
188 | else:
189 | checkpoint_dir = None
190 | output_dir = os.path.join(get_save_dir(model_name), finetuning_type, "eval_base")
191 |
192 | args = dict(
193 | model_name_or_path=model_name_or_path,
194 | do_eval=True,
195 | overwrite_cache=True,
196 | predict_with_generate=True,
197 | checkpoint_dir=checkpoint_dir,
198 | finetuning_type=finetuning_type,
199 | quantization_bit=int(quantization_bit) if quantization_bit else None,
200 | source_prefix=source_prefix,
201 | dataset_dir=dataset_dir,
202 | dataset=",".join(dataset),
203 | max_source_length=max_source_length,
204 | max_target_length=max_target_length,
205 | max_samples=int(max_samples),
206 | per_device_eval_batch_size=batch_size,
207 | output_dir=output_dir
208 | )
209 |
210 | if predict:
211 | args.pop("do_eval", None)
212 | args["do_predict"] = True
213 |
214 | model_args, data_args, training_args, finetuning_args, _ = get_train_args(args)
215 |
216 | run_args = dict(
217 | model_args=model_args,
218 | data_args=data_args,
219 | training_args=training_args,
220 | finetuning_args=finetuning_args,
221 | callbacks=[trainer_callback]
222 | )
223 | thread = threading.Thread(target=run_sft, kwargs=run_args)
224 | thread.start()
225 |
226 | while thread.is_alive():
227 | time.sleep(1)
228 | if self.aborted:
229 | yield ALERTS["info_aborting"][lang]
230 | else:
231 | yield format_info(logger_handler.log, trainer_callback.tracker)
232 |
233 | yield self.finalize(lang, get_eval_results(os.path.join(output_dir, "all_results.json")))
234 |
--------------------------------------------------------------------------------
/src/glmtuner/webui/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import gradio as gr
4 | import matplotlib.figure
5 | import matplotlib.pyplot as plt
6 | from typing import Any, Dict, Generator, List, Tuple
7 | from datetime import datetime
8 |
9 | from glmtuner.extras.ploting import smooth
10 | from glmtuner.tuner import get_infer_args, load_model_and_tokenizer
11 | from glmtuner.webui.common import get_model_path, get_save_dir, DATA_CONFIG
12 | from glmtuner.webui.locales import ALERTS
13 |
14 |
15 | def format_info(log: str, tracker: dict) -> str:
16 | info = log
17 | if "current_steps" in tracker:
18 | info += "Running **{:d}/{:d}**: {} < {}\n".format(
19 | tracker["current_steps"], tracker["total_steps"], tracker["elapsed_time"], tracker["remaining_time"]
20 | )
21 | return info
22 |
23 |
24 | def get_time() -> str:
25 | return datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
26 |
27 |
28 | def can_preview(dataset_dir: str, dataset: list) -> Dict[str, Any]:
29 | with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f:
30 | dataset_info = json.load(f)
31 | if (
32 | len(dataset) > 0
33 | and "file_name" in dataset_info[dataset[0]]
34 | and os.path.isfile(os.path.join(dataset_dir, dataset_info[dataset[0]]["file_name"]))
35 | ):
36 | return gr.update(interactive=True)
37 | else:
38 | return gr.update(interactive=False)
39 |
40 |
41 | def get_preview(dataset_dir: str, dataset: list) -> Tuple[int, list, Dict[str, Any]]:
42 | with open(os.path.join(dataset_dir, DATA_CONFIG), "r", encoding="utf-8") as f:
43 | dataset_info = json.load(f)
44 | data_file = dataset_info[dataset[0]]["file_name"]
45 | with open(os.path.join(dataset_dir, data_file), "r", encoding="utf-8") as f:
46 | data = json.load(f)
47 | return len(data), data[:2], gr.update(visible=True)
48 |
49 |
50 | def can_quantize(finetuning_type: str) -> Dict[str, Any]:
51 | if finetuning_type not in ["p_tuning", "lora"]:
52 | return gr.update(value="", interactive=False)
53 | else:
54 | return gr.update(interactive=True)
55 |
56 |
57 | def get_eval_results(path: os.PathLike) -> str:
58 | with open(path, "r", encoding="utf-8") as f:
59 | result = json.dumps(json.load(f), indent=4)
60 | return "```json\n{}\n```\n".format(result)
61 |
62 |
63 | def gen_plot(base_model: str, finetuning_type: str, output_dir: str) -> matplotlib.figure.Figure:
64 | log_file = os.path.join(get_save_dir(base_model), finetuning_type, output_dir, "trainer_log.jsonl")
65 | if not os.path.isfile(log_file):
66 | return None
67 |
68 | plt.close("all")
69 | fig = plt.figure()
70 | ax = fig.add_subplot(111)
71 | steps, losses = [], []
72 | with open(log_file, "r", encoding="utf-8") as f:
73 | for line in f:
74 | log_info = json.loads(line)
75 | if log_info.get("loss", None):
76 | steps.append(log_info["current_steps"])
77 | losses.append(log_info["loss"])
78 |
79 | if len(losses) == 0:
80 | return None
81 |
82 | ax.plot(steps, losses, alpha=0.4, label="original")
83 | ax.plot(steps, smooth(losses), label="smoothed")
84 | ax.legend()
85 | ax.set_xlabel("step")
86 | ax.set_ylabel("loss")
87 | return fig
88 |
89 |
90 | def export_model(
91 | lang: str, model_name: str, checkpoints: List[str], finetuning_type: str, max_shard_size: int, save_dir: str
92 | ) -> Generator[str, None, None]:
93 | if not model_name:
94 | yield ALERTS["err_no_model"][lang]
95 | return
96 |
97 | model_name_or_path = get_model_path(model_name)
98 | if not model_name_or_path:
99 | yield ALERTS["err_no_path"][lang]
100 | return
101 |
102 | if not checkpoints:
103 | yield ALERTS["err_no_checkpoint"][lang]
104 | return
105 |
106 | checkpoint_dir = ",".join(
107 | [os.path.join(get_save_dir(model_name), finetuning_type, checkpoint) for checkpoint in checkpoints]
108 | )
109 |
110 | if not save_dir:
111 | yield ALERTS["err_no_save_dir"][lang]
112 | return
113 |
114 | args = dict(
115 | model_name_or_path=model_name_or_path,
116 | checkpoint_dir=checkpoint_dir,
117 | finetuning_type=finetuning_type
118 | )
119 |
120 | yield ALERTS["info_exporting"][lang]
121 | model_args, _, finetuning_args, _ = get_infer_args(args)
122 | model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
123 | model.save_pretrained(save_dir, max_shard_size=str(max_shard_size)+"GB")
124 | tokenizer.save_pretrained(save_dir)
125 | yield ALERTS["info_exported"][lang]
126 |
--------------------------------------------------------------------------------
/src/train_bash.py:
--------------------------------------------------------------------------------
1 | from glmtuner.tuner import get_train_args, run_sft, run_rm, run_ppo
2 |
3 |
4 | def main():
5 | model_args, data_args, training_args, finetuning_args, general_args = get_train_args()
6 |
7 | if general_args.stage == "sft":
8 | run_sft(model_args, data_args, training_args, finetuning_args)
9 | elif general_args.stage == "rm":
10 | run_rm(model_args, data_args, training_args, finetuning_args)
11 | elif general_args.stage == "ppo":
12 | run_ppo(model_args, data_args, training_args, finetuning_args)
13 |
14 |
15 | def _mp_fn(index):
16 | # For xla_spawn (TPUs)
17 | main()
18 |
19 |
20 | if __name__ == "__main__":
21 | main()
22 |
--------------------------------------------------------------------------------
/src/web_demo.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Implements user interface in browser for ChatGLM fine-tuned with PEFT.
3 | # Usage: python web_demo.py --checkpoint_dir path_to_checkpoint [--quantization_bit 4]
4 |
5 | import gradio as gr
6 | from transformers.utils.versions import require_version
7 |
8 | from glmtuner.tuner import get_infer_args
9 | from glmtuner.webui.chat import WebChatModel
10 | from glmtuner.webui.components.chatbot import create_chat_box
11 | from glmtuner.webui.manager import Manager
12 |
13 |
14 | require_version("gradio>=3.36.0", "To fix: pip install gradio>=3.36.0")
15 |
16 |
17 | def main():
18 | chat_model = WebChatModel(*get_infer_args())
19 |
20 | with gr.Blocks(title="Web Demo") as demo:
21 | lang = gr.Dropdown(choices=["en", "zh"], value="en")
22 |
23 | _, _, _, chat_elems = create_chat_box(chat_model, visible=True)
24 |
25 | manager = Manager([{"lang": lang}, chat_elems])
26 |
27 | demo.load(manager.gen_label, [lang], [lang] + list(chat_elems.values()))
28 |
29 | lang.change(manager.gen_label, [lang], [lang] + list(chat_elems.values()))
30 |
31 | demo.queue()
32 | demo.launch(server_name="0.0.0.0", server_port=7860, share=False, inbrowser=True)
33 |
34 |
35 | if __name__ == "__main__":
36 | main()
37 |
--------------------------------------------------------------------------------