├── .gitignore ├── LICENSE ├── README.MD ├── args.MD ├── assets ├── 1.png └── react_prompt.md ├── config ├── __init__.py ├── colossalai_strategy.yaml ├── constant_map.py ├── deepspeed.yaml ├── deepspeed_offload.yaml ├── global.yaml ├── main.py ├── petl.yaml ├── train_ac.yaml ├── train_cl.yaml ├── train_hf.yaml └── train_pl.yaml ├── data ├── finetune_train_conversations.json ├── finetune_train_paragraph.json └── make_data_example.py ├── data_processer.py ├── data_tools.py ├── data_utils.py ├── infer ├── __init__.py ├── api_lora_demo.py ├── evaluate.py ├── infer.py ├── infer_finetuning.py ├── infer_lora_finetuning.py └── infer_muti_lora_finetuning.py ├── requirements.txt ├── scripts ├── train_full.sh ├── train_lora.sh ├── train_lora_int4.sh ├── train_lora_int8.sh └── train_ptv2.sh ├── train.py └── training ├── __init__.py ├── train_ac.py ├── train_cl.py ├── train_hf.py └── train_pl.py /.gitignore: -------------------------------------------------------------------------------- 1 | /.idea 2 | /__pycache__ 3 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.MD: -------------------------------------------------------------------------------- 1 | ## statement 2 | - [deep_training](https://github.com/ssbuild/deep_training) 3 | 4 | ```text 5 | 2024-04-22 简化 6 | 2023-12-02 update qwen model 1.8b 7b 12b 72b 7 | 2023-10-09 support accelerator trainer 8 | 2023-10-07 support colossalai trainer 9 | 2023-09-26 support transformers trainer 10 | 2023-09-25 0.2.4 support qwen-7b 新版 和 qwen-14b , 旧版不再支持,旧版可以安装 deep_training <= 0.2.3 11 | support transformers trainer 12 | 2023-08-11 aigc-zoo 0.1.17.post0 update config , 更新下官方权重配置文件 13 | dev 分支加一些新功能和想法 如果求稳定,请使用 stable分支 14 | 15 | ``` 16 | 17 | 18 | 19 | ## install 20 | - pip install -U -r requirements.txt 21 | - 如果无法安装 , 可以切换官方源 pip install -i https://pypi.org/simple -U -r requirements.txt 22 | 23 | ```text 24 | 25 | # flash-attention对显卡算例要求算力7.5 以上 , 下面可选安装 ,如果卡不支持可以不安装。 26 | git clone -b https://github.com/Dao-AILab/flash-attention 27 | cd flash-attention && pip install . 28 | pip install csrc/layer_norm 29 | pip install csrc/rotary 30 | ``` 31 | 32 | 33 | ## weight 34 | - [Qwen2.5-1.5B-Instruct](https://huggingface.co/Qwen/Qwen2.5-1.5B-Instruct) 35 | 36 | DeepSeek-R1-Distill-Qwen-1.5B Qwen2.5-Math-1.5B 🤗 HuggingFace 37 | DeepSeek-R1-Distill-Qwen-7B Qwen2.5-Math-7B 🤗 HuggingFace 38 | DeepSeek-R1-Distill-Qwen-14B Qwen2.5-14B 🤗 HuggingFace 39 | DeepSeek-R1-Distill-Qwen-32B Qwen2.5-32B 🤗 HuggingFace 40 | 41 | ## data sample 42 | - [open_data 不定时开放新数据集](https://github.com/ssbuild/open_data) 43 | - [react_prompt](assets/react_prompt.md) 44 | 45 | ```text 46 | 数据示例 47 | 例子依次分别是 工具,对话,对话,对话 48 | 数据构建sample 参考 data/make_data_example.py 49 | 数组组成 50 | role: 可选字段(str) 标志 q字段 角色, one of user system, observation ,system 标识是否为 system prompt , system prompt a 空 51 | q: 问题 52 | a: 回答 53 | 54 | 注意事项: 55 | a字段:对于普通对话,a即为回答。 56 | 细节可以参考 assets/react_prompt.md 57 | ``` 58 | 59 | ```json 60 | {"id": 1, "paragraph": [{"role": "system", "q": "You are a helpful assistant.", "a": ""}, {"role": "user", "q": "Answer the following questions as best you can. You have access to the following tools:\n\nquark_search: Call this tool to interact with the 夸克搜索 API. What is the 夸克搜索 API useful for? 夸克搜索是一个通用搜索引擎,可用于访问互联网、查询百科知识、了解时事新闻等。 Parameters: [{\"name\": \"search_query\", \"description\": \"搜索关键词或短语\", \"required\": true, \"schema\": {\"type\": \"string\"}}] Format the arguments as a JSON object.\n\nimage_gen: Call this tool to interact with the 通义万相 API. What is the 通义万相 API useful for? 通义万相是一个AI绘画(图像生成)服务,输入文本描述,返回根据文本作画得到的图片的URL Parameters: [{\"name\": \"query\", \"description\": \"中文关键词,描述了希望图像具有什么内容\", \"required\": true, \"schema\": {\"type\": \"string\"}}] Format the arguments as a JSON object.\n\nUse the following format:\n\nQuestion: the input question you must answer\nThought: you should always think about what to do\nAction: the action to take, should be one of [quark_search,image_gen]\nAction Input: the input to the action\nObservation: the result of the action\n... (this Thought/Action/Action Input/Observation can be repeated zero or more times)\nThought: I now know the final answer\nFinal Answer: the final answer to the original input question\n\nBegin!\n\nQuestion: 现在给我画个五彩斑斓的黑", "a": "\nThought: 我应该使用通义万相API来生成一张五彩斑斓的黑的图片。\nAction: image_gen\nAction Input: {\"query\": \"五彩斑斓的黑\"}\n "}, {"role": "observation", "q": "Observation: \n{\"status_code\": 200, \"request_id\": \"3d894da2-0e26-9b7c-bd90-102e5250ae03\", \"code\": null, \"message\": \"\", \"output\": {\"task_id\": \"2befaa09-a8b3-4740-ada9-4d00c2758b05\", \"task_status\": \"SUCCEEDED\", \"results\": [{\"url\": \"https://dashscope-result-sh.oss-cn-shanghai.aliyuncs.com/1e5e2015/20230801/1509/6b26bb83-469e-4c70-bff4-a9edd1e584f3-1.png\"}], \"task_metrics\": {\"TOTAL\": 1, \"SUCCEEDED\": 1, \"FAILED\": 0}}, \"usage\": {\"image_count\": 1}}\n", "a": "\nThought: 我已经成功使用通义万相API生成了一张五彩斑斓的黑的图片。\nFinal Answer: 我已经成功使用通义万相API生成了一张五彩斑斓的黑的图片https://dashscope-result-sh.oss-cn-shanghai.aliyuncs.com/1e5e2015/20230801/1509/6b26bb83-469e-4c70-bff4-a9edd1e584f3-1.png。\n "}]} 61 | {"id": 2, "paragraph": [{"role": "system", "q": "You are a helpful assistant.", "a": ""}, {"q": "从南京到上海的路线", "a": "你好,南京到上海的路线如下:\n1. 南京到上海,可以乘坐南京地铁1号线,在南京站乘坐轨道交通1号线。\n2. 南京到浦东机场,可以搭乘上海地铁1号,在陆家嘴站乘坐地铁1线,在浦东国际机场站乘坐机场快线,前往上海浦东国际机场。\n3. 上海到南京,可以换乘上海地铁2号线,从南京站换乘地铁2线,再从南京南站换乘地铁1路,然后到达上海站"}]} 62 | {"id": 3, "paragraph": [{"role": "system", "q": "You are a helpful assistant.", "a": ""}, {"q": "写一个诗歌,关于冬天", "a": "冬夜寂静冷,\n云在天边飘,\n冰封白雪上, \n寒冷像一场雪。\n \n雪花融化成冰,\n像那雪花飘洒,\n在寒冷的冬天,\n感受春天的喜悦。\n 冬日里,\n风雪渐消,\n一片寂静,\n把快乐和温暖带回家。"}]} 63 | {"id": 4, "paragraph": [{"role": "system", "q": "You are a helpful assistant.", "a": ""}, {"q": "晚上睡不着应该怎么办", "a": "晚上睡不着可能会让你感到焦虑或不舒服,但以下是一些可以帮助你入睡的方法:\n\n1. 制定规律的睡眠时间表:保持规律的睡眠时间表可以帮助你建立健康的睡眠习惯,使你更容易入睡。尽量在每天的相同时间上床,并在同一时间起床。\n2. 创造一个舒适的睡眠环境:确保睡眠环境舒适,安静,黑暗且温度适宜。可以使用舒适的床上用品,并保持房间通风。\n3. 放松身心:在睡前做些放松的活动,例如泡个热水澡,听些轻柔的音乐,阅读一些有趣的书籍等,有助于缓解紧张和焦虑,使你更容易入睡。\n4. 避免饮用含有咖啡因的饮料:咖啡因是一种刺激性物质,会影响你的睡眠质量。尽量避免在睡前饮用含有咖啡因的饮料,例如咖啡,茶和可乐。\n5. 避免在床上做与睡眠无关的事情:在床上做些与睡眠无关的事情,例如看电影,玩游戏或工作等,可能会干扰你的睡眠。\n6. 尝试呼吸技巧:深呼吸是一种放松技巧,可以帮助你缓解紧张和焦虑,使你更容易入睡。试着慢慢吸气,保持几秒钟,然后缓慢呼气。\n\n如果这些方法无法帮助你入睡,你可以考虑咨询医生或睡眠专家,寻求进一步的建议。"}]} 64 | ``` 65 | 或者 66 | ```json 67 | {"id": 1, "conversations": [{"from": "system", "value": "You are a helpful assistant."}, {"from": "user", "value": "Answer the following questions as best you can. You have access to the following tools:\n\nquark_search: Call this tool to interact with the 夸克搜索 API. What is the 夸克搜索 API useful for? 夸克搜索是一个通用搜索引擎,可用于访问互联网、查询百科知识、了解时事新闻等。 Parameters: [{\"name\": \"search_query\", \"description\": \"搜索关键词或短语\", \"required\": true, \"schema\": {\"type\": \"string\"}}] Format the arguments as a JSON object.\n\nimage_gen: Call this tool to interact with the 通义万相 API. What is the 通义万相 API useful for? 通义万相是一个AI绘画(图像生成)服务,输入文本描述,返回根据文本作画得到的图片的URL Parameters: [{\"name\": \"query\", \"description\": \"中文关键词,描述了希望图像具有什么内容\", \"required\": true, \"schema\": {\"type\": \"string\"}}] Format the arguments as a JSON object.\n\nUse the following format:\n\nQuestion: the input question you must answer\nThought: you should always think about what to do\nAction: the action to take, should be one of [quark_search,image_gen]\nAction Input: the input to the action\nObservation: the result of the action\n... (this Thought/Action/Action Input/Observation can be repeated zero or more times)\nThought: I now know the final answer\nFinal Answer: the final answer to the original input question\n\nBegin!\n\nQuestion: 现在给我画个五彩斑斓的黑"}, {"from": "assistant", "value": "\nThought: 我应该使用通义万相API来生成一张五彩斑斓的黑的图片。\nAction: image_gen\nAction Input: {\"query\": \"五彩斑斓的黑\"}\n "}, {"from": "observation", "value": "Observation: \n{\"status_code\": 200, \"request_id\": \"3d894da2-0e26-9b7c-bd90-102e5250ae03\", \"code\": null, \"message\": \"\", \"output\": {\"task_id\": \"2befaa09-a8b3-4740-ada9-4d00c2758b05\", \"task_status\": \"SUCCEEDED\", \"results\": [{\"url\": \"https://dashscope-result-sh.oss-cn-shanghai.aliyuncs.com/1e5e2015/20230801/1509/6b26bb83-469e-4c70-bff4-a9edd1e584f3-1.png\"}], \"task_metrics\": {\"TOTAL\": 1, \"SUCCEEDED\": 1, \"FAILED\": 0}}, \"usage\": {\"image_count\": 1}}\n"}, {"from": "assistant", "value": "\nThought: 我已经成功使用通义万相API生成了一张五彩斑斓的黑的图片。\nFinal Answer: 我已经成功使用通义万相API生成了一张五彩斑斓的黑的图片https://dashscope-result-sh.oss-cn-shanghai.aliyuncs.com/1e5e2015/20230801/1509/6b26bb83-469e-4c70-bff4-a9edd1e584f3-1.png。\n "}]} 68 | {"id": 2, "conversations": [{"from": "system", "value": "You are a helpful assistant."}, {"from": "user", "value": "从南京到上海的路线"}, {"from": "assistant", "value": "你好,南京到上海的路线如下:\n1. 南京到上海,可以乘坐南京地铁1号线,在南京站乘坐轨道交通1号线。\n2. 南京到浦东机场,可以搭乘上海地铁1号,在陆家嘴站乘坐地铁1线,在浦东国际机场站乘坐机场快线,前往上海浦东国际机场。\n3. 上海到南京,可以换乘上海地铁2号线,从南京站换乘地铁2线,再从南京南站换乘地铁1路,然后到达上海站"}]} 69 | {"id": 3, "conversations": [{"from": "system", "value": "You are a helpful assistant."}, {"from": "user", "value": "写一个诗歌,关于冬天"}, {"from": "assistant", "value": "冬夜寂静冷,\n云在天边飘,\n冰封白雪上, \n寒冷像一场雪。\n \n雪花融化成冰,\n像那雪花飘洒,\n在寒冷的冬天,\n感受春天的喜悦。\n 冬日里,\n风雪渐消,\n一片寂静,\n把快乐和温暖带回家。"}]} 70 | {"id": 4, "conversations": [{"from": "system", "value": "You are a helpful assistant."}, {"from": "user", "value": "晚上睡不着应该怎么办"}, {"from": "assistant", "value": "晚上睡不着可能会让你感到焦虑或不舒服,但以下是一些可以帮助你入睡的方法:\n\n1. 制定规律的睡眠时间表:保持规律的睡眠时间表可以帮助你建立健康的睡眠习惯,使你更容易入睡。尽量在每天的相同时间上床,并在同一时间起床。\n2. 创造一个舒适的睡眠环境:确保睡眠环境舒适,安静,黑暗且温度适宜。可以使用舒适的床上用品,并保持房间通风。\n3. 放松身心:在睡前做些放松的活动,例如泡个热水澡,听些轻柔的音乐,阅读一些有趣的书籍等,有助于缓解紧张和焦虑,使你更容易入睡。\n4. 避免饮用含有咖啡因的饮料:咖啡因是一种刺激性物质,会影响你的睡眠质量。尽量避免在睡前饮用含有咖啡因的饮料,例如咖啡,茶和可乐。\n5. 避免在床上做与睡眠无关的事情:在床上做些与睡眠无关的事情,例如看电影,玩游戏或工作等,可能会干扰你的睡眠。\n6. 尝试呼吸技巧:深呼吸是一种放松技巧,可以帮助你缓解紧张和焦虑,使你更容易入睡。试着慢慢吸气,保持几秒钟,然后缓慢呼气。\n\n如果这些方法无法帮助你入睡,你可以考虑咨询医生或睡眠专家,寻求进一步的建议。"}]} 71 | ``` 72 | 73 | 74 | 75 | ## infer 76 | # infer.py 推理预训练模型 77 | # infer_finetuning.py 推理微调模型 78 | # infer_lora_finetuning.py 推理lora微调模型 79 | python infer.py 80 | 81 | 82 | | **量化等级** | **最低 GPU 显存** | 83 | | -------------- | ----------------- | 84 | | FP16(无量化) | 13 GB | 85 | | INT8 | 10 GB | 86 | | INT4 | 6 GB | 87 | 88 | 89 | 90 | ![inference](assets/1.png) 91 | 92 | 93 | 94 | 95 | ## training 96 | ```text 97 | # 制作数据 98 | cd scripts 99 | bash train_full.sh -m dataset 100 | or 101 | bash train_lora.sh -m dataset 102 | or 103 | bash train_ptv2.sh -m dataset 104 | 105 | 注: num_process_worker 为多进程制作数据 , 如果数据量较大 , 适当调大至cpu数量 106 | dataHelper.make_dataset_with_args(data_args.train_file,mixed_data=False, shuffle=True,mode='train',num_process_worker=0) 107 | 108 | # 全参数训练 109 | bash train_full.sh -m train 110 | 111 | # lora adalora ia3 112 | bash train_lora.sh -m train 113 | 114 | # ptv2 115 | bash train_ptv2.sh -m train 116 | ``` 117 | 118 | ## aigc-serving 119 | 120 | 部署qwen之后 , 可测试工具函数 121 | - [quad_calculator.py](https://github.com/ssbuild/aigc_serving/blob/main/tests/quad_calculator.py) 122 | 123 | 124 | 125 | ## 训练参数 126 | [训练参数](args.MD) 127 | 128 | 129 | 130 | 131 | ## 友情链接 132 | 133 | - [pytorch-task-example](https://github.com/ssbuild/pytorch-task-example) 134 | - [moss_finetuning](https://github.com/ssbuild/chatmoss_finetuning) 135 | - [chatglm_finetuning](https://github.com/ssbuild/chatglm_finetuning) 136 | - [t5_finetuning](https://github.com/ssbuild/t5_finetuning) 137 | - [llm_finetuning](https://github.com/ssbuild/llm_finetuning) 138 | - [llm_rlhf](https://github.com/ssbuild/llm_rlhf) 139 | - [chatglm_rlhf](https://github.com/ssbuild/chatglm_rlhf) 140 | - [t5_rlhf](https://github.com/ssbuild/t5_rlhf) 141 | - [rwkv_finetuning](https://github.com/ssbuild/rwkv_finetuning) 142 | - [baichuan_finetuning](https://github.com/ssbuild/baichuan_finetuning) 143 | - [xverse_finetuning](https://github.com/ssbuild/xverse_finetuning) 144 | - [internlm_finetuning](https://github.com/ssbuild/internlm_finetuning) 145 | - [qwen_finetuning](https://github.com/ssbuild/qwen_finetuning) 146 | - [skywork_finetuning](https://github.com/ssbuild/skywork_finetuning) 147 | - [bluelm_finetuning](https://github.com/ssbuild/bluelm_finetuning) 148 | - [yi_finetuning](https://github.com/ssbuild/yi_finetuning) 149 | 150 | 151 | ## 152 | 纯粹而干净的代码 153 | 154 | 155 | 156 | ## Reference 157 | https://github.com/QwenLM/Qwen-7B 158 | 159 | 160 | 161 | 162 | ## Star History 163 | 164 | [![Star History Chart](https://api.star-history.com/svg?repos=ssbuild/qwen_finetuning&type=Date)](https://star-history.com/#ssbuild/qwen_finetuning&Date) 165 | 166 | -------------------------------------------------------------------------------- /args.MD: -------------------------------------------------------------------------------- 1 | 2 | ## 切换训练模式配置 3 | 修改 config/main.py 4 | # 模块配置, 默认启用lora 5 | enable_deepspeed = False 6 | enable_ptv2 = False 7 | enable_lora = True 8 | load_in_bit = 0 # 4 load_in_4bit, 8 load_in_8bit other 0 9 | 10 | ## optimizer 11 | # lamb,adamw_hf,adamw,adamw_torch,adamw_torch_fused,adamw_torch_xla,adamw_apex_fused, 12 | # adafactor,adamw_anyprecision,sgd,adagrad,adamw_bnb_8bit,adamw_8bit,lion,lion_8bit,lion_32bit, 13 | # paged_adamw_32bit,paged_adamw_8bit,paged_lion_32bit,paged_lion_8bit, 14 | # lamb_fused_dp adagrad_cpu_dp adam_cpu_dp adam_fused_dp 15 | 16 | ## scheduler 17 | # linear,WarmupCosine,CAWR,CAL,Step,ReduceLROnPlateau, cosine,cosine_with_restarts,polynomial, 18 | # constant,constant_with_warmup,inverse_sqrt,reduce_lr_on_plateau 19 | 20 | ### 单机多卡 21 | ```text 22 | 可见的前两块卡 23 | config_args = { 24 | 'devices': 2, 25 | } 26 | 27 | # 第一块 和 第三块卡 28 | config_args = { 29 | 'devices': [0,2], 30 | } 31 | ``` 32 | 33 | ### 多机多卡训练 34 | ```text 35 | 例子 3个机器 每个机器 4个卡 36 | 修改train.py Trainer num_nodes = 3 37 | MASTER_ADDR=10.0.0.1 MASTER_PORT=6667 WORLD_SIZE=12 NODE_RANK=0 python train.py 38 | MASTER_ADDR=10.0.0.1 MASTER_PORT=6667 WORLD_SIZE=12 NODE_RANK=1 python train.py 39 | MASTER_ADDR=10.0.0.1 MASTER_PORT=6667 WORLD_SIZE=12 NODE_RANK=2 python train.py 40 | ``` 41 | 42 | 43 | ### 超大数据集 44 | 修改data_utils.py "data_backend": "lmdb" 45 | 46 | ## 精度训练 47 | Trainer.precision = '16' # 半精度训练 "32": "32-true", "16": "16-mixed", "bf16": "bf16-mixed" 48 | 49 | 50 | 51 | ### lora finetuning 52 | ```text 53 | global_args = { 54 | "load_in_8bit": False, # lora 如果显卡支持int8 可以开启 , 需安装依赖 pip install bitsandbytes 55 | "num_layers_freeze": -1, # 非lora,非p-tuning 模式 , <= config.json num_layers 56 | "num_layers": -1, # 是否使用骨干网络的全部层数 最大1-28, -1 表示全层, 否则只用只用N层 57 | } 58 | lora_info_args = { 59 | 'with_lora': True, # 是否启用lora模块 60 | 'r': 8, 61 | 'target_modules': ['query_key_value'], 62 | 'target_dtype': None, 63 | 'lora_alpha': 32, 64 | 'lora_dropout': 0.1, 65 | 'bias': 'none', # Bias type for Lora. Can be 'none', 'all' or 'lora_only'" 66 | 'modules_to_save' : None, # "help": "List of modules apart from LoRA layers to be set as trainable and saved in the final checkpoint. " 67 | } 68 | ``` 69 | 70 | 71 | ### ptuning v2 72 | 73 | 74 | ```text 75 | global_args = { 76 | "load_in_8bit": False, # lora 如果显卡支持int8 可以开启 , 需安装依赖 pip install bitsandbytes 77 | "num_layers_freeze": -1, # 非lora,非p-tuning 模式 , <= config.json num_layers 78 | "num_layers": -1, # 是否使用骨干网络的全部层数 最大1-28, -1 表示全层, 否则只用只用N层 79 | } 80 | 81 | ``` 82 | 83 | 84 | ### freeze 85 | 86 | 87 | ```text 88 | global_args = { 89 | "load_in_8bit": False, # lora 如果显卡支持int8 可以开启 , 需安装依赖 pip install bitsandbytes 90 | "num_layers_freeze": 14, # 非lora,非p-tuning 模式 , <= config.json num_layers 91 | "num_layers": -1, # 是否使用骨干网络的全部层数 最大1-28, -1 表示全层, 否则只用只用N层 92 | } 93 | 94 | 95 | 96 | ``` 97 | 98 | ## 全参数微调 99 | 100 | ```text 101 | global_args = { 102 | "load_in_8bit": False, # lora 如果显卡支持int8 可以开启 , 需安装依赖 pip install bitsandbytes 103 | "num_layers_freeze": -1, # 非lora,非p-tuning 模式 , <= config.json num_layers 104 | "num_layers": -1, # 是否使用骨干网络的全部层数 最大1-28, -1 表示全层, 否则只用只用N层 105 | } 106 | lora_info_args = { 107 | 'with_lora': False, # 是否启用lora模块 108 | ... 109 | } 110 | adalora_info_args = { 111 | 'with_lora': False, # 是否启用lora模块 112 | ... 113 | } 114 | ``` 115 | 116 | 117 | -------------------------------------------------------------------------------- /assets/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ssbuild/qwen_finetuning/35ad9c25dab59c750cfc01d78b2defd27f06a092/assets/1.png -------------------------------------------------------------------------------- /assets/react_prompt.md: -------------------------------------------------------------------------------- 1 | # ReAct Prompting 示例 2 | 3 | 本文档将介绍如何用 ReAct Prompting 技术命令千问使用工具。 4 | 5 | 本文档主要基本的原理概念介绍,并在文末附上了一些具体实现相关的 FAQ,但不含被调用插件的实际实现。如果您更喜欢一边调试实际可执行的代码、一边理解原理,可以转而阅读整合了 LangChain 常用工具的这个 [ipython notebook](https://github.com/QwenLM/Qwen-7B/blob/main/examples/langchain_tooluse.ipynb)。 6 | 7 | 此外,本文档和前述的 ipython notebook 都仅介绍单轮对话的实现。如果想了解多轮对话下的实现,可参见 [react_demo.py](https://github.com/QwenLM/Qwen-7B/blob/main/examples/react_demo.py)。 8 | 9 | ## 准备工作一:样例问题、样例工具 10 | 11 | 假设我们有如下的一个适合用工具处理的 query,以及有夸克搜索、通义万相文生图这两个工具: 12 | 13 | ```py 14 | query = '现在给我画个五彩斑斓的黑。' 15 | 16 | TOOLS = [ 17 | { 18 | 'name_for_human': 19 | '夸克搜索', 20 | 'name_for_model': 21 | 'quark_search', 22 | 'description_for_model': 23 | '夸克搜索是一个通用搜索引擎,可用于访问互联网、查询百科知识、了解时事新闻等。', 24 | 'parameters': [{ 25 | 'name': 'search_query', 26 | 'description': '搜索关键词或短语', 27 | 'required': True, 28 | 'schema': { 29 | 'type': 'string' 30 | }, 31 | }], 32 | }, 33 | { 34 | 'name_for_human': 35 | '通义万相', 36 | 'name_for_model': 37 | 'image_gen', 38 | 'description_for_model': 39 | '通义万相是一个AI绘画(图像生成)服务,输入文本描述,返回根据文本作画得到的图片的URL', 40 | 'parameters': [{ 41 | 'name': 'query', 42 | 'description': '中文关键词,描述了希望图像具有什么内容', 43 | 'required': True, 44 | 'schema': { 45 | 'type': 'string' 46 | }, 47 | }], 48 | }, 49 | ] 50 | ``` 51 | 52 | ## 准备工作二:ReAct 模版 53 | 54 | 我们将使用如下的 ReAct prompt 模版来激发千问使用工具的能力。 55 | 56 | ```py 57 | TOOL_DESC = """{name_for_model}: Call this tool to interact with the {name_for_human} API. What is the {name_for_human} API useful for? {description_for_model} Parameters: {parameters} Format the arguments as a JSON object.""" 58 | 59 | REACT_PROMPT = """Answer the following questions as best you can. You have access to the following tools: 60 | 61 | {tool_descs} 62 | 63 | Use the following format: 64 | 65 | Question: the input question you must answer 66 | Thought: you should always think about what to do 67 | Action: the action to take, should be one of [{tool_names}] 68 | Action Input: the input to the action 69 | Observation: the result of the action 70 | ... (this Thought/Action/Action Input/Observation can be repeated zero or more times) 71 | Thought: I now know the final answer 72 | Final Answer: the final answer to the original input question 73 | 74 | Begin! 75 | 76 | Question: {query}""" 77 | ``` 78 | 79 | ## 步骤一:让千问判断要调用什么工具、生成工具入参 80 | 81 | 首先我们需要根据 ReAct prompt 模版、query、工具的信息构建 prompt: 82 | 83 | ```py 84 | tool_descs = [] 85 | tool_names = [] 86 | for info in TOOLS: 87 | tool_descs.append( 88 | TOOL_DESC.format( 89 | name_for_model=info['name_for_model'], 90 | name_for_human=info['name_for_human'], 91 | description_for_model=info['description_for_model'], 92 | parameters=json.dumps( 93 | info['parameters'], ensure_ascii=False), 94 | ) 95 | ) 96 | tool_names.append(info['name_for_model']) 97 | tool_descs = '\n\n'.join(tool_descs) 98 | tool_names = ','.join(tool_names) 99 | 100 | prompt = REACT_PROMPT.format(tool_descs=tool_descs, tool_names=tool_names, query=query) 101 | print(prompt) 102 | ``` 103 | 104 | 打印出来的、构建好的 prompt 如下: 105 | 106 | ``` 107 | Answer the following questions as best you can. You have access to the following tools: 108 | 109 | quark_search: Call this tool to interact with the 夸克搜索 API. What is the 夸克搜索 API useful for? 夸克搜索是一个通用搜索引擎,可用于访问互联网、查询百科知识、了解时事新闻等。 Parameters: [{"name": "search_query", "description": "搜索关键词或短语", "required": true, "schema": {"type": "string"}}] Format the arguments as a JSON object. 110 | 111 | image_gen: Call this tool to interact with the 通义万相 API. What is the 通义万相 API useful for? 通义万相是一个AI绘画(图像生成)服务,输入文本描述,返回根据文本作画得到的图片的URL Parameters: [{"name": "query", "description": "中文关键词,描述了希望图像具有什么内容", "required": true, "schema": {"type": "string"}}] Format the arguments as a JSON object. 112 | 113 | Use the following format: 114 | 115 | Question: the input question you must answer 116 | Thought: you should always think about what to do 117 | Action: the action to take, should be one of [quark_search,image_gen] 118 | Action Input: the input to the action 119 | Observation: the result of the action 120 | ... (this Thought/Action/Action Input/Observation can be repeated zero or more times) 121 | Thought: I now know the final answer 122 | Final Answer: the final answer to the original input question 123 | 124 | Begin! 125 | 126 | Question: 现在给我画个五彩斑斓的黑。 127 | ``` 128 | 129 | 将这个 prompt 送入千问,并记得设置 "Observation" 为 stop word (见本文末尾的 FAQ)—— 即让千问在预测到要生成的下一个词是 "Observation" 时马上停止生成 —— 则千问在得到这个 prompt 后会生成如下的结果: 130 | 131 | ![](../assets/react_tutorial_001.png) 132 | 133 | ``` 134 | Thought: 我应该使用通义万相API来生成一张五彩斑斓的黑的图片。 135 | Action: image_gen 136 | Action Input: {"query": "五彩斑斓的黑"} 137 | ``` 138 | 139 | 在得到这个结果后,调用千问的开发者可以通过简单的解析提取出 `{"query": "五彩斑斓的黑"}` 并基于这个解析结果调用文生图服务 —— 这部分逻辑需要开发者自行实现,或者也可以使用千问商业版,商业版本将内部集成相关逻辑。 140 | 141 | ## 步骤二:让千问根据插件返回结果继续作答 142 | 143 | 让我们假设文生图插件返回了如下结果: 144 | 145 | ``` 146 | {"status_code": 200, "request_id": "3d894da2-0e26-9b7c-bd90-102e5250ae03", "code": null, "message": "", "output": {"task_id": "2befaa09-a8b3-4740-ada9-4d00c2758b05", "task_status": "SUCCEEDED", "results": [{"url": "https://dashscope-result-sh.oss-cn-shanghai.aliyuncs.com/1e5e2015/20230801/1509/6b26bb83-469e-4c70-bff4-a9edd1e584f3-1.png"}], "task_metrics": {"TOTAL": 1, "SUCCEEDED": 1, "FAILED": 0}}, "usage": {"image_count": 1}} 147 | ``` 148 | 149 | ![](../assets/wanx_colorful_black.png) 150 | 151 | 接下来,我们可以将之前首次请求千问时用的 prompt 和 调用文生图插件的结果拼接成如下的新 prompt: 152 | 153 | ``` 154 | Answer the following questions as best you can. You have access to the following tools: 155 | 156 | quark_search: Call this tool to interact with the 夸克搜索 API. What is the 夸克搜索 API useful for? 夸克搜索是一个通用搜索引擎,可用于访问互联网、查询百科知识、了解时事新闻等。 Parameters: [{"name": "search_query", "description": "搜索关键词或短语", "required": true, "schema": {"type": "string"}}] Format the arguments as a JSON object. 157 | 158 | image_gen: Call this tool to interact with the 通义万相 API. What is the 通义万相 API useful for? 通义万相是一个AI绘画(图像生成)服务,输入文本描述,返回根据文本作画得到的图片的URL Parameters: [{"name": "query", "description": "中文关键词,描述了希望图像具有什么内容", "required": true, "schema": {"type": "string"}}] Format the arguments as a JSON object. 159 | 160 | Use the following format: 161 | 162 | Question: the input question you must answer 163 | Thought: you should always think about what to do 164 | Action: the action to take, should be one of [quark_search,image_gen] 165 | Action Input: the input to the action 166 | Observation: the result of the action 167 | ... (this Thought/Action/Action Input/Observation can be repeated zero or more times) 168 | Thought: I now know the final answer 169 | Final Answer: the final answer to the original input question 170 | 171 | Begin! 172 | 173 | Question: 现在给我画个五彩斑斓的黑。 174 | Thought: 我应该使用通义万相API来生成一张五彩斑斓的黑的图片。 175 | Action: image_gen 176 | Action Input: {"query": "五彩斑斓的黑"} 177 | Observation: {"status_code": 200, "request_id": "3d894da2-0e26-9b7c-bd90-102e5250ae03", "code": null, "message": "", "output": {"task_id": "2befaa09-a8b3-4740-ada9-4d00c2758b05", "task_status": "SUCCEEDED", "results": [{"url": "https://dashscope-result-sh.oss-cn-shanghai.aliyuncs.com/1e5e2015/20230801/1509/6b26bb83-469e-4c70-bff4-a9edd1e584f3-1.png"}], "task_metrics": {"TOTAL": 1, "SUCCEEDED": 1, "FAILED": 0}}, "usage": {"image_count": 1}} 178 | ``` 179 | 180 | 用这个新的拼接了文生图插件结果的新 prompt 去调用千问,将得到如下的最终回复: 181 | 182 | ![](../assets/react_tutorial_002.png) 183 | 184 | ``` 185 | Thought: 我已经成功使用通义万相API生成了一张五彩斑斓的黑的图片。 186 | Final Answer: 我已经成功使用通义万相API生成了一张五彩斑斓的黑的图片https://dashscope-result-sh.oss-cn-shanghai.aliyuncs.com/1e5e2015/20230801/1509/6b26bb83-469e-4c70-bff4-a9edd1e584f3-1.png。 187 | ``` 188 | 189 | 虽然对于文生图来说,这个第二次调用千问的步骤显得多余。但是对于搜索插件、代码执行插件、计算器插件等别的插件来说,这个第二次调用千问的步骤给了千问提炼、总结插件返回结果的机会。 190 | 191 | ## FAQ 192 | 193 | **怎么配置 "Observation" 这个 stop word?** 194 | 195 | 通过 chat 接口的 stop_words_ids 指定: 196 | ```py 197 | react_stop_words = [ 198 | # tokenizer.encode('Observation'), # [37763, 367] 199 | tokenizer.encode('Observation:'), # [37763, 367, 25] 200 | tokenizer.encode('Observation:\n'), # [37763, 367, 510] 201 | ] 202 | response, history = model.chat( 203 | tokenizer, query, history, 204 | stop_words_ids=react_stop_words # 此接口用于增加 stop words 205 | ) 206 | ``` 207 | 208 | 如果报错称不存在 stop_words_ids 此参数,可能是因为您用了老的代码,请重新执行 from_pretrained 拉取新的代码和模型。 209 | 210 | 需要注意的是,当前的 tokenizer 对 `\n` 有一系列较复杂的聚合操作。比如例子中的`:\n`这两个字符便被聚合成了一个 token。因此配置 stop words 需要非常细致地预估 tokenizer 的行为。 211 | 212 | **对 top_p 等推理参数有调参建议吗?** 213 | 214 | 通常来讲,较低的 top_p 会有更高的准确度,但会牺牲回答的多样性、且更易出现重复某个词句的现象。 215 | 216 | 可以按如下方式调整 top_p 为 0.5: 217 | ```py 218 | model.generation_config.top_p = 0.5 219 | ``` 220 | 221 | 特别的,可以用如下方式关闭 top-p sampling,改用 greedy sampling,效果上相当于 top_p=0 或 temperature=0: 222 | ```py 223 | model.generation_config.do_sample = False # greedy decoding 224 | ``` 225 | 226 | 此外,我们在 `model.chat()` 接口也提供了调整 top_p 等参数的接口。 227 | 228 | **有解析Action、Action Input的参考代码吗?** 229 | 230 | 有的,可以参考: 231 | ```py 232 | def parse_latest_plugin_call(text: str) -> Tuple[str, str]: 233 | i = text.rfind('\nAction:') 234 | j = text.rfind('\nAction Input:') 235 | k = text.rfind('\nObservation:') 236 | if 0 <= i < j: # If the text has `Action` and `Action input`, 237 | if k < j: # but does not contain `Observation`, 238 | # then it is likely that `Observation` is ommited by the LLM, 239 | # because the output text may have discarded the stop word. 240 | text = text.rstrip() + '\nObservation:' # Add it back. 241 | k = text.rfind('\nObservation:') 242 | if 0 <= i < j < k: 243 | plugin_name = text[i + len('\nAction:'):j].strip() 244 | plugin_args = text[j + len('\nAction Input:'):k].strip() 245 | return plugin_name, plugin_args 246 | return '', '' 247 | ``` 248 | 249 | 此外,如果输出的 Action Input 内容是一段表示 JSON 对象的文本,我们建议使用 `json5` 包的 `json5.loads(...)` 方法加载。 250 | -------------------------------------------------------------------------------- /config/__init__.py: -------------------------------------------------------------------------------- 1 | # coding=utf8 2 | # @Time : 2023/5/12 20:05 3 | # @Author : tk 4 | # @FileName: chatglm_config 5 | import json 6 | import os 7 | 8 | # 切换配置 9 | from config.main import * 10 | 11 | -------------------------------------------------------------------------------- /config/colossalai_strategy.yaml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | colossalai_strategy: 5 | "ddp": 6 | name: "ddp" 7 | broadcast_buffers: True 8 | bucket_cap_mb: 25 9 | find_unused_parameters: False 10 | check_reduction: False 11 | gradient_as_bucket_view: False 12 | static_graph: False 13 | "gemini": 14 | name: "gemini" 15 | chunk_config_dict: None 16 | chunk_init_device: None 17 | placement_policy: "static" 18 | shard_param_frac: 1.0 # only for static placement 19 | offload_optim_frac: 0.0 # only for static placement 20 | offload_param_frac: 0.0 # only for static placement 21 | warmup_non_model_data_ratio: 0.8 # only for auto placement 22 | steady_cuda_cap_ratio: 0.9 # only for auto placement 23 | precision: "fp16" 24 | pin_memory: False 25 | force_outputs_fp32: False 26 | strict_ddp_mode: False 27 | search_range_m: 32 28 | hidden_dim: None 29 | min_chunk_size_m: 32 30 | memstats: None 31 | gpu_margin_mem_ratio: 0.0 32 | initial_scale: 2 ** 16 33 | min_scale: 1 34 | growth_factor: 2 35 | backoff_factor: 0.5 36 | growth_interval: 1000 37 | hysteresis: 2 38 | max_scale: 2 ** 32 39 | max_norm: 1.0 40 | norm_type: 2.0 41 | verbose: False 42 | "zero2": 43 | name: zero2 44 | stage: 2 45 | precision: "fp16" 46 | initial_scale: 2 ** 32 47 | min_scale: 1 48 | growth_factor: 2 49 | backoff_factor: 0.5 50 | growth_interval: 1000 51 | hysteresis: 2 52 | max_scale: 2 ** 32 53 | max_norm: 1.0 54 | norm_type: 2.0 55 | reduce_bucket_size_in_m: 12 56 | communication_dtype: None 57 | overlap_communication: True 58 | cpu_offload: False 59 | verbose: False 60 | 61 | "zero2_cpu": 62 | name: zero2_cpu 63 | stage: 2 64 | precision: "fp16" 65 | initial_scale: 2 ** 32 66 | min_scale: 1 67 | growth_factor: 2 68 | backoff_factor: 0.5 69 | growth_interval: 1000 70 | hysteresis: 2 71 | max_scale: 2 ** 32 72 | max_norm: 1.0 73 | norm_type: 2.0 74 | reduce_bucket_size_in_m: 12 75 | communication_dtype: None 76 | overlap_communication: True 77 | cpu_offload: True 78 | verbose: False 79 | 80 | "3d": 81 | name: "3d" 82 | tp_size: 1 83 | pp_size: 1 84 | precision: "fp16" 85 | zero_stage: 0 86 | enable_all_optimization: False 87 | enable_fused_normalization: False 88 | enable_flash_attention: False 89 | enable_jit_fused: False 90 | enable_sequence_parallelism: False 91 | enable_sequence_overlap: False 92 | num_microbatches: None 93 | microbatch_size: None 94 | initial_scale: 2 ** 16 95 | min_scale: 1 96 | growth_factor: 2 97 | backoff_factor: 0.5 98 | growth_interval: 1000 99 | hysteresis: 2 100 | max_scale: 2 ** 32 101 | max_norm: 0 102 | broadcast_buffers: True 103 | ddp_bucket_cap_mb: 25 104 | find_unused_parameters: False 105 | check_reduction: False 106 | gradient_as_bucket_view: False 107 | static_graph: False 108 | zero_bucket_size_in_m: 12 109 | cpu_offload: False 110 | communication_dtype: None 111 | overlap_communication: True 112 | custom_policy: None 113 | -------------------------------------------------------------------------------- /config/constant_map.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time: 23:20 3 | # @Author: tk 4 | # @File:model_maps 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /config/deepspeed.yaml: -------------------------------------------------------------------------------- 1 | "zero_allow_untested_optimizer": true, 2 | "fp16": 3 | "enabled": true, 4 | "auto_cast": false, 5 | "loss_scale": 0, 6 | "initial_scale_power": 16, 7 | "loss_scale_window": 1000, 8 | "hysteresis": 2, 9 | "min_loss_scale": 1 10 | 11 | "zero_optimization": 12 | "stage": 2, 13 | "allgather_partitions": true, 14 | "allgather_bucket_size": 5e8, 15 | "overlap_comm": false, 16 | "reduce_scatter": true, 17 | "reduce_bucket_size": 5e8, 18 | "contiguous_gradients" : true, 19 | 20 | "stage3_max_live_parameters": 1e9, 21 | "stage3_max_reuse_distance" : 1e9, 22 | "stage3_prefetch_bucket_size" : 5e8, 23 | "stage3_param_persistence_threshold" : 1e6, 24 | "sub_group_size" : 1e12, 25 | "elastic_checkpoint" : true, 26 | "stage3_gather_16bit_weights_on_model_save": true, 27 | "ignore_unused_parameters": true, 28 | "round_robin_gradients": true 29 | -------------------------------------------------------------------------------- /config/deepspeed_offload.yaml: -------------------------------------------------------------------------------- 1 | 2 | "optimizer": 3 | "type": "AdamW" 4 | "params": 5 | "lr": 2e-5 6 | "betas": [0.9, 0.999] 7 | "eps": 1e-8 8 | "weight_decay": 0 9 | 10 | "zero_allow_untested_optimizer": true 11 | "fp16": 12 | "enabled": true 13 | "auto_cast": false 14 | "loss_scale": 0 15 | "initial_scale_power": 16 16 | "loss_scale_window": 1000 17 | "hysteresis": 2 18 | "min_loss_scale": 1 19 | "zero_optimization": 20 | "stage": 2 21 | "allgather_partitions": true 22 | "allgather_bucket_size": 5e8 23 | "overlap_comm": false 24 | "reduce_scatter": true 25 | "reduce_bucket_size": 5e8 26 | "contiguous_gradients": true 27 | "stage3_max_live_parameters": 1e9 28 | "stage3_max_reuse_distance": 1e9 29 | "stage3_prefetch_bucket_size": 5e8 30 | "stage3_param_persistence_threshold": 1e6 31 | "sub_group_size": 1e12 32 | "elastic_checkpoint": true 33 | "stage3_gather_16bit_weights_on_model_save": true 34 | "ignore_unused_parameters": true 35 | "round_robin_gradients": true 36 | "offload_optimizer": 37 | "device": "cpu" 38 | "pin_memory": true 39 | -------------------------------------------------------------------------------- /config/global.yaml: -------------------------------------------------------------------------------- 1 | global_args: 2 | trainer_backend: pl 3 | enable_deepspeed: false 4 | enable_ptv2: false 5 | enable_lora: true 6 | load_in_bit: 0 7 | config_merge: {} 8 | # 模型权重 , 对应 config.constant_map.py 9 | model_name: Qwen2.5-1.5B 10 | 11 | # one of auto 16 bf16 32 12 | precision: auto 13 | quantization_config: 14 | load_in_8bit: false 15 | load_in_4bit: false 16 | llm_int8_threshold: 6.0 17 | llm_int8_has_fp16_weight: false 18 | bnb_4bit_compute_dtype: float16 # one of float16 bfloat16 float32 19 | bnb_4bit_use_double_quant: true 20 | bnb_4bit_quant_type: nf4 21 | 22 | 23 | # qwen 模型 >= 1.5 均可在下面添加 24 | 25 | global_models_mapper: 26 | 27 | Qwen2.5-1.5B: 28 | model_type: qwen 29 | model_name_or_path: /data/nlp/pre_models/torch/qwen/Qwen2.5-1.5B-Instruct 30 | config_name: /data/nlp/pre_models/torch/qwen/Qwen2.5-1.5B-Instruct 31 | tokenizer_name: /data/nlp/pre_models/torch/qwen/Qwen2.5-1.5B-Instruct 32 | 33 | Qwen1.5-1.8B-Chat: 34 | model_type: qwen2 35 | model_name_or_path: /data/nlp/pre_models/torch/qwen2/Qwen1.5-1.8B-Chat 36 | config_name: /data/nlp/pre_models/torch/qwen2/Qwen1.5-1.8B-Chat 37 | tokenizer_name: /data/nlp/pre_models/torch/qwen2/Qwen1.5-1.8B-Chat 38 | 39 | Qwen1.5-MoE-A2.7B: 40 | model_type: qwen2_moe 41 | model_name_or_path: /data/nlp/pre_models/torch/qwen2/Qwen1.5-MoE-A2.7B 42 | config_name: /data/nlp/pre_models/torch/qwen2/Qwen1.5-MoE-A2.7B 43 | tokenizer_name: /data/nlp/pre_models/torch/qwen2/Qwen1.5-MoE-A2.7B 44 | 45 | -------------------------------------------------------------------------------- /config/main.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author : ssbuild 3 | # @Time : 2023/5/31 14:43 4 | import json 5 | import os 6 | import torch 7 | import yaml 8 | from transformers import BitsAndBytesConfig 9 | from transformers.utils import strtobool 10 | 11 | from deep_training.zoo.constants.define import (TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING, 12 | TRANSFORMERS_MODELS_TO_ADALORA_TARGET_MODULES_MAPPING, 13 | TRANSFORMERS_MODELS_TO_IA3_TARGET_MODULES_MAPPING, 14 | TRANSFORMERS_MODELS_TO_IA3_FEEDFORWARD_MODULES_MAPPING) 15 | 16 | # 按需修改 17 | # TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING 18 | # TRANSFORMERS_MODELS_TO_ADALORA_TARGET_MODULES_MAPPING 19 | # TRANSFORMERS_MODELS_TO_IA3_TARGET_MODULES_MAPPING 20 | # TRANSFORMERS_MODELS_TO_IA3_FEEDFORWARD_MODULES_MAPPING 21 | 22 | 23 | TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING['qwen'] = ['q_proj','k_proj','v_proj'] 24 | TRANSFORMERS_MODELS_TO_ADALORA_TARGET_MODULES_MAPPING['qwen'] = ['q_proj','k_proj','v_proj'] 25 | TRANSFORMERS_MODELS_TO_IA3_TARGET_MODULES_MAPPING['qwen'] = ['q_proj','k_proj','v_proj'] 26 | 27 | from deep_training.utils.wrapper import load_yaml 28 | 29 | 30 | 31 | # 加载 32 | __CUR_PATH__ = os.path.abspath(os.path.dirname(__file__)) 33 | 34 | 35 | config_args = load_yaml(os.environ.get('train_file', os.path.join(__CUR_PATH__, 'train_pl.yaml'))) 36 | global_args = config_args.pop("global_args") 37 | global_models_mapper = config_args.pop("global_models_mapper") 38 | colossalai_strategy = config_args.pop("colossalai_strategy", {}) 39 | train_model_config = global_models_mapper[global_args["model_name"]] 40 | 41 | 42 | def merge_from_env(global_args): 43 | merge_config = {} 44 | if "trainer_backend" in os.environ: 45 | merge_config["trainer_backend"] = str(os.environ["trainer_backend"]) 46 | if "enable_deepspeed" in os.environ: 47 | merge_config["enable_deepspeed"] = strtobool(os.environ["enable_deepspeed"]) 48 | if "enable_ptv2" in os.environ: 49 | merge_config["enable_ptv2"] = strtobool(os.environ["enable_ptv2"]) 50 | if "enable_lora" in os.environ: 51 | merge_config["enable_lora"] = strtobool(os.environ["enable_lora"]) 52 | if "load_in_bit" in os.environ: 53 | merge_config["load_in_bit"] = int(os.environ["load_in_bit"]) 54 | if merge_config: 55 | global_args.update(merge_config) 56 | 57 | merge_from_env(global_args) 58 | 59 | def patch_args(config_args): 60 | assert global_args["trainer_backend"] in ["pl", "hf", "cl", "ac"] 61 | global_args["precision"] = str(global_args["precision"]) 62 | 63 | if global_args["quantization_config"]: 64 | # 精度 65 | if global_args["precision"] == "auto": 66 | global_args["quantization_config"][ 67 | "bnb_4bit_compute_dtype"] = "bfloat16" if torch.cuda.is_bf16_supported() else "float16" 68 | 69 | global_args["quantization_config"] = BitsAndBytesConfig(**global_args["quantization_config"]) 70 | 71 | assert global_args["enable_lora"] + global_args["enable_ptv2"] <= 1 , ValueError("lora ptv2 cannot open at same time") 72 | 73 | #更新模型配置 74 | config_args.update(train_model_config) 75 | 76 | if global_args["trainer_backend"] == "cl": 77 | config_args["strategy"] = colossalai_strategy[config_args["strategy"]] 78 | 79 | if global_args['quantization_config'] is not None: 80 | global_args['quantization_config'].load_in_4bit = global_args["load_in_bit"] == 4 81 | global_args['quantization_config'].load_in_8bit = global_args["load_in_bit"] == 8 82 | if global_args["load_in_bit"] == 0: 83 | global_args["quantization_config"] = None 84 | 85 | 86 | 87 | if global_args["enable_lora"]: 88 | # 检查lora adalora是否开启 89 | assert config_args.get('lora', {}).get('with_lora', False) + \ 90 | config_args.get('adalora', {}).get('with_lora', False) + \ 91 | config_args.get('ia3', {}).get('with_lora', False) == 1, ValueError( 92 | 'lora adalora ia3 can set one at same time !') 93 | 94 | model_type = train_model_config['model_type'] 95 | if config_args.get('lora', {}).get('with_lora', False): 96 | config_args["lora"]["target_modules"] = TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING[model_type] 97 | elif config_args.get('adalora', {}).get('with_lora', False): 98 | config_args["adalora"]["target_modules"] = TRANSFORMERS_MODELS_TO_ADALORA_TARGET_MODULES_MAPPING[model_type] 99 | else: 100 | config_args["ia3"]["target_modules"] = TRANSFORMERS_MODELS_TO_IA3_TARGET_MODULES_MAPPING[model_type] 101 | config_args["ia3"]["feedforward_modules"] = TRANSFORMERS_MODELS_TO_IA3_FEEDFORWARD_MODULES_MAPPING[model_type] 102 | 103 | config_args.pop('prompt', None) 104 | 105 | elif global_args["enable_ptv2"]: 106 | config_args.pop('lora', None) 107 | config_args.pop('adalora', None) 108 | config_args.pop('ia3', None) 109 | if "gradient_checkpointing" in config_args: 110 | config_args[ "gradient_checkpointing" ] = False 111 | 112 | assert "prompt" in config_args 113 | config_args["prompt"]["with_prompt"] = True 114 | else: 115 | config_args.pop('lora',None) 116 | config_args.pop('adalora', None) 117 | config_args.pop('ia3', None) 118 | config_args.pop('prompt', None) 119 | 120 | # 预处理 121 | if 'rwkv' in (config_args['model_type'] or config_args['model_name_or_path']).lower(): 122 | config_args['use_fast_tokenizer'] = True 123 | 124 | 125 | 126 | patch_args(config_args) 127 | 128 | def get_deepspeed_config(precision='fp16'): 129 | ''' 130 | lora prompt finetuning deepspeed_offload.json 131 | 普通 finetuning deepspeed.json 132 | ''' 133 | # 是否开启deepspeed 134 | if not global_args["enable_deepspeed"]: 135 | return None 136 | precision = str(precision).lower() 137 | # 选择 deepspeed 配置文件 138 | is_need_update_config = False 139 | if global_args["enable_lora"] or global_args["enable_ptv2"]: 140 | is_need_update_config = True 141 | filename = os.path.join(os.path.dirname(__file__), 'deepspeed_offload.json') 142 | else: 143 | filename = os.path.join(os.path.dirname(__file__), 'deepspeed.json') 144 | 145 | 146 | with open(filename, mode='r', encoding='utf-8') as f: 147 | deepspeed_config = json.loads(f.read()) 148 | 149 | #lora offload 同步优化器配置 150 | if is_need_update_config: 151 | optimizer = deepspeed_config.get('optimizer',None) 152 | if optimizer: 153 | if global_args["trainer_backend"] == 'hf': 154 | optimizer[ 'params' ][ 'betas' ] = (config_args.get('adam_beta1', 0.9),config_args.get('adam_beta2', 0.999),) 155 | optimizer[ 'params' ][ 'lr' ] = config_args.get('learning_rate', 2e-5) 156 | optimizer[ 'params' ][ 'eps' ] = config_args.get('adam_epsilon', 1e-8) 157 | # deepspeed_offload 优化器有效 158 | config_args[ 'optim' ] = optimizer[ 'type' ] 159 | else: 160 | optimizer['params']['betas'] = config_args.get('optimizer_betas', (0.9, 0.999)) 161 | optimizer['params']['lr'] = config_args.get('learning_rate', 2e-5) 162 | optimizer['params']['eps'] = config_args.get('adam_epsilon', 1e-8) 163 | # deepspeed_offload 优化器有效 164 | config_args['optimizer'] = optimizer['type'] 165 | 166 | if precision == 'bf16': 167 | if 'fp16' in deepspeed_config: 168 | deepspeed_config["fp16"]["enbale"] = False 169 | if 'bf16' in deepspeed_config: 170 | deepspeed_config["bf16"]["enbale"] = True 171 | else: 172 | deepspeed_config['bf16'] = {"enbale": True} 173 | elif precision == 'fp16': 174 | if 'bf16' in deepspeed_config: 175 | deepspeed_config["bf16"]["enbale"] = False 176 | 177 | return deepspeed_config 178 | 179 | -------------------------------------------------------------------------------- /config/petl.yaml: -------------------------------------------------------------------------------- 1 | 2 | ############## lora模块 3 | 4 | lora: 5 | with_lora: true # 是否启用模块 6 | lora_type: lora 7 | r: 8 8 | lora_alpha: 32 9 | lora_dropout: 0.1 10 | fan_in_fan_out: false 11 | # Bias type for Lora. Can be 'none', 'all' or 'lora_only'" 12 | bias: none 13 | # "help": "List of modules apart from LoRA layers to be set as trainable and saved in the final checkpoint. " 14 | modules_to_save: null 15 | layers_to_transform: null 16 | layers_pattern: null 17 | 18 | # "The mapping from layer names or regexp expression to ranks which are different from the default rank specified by `r`. " 19 | # "For example, `{model.decoder.layers.0.encoder_attn.k_proj: 8`}" 20 | rank_pattern: {} 21 | 22 | # "The mapping from layer names or regexp expression to alphas which are different from the default alpha specified by `lora_alpha`. " 23 | # "For example, `{model.decoder.layers.0.encoder_attn.k_proj: 32`}" 24 | 25 | alpha_pattern: {} 26 | adalora: 27 | with_lora: false # 是否启用模块 28 | lora_type: adalora 29 | r: 8 30 | lora_alpha: 32 31 | lora_dropout: 0.1 32 | fan_in_fan_out: false 33 | # Bias type for Lora. Can be 'none', 'all' or 'lora_only'" 34 | bias: none 35 | # "help": "List of modules apart from LoRA layers to be set as trainable and saved in the final checkpoint. " 36 | modules_to_save: null 37 | layers_to_transform: null 38 | layers_pattern: null 39 | alpha_pattern: {} 40 | 41 | # Target Lora matrix dimension. 42 | target_r: 8 43 | #Intial Lora matrix dimension. 44 | init_r: 12 45 | #The steps of initial warmup. 46 | tinit: 0 47 | #The steps of final warmup 48 | tfinal: 0 49 | #Step interval of rank allocation. 50 | deltaT: 1 51 | #Hyperparameter of EMA. 52 | beta1: 0.85 53 | #Hyperparameter of EMA. 54 | beta2: 0.85 55 | #The orthogonal regularization coefficient. 56 | orth_reg_weight: 0.5 57 | 58 | #The total training steps. 59 | total_step: null 60 | 61 | #The saved rank pattern. 62 | rank_pattern: null 63 | 64 | ia3: 65 | with_lora: false # 是否启用模块 66 | fan_in_fan_out: false 67 | # "help": "List of modules apart from LoRA layers to be set as trainable and saved in the final checkpoint. " 68 | modules_to_save: null 69 | init_ia3_weights: true 70 | 71 | ############## ptv2模块 72 | prompt: 73 | with_prompt: true 74 | prompt_type: prefix_tuning 75 | task_type: causal_lm 76 | prefix_projection: false 77 | num_virtual_tokens: 32 78 | -------------------------------------------------------------------------------- /config/train_ac.yaml: -------------------------------------------------------------------------------- 1 | includes: [global.yaml, petl.yaml] 2 | 3 | # one of record lmdb arrow_stream arrow_file,parquet, 超大数据集可以使用 lmdb , 注 lmdb 存储空间比record大 4 | data_backend: parquet 5 | 6 | output_dir: ./outputs_ac 7 | overwrite_output_dir: true 8 | num_train_epochs: 20 9 | max_steps: -1 10 | save_safetensors: false 11 | save_strategy: steps 12 | save_steps: 1000 13 | save_total_limit: 10 14 | seed: 42 15 | fp16: true 16 | do_train: true 17 | train_file: 18 | - ../data/*.json 19 | 20 | do_eval: false 21 | do_predict: false 22 | per_device_train_batch_size: 2 23 | per_device_eval_batch_size: 2 24 | gradient_accumulation_steps: 1 25 | evaluation_strategy: 'no' 26 | eval_steps: 100 27 | 28 | # adamw_hf , adamw_torch,adamw_torch_fused,adamw_torch_xla,adamw_apex_fused, 29 | # adafactor,adamw_anyprecision,sgd,adagrad,adamw_bnb_8bit,adamw_8bit,lion,lion_8bit,lion_32bit, 30 | # paged_adamw_32bit,paged_adamw_8bit,paged_lion_32bit,paged_lion_8bit, 31 | # lamb_fused_dp adagrad_cpu_dp adam_cpu_dp adam_fused_dp 32 | 33 | optim: adamw_torch 34 | 35 | # one of linear,cosine,cosine_with_restarts,polynomial,constant_with_warmup,inverse_sqrt,reduce_lr_on_plateau 36 | lr_scheduler_type: cosine 37 | torch_compile: false 38 | learning_rate: 2.0e-05 39 | adam_beta1: 0.9 40 | adam_beta2: 0.999 41 | adam_epsilon: 1.0e-08 42 | max_grad_norm: 1.0 43 | weight_decay: 0.0 44 | warmup_ratio: 0.03 45 | logging_strategy: steps 46 | logging_steps: 10 47 | tf32: false 48 | gradient_checkpointing: false 49 | max_seq_length: 512 50 | max_target_length: 100 51 | do_lower_case: null 52 | 53 | use_fast_tokenizer: false 54 | dataloader_drop_last: true 55 | dataloader_pin_memory: true 56 | dataloader_num_workers: 0 57 | log_level: info 58 | -------------------------------------------------------------------------------- /config/train_cl.yaml: -------------------------------------------------------------------------------- 1 | includes: [global.yaml, petl.yaml,colossalai_strategy.yaml] 2 | 3 | # one of record lmdb arrow_stream arrow_file,parquet, 超大数据集可以使用 lmdb , 注 lmdb 存储空间比record大 4 | data_backend: parquet 5 | 6 | 7 | # 目前仅ddp 支持lora 8 | # one of ddp,gemini,zero2,zero2_cpu,3d 9 | strategy: ddp 10 | 11 | output_dir: ./outputs_cl 12 | overwrite_output_dir: true 13 | num_train_epochs: 20 14 | max_steps: -1 15 | save_safetensors: false 16 | save_strategy: steps 17 | save_steps: 1000 18 | save_total_limit: 10 19 | seed: 42 20 | fp16: true 21 | do_train: true 22 | train_file: 23 | - ../data/*.json 24 | 25 | do_eval: false 26 | do_predict: false 27 | per_device_train_batch_size: 2 28 | per_device_eval_batch_size: 2 29 | gradient_accumulation_steps: 1 30 | evaluation_strategy: 'no' 31 | eval_steps: 100 32 | 33 | # 优化器,如果策略使用 gemini , 则 optim one of adam_hybrid_cl,adam_cpu_cl,adam_fused_cl 34 | # 如果策略使用非 gemini ,则 optim one of follow 35 | # one of adam_hybrid_cl,adam_cpu_cl,adam_fused_cl,lamb,adamw_hf,adamw,adamw_torch,adamw_torch_fused,adamw_torch_xla,adamw_apex_fused, 36 | # adafactor,adamw_anyprecision,sgd,adagrad,adamw_bnb_8bit,adamw_8bit,lion,lion_8bit,lion_32bit, 37 | # paged_adamw_32bit,paged_adamw_8bit,paged_lion_32bit,paged_lion_8bit, 38 | # lamb_fused_dp adagrad_cpu_dp adam_cpu_dp adam_fused_dp 39 | 40 | optim: adam_hybrid_cl 41 | 42 | # one of linear,cosine,cosine_with_restarts,polynomial,constant_with_warmup,inverse_sqrt,reduce_lr_on_plateau 43 | lr_scheduler_type: cosine 44 | torch_compile: false 45 | learning_rate: 2.0e-05 46 | adam_beta1: 0.9 47 | adam_beta2: 0.999 48 | adam_epsilon: 1.0e-08 49 | max_grad_norm: 1.0 50 | weight_decay: 0.0 51 | warmup_ratio: 0.03 52 | logging_strategy: steps 53 | logging_steps: 10 54 | tf32: false 55 | gradient_checkpointing: false 56 | max_seq_length: 512 57 | max_target_length: 100 58 | 59 | do_lower_case: null 60 | use_fast_tokenizer: false 61 | dataloader_drop_last: true 62 | dataloader_pin_memory: true 63 | dataloader_num_workers: 0 64 | log_level: info 65 | -------------------------------------------------------------------------------- /config/train_hf.yaml: -------------------------------------------------------------------------------- 1 | includes: [global.yaml, petl.yaml] 2 | 3 | # one of record lmdb arrow_stream arrow_file,parquet, 超大数据集可以使用 lmdb , 注 lmdb 存储空间比record大 4 | data_backend: parquet 5 | output_dir: ./outputs_hf 6 | overwrite_output_dir: true 7 | num_train_epochs: 20 8 | max_steps: -1 9 | save_safetensors: false 10 | save_strategy: steps 11 | save_steps: 1000 12 | save_total_limit: 10 13 | seed: 42 14 | fp16: true 15 | do_train: true 16 | train_file: 17 | - ../data/*.json 18 | 19 | do_eval: false 20 | do_predict: false 21 | per_device_train_batch_size: 2 22 | per_device_eval_batch_size: 2 23 | gradient_accumulation_steps: 1 24 | evaluation_strategy: 'no' 25 | eval_steps: 100 26 | 27 | 28 | # adamw_hf , adamw_torch,adamw_torch_fused,adamw_torch_xla,adamw_apex_fused, 29 | # adafactor,adamw_anyprecision,sgd,adagrad,adamw_bnb_8bit,adamw_8bit,lion,lion_8bit,lion_32bit, 30 | # paged_adamw_32bit,paged_adamw_8bit,paged_lion_32bit,paged_lion_8bit, 31 | # lamb_fused_dp adagrad_cpu_dp adam_cpu_dp adam_fused_dp 32 | 33 | optim: adamw_torch 34 | 35 | # one of linear,cosine,cosine_with_restarts,polynomial,constant_with_warmup,inverse_sqrt,reduce_lr_on_plateau 36 | lr_scheduler_type: cosine 37 | torch_compile: false 38 | learning_rate: 2.0e-05 39 | adam_beta1: 0.9 40 | adam_beta2: 0.999 41 | adam_epsilon: 1.0e-08 42 | max_grad_norm: 1.0 43 | weight_decay: 0.0 44 | warmup_ratio: 0.03 45 | logging_strategy: steps 46 | logging_steps: 10 47 | tf32: false 48 | gradient_checkpointing: false 49 | max_seq_length: 512 50 | max_target_length: 100 51 | 52 | do_lower_case: null 53 | use_fast_tokenizer: false 54 | dataloader_drop_last: true 55 | dataloader_pin_memory: true 56 | dataloader_num_workers: 0 57 | log_level: info 58 | -------------------------------------------------------------------------------- /config/train_pl.yaml: -------------------------------------------------------------------------------- 1 | includes: [global.yaml, petl.yaml] 2 | 3 | devices: 1 4 | data_backend: parquet 5 | convert_onnx: false 6 | do_train: true 7 | train_file: 8 | - ../data/*.json 9 | 10 | max_epochs: 20 11 | max_steps: -1 12 | 13 | # *** optimizer 14 | # lamb,adamw_hf,adamw,adamw_torch,adamw_torch_fused,adamw_torch_xla,adamw_apex_fused, 15 | # adafactor,adamw_anyprecision,sgd,adagrad,adamw_bnb_8bit,adamw_8bit,lion,lion_8bit,lion_32bit, 16 | # paged_adamw_32bit,paged_adamw_8bit,paged_lion_32bit,paged_lion_8bit, 17 | # lamb_fused_dp adagrad_cpu_dp adam_cpu_dp adam_fused_dp 18 | 19 | # *** scheduler 20 | # linear,WarmupCosine,CAWR,CAL,Step,ReduceLROnPlateau, cosine,cosine_with_restarts,polynomial, 21 | # constant,constant_with_warmup,inverse_sqrt,reduce_lr_on_plateau 22 | 23 | # 'scheduler_type': 'linear',# one of [linear,WarmupCosine,CAWR,CAL,Step,ReduceLROnPlateau 24 | # 'scheduler': None, 25 | # 切换scheduler类型 26 | # 'scheduler_type': 'WarmupCosine', 27 | # 'scheduler': None, 28 | 29 | # 'scheduler_type': 'ReduceLROnPlateau', 30 | # 'scheduler': None, 31 | 32 | # 'scheduler_type': 'Step', 33 | # 'scheduler':{ 'decay_rate': 0.999,'decay_steps': 100,'verbose': True}, 34 | 35 | # 'scheduler_type': 'CAWR', 36 | # 'scheduler':{'T_mult': 1, 'rewarm_epoch_num': 2, 'verbose': True}, 37 | 38 | # 'scheduler_type': 'CAL', 39 | # 'scheduler': {'rewarm_epoch_num': 2,'verbose': True}, 40 | 41 | optimizer: lion 42 | scheduler_type: CAWR 43 | scheduler: 44 | T_mult: 1 45 | rewarm_epoch_num: 0.5 46 | verbose: false 47 | optimizer_betas: !!python/tuple 48 | - 0.9 49 | - 0.999 50 | train_batch_size: 2 51 | eval_batch_size: 2 52 | test_batch_size: 2 53 | learning_rate: 2.0e-05 54 | adam_epsilon: 1.0e-08 55 | gradient_accumulation_steps: 1 56 | max_grad_norm: 1.0 57 | weight_decay: 0 58 | warmup_steps: 0 59 | output_dir: ./outputs_pl 60 | max_seq_length: 512 61 | do_lower_case: null 62 | 63 | # 预测最大长度, 保留字段 64 | max_target_length: 100 65 | use_fast_tokenizer: false 66 | dataloader_drop_last: true 67 | dataloader_pin_memory: true 68 | dataloader_num_workers: 0 69 | -------------------------------------------------------------------------------- /data/make_data_example.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2023/2/24 12:50 3 | 4 | import json 5 | import sys 6 | import os 7 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__),'..'))) 8 | 9 | # 导入ToolsBuilder , 从兼容性考虑,还是从数据源直接构建, 10 | from data_tools import ToolsBuilder 11 | 12 | x0 = [ 13 | { 14 | "role": "system", 15 | "q": "You are a helpful assistant.", 16 | "a": "" 17 | }, 18 | { 19 | "role": "user", 20 | "q": "现在给我画个五彩斑斓的黑", 21 | "tools": """ 22 | [ 23 | { 24 | "name_for_human": "夸克搜索", 25 | "name_for_model": "quark_search", 26 | "description_for_model": "夸克搜索是一个通用搜索引擎,可用于访问互联网、查询百科知识、了解时事新闻等。", 27 | "parameters": [ 28 | { 29 | "name": "search_query", 30 | "description": "搜索关键词或短语", 31 | "required": true, 32 | "schema": { 33 | "type": "string" 34 | } 35 | } 36 | ] 37 | }, 38 | { 39 | "name_for_human": "通义万相", 40 | "name_for_model": "image_gen", 41 | "description_for_model": "通义万相是一个AI绘画(图像生成)服务,输入文本描述,返回根据文本作画得到的图片的URL", 42 | "parameters": [ 43 | { 44 | "name": "query", 45 | "description": "中文关键词,描述了希望图像具有什么内容", 46 | "required": true, 47 | "schema": { 48 | "type": "string" 49 | } 50 | } 51 | ] 52 | } 53 | ] 54 | """, 55 | "a": ''' 56 | Thought: 我应该使用通义万相API来生成一张五彩斑斓的黑的图片。 57 | Action: image_gen 58 | Action Input: {"query": "五彩斑斓的黑"} 59 | ''' 60 | }, 61 | { 62 | "role": 'observation', 63 | "q": ''' 64 | {"status_code": 200, "request_id": "3d894da2-0e26-9b7c-bd90-102e5250ae03", "code": null, "message": "", "output": {"task_id": "2befaa09-a8b3-4740-ada9-4d00c2758b05", "task_status": "SUCCEEDED", "results": [{"url": "https://dashscope-result-sh.oss-cn-shanghai.aliyuncs.com/1e5e2015/20230801/1509/6b26bb83-469e-4c70-bff4-a9edd1e584f3-1.png"}], "task_metrics": {"TOTAL": 1, "SUCCEEDED": 1, "FAILED": 0}}, "usage": {"image_count": 1}} 65 | ''', 66 | "a": ''' 67 | Thought: 我已经成功使用通义万相API生成了一张五彩斑斓的黑的图片。 68 | Final Answer: 我已经成功使用通义万相API生成了一张五彩斑斓的黑的图片https://dashscope-result-sh.oss-cn-shanghai.aliyuncs.com/1e5e2015/20230801/1509/6b26bb83-469e-4c70-bff4-a9edd1e584f3-1.png。 69 | ''' 70 | } 71 | ] 72 | 73 | 74 | x1 = [ 75 | { 76 | "role": "system", 77 | "q": "You are a helpful assistant.", 78 | "a": "" 79 | }, 80 | { 81 | "q": "从南京到上海的路线", 82 | "a": "\n".join([ 83 | "你好,南京到上海的路线如下:", 84 | "1. 南京到上海,可以乘坐南京地铁1号线,在南京站乘坐轨道交通1号线。", 85 | "2. 南京到浦东机场,可以搭乘上海地铁1号,在陆家嘴站乘坐地铁1线,在浦东国际机场站乘坐机场快线,前往上海浦东国际机场。", 86 | "3. 上海到南京,可以换乘上海地铁2号线,从南京站换乘地铁2线,再从南京南站换乘地铁1路,然后到达上海站" 87 | ]) 88 | } 89 | ] 90 | 91 | 92 | x2 = [ 93 | { 94 | "role": "system", 95 | "q": "You are a helpful assistant.", 96 | "a": "" 97 | }, 98 | { 99 | "q": "写一个诗歌,关于冬天", 100 | "a": "\n".join([ 101 | "冬夜寂静冷,", 102 | "云在天边飘,", "冰封白雪上, ", "寒冷像一场雪。", 103 | " ", 104 | "雪花融化成冰,", 105 | "像那雪花飘洒,", 106 | "在寒冷的冬天,", 107 | "感受春天的喜悦。", 108 | " 冬日里,", 109 | "风雪渐消,", 110 | "一片寂静,", 111 | "把快乐和温暖带回家。" 112 | ]) 113 | } 114 | ] 115 | 116 | 117 | x3 = [ 118 | { 119 | "role": "system", 120 | "q": "You are a helpful assistant.", 121 | "a": "" 122 | }, 123 | { 124 | "q": "晚上睡不着应该怎么办", 125 | "a": "\n".join([ 126 | "晚上睡不着可能会让你感到焦虑或不舒服,但以下是一些可以帮助你入睡的方法:", 127 | "", 128 | "1. 制定规律的睡眠时间表:保持规律的睡眠时间表可以帮助你建立健康的睡眠习惯,使你更容易入睡。尽量在每天的相同时间上床,并在同一时间起床。", 129 | "2. 创造一个舒适的睡眠环境:确保睡眠环境舒适,安静,黑暗且温度适宜。可以使用舒适的床上用品,并保持房间通风。", 130 | "3. 放松身心:在睡前做些放松的活动,例如泡个热水澡,听些轻柔的音乐,阅读一些有趣的书籍等,有助于缓解紧张和焦虑,使你更容易入睡。", 131 | "4. 避免饮用含有咖啡因的饮料:咖啡因是一种刺激性物质,会影响你的睡眠质量。尽量避免在睡前饮用含有咖啡因的饮料,例如咖啡,茶和可乐。", 132 | "5. 避免在床上做与睡眠无关的事情:在床上做些与睡眠无关的事情,例如看电影,玩游戏或工作等,可能会干扰你的睡眠。", 133 | "6. 尝试呼吸技巧:深呼吸是一种放松技巧,可以帮助你缓解紧张和焦虑,使你更容易入睡。试着慢慢吸气,保持几秒钟,然后缓慢呼气。", 134 | "", 135 | "如果这些方法无法帮助你入睡,你可以考虑咨询医生或睡眠专家,寻求进一步的建议。" 136 | ]) 137 | } 138 | ] 139 | 140 | 141 | 142 | 143 | x = [x0,x1,x2,x3] 144 | 145 | # 重新构建数据集 146 | for paragraph in x: 147 | for node in paragraph: 148 | role = node.get("role", "user") 149 | tools = node.pop("tools",None) 150 | q = node["q"] 151 | if tools is not None: 152 | q = ToolsBuilder.build(tools,query=q) 153 | node["q"] = q 154 | if role in ["observation","function"]: 155 | node["q"] = f'Observation: {q}' 156 | 157 | 158 | with open('./finetune_train_paragraph.json',mode='w',encoding='utf-8',newline='\n') as f: 159 | index = 0 160 | for i in range(50): 161 | for j in range(len(x)): 162 | index += 1 163 | 164 | conversations = { 165 | "id": index, 166 | "paragraph": x[j] 167 | } 168 | f.write(json.dumps(conversations,ensure_ascii=False) + '\n' ) 169 | 170 | 171 | with open('./finetune_train_conversations.json',mode='w',encoding='utf-8',newline='\n') as f: 172 | index = 0 173 | for i in range(50): 174 | for j in range(len(x)): 175 | index += 1 176 | 177 | conversation = [] 178 | for item in x[j]: 179 | role = item.get("role","user") 180 | if role == "system": 181 | conversation.append( { 182 | "from": item.get("role","user"), 183 | "value": item["q"] 184 | }) 185 | else: 186 | conversation.append({ 187 | "from": item.get("role", "user"), 188 | "value": item["q"] 189 | }) 190 | conversation.append({ 191 | "from": "assistant", 192 | "value": item["a"] 193 | }) 194 | 195 | conversations = { 196 | "id": index, 197 | "conversations": conversation 198 | } 199 | f.write(json.dumps(conversations,ensure_ascii=False) + '\n' ) -------------------------------------------------------------------------------- /data_processer.py: -------------------------------------------------------------------------------- 1 | # @Time : 2023/3/25 18:36 2 | # @Author : tk 3 | import copy 4 | import json 5 | import random 6 | from distutils.command.config import config 7 | from enum import Enum 8 | from typing import Tuple, List 9 | 10 | import numpy as np 11 | # from deep_training.zoo.model_zoo.llm.qwen_generation_utils import make_context 12 | from transformers import PreTrainedTokenizer 13 | 14 | 15 | class DataStrategy(Enum): 16 | truncation = 1 17 | siding = 2 18 | 19 | 20 | def make_context( 21 | tokenizer: PreTrainedTokenizer, 22 | query: str, 23 | history: List[Tuple[str, str]] = None, 24 | system: str = "", 25 | max_window_size: int = 6144, 26 | chat_format: str = "chatml", 27 | ): 28 | if history is None: 29 | history = [] 30 | 31 | if chat_format == "chatml": 32 | im_start, im_end = "<|im_start|>", "<|im_end|>" 33 | im_start_tokens = tokenizer.encode(im_start, add_special_tokens=False) 34 | im_end_tokens = tokenizer.encode(im_end, add_special_tokens=False) 35 | nl_tokens = tokenizer.encode("\n", add_special_tokens=False) 36 | 37 | def _tokenize_str(role, content): 38 | return f"{role}\n{content}", tokenizer.encode( 39 | role 40 | ) + nl_tokens + tokenizer.encode(content) 41 | 42 | system_text, system_tokens_part = _tokenize_str("system", system) 43 | system_tokens = im_start_tokens + system_tokens_part + im_end_tokens 44 | 45 | raw_text = "" 46 | context_tokens = [] 47 | 48 | for turn_query, turn_response in reversed(history): 49 | query_text, query_tokens_part = _tokenize_str("user", turn_query) 50 | query_tokens = im_start_tokens + query_tokens_part + im_end_tokens 51 | response_text, response_tokens_part = _tokenize_str( 52 | "assistant", turn_response 53 | ) 54 | response_tokens = im_start_tokens + response_tokens_part + im_end_tokens 55 | 56 | next_context_tokens = nl_tokens + query_tokens + nl_tokens + response_tokens 57 | prev_chat = ( 58 | f"\n{im_start}{query_text}{im_end}\n{im_start}{response_text}{im_end}" 59 | ) 60 | 61 | current_context_size = ( 62 | len(system_tokens) + len(next_context_tokens) + len(context_tokens) 63 | ) 64 | if current_context_size < max_window_size: 65 | context_tokens = next_context_tokens + context_tokens 66 | raw_text = prev_chat + raw_text 67 | else: 68 | break 69 | 70 | context_tokens = system_tokens + context_tokens 71 | raw_text = f"{im_start}{system_text}{im_end}" + raw_text 72 | context_tokens += ( 73 | nl_tokens 74 | + im_start_tokens 75 | + _tokenize_str("user", query)[1] 76 | + im_end_tokens 77 | + nl_tokens 78 | + im_start_tokens 79 | + tokenizer.encode("assistant") 80 | + nl_tokens 81 | ) 82 | raw_text += f"\n{im_start}user\n{query}{im_end}\n{im_start}assistant\n" 83 | 84 | elif chat_format == "raw": 85 | raw_text = query 86 | context_tokens = tokenizer.encode(raw_text) 87 | else: 88 | raise NotImplementedError(f"Unknown chat format {chat_format!r}") 89 | 90 | return raw_text, context_tokens 91 | 92 | 93 | 94 | 95 | class TokenIdsMaker: 96 | 97 | @classmethod 98 | def final(cls, tokenizer,config, input_ids, labels, max_seq_length): 99 | seqlen = np.asarray(len(input_ids), dtype=np.int32) 100 | pad_len = max_seq_length - seqlen 101 | input_ids = np.asarray(input_ids, dtype=np.int32) 102 | labels = np.asarray(labels, dtype=np.int32) 103 | if pad_len: 104 | pad_val = tokenizer.eos_token_id 105 | input_ids = np.pad(input_ids, (0, pad_len), 'constant', constant_values=(pad_val, pad_val)) 106 | labels = np.pad(labels, (0, pad_len), 'constant', constant_values=(-100, -100)) 107 | d = { 108 | 'input_ids': input_ids, 109 | 'labels': labels, 110 | 'seqlen': seqlen 111 | } 112 | return d 113 | 114 | @classmethod 115 | def tunction(cls, tokenizer: PreTrainedTokenizer,config, paragraph, max_seq_length, sup=True): 116 | sptoken = [] 117 | ds = [] 118 | prefix = None 119 | history = [] 120 | for sid,(role,q,a) in enumerate(paragraph): 121 | if role == 'system': 122 | prefix = q 123 | continue 124 | # 从兼容性考虑,预处理从数据源构建 125 | # if tools is not None: 126 | # q = ToolsBuilder.build(tools,query=q) 127 | # 128 | # if role in ['observation','Observation']: 129 | # q = f'Observation: {q}' 130 | 131 | history += [(q,a)] 132 | _,a_ids = make_context(tokenizer=tokenizer,query=q,history=history[:-1], 133 | system = prefix or "You are a helpful assistant." , 134 | max_window_size = 6144, 135 | chat_format = "chatml",) 136 | b_ids = tokenizer.encode(a,add_special_tokens=False) 137 | 138 | while len(a_ids) + len(b_ids) > max_seq_length - len(sptoken) - 1: 139 | if len(b_ids) > len(a_ids): 140 | b_ids.pop(-1) 141 | else: 142 | a_ids.pop(0) 143 | b_ids += [ tokenizer.eos_token_id ] 144 | input_ids = a_ids + b_ids 145 | labels = copy.deepcopy(input_ids) if not sup else [ -100 ] * len(a_ids) + copy.deepcopy(b_ids) 146 | input_ids = sptoken + input_ids 147 | labels = sptoken + labels if not sup else [ -100 ] * len(sptoken) + labels 148 | assert len(input_ids) <= max_seq_length 149 | ds.append(cls.final(tokenizer,config, input_ids, labels, max_seq_length)) 150 | return ds 151 | 152 | 153 | @classmethod 154 | def slidding(cls, tokenizer: PreTrainedTokenizer,config, paragraph, max_seq_length, sliding_size = None,src_max_length=None,dst_max_length=None,sup=True): 155 | if sliding_size is None: 156 | sliding_size = max_seq_length 157 | ds = [] 158 | sptoken = [] 159 | prefix = None 160 | history = [] 161 | for sid, (role, q, a) in enumerate(paragraph): 162 | if role == 'system': 163 | prefix = q 164 | continue 165 | 166 | # 从兼容性考虑,预处理从数据源构建 167 | # if tools is not None: 168 | # q = ToolsBuilder.build(tools, query=q) 169 | 170 | # if role in ['observation', 'Observation']: 171 | # q = f'Observation: {q}' 172 | 173 | history += [(q, a)] 174 | _, a_ids = make_context(tokenizer=tokenizer, query=q, history=history[:-1], 175 | system=prefix or "You are a helpful assistant.", 176 | max_window_size=6144, 177 | chat_format="chatml", ) 178 | b_ids = tokenizer.encode(a,add_special_tokens=False) 179 | 180 | if src_max_length and src_max_length > 0: 181 | a_ids = a_ids[ :src_max_length ] 182 | if dst_max_length and dst_max_length > 0: 183 | b_ids = b_ids[ :dst_max_length ] 184 | 185 | input_ids_qa = a_ids + b_ids + [tokenizer.eos_token_id] 186 | if sup: 187 | labels_all = [-100] * len(a_ids) + b_ids 188 | else: 189 | labels_all = copy.deepcopy(input_ids_qa) 190 | 191 | pos = 0 192 | while pos < len(input_ids_qa): 193 | input_ids = input_ids_qa[pos:pos + max_seq_length - len(sptoken)] 194 | labels = labels_all[pos:pos + max_seq_length - len(sptoken)] 195 | 196 | pos += sliding_size 197 | if np.all(np.asarray(labels) == -100): 198 | continue 199 | input_ids = sptoken + input_ids 200 | labels = sptoken + labels if not sup else [ -100 ] * len(sptoken) + labels 201 | ds.append(cls.final(tokenizer, config,input_ids, labels, max_seq_length)) 202 | return ds 203 | -------------------------------------------------------------------------------- /data_tools.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author : ssbuild 3 | # @Time : 2023/11/9 10:40 4 | import json 5 | 6 | class ToolsBuilder: 7 | TOOL_DESC = """{name_for_model}: Call this tool to interact with the {name_for_human} API. What is the {name_for_human} API useful for? {description_for_model} Parameters: {parameters} Format the arguments as a JSON object.""" 8 | 9 | REACT_PROMPT = """Answer the following questions as best you can. You have access to the following tools: 10 | 11 | {tool_descs} 12 | 13 | Use the following format: 14 | 15 | Question: the input question you must answer 16 | Thought: you should always think about what to do 17 | Action: the action to take, should be one of [{tool_names}] 18 | Action Input: the input to the action 19 | Observation: the result of the action 20 | ... (this Thought/Action/Action Input/Observation can be repeated zero or more times) 21 | Thought: I now know the final answer 22 | Final Answer: the final answer to the original input question 23 | 24 | Begin! 25 | 26 | Question: {query}""" 27 | 28 | @classmethod 29 | def build(cls,tools,query): 30 | tools = json.loads(tools) 31 | TOOL_DESC = ToolsBuilder.TOOL_DESC 32 | REACT_PROMPT = ToolsBuilder.REACT_PROMPT 33 | tool_descs = [] 34 | tool_names = [] 35 | for info in tools: 36 | tool_descs.append( 37 | TOOL_DESC.format( 38 | name_for_model=info['name_for_model'], 39 | name_for_human=info['name_for_human'], 40 | description_for_model=info['description_for_model'], 41 | parameters=json.dumps( 42 | info['parameters'], ensure_ascii=False), 43 | ) 44 | ) 45 | tool_names.append(info['name_for_model']) 46 | tool_descs = '\n\n'.join(tool_descs) 47 | tool_names = ','.join(tool_names) 48 | 49 | prompt = REACT_PROMPT.format(tool_descs=tool_descs, tool_names=tool_names, query=query) 50 | return prompt 51 | -------------------------------------------------------------------------------- /data_utils.py: -------------------------------------------------------------------------------- 1 | # @Time : 2023/1/22 16:22 2 | # @Author : tk 3 | # @FileName: data_utils.py 4 | import glob 5 | import sys 6 | import os 7 | from functools import cache 8 | 9 | sys.path.append(os.path.abspath(os.path.dirname(__file__))) 10 | 11 | import copy 12 | import json 13 | import typing 14 | import numpy as np 15 | import torch 16 | from deep_training.data_helper import DataHelper, ModelArguments, TrainingArguments, DataArguments, TrainingArgumentsHF, \ 17 | TrainingArgumentsCL, TrainingArgumentsAC 18 | from fastdatasets.record import load_dataset as Loader, RECORD, WriterObject, gfile 19 | from tqdm import tqdm 20 | from transformers import HfArgumentParser, PreTrainedTokenizer,PretrainedConfig 21 | from data_processer import DataStrategy, TokenIdsMaker 22 | from deep_training.zoo.model_zoo.llm.llm_model import PetlArguments,PromptArguments 23 | from config import * 24 | data_conf = { 25 | 'strategy': DataStrategy.truncation, # 数据策略选项 26 | DataStrategy.truncation: { 27 | 'sup': True, # 是否监督训练 28 | }, 29 | DataStrategy.siding: { 30 | 'stride': int(config_args['max_seq_length'] / 3 * 2), 31 | 'sup': True, # 是否监督模式 32 | "src_max_length": config_args['max_seq_length'] - 10, 33 | "dst_max_length": None, 34 | }, 35 | 36 | } 37 | 38 | 39 | def preprocess(text): 40 | #text = text.replace("\n", "\\n").replace("\t", "\\t") 41 | return text 42 | 43 | def postprocess(text): 44 | # return text.replace("\\n", "\n").replace("\\t", "\t") 45 | return text 46 | 47 | 48 | 49 | 50 | class NN_DataHelper(DataHelper): 51 | index = 1 52 | 53 | def load_tokenizer_and_config(self, *args, tokenizer_kwargs=None, config_kwargs=None, **kwargs): 54 | if config_kwargs is None: 55 | config_kwargs = {} 56 | 57 | if tokenizer_kwargs is None: 58 | tokenizer_kwargs = {} 59 | 60 | model_args = self.model_args 61 | base_path = model_args.config_name or model_args.model_name_or_path 62 | if os.path.isfile(base_path): 63 | base_path = os.path.dirname(base_path) 64 | 65 | # last_name = base_path.rsplit('/')[-1].lower() 66 | # if "yi" in last_name: 67 | gen_file = os.path.join(base_path, "generation_config.json") 68 | if os.path.exists(gen_file): 69 | with open(gen_file, mode='r', encoding='utf-8') as f: 70 | gen_args = json.loads(f.read()) 71 | gen_args_new = {} 72 | if "bos_token_id" in gen_args: 73 | gen_args_new["bos_token_id"] = gen_args["bos_token_id"] 74 | 75 | if "pad_token_id" in gen_args: 76 | gen_args_new["pad_token_id"] = gen_args["pad_token_id"] 77 | 78 | if "eos_token_id" in gen_args: 79 | gen_args_new["eos_token_id"] = gen_args["eos_token_id"] 80 | 81 | config_kwargs.update(gen_args_new) 82 | 83 | # if 'trust_remote_code' not in config_kwargs: 84 | # config_kwargs.update({"trust_remote_code": True, "local_files_only": True}) 85 | # if 'trust_remote_code' not in tokenizer_kwargs: 86 | # tokenizer_kwargs.update({"trust_remote_code": True, "local_files_only": True}) 87 | 88 | return super().load_tokenizer_and_config(*args, tokenizer_kwargs=tokenizer_kwargs, config_kwargs=config_kwargs, 89 | **kwargs) 90 | 91 | 92 | 93 | def on_data_ready(self): 94 | self.index = -1 95 | 96 | # 切分词 97 | def on_data_process(self, data: typing.Any, mode: str): 98 | self.index += 1 99 | 100 | 101 | max_seq_length = self.max_seq_length_dict[mode] 102 | tokenizer = self.tokenizer # noqa 103 | config = self.config # noqa 104 | 105 | strategy = data_conf['strategy'] 106 | if strategy == DataStrategy.truncation: 107 | ds = TokenIdsMaker.tunction(tokenizer,config,data, max_seq_length,**data_conf[strategy]) 108 | elif strategy == DataStrategy.siding: 109 | ds = TokenIdsMaker.slidding(tokenizer,config,data, max_seq_length, **data_conf[strategy]) 110 | else: 111 | raise ValueError('Invlid strategy',strategy) 112 | 113 | if not ds: 114 | return None 115 | 116 | if self.index < 3: 117 | print(ds[0]) 118 | return ds 119 | 120 | def _get_paragraph(self, lines): 121 | D = [] 122 | for line_id, line in enumerate(lines): 123 | jd = json.loads(line) 124 | if not jd: 125 | continue 126 | paragraph = jd['paragraph'] 127 | if line_id < 10: 128 | print(paragraph) 129 | 130 | paragraph = [(session.get("role", ""), preprocess(session['q']), 131 | preprocess('\n'.join(session['a'])) if isinstance(session['a'], list) else preprocess( 132 | session['a'])) 133 | for session in paragraph] 134 | sub = [] 135 | # 自行做模板 136 | for (role, q, a) in paragraph: 137 | # 不是system prompt answer 必须存在 138 | if role != "system": 139 | assert len(a), ValueError('answer cannot empty') 140 | sub.append((role, q, a)) 141 | D.append(copy.deepcopy(sub)) 142 | sub.clear() 143 | return D 144 | 145 | def _get_messages(self, lines): 146 | D = [] 147 | for line_id, line in enumerate(lines): 148 | jd = json.loads(line) 149 | if not jd: 150 | continue 151 | conversations = jd['conversations'] 152 | if line_id < 10: 153 | print(conversations) 154 | 155 | cid = 0 156 | sub = [] 157 | while cid < len(conversations): 158 | m = conversations[cid] 159 | cid += 1 160 | role = m["from"] 161 | q = preprocess(m["value"]) 162 | if role == "system": 163 | a = "" 164 | sub.append((role, q, a)) 165 | continue 166 | assert role in ['user', 'observation', 'function'] 167 | m = conversations[cid] 168 | cid += 1 169 | assert m["from"] == "assistant" 170 | a = preprocess(m["value"]) 171 | assert len(a), ValueError('answer cannot empty') 172 | sub.append((role, q, a)) 173 | D.append(sub) 174 | return D 175 | # 读取文件 176 | 177 | def on_get_corpus(self, files: typing.List, mode: str): 178 | D = [] 179 | files = sum([glob.glob(file) for file in files], []) 180 | for file in files: 181 | with open(file, mode='r', encoding='utf-8', newline='\n') as f: 182 | lines = f.readlines() 183 | is_new = False 184 | if len(lines) > 0: 185 | is_new = 'conversations' in json.loads(lines[0]) 186 | if is_new: 187 | D.extend(self._get_messages(lines)) 188 | else: 189 | D.extend(self._get_paragraph(lines)) 190 | return D 191 | 192 | def collate_fn(self,batch): 193 | o = {} 194 | for i, b in enumerate(batch): 195 | if i == 0: 196 | for k in b: 197 | o[k] = [torch.tensor(b[k])] 198 | else: 199 | for k in b: 200 | o[k].append(torch.tensor(b[k])) 201 | for k in o: 202 | o[k] = torch.stack(o[k]) 203 | 204 | seqlens = o.pop('seqlen') 205 | max_len = torch.max(seqlens).tolist() 206 | input_ids = o['input_ids'][:, :max_len] 207 | attention_mask = torch.zeros_like(input_ids,dtype=torch.bool) 208 | for i,seqlen in enumerate(seqlens): 209 | attention_mask[i,:seqlen] = 1 210 | o['input_ids'] = input_ids.long() 211 | o['attention_mask'] = attention_mask 212 | o['labels'] = o['labels'][:, :max_len].long() 213 | return o 214 | 215 | def make_dataset_all(self): 216 | data_args = self.data_args 217 | 218 | # schema for arrow parquet 219 | schema = { 220 | "input_ids": "int32_list", 221 | "labels": "int32_list", 222 | "seqlen": "int32_list", 223 | } 224 | # 缓存数据集 225 | if data_args.do_train: 226 | self.make_dataset_with_args(data_args.train_file, mixed_data=False, shuffle=True, 227 | mode='train',schema=schema) 228 | if data_args.do_eval: 229 | self.make_dataset_with_args(data_args.eval_file, mode='eval',schema=schema) 230 | if data_args.do_test: 231 | self.make_dataset_with_args(data_args.test_file, mode='test',schema=schema) 232 | 233 | # 记录缓存文件 234 | with open(os.path.join(data_args.output_dir, 'intermediate_file_index.json'), mode='w', 235 | encoding='utf-8') as f: 236 | f.write(json.dumps({ 237 | "train_files": self.train_files, 238 | "eval_files": self.eval_files, 239 | "test_files": self.test_files, 240 | }, ensure_ascii=False)) 241 | 242 | @cache 243 | def load_dataset_files(self): 244 | data_args = self.data_args 245 | if not data_args.convert_file: 246 | return { 247 | "train_files": self.train_files, 248 | "eval_files": self.eval_files, 249 | "test_files": self.test_files, 250 | } 251 | filename = os.path.join(data_args.output_dir, 'intermediate_file_index.json') 252 | assert os.path.exists(filename), 'make you dataset firstly' 253 | with open(filename, mode='r', encoding='utf-8') as f: 254 | return json.loads(f.read()) 255 | 256 | if __name__ == '__main__': 257 | if global_args["trainer_backend"] == "hf": 258 | parser = HfArgumentParser((ModelArguments, TrainingArgumentsHF, DataArguments, PetlArguments, PromptArguments), 259 | conflict_handler='resolve') 260 | model_args, training_args, data_args, lora_args, prompt_args = parser.parse_dict(config_args,allow_extra_keys=True, ) 261 | elif global_args[ "trainer_backend" ] == "pl": 262 | parser = HfArgumentParser((ModelArguments, TrainingArguments, DataArguments, PetlArguments, PromptArguments)) 263 | model_args, training_args, data_args, lora_args, _ = parser.parse_dict(config_args) 264 | elif global_args["trainer_backend"] == "cl": 265 | parser = HfArgumentParser((ModelArguments, TrainingArgumentsCL, DataArguments, PetlArguments, PromptArguments), 266 | conflict_handler='resolve') 267 | model_args, training_args, data_args, lora_args, prompt_args = parser.parse_dict(config_args, 268 | allow_extra_keys=True, ) 269 | else: 270 | parser = HfArgumentParser((ModelArguments, TrainingArgumentsAC, DataArguments, PetlArguments, PromptArguments), 271 | conflict_handler='resolve') 272 | model_args, training_args, data_args, lora_args, prompt_args = parser.parse_dict(config_args, 273 | allow_extra_keys=True, ) 274 | 275 | lora_args = lora_args.config 276 | dataHelper = NN_DataHelper(model_args, training_args, data_args) 277 | tokenizer, config, _,_ = dataHelper.load_tokenizer_and_config() 278 | 279 | # 缓存数据集 280 | print(f'to make dataset is overwrite_cache {data_args.overwrite_cache}') 281 | dataHelper.make_dataset_all() 282 | 283 | print('make dataset complete!') 284 | print('check data !') 285 | dataset = dataHelper.load_sequential_sampler(dataHelper.load_dataset_files()["train_files"], 286 | with_load_memory=data_args.data_backend == 'record', 287 | batch_size=1, 288 | collate_fn=dataHelper.collate_fn) 289 | 290 | print('total', len(dataset)) 291 | for i, d in enumerate(dataset): 292 | print(d) 293 | if i > 3: 294 | break 295 | 296 | -------------------------------------------------------------------------------- /infer/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author : ssbuild 3 | # @Time : 2023/10/12 16:35 4 | -------------------------------------------------------------------------------- /infer/api_lora_demo.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2023/4/4 14:46 3 | import os 4 | import sys 5 | sys.path.append(os.path.join(os.path.dirname(__file__),'..')) 6 | 7 | from fastapi import FastAPI, Request 8 | import uvicorn, json, datetime 9 | import torch 10 | 11 | from deep_training.data_helper import ModelArguments, DataArguments 12 | from transformers import PreTrainedTokenizer, HfArgumentParser, PretrainedConfig, GenerationConfig 13 | from deep_training.zoo.model_zoo.llm.llm_model import PetlArguments,LoraConfig,PromptArguments 14 | from transformers import HfArgumentParser 15 | 16 | from data_utils import config_args, NN_DataHelper,global_args 17 | from deep_training.zoo.model_zoo.llm.llm_model import MyTransformer 18 | 19 | DEVICE = "cuda" 20 | DEVICE_ID = "0" 21 | CUDA_DEVICE = f"{DEVICE}:{DEVICE_ID}" if DEVICE_ID else DEVICE 22 | 23 | 24 | def torch_gc(): 25 | if torch.cuda.is_available(): 26 | with torch.cuda.device(CUDA_DEVICE): 27 | torch.cuda.empty_cache() 28 | torch.cuda.ipc_collect() 29 | 30 | 31 | app = FastAPI() 32 | 33 | 34 | @app.post("/") 35 | async def create_item(request: Request): 36 | global model, tokenizer 37 | json_post_raw = await request.json() 38 | json_post = json.dumps(json_post_raw) 39 | json_args= json.loads(json_post) 40 | prompt = json_args.pop('prompt') 41 | history = json_args.pop('history',None) 42 | 43 | gen_args = { 44 | "chat_format": "chatml", 45 | "decay_bound": 0.0, 46 | "decay_factor": 1.0, 47 | "eos_token_id": 151643, 48 | "factual_nucleus_sampling": False, 49 | "max_context_size": 1024, 50 | "max_generate_size": 512, 51 | "max_new_tokens": 512, 52 | "pad_token_id": 151643, 53 | # "stop_words_ids": [[151643]], 54 | "do_sample": True, 55 | "top_k": 0, 56 | "top_p": 0.8, 57 | } 58 | gen_args.update(json_args) 59 | generation_config = GenerationConfig(**gen_args) 60 | 61 | response, history = model.chat(tokenizer, prompt, history=[],generation_config=generation_config) 62 | now = datetime.datetime.now() 63 | time = now.strftime("%Y-%m-%d %H:%M:%S") 64 | answer = { 65 | "response": response, 66 | "history": history, 67 | "status": 200, 68 | "time": time 69 | } 70 | log = "[" + time + "] " + '", prompt:"' + prompt + '", response:"' + repr(response) + '"' 71 | print(log) 72 | torch_gc() 73 | return answer 74 | 75 | 76 | if __name__ == '__main__': 77 | config_args['seed'] = None 78 | parser = HfArgumentParser((ModelArguments, )) 79 | (model_args,) = parser.parse_dict(config_args,allow_extra_keys=True) 80 | 81 | 82 | dataHelper = NN_DataHelper(model_args, None, None) 83 | 84 | tokenizer, _, _, _ = dataHelper.load_tokenizer_and_config() 85 | 86 | ckpt_dir = './best_ckpt/last' 87 | config = PretrainedConfig.from_pretrained(ckpt_dir) 88 | 89 | 90 | lora_args = PetlArguments.from_pretrained(ckpt_dir) 91 | 92 | assert lora_args.inference_mode == True 93 | 94 | # new_num_tokens = config.vocab_size 95 | # if config.task_specific_params is not None and config.task_specific_params.get('vocab_size', None) is not None: 96 | # config.vocab_size = config.task_specific_params['vocab_size'] 97 | 98 | pl_model = MyTransformer(config=config, model_args=model_args, lora_args=lora_args, 99 | torch_dtype=torch.float16, 100 | # new_num_tokens=new_num_tokens, # 扩充词 101 | 102 | # # device_map="auto", 103 | # device_map={"": 0}, # 第一块卡 104 | ) 105 | # 加载lora权重 106 | pl_model.load_sft_weight(ckpt_dir) 107 | 108 | model = pl_model.get_llm_model() 109 | # 按需修改 110 | model.half().cuda() 111 | model = model.eval() 112 | 113 | uvicorn.run(app, host='0.0.0.0', port=8000, workers=1) -------------------------------------------------------------------------------- /infer/evaluate.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2023/3/29 11:25 3 | import numpy as np 4 | from sacrebleu.metrics import BLEU 5 | from rouge import Rouge 6 | 7 | 8 | def evaluate(data): 9 | bleu_scorer_obj = BLEU() 10 | rouge_scorer_obj = Rouge() 11 | bleu_score = [] 12 | for d in data: 13 | score = bleu_scorer_obj.sentence_score( 14 | hypothesis=d['text'], 15 | references=d['ref'], 16 | ) 17 | bleu_score.append(score.score) 18 | 19 | bleu_score = np.average(np.asarray(bleu_score)) 20 | 21 | rouge_score = [] 22 | for d in data: 23 | score = rouge_scorer_obj.get_scores( 24 | hyps=[d['text']], 25 | refs=d['ref'], 26 | ) 27 | rouge_score.append(score[0]["rouge-l"]["f"]) 28 | 29 | rouge_score = np.average(np.asarray(rouge_score)) 30 | 31 | return { 32 | "bleu_score": bleu_score, 33 | "rouge-l_score": rouge_score 34 | } 35 | 36 | 37 | 38 | if __name__ == '__main__': 39 | data = [ 40 | { 41 | "text": "to make people trustworthy you need to trust them", 42 | "ref": ["the way to make people trustworthy is to trust them"] 43 | }, 44 | ] 45 | 46 | result = evaluate(data) 47 | print(result) 48 | 49 | 50 | -------------------------------------------------------------------------------- /infer/infer.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2023/3/9 15:29 3 | import os 4 | import sys 5 | sys.path.append(os.path.join(os.path.dirname(__file__),'..')) 6 | 7 | import torch 8 | from deep_training.data_helper import ModelArguments, TrainingArguments, DataArguments 9 | from transformers import HfArgumentParser, BitsAndBytesConfig, GenerationConfig, PreTrainedTokenizer 10 | from data_utils import config_args, NN_DataHelper 11 | from deep_training.zoo.model_zoo.llm.llm_model import MyTransformer,PetlArguments 12 | from data_processer import make_context 13 | 14 | 15 | 16 | if __name__ == '__main__': 17 | config_args['seed'] = None 18 | parser = HfArgumentParser((ModelArguments,)) 19 | (model_args,) = parser.parse_dict(config_args, allow_extra_keys=True) 20 | 21 | dataHelper = NN_DataHelper(model_args) 22 | tokenizer: PreTrainedTokenizer 23 | tokenizer, config, _,_ = dataHelper.load_tokenizer_and_config() 24 | 25 | # quantization configuration for NF4 (4 bits) 26 | quantization_config = BitsAndBytesConfig( 27 | load_in_4bit=True, 28 | bnb_4bit_quant_type='nf4', 29 | bnb_4bit_compute_dtype=torch.bfloat16 30 | ) 31 | 32 | # # quantization configuration for Int8 (8 bits) 33 | # quantization_config = BitsAndBytesConfig(load_in_8bit=True) 34 | 35 | 36 | pl_model = MyTransformer(config=config, model_args=model_args, 37 | torch_dtype=torch.float16, 38 | # device_map="cuda:0", 39 | # quantization_config=quantization_config, 40 | ) 41 | 42 | model = pl_model.get_llm_model() 43 | 44 | # if hasattr(model,'is_loaded_in_4bit') or hasattr(model,'is_loaded_in_8bit'): 45 | # model.eval().cuda() 46 | # else: 47 | # model.half().eval().cuda() 48 | 49 | model = model.eval() 50 | model.requires_grad_(False) 51 | 52 | model.half().cuda() 53 | model = model.eval() 54 | 55 | text_list = [ 56 | "写一个诗歌,关于冬天", 57 | "晚上睡不着应该怎么办", 58 | ] 59 | generation_config = GenerationConfig(**{ 60 | "chat_format": "chatml", 61 | "eos_token_id": tokenizer.eos_token_id, 62 | "max_new_tokens": 512, 63 | "pad_token_id": tokenizer.eos_token_id, 64 | #"stop_words_ids": [[151643]], 65 | "do_sample": True, 66 | "top_k": 0, 67 | "top_p": 0.8, 68 | }) 69 | 70 | for input in text_list: 71 | _, input_ids = make_context(tokenizer, input) 72 | input_ids = torch.tensor(input_ids) 73 | input_ids = input_ids.unsqueeze(0) 74 | response = model.generate(inputs=input_ids.cuda(), ) 75 | outputs = response.tolist()[0][len(input_ids[0]):] 76 | response = tokenizer.decode(outputs, skip_special_tokens=True) 77 | 78 | print("input", input) 79 | print("response", response) 80 | 81 | # response, history = base_model.chat(tokenizer, "写一个诗歌,关于冬天", history=[],max_length=30) 82 | # print('写一个诗歌,关于冬天',' ',response) 83 | 84 | -------------------------------------------------------------------------------- /infer/infer_finetuning.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2023/3/9 15:29 3 | import os 4 | import sys 5 | sys.path.append(os.path.join(os.path.dirname(__file__),'..')) 6 | 7 | import torch 8 | from deep_training.data_helper import ModelArguments, TrainingArguments, DataArguments 9 | from transformers import HfArgumentParser, GenerationConfig,AutoConfig 10 | from data_utils import config_args, NN_DataHelper, get_deepspeed_config 11 | from deep_training.zoo.model_zoo.llm.llm_model import MyTransformer,PetlArguments 12 | from data_processer import make_context 13 | 14 | deep_config = get_deepspeed_config() 15 | 16 | 17 | if __name__ == '__main__': 18 | config_args['seed'] = None 19 | config_args['model_name_or_path'] = None 20 | 21 | parser = HfArgumentParser((ModelArguments, )) 22 | (model_args,) = parser.parse_dict(config_args,allow_extra_keys=True) 23 | 24 | 25 | dataHelper = NN_DataHelper(model_args) 26 | tokenizer, _, _, _ = dataHelper.load_tokenizer_and_config() 27 | 28 | ###################### 注意 选最新权重 29 | #选择最新的权重 , 根据时间排序 选最新的 30 | config = AutoConfig.from_pretrained('./best_ckpt') 31 | 32 | # new_num_tokens = config.vocab_size 33 | # if config.task_specific_params is not None and config.task_specific_params.get('vocab_size', None) is not None: 34 | # config.vocab_size = config.task_specific_params['vocab_size'] 35 | 36 | pl_model = MyTransformer(config=config, model_args=model_args, 37 | torch_dtype=torch.float16, 38 | # new_num_tokens=new_num_tokens,#扩充词 39 | ) 40 | if deep_config is None: 41 | train_weight = './best_ckpt/last-v3.ckpt' 42 | else: 43 | #使用转换脚本命令 生成 ./best_ckpt/last/best.pt 权重文件 44 | # cd best_ckpt/last 45 | # python zero_to_fp32.py . best.pt 46 | train_weight = './best_ckpt/last/best.pt' 47 | 48 | #加载微调权重 49 | pl_model.load_sft_weight(train_weight,strict=True) 50 | 51 | model = pl_model.get_llm_model() 52 | #保存hf权重 53 | #config.save_pretrained('convert/') 54 | 55 | # 保存sft p-tuning-v2 权重 56 | # pl_model.save_sft_weight('convert/pytorch_model_sft_ptv2.bin') 57 | 58 | #保存sft权重 59 | # pl_model.save_sft_weight('convert/pytorch_model_sft.bin') 60 | 61 | 62 | model.half().cuda() 63 | model = model.eval() 64 | 65 | text_list = [ 66 | "写一个诗歌,关于冬天", 67 | "晚上睡不着应该怎么办", 68 | ] 69 | generation_config = GenerationConfig(**{ 70 | "chat_format": "chatml", 71 | "eos_token_id": tokenizer.eos_token_id, 72 | "max_new_tokens": 512, 73 | "pad_token_id": tokenizer.eos_token_id, 74 | "do_sample": True, 75 | "top_k": 0, 76 | "top_p": 0.8, 77 | }) 78 | for input in text_list: 79 | _, input_ids = make_context(tokenizer, input) 80 | input_ids = torch.tensor(input_ids) 81 | input_ids = input_ids.unsqueeze(0) 82 | response = model.generate(inputs=input_ids.cuda(), ) 83 | outputs = response.tolist()[0][len(input_ids[0]):] 84 | response = tokenizer.decode(outputs, skip_special_tokens=True) 85 | 86 | print("input", input) 87 | print("response", response) 88 | 89 | -------------------------------------------------------------------------------- /infer/infer_lora_finetuning.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2023/3/9 15:29 3 | import os 4 | import sys 5 | sys.path.append(os.path.join(os.path.dirname(__file__),'..')) 6 | 7 | import torch 8 | from deep_training.data_helper import ModelArguments, DataArguments 9 | from transformers import HfArgumentParser, GenerationConfig,AutoConfig 10 | from data_utils import config_args, NN_DataHelper,global_args 11 | from deep_training.zoo.model_zoo.llm.llm_model import MyTransformer,PetlArguments 12 | from data_processer import make_context 13 | 14 | 15 | if __name__ == '__main__': 16 | config_args['seed'] = None 17 | parser = HfArgumentParser((ModelArguments,)) 18 | (model_args,) = parser.parse_dict(config_args, allow_extra_keys=True) 19 | dataHelper = NN_DataHelper(model_args) 20 | tokenizer, _, _, _ = dataHelper.load_tokenizer_and_config() 21 | 22 | weight_dir = '../scripts/best_ckpt' 23 | lora_weight_dir = os.path.join(weight_dir, "last") 24 | 25 | config = AutoConfig.from_pretrained(weight_dir) 26 | lora_args = PetlArguments.from_pretrained(lora_weight_dir) 27 | 28 | assert lora_args.inference_mode == True 29 | 30 | # new_num_tokens = config.vocab_size 31 | # if config.task_specific_params is not None and config.task_specific_params.get('vocab_size', None) is not None: 32 | # config.vocab_size = config.task_specific_params['vocab_size'] 33 | 34 | pl_model = MyTransformer(config=config, model_args=model_args, lora_args=lora_args, 35 | torch_dtype=torch.float16, 36 | # new_num_tokens=new_num_tokens,#扩充词 37 | 38 | # # device_map="auto", 39 | # device_map = {"":0} # 第一块卡 40 | ) 41 | # 加载lora权重 42 | pl_model.load_sft_weight(lora_weight_dir) 43 | 44 | pl_model.eval().half().cuda() 45 | 46 | enable_merge_weight = False 47 | if enable_merge_weight: 48 | # 合并lora 权重 保存 49 | pl_model.save_sft_weight(os.path.join(lora_weight_dir,'pytorch_model_merge.bin'),merge_lora_weight=True) 50 | 51 | else: 52 | model = pl_model.get_llm_model() 53 | 54 | text_list = [ 55 | "写一个诗歌,关于冬天", 56 | "晚上睡不着应该怎么办", 57 | ] 58 | 59 | model.generation_config = GenerationConfig(**{ 60 | "chat_format": "chatml", 61 | "eos_token_id": tokenizer.eos_token_id, 62 | "max_new_tokens": 512, 63 | "pad_token_id": tokenizer.eos_token_id, 64 | #"stop_words_ids": [[151643]], 65 | "do_sample": True, 66 | "top_k": 0, 67 | "top_p": 0.8, 68 | }) 69 | for input in text_list: 70 | _, input_ids = make_context(tokenizer, input) 71 | input_ids = torch.tensor(input_ids) 72 | input_ids = input_ids.unsqueeze(0) 73 | response = model.generate(inputs = input_ids.cuda(),) 74 | outputs = response.tolist()[0][len(input_ids[0]):] 75 | response = tokenizer.decode(outputs, skip_special_tokens=True) 76 | 77 | print("input", input) 78 | print("response", response) 79 | 80 | 81 | 82 | -------------------------------------------------------------------------------- /infer/infer_muti_lora_finetuning.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2023/3/9 15:29 3 | import os 4 | import sys 5 | sys.path.append(os.path.join(os.path.dirname(__file__),'..')) 6 | 7 | import torch 8 | from deep_training.data_helper import ModelArguments, DataArguments 9 | from transformers import HfArgumentParser, GenerationConfig, AutoConfig 10 | from data_utils import config_args, NN_DataHelper, global_args 11 | from deep_training.zoo.model_zoo.llm.llm_model import MyTransformer, PetlArguments,PetlModel 12 | from data_processer import make_context 13 | 14 | if __name__ == '__main__': 15 | config_args['seed'] = None 16 | parser = HfArgumentParser((ModelArguments,)) 17 | (model_args,) = parser.parse_dict(config_args, allow_extra_keys=True) 18 | dataHelper = NN_DataHelper(model_args) 19 | tokenizer, _, _, _ = dataHelper.load_tokenizer_and_config() 20 | 21 | ckpt_dir = './best_ckpt/last' 22 | config = AutoConfig.from_pretrained(ckpt_dir) 23 | 24 | lora_args = PetlArguments.from_pretrained(ckpt_dir) 25 | 26 | assert lora_args.inference_mode == True 27 | 28 | # new_num_tokens = config.vocab_size 29 | # if config.task_specific_params is not None and config.task_specific_params.get('vocab_size', None) is not None: 30 | # config.vocab_size = config.task_specific_params['vocab_size'] 31 | 32 | pl_model = MyTransformer(config=config, model_args=model_args, lora_args=lora_args, 33 | torch_dtype=torch.float16, 34 | # new_num_tokens=new_num_tokens,#扩充词 35 | 36 | # # device_map="auto", 37 | # device_map = {"":0} # 第一块卡 38 | ) 39 | # 加载多个lora权重 40 | pl_model.load_sft_weight(ckpt_dir,adapter_name="default") 41 | 42 | # 加载多个lora权重 43 | #pl_model.load_sft_weight(ckpt_dir, adapter_name="yourname") 44 | 45 | # 加载多个lora权重 46 | #pl_model.load_sft_weight(ckpt_dir, adapter_name="yourname") 47 | 48 | 49 | pl_model.eval().half().cuda() 50 | 51 | # backbone model replaced PetlModel 52 | lora_model: PetlModel = pl_model.backbone 53 | 54 | text_list = [ 55 | "写一个诗歌,关于冬天", 56 | "晚上睡不着应该怎么办", 57 | ] 58 | generation_config = GenerationConfig(**{ 59 | "chat_format": "chatml", 60 | "eos_token_id": tokenizer.eos_token_id, 61 | "max_new_tokens": 512, 62 | "pad_token_id": tokenizer.eos_token_id, 63 | #"stop_words_ids": [[151643]], 64 | "do_sample": True, 65 | "top_k": 0, 66 | "top_p": 0.8, 67 | }) 68 | # 基准模型推理 69 | with lora_model.disable_adapter(): 70 | for input in text_list: 71 | #lora_model 调用子对象方法 72 | response, history = lora_model.chat(tokenizer, input, history=[],generation_config=generation_config ) 73 | print("input", input) 74 | print("response", response) 75 | 76 | lora_model.set_adapter(adapter_name='default') 77 | 78 | for input in text_list: 79 | _, input_ids = make_context(tokenizer, input) 80 | input_ids = torch.tensor(input_ids) 81 | input_ids = input_ids.unsqueeze(0) 82 | response = lora_model.generate(inputs=input_ids.cuda(), ) 83 | outputs = response.tolist()[0][len(input_ids[0]):] 84 | response = tokenizer.decode(outputs, skip_special_tokens=True) 85 | 86 | print("input", input) 87 | print("response", response) 88 | 89 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tiktoken 2 | einops 3 | transformers>=4.30 4 | deepspeed 5 | cpm-kernels 6 | bitsandbytes>=0.39 7 | accelerate>=0.20 8 | 9 | git+https://github.com/ssbuild/deep_training#egg=deep_training -------------------------------------------------------------------------------- /scripts/train_full.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | export trainer_backend=pl 4 | 5 | train_file="../config/train_${trainer_backend}.yaml" 6 | 7 | # 强制覆盖配置文件 8 | export train_file=${train_file} 9 | export enable_deepspeed=false 10 | export enable_ptv2=false 11 | export enable_lora=false 12 | export load_in_bit=0 13 | 14 | # export CUDA_VISIBLE_DEVICES=0,1,2,3 15 | 16 | 17 | usage() { echo "Usage: $0 [-m ]" 1>&2; exit 1; } 18 | 19 | 20 | while getopts m: opt 21 | do 22 | case "${opt}" in 23 | m) mode=${OPTARG};; 24 | *) 25 | usage 26 | ;; 27 | esac 28 | done 29 | 30 | if [ "${mode}" != "dataset" ] && [ "${mode}" != "train" ] ; then 31 | usage 32 | fi 33 | 34 | if [[ "${mode}" == "dataset" ]] ; then 35 | python ../data_utils.py 36 | exit 0 37 | fi 38 | 39 | if [[ "${trainer_backend}" == "pl" ]] ; then 40 | # pl 多卡 修改配置文件 devices 41 | 42 | ### 多机多卡训练 43 | 44 | # 例子 3个机器 每个机器 4个卡 45 | # 修改train.py Trainer num_nodes = 3 46 | # MASTER_ADDR=10.0.0.1 MASTER_PORT=6667 WORLD_SIZE=12 NODE_RANK=0 python train.py 47 | # MASTER_ADDR=10.0.0.1 MASTER_PORT=6667 WORLD_SIZE=12 NODE_RANK=1 python train.py 48 | # MASTER_ADDR=10.0.0.1 MASTER_PORT=6667 WORLD_SIZE=12 NODE_RANK=2 python train.py 49 | 50 | # pl 多卡 修改配置文件 devices 51 | 52 | python ../train.py 53 | elif [[ "${trainer_backend}" == "cl" ]] ; then 54 | # 多机多卡 55 | # colossalai run --nproc_per_node 1 --num_nodes 1 --master_addr $MASTER_ADDR --master_port $MASTER_PORT ../train.py 56 | 57 | colossalai run --nproc_per_node 1 --num_nodes 1 ../train.py 58 | else 59 | # 多机多卡 60 | # --nproc_per_node=1 nnodes=1 --node_rank $NODE_RANK --master_addr $MASTER_ADDR --master_port $MASTER_PORT ../train.py 61 | torchrun --nproc_per_node 1 --nnodes 1 ../train.py 62 | fi -------------------------------------------------------------------------------- /scripts/train_lora.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | export trainer_backend=pl 4 | 5 | train_file="../config/train_${trainer_backend}.yaml" 6 | 7 | # 强制覆盖配置文件 8 | export train_file=${train_file} 9 | export enable_deepspeed=false 10 | export enable_ptv2=false 11 | export enable_lora=true 12 | export load_in_bit=0 13 | 14 | # export CUDA_VISIBLE_DEVICES=0,1,2,3 15 | 16 | usage() { echo "Usage: $0 [-m ]" 1>&2; exit 1; } 17 | 18 | 19 | while getopts m: opt 20 | do 21 | case "${opt}" in 22 | m) mode=${OPTARG};; 23 | *) 24 | usage 25 | ;; 26 | esac 27 | done 28 | 29 | if [ "${mode}" != "dataset" ] && [ "${mode}" != "train" ] ; then 30 | usage 31 | fi 32 | 33 | if [[ "${mode}" == "dataset" ]] ; then 34 | python ../data_utils.py 35 | exit 0 36 | fi 37 | 38 | if [[ "${trainer_backend}" == "pl" ]] ; then 39 | # pl 多卡 修改配置文件 devices 40 | 41 | ### 多机多卡训练 42 | 43 | # 例子 3个机器 每个机器 4个卡 44 | # 修改train.py Trainer num_nodes = 3 45 | # MASTER_ADDR=10.0.0.1 MASTER_PORT=6667 WORLD_SIZE=12 NODE_RANK=0 python train.py 46 | # MASTER_ADDR=10.0.0.1 MASTER_PORT=6667 WORLD_SIZE=12 NODE_RANK=1 python train.py 47 | # MASTER_ADDR=10.0.0.1 MASTER_PORT=6667 WORLD_SIZE=12 NODE_RANK=2 python train.py 48 | 49 | # pl 多卡 修改配置文件 devices 50 | 51 | python ../train.py 52 | elif [[ "${trainer_backend}" == "cl" ]] ; then 53 | # 多机多卡 54 | # colossalai run --nproc_per_node 1 --num_nodes 1 --master_addr $MASTER_ADDR --master_port $MASTER_PORT ../train.py 55 | 56 | colossalai run --nproc_per_node 1 --num_nodes 1 ../train.py 57 | else 58 | # 多机多卡 59 | # --nproc_per_node=1 nnodes=1 --node_rank $NODE_RANK --master_addr $MASTER_ADDR --master_port $MASTER_PORT ../train.py 60 | torchrun --nproc_per_node 1 --nnodes 1 ../train.py 61 | fi -------------------------------------------------------------------------------- /scripts/train_lora_int4.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | export trainer_backend=pl 4 | 5 | train_file="../config/train_${trainer_backend}.yaml" 6 | 7 | # 强制覆盖配置文件 8 | export train_file=${train_file} 9 | export enable_deepspeed=false 10 | export enable_ptv2=false 11 | export enable_lora=true 12 | export load_in_bit=4 13 | 14 | # export CUDA_VISIBLE_DEVICES=0,1,2,3 15 | 16 | usage() { echo "Usage: $0 [-m ]" 1>&2; exit 1; } 17 | 18 | 19 | while getopts m: opt 20 | do 21 | case "${opt}" in 22 | m) mode=${OPTARG};; 23 | *) 24 | usage 25 | ;; 26 | esac 27 | done 28 | 29 | if [ "${mode}" != "dataset" ] && [ "${mode}" != "train" ] ; then 30 | usage 31 | fi 32 | 33 | if [[ "${mode}" == "dataset" ]] ; then 34 | python ../data_utils.py 35 | exit 0 36 | fi 37 | 38 | if [[ "${trainer_backend}" == "pl" ]] ; then 39 | # pl 多卡 修改配置文件 devices 40 | 41 | ### 多机多卡训练 42 | 43 | # 例子 3个机器 每个机器 4个卡 44 | # 修改train.py Trainer num_nodes = 3 45 | # MASTER_ADDR=10.0.0.1 MASTER_PORT=6667 WORLD_SIZE=12 NODE_RANK=0 python train.py 46 | # MASTER_ADDR=10.0.0.1 MASTER_PORT=6667 WORLD_SIZE=12 NODE_RANK=1 python train.py 47 | # MASTER_ADDR=10.0.0.1 MASTER_PORT=6667 WORLD_SIZE=12 NODE_RANK=2 python train.py 48 | 49 | # pl 多卡 修改配置文件 devices 50 | 51 | python ../train.py 52 | elif [[ "${trainer_backend}" == "cl" ]] ; then 53 | # 多机多卡 54 | # colossalai run --nproc_per_node 1 --num_nodes 1 --master_addr $MASTER_ADDR --master_port $MASTER_PORT ../train.py 55 | 56 | colossalai run --nproc_per_node 1 --num_nodes 1 ../train.py 57 | else 58 | # 多机多卡 59 | # --nproc_per_node=1 nnodes=1 --node_rank $NODE_RANK --master_addr $MASTER_ADDR --master_port $MASTER_PORT ../train.py 60 | torchrun --nproc_per_node 1 --nnodes 1 ../train.py 61 | fi -------------------------------------------------------------------------------- /scripts/train_lora_int8.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | export trainer_backend=pl 4 | 5 | train_file="../config/train_${trainer_backend}.yaml" 6 | 7 | # 强制覆盖配置文件 8 | export train_file=${train_file} 9 | export enable_deepspeed=false 10 | export enable_ptv2=false 11 | export enable_lora=true 12 | export load_in_bit=8 13 | 14 | # export CUDA_VISIBLE_DEVICES=0,1,2,3 15 | 16 | usage() { echo "Usage: $0 [-m ]" 1>&2; exit 1; } 17 | 18 | 19 | while getopts m: opt 20 | do 21 | case "${opt}" in 22 | m) mode=${OPTARG};; 23 | *) 24 | usage 25 | ;; 26 | esac 27 | done 28 | 29 | if [ "${mode}" != "dataset" ] && [ "${mode}" != "train" ] ; then 30 | usage 31 | fi 32 | 33 | if [[ "${mode}" == "dataset" ]] ; then 34 | python ../data_utils.py 35 | exit 0 36 | fi 37 | 38 | if [[ "${trainer_backend}" == "pl" ]] ; then 39 | # pl 多卡 修改配置文件 devices 40 | 41 | ### 多机多卡训练 42 | 43 | # 例子 3个机器 每个机器 4个卡 44 | # 修改train.py Trainer num_nodes = 3 45 | # MASTER_ADDR=10.0.0.1 MASTER_PORT=6667 WORLD_SIZE=12 NODE_RANK=0 python train.py 46 | # MASTER_ADDR=10.0.0.1 MASTER_PORT=6667 WORLD_SIZE=12 NODE_RANK=1 python train.py 47 | # MASTER_ADDR=10.0.0.1 MASTER_PORT=6667 WORLD_SIZE=12 NODE_RANK=2 python train.py 48 | 49 | # pl 多卡 修改配置文件 devices 50 | 51 | python ../train.py 52 | elif [[ "${trainer_backend}" == "cl" ]] ; then 53 | # 多机多卡 54 | # colossalai run --nproc_per_node 1 --num_nodes 1 --master_addr $MASTER_ADDR --master_port $MASTER_PORT ../train.py 55 | 56 | colossalai run --nproc_per_node 1 --num_nodes 1 ../train.py 57 | else 58 | # 多机多卡 59 | # --nproc_per_node=1 nnodes=1 --node_rank $NODE_RANK --master_addr $MASTER_ADDR --master_port $MASTER_PORT ../train.py 60 | torchrun --nproc_per_node 1 --nnodes 1 ../train.py 61 | fi -------------------------------------------------------------------------------- /scripts/train_ptv2.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | export trainer_backend=pl 4 | 5 | train_file="../config/train_${trainer_backend}.yaml" 6 | 7 | # 强制覆盖配置文件 8 | export train_file=${train_file} 9 | export enable_deepspeed=false 10 | export enable_ptv2=true 11 | export enable_lora=false 12 | export load_in_bit=0 13 | 14 | # export CUDA_VISIBLE_DEVICES=0,1,2,3 15 | 16 | usage() { echo "Usage: $0 [-m ]" 1>&2; exit 1; } 17 | 18 | 19 | while getopts m: opt 20 | do 21 | case "${opt}" in 22 | m) mode=${OPTARG};; 23 | *) 24 | usage 25 | ;; 26 | esac 27 | done 28 | 29 | if [ "${mode}" != "dataset" ] && [ "${mode}" != "train" ] ; then 30 | usage 31 | fi 32 | 33 | 34 | if [[ "${mode}" == "dataset" ]] ; then 35 | python ../data_utils.py 36 | exit 0 37 | fi 38 | 39 | if [[ "${trainer_backend}" == "pl" ]] ; then 40 | # pl 多卡 修改配置文件 devices 41 | 42 | ### 多机多卡训练 43 | 44 | # 例子 3个机器 每个机器 4个卡 45 | # 修改train.py Trainer num_nodes = 3 46 | # MASTER_ADDR=10.0.0.1 MASTER_PORT=6667 WORLD_SIZE=12 NODE_RANK=0 python train.py 47 | # MASTER_ADDR=10.0.0.1 MASTER_PORT=6667 WORLD_SIZE=12 NODE_RANK=1 python train.py 48 | # MASTER_ADDR=10.0.0.1 MASTER_PORT=6667 WORLD_SIZE=12 NODE_RANK=2 python train.py 49 | 50 | # pl 多卡 修改配置文件 devices 51 | 52 | python ../train.py 53 | elif [[ "${trainer_backend}" == "cl" ]] ; then 54 | # 多机多卡 55 | # colossalai run --nproc_per_node 1 --num_nodes 1 --master_addr $MASTER_ADDR --master_port $MASTER_PORT ../train.py 56 | 57 | colossalai run --nproc_per_node 1 --num_nodes 1 ../train.py 58 | else 59 | # 多机多卡 60 | # --nproc_per_node=1 nnodes=1 --node_rank $NODE_RANK --master_addr $MASTER_ADDR --master_port $MASTER_PORT ../train.py 61 | torchrun --nproc_per_node 1 --nnodes 1 ../train.py 62 | fi -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author : ssbuild 3 | # @Time : 2023/10/12 10:50 4 | 5 | import os 6 | from config import global_args 7 | 8 | def main(): 9 | trainer_backend = global_args["trainer_backend"] 10 | if trainer_backend == "pl": 11 | from training.train_pl import main as main_execute 12 | elif trainer_backend == "hf": 13 | from training.train_hf import main as main_execute 14 | elif trainer_backend == "cl": 15 | from training.train_cl import main as main_execute 16 | elif trainer_backend == "ac": 17 | from training.train_ac import main as main_execute 18 | else: 19 | raise ValueError(f"{trainer_backend} NotImplemented ") 20 | 21 | main_execute() 22 | 23 | def _mp_fn(index): 24 | # For xla_spawn (TPUs) 25 | main() 26 | 27 | 28 | if __name__ == "__main__": 29 | main() 30 | -------------------------------------------------------------------------------- /training/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author : ssbuild 3 | # @Time : 2023/10/12 16:33 4 | -------------------------------------------------------------------------------- /training/train_ac.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author : ssbuild 3 | # @Time : 2023/9/25 12:29 4 | import sys 5 | import os 6 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__),'..'))) 7 | 8 | import logging 9 | import math 10 | import datasets 11 | import torch 12 | import transformers 13 | from deep_training.trainer.ac.trainer import TrainerAC 14 | from transformers import ( 15 | HfArgumentParser, 16 | default_data_collator, 17 | set_seed, 18 | ) 19 | from transformers.trainer_utils import get_last_checkpoint 20 | from transformers.utils import check_min_version, send_example_telemetry 21 | from transformers.utils.versions import require_version 22 | from data_utils import NN_DataHelper, config_args, get_deepspeed_config, global_args 23 | from transformers import HfArgumentParser, PreTrainedTokenizer,PretrainedConfig 24 | from deep_training.zoo.model_zoo.llm.llm_model import MyTransformer, PetlArguments, LoraConfig, PromptArguments 25 | from deep_training.data_helper import ModelArguments, DataArguments,TrainingArgumentsAC 26 | 27 | assert global_args["trainer_backend"] == "ac" 28 | 29 | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. 30 | check_min_version("4.33.2") 31 | 32 | logger = logging.getLogger(__name__) 33 | 34 | # Setup logging 35 | logging.basicConfig( 36 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 37 | datefmt="%m/%d/%Y %H:%M:%S", 38 | handlers=[logging.StreamHandler(sys.stdout)], 39 | ) 40 | 41 | def main(): 42 | 43 | 44 | training_args: TrainingArgumentsAC 45 | parser = HfArgumentParser((ModelArguments, TrainingArgumentsAC, DataArguments, PetlArguments, PromptArguments), 46 | conflict_handler='resolve') 47 | model_args, training_args, data_args, lora_args, prompt_args = parser.parse_dict(config_args,allow_extra_keys=True,) 48 | lora_args = lora_args.config 49 | prompt_args = prompt_args.config 50 | 51 | if training_args.should_log: 52 | # The default of training_args.log_level is passive, so we set log level at info here to have that default. 53 | transformers.utils.logging.set_verbosity_info() 54 | 55 | log_level = training_args.get_process_log_level() 56 | logger.setLevel(log_level) 57 | datasets.utils.logging.set_verbosity(log_level) 58 | transformers.utils.logging.set_verbosity(log_level) 59 | transformers.utils.logging.enable_default_handler() 60 | transformers.utils.logging.enable_explicit_format() 61 | 62 | dataHelper = NN_DataHelper(model_args, training_args, data_args) 63 | config_kwargs = {"torch_dtype": torch.float16} 64 | if global_args['config_merge']: 65 | config_kwargs.update(global_args['config_merge']) 66 | 67 | tokenizer, config, _, _ = dataHelper.load_tokenizer_and_config(config_kwargs=config_kwargs) 68 | 69 | with training_args.main_process_first(desc="make_dataset_all"): 70 | dataHelper.make_dataset_all() 71 | 72 | is_bf16_supported = torch.cuda.is_bf16_supported() 73 | precision = global_args["precision"] 74 | if precision == "auto": 75 | # 精度 根据实际情况做调整 76 | if is_bf16_supported: 77 | precision = 'bf16' 78 | else: 79 | precision = '16' 80 | 81 | if global_args["quantization_config"] is not None and global_args["quantization_config"].load_in_8bit: 82 | precision = "32" 83 | 84 | 85 | # 精度 根据实际情况做调整 86 | if precision == "bf16": 87 | config.bf16 = True 88 | config.fp16 = False 89 | config.fp32 = False 90 | elif precision == "16": 91 | config.bf16 = False 92 | config.fp16 = True 93 | config.fp32 = False 94 | elif precision == "32": 95 | config.bf16 = False 96 | config.fp16 = False 97 | config.fp32 = True 98 | else: 99 | raise NotImplemented 100 | 101 | 102 | 103 | if str(precision) == '16': 104 | training_args.fp16 = True 105 | elif str(precision) == 'bf16': 106 | training_args.bf16 = True 107 | else: 108 | training_args.fp16 = False 109 | training_args.bf16 = False 110 | 111 | 112 | deepspeed_config = get_deepspeed_config(precision) 113 | if deepspeed_config: 114 | training_args.deepspeed = deepspeed_config 115 | 116 | # Log on each process the small summary: 117 | logger.warning( 118 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" 119 | + f"distributed training: {training_args.parallel_mode.value == 'distributed'}, 16-bits training: {training_args.fp16}" 120 | ) 121 | logger.info(f"Training/evaluation parameters {training_args}") 122 | 123 | # Detecting last checkpoint. 124 | last_checkpoint = None 125 | if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir: 126 | last_checkpoint = get_last_checkpoint(training_args.output_dir) 127 | if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: 128 | raise ValueError( 129 | f"Output directory ({training_args.output_dir}) already exists and is not empty. " 130 | "Use --overwrite_output_dir to overcome." 131 | ) 132 | elif last_checkpoint is not None and training_args.resume_from_checkpoint is None: 133 | logger.info( 134 | f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " 135 | "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." 136 | ) 137 | 138 | # Set seed before initializing model. 139 | set_seed(training_args.seed) 140 | 141 | world_size,local_rank,process_index = training_args.world_size,training_args.local_rank,training_args.process_index 142 | 143 | transformer_args = dict(config=config, model_args=model_args, training_args=training_args, lora_args=lora_args, 144 | prompt_args=prompt_args, 145 | quantization_config=global_args["quantization_config"], 146 | device_map={"": local_rank} if world_size > 1 else "auto", 147 | torch_dtype=torch.float16, 148 | # new_num_tokens=len(tokenizer), # 可能扩充词 149 | auto_prepare_kbit_training=True, 150 | use_gradient_checkpointing=False 151 | ) 152 | 153 | if transformer_args["quantization_config"] is None: 154 | transformer_args.pop("device_map") 155 | 156 | pl_model = MyTransformer(**transformer_args) 157 | 158 | config.save_pretrained(training_args.output_dir) 159 | 160 | # 加载sft权重 161 | # pl_model.load_sft_weight('./best_ckpt/best.pt',is_trainable=True) 162 | 163 | # # Finetune 164 | # if config.bf16: 165 | # pl_model = pl_model.bfloat16() 166 | # else: 167 | # pl_model = pl_model.float() 168 | 169 | train_datasets = None 170 | if training_args.do_train: 171 | train_datasets = dataHelper.load_distributed_random_sampler( 172 | dataHelper.load_dataset_files()["train_files"], 173 | with_load_memory=data_args.data_backend == 'record', 174 | collate_fn=dataHelper.collate_fn, 175 | batch_size=training_args.per_device_train_batch_size, 176 | drop_last=training_args.dataloader_drop_last, # 多卡建议扔掉 177 | num_processes=world_size, process_index=process_index, 178 | num_workers = training_args.dataloader_num_workers, 179 | pin_memory = training_args.dataloader_pin_memory, 180 | ) 181 | 182 | 183 | 184 | # Initialize our Trainer 185 | trainer = TrainerAC( 186 | model=pl_model, 187 | args=training_args, 188 | train_dataset=train_datasets, 189 | tokenizer=tokenizer, 190 | # Data collator will default to DataCollatorWithPadding, so we change it. 191 | data_collator=default_data_collator, 192 | ) 193 | 194 | # Training 195 | if training_args.do_train: 196 | checkpoint = None 197 | if training_args.resume_from_checkpoint is not None: 198 | checkpoint = training_args.resume_from_checkpoint 199 | elif last_checkpoint is not None: 200 | checkpoint = last_checkpoint 201 | trainer.train(resume_from_checkpoint=checkpoint) 202 | 203 | 204 | 205 | 206 | def _mp_fn(index): 207 | # For xla_spawn (TPUs) 208 | main() 209 | 210 | 211 | if __name__ == "__main__": 212 | main() 213 | -------------------------------------------------------------------------------- /training/train_cl.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author : ssbuild 3 | # @Time : 2023/9/25 12:29 4 | import sys 5 | import os 6 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__),'..'))) 7 | 8 | import logging 9 | import math 10 | from contextlib import nullcontext 11 | import datasets 12 | import torch 13 | import transformers 14 | from deep_training.trainer.cl.trainer import TrainerCL 15 | from transformers import ( 16 | HfArgumentParser, 17 | default_data_collator, 18 | set_seed, 19 | ) 20 | from transformers.trainer_utils import get_last_checkpoint 21 | from transformers.utils import check_min_version, send_example_telemetry 22 | from transformers.utils.versions import require_version 23 | from data_utils import NN_DataHelper, config_args, get_deepspeed_config, global_args 24 | from transformers import HfArgumentParser, PreTrainedTokenizer,PretrainedConfig 25 | from deep_training.zoo.model_zoo.llm.llm_model import MyTransformer, PetlArguments, LoraConfig, PromptArguments 26 | from deep_training.data_helper import ModelArguments, DataArguments,TrainingArgumentsCL 27 | 28 | assert global_args["trainer_backend"] == "cl" 29 | 30 | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. 31 | check_min_version("4.33.2") 32 | 33 | 34 | logger = logging.getLogger(__name__) 35 | 36 | # Setup logging 37 | logging.basicConfig( 38 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 39 | datefmt="%m/%d/%Y %H:%M:%S", 40 | handlers=[logging.StreamHandler(sys.stdout)], 41 | ) 42 | 43 | def main(): 44 | 45 | world_size, local_rank, process_index = int(os.environ.get("WORLD_SIZE", 1)), int( 46 | os.environ.get("LOCAL_RANK", 0)), int(os.environ.get("RANK", 0)) 47 | 48 | 49 | training_args: TrainingArgumentsCL 50 | parser = HfArgumentParser((ModelArguments, TrainingArgumentsCL, DataArguments, PetlArguments, PromptArguments), 51 | conflict_handler='resolve') 52 | model_args, training_args, data_args, lora_args, prompt_args = parser.parse_dict(config_args,allow_extra_keys=True,) 53 | lora_args = lora_args.config 54 | prompt_args = prompt_args.config 55 | 56 | 57 | 58 | dataHelper = NN_DataHelper(model_args, training_args, data_args) 59 | config_kwargs = {"torch_dtype": torch.float16} 60 | if global_args['config_merge']: 61 | config_kwargs.update(global_args['config_merge']) 62 | 63 | tokenizer, config, _, _ = dataHelper.load_tokenizer_and_config(config_kwargs=config_kwargs) 64 | 65 | if process_index == 0: 66 | dataHelper.make_dataset_all() 67 | 68 | is_bf16_supported = torch.cuda.is_bf16_supported() 69 | precision = global_args["precision"] 70 | if precision == "auto": 71 | # 精度 根据实际情况做调整 72 | if is_bf16_supported: 73 | precision = 'bf16' 74 | else: 75 | precision = '16' 76 | 77 | if global_args["quantization_config"] is not None and global_args["quantization_config"].load_in_8bit: 78 | precision = "32" 79 | 80 | # 精度 根据实际情况做调整 81 | if precision == "bf16": 82 | config.bf16 = True 83 | config.fp16 = False 84 | config.fp32 = False 85 | elif precision == "16": 86 | config.bf16 = False 87 | config.fp16 = True 88 | config.fp32 = False 89 | elif precision == "32": 90 | config.bf16 = False 91 | config.fp16 = False 92 | config.fp32 = True 93 | else: 94 | raise NotImplemented 95 | 96 | if str(precision) == '16': 97 | training_args.fp16 = True 98 | elif str(precision) == 'bf16': 99 | training_args.bf16 = True 100 | else: 101 | training_args.fp16 = False 102 | training_args.bf16 = False 103 | 104 | 105 | # Log on each process the small summary: 106 | logger.warning( 107 | f"Process rank: {training_args.local_rank}" 108 | + f"16-bits training: {training_args.fp16}" 109 | ) 110 | logger.info(f"Training/evaluation parameters {training_args}") 111 | 112 | # Detecting last checkpoint. 113 | last_checkpoint = None 114 | if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir: 115 | last_checkpoint = get_last_checkpoint(training_args.output_dir) 116 | if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: 117 | raise ValueError( 118 | f"Output directory ({training_args.output_dir}) already exists and is not empty. " 119 | "Use --overwrite_output_dir to overcome." 120 | ) 121 | elif last_checkpoint is not None and training_args.resume_from_checkpoint is None: 122 | logger.info( 123 | f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " 124 | "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." 125 | ) 126 | 127 | # Set seed before initializing model. 128 | set_seed(training_args.seed) 129 | 130 | 131 | transformer_args = dict(config=config, model_args=model_args, training_args=training_args, lora_args=lora_args, 132 | prompt_args=prompt_args, 133 | quantization_config=global_args["quantization_config"], 134 | device_map={"": local_rank} if world_size > 1 else "auto", 135 | torch_dtype=torch.float16, 136 | # new_num_tokens=len(tokenizer), # 可能扩充词 137 | auto_prepare_kbit_training=True, 138 | use_gradient_checkpointing=False 139 | ) 140 | 141 | if transformer_args["quantization_config"] is None: 142 | transformer_args.pop("device_map") 143 | 144 | with nullcontext(): 145 | pl_model = MyTransformer(**transformer_args) 146 | 147 | config.save_pretrained(training_args.output_dir) 148 | 149 | # 加载sft权重 150 | # pl_model.load_sft_weight('./best_ckpt/best.pt',is_trainable=True) 151 | 152 | # # Finetune 153 | # if config.bf16: 154 | # pl_model = pl_model.bfloat16() 155 | # else: 156 | # pl_model = pl_model.float() 157 | 158 | train_datasets = None 159 | if training_args.do_train: 160 | train_datasets = dataHelper.load_distributed_random_sampler( 161 | dataHelper.load_dataset_files()["train_files"], 162 | with_load_memory=data_args.data_backend == 'record', 163 | collate_fn=dataHelper.collate_fn, 164 | batch_size=training_args.per_device_train_batch_size, 165 | drop_last=training_args.dataloader_drop_last, # 多卡建议扔掉 166 | num_processes=world_size, process_index=process_index, 167 | num_workers = training_args.dataloader_num_workers, 168 | pin_memory = training_args.dataloader_pin_memory, 169 | ) 170 | 171 | 172 | 173 | # Initialize our Trainer 174 | trainer = TrainerCL( 175 | model=pl_model, 176 | args=training_args, 177 | train_dataset=train_datasets, 178 | tokenizer=tokenizer, 179 | # Data collator will default to DataCollatorWithPadding, so we change it. 180 | data_collator=default_data_collator, 181 | ) 182 | 183 | # Training 184 | if training_args.do_train: 185 | checkpoint = None 186 | if training_args.resume_from_checkpoint is not None: 187 | checkpoint = training_args.resume_from_checkpoint 188 | elif last_checkpoint is not None: 189 | checkpoint = last_checkpoint 190 | trainer.train(resume_from_checkpoint=checkpoint) 191 | 192 | 193 | 194 | 195 | def _mp_fn(index): 196 | # For xla_spawn (TPUs) 197 | main() 198 | 199 | 200 | if __name__ == "__main__": 201 | main() 202 | -------------------------------------------------------------------------------- /training/train_hf.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author : ssbuild 3 | # @Time : 2023/9/25 12:29 4 | import sys 5 | import os 6 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__),'..'))) 7 | 8 | 9 | import logging 10 | import math 11 | import datasets 12 | import torch 13 | import transformers 14 | from deep_training.trainer.hf.trainer import TrainerHF 15 | from transformers import ( 16 | HfArgumentParser, 17 | default_data_collator, 18 | set_seed, 19 | ) 20 | from transformers.trainer_utils import get_last_checkpoint 21 | from transformers.utils import check_min_version, send_example_telemetry 22 | from transformers.utils.versions import require_version 23 | from data_utils import NN_DataHelper, config_args, get_deepspeed_config, global_args 24 | from transformers import HfArgumentParser, PreTrainedTokenizer,PretrainedConfig 25 | from deep_training.zoo.model_zoo.llm.llm_model import MyTransformer, PetlArguments, LoraConfig, PromptArguments 26 | from deep_training.data_helper import ModelArguments, DataArguments,TrainingArgumentsHF 27 | 28 | assert global_args["trainer_backend"] == "hf" 29 | 30 | # Will error if the minimal version of Transformers is not installed. Remove at your own risks. 31 | check_min_version("4.33.2") 32 | 33 | logger = logging.getLogger(__name__) 34 | 35 | # Setup logging 36 | logging.basicConfig( 37 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 38 | datefmt="%m/%d/%Y %H:%M:%S", 39 | handlers=[logging.StreamHandler(sys.stdout)], 40 | ) 41 | 42 | def main(): 43 | setup_model_profile() 44 | training_args: TrainingArgumentsHF 45 | parser = HfArgumentParser((ModelArguments, TrainingArgumentsHF, DataArguments, PetlArguments, PromptArguments), 46 | conflict_handler='resolve') 47 | model_args, training_args, data_args, lora_args, prompt_args = parser.parse_dict(config_args,allow_extra_keys=True,) 48 | lora_args = lora_args.config 49 | prompt_args = prompt_args.config 50 | 51 | if training_args.should_log: 52 | # The default of training_args.log_level is passive, so we set log level at info here to have that default. 53 | transformers.utils.logging.set_verbosity_info() 54 | 55 | log_level = training_args.get_process_log_level() 56 | logger.setLevel(log_level) 57 | datasets.utils.logging.set_verbosity(log_level) 58 | transformers.utils.logging.set_verbosity(log_level) 59 | transformers.utils.logging.enable_default_handler() 60 | transformers.utils.logging.enable_explicit_format() 61 | 62 | dataHelper = NN_DataHelper(model_args, training_args, data_args) 63 | config_kwargs = {"torch_dtype": torch.float16} 64 | if global_args['config_merge']: 65 | config_kwargs.update(global_args['config_merge']) 66 | 67 | tokenizer, config, _, _ = dataHelper.load_tokenizer_and_config(config_kwargs=config_kwargs) 68 | 69 | with training_args.main_process_first(desc="make_dataset_all"): 70 | dataHelper.make_dataset_all() 71 | 72 | is_bf16_supported = torch.cuda.is_bf16_supported() 73 | precision = global_args["precision"] 74 | if precision == "auto": 75 | # 精度 根据实际情况做调整 76 | if is_bf16_supported: 77 | precision = 'bf16' 78 | else: 79 | precision = '16' 80 | 81 | if global_args["quantization_config"] is not None and global_args["quantization_config"].load_in_8bit: 82 | precision = "32" 83 | 84 | # 精度 根据实际情况做调整 85 | if precision == "bf16": 86 | config.bf16 = True 87 | config.fp16 = False 88 | config.fp32 = False 89 | elif precision == "16": 90 | config.bf16 = False 91 | config.fp16 = True 92 | config.fp32 = False 93 | elif precision == "32": 94 | config.bf16 = False 95 | config.fp16 = False 96 | config.fp32 = True 97 | else: 98 | raise NotImplemented 99 | 100 | if str(precision) == '16': 101 | training_args.fp16 = True 102 | elif str(precision) == 'bf16': 103 | training_args.bf16 = True 104 | else: 105 | training_args.fp16 = False 106 | training_args.bf16 = False 107 | 108 | deepspeed_config = get_deepspeed_config(precision) 109 | if deepspeed_config: 110 | training_args.deepspeed = deepspeed_config 111 | 112 | # Log on each process the small summary: 113 | logger.warning( 114 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" 115 | + f"distributed training: {training_args.parallel_mode.value == 'distributed'}, 16-bits training: {training_args.fp16}" 116 | ) 117 | logger.info(f"Training/evaluation parameters {training_args}") 118 | 119 | # Detecting last checkpoint. 120 | last_checkpoint = None 121 | if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir: 122 | last_checkpoint = get_last_checkpoint(training_args.output_dir) 123 | if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: 124 | raise ValueError( 125 | f"Output directory ({training_args.output_dir}) already exists and is not empty. " 126 | "Use --overwrite_output_dir to overcome." 127 | ) 128 | elif last_checkpoint is not None and training_args.resume_from_checkpoint is None: 129 | logger.info( 130 | f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " 131 | "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." 132 | ) 133 | 134 | # Set seed before initializing model. 135 | set_seed(training_args.seed) 136 | 137 | world_size,local_rank,process_index = training_args.world_size,training_args.local_rank,training_args.process_index 138 | 139 | transformer_args = dict(config=config, model_args=model_args, training_args=training_args, lora_args=lora_args, 140 | prompt_args=prompt_args, 141 | quantization_config=global_args["quantization_config"], 142 | device_map={"": local_rank} if world_size > 1 else "auto", 143 | torch_dtype=torch.float16, 144 | # new_num_tokens=len(tokenizer), # 可能扩充词 145 | auto_prepare_kbit_training=True, 146 | use_gradient_checkpointing=False 147 | ) 148 | 149 | if transformer_args["quantization_config"] is None: 150 | transformer_args.pop("device_map") 151 | 152 | pl_model = MyTransformer(**transformer_args) 153 | 154 | config.save_pretrained(training_args.output_dir) 155 | 156 | # 加载sft权重 157 | # pl_model.load_sft_weight('./best_ckpt/best.pt',is_trainable=True) 158 | 159 | # # Finetune 160 | # if config.bf16: 161 | # pl_model = pl_model.bfloat16() 162 | # else: 163 | # pl_model = pl_model.float() 164 | 165 | train_datasets = None 166 | if training_args.do_train: 167 | train_datasets = dataHelper.load_distributed_random_sampler( 168 | dataHelper.load_dataset_files()["train_files"], 169 | with_load_memory=data_args.data_backend == 'record', 170 | collate_fn=dataHelper.collate_fn, 171 | batch_size=training_args.per_device_train_batch_size, 172 | drop_last=training_args.dataloader_drop_last, # 多卡建议扔掉 173 | num_processes=world_size, process_index=process_index, 174 | num_workers = training_args.dataloader_num_workers, 175 | pin_memory = training_args.dataloader_pin_memory, 176 | ) 177 | 178 | 179 | 180 | # Initialize our Trainer 181 | trainer = TrainerHF( 182 | model=pl_model, 183 | args=training_args, 184 | train_dataset=train_datasets, 185 | tokenizer=tokenizer, 186 | # Data collator will default to DataCollatorWithPadding, so we change it. 187 | data_collator=default_data_collator, 188 | ) 189 | 190 | # Training 191 | if training_args.do_train: 192 | checkpoint = None 193 | if training_args.resume_from_checkpoint is not None: 194 | checkpoint = training_args.resume_from_checkpoint 195 | elif last_checkpoint is not None: 196 | checkpoint = last_checkpoint 197 | train_result = trainer.train(resume_from_checkpoint=checkpoint) 198 | trainer.save_model() # Saves the tokenizer too for easy upload 199 | 200 | metrics = train_result.metrics 201 | metrics["train_samples"] = len(train_datasets) 202 | trainer.log_metrics("train", metrics) 203 | trainer.save_metrics("train", metrics) 204 | trainer.save_state() 205 | 206 | 207 | 208 | 209 | def _mp_fn(index): 210 | # For xla_spawn (TPUs) 211 | main() 212 | 213 | 214 | if __name__ == "__main__": 215 | main() 216 | -------------------------------------------------------------------------------- /training/train_pl.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import sys 3 | import os 4 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__),'..'))) 5 | 6 | import logging 7 | import torch 8 | from deep_training.data_helper import ModelArguments, DataArguments, TrainingArguments 9 | from deep_training.trainer.pl.modelcheckpoint import ModelCheckpointEx 10 | from lightning import Trainer 11 | from lightning.pytorch.callbacks import LearningRateMonitor 12 | from lightning.pytorch.strategies import DeepSpeedStrategy 13 | from transformers import HfArgumentParser 14 | from data_utils import NN_DataHelper, config_args, get_deepspeed_config,global_args 15 | from transformers import HfArgumentParser, PreTrainedTokenizer,PretrainedConfig 16 | from deep_training.zoo.model_zoo.llm.llm_model import MyTransformer, PetlArguments,PromptArguments 17 | 18 | assert global_args["trainer_backend"] == "pl" 19 | 20 | def main(): 21 | parser = HfArgumentParser((ModelArguments, TrainingArguments, DataArguments, PetlArguments,PromptArguments)) 22 | model_args, training_args, data_args, lora_args,prompt_args = parser.parse_dict(config_args) 23 | lora_args = lora_args.config 24 | prompt_args = prompt_args.config 25 | 26 | output_weight_dir = './best_ckpt' 27 | 28 | dataHelper = NN_DataHelper(model_args, training_args, data_args) 29 | config_kwargs = {"torch_dtype": torch.float16} 30 | if global_args["config_merge"]: 31 | config_kwargs.update(global_args["config_merge"]) 32 | tokenizer, config, _, _ = dataHelper.load_tokenizer_and_config(config_kwargs=config_kwargs) 33 | 34 | 35 | 36 | dataHelper.make_dataset_all() 37 | 38 | 39 | is_bf16_supported = torch.cuda.is_bf16_supported() 40 | precision = global_args["precision"] 41 | if precision == "auto": 42 | # 精度 根据实际情况做调整 43 | if is_bf16_supported: 44 | precision = 'bf16' 45 | else: 46 | precision = '16' 47 | 48 | if global_args["quantization_config"] is not None and global_args["quantization_config"].load_in_8bit: 49 | precision = "32" 50 | 51 | # 精度 根据实际情况做调整 52 | if precision == "bf16": 53 | config.bf16 = True 54 | config.fp16 = False 55 | config.fp32 = False 56 | elif precision == "16": 57 | config.bf16 = False 58 | config.fp16 = True 59 | config.fp32 = False 60 | elif precision == "32": 61 | config.bf16 = False 62 | config.fp16 = False 63 | config.fp32 = True 64 | else: 65 | raise NotImplemented 66 | 67 | 68 | deepspeed_config = get_deepspeed_config(precision) 69 | strategy = 'ddp' if torch.cuda.device_count() > 1 else 'auto' 70 | if deepspeed_config is not None and len(deepspeed_config): 71 | strategy = DeepSpeedStrategy(config=deepspeed_config, ) 72 | 73 | checkpoint_callback = ModelCheckpointEx( 74 | # monitor='loss', 75 | dirpath=output_weight_dir, 76 | save_weights_only=True, 77 | save_last=True, 78 | # every_n_train_steps=2000 // training_args.gradient_accumulation_steps, 79 | every_n_epochs=1, 80 | lora_args=lora_args, 81 | # monitor="loss",mode = "min", save_top_k = 10 按loss存储10个模型 82 | monitor="step", mode="max", 83 | save_top_k=10, # 按步存储最后10个模型 84 | ) 85 | 86 | 87 | trainer = Trainer( 88 | callbacks=[checkpoint_callback,LearningRateMonitor(logging_interval='step')], 89 | max_epochs=training_args.max_epochs, 90 | max_steps=training_args.max_steps, 91 | accelerator="gpu", 92 | devices=data_args.devices, 93 | enable_progress_bar=True, 94 | default_root_dir=data_args.output_dir, 95 | gradient_clip_val=training_args.max_grad_norm, 96 | accumulate_grad_batches=training_args.gradient_accumulation_steps, 97 | num_sanity_val_steps=0, 98 | strategy=strategy, 99 | #lora int8 precision='32' 100 | # 可以自行尝试 "32": "32-true", "16": "16-mixed", "bf16": "bf16-mixed" 101 | precision=precision, 102 | 103 | ) 104 | 105 | transformer_args = dict(config=config, model_args=model_args, training_args=training_args, lora_args=lora_args, 106 | prompt_args=prompt_args, 107 | quantization_config=global_args["quantization_config"], 108 | device_map={"": trainer.local_rank} if trainer.world_size > 1 else "auto", 109 | torch_dtype=torch.float16, 110 | # new_num_tokens=len(tokenizer), # 可能扩充词 , 还有一些隐藏token, 如果不需要可自行注释 111 | auto_prepare_kbit_training=True, 112 | use_gradient_checkpointing=False 113 | ) 114 | 115 | if transformer_args["quantization_config"] is None: 116 | transformer_args.pop("device_map") 117 | 118 | pl_model = MyTransformer(**transformer_args) 119 | 120 | config.save_pretrained(output_weight_dir) 121 | 122 | #恢复权重继续训练 123 | # pl_model.load_sft_weight('./best_ckpt/best.pt',is_trainable=True) 124 | 125 | 126 | # # Finetune 127 | # if config.bf16: 128 | # pl_model = pl_model.bfloat16() 129 | # else: 130 | # pl_model = pl_model.float() 131 | 132 | 133 | def dataset_loader_filter_fn(dataset): 134 | print('*' * 30, 'total', len(dataset)) 135 | return dataset 136 | 137 | 138 | train_datasets = dataHelper.load_distributed_random_sampler( 139 | dataHelper.load_dataset_files()["train_files"], 140 | with_load_memory=data_args.data_backend == 'record', 141 | collate_fn=dataHelper.collate_fn, 142 | batch_size=training_args.train_batch_size, 143 | drop_last=True, # 多卡建议扔掉 144 | num_processes=trainer.world_size, process_index=trainer.global_rank, 145 | dataset_loader_filter_fn=dataset_loader_filter_fn, 146 | num_workers=0 147 | ) 148 | 149 | if train_datasets is not None: 150 | trainer.fit(pl_model, train_dataloaders=train_datasets) 151 | 152 | def _mp_fn(index): 153 | # For xla_spawn (TPUs) 154 | main() 155 | 156 | 157 | if __name__ == "__main__": 158 | main() 159 | --------------------------------------------------------------------------------