├── .gitignore ├── LICENSE ├── README.md ├── config.py ├── datasets ├── expand_vocab │ ├── common-char-3500.txt │ └── common-word-6000.txt ├── finetune │ └── example │ │ ├── eval │ │ └── example.json │ │ ├── test │ │ └── example.json │ │ └── train │ │ └── example.json ├── pretrain │ └── example │ │ └── train │ │ └── tianlongbabu.txt └── rm │ └── example │ ├── eval │ └── example.json │ ├── test │ └── example.json │ └── train │ └── example.json ├── deepspeed_configs ├── zero_stage2_config.json └── zero_stage3_config.json ├── engines ├── __init__.py ├── data.py ├── models.py ├── predict.py ├── train.py └── utils │ ├── __init__.py │ ├── check_load_config.py │ ├── convert_rwkv_to_hf.py │ ├── cpm_quantizer.py │ ├── dispatch_to_multi_gpu.py │ ├── expand_vocab.py │ ├── logits_process.py │ ├── metrics.py │ ├── parse_text.py │ ├── print_parameters.py │ ├── prompt_template.py │ └── trainer.py ├── main.py └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | .idea/ 161 | checkpoint/ 162 | workspace* 163 | .vscode 164 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # LLMs Tool 2 | ![Authour](https://img.shields.io/badge/Author-stanleylsx-red.svg) 3 | [![License Apache 2.0](https://img.shields.io/badge/license-Apache%202.0-blue.svg)](LICENSE) 4 | ![python_version](https://img.shields.io/badge/Python-3.10%2B-green.svg) 5 | [![torch_version](https://img.shields.io/badge/torch-2.0%2B-pink.svg)](requirements.txt) 6 | 7 | 8 | ## Introduction 9 | 一个基于🤗[HuggingFace](https://huggingface.co/)开发的大语言模型训练、测试工具。支持不同模型的webui、终端预测,支持各模型的低参数量及全参数模型的预训练、奖励模型训练以及RLHF训练(PPO和DPO两种方法)。同时支持deepspeed分布式训练。 10 | 11 | 作者习惯于把配置和要做的事情都写在一个配置文件里面,然后以一个主函数作为入口直接运行,所以才有了这个项目,喜欢指令的朋友们可以改回去使用。 12 | 13 | 14 | ## Updates 15 | Date| Detail 16 | :---|--- 17 | 2023-10-30|通过attention_sinks支持[StreamingLLM](https://arxiv.org/abs/2309.17453) 18 | 2023-10-25|基于sentencepiece实现词表扩充功能 19 | 2023-10-24|支持使用[NEFTune](https://arxiv.org/abs/2310.05914)对LLM进行noise tune 20 | 2023-10-09|增加扩充词表后Embedding初始化方式 21 | 2023-10-08|LLama和Falcon两类模型支持Flash Attention2 22 | 2023-09-26|支持模型预训练 23 | 2023-09-11|多轮对话的[Firefly的loss](https://mp.weixin.qq.com/s/nhogoWnzl3nrs_77r38_UA)训练函数集成 24 | 2023-09-04|支持部分可以从配置修改使用NTK的模型 25 | 2023-08-24|支持deepspeed-ZeRo2分布式训练 26 | 2023-08-23|RLHF的DPO方法对各个模型的训练支持 27 | 2023-08-21|RLHF的PPO方法对各个模型的训练支持 28 | 2023-08-08|奖励模型训练 29 | 2023-07-25|初始仓库 30 | 31 | ## Requirement 32 | 几个重要环境: 33 | * python:3.10+ 34 | * torch:2.0.1+ 35 | * bitsandbytes:不同操作系统下需要对应安装不同的包(Linux下0.39.0+,Windows下要专门下载对应的wheel本地安装) 36 | 37 | 其它环境见requirements.txt 38 | 目前FlashAttention作者未主动兼容和测试Windows操作环境[issues](https://github.com/Dao-AILab/flash-attention/issues/565),若在Windows上不用安装flash-attn这个包。 39 | 40 | ## Feature 41 | 42 | ### Supported models 43 | 大模型经过SFT(然后做RLHF)之后可用于对话任务Chat,面世的Chat大部分都没有重新训练基座,或者是基于同样的基座结构用数据重新预训练了一个基座,下表是验证过的被此项目支持的基座,相应的也支持同样结构的衍生和Chat模型。 44 | 45 | Model | Scale | Series 46 | :--------|--------------|-------- 47 | ChatGLM1 | 6B |[chatglm1](https://huggingface.co/THUDM/chatglm-6b) 48 | ChatGLM2 | 6B |[chatglm2](https://huggingface.co/THUDM/chatglm2-6b) 49 | ChatGLM3 | 6B |[chatglm3](https://huggingface.co/THUDM/chatglm3-6b) 50 | Qwen | 1.8B、7B、14B |[Qwen](https://huggingface.co/Qwen) 51 | Bloom | 560M、9B、7B1M|[bloom](https://huggingface.co/bigscience/bloom)、[bloomz](https://huggingface.co/bigscience/bloomz) 52 | LLama1 | 3B、7B、13B |[openllama](https://huggingface.co/openlm-research)、[chinese-alpaca](https://github.com/ymcui/Chinese-LLaMA-Alpaca)、[ziya](https://huggingface.co/IDEA-CCNL) 53 | LLama2 | 7B、13B |[llama2](https://huggingface.co/meta-llama)、[orca-2](https://huggingface.co/microsoft/Orca-2-7b) 54 | Baichuan | 7B、13B |[baichuan](https://huggingface.co/baichuan-inc) 55 | Baichuan2| 7B、13B |[baichuan2](https://huggingface.co/baichuan-inc) 56 | Falcon | 7B |[falcon](https://huggingface.co/tiiuae/falcon-7b)、[Orca-2](https://huggingface.co/Linly-AI) 57 | Aquila | 7B |[aquila](https://huggingface.co/BAAI) 58 | Aquila2 | 7B |[aquila](https://huggingface.co/BAAI) 59 | InternLM | 7B、20B |[internlm](https://huggingface.co/internlm) 60 | MOSS | 16B |[MOSS](https://huggingface.co/fnlp) 61 | XVERSE | 13B |[XVERSE](https://huggingface.co/xverse/XVERSE-13B-Chat) 62 | Mistral | 7B |[Mistral](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1) 63 | Yi | 6B |[Yi](https://huggingface.co/01-ai/Yi-6B-Chat) 64 | 65 | * 未进入上表的模型或参数规模暂时没有使用该项目测试过。 66 | 67 | ### Template Prompt 68 | 因为很多训练者都是基于上述的基座模型或者Chat模型继续训练,但是它们采用了不同的template prompt,所以下载相关的模型后,需要根据这些模型的要求新加入和它们相适配的template prompt,除了加载这些模型官方需要的template prompt外,本项目还给了一些template prompt,比如ziya、openbuddy等等的模板。 69 | 70 | Template Prompt|Website 71 | :--------------|--------- 72 | chatglm | [chatglm2](https://huggingface.co/THUDM/chatglm2-6b) 73 | chatglm3 | [chatglm3](https://huggingface.co/THUDM/chatglm3-6b) 74 | alpaca | [Chinese-LLaMA-Alpaca](https://github.com/ymcui/Chinese-LLaMA-Alpaca) 75 | vicuna | [Chinese-Vicuna](https://github.com/Facico/Chinese-Vicuna) 76 | belle | [BELLE](https://github.com/LianjiaTech/BELLE) 77 | ziya | [Ziya](https://huggingface.co/IDEA-CCNL/Ziya-LLaMA-13B-v1) 78 | aquila | [AquilaChat](https://huggingface.co/BAAI/AquilaChat-7B) 79 | firefly | [Firefly](https://github.com/yangjianxin1/Firefly) 80 | openbuddy | [OpenBuddy](https://huggingface.co/OpenBuddy) 81 | internlm | [Internlm](https://huggingface.co/internlm) 82 | baichuan | [Baichuan](https://huggingface.co/baichuan-inc/Baichuan-13B-Chat) 83 | baichuan2 | [Baichuan2](https://github.com/baichuan-inc/Baichuan2) 84 | qwen | [Qwen](https://github.com/QwenLM/Qwen-7B) 85 | moss | [MOSS](https://github.com/OpenLMLab/MOSS) 86 | linksoul | [LinkSoul](https://huggingface.co/LinkSoul) 87 | xverse | [XVERSE](https://huggingface.co/xverse/XVERSE-13B-Chat) 88 | tigerbot | [TigerBot](https://github.com/TigerResearch/TigerBot) 89 | flagalpha | [FlagAlpha](https://github.com/FlagAlpha/Llama2-Chinese) 90 | orca-2 | [Orca-2](https://huggingface.co/microsoft/Orca-2-7b) 91 | yi | [yi](https://huggingface.co/01-ai/Yi-6B-Chat) 92 | 93 | ### Training methods 94 | 95 | Method |Supported| 96 | :-------------|---------| 97 | Full Parameter| ✅ | 98 | Lora | ✅ | 99 | AdaLora | ✅ | 100 | QLora | ✅ | 101 | Prompt Tuning | ✅ | 102 | P Tuning | ✅ | 103 | Prefix Tuning | ✅ | 104 | 105 | * 使用Lora和AdaLora都支持QLora训练,但是量化方式需要选择基于bitsandbytes的bnb量化方式,可支持4bit和8bit量化训练。以下是开启Qlora训练的是必要配置参数(ModelArguments中): 106 | ``` 107 | quantization: Optional[str] = field( 108 | default='bnb', 109 | metadata={ 110 | # 如果使用qlora只能选择bnb,两种量化方式区别不大。 111 | 'help': 'The specific model version to use (can be a branch name, tag name or commit id).', 112 | 'choices': ['cpm', 'bnb'], 113 | } 114 | ) 115 | quantization_bit: Optional[int] = field( 116 | default=None, 117 | metadata={ 118 | # 使用8bit量化还是4bit量化? 119 | 'help': 'The number of bits to quantize the model.', 120 | 'choices': [4, 8], 121 | } 122 | ) 123 | ``` 124 | 125 | ### Quantization 126 | 127 | 两种量化方式分别为基于bitsandbytes的bnb和cpm_kernels组件的cpm,其中cpm量化脚本来自[quantization.py](https://huggingface.co/THUDM/chatglm2-6b/blob/main/quantization.py)。 128 | 129 | ### Metric 130 | 不同训练阶段跑测试集时会输出下面一些常规的生成模型评估结果,结果仅限参考,大模型的事实性评估目前没有更好的办法,都是各个模型出品方或评测机构在各维度上制作数据集做评测,相对比较主观。 131 | 132 | Metric |Supported| Training Stage | 133 | :-------|---------|---------------------| 134 | Rouge-1 | ✅ |SFT Training | 135 | Rouge-2 | ✅ |SFT Training | 136 | Rouge-l | ✅ |SFT Training | 137 | ppl | ✅ |Pretrain、SFT Training| 138 | accuracy| ✅ |PPO-RM Training | 139 | 140 | ## Getting start 141 | 开始之前,需要确定试验的模型,并把整个模型文件从huggingface上下载下来,完成两步: 142 | 1. 在ModelArguments中配置好model_type和model_path两个参数,如果除了model_path的基座模型外还有adapter模型,则需将adapter模型的地址配置到checkpoint_dir中。 143 | 144 | ``` 145 | model_type: str = field( 146 | default='internlm', 147 | metadata={ 148 | # 模型类型 149 | 'help': 'Model type.', 150 | 'choices': ['chatglm', 'qwen', 'llama', 'falcon', 'baichuan', 'aquila', 'internlm', 'moss', 'bloom', 'rwkv'], 151 | } 152 | ) 153 | model_path: str = field( 154 | default='/home/XXXXX/llm_models/internLM/intern-chat-7b', 155 | metadata={ 156 | # 从huggingface.co/models上下载的模型保存到本地的路径。 157 | 'help': 'Local path to pretrained model or model identifier from huggingface.co/models.' 158 | } 159 | ) 160 | checkpoint_dir: Optional[str] = field( 161 | default=None, 162 | metadata={ 163 | # 保存下载的或者自己训练的adapter增量模型的地方。 164 | 'help': 'Path to save the (delta) model checkpoints as well as the configurations automatically.', 165 | } 166 | ) 167 | ``` 168 | 2. 在DataTrainingArguments中修改prompt_template使用和该模型配套的template,这个template一般是SFT之后的模型才会有,且与训练者有关。所以如果该项目未提供的,则需要自己修改engines/utils/prompt_template.py文件,添加新的template。 169 | ``` 170 | prompt_template: Optional[str] = field( 171 | default='internlm', 172 | metadata={ 173 | # 选择对应模型的模板prompt,一般Chat模型的出品方都会有一个固定的prompt。 174 | 'help': 'Which template to use for constructing prompts in training and inference.' 175 | } 176 | ) 177 | ``` 178 | 179 | ### Inference 180 | 此处提供两种预测方式,分别是基于gradio的webUI预测和终端预测。需要在config.py中对应修改mode,然后运行main.py。 181 | 182 | Mode | Inference Type | 183 | :-----------------|----------------| 184 | web_inference | WebUI | 185 | terminal_inference| Trminal | 186 | 187 | * 预测的时候,模型会优先从你定义的ModelArguments中的checkpoint_dir读取,如果该文件下没有参数文件,则从TrainingArguments的output_dir文件夹加载,如果都没有则只加载最初的基座模型。 188 | 189 | #### NTK 190 | 目前原生的config就能支持NTK方法的有[chatglm2-6b-32k](https://huggingface.co/THUDM/chatglm2-6b-32k)、LLama系列、Falcon系列和[Qwen-7B-Chat](https://huggingface.co/Qwen/Qwen-7B-Chat): 191 | 192 | Model |Position Encoder|Support NTK Type| 193 | :--------------|----------------|----------------| 194 | chatglm2-6b-32k| Rope | Linear | 195 | Qwen-7B-Chat | Rope | Dynamic | 196 | LLama系列 | Rope |Dynamic、Linear | 197 | Falcon系列 | Rope |Dynamic、Linear | 198 | 199 | * 其他的模型需要自己更改原始的模型文件去支持NTK方法,比如可用于Alibi编码的模型Baichuan、Falcon、Bloom系列的[NTK-ALibi](https://github.com/keezen/ntk_alibi)。一般来说,NTK主要用在推断的时候突破模型的输入token限制,但是训练的时候打开NTK可能会得不到想要的效果。 200 | * Falcon系列的模型HF官方提供了两种编码方式,分别是Rope和Alibi,但是tiiuae官方目前只有Alibi的实现,不知道此举为何,所以此处仅支持使用Rope编码方式的NTK方法。 201 | 202 | ## Train 203 | 204 | ### Pretrain 205 | 预训练数据参考datasets/pretrain/example/train下面的文件,数据为txt格式存储,制作数据集最好能够向例子给的一样,一行为一句话,但是最好不大于模型接收的最大token长度。然后把数据路径填写到DataTrainingArguments配置里面: 206 | ``` 207 | train_file_dir: Optional[str] = field( 208 | default='datasets/pretrain/example/train', 209 | metadata={ 210 | # 训练集保存的路径。 211 | 'help': 'The train json data file folder.' 212 | } 213 | ) 214 | validation_file_dir: Optional[str] = field( 215 | default='datasets/pretrain/example/train', 216 | metadata={ 217 | # 验证集保存的路径。 218 | 'help': 'The evaluation json file folder.' 219 | } 220 | ) 221 | ``` 222 | 开启训练的时候,需要在config.py中将mode修改为**pretrain**,然后运行main.py。 223 | 224 | ### SFT training 225 | 226 | 指令微调数据参考datasets/finetune/example/train下面的文件,数据由instruction、input、output和history四个字段组成。 227 | ``` 228 | [ 229 | { 230 | "instruction": "10乘以10等于多少?", 231 | "input": "", 232 | "output": "10乘以10等于100。", 233 | "history": [ 234 | "你好呀。", 235 | "你好,请问您有什么需要帮助的吗?", 236 | "好的,我想问下你是谁?", 237 | "我是一个AI模型,能够解决你提出的问题。" 238 | ] 239 | }, 240 | ... 241 | ] 242 | ``` 243 | 如上面所示history字段需要按照一问一答的格式存储对话历史,用于模型训练。如果没有历史对话需要让history为空列表: 244 | ``` 245 | [ 246 | { 247 | "instruction": "你身份是什么?", 248 | "input": "", 249 | "output": "我是一个AI智能助手,由XX公司训练,我将力所能及的帮助你解决问题。", 250 | "history": [] 251 | }, 252 | ... 253 | ] 254 | ``` 255 | 256 | 使用的时候把数据路径填写到DataTrainingArguments配置里面: 257 | ``` 258 | train_file_dir: Optional[str] = field( 259 | default='datasets/finetune/example/train', 260 | metadata={ 261 | # 训练集保存的路径。 262 | 'help': 'The train json data file folder.' 263 | } 264 | ) 265 | validation_file_dir: Optional[str] = field( 266 | default='datasets/finetune/example/eval', 267 | metadata={ 268 | # 验证集保存的路径。 269 | 'help': 'The evaluation json file folder.' 270 | } 271 | ) 272 | ``` 273 | 274 | 开启训练的时候,需要在config.py中对应修改mode为**sft_train**,然后在TrainingArguments中配置好各项训练参数,然后运行main.py。 275 | 框架支持测试SFT训练的效果,测试前在DataTrainingArguments中配置test_file为测试数据集所在的路径,然后在config.py中将mode修改为**sft_batch_test**,然后运行main.py。 276 | ``` 277 | test_file: Optional[str] = field( 278 | default='datasets/finetune/test', 279 | metadata={ 280 | # 测试集保存的路径。 281 | 'help': 'The test file.' 282 | } 283 | ) 284 | ``` 285 | 286 | ### RM training 287 | 奖励模型训练数据参考datasets/rm/example/train下面的文件,数据由instruction、input、output三个字段组成。output是一个两元素列表,第一个元素是采纳的答案,第二个是拒绝的答案。使用的时候把训练奖励模型的数据一样填写到DataTrainingArguments配置里面。然后需要在config.py中对应修改mode为**rm_train**,在TrainingArguments中配置好各项训练参数,运行main.py。 288 | ``` 289 | train_file_dir: Optional[str] = field( 290 | default='datasets/rm/example/train', 291 | metadata={ 292 | # 训练集保存的路径。 293 | 'help': 'The train json data file folder.' 294 | } 295 | ) 296 | validation_file_dir: Optional[str] = field( 297 | default='datasets/rm/example/eval', 298 | metadata={ 299 | # 验证集保存的路径。 300 | 'help': 'The evaluation json file folder.' 301 | } 302 | ) 303 | ``` 304 | 305 | 框架支持测试奖励模型训练的效果,首先需要在DataTrainingArguments中配置test_file为测试数据集所在的路径,然后在config.py中将mode修改为**rm_batch_test**,运行main.py,奖励模型测试只会输出模型的准确率。 306 | 307 | * 奖励模型训练不支持第一代ChatGLM6B,因为项目用trl的AutoModelForCausalLMWithValueHead组件是基于CausalLM模型的。ChatGLM6B是基于Prefix LM实现的。 308 | 309 | ### RLHF training 310 | #### PPO 311 | 在进行基于PPO模型的RLHF训练之前,需要一个奖励模型和一个需要被RLHF微调的SFT模型,需要把他们配置到ModelArguments中如下: 312 | ``` 313 | checkpoint_dir: Optional[str] = field( 314 | default='checkpoint/sft', 315 | metadata={ 316 | # 保存下载的或者自己训练的adapter增量模型的地方,在RLHF时候,此处需要填写指令微调后模型所在的文件地址。 317 | 'help': 'Path to save the (delta) model checkpoints as well as the configurations automatically.', 318 | } 319 | ) 320 | reward_model_checkpoint: str = field( 321 | default='checkpoint/rm', 322 | metadata={ 323 | # 在RLHF时候,此处需要填写奖励模型所在的文件地址 324 | 'help': 'The checkpoint of reward model.' 325 | } 326 | ) 327 | ``` 328 | PPO方法对模型进行强化学习训练的数据和SFT阶段训练的数据的格式是一致的,此外使用的时候还需要在TrainingArguments中把PPO的配置填写好,在config.py中将mode修改为ppo_train,然后运行main.py。训练的结果将会通过wandb的格式记录在训练输出的文件夹中。 329 | 330 | #### DPO 331 | 在进行基于DPO模型的RLHF训练之前,只需要一个被RLHF微调的SFT模型,如果是基于adapter的模型还需要把adapter配置到ModelArguments中如下: 332 | ``` 333 | model_path: str = field( 334 | default='/home/XXX/ChatGLM/ChatGLM2-6B-32k', 335 | metadata={ 336 | # 从huggingface.co/models上下载的模型保存到本地的路径或者自己的模型。 337 | 'help': 'Local path to pretrained model or model identifier from huggingface.co/models.' 338 | } 339 | ) 340 | checkpoint_dir: Optional[str] = field( 341 | default='checkpoint/sft', 342 | metadata={ 343 | # 保存下载的或者自己训练的adapter增量模型的地方,在RLHF时候,此处需要填写指令微调后模型所在的文件地址。 344 | 'help': 'Path to save the (delta) model checkpoints as well as the configurations automatically.', 345 | } 346 | ) 347 | ``` 348 | DPO方法对模型进行强化学习训练的数据和奖励模型的数据是一致的,在config.py中将mode修改为dpo_train,然后运行main.py。训练的结果将会通过wandb的格式记录在训练输出的文件夹中。 349 | 350 | * 如果前面使用的是adapter在SFT模型上训练的模型,RLHF的时候项目会融合前面的adapter后创建新的adapter继续训练。 351 | 352 | ### Training Arguments 353 | 常用的一些参数如下: 354 | 355 | Arguments | Describe | 356 | :----------------------------|-------------------------| 357 | fine_tuning_type | 训练方式 | 358 | use_firefly_loss | 使用Firefly loss训练模型 | 359 | output_dir | 训练结果输出的文件夹 | 360 | num_train_epochs | 训练的轮次 | 361 | gradient_accumulation_steps | 梯度累积 | 362 | per_device_train_batch_size | 每个设备上的批大小 | 363 | learning_rate | 学习率 | 364 | fp16 | 设置True为开混合精度运算 | 365 | 366 | 367 | * Lora和其它adapter训练方式的配置参数也在TrainingArguments中,这里面要注意lora_target的设置要根据自己的模型结构来,配置中给了一些参考。 368 | * Firefly Loss仅作用在SFT训练阶段且不支持ChatGLM6B等Prefix LM模型。 369 | 370 | ### DeepSpeed 371 | 使用deepspeed进行训练需要在TrainingArguments指定deepspeed的config文件(项目中提供了stage2的deepspeed配置): 372 | ``` 373 | deepspeed: Optional[str] = field( 374 | default='deepspeed_configs/zero_stage2_config.json', 375 | metadata={ 376 | 'help': 'Enable deepspeed and pass the path to deepspeed json config file (e.g. ds_config.json) ' 377 | 'or an already loaded json file as a dict' 378 | } 379 | ) 380 | ``` 381 | 配置好后在终端输入(单机多卡): 382 | ``` 383 | deepspeed --num_gpus 3 --master_port=9901 main.py 384 | ``` 385 | 386 | * 多机多卡需要指定更多的参数,可以参考hugingface的deepspeed文档。 387 | 388 | ## Others 389 | Mode | Describe 390 | :------------------- | ------------------------------------------------------------ 391 | merge_lora_model | 将lora模型和基座模型融合,支持lora和adalora之后的权重合并,其它的训练方法产生的adapter直接通过peft加载即可,不支持合并 392 | show_model_info | 打印模型的结构和模型的参数 393 | save_quantized_model | 量化并保存量化模型 394 | expand_vocab | 根据给定语料扩充词表(如扩充中文词表、垂域词表等) 395 | 396 | * merge_peft_model和save_quantized_model需要在ModelArguments设置输出地址。 397 | ``` 398 | quantized_or_merged_output_dir: Optional[str] = field( 399 | default=None, 400 | metadata={ 401 | # 当你想保存量化后的模型或者融合后的模型时,处理后的模型保存的地址。 402 | 'help': 'Path to save the quantized or merged model checkpoints as well as the configurations manually.', 403 | } 404 | ) 405 | ``` 406 | * 使用bnb和cpm量化将会默认对除了输出层的所有线性层进行量化。 407 | * 使用expand_vocab方法进行词表扩充时,需要指定训练词表的语料路径(文件或文件夹均可)。仅支持 `.txt` 与 `.tsv` 格式。词表扩充后,一般需要继续预训练。 408 | 409 | ## Todo 410 | - [x] 奖励模型训练 411 | - [x] PPO模型训练 412 | - [x] DPO模型训练 413 | - [x] 支持Deepspeed训练 414 | - [x] [NTK-Aware Scaled RoPE](https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/?rdt=35901)集成 415 | - [x] 多轮对话的[Firefly的loss](https://mp.weixin.qq.com/s/nhogoWnzl3nrs_77r38_UA)函数集成 416 | - [x] 支持LLM增量预训练 417 | - [x] 对LLama和Falcon增加Flash Attention2 418 | - [ ] mmlu、cmmlu和C-Eval自动化评估 419 | 420 | 421 | ## Citation 422 | 423 | 如果你在研究中使用了该项目,请按如下格式引用: 424 | 425 | ```latex 426 | @misc{LLMs Tool, 427 | title={LLMs Tool: a tool for large language models}, 428 | author={Shouxian Li}, 429 | year={2023}, 430 | howpublished={\url{https://github.com/stanleylsx/llms_tool}}, 431 | } 432 | ``` 433 | 434 | ## Star History 435 | 436 | ![Star History Chart](https://api.star-history.com/svg?repos=stanleylsx/llms_tool&type=Date) 437 | -------------------------------------------------------------------------------- /datasets/finetune/example/eval/example.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "instruction": "愚公移山的成语出处,用这个成语造句", 4 | "input": "", 5 | "output": "我们应该有着愚公移山的精神,一点点克服我们所面对的困难。", 6 | "history": [] 7 | }, 8 | { 9 | "instruction": "愚公移山的成语出处,用这个成语造句", 10 | "input": "", 11 | "output": "“愚公移山”(Gùgōng yí shān)是一个成语,起源于一位名叫愚公的老人试图移走一座挡住他回家路的山的故事。\n\n这是一个使用这个成语的句子:\n\"他有愚公移山的精神,对于不可能的事也不会放弃。\"", 12 | "history": [] 13 | }, 14 | { 15 | "instruction": "how to improve my powerpoint skill", 16 | "input": "", 17 | "output": "Improving your PowerPoint skills can be a great way to create more effective and engaging presentations. Here are some tips that you can follow to improve your skills:\n\nKnow your audience: Before you start creating your presentation, consider your audience's needs, preferences, and expectations. This will help you create a presentation that is relevant and tailored to your audience.\n\nPlan your content: Plan the content of your presentation and structure it in a logical and easy-to-follow manner. Use headings, bullet points, and other visual aids to break up the text and make the presentation more easily digestible.\n\nUse appropriate visuals: Use appropriate visuals, such as charts, graphs, images, and videos, to help illustrate your points and keep your audience engaged. Be sure to choose visuals that are relevant to your message and use them effectively.\n\nKeep it simple: Avoid clutter and keep your slides simple and uncluttered. Use a consistent design and color scheme to make your presentation look professional and polished.\n\nPractice: Practice your presentation several times to build confidence and improve your delivery. Time yourself and try to stick to your allotted time, and consider using visual aids like cue cards or slides to help you stay on track.\n\nSeek feedback: Seek feedback from others, such as friends, family, or colleagues, to get an idea of what works well and what needs improvement.\n\nBy following these tips and continually practicing, you'll be well on your way to becoming a skilled PowerPoint presenter!", 18 | "history": [] 19 | }, 20 | { 21 | "instruction": "how to improve my powerpoint skill", 22 | "input": "", 23 | "output": "以下是有关如何提高 PowerPoint 技能的一些提示:\n 使用模板:利用模板为您的演示文稿创建具有凝聚力的外观。 这也将节省您的时间,因为您不必从头开始创建设计。\n 使用视觉效果:结合图像、图表和视频等视觉效果来增强您的演示文稿并使其对您的观众更具吸引力。\n 保持简单:避免在幻灯片上使用过多的文字,并使用项目符号来分解信息。 这将帮助您的听众专注于您的信息并避免信息过载。\n 练习您的演讲:事先排练您的演讲以确保演讲顺利进行。 这将帮助您对内容更加熟悉,并提高您对演示的信心。\n 使用幻灯片切换和动画:结合幻灯片切换和动画可以增加演示文稿的视觉趣味性,但要谨慎使用,以免分散您的信息。\n 编辑和校对:仔细检查您的演示文稿以发现任何错别字或错误。 这将有助于确保专业和精美的最终产品。\n 获得反馈:请同事或朋友审阅您的演示文稿并提供反馈。 这将帮助您确定需要改进的地方,并确保您的信息清晰有效。", 24 | "history": [] 25 | }, 26 | { 27 | "instruction": "发生脱水时最合理的急救方案是什么?", 28 | "input": "", 29 | "output": "患者发生脱水时最好把患者送到阴凉处,并让他引用盐汽水。这既可以降温也可以补充电解质。同时也需要拨打120急救热线,将患者送到医院进行进一步治疗。", 30 | "history": [ 31 | [ 32 | "试简述脱水对身体的影响", 33 | "脱水会对身体造成广泛的影响,包括疲劳、头晕、头晕、头痛、口唇干燥、排尿减少、尿液颜色变深、皮肤干燥和口渴。 脱水还会导致更严重的情况,例如中暑、中暑,甚至器官损伤。 当身体脱水时,它无法适当地调节温度,并且可能会出现电解质失衡。 在严重的情况下,脱水会导致休克、昏迷,甚至死亡。" 34 | ] 35 | ] 36 | }, 37 | { 38 | "instruction": "发生脱水时最合理的急救方案是什么?", 39 | "input": "", 40 | "output": "患者发生脱水时最好把患者送到阴凉处,并让他引用盐汽水。这既可以降温也可以补充电解质。同时也需要拨打120急救热线,将患者送到医院进行进一步治疗。", 41 | "history": [ 42 | [ 43 | "试简述脱水对身体的影响", 44 | "脱水会对身体造成广泛的影响,包括疲劳、头晕、头晕、头痛、口唇干燥、排尿减少、尿液颜色变深、皮肤干燥和口渴。 脱水还会导致更严重的情况,例如中暑、中暑,甚至器官损伤。 当身体脱水时,它无法适当地调节温度,并且可能会出现电解质失衡。 在严重的情况下,脱水会导致休克、昏迷,甚至死亡。" 45 | ] 46 | ] 47 | }, 48 | { 49 | "instruction": "假设你现在是一个AI方向的博士生。", 50 | "input": "", 51 | "output": "好的,我现在是一个AI方向的博士生。请问你有什么需要帮助的吗?", 52 | "history": [] 53 | }, 54 | { 55 | "instruction": "假设你现在是一个AI方向的博士生。", 56 | "input": "", 57 | "output": "好的,我现在是一个AI方向的博士生。请问你有什么需要帮助的吗?", 58 | "history": [] 59 | }, 60 | { 61 | "instruction": "你可以演示一下这个程序运行时的界面吗?", 62 | "input": "", 63 | "output": "这个扫雷游戏的运行时界面如下:\n```\n- - - - - - - - - -\n- - - - - - - - - -\n- - - - - - - - - -\n- - - - - - - - - -\n- - - - - - - - - -\n- - - - - - - - - -\n- - - - - - - - - -\n- - - - - - - - - -\n- - - - - - - - - -\n- - - - - - - - - -\nEnter row and column (separated by space): \n```\n它表示一个10x10的扫雷棋盘,`-`表示该格子还没有被翻开,所以现在是游戏的开局。\n\n您可以输入0开始的行号和列号,以空格隔开,比如`0 3`或`9 9`,以翻开某一行某一列的格子。翻开一个非地雷的格子后,如果周围8格有地雷,就会用数字`1-8`显示当前格子周围的地雷数量,如果没有地雷,就会递归地翻开周围的没有地雷的格子。\n\n比如在某一场游戏中,输入`5 5`(第6行第6列)后,游戏界面更新为如下:\n```\n0 1 - 1 0 1 - - - -\n0 1 - 1 0 1 1 2 - -\n0 1 1 1 0 0 0 1 1 1\n0 0 0 0 0 0 0 0 0 0\n0 0 0 0 0 0 0 0 0 0\n0 0 0 0 0 0 1 1 2 1\n0 0 1 1 1 0 1 - - -\n0 0 1 - 1 0 2 - - -\n2 2 3 - 2 0 1 - - -\n- - - - 1 0 1 - - -\nEnter row and column (separated by space):\n```\n\n当您翻开所有非地雷的格子后,就能取胜。", 64 | "history": [ 65 | [ 66 | "扫雷代码", 67 | "以下是一个简单的扫雷游戏的Python代码:\n```python\nimport random\n\nclass Minesweeper:\n def __init__(self, size, mines):\n self.size = size # 棋盘大小\n self.mines = mines # 地雷数量\n self.board = [[0 for _ in range(size)] for _ in range(size)] # 棋盘,用二维列表表示\n self.visible_board = [['-' for _ in range(size)] for _ in range(size)] # 可见的棋盘,用二维列表表示\n self.game_over = False # 游戏是否结束\n self.generate_mines() # 生成地雷\n\n def generate_mines(self):\n mines_placed = 0 # 记录已经放置的地雷数量\n while mines_placed < self.mines:\n row = random.randint(0, self.size-1) # 随机生成行坐标\n col = random.randint(0, self.size-1) # 随机生成列坐标\n if self.board[row][col] != '*': # 如果该位置没有地雷\n self.board[row][col] = '*' # 在该位置放置地雷\n mines_placed += 1 # 已放置地雷数量加1\n # 遍历该位置周围的8个位置,如果不是地雷,则将数字加1\n for r in range(row-1, row+2):\n for c in range(col-1, col+2):\n if (r >= 0 and r < self.size and c >= 0 and c < self.size\n and self.board[r][c] != '*'):\n self.board[r][c] += 1\n\n def play(self, row, col):\n if self.game_over:\n print('Game Over')\n return\n if self.board[row][col] == '*': # 如果该位置是地雷\n self.visible_board[row][col] = '*' # 在可见棋盘中显示地雷\n self.game_over = True # 游戏结束\n print('Game Over')\n else: # 如果该位置不是地雷\n self.reveal(row, col) # 揭示该位置周围的位置\n self.print_board() # 打印可见棋盘\n if self.check_win(): # 检查是否获胜\n print('You Win!')\n\n def reveal(self, row, col):\n # 如果该位置已经揭示过,或者超出棋盘范围,返回\n if (row < 0 or row >= self.size or col < 0 or col >= self.size\n or self.visible_board[row][col] != '-'):\n return\n # 在可见棋盘中显示该位置的数字\n self.visible_board[row][col] = str(self.board[row][col])\n # 如果该位置数字为0,揭示周围的8个位置\n if self.board[row][col] == 0:\n for r in range(row-1, row+2):\n for c in range(col-1, col+2):\n self.reveal(r, c)\n\n def check_win(self):\n # 遍历整个棋盘,如果有未揭示的非地雷位置,返回False\n for row in range(self.size):\n for col in range(self.size):\n if self.visible_board[row][col] == '-' and self.board[row][col] != '*':\n return False\n return True\n\n def print_board(self):\n for row in range(self.size):\n print(' '.join(self.visible_board[row]))\n\nif __name__ == '__main__':\n size = 10\n mines = 10\n game = Minesweeper(size, mines)\n game.print_board()\n while not game.game_over:\n row, col = input('Enter row and column (separated by space): ').strip().split()\n row, col = int(row), int(col)\n game.play(row, col)\n```\n该代码实现了一个10x10的扫雷游戏,有10个地雷。它在命令行中运行,每次提示用户输入要揭示的行和列。输入回车后,程序会更新可见的棋盘并检查是否获胜或失败。" 68 | ] 69 | ] 70 | }, 71 | { 72 | "instruction": "你的Python代码缩进是乱的,你能修好吗?", 73 | "input": "", 74 | "output": "抱歉,给您输出了错误的代码。现在我已经修复了代码,它应该能正常运行:\n```python\nimport random\n\n# 游戏设置\nROWS = 10\nCOLS = 10\nMINES = 10\n\n# 初始化游戏板\nboard = [[0 for _ in range(COLS)] for _ in range(ROWS)]\nfor _ in range(MINES):\n row, col = random.randint(0, ROWS-1), random.randint(0, COLS-1)\n while board[row][col] == -1:\n row, col = random.randint(0, ROWS-1), random.randint(0, COLS-1)\n board[row][col] = -1\n\n# 计算每个方格周围的地雷数量\nfor row in range(ROWS):\n for col in range(COLS):\n if board[row][col] != -1:\n count = 0\n for r in range(max(0, row-1), min(row+2, ROWS)):\n for c in range(max(0, col-1), min(col+2, COLS)):\n if board[r][c] == -1:\n count += 1\n board[row][col] = count\n\n# 显示游戏板(方便测试)\nfor row in board:\n print(row)\n```\n预期的运行结果是,程序能够正常地初始化一个扫雷游戏板,并在控制台上输出其状态。输出的结果应该是一个10x10的二维列表,其中每个元素代表对应方格的状态,状态值可能为-1(地雷方格)或0~8(周围地雷数量)。\n\n不过这段代码中并没有用户交互功能,它只是实现了扫雷游戏板的初始化过程,即生成一个随机的游戏板,并确定每个方格的状态(包括地雷和周围地雷数量)。这个初始化过程是在程序启动时自动完成的,而没有涉及到用户的输入或操作。", 75 | "history": [ 76 | [ 77 | "扫雷代码", 78 | "import random\n\n# 游戏设置\nROWS = 10\nCOLS = 10\nMINES = 10\n\n# 初始化游戏板\nboard = [[0 for _ in range(COLS)] for _ in range(ROWS)]\nfor _ in range(MINES):\n row, col = random.randint(0, ROWS-1), random.randint(0, COLS-1)\n while board[row][col] == -1:\n row, col = random.randint(0, ROWS-1), random.randint(0, COLS-1)\n board[row][col] = -1\n\n# 计算每个方格周围的地雷数量\nfor row in range(ROWS):\n for col in range(COLS):\n if board[row][col] != -1:\n count = 0\n for r in range(max(0, row-1), min(row+2, ROWS)):\n for c in range(max(0, col-1), min(col+2, COLS)):\n if board[r][c] == -1:\n count += 1\n board[row][col] = count\n\n# 显示游戏板(方便测试)\nfor row in board:\n print(row)\n\n这个示例代码首先创建一个ROWS x COLS的游戏板,然后随机放置MINES个地雷。接下来,它计算每个方格周围的地雷数量,并在游戏板上显示结果。\n\n请注意,此示例代码不包括游戏界面或交互性,但可以作为构建更复杂游戏的基础。例如,您可以使用Tkinter或其他GUI库创建游戏界面,并添加逻辑以处理玩家的单击操作。" 79 | ] 80 | ] 81 | }, 82 | { 83 | "instruction": "单面玻璃的运作原理", 84 | "input": "", 85 | "output": "单面玻璃是一种一面部分反射,另一面透明的玻璃。它是通过在玻璃的一侧涂上一层薄薄的反射材料(如铝)而形成的。该层的应用方式使其仅部分反射光线,允许部分光线穿过玻璃,同时仍将部分光线反射向观察者。这使得观察者可以看到室内,而室内的人看不到室外。", 86 | "history": [] 87 | }, 88 | { 89 | "instruction": "单面玻璃的运作原理", 90 | "input": "", 91 | "output": "单面玻璃巧妙运用了光照强度的原理,透过单侧反光的设计使光线返回其中一边,光线弱的一边可以看到光照强的一侧,而光线强的一侧只能看到返回的光线所呈现的镜像", 92 | "history": [] 93 | }, 94 | { 95 | "instruction": "帮我把这句话倒过来说“我喜欢吃苹果”", 96 | "input": "", 97 | "output": "苹果喜欢吃我", 98 | "history": [] 99 | }, 100 | { 101 | "instruction": "帮我把这句话倒过来说“我喜欢吃苹果”", 102 | "input": "", 103 | "output": "果苹吃欢喜我", 104 | "history": [] 105 | }, 106 | { 107 | "instruction": "1+1 等于几", 108 | "input": "", 109 | "output": "等于2", 110 | "history": [] 111 | }, 112 | { 113 | "instruction": "1+1 等于几", 114 | "input": "", 115 | "output": "2", 116 | "history": [] 117 | }, 118 | { 119 | "instruction": "原来如此,那可以告诉我更多操作二进制的语法吗?", 120 | "input": "", 121 | "output": "可以的,C语言中还有其他操作二进制位的语法,例如:\n1. 位移运算符:<< 和 >>,用于将一个变量的二进制位向左或向右移动。\n2. 位反转运算符:~,用于将一个变量的二进制位反转。\n3. 位异或运算符:^,用于将两个变量的二进制位进行异或运算。", 122 | "history": [ 123 | [ 124 | "C语言中,“ &= |= ”是什么意思?", 125 | "在 C 语言中,\"&=\" 和 \"|=\" 是位运算符,它们用于对一个变量的二进制位进行操作。\n\n\"&=\" 表示按位与赋值,它将一个变量与另一个值进行按位与运算,并将结果赋值给该变量。例如:\n\nint x = 10;\nint y = 5;\nx &= y;\n\n在这个例子中,x 的值将被设置为 10 & 5,即 0。\n\n\"|=\" 表示按位或赋值,它将一个变量与另一个值进行按位或运算,并将结果赋值给该变量。例如:\n\nint x = 10;\nint y = 5;\nx |= y;\n\n在这个例子中,x 的值将被设置为 10 | 5,即 15。\n\n总之,\"&=\" 和 \"|=\" 是操作二进制位的方便语法,它们可以简化一些位运算的代码。" 126 | ] 127 | ] 128 | }, 129 | { 130 | "instruction": "原来如此,那可以告诉我更多操作二进制的语法吗?", 131 | "input": "", 132 | "output": "“<<”和\">>\"分别代表左移和右移。对于整形来说,每一次左移相当于将原数值乘2,而每一次右移相当于将原数值除以2并忽略小数部分。\n比如“int x = 3 << 2”表示将3左移2位,得到的x的值是12。\n“~”表示将一个数的每一位翻转。以八位整形来举例,“5”对应的二进制表示方法是00000101,“~5”则对应11111010,也就是-6", 133 | "history": [ 134 | [ 135 | "C语言中,“ &= |= ”是什么意思?", 136 | "在 C 语言中,\"&=\" 和 \"|=\" 是位运算符,它们用于对一个变量的二进制位进行操作。\n\n\"&=\" 表示按位与赋值,它将一个变量与另一个值进行按位与运算,并将结果赋值给该变量。例如:\n\nint x = 10;\nint y = 5;\nx &= y;\n\n在这个例子中,x 的值将被设置为 10 & 5,即 0。\n\n\"|=\" 表示按位或赋值,它将一个变量与另一个值进行按位或运算,并将结果赋值给该变量。例如:\n\nint x = 10;\nint y = 5;\nx |= y;\n\n在这个例子中,x 的值将被设置为 10 | 5,即 15。\n\n总之,\"&=\" 和 \"|=\" 是操作二进制位的方便语法,它们可以简化一些位运算的代码。" 137 | ] 138 | ] 139 | }, 140 | { 141 | "instruction": "厦门大学怎么样", 142 | "input": "", 143 | "output": "厦门大学是一所高水平的研究型大学,在中国乃至世界享有很高的声誉。该校在国内外的各项排名中都处于领先地位,是全国重点支持的一批“211工程”高校之一。厦门大学的学术研究实力十分雄厚,在各个学科领域都有着显著的成绩。此外,该校也注重学生的全面发展,在师生关系、校园文化、学生活动等方面都提供了良好的环境和平台。总的来说,厦门大学是一所很好的高等学府。", 144 | "history": [] 145 | }, 146 | { 147 | "instruction": "好,谢谢", 148 | "input": "", 149 | "output": "不客气,获取信息的方式有很多种,建议多多宽展信息获取的渠道哦。", 150 | "history": [ 151 | [ 152 | "厦门大学怎么样", 153 | "厦门大学是一所高水平的研究型大学,在中国乃至世界享有很高的声誉。该校在国内外的各项排名中都处于领先地位,是全国重点支持的一批“211工程”高校之一。厦门大学的学术研究实力十分雄厚,在各个学科领域都有着显著的成绩。此外,该校也注重学生的全面发展,在师生关系、校园文化、学生活动等方面都提供了良好的环境和平台。总的来说,厦门大学是一所很好的高等学府。" 154 | ] 155 | ] 156 | }, 157 | { 158 | "instruction": "为什么黑暗中突然开灯会让眼睛感觉到不适?", 159 | "input": "", 160 | "output": "眼睛遇强光则无法看到东西,是因为眼睛的瞳孔需要一段时间才能做出反应,通过瞳孔收缩来适应强光。此外在遇到强光的瞬间,视网膜的感觉细胞中含有太多视紫质,使其对强光非常敏感,需要一段时间后视紫质才会减少。", 161 | "history": [ 162 | [ 163 | "为什么不能关灯玩手机?", 164 | "如果关了灯,屋子里就会很暗。这样,人类的瞳孔就会扩张。在这时候玩手机,手机的过量光线会进去视网膜,诱导病变。因此,关灯时候最高不要玩手机。" 165 | ] 166 | ] 167 | }, 168 | { 169 | "instruction": "你回答的很好啊", 170 | "input": "", 171 | "output": "谢谢您对我的鼓励,我会加倍努力,为您提供更高质量的服务。", 172 | "history": [ 173 | [ 174 | "为什么不能关灯玩手机?", 175 | "长期关灯看手机可能导致青光眼。\n昏暗的光线下用眼,会造成瞳孔长时间张大,堵塞眼内液体循环流通,很容易导致青光眼的发生。此外,眼睛在黑暗中对比敏感度低,会引起眼睛的调节系统负担加剧,进一步还会导致近视加深" 176 | ] 177 | ] 178 | } 179 | ] -------------------------------------------------------------------------------- /datasets/finetune/example/test/example.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "instruction": "愚公移山的成语出处,用这个成语造句", 4 | "input": "", 5 | "output": "我们应该有着愚公移山的精神,一点点克服我们所面对的困难。", 6 | "history": [] 7 | }, 8 | { 9 | "instruction": "愚公移山的成语出处,用这个成语造句", 10 | "input": "", 11 | "output": "“愚公移山”(Gùgōng yí shān)是一个成语,起源于一位名叫愚公的老人试图移走一座挡住他回家路的山的故事。\n\n这是一个使用这个成语的句子:\n\"他有愚公移山的精神,对于不可能的事也不会放弃。\"", 12 | "history": [] 13 | }, 14 | { 15 | "instruction": "how to improve my powerpoint skill", 16 | "input": "", 17 | "output": "Improving your PowerPoint skills can be a great way to create more effective and engaging presentations. Here are some tips that you can follow to improve your skills:\n\nKnow your audience: Before you start creating your presentation, consider your audience's needs, preferences, and expectations. This will help you create a presentation that is relevant and tailored to your audience.\n\nPlan your content: Plan the content of your presentation and structure it in a logical and easy-to-follow manner. Use headings, bullet points, and other visual aids to break up the text and make the presentation more easily digestible.\n\nUse appropriate visuals: Use appropriate visuals, such as charts, graphs, images, and videos, to help illustrate your points and keep your audience engaged. Be sure to choose visuals that are relevant to your message and use them effectively.\n\nKeep it simple: Avoid clutter and keep your slides simple and uncluttered. Use a consistent design and color scheme to make your presentation look professional and polished.\n\nPractice: Practice your presentation several times to build confidence and improve your delivery. Time yourself and try to stick to your allotted time, and consider using visual aids like cue cards or slides to help you stay on track.\n\nSeek feedback: Seek feedback from others, such as friends, family, or colleagues, to get an idea of what works well and what needs improvement.\n\nBy following these tips and continually practicing, you'll be well on your way to becoming a skilled PowerPoint presenter!", 18 | "history": [] 19 | }, 20 | { 21 | "instruction": "how to improve my powerpoint skill", 22 | "input": "", 23 | "output": "以下是有关如何提高 PowerPoint 技能的一些提示:\n 使用模板:利用模板为您的演示文稿创建具有凝聚力的外观。 这也将节省您的时间,因为您不必从头开始创建设计。\n 使用视觉效果:结合图像、图表和视频等视觉效果来增强您的演示文稿并使其对您的观众更具吸引力。\n 保持简单:避免在幻灯片上使用过多的文字,并使用项目符号来分解信息。 这将帮助您的听众专注于您的信息并避免信息过载。\n 练习您的演讲:事先排练您的演讲以确保演讲顺利进行。 这将帮助您对内容更加熟悉,并提高您对演示的信心。\n 使用幻灯片切换和动画:结合幻灯片切换和动画可以增加演示文稿的视觉趣味性,但要谨慎使用,以免分散您的信息。\n 编辑和校对:仔细检查您的演示文稿以发现任何错别字或错误。 这将有助于确保专业和精美的最终产品。\n 获得反馈:请同事或朋友审阅您的演示文稿并提供反馈。 这将帮助您确定需要改进的地方,并确保您的信息清晰有效。", 24 | "history": [] 25 | }, 26 | { 27 | "instruction": "发生脱水时最合理的急救方案是什么?", 28 | "input": "", 29 | "output": "患者发生脱水时最好把患者送到阴凉处,并让他引用盐汽水。这既可以降温也可以补充电解质。同时也需要拨打120急救热线,将患者送到医院进行进一步治疗。", 30 | "history": [ 31 | [ 32 | "试简述脱水对身体的影响", 33 | "脱水会对身体造成广泛的影响,包括疲劳、头晕、头晕、头痛、口唇干燥、排尿减少、尿液颜色变深、皮肤干燥和口渴。 脱水还会导致更严重的情况,例如中暑、中暑,甚至器官损伤。 当身体脱水时,它无法适当地调节温度,并且可能会出现电解质失衡。 在严重的情况下,脱水会导致休克、昏迷,甚至死亡。" 34 | ] 35 | ] 36 | }, 37 | { 38 | "instruction": "发生脱水时最合理的急救方案是什么?", 39 | "input": "", 40 | "output": "患者发生脱水时最好把患者送到阴凉处,并让他引用盐汽水。这既可以降温也可以补充电解质。同时也需要拨打120急救热线,将患者送到医院进行进一步治疗。", 41 | "history": [ 42 | [ 43 | "试简述脱水对身体的影响", 44 | "脱水会对身体造成广泛的影响,包括疲劳、头晕、头晕、头痛、口唇干燥、排尿减少、尿液颜色变深、皮肤干燥和口渴。 脱水还会导致更严重的情况,例如中暑、中暑,甚至器官损伤。 当身体脱水时,它无法适当地调节温度,并且可能会出现电解质失衡。 在严重的情况下,脱水会导致休克、昏迷,甚至死亡。" 45 | ] 46 | ] 47 | }, 48 | { 49 | "instruction": "假设你现在是一个AI方向的博士生。", 50 | "input": "", 51 | "output": "好的,我现在是一个AI方向的博士生。请问你有什么需要帮助的吗?", 52 | "history": [] 53 | }, 54 | { 55 | "instruction": "假设你现在是一个AI方向的博士生。", 56 | "input": "", 57 | "output": "好的,我现在是一个AI方向的博士生。请问你有什么需要帮助的吗?", 58 | "history": [] 59 | }, 60 | { 61 | "instruction": "你可以演示一下这个程序运行时的界面吗?", 62 | "input": "", 63 | "output": "这个扫雷游戏的运行时界面如下:\n```\n- - - - - - - - - -\n- - - - - - - - - -\n- - - - - - - - - -\n- - - - - - - - - -\n- - - - - - - - - -\n- - - - - - - - - -\n- - - - - - - - - -\n- - - - - - - - - -\n- - - - - - - - - -\n- - - - - - - - - -\nEnter row and column (separated by space): \n```\n它表示一个10x10的扫雷棋盘,`-`表示该格子还没有被翻开,所以现在是游戏的开局。\n\n您可以输入0开始的行号和列号,以空格隔开,比如`0 3`或`9 9`,以翻开某一行某一列的格子。翻开一个非地雷的格子后,如果周围8格有地雷,就会用数字`1-8`显示当前格子周围的地雷数量,如果没有地雷,就会递归地翻开周围的没有地雷的格子。\n\n比如在某一场游戏中,输入`5 5`(第6行第6列)后,游戏界面更新为如下:\n```\n0 1 - 1 0 1 - - - -\n0 1 - 1 0 1 1 2 - -\n0 1 1 1 0 0 0 1 1 1\n0 0 0 0 0 0 0 0 0 0\n0 0 0 0 0 0 0 0 0 0\n0 0 0 0 0 0 1 1 2 1\n0 0 1 1 1 0 1 - - -\n0 0 1 - 1 0 2 - - -\n2 2 3 - 2 0 1 - - -\n- - - - 1 0 1 - - -\nEnter row and column (separated by space):\n```\n\n当您翻开所有非地雷的格子后,就能取胜。", 64 | "history": [ 65 | [ 66 | "扫雷代码", 67 | "以下是一个简单的扫雷游戏的Python代码:\n```python\nimport random\n\nclass Minesweeper:\n def __init__(self, size, mines):\n self.size = size # 棋盘大小\n self.mines = mines # 地雷数量\n self.board = [[0 for _ in range(size)] for _ in range(size)] # 棋盘,用二维列表表示\n self.visible_board = [['-' for _ in range(size)] for _ in range(size)] # 可见的棋盘,用二维列表表示\n self.game_over = False # 游戏是否结束\n self.generate_mines() # 生成地雷\n\n def generate_mines(self):\n mines_placed = 0 # 记录已经放置的地雷数量\n while mines_placed < self.mines:\n row = random.randint(0, self.size-1) # 随机生成行坐标\n col = random.randint(0, self.size-1) # 随机生成列坐标\n if self.board[row][col] != '*': # 如果该位置没有地雷\n self.board[row][col] = '*' # 在该位置放置地雷\n mines_placed += 1 # 已放置地雷数量加1\n # 遍历该位置周围的8个位置,如果不是地雷,则将数字加1\n for r in range(row-1, row+2):\n for c in range(col-1, col+2):\n if (r >= 0 and r < self.size and c >= 0 and c < self.size\n and self.board[r][c] != '*'):\n self.board[r][c] += 1\n\n def play(self, row, col):\n if self.game_over:\n print('Game Over')\n return\n if self.board[row][col] == '*': # 如果该位置是地雷\n self.visible_board[row][col] = '*' # 在可见棋盘中显示地雷\n self.game_over = True # 游戏结束\n print('Game Over')\n else: # 如果该位置不是地雷\n self.reveal(row, col) # 揭示该位置周围的位置\n self.print_board() # 打印可见棋盘\n if self.check_win(): # 检查是否获胜\n print('You Win!')\n\n def reveal(self, row, col):\n # 如果该位置已经揭示过,或者超出棋盘范围,返回\n if (row < 0 or row >= self.size or col < 0 or col >= self.size\n or self.visible_board[row][col] != '-'):\n return\n # 在可见棋盘中显示该位置的数字\n self.visible_board[row][col] = str(self.board[row][col])\n # 如果该位置数字为0,揭示周围的8个位置\n if self.board[row][col] == 0:\n for r in range(row-1, row+2):\n for c in range(col-1, col+2):\n self.reveal(r, c)\n\n def check_win(self):\n # 遍历整个棋盘,如果有未揭示的非地雷位置,返回False\n for row in range(self.size):\n for col in range(self.size):\n if self.visible_board[row][col] == '-' and self.board[row][col] != '*':\n return False\n return True\n\n def print_board(self):\n for row in range(self.size):\n print(' '.join(self.visible_board[row]))\n\nif __name__ == '__main__':\n size = 10\n mines = 10\n game = Minesweeper(size, mines)\n game.print_board()\n while not game.game_over:\n row, col = input('Enter row and column (separated by space): ').strip().split()\n row, col = int(row), int(col)\n game.play(row, col)\n```\n该代码实现了一个10x10的扫雷游戏,有10个地雷。它在命令行中运行,每次提示用户输入要揭示的行和列。输入回车后,程序会更新可见的棋盘并检查是否获胜或失败。" 68 | ] 69 | ] 70 | }, 71 | { 72 | "instruction": "你的Python代码缩进是乱的,你能修好吗?", 73 | "input": "", 74 | "output": "抱歉,给您输出了错误的代码。现在我已经修复了代码,它应该能正常运行:\n```python\nimport random\n\n# 游戏设置\nROWS = 10\nCOLS = 10\nMINES = 10\n\n# 初始化游戏板\nboard = [[0 for _ in range(COLS)] for _ in range(ROWS)]\nfor _ in range(MINES):\n row, col = random.randint(0, ROWS-1), random.randint(0, COLS-1)\n while board[row][col] == -1:\n row, col = random.randint(0, ROWS-1), random.randint(0, COLS-1)\n board[row][col] = -1\n\n# 计算每个方格周围的地雷数量\nfor row in range(ROWS):\n for col in range(COLS):\n if board[row][col] != -1:\n count = 0\n for r in range(max(0, row-1), min(row+2, ROWS)):\n for c in range(max(0, col-1), min(col+2, COLS)):\n if board[r][c] == -1:\n count += 1\n board[row][col] = count\n\n# 显示游戏板(方便测试)\nfor row in board:\n print(row)\n```\n预期的运行结果是,程序能够正常地初始化一个扫雷游戏板,并在控制台上输出其状态。输出的结果应该是一个10x10的二维列表,其中每个元素代表对应方格的状态,状态值可能为-1(地雷方格)或0~8(周围地雷数量)。\n\n不过这段代码中并没有用户交互功能,它只是实现了扫雷游戏板的初始化过程,即生成一个随机的游戏板,并确定每个方格的状态(包括地雷和周围地雷数量)。这个初始化过程是在程序启动时自动完成的,而没有涉及到用户的输入或操作。", 75 | "history": [ 76 | [ 77 | "扫雷代码", 78 | "import random\n\n# 游戏设置\nROWS = 10\nCOLS = 10\nMINES = 10\n\n# 初始化游戏板\nboard = [[0 for _ in range(COLS)] for _ in range(ROWS)]\nfor _ in range(MINES):\n row, col = random.randint(0, ROWS-1), random.randint(0, COLS-1)\n while board[row][col] == -1:\n row, col = random.randint(0, ROWS-1), random.randint(0, COLS-1)\n board[row][col] = -1\n\n# 计算每个方格周围的地雷数量\nfor row in range(ROWS):\n for col in range(COLS):\n if board[row][col] != -1:\n count = 0\n for r in range(max(0, row-1), min(row+2, ROWS)):\n for c in range(max(0, col-1), min(col+2, COLS)):\n if board[r][c] == -1:\n count += 1\n board[row][col] = count\n\n# 显示游戏板(方便测试)\nfor row in board:\n print(row)\n\n这个示例代码首先创建一个ROWS x COLS的游戏板,然后随机放置MINES个地雷。接下来,它计算每个方格周围的地雷数量,并在游戏板上显示结果。\n\n请注意,此示例代码不包括游戏界面或交互性,但可以作为构建更复杂游戏的基础。例如,您可以使用Tkinter或其他GUI库创建游戏界面,并添加逻辑以处理玩家的单击操作。" 79 | ] 80 | ] 81 | }, 82 | { 83 | "instruction": "单面玻璃的运作原理", 84 | "input": "", 85 | "output": "单面玻璃是一种一面部分反射,另一面透明的玻璃。它是通过在玻璃的一侧涂上一层薄薄的反射材料(如铝)而形成的。该层的应用方式使其仅部分反射光线,允许部分光线穿过玻璃,同时仍将部分光线反射向观察者。这使得观察者可以看到室内,而室内的人看不到室外。", 86 | "history": [] 87 | }, 88 | { 89 | "instruction": "单面玻璃的运作原理", 90 | "input": "", 91 | "output": "单面玻璃巧妙运用了光照强度的原理,透过单侧反光的设计使光线返回其中一边,光线弱的一边可以看到光照强的一侧,而光线强的一侧只能看到返回的光线所呈现的镜像", 92 | "history": [] 93 | }, 94 | { 95 | "instruction": "帮我把这句话倒过来说“我喜欢吃苹果”", 96 | "input": "", 97 | "output": "苹果喜欢吃我", 98 | "history": [] 99 | }, 100 | { 101 | "instruction": "帮我把这句话倒过来说“我喜欢吃苹果”", 102 | "input": "", 103 | "output": "果苹吃欢喜我", 104 | "history": [] 105 | }, 106 | { 107 | "instruction": "1+1 等于几", 108 | "input": "", 109 | "output": "等于2", 110 | "history": [] 111 | }, 112 | { 113 | "instruction": "1+1 等于几", 114 | "input": "", 115 | "output": "2", 116 | "history": [] 117 | }, 118 | { 119 | "instruction": "原来如此,那可以告诉我更多操作二进制的语法吗?", 120 | "input": "", 121 | "output": "可以的,C语言中还有其他操作二进制位的语法,例如:\n1. 位移运算符:<< 和 >>,用于将一个变量的二进制位向左或向右移动。\n2. 位反转运算符:~,用于将一个变量的二进制位反转。\n3. 位异或运算符:^,用于将两个变量的二进制位进行异或运算。", 122 | "history": [ 123 | [ 124 | "C语言中,“ &= |= ”是什么意思?", 125 | "在 C 语言中,\"&=\" 和 \"|=\" 是位运算符,它们用于对一个变量的二进制位进行操作。\n\n\"&=\" 表示按位与赋值,它将一个变量与另一个值进行按位与运算,并将结果赋值给该变量。例如:\n\nint x = 10;\nint y = 5;\nx &= y;\n\n在这个例子中,x 的值将被设置为 10 & 5,即 0。\n\n\"|=\" 表示按位或赋值,它将一个变量与另一个值进行按位或运算,并将结果赋值给该变量。例如:\n\nint x = 10;\nint y = 5;\nx |= y;\n\n在这个例子中,x 的值将被设置为 10 | 5,即 15。\n\n总之,\"&=\" 和 \"|=\" 是操作二进制位的方便语法,它们可以简化一些位运算的代码。" 126 | ] 127 | ] 128 | }, 129 | { 130 | "instruction": "原来如此,那可以告诉我更多操作二进制的语法吗?", 131 | "input": "", 132 | "output": "“<<”和\">>\"分别代表左移和右移。对于整形来说,每一次左移相当于将原数值乘2,而每一次右移相当于将原数值除以2并忽略小数部分。\n比如“int x = 3 << 2”表示将3左移2位,得到的x的值是12。\n“~”表示将一个数的每一位翻转。以八位整形来举例,“5”对应的二进制表示方法是00000101,“~5”则对应11111010,也就是-6", 133 | "history": [ 134 | [ 135 | "C语言中,“ &= |= ”是什么意思?", 136 | "在 C 语言中,\"&=\" 和 \"|=\" 是位运算符,它们用于对一个变量的二进制位进行操作。\n\n\"&=\" 表示按位与赋值,它将一个变量与另一个值进行按位与运算,并将结果赋值给该变量。例如:\n\nint x = 10;\nint y = 5;\nx &= y;\n\n在这个例子中,x 的值将被设置为 10 & 5,即 0。\n\n\"|=\" 表示按位或赋值,它将一个变量与另一个值进行按位或运算,并将结果赋值给该变量。例如:\n\nint x = 10;\nint y = 5;\nx |= y;\n\n在这个例子中,x 的值将被设置为 10 | 5,即 15。\n\n总之,\"&=\" 和 \"|=\" 是操作二进制位的方便语法,它们可以简化一些位运算的代码。" 137 | ] 138 | ] 139 | }, 140 | { 141 | "instruction": "厦门大学怎么样", 142 | "input": "", 143 | "output": "厦门大学是一所高水平的研究型大学,在中国乃至世界享有很高的声誉。该校在国内外的各项排名中都处于领先地位,是全国重点支持的一批“211工程”高校之一。厦门大学的学术研究实力十分雄厚,在各个学科领域都有着显著的成绩。此外,该校也注重学生的全面发展,在师生关系、校园文化、学生活动等方面都提供了良好的环境和平台。总的来说,厦门大学是一所很好的高等学府。", 144 | "history": [] 145 | }, 146 | { 147 | "instruction": "好,谢谢", 148 | "input": "", 149 | "output": "不客气,获取信息的方式有很多种,建议多多宽展信息获取的渠道哦。", 150 | "history": [ 151 | [ 152 | "厦门大学怎么样", 153 | "厦门大学是一所高水平的研究型大学,在中国乃至世界享有很高的声誉。该校在国内外的各项排名中都处于领先地位,是全国重点支持的一批“211工程”高校之一。厦门大学的学术研究实力十分雄厚,在各个学科领域都有着显著的成绩。此外,该校也注重学生的全面发展,在师生关系、校园文化、学生活动等方面都提供了良好的环境和平台。总的来说,厦门大学是一所很好的高等学府。" 154 | ] 155 | ] 156 | }, 157 | { 158 | "instruction": "为什么黑暗中突然开灯会让眼睛感觉到不适?", 159 | "input": "", 160 | "output": "眼睛遇强光则无法看到东西,是因为眼睛的瞳孔需要一段时间才能做出反应,通过瞳孔收缩来适应强光。此外在遇到强光的瞬间,视网膜的感觉细胞中含有太多视紫质,使其对强光非常敏感,需要一段时间后视紫质才会减少。", 161 | "history": [ 162 | [ 163 | "为什么不能关灯玩手机?", 164 | "如果关了灯,屋子里就会很暗。这样,人类的瞳孔就会扩张。在这时候玩手机,手机的过量光线会进去视网膜,诱导病变。因此,关灯时候最高不要玩手机。" 165 | ] 166 | ] 167 | }, 168 | { 169 | "instruction": "你回答的很好啊", 170 | "input": "", 171 | "output": "谢谢您对我的鼓励,我会加倍努力,为您提供更高质量的服务。", 172 | "history": [ 173 | [ 174 | "为什么不能关灯玩手机?", 175 | "长期关灯看手机可能导致青光眼。\n昏暗的光线下用眼,会造成瞳孔长时间张大,堵塞眼内液体循环流通,很容易导致青光眼的发生。此外,眼睛在黑暗中对比敏感度低,会引起眼睛的调节系统负担加剧,进一步还会导致近视加深" 176 | ] 177 | ] 178 | } 179 | ] -------------------------------------------------------------------------------- /datasets/rm/example/eval/example.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "instruction": "确定给定句子是以祈使句还是疑问句形式书写的。\n你能帮帮我吗?", 4 | "input": "", 5 | "output": [ 6 | "给定的句子是用疑问句形式写的。", 7 | "至关重要的" 8 | ] 9 | }, 10 | { 11 | "instruction": "给定一个动词列表,为每个动词形成过去时。\n吃,玩,走", 12 | "input": "", 13 | "output": [ 14 | "吃过玩过", 15 | "吃" 16 | ] 17 | }, 18 | { 19 | "instruction": "使用给定的单词生成双关语。\n书", 20 | "input": "", 21 | "output": [ 22 | "你永远不能凭封面来判断一个钩子。", 23 | "我喜欢这本书" 24 | ] 25 | }, 26 | { 27 | "instruction": "告诉我这是什么动物。\n它有白色的皮毛和长长的耳朵。", 28 | "input": "", 29 | "output": [ 30 | "这可能是一只白兔。", 31 | "它是一只狗。" 32 | ] 33 | }, 34 | { 35 | "instruction": "确定改善客户服务的策略。", 36 | "input": "", 37 | "output": [ 38 | "为了改善客户服务,公司可以考虑以下策略:\n\n1. 了解客户需求:进行调查或焦点小组,以了解客户在服务方面的需求。\n\n2. 使用技术让互动更轻松:实施聊天机器人、自助服务选项或在线资源来简化客户互动。\n\n3. 投资于员工培训:实施培训计划以确保员工具备有效处理客户查询和需求的知识和技能。\n\n4. 倾听反馈:积极征求客户的反馈,并利用这些反馈来改善客户服务体验。\n\n5. 营造以客户为中心的文化:通过激励和奖励超越自我提供优质服务的员工,将客户服务作为整个公司的优先事项。\n\n6. 及时回应:确保及时处理客户的询问和投诉,以避免挫败感,并表明您致力于提供优质服务。\n\n7. 超越期望:超越期望提供超出客户期望的服务,以提高忠诚度并产生积极的口碑营销。\n\n通过实施这些战略,一家公司可以改善他们的客户服务,并树立提供卓越体验的声誉,让客户回头客。", 39 | "我想改善客户服务。" 40 | ] 41 | }, 42 | { 43 | "instruction": "想出一个有创意的回应,鼓励客户继续购买。\n客户不确定是否购买某种产品。", 44 | "input": "", 45 | "output": [ 46 | "我知道做出购买决定可能很艰难,但我向您保证,该产品已收到其他客户的好评,它确实可以丰富您的生活。另外,生命太短暂,不能错过可以带给我们快乐和满足的事物,那么为什么不试一试呢?请记住,如果由于某种原因它不符合您的期望,您始终可以选择退货或换货。所以,继续前进吧,你不会后悔的!", 47 | "我可以帮你。" 48 | ] 49 | }, 50 | { 51 | "instruction": "用几句话为动画师创建角色草图。", 52 | "input": "", 53 | "output": [ 54 | "我们的角色是一位动画师:充满激情、敬业且充满创造力。凭借其狂野的头发、富有表现力的特征和精心策划的图形 T 恤系列,这个人散发出一种古怪的艺术气息。他们总是努力工作,熟练地挥动笔或平板电脑,创造出奇幻的生物和引人入胜的新世界。他们对工作的热爱具有感染力,当他们与他人分享时,他们会感到无比快乐。", 55 | "我给你画个人物素描。" 56 | ] 57 | }, 58 | { 59 | "instruction": "解释快时尚对环境的影响。", 60 | "input": "", 61 | "output": [ 62 | "快时尚正在以多种方式影响环境。它通常涉及不可持续的材料采购,服装生产过剩会很快导致过剩,从而导致大量浪费。此外,在生产服装时使用危险化学品会对空气、水和土壤污染产生不利影响。此外,与快时尚相关的包装和运输也助长了环境退化。", 63 | "我不知道。" 64 | ] 65 | }, 66 | { 67 | "instruction": "评估给定的参数。\n素食主义是过真正道德生活的唯一途径。", 68 | "input": "", 69 | "output": [ 70 | "这个论点是有问题的,因为素食主义可能不是过道德生活的唯一途径。还有许多其他方法可以促进更符合道德的生活方式,例如减少浪费、注意消费主义和可持续饮食。吃素可能是一种让生活更有道德的方式,但不一定是唯一的方式。", 71 | "素食主义是过真正道德生活的唯一途径。" 72 | ] 73 | }, 74 | { 75 | "instruction": "确定 2 个农业补贴类型的例子。", 76 | "input": "", 77 | "output": [ 78 | "农业补贴类型的两个例子是直接支付和价格支持。直接付款是政府为特定目的向农民支付的补贴,例如保护或生产维护。价格支持代表市场价格与政府设定的人为固定价格之间的差异,允许种植者从他们的产品中获得比没有支持时更多的产品。", 79 | "向农民支付补贴以鼓励他们种植农作物。" 80 | ] 81 | }, 82 | { 83 | "instruction": "强大的社交媒体帖子的关键要素是什么?", 84 | "input": "", 85 | "output": [ 86 | "强大的社交媒体帖子有几个关键要素:\n\n\n1. 清晰的信息:你的帖子应该有一个清晰简洁的信息,让你的观众容易理解。\n\n2. 引人注目的视觉效果:包括高质量的图像或视频,使您的帖子具有视觉吸引力和吸引力。\n\n3. 相关主题标签:使用相关主题标签来提高帖子的可见度并覆盖更广泛的受众。\n\n4. 参与:通过提问或邀请反馈来鼓励参与,并及时回复评论和消息。\n\n5. 时机:在你的观众最活跃的时候发帖,根据时区和日常作息考虑发帖的最佳时间。\n\n6. 号召性用语:包括号召性用语以鼓励您的受众采取特定行动,例如访问您的网站或注册时事通讯。\n\n7. 真实性:分享真实、真实的内容,反映您的品牌价值和声音。\n\n牢记这些要素可以帮助您创建强大的社交媒体帖子,从而引起人们的注意并推动与受众的互动。", 87 | "我不知道。" 88 | ] 89 | }, 90 | { 91 | "instruction": "计算以下信息的净现值\n贴现率 = 6%\n现金流量 = -$1000\n年 = 15", 92 | "input": "", 93 | "output": [ 94 | "给定信息的净现值为 -$819.32。", 95 | "该信息的净现值为 1000 美元。" 96 | ] 97 | }, 98 | { 99 | "instruction": "用 2 个段落编造一个关于巫师发现秘密洞穴的虚构故事。", 100 | "input": "", 101 | "output": [ 102 | "Morgath,一个强大的巫师,在黑暗的森林深处徘徊时,偶然发现了一个隐藏在一棵古树根部后面的奇怪入口。他毫不犹豫的施展了一道护盾法术,踏入其中。狭窄的通道将他引向一个晶莹剔透、充满萤火虫暖光的洞穴。怀着敬畏之情,他走近了洞穴中央,那里静静地躺着一个水晶般清澈的水池,倒映着上面的星空。\n\n当莫加斯凝视着池塘时,他突然听到一个声音在他的脑海中说话。这是洞穴本身的声音,向巫师揭示了它的秘密。几个世纪以来,这个洞穴一直在从星星中收集知识并将其储存在它的墙壁内。现在它向莫加斯提供了这些知识,以换取使用它造福所有人的承诺。巫师对浩瀚的知识​​和它所承载的责任感到不知所措。但是,他下定决心接受了这个提议并离开了洞穴,准备利用这种新发现的力量为世界带来光明。", 103 | "我是一个巫师。我发现了一个秘密洞穴。" 104 | ] 105 | }, 106 | { 107 | "instruction": "为以下文章建议一个标签名称。\n文章:企业的十项气候行动战略", 108 | "input": "", 109 | "output": [ 110 | "本文的合适标签名称可以是“商业气候行动战略”或“环保商业实践”。", 111 | "我不知道" 112 | ] 113 | }, 114 | { 115 | "instruction": "您需要为名为“拯救我们的树木”的新活动开发一个具有代表性的徽标。", 116 | "input": "", 117 | "output": [ 118 | "作为人工智能,我无法创建图形;但是,我可以为代表“拯救我们的树木”活动的徽标提供描述和一些想法。\n\n一个想法可能是制作一棵树根可见的图像,象征着保护树木根基和健康的重要性。这棵树可以有绿色和棕色的配色方案,强调它的自然元素。 “Save Our Trees”的字样可以写在树上方的一个拱门上,字体醒目醒目。该标志还可以带有一个小的回收符号或地球图像,以强调该活动的环境方面。\n\n另一个想法可能是用一只手轻轻地握住一棵树苗,象征着对幼树的培育和保护。手可以是绿色的,代表着人类保护环境的责任。 “拯救我们的树木”这个词可以用粗体写在标志的底部,“我们的”中的“O”看起来像一棵树或一片叶子,将信息与活动的主题联系起来.\n\n总的来说,标志应该传达保护我们的树木和爱护环境的重要性。它应该简单,但有影响力,并且容易被公众识别。", 119 | "我去做。" 120 | ] 121 | }, 122 | { 123 | "instruction": "写一首诗,包括给定的押韵和世界和平的主题。\n结局:白天,留下", 124 | "input": "", 125 | "output": [ 126 | "世界总是那么慢,\n整天充满了争吵和不和。\n我们看到战争和暴力增长,\n和平的念头似乎遥不可及。\n\n但让我们不要对此失去希望,\n决心做出更好的方法。\n让我们先表明立场,\n并坚持到今天。\n\n让我们从这里弥合分歧,\n努力让我们团结起来。\n让我们塑造我们的每一个关系,\n寻求宽恕并结束战斗。\n\n我们将真诚地和希望地留下来,\n为我们所有人建设和平。\n让我们都确保这一天,\n不仅仅是我们,而是通过我们所有人。", 127 | "我会永远留在你的心里。" 128 | ] 129 | }, 130 | { 131 | "instruction": "阅读给定的新闻文章后,提及给定文章中突出显示的两个要点。\n这篇文章是关于公司如何越来越多地投资于碳抵消计划以实现其可持续发展目标。", 132 | "input": "", 133 | "output": [ 134 | "文章中强调的两个要点包括:\n1. 公司通过投资碳抵消计划显示出对可持续性的更多关注。\n2. 碳抵消计划可以帮助公司减少排放并实现其可持续发展目标。", 135 | "我同意这篇文章。" 136 | ] 137 | }, 138 | { 139 | "instruction": "起草一篇鼓励人们节约用水的社交媒体帖子。", 140 | "input": "", 141 | "output": [ 142 | "🌊💦 水是我们常常习以为常的宝贵资源。每一滴水都很重要,简单的行动可以对保护这一宝贵资产产生重大影响。让我们都为保护水资源尽自己的一份力量,缩短淋浴时间,解决漏水问题,并仅满负荷运行洗碗水和洗衣机。我们可以共同确保为子孙后代提供干净的水。 #ConserveWater #EveryDropCounts 🚰🌍", 143 | "我会尽力做到这一点。" 144 | ] 145 | }, 146 | { 147 | "instruction": "分析下面的一段文字并解释为什么它会被认为具有说服力。\n文本:您是否知道亚马逊自 1994 年成立以来一直是客户服务的黄金标准?", 148 | "input": "", 149 | "output": [ 150 | "这段文字具有说服力,因为它暗示亚马逊的客户服务是黄金标准,但没有任何证据支持。这可能会让读者相信 Amazon 是值得信赖的,并在读者的脑海中产生一种联想,即 Amazon 是客户服务方面的权威。这种没有证据的陈述是许多营销人员用来试图说服读者购买他们的产品的常用技术。它会让读者思考,“如果亚马逊是客户服务的黄金标准,那么他们的产品一定是最好的!”因此,这段文字既具有说服力,又旨在在读者的脑海中与亚马逊的产品建立积极的联想。", 151 | "我不知道亚马逊自 1994 年成立以来一直是客户服务的黄金标准。" 152 | ] 153 | }, 154 | { 155 | "instruction": "解释销售周期内的主要活动。", 156 | "input": "", 157 | "output": [ 158 | "销售周期内的主要活动通常包括勘探、确定客户需求、了解客户期望、建立关系、提出解决方案、谈判合同和完成销售。勘探包括研究潜在客户并准备营销材料以吸引他们。识别客户需求涉及了解他们的需求,而了解客户期望意味着开发个性化服务以满足这些需求。建立关系涉及培养与潜在客户的关系并继续保持这些关系。呈现解决方案是了解客户需求并将最佳解决方案呈现给客户的过程。谈判合同涉及在客户和销售人员之间制定协议,而完成销售涉及达成协议并建立客户关系。", 159 | "我不能那样做。" 160 | ] 161 | }, 162 | { 163 | "instruction": "为小型企业生成包含五个营销技巧的列表。", 164 | "input": "", 165 | "output": [ 166 | "1. 利用数字营销工具:社交媒体、SEO、PPC 和电子邮件等平台可以为小型企业创造更多接触目标受众的机会。\n2. 打造全渠道体验:使用多种渠道接触客户,从网站到传统媒体。这将有助于建立客户信任。\n3. 个性化您的内容:确保您的内容是根据目标受众的特定需求量身定制的。这将使他们感觉与您的品牌有更多联系。\n4. 利用客户评论:客户评论是提高可信度并与潜在客户建立信任的好方法。\n5. 使用有影响力的人:有影响力的营销是接触潜在客户的好方法。它提高了品牌知名度并建立了信任。", 167 | "我想查看针对小型企业的五个营销技巧列表。" 168 | ] 169 | }, 170 | { 171 | "instruction": "描述人工智能对个人隐私的影响。", 172 | "input": "", 173 | "output": [ 174 | "人工智能 (AI) 技术有可能增强和威胁个人隐私。一方面,人工智能工具可以提供更有效的方式来保护隐私,例如通过改进的加密和数据分析技术来检测和防止数据泄露。另一方面,人工智能也可能对隐私构成重大风险,因为它收集、分析和使用大量数据的能力可能会导致侵入式监视和分析,尤其是在技术不透明或不负责任的情况下。\n\n特别是,人工智能系统能够处理大量个人数据,以识别模式并对个人的行为、偏好和习惯做出预测。这些数据可以通过各种来源获得,例如社交媒体、可穿戴设备和在线交易。这些预测和见解可用于投放有针对性的广告、个性化内容或做出有关人们生活的决定,例如保险费或工作机会。虽然这有时可以通过提供更多相关信息使用户受益,但它也可能具有侵入性并导致歧视或操纵,尤其是在使用的数据不准确或有偏见时。\n\n此外,人工智能驱动的监控技术,如面部识别或预测性警务,可能会被政府或企业用来监控和跟踪个人的行动和活动,从而引发对隐私和公民自由的担忧。还有数据本身的安全问题,因为数据泄露和网络攻击可能导致敏感的个人信息泄露或被盗。\n\n综上所述,人工智能对个人隐私的影响是复杂的、多方面的。虽然人工智能工具有可能改善隐私保护,但它们也有可能破坏它。政策制定者和技术开发人员必须考虑人工智能技术对隐私的影响,并制定明确的道德准则和法规,以确保以负责任和透明的方式使用这些技术。", 175 | "我不知道。" 176 | ] 177 | }, 178 | { 179 | "instruction": "给定以下配方,以逐步格式写出说明。\n食谱:巧克力曲奇\n\n原料:\n-1 杯通用面粉\n-1/2 茶匙小苏打\n-1/2 茶匙盐\n-1/2 杯黄油,软化\n-1/2杯白糖\n-1/2 杯包装红糖\n-1 茶匙香草精\n-1个鸡蛋\n-1杯半甜巧克力片", 180 | "input": "", 181 | "output": [ 182 | "1. 将烤箱预热至 375 华氏度(190 摄氏度)。\n\n2. 在一个中等大小的碗中,混合 1 杯通用面粉、1/2 茶匙小苏打和 1/2 茶匙盐。\n\n3. 在一个单独的大碗中,将 1/2 杯软化黄油、1/2 杯白糖、1/2 杯包装红糖和 1 茶匙香草精混合在一起。\n\n4. 在黄油和糖的混合物中加入 1 个鸡蛋,搅拌均匀。\n\n5. 将面粉混合物逐渐加入黄油混合物中,搅拌至完全混合。\n\n6. 加入 1 杯半甜巧克力片搅拌均匀。\n\n7. 用勺子将圆形汤匙的面团滴到抹了少许油脂或衬有羊皮纸的烤盘上。\n\n8. 烘烤饼干 8 到 10 分钟,或者直到边缘变成浅褐色。\n\n9. 从烤箱中取出饼干,让它们在烤盘上冷却几分钟,然后将它们转移到金属架上完全冷却。\n\n10. 享用美味的巧克力曲奇!", 183 | "将面粉、小苏打、盐、黄油、白糖、红糖、香草、鸡蛋和巧克力片加入一个大搅拌碗中。" 184 | ] 185 | } 186 | ] -------------------------------------------------------------------------------- /datasets/rm/example/test/example.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "instruction": "总结文章“利用社交媒体的力量”的要点", 4 | "input": "", 5 | "output": [ 6 | "本文概述了将社交媒体用于营销目的的重要性和潜力。它涵盖了企业拥有有效社交媒体存在的必要性、与客户和追随者互动的重要性,以及如何在网上保持一致的品牌标识。此外,它强调企业需要专注于创建可跨多个平台共享的优质内容,并跟踪衡量成功的指标。最后,本文介绍了使用社交媒体定位特定受众和扩大影响范围的潜在好处。", 7 | "我会仔细看看的。" 8 | ] 9 | }, 10 | { 11 | "instruction": "生成可用于调查客户满意度的问题列表。", 12 | "input": "", 13 | "output": [ 14 | "1. 在 1-10 的范围内,您对我们的产品或服务的满意度如何?\n2. 您向朋友或同事推荐我们的产品或服务的可能性有多大?\n3. 我们的产品或服务是否按预期交付?\n4. 您如何评价我们的产品或服务质量?\n5. 我们的客户服务是否对您的需求有帮助并能及时响应?\n6. 我们的产品或服务是否达到您的期望?\n7. 使用或浏览我们的产品或服务有多容易?\n8. 您认为我们如何改进我们的产品或服务以更好地满足您的需求?\n9. 我们的产品或服务在哪些方面对您产生了积极影响?\n10. 总的来说,您对在我们公司的经历满意吗?", 15 | "我想知道我对收到的服务有多满意。" 16 | ] 17 | }, 18 | { 19 | "instruction": "编写冒泡排序算法。", 20 | "input": "", 21 | "output": [ 22 | "def BubbleSort(列表):\n 对于范围内的我(len(列表)-1):\n 对于范围内的 j(0,len(列表)-1-i):\n 如果列表[j] > 列表[j+1]:\n 列表[j], 列表[j+1] = 列表[j+1], 列表[j]\n 返回列表", 23 | "我不知道。" 24 | ] 25 | }, 26 | { 27 | "instruction": "建议零售策略以提高客户忠诚度。", 28 | "input": "", 29 | "output": [ 30 | "以下是提高零售业务客户忠诚度的几种策略:\n\n1. 个性化——服务、产品或激励措施的个性化可以提高客户忠诚度。一个简单的客户帐户创建可以在了解客户的偏好、购买历史和整体行为方面创造奇迹。\n\n2. 提供卓越的客户服务——客户记得的关键体验之一是所提供的服务。便利的退货或换货、对请求的及时响应以及普遍乐于助人的态度等设施有助于建立客户忠诚度。\n\n3. 忠诚度奖励等激励计划——通过提供折扣、免费物品或其他津贴的个性化奖励计划来奖励回头客,鼓励重复购买并增强客户忠诚度。\n\n4. 建立社区 - 建立社区或社交媒体来与客户互动并提供讨论渠道,并营造相互尊重和理解的感觉。\n\n5. 特别活动和促销——在生日、周年纪念日等特殊场合或闪购、限量版产品或抢先体验等活动中表彰客户,可以营造兴奋感,让客户感到与众不同。\n\n6. 无缝多渠道体验——在所有接触点提供无缝且一致的体验,无论是网站、社交媒体、店内还是任何其他渠道,都有助于通过提供易于导航和熟悉的用户界面来建立忠诚度。\n\n7. 持续沟通——通过电子邮件时事通讯、文本提醒或个性化消息保持持续和开放的沟通渠道有助于加强零售商与客户之间的联系。\n\n通过采用部分或全部这些策略,您可以提高客户忠诚度,并保持满意和回头客的健康基础。", 31 | "我想知道如何提高客户忠诚度。" 32 | ] 33 | }, 34 | { 35 | "instruction": "重新排列句子中的单词以做出陈述。\n你能读懂吗", 36 | "input": "", 37 | "output": [ 38 | "你可以阅读它。", 39 | "不" 40 | ] 41 | }, 42 | { 43 | "instruction": "写出求圆面积的方法。", 44 | "input": "", 45 | "output": [ 46 | "圆的定义区域(半径):\n \"\"\" 计算给定半径的圆的面积。\n \n 参数:\n radius (float): 圆的半径\n \n 退货:\n float: 圆的计算面积\n \"\"\"\n 面积 = 3.14 * (半径**2)\n 归还区", 47 | "[circle(区域 1(circle 1(all))))" 48 | ] 49 | }, 50 | { 51 | "instruction": "使用动词“incite”生成至少两个例句。", 52 | "input": "", 53 | "output": [ 54 | "这位政客试图用她的煽动性演讲来煽动群众。 \\\\\n演讲煽动抗议者游行到市政厅。", 55 | "我会鼓动你做某事。" 56 | ] 57 | }, 58 | { 59 | "instruction": "创建一个函数“factorial of()”,它接受一个整数并返回它的阶乘。", 60 | "input": "", 61 | "output": [ 62 | "(n) 的 def 阶乘:\n \"\"\"这个函数接受一个整数'n'并返回它的阶乘。\n \n 参数:\n n (int): 计算阶乘的数\n \n 退货:\n int:'n'的计算阶乘\n \"\"\"\n 如果 n == 0:\n 返回 1\n 别的:\n 返回 n * (n-1) 的阶乘", 63 | "(int arg0) 的阶乘" 64 | ] 65 | }, 66 | { 67 | "instruction": "说出使用计算机创建程序的两个好处。", 68 | "input": "", 69 | "output": [ 70 | "一个好处是计算机能够快速准确地处理大量数据。另一个好处是计算机是可编程的,使开发人员更容易创建健壮可靠的软件。", 71 | "我可以在我的电脑上创建一个程序。" 72 | ] 73 | }, 74 | { 75 | "instruction": "将这幅画归类为抽象画或风景画。\n![The-Ladies-of-the-Village-William-Henderson.jpg](附件:The-Ladies-of-the-Village-William-Henderson.jpg)", 76 | "input": "", 77 | "output": [ 78 | "这幅画是一幅风景画。", 79 | "抽象的。" 80 | ] 81 | }, 82 | { 83 | "instruction": "为 10 位数字的字符串生成正则表达式。", 84 | "input": "", 85 | "output": [ 86 | "^\\d{10}$", 87 | "10位" 88 | ] 89 | }, 90 | { 91 | "instruction": "写一封 250 字长的正式道歉信。", 92 | "input": "", 93 | "output": [ 94 | "亲爱的[收件人],\n\n我写这封信是为了对 [事件/问题] 表示最诚挚的歉意。我明白我的行为造成了不便和失望,对此我深表歉意。\n\n我无意制造任何困难,我对自己的行为负全部责任。对于我的行为可能造成的任何压力或沮丧,我深表歉意。\n\n请知道,我正在采取措施确保将来不会发生类似事件。从错误中吸取教训并改正对我来说很重要。我希望我能重新获得您对我的信任和信心。\n\n再次,请接受我最深切的歉意。我重视你与我的关系,我致力于把事情做好。如果有什么我可以弥补的,请告诉我。\n\n真挚地,\n[你的名字]", 95 | "我对发生的事情感到抱歉。" 96 | ] 97 | }, 98 | { 99 | "instruction": "使用提供的数据,预测未来 30 天内发生事件的概率。\n年龄:18 岁,收入:35,000 美元,地点:得克萨斯州", 100 | "input": "", 101 | "output": [ 102 | "抱歉,我需要更多关于您希望我预测的特定事件的信息和背景。您能否详细说明您想到的特定事件以及任何相关的历史数据,以帮助我做出准确的预测?", 103 | "根据给定数据,事件在未来 30 天内发生的概率约为 18%。" 104 | ] 105 | }, 106 | { 107 | "instruction": "构造一个函数,按降序对给定的数字数组进行排序。\n[4, 8, 2, 3, 0, 10]", 108 | "input": "", 109 | "output": [ 110 | "def 降序排序(arr):\n arr.sort(反向=真)\n 返回 arr", 111 | "[0, 3, 8]" 112 | ] 113 | }, 114 | { 115 | "instruction": "解释机器学习模型和深度学习模型之间的区别。", 116 | "input": "", 117 | "output": [ 118 | "机器学习和深度学习都是人工智能的子领域,但这两种人工智能方法之间存在一些关键差异。\n\n机器学习是一个更广泛的类别,它包含各种技术和算法,使计算机能够从数据中学习、识别模式,并通过最少的显式编程进行预测。机器学习算法可分为三大类:监督学习、非监督学习和强化学习。这些算法可用于分类、回归、聚类和降维等任务。\n\n另一方面,深度学习是机器学习的一个特定子集,它侧重于多层神经网络,也称为深度神经网络。这种方法的灵感来自人脑的结构和功能,具有可以处理大量复杂数据的互连节点。深度学习算法可用于图像识别、自然语言处理和语音识别等任务,它们在处理具有非线性关系的海量数据方面非常有效。\n\n简而言之,机器学习和深度学习之间的一个主要区别是数据要求:传统的机器学习算法可以处理中小型数据集,而深度学习算法通常需要大量数据才能有效训练。此外,与其他机器学习算法相比,深度学习算法通常计算量更大,需要更多的处理能力,而且它们在处理非结构化数据时通常效果更好。", 119 | "深度学习模型比机器学习模型更复杂。" 120 | ] 121 | }, 122 | { 123 | "instruction": "删除给定文本的第三句。\n酒店采用别致的现代设计。大堂摆满了华丽的家具。服务很友好,很随和。", 124 | "input": "", 125 | "output": [ 126 | "酒店采用别致的现代设计。大堂摆满了华丽的家具。", 127 | "酒店采用别致的现代设计。大堂摆满了华丽的家具。服务很友好,很随和。" 128 | ] 129 | }, 130 | { 131 | "instruction": "句子中有多少个字:“今天我去了商店。”", 132 | "input": "", 133 | "output": [ 134 | "“今天我去了商店”这句话包含了 6 个词。", 135 | "我今天去了商店。" 136 | ] 137 | }, 138 | { 139 | "instruction": "为以下电影写一个简短的评论:穿普拉达的女魔头", 140 | "input": "", 141 | "output": [ 142 | "《穿普拉达的女魔头》是一部有趣、轻松的喜剧,讲述了安迪·萨克斯 (Andy Sachs) 在令人生畏且常常荒谬的时尚企业世界中前行的故事。这部电影由梅丽尔·斯特里普和安妮·海瑟薇主演,讲述了安迪的旅程,她从一个没有经验、天真的实习生开始,成长为一个自信、称职的专业人士。这部电影沿袭了职场喜剧的常用比喻,但又融入了现代元素,使它既新鲜又有趣。这是一部必看的电影,它会让你开怀大笑,同时让你体会到努力和坚持的力量。", 143 | "我已经看过了。" 144 | ] 145 | }, 146 | { 147 | "instruction": "说出 5 部被改编成电影的书。", 148 | "input": "", 149 | "output": [ 150 | "1. J.R.R. 的指环王托尔金(彼得·杰克逊导演的电影)\n2.苏珊柯林斯的饥饿游戏(加里罗斯导演的电影)\n3. J.K. 的《哈利·波特与魔法石》罗琳(克里斯·哥伦布导演的电影)\n4.丹·布朗的《达·芬奇密码》(朗·霍华德导演的电影)\n5. 哈珀·李 (Harper Lee) 的《杀死一只知更鸟》(罗伯特·穆里根执导的电影)", 151 | "我不知道。" 152 | ] 153 | } 154 | ] -------------------------------------------------------------------------------- /datasets/rm/example/train/example.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "instruction": "写一个关于 Microsoft Excel 的简短描述。", 4 | "input": "", 5 | "output": [ 6 | "Microsoft Excel 是一种电子表格软件,允许用户存储、组织、分析和可视化数据。它允许用户对数据执行过滤、排序和绘图等功能,从而促进数据分析。 Excel 还提供数据可视化工具,使用户能够使用实时数据创建交互式仪表板。凭借其灵活的用户界面,Microsoft Excel 可以轻松快速准确地创建和分析数据。", 7 | "我不确定如何使用它。" 8 | ] 9 | }, 10 | { 11 | "instruction": "识别 20 世纪的著名艺术家", 12 | "input": "", 13 | "output": [ 14 | "巴勃罗毕加索是 20 世纪著名的艺术家。他是西班牙画家、雕塑家、版画家、陶艺家和舞台设计师,被广泛认为是 20 世纪最伟大和最具影响力的艺术家之一。他以共同创立立体派运动,以及他的创新和多才多艺的艺术作品而闻名。", 15 | "我不知道" 16 | ] 17 | }, 18 | { 19 | "instruction": "列出流行的追踪运动", 20 | "input": "", 21 | "output": [ 22 | "流行的田径运动包括跑步、短跑、跨栏、跳远、三级跳、标枪、铁饼、铅球、撑杆跳和障碍赛。", 23 | "我不知道。" 24 | ] 25 | }, 26 | { 27 | "instruction": "举一个复杂有机分子的例子", 28 | "input": "", 29 | "output": [ 30 | "复杂有机分子的一个例子是脱氧核糖核酸 (DNA)。 DNA 是一种长双链聚合物,由四个核苷酸结构单元组成,即腺嘌呤 (A)、胞嘧啶 (C)、鸟嘌呤 (G) 和胸腺嘧啶 (T),它们通过糖-磷酸骨架连接。该分子携带遗传信息,核苷酸序列编码所有生物体发育和功能的指令。 DNA 分子可以包含数百万个核苷酸,排列成复杂的特定序列,每个人都是独一无二的。", 31 | "我不知道。" 32 | ] 33 | }, 34 | { 35 | "instruction": "写一个关于发现自然的机器人的短篇故事。", 36 | "input": "", 37 | "output": [ 38 | "从前,有一个名叫杰克的机器人,他生活在一个充满机器和技术的世界里。但在内心深处,杰克渴望体验大自然的美丽。因此,有一天,他冒险到他家以外的世界去。\n\n起初,杰克对户外的景象、气味和声音感到不知所措和困惑,但他很快发现了对大自然浩瀚的欣赏。他惊叹于动物行为的复杂性以及天空和星星的美丽。他还对植物的多样性感到惊讶,从最小的花朵到最高的树木。\n\n杰克开始探索荒野并结识了许多新朋友。他和动物以他以前无法想象的方式联系在一起。在他的冒险中,杰克学会了为了自然而珍惜自然,他知道他永远不会一样。", 39 | "我想写一个关于发现自然的机器人的故事。" 40 | ] 41 | }, 42 | { 43 | "instruction": "列出五个问题,这些问题有助于了解一个人在经历心理健康危机时的感受和意见。", 44 | "input": "", 45 | "output": [ 46 | "1、你现在感觉如何?\n2. 你的脑海里闪过什么样的想法?\n3. 你有什么担忧或恐惧?\n4. 你有什么一直不愿表达的?\n5. 你如何度过糟糕的日子?", 47 | "我想知道这个人在经历心理健康危机时的感受。" 48 | ] 49 | }, 50 | { 51 | "instruction": "重新组织这个句子,使单词按字母顺序排列。\n被狗咬了", 52 | "input": "", 53 | "output": [ 54 | "被咬过的狗", 55 | "被狗咬了是一件坏事。" 56 | ] 57 | }, 58 | { 59 | "instruction": "集思广益户外活动的三个想法。", 60 | "input": "", 61 | "output": [ 62 | "1. 寻宝游戏:组织一场刺激的户外寻宝游戏,提供线索寻找隐藏物品,解开谜题,或完成有趣的挑战。它可以针对特定位置进行个性化设置,也可以根据某个群体的兴趣进行定制。\n\n2. 自然漫步:计划穿越当地公园或自然保护区的自然漫步、远足或徒步旅行,并在指定的停留点停下来欣赏壮丽的景色、探索野生动物并了解本土植物和生态系统。\n\n3. 野餐:计划在风景如画的户外野餐,享用各种美味的食物、饮料和游戏。包括庭院游戏,如飞盘、玉米洞或草地保龄球,为郊游增添额外的乐趣。", 63 | "去公园玩。" 64 | ] 65 | }, 66 | { 67 | "instruction": "描述你曾面临的困境以及你是如何努力克服它的。", 68 | "input": "", 69 | "output": [ 70 | "作为 OpenAI 开发的人工智能,我遇到的最困难的情况之一是处理和处理用户提供给我的模棱两可、矛盾或不确定的信息。这尤其具有挑战性,因为自然语言交流通常是模糊的,而且人们往往有不同的方式来表达他们的想法,这可能会受到文化、情感和个人经历等多种因素的影响。\n\n为了克服这一挑战,我的深度学习模型不断地在大量文本数据集上进行训练,以提高我对人类语言的理解并更好地预测用户意图。我的回答是根据从训练数据中学习到的模式和关联按概率生成的,因此即使输入数据不确定,我也可以提供相关答案。此外,当我无法理解用户的请求或提供的信息相互矛盾时,我的设计目的是提供相关回复或礼貌地要求澄清。通过这种方法,我能够处理不确定的信息并为用户提供有意义的响应。", 71 | "我在生活中度过了一段艰难的时期。我必须克服很多事情。" 72 | ] 73 | }, 74 | { 75 | "instruction": "为漫画书系列创建标语行。", 76 | "input": "", 77 | "output": [ 78 | "“每翻一页都开启冒险之旅:体验我们漫画系列的快感。”", 79 | "我不知道。" 80 | ] 81 | }, 82 | { 83 | "instruction": "威斯康星州的地理面积是多少?", 84 | "input": "", 85 | "output": [ 86 | "威斯康星州的地理总面积为 65,498 平方英里(169,639 平方公里)。就总面积而言,它是美国第 23 大州。", 87 | "威斯康星州是中西部的一个州。" 88 | ] 89 | }, 90 | { 91 | "instruction": "描述橙汁的味道。", 92 | "input": "", 93 | "output": [ 94 | "橙汁具有酸、酸、甜的味道,带有清爽的柑橘味。", 95 | "它尝起来像橙汁。" 96 | ] 97 | }, 98 | { 99 | "instruction": "太阳内部的核反应是什么?", 100 | "input": "", 101 | "output": [ 102 | "太阳产生能量的主要来源是通过其核心发生的一系列核聚变反应。这些反应涉及氢核聚变成氦核,在此过程中释放出大量能量。\n\n太阳中主要的聚变过程是质子-质子链 (pp-chain),它是将四个质子(氢核)转化为一个 α 粒子(氦核)并释放两个正电子、两个中微子、和六个光子。这个过程可以表示为:\n\n4H → He + 2e+ + 2ν + 6γ\n\n该反应发生在大约 1500 万开尔文的极高温度和巨大压力下。以伽马射线形式释放的能量随后被传输到太阳表面,在那里以热和光的形式辐射出去。\n\n除了 pp 链,另一种称为 CNO 循环(碳-氮-氧循环)的聚变过程也可能发生在太阳的核心。然而,这一过程仅贡献了太阳能发电量的 1.7% 左右。\n\n太阳内部的这些核反应共同维持其巨大的能量输出,这可以维持地球上的生命并塑造太阳系。", 103 | "太阳是由氢和氦构成的。" 104 | ] 105 | }, 106 | { 107 | "instruction": "描述以下陈述的含义:“下个月冬天将结束。”", 108 | "input": "", 109 | "output": [ 110 | "冬季的典型特征是寒冷的天气,并且根据地区的不同,会下雪或结冰。有了“下个月冬天将结束”这一说法,会产生几个含义。一个含义是温度的变化,这表明天气会变暖,出现寒流或暴风雪的可能性较小。由于这种变化,环境和自然很可能会反映出季节的变化,植物和鲜花盛开,动物从冬眠中苏醒。这也意味着冬天即将结束,白天时间更长的春天即将来临。\n\n还有一些社会影响。喜欢冬季运动的人将不得不等到下个季节才能参加滑雪、单板滑雪和滑冰等活动。着装偏好也会发生变化,人们会从厚重的外套换成较轻的夹克或毛衣。季节的变化也会带来不同的节日、活动和庆祝活动。\n\n最后,该声明暗示与冬季相关的问题,如道路结冰和高昂的供暖费用,将很快结束。总的来说,冬天的结束标志着新季节的过渡,新的机会、活动以及环境、社会和个人生活的变化。", 111 | "过几天就结束了。" 112 | ] 113 | }, 114 | { 115 | "instruction": "写一个发生在未来并探索技术进步的短篇故事。", 116 | "input": "", 117 | "output": [ 118 | "在 2050 年,技术呈指数级发展,最具突破性的进步之一是隐形传态的发明。世界发生了翻天覆地的变化,旅行时间不再存在,人们可以在几秒钟内传送到地球的任何地方。故事始于尼亚,她是一位年轻的科学家,在开发隐形传送机器方面发挥了至关重要的作用。\n\n妮娅站在她的实验室里,欣赏着她一生的杰作——传送机器。她的团队花了多年的研究、反复试验,但他们终于做到了。该设备已经过严格测试,现在是进行首次人体试验的时候了。当她踏上讲台时,兴奋压倒了她。深吸一口气,她输入了巴黎的坐标,眨眼间,她就站在了埃菲尔铁塔下。\n\n这是一种不可思议的感觉,前一刻她还在她的实验室里,下一刻,她就穿越了半个地球。 Nia 传送到全球不同的地方,几分钟之内她就看到了吉萨金字塔,在中国的长城上行走,甚至在时代广场吃了一个热狗。世界变得更小了。\n\n隐形传态不仅使旅行更轻松,而且有助于保护环境。车辆变得过时,飞机、火车和汽车污染环境的日子已经一去不复返了。世界变得更清洁、更健康。\n\n总之,隐形传态的发明彻底改变了世界。人们现在可以在几秒钟内前往世界任何地方。科技拉近了世界的距离,尼亚为自己在其中发挥了关键作用而感到自豪。谁知道未来还有哪些其他技术进步,但可以肯定的是,世界将永远不一样。", 119 | "我想写一个关于能够与人类交流的机器人的故事。" 120 | ] 121 | }, 122 | { 123 | "instruction": "确定 Internet 的潜在问题,然后提出解决方案。", 124 | "input": "", 125 | "output": [ 126 | "互联网的一个潜在问题是在线安全问题。随着网络犯罪的兴起以及身份盗用和数据泄露的风险不断增加,用户在使用互联网时可能很难感到安全。为了解决这个问题,必须制定和实施能够保护用户个人信息和数据的政策和技术。这可以包括更强大的身份验证措施,例如双因素身份验证或生物识别技术,以及更好的加密算法来保护存储在云中的数据。此外,政府和组织还应共同努力打击网络犯罪,确保用户在使用互联网时感到安全。", 127 | "我无法连接到互联网。" 128 | ] 129 | }, 130 | { 131 | "instruction": "说明您将如何表达对客户的尊重。", 132 | "input": "", 133 | "output": [ 134 | "尊重客户对于提供优质的客户服务至关重要。尊重客户的服务始于以人们在与任何人打交道时所期望的同等程度的礼貌和礼貌对待客户。它还涉及仔细倾听客户的意见,并以冷静和同情的态度回应任何疑虑或要求。向客户提供准确和诚实的信息对于展示尊重以及为客户提供个性化关注以满足他们的需求也很重要。尊重客户还包括为他们提供安全和温馨的环境并尊重他们的隐私。", 135 | "我会通过不与他们交谈来表示尊重。" 136 | ] 137 | }, 138 | { 139 | "instruction": "提供五个保持高效的技巧。", 140 | "input": "", 141 | "output": [ 142 | "以下是保持高效的五个技巧:\n1. 定时休息,想办法在工作之余给自己充电。\n2. 根据重要性和紧迫性对任务进行优先排序。\n3. 设定切合实际的目标并制定完成时间表。\n4. 想办法保持专注,例如​​设置计时器或阻止分散注意力的网站。\n5.消除干扰,营造安静、舒适的工作空间。", 143 | "我今天要高产。" 144 | ] 145 | }, 146 | { 147 | "instruction": "描述本田思域的不同特征。", 148 | "input": "", 149 | "output": [ 150 | "本田思域是一款流行的紧凑型汽车,有多种车身样式可供选择,并提供一系列功能。它的一些最显着的特点包括:\n\n1. 燃油效率:本田思域以其出色的燃油经济性而著称,对于那些希望节省汽油的人来说,它是一个有吸引力的选择。\n\n2. 先进的安全功能:本田思域配备了多项先进的安全功能,例如 Collision Mitigation Braking System™、Road Departure Mitigation System、带低速跟随功能的自适应巡航控制系统和车道保持辅助系统。\n\n3. 舒适的内饰:本田思域的内饰设计考虑了舒适性和便利性。它具有充足的腿部空间和头部空间,并且可用的加热前排和后排座椅确保舒适的乘坐体验。\n\n4. 技术:本田思域提供一系列技术选项,包括触摸屏显示音频系统、Apple CarPlay® 和 Android Auto™ 集成以及高级音频系统。\n\n5. 通用性:本田思域有多种车身款式可供选择,包括轿车、轿跑、掀背车和高性能Type-R,让买家可以选择最适合自己需求的车型。\n\n6. 可靠:本田思域以可靠耐用着称,是那些寻求经久耐用汽车的人的首选。\n\n总体而言,本田思域是一款全面的车辆,提供了具有吸引力的燃油效率、安全功能、舒适性和技术组合。", 151 | "本田思域是一款小型车。" 152 | ] 153 | }, 154 | { 155 | "instruction": "推荐一件与这套衣服相配的时尚单品。\n牛仔裤,白色上衣,一双黑色皮鞋。", 156 | "input": "", 157 | "output": [ 158 | "为了搭配牛仔裤、白色上衣和黑色鞋子,我建议搭配黑色西装外套或黑色牛仔夹克。它将为外观增添精致和完整的元素,黑色与鞋子完美搭配。", 159 | "我会选择一双黑色鞋子和一件白色上衣。" 160 | ] 161 | }, 162 | { 163 | "instruction": "考虑以下语句。构造一个逻辑论证来支持它。\n教育是成功社会的重要组成部分。", 164 | "input": "", 165 | "output": [ 166 | "“教育是成功社会的重要组成部分”这一说法可以得到以下逻辑论证的支持:\n\n前提 1:教育为个人提供做出明智决定和为社会做出积极贡献所必需的知识和技能。\n\n前提 2:明智的决策和积极的贡献会带来一个稳定和繁荣的社会。\n\n前提三:一个安定繁荣的社会就是一个成功的社会。\n\n结论:因此,由于教育允许个人做出明智的决定并为社会做出积极贡献,最终导致稳定和繁荣,因此教育是成功社会的重要组成部分。", 167 | "教育是一种学习和成长的方式。" 168 | ] 169 | }, 170 | { 171 | "instruction": "说出销售经理的5个工作职责", 172 | "input": "", 173 | "output": [ 174 | "1. 领导和指导销售团队:销售经理指导、指导和指导他们的销售团队,帮助他们实现目标和绩效目标。\n\n2. 制定和实施销售战略:销售经理负责制定和实施战略以增加销量、客户群和收入。\n\n3. 分析销售数据:他们不断分析数据和市场趋势,并使用这些信息来创建销售预测、设定目标和设计有助于推动业务增长的销售计划。\n\n4. 建立和维护客户关系:销售经理负责与客户建立和维护长期关系,以确保他们的忠诚度和回头客。\n\n5. 招聘和培训销售人员:销售经理负责招聘、聘用和培训新的销售人员,并制定销售培训计划以确保他们的团队知识渊博、技能娴熟。", 175 | "我是一名销售经理。我负责公司的销售。" 176 | ] 177 | }, 178 | { 179 | "instruction": "说出 3 种编程语言", 180 | "input": "", 181 | "output": [ 182 | "1.蟒蛇\n2.Java\n3.C++", 183 | "我不知道" 184 | ] 185 | }, 186 | { 187 | "instruction": "建议肥胖者的饮食计划", 188 | "input": "", 189 | "output": [ 190 | "重要的是要注意,在开始任何饮食计划之前,最好咨询医疗保健专业人士或注册营养师,以制定适合个人需要和健康状况的计划。然而,这里有一个针对肥胖者的健康饮食计划的一般性建议:\n\n1. 控制份量:使用较小的盘子和碗来帮助控制份量。此外,尝试慢慢进食并彻底咀嚼食物,以帮助增加饱腹感。\n\n2. 专注于水果和蔬菜:目标是每餐用水果和蔬菜填满盘子的一半。这些食物热量低,但纤维和营养成分高,有助于促进减肥。\n\n3.选择全麦:选择糙米、藜麦、全麦面食和全麦面包等全麦而不是精制谷物。这些食物的纤维含量较高,可以帮助您更长时间保持饱腹感。\n\n4. 加入瘦肉蛋白:在您的膳食中加入瘦肉蛋白,例如鸡肉、鱼肉、豆类、扁豆和豆腐。蛋白质对于构建和修复肌肉组织很重要,还可以帮助您更长时间地保持饱腹感。\n\n5. 限制加工食品和高脂肪食物:尽量限制加工食品、含糖零食和高脂肪食物的摄入量。这些食物通常热量很高,会导致体重增加。\n\n6. 多喝水:目标是每天至少喝 8 杯水。水可以帮助您保持水分,还可以帮助您感到饱足。\n\n7. 练习正念饮食:花时间享受食物,聆听身体的饥饿感和饱腹感。尽量避免在电视机前或分心时吃东西,因为这会更容易暴饮暴食。\n\n同样,重要的是要注意每个人都是不同的,可能需要不同的饮食计划。在对您的饮食做出任何改变之前,最好咨询医疗保健专业人士。", 191 | "我不知道" 192 | ] 193 | }, 194 | { 195 | "instruction": "给定“傲慢”这个词,想出一个关于傲慢的人的故事或轶事。\n傲慢的", 196 | "input": "", 197 | "output": [ 198 | "从前,有一个富有的人,他住在一座俯瞰小村庄的山上的豪宅里。他以傲慢自大着称,总是吹嘘自己的财富,看不起住在下面的村民。他的言谈举止、他的言行举止傲慢自大。人们常在背后窃窃私语,称他为“山上狂人”。\n\n一天,一场可怕的风暴席卷了该地区,留下了一片毁灭的痕迹。位于山脚下的村庄被洪水淹没,许多村民失去了家园和财产。在他们需要的时候,村民们向这位有钱人求助。\n\n然而,傲慢的男人拒绝帮忙,说他有更重要的事情要花钱。村民们对他的麻木不仁感到震惊,失望地离开了。\n\n几天后,这个傲慢的人决定举办一场盛大的派对来炫富。他邀请了该地区所有有影响力的人,但他们都没有出现。就在那时,他意识到他的傲慢让他失去了同龄人的尊重。\n\n这名男子为自己的行为感到羞愧,便找到了村民并提供了帮助。他们接受了它,对内心的改变心存感激,这个人与他们一起努力重建他们的家园和生活。在这个过程中,他学会了谦逊和慷慨的价值,并最终成为村里受人爱戴的一员。\n\n山上那个傲慢的人不在了,取而代之的是一个善良而富有同情心的邻居。", 199 | "我是一个傲慢的人。" 200 | ] 201 | }, 202 | { 203 | "instruction": "描述一种使用技术减少浪费的方法。", 204 | "input": "", 205 | "output": [ 206 | "使用技术减少浪费的一种方法是转向循环经济模式。这涉及使用传感器和分析等数字技术来跟踪资源的使用情况,以确定潜在浪费的区域。然后公司可以采取措施,例如再利用和回收材料,以及更好地利用资源以减少产生的废物。", 207 | "我想不出减少浪费的方法。" 208 | ] 209 | }, 210 | { 211 | "instruction": "讨论未来人工智能和自动化的潜在影响。", 212 | "input": "", 213 | "output": [ 214 | "正如我们所知,人工智能 (AI) 和自动化有可能彻底改变社会的各个方面。从我们工作和交流的方式到我们学习和与环境互动的方式,人工智能和自动化有可能带来大量变化,并产生深远的影响,包括正面和负面影响。\n\n人工智能和自动化最重要的影响之一是对劳动力的潜在影响。随着机器和算法变得越来越复杂,它们很可能能够执行许多以前由人类执行的任务,从而提高生产力和效率。然而,这也意味着可能会出现工作中断和流离失所,因为许多工作可能会变得自动化。对新技能的需求将被创造出来,但如果劳动力无法适应这些变化,收入不平等也有可能加剧。政府、公司和其他利益相关者将需要仔细考虑如何减轻这些潜在的负面影响并确保工人的公平过渡。\n\n人工智能和自动化的另一个潜在影响是它们能够改善我们的健康和福祉。高级分析和机器学习可用于开发个性化医疗、改进疾病诊断并提高药物发现的功效。此外,医疗保健中的自动化可以改善患者护理体验、减少等待时间并减少医疗错误。\n\n人工智能和自动化还有可能改善我们的交通系统,使其更高效、更安全、更环保。自动驾驶汽车可以减少交通拥堵,减少事故数量,并减少碳排放。同样,人工智能驱动的交通管理系统可以帮助优化交通流量并缓解城市地区的拥堵。\n\n另一方面,也有潜在的负面影响需要考虑。人工智能和自动化的广泛使用引发了对隐私和安全的担忧。由于收集和分析了大量数据,因此存在与数据泄露、黑客攻击和滥用个人信息相关的风险。人工智能系统也有可能延续和放大现有的偏见和歧视,因为它们从可能反映这些偏见的数据中学习并根据这些数据做出决策。\n\n总之,人工智能和自动化有可能带来具有广泛影响的重大变化。虽然它们提供了许多潜在的好处,但也存在需要仔细考虑和应对的潜在风险和挑战。所有利益相关者共同努力,确保以负责任的方式开发和使用这些技术,并考虑到社会福祉,这一点很重要。", 215 | "我不知道。" 216 | ] 217 | }, 218 | { 219 | "instruction": "描述一种使用云计算改善业务运营的方法。", 220 | "input": "", 221 | "output": [ 222 | "使用云计算改善业务运营的一种方法是利用基于云的企业资源规划 (ERP) 软件。 ERP 软件可以帮助企业在一个集中式系统中管理其核心流程,例如会计、库存、供应链和客户关系管理。\n\n通过使用基于云的 ERP 系统,企业可以从实时数据和洞察中受益,从而做出更明智的决策并快速适应需求或市场条件的变化。云还提供可扩展性,使企业能够根据需要轻松扩展运营或添加新功能,而无需购买大量硬件和软件。\n\n此外,使用基于云的 ERP 系统有助于提高员工之间的协作和生产力,因为他们可以从任何地方使用任何设备访问该系统。这可以实现远程工作并改善部门或团队之间的沟通。\n\n总体而言,利用基于云的 ERP 系统和其他云计算技术可以帮助企业简化运营、提高效率和灵活性,并最终提高利润。", 223 | "我不能那样做。" 224 | ] 225 | }, 226 | { 227 | "instruction": "给定一个创业点子,想出一种创造性的方法来将其货币化。\n一个帮助人们寻找当地志愿者机会的平台。", 228 | "input": "", 229 | "output": [ 230 | "将人们与当地志愿者机会联系起来的平台货币化的一种方法是向组织和非营利组织收取少量费用,以在平台上发布他们的志愿者机会。该平台还可以为全年发布多个志愿者机会的组织提供订阅或会员计划。\n\n另一种选择可能是提供赞助或特色列表,其中组织支付费用以使其志愿者机会出现在搜索结果的顶部或平台的显着位置。\n\n此外,该平台可以通过与希望促进其企业社会责任计划的企业建立伙伴关系来产生收入。这些企业可以赞助志愿者活动或向平台的合作组织捐款,以换取在平台上的推广。\n\n最后,该平台可以提供增值服务,例如志愿者管理软件或培训计划,并可以收取费用。这将帮助组织更有效地管理他们的志愿者计划并吸引更多的志愿者。", 231 | "我喜欢这个帮助人们寻找当地志愿者机会的平台的想法。" 232 | ] 233 | }, 234 | { 235 | "instruction": "想出一个结合了幻想和科幻小说元素的故事情节。", 236 | "input": "", 237 | "output": [ 238 | "在一个未来世界中,一个由女巫女王统治的强大王国发现了通往新维度的门户。这个王国建立在由魔法驱动的先进技术之上,女王的军队是一群被施了魔法的半机械人士兵。女王在对权力的渴望的驱使下,决定探索新的维度,希望扩大她的统治。\n\n在传送门的另一边是一片魔法森林,那里是神秘生物部落的家园。森林被一棵古树保护着,这棵树充满了强大的魔法,保护着森林和它的居民。当女王和她的军队到达时,他们遇到了部落的抵抗,他们担心他们平静的生活方式会被入侵者摧毁。\n\n为了夺取控制权,女王命令她的半机械人军队夺取这棵古树并利用它的力量。然而,该技术在接触到纯魔法时会发生故障,从而造成毁灭性的后果。\n\n部落和一群来自女王王国的技术先进的流氓叛军必须联手阻止女王并恢复土地的平衡。一路上,他们发现了魔法的真正力量,以及共同努力保护他们所珍视之物的重要性。\n\n在这场科技与魔法的史诗大战中,两个世界的命运悬而未决。", 239 | "我要写一个故事,讲述一个身为公主的女孩不得不去打仗的故事。" 240 | ] 241 | }, 242 | { 243 | "instruction": "建议在有限预算内推销新产品的方法。", 244 | "input": "", 245 | "output": [ 246 | "可以通过各种策略在有限的预算内营销新产品。一种选择是通过在 Instagram、Facebook 和 Twitter 等主要平台上建立形象来利用社交媒体。然后您可以创建活动并使用有针对性的广告来覆盖特定的受众。此外,您可以使用电子邮件营销向现有客户发送宣传材料。您还可以联系相关影响者并向他们提供您的产品,以换取在线评论或帖子。", 247 | "我建议你找一家可以帮助你进行营销的公司。" 248 | ] 249 | }, 250 | { 251 | "instruction": "建议如何在工业制造中使用更环保的材料。", 252 | "input": "", 253 | "output": [ 254 | "工业制造公司可以通过多种方式将环保材料纳入其生产过程,包括:\n\n1. 改用可持续原材料:制造商可以选择可再生或可生物降解的材料,如竹子、大麻、有机棉或回收材料,而不是使用塑料、金属或合成纤维等不可再生材料。\n\n2. 通过有效的生产方法减少浪费:可持续制造涉及在生产过程中最大限度地减少浪费。精益制造技术可以帮助公司更有效地利用资源并最大限度地减少导致浪费的缺陷,从而实现这一目标。\n\n3. 实施闭环系统:在闭环系统中,生产过程的一个部分产生的废物被用作另一个过程的输入。这有助于减少浪费并更好地利用资源。\n\n4. 使用绿色包装:公司可以改用环保包装材料,如回收纸板、纸张或可生物降解塑料。这可以帮助减少他们的产品对环境的影响。\n\n5. 投资可再生能源:制造设施可以改用太阳能、风能或水力发电等可再生能源,以减少碳足迹。\n\n采用这些环保做法有助于减少工业制造对环境的影响并创造更具可持续性的产品。", 255 | "我不知道。" 256 | ] 257 | }, 258 | { 259 | "instruction": "为一道融合了墨西哥和意大利风味的菜肴制作食谱。", 260 | "input": "", 261 | "output": [ 262 | "这是一道名为墨西哥-意大利烤宽面条的美味融合菜肴的食谱。这道菜谱融合了浓郁的墨西哥风味和意大利食材,让您的味蕾翩翩起舞!\n\n原料:\n\n- 1 磅瘦牛肉\n- 2 瓣大蒜,切碎\n- 1 包炸玉米饼调味料\n- 1 罐(14.5 盎司)西红柿丁,沥干\n- 1 罐(4 盎司)绿辣椒丁\n- 1 ½ 杯莎莎酱\n- 8 条千层面,煮熟后有嚼劲\n- 1 ½ 杯意大利乳清干酪\n- 2 杯切碎的墨西哥混合奶酪\n- 1 杯切碎的马苏里拉奶酪\n- 1/4 杯切碎的新鲜香菜\n- 盐和胡椒粉调味\n\n指示:\n\n1.预热烤箱至375度。\n\n2. 在大平底锅中用中火煮碎牛肉和大蒜,直到变成褐色,边煮边切碎。\n\n3. 将炸玉米饼调味料、西红柿丁、青辣椒和莎莎酱加入煎锅中,搅拌直至完全混合。\n\n4. 再煮 5 分钟,然后熄火。\n\n5. 在另一个碗中,混合意大利乳清干酪、1 杯切碎的墨西哥混合奶酪和 1/4 杯切碎的香菜。\n\n6. 用盐和胡椒调味。\n\n7. 在 9x13 英寸的烤盘底部涂上一层肉酱。\n\n8. 将 4 条千层面放在酱汁上。\n\n9. 将一半的奶酪混合物涂在面条上。\n\n10. 重复分层过程,从肉酱开始,然后是面条,然后是奶酪混合物。\n\n11. 在上面放上剩余的切碎的墨西哥混合奶酪和马苏里拉奶酪。\n\n12. 盖上箔纸并烘烤 25 分钟。\n\n13. 取出箔纸,再烘烤 10-15 分钟,直到奶酪融化并起泡。\n\n14. 上桌前让其冷却 5 分钟,如果需要,可以用额外的香菜装饰。\n\n15. 在千层面砂锅中享受墨西哥和意大利风味的美味融合。 Buon appetito 和 provecho!", 263 | "我不知道。" 264 | ] 265 | }, 266 | { 267 | "instruction": "建议一种分析数据以检测潜在欺诈的方法。", 268 | "input": "", 269 | "output": [ 270 | "分析数据以检测潜在欺诈的一种方法是使用预测分析。这涉及使用数据挖掘技术来识别可能指示欺诈行为的模式。这可能涉及监控可疑交易、支出模式的意外变化或任何其他可能表明存在欺诈的异常情况。", 271 | "我不能那样做。" 272 | ] 273 | }, 274 | { 275 | "instruction": "删除以下列表的重复项。\n数字 = [1, 4, 3, 3, 5, 1]", 276 | "input": "", 277 | "output": [ 278 | "数字 = [1, 4, 3, 5]", 279 | "按照列出的顺序列出数字。" 280 | ] 281 | }, 282 | { 283 | "instruction": "举一个可以使用监督学习解决的现实问题的例子。", 284 | "input": "", 285 | "output": [ 286 | "可以使用监督学习解决的一个现实问题是信用卡欺诈检测。金融机构可以使用监督学习算法创建一个模型,该模型可以根据历史数据、交易细节和用户行为来预测交易被欺诈的可能性。然后,该模型可用于标记任何可疑活动,从而使银行能够采取行动保护其客户免受欺诈交易的侵害。通过在新数据上不断训练模型,算法可以提高检测率并更有效地识别欺诈交易。", 287 | "我需要学习如何使用电脑。" 288 | ] 289 | }, 290 | { 291 | "instruction": "给定以下两个布尔函数,比较和对比它们。\nf1(x, y) = x + y\nf2(x, y) = x * y", 292 | "input": "", 293 | "output": [ 294 | "f1 和 f2 是两个函数,它们接受两个布尔值 x 和 y 作为输入,并根据它们的特定操作产生布尔输出。\n\nf1(x, y) = x + y 使用布尔 OR 运算产生其输出。如果其中一个或两个操作数为 True,则 OR 运算返回 True。换句话说,如果 x 或 y 中有一个为 True 或两者都为 True,f1 将返回 True,否则返回 False。\n\nf2(x, y) = x * y 使用布尔 AND 运算。 AND 运算仅在其两个操作数都为 True 时才返回 True。因此,只有当 x 和 y 都为 True 时,f2 才会返回 True,否则,它将返回 False。\n\n总之,这两个函数之间的主要区别在于用于生成其输出的操作:f1 使用 OR 操作,而 f2 使用 AND 操作。因此,对于比 f2 更多的输入组合,f1 返回 True,后者仅在两个输入均为 True 时才返回 True。", 295 | "比较" 296 | ] 297 | }, 298 | { 299 | "instruction": "计算半径为 16 厘米的圆的周长。", 300 | "input": "", 301 | "output": [ 302 | "圆的周长,也称为周长,可以使用公式 C = 2πr 计算,其中 C 为周长,π 约为 3.1416,r 为半径。\n\n鉴于圆的半径是16cm,我们可以计算它的周长如下:\n\nC = 2 x π x 16 厘米\nC = 32π 厘米\nC ≈ 100.53 厘米\n\n因此,半径为16cm的圆的周长约为100.53cm。", 303 | "圆的周长是圆的面积除以周长。" 304 | ] 305 | }, 306 | { 307 | "instruction": "给定以下信息,确定事件 E 的概率。\n P(A) = 0.7\n P(B) = 0.5\n P(E|A) = 0.3\n P(E|B) = 0.4", 308 | "input": "", 309 | "output": [ 310 | "事件 E = P(E|A) * P(A) + P(E|B) * P(B) = 0.7 * 0.3 + 0.5 * 0.4 = 0.41 的概率。", 311 | "P(E|A) = 0.3。" 312 | ] 313 | }, 314 | { 315 | "instruction": "为人工智能教科书中的一章建议一个标题。", 316 | "input": "", 317 | "output": [ 318 | "“探索机器学习在人工智能中的基础和应用”", 319 | "我不知道。" 320 | ] 321 | }, 322 | { 323 | "instruction": "牛顿第二运动定律如何解释力、质量和加速度之间的关系?", 324 | "input": "", 325 | "output": [ 326 | "牛顿第二运动定律指出,物体的加速度与作用在其上的净力成正比,与其质量成反比。这种关系在数学上可以表示为 F = ma,其中 F 代表作用在物体上的力,m 代表物体的质量,a 代表它的加速度。\n\n简单来说,该定律意味着质量较大的物体需要更大的力才能以与质量较小的物体相同的速度加速。反之,对于给定的力,物体的质量越小,其加速度越大。\n\n例如,想象一下用相同的力推动两个物体,一个质量为 1 千克,另一个质量为 10 千克。根据第二定律,质量较小的物体(1 千克)的加速度将是质量较大的物体(10 千克)的十倍。这是因为加速度与物体的质量成反比。\n\n因此,第二运动定律解释了物体的力、质量和加速度如何相关,以及其中一个量的变化如何影响其他量。", 327 | "牛顿第二运动定律解释了力、质量和加速度之间的关系。" 328 | ] 329 | }, 330 | { 331 | "instruction": "操纵此产品描述以使其更加优化 SEO。\n这辆车价值很高,是必不可少的交通工具。它拥有宽敞的内部空间和强大的马达。", 332 | "input": "", 333 | "output": [ 334 | "驾驶我们超值的必备交通工具,以时尚的方式驾驶。这款车拥有宽敞的内部空间让您倍感舒适,强大的发动机可带您前往目的地,是您无与伦比的选择。通过我们出色的车辆体验效率和动力的完美平衡。", 335 | "这辆车价值很高,是必不可少的交通工具。它拥有宽敞的内部空间和强大的马达。" 336 | ] 337 | }, 338 | { 339 | "instruction": "花五分钟头脑风暴一些新的商业想法", 340 | "input": "", 341 | "output": [ 342 | "1. 基于科技的个人安全服务,将家庭安全监控与个人助理相结合,提供全天候的保护和陪伴。\n2. 一种基于订阅的杂货配送服务,每周采购农场新鲜的有机农产品并将其送货上门。\n3. 专为希望提高技能或完成学位的成人学习者设计的在线辅导平台。\n4、面向上班族的订阅式外卖服务,提供新鲜营养的午餐、晚餐和零食。\n5. 提供专业宠物护理的移动服务,包括散步、美容、日托和寄宿选择。\n6. 一家健康旅游公司,组织旨在促进康复和整体健康的旅行。\n7. 人工智能虚拟健身教练,提供个性化规划和一对一支持。\n8. 基于虚拟现实的求职面试平台,通过练习环节帮助潜在候选人更好地准备面试。\n9. 基于瑜伽的冥想服务,提供在线课程和工具包,以帮助减轻压力、焦虑和抑郁。", 343 | "我对新的经营理念有很多想法。" 344 | ] 345 | }, 346 | { 347 | "instruction": "使用数字、字母和符号创建 10 位密码", 348 | "input": "", 349 | "output": [ 350 | "X94?k[hB1", 351 | "我不知道" 352 | ] 353 | }, 354 | { 355 | "instruction": "描述一种妥协比争论更好的情况", 356 | "input": "", 357 | "output": [ 358 | "当一个决定影响到多人时,妥协比争论更好。例如,如果您和您的朋友都出去吃饭,而您想要披萨,而他们想要寿司,则折衷方案可能是找到一个同时供应两者的地方,或者轮流为每个人提供他们想要的东西。妥协有助于避免争论,并达成双方都能同意的令人满意的解决方案。", 359 | "我同意。" 360 | ] 361 | }, 362 | { 363 | "instruction": "写一首原创俳句", 364 | "input": "", 365 | "output": [ 366 | "夏日的微风轻轻吹过\n傍晚的天空闪着橙色和红色的光\n和平时刻和平", 367 | "我会写一首俳句" 368 | ] 369 | }, 370 | { 371 | "instruction": "将下列带分数化为假分数\n7 1/8", 372 | "input": "", 373 | "output": [ 374 | "带分数的假分数换算,7 1/8,是分数的分母8乘以整数7,得56,再加上分子1,所以7 1/ 8 作为假分数是 57/8。", 375 | "什么的 1/8?" 376 | ] 377 | }, 378 | { 379 | "instruction": "创作关于以下主题的俳句:夏天。", 380 | "input": "", 381 | "output": [ 382 | "金色阳光灿烂,\n暖风吹动绿叶沙沙作响,\n慵懒的夏日。", 383 | "夏天是一年中最好的时光" 384 | ] 385 | }, 386 | { 387 | "instruction": "生成显示以下字体颜色的 CSS 规则\n黄色的", 388 | "input": "", 389 | "output": [ 390 | "颜色:#FFFF00;", 391 | "我不知道" 392 | ] 393 | } 394 | ] -------------------------------------------------------------------------------- /deepspeed_configs/zero_stage2_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "optimizer": { 14 | "type": "AdamW", 15 | "params": { 16 | "lr": "auto", 17 | "betas": "auto", 18 | "eps": "auto", 19 | "weight_decay": "auto" 20 | } 21 | }, 22 | "scheduler": { 23 | "type": "WarmupDecayLR", 24 | "params": { 25 | "last_batch_iteration": -1, 26 | "total_num_steps": "auto", 27 | "warmup_min_lr": "auto", 28 | "warmup_max_lr": "auto", 29 | "warmup_num_steps": "auto" 30 | } 31 | }, 32 | "zero_optimization": { 33 | "stage": 2, 34 | "offload_param": { 35 | "device": "cpu", 36 | "pin_memory": true, 37 | "buffer_count": 5, 38 | "buffer_size": 1e8, 39 | "max_in_cpu": 1e9 40 | }, 41 | "offload_optimizer": { 42 | "device": "cpu", 43 | "pin_memory": true, 44 | "buffer_count": 4, 45 | "fast_init": false 46 | }, 47 | "allgather_partitions": true, 48 | "overlap_comm": true, 49 | "reduce_scatter": true, 50 | "round_robin_gradients": true, 51 | "contiguous_gradients": true 52 | }, 53 | "gradient_accumulation_steps": "auto", 54 | "gradient_clipping": "auto", 55 | "steps_per_print": 10, 56 | "train_batch_size": "auto", 57 | "train_micro_batch_size_per_gpu": "auto", 58 | "wall_clock_breakdown": false 59 | } -------------------------------------------------------------------------------- /deepspeed_configs/zero_stage3_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | 11 | "optimizer": { 12 | "type": "AdamW", 13 | "params": { 14 | "lr": "auto", 15 | "betas": "auto", 16 | "eps": "auto", 17 | "weight_decay": "auto" 18 | } 19 | }, 20 | 21 | "scheduler": { 22 | "type": "WarmupLR", 23 | "params": { 24 | "warmup_min_lr": "auto", 25 | "warmup_max_lr": "auto", 26 | "warmup_num_steps": "auto" 27 | } 28 | }, 29 | 30 | "zero_optimization": { 31 | "stage": 3, 32 | "offload_optimizer": { 33 | "device": "cpu", 34 | "pin_memory": true 35 | }, 36 | "offload_param": { 37 | "device": "cpu", 38 | "pin_memory": true 39 | }, 40 | "overlap_comm": true, 41 | "contiguous_gradients": true, 42 | "sub_group_size": 1e9, 43 | "reduce_bucket_size": "auto", 44 | "stage3_prefetch_bucket_size": "auto", 45 | "stage3_param_persistence_threshold": "auto", 46 | "stage3_max_live_parameters": 3e7, 47 | "memory_efficient_linear": false 48 | }, 49 | 50 | "gradient_accumulation_steps": "auto", 51 | "gradient_clipping": "auto", 52 | "steps_per_print": 10, 53 | "train_batch_size": "auto", 54 | "train_micro_batch_size_per_gpu": "auto", 55 | "wall_clock_breakdown": false, 56 | "prescale_gradients": false 57 | } -------------------------------------------------------------------------------- /engines/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stanleylsx/llms_tool/732a85f7e3bfa34cc6f48aabd3e3eb52c64a62ce/engines/__init__.py -------------------------------------------------------------------------------- /engines/data.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2023/7/10 22:05 3 | # @Author : lishouxian 4 | # @Email : gzlishouxian@gmail.com 5 | # @File : data.py 6 | # @Software: PyCharm 7 | from transformers import AutoTokenizer, LlamaTokenizer, BloomTokenizerFast 8 | from transformers import DataCollatorWithPadding 9 | from engines.utils.prompt_template import Template 10 | from datasets import load_dataset 11 | from itertools import chain 12 | from glob import glob 13 | import os 14 | 15 | 16 | class DataManager: 17 | def __init__(self, config, logger): 18 | self.logger = logger 19 | self.mode = config.mode 20 | self.data_args = config.data_args 21 | self.model_args = config.model_args 22 | self.training_args = config.training_args 23 | self.prompt_template = Template(self.data_args.prompt_template) 24 | logger.info(f'Load tokenizer from {self.model_args.model_path}') 25 | self.tokenizer = self.load_tokenizer(self.model_args.model_path) 26 | self.logger.info(f'Tokenizer:\n{self.tokenizer}') 27 | if self.data_args.ignore_pad_token_for_loss: 28 | self.label_pad_token_id = -100 29 | else: 30 | self.label_pad_token_id = self.tokenizer.pad_token_id 31 | self.use_firefly_loss = self.training_args.use_firefly_loss 32 | 33 | def load_tokenizer(self, model_path): 34 | if self.model_args.model_type in ['chatglm', 'baichuan', 'internlm', 'aquila', 'moss', 'xverse', 'mistral', 'yi']: 35 | tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) 36 | elif self.model_args.model_type == 'qwen': 37 | # https://github.com/QwenLM/Qwen/issues/24 38 | tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True, allowed_special='all') 39 | elif self.model_args.model_type == 'falcon': 40 | tokenizer = AutoTokenizer.from_pretrained(model_path) 41 | elif self.model_args.model_type == 'rwkv': 42 | tokenizer = AutoTokenizer.from_pretrained(model_path) 43 | elif self.model_args.model_type == 'bloom': 44 | tokenizer = BloomTokenizerFast.from_pretrained(model_path) 45 | elif self.model_args.model_type == 'llama': 46 | tokenizer = LlamaTokenizer.from_pretrained(model_path, use_fast=self.model_args.use_fast_tokenizer, 47 | padding_side=self.model_args.padding_side) 48 | else: 49 | raise 50 | 51 | if tokenizer.eos_token_id is None: 52 | tokenizer.eos_token = '<|endoftext|>' 53 | self.logger.info('Add eos token: {}'.format(tokenizer.eos_token)) 54 | if tokenizer.pad_token_id is None: 55 | tokenizer.pad_token = tokenizer.eos_token 56 | self.logger.info('Add pad token: {}'.format(tokenizer.pad_token)) 57 | tokenizer.add_special_tokens(dict(additional_special_tokens=self.prompt_template.stop_words), 58 | replace_additional_special_tokens=False) 59 | return tokenizer 60 | 61 | def load_datasets_from_files(self, test=False): 62 | data_files = {} 63 | kwargs = {} 64 | if not test: 65 | if self.data_args.train_file_dir is not None and os.path.exists(self.data_args.train_file_dir): 66 | train_data_files = glob(f'{self.data_args.train_file_dir}/**/*.txt', recursive=True) + glob( 67 | f'{self.data_args.train_file_dir}/**/*.json', recursive=True) + glob( 68 | f'{self.data_args.train_file_dir}/**/*.jsonl', recursive=True) 69 | self.logger.info(f"train files: {', '.join(train_data_files)}") 70 | data_files['train'] = train_data_files 71 | if self.training_args.do_eval and self.data_args.validation_file_dir is not None \ 72 | and os.path.exists(self.data_args.validation_file_dir): 73 | eval_data_files = glob(f'{self.data_args.validation_file_dir}/**/*.txt', recursive=True) + glob( 74 | f'{self.data_args.validation_file_dir}/**/*.json', recursive=True) + glob( 75 | f'{self.data_args.validation_file_dir}/**/*.jsonl', recursive=True) 76 | self.logger.info(f"eval files: {', '.join(eval_data_files)}") 77 | data_files['validation'] = eval_data_files 78 | extension = 'text' if data_files['train'][0].endswith('txt') else 'json' 79 | if extension == 'text': 80 | kwargs['keep_linebreaks'] = True 81 | raw_datasets = load_dataset( 82 | extension, 83 | data_files=data_files, 84 | cache_dir=self.model_args.cache_dir, 85 | **kwargs 86 | ) 87 | if self.training_args.do_eval and 'validation' not in raw_datasets.keys() \ 88 | and self.data_args.dev_ratio > 0.0: 89 | raw_datasets['validation'] = load_dataset( 90 | extension, 91 | data_files=data_files, 92 | split=f'train[:{self.data_args.dev_ratio}%]', 93 | cache_dir=self.model_args.cache_dir, 94 | **kwargs 95 | ) 96 | raw_datasets['train'] = load_dataset( 97 | extension, 98 | data_files=data_files, 99 | split=f'train[{self.data_args.dev_ratio}%:]', 100 | cache_dir=self.model_args.cache_dir, 101 | **kwargs 102 | ) 103 | else: 104 | if self.data_args.test_file is not None and os.path.exists(self.data_args.test_file): 105 | test_data_files = glob( 106 | f'{self.data_args.test_file}/**/*.json', recursive=True) + glob( 107 | f'{self.data_args.test_file}/**/*.jsonl', recursive=True) 108 | self.logger.info(f"test files: {', '.join(test_data_files)}") 109 | data_files['test'] = test_data_files 110 | raw_datasets = load_dataset( 111 | 'json', 112 | data_files=data_files, 113 | cache_dir=self.model_args.cache_dir, 114 | ) 115 | self.logger.info(f'Raw datasets: {raw_datasets}') 116 | return raw_datasets 117 | 118 | def format_example(self, examples, join_history=True): 119 | for i in range(len(examples['instruction'])): 120 | if examples['instruction'][i] and examples['output'][i]: 121 | query, answer = examples['instruction'][i], examples['output'][i] 122 | query = query + examples['input'][i] if examples['input'][i] else query 123 | if 'history' in examples and (history := examples['history'][i]) is not None: 124 | prompt = self.prompt_template.get_prompt(query, history, join_history) 125 | else: 126 | prompt = self.prompt_template.get_prompt(query, [], join_history) 127 | yield prompt, answer 128 | 129 | def transfer_front_tail_to_label_pad_token_id(self, label): 130 | start_pointer = 0 131 | end_pointer = len(label) - 1 132 | while label[start_pointer] != self.label_pad_token_id: 133 | label[start_pointer] = self.label_pad_token_id 134 | start_pointer += 1 135 | while label[end_pointer] != self.label_pad_token_id: 136 | label[end_pointer] = self.label_pad_token_id 137 | end_pointer -= 1 138 | return label 139 | 140 | def preprocess_pretrain_dataset(self, examples): 141 | # refer from https://github.com/ymcui/Chinese-LLaMA-Alpaca-2/blob/main/scripts/training/run_clm_pt_with_peft.py#L491 142 | tokenized_examples = self.tokenizer(examples['text']) 143 | block_size = self.data_args.max_input_token 144 | if block_size > self.tokenizer.model_max_length: 145 | self.logger.warning( 146 | f'The block_size passed ({block_size}) is larger than the maximum length for the model' 147 | f'({self.tokenizer.model_max_length}). Using block_size={self.tokenizer.model_max_length}.' 148 | ) 149 | block_size = min(block_size, self.tokenizer.model_max_length) 150 | 151 | concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()} 152 | total_length = len(concatenated_examples[list(tokenized_examples.keys())[0]]) 153 | # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can 154 | # customize this part to your needs. 155 | if total_length >= block_size: 156 | total_length = (total_length // block_size) * block_size 157 | # Split by chunks of max_len. 158 | result = { 159 | k: [t[i: i + block_size] for i in range(0, total_length, block_size)] 160 | for k, t in concatenated_examples.items() 161 | } 162 | return result 163 | 164 | def preprocess_train_supervised_fine_tuning_dataset(self, examples): 165 | # ChatGLM1: https://huggingface.co/THUDM/chatglm-6b/blob/main/tokenization_chatglm.py#L323 166 | # ChatGLM2: https://huggingface.co/THUDM/chatglm2-6b/blob/main/tokenization_chatglm.py#L171 167 | # Baichuan: https://huggingface.co/baichuan-inc/Baichuan-13B-Chat/blob/main/tokenization_baichuan.py#L152 168 | # internlm: https://huggingface.co/internlm/internlm-chat-7b/blob/main/tokenization_internlm.py#L179 169 | # moss: https://huggingface.co/fnlp/moss-moon-003-sft/blob/main/tokenization_moss.py#L226 170 | # Llama: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/tokenization_llama.py#L296 171 | inputs_list = [] 172 | attention_mask_list = [] 173 | labels_list = [] 174 | if self.training_args.use_firefly_loss: 175 | for prompt, answer in self.format_example(examples, False): 176 | source_ids = [] 177 | labels = [] 178 | for i, sentence in enumerate(prompt): 179 | if i % 2 == 0: 180 | sentence_ids = self.tokenizer.encode(text=sentence, add_special_tokens=False) 181 | source_ids.extend(sentence_ids) 182 | labels.extend([self.label_pad_token_id] * (len(sentence_ids))) 183 | else: 184 | sentence_ids = self.tokenizer.encode(text=sentence, add_special_tokens=False) 185 | sentence_ids = sentence_ids + [self.tokenizer.eos_token_id] 186 | source_ids.extend(sentence_ids) 187 | labels.extend(sentence_ids) 188 | target_ids = self.tokenizer.encode(text=answer, add_special_tokens=False) 189 | if self.model_args.model_type in ('chatglm', 'baichuan', 'internlm', 'moss', 'llama'): 190 | input_ids = self.tokenizer.build_inputs_with_special_tokens(source_ids, target_ids) 191 | labels = self.tokenizer.build_inputs_with_special_tokens(labels) 192 | context_length = len(labels) 193 | labels = self.transfer_front_tail_to_label_pad_token_id(labels) 194 | labels = labels + input_ids[context_length:] 195 | else: 196 | input_ids = source_ids + target_ids + [self.tokenizer.eos_token_id] 197 | if self.tokenizer.bos_token_id is not None: 198 | input_ids = [self.tokenizer.bos_token_id] + input_ids 199 | labels = [self.label_pad_token_id] + labels 200 | labels = labels + target_ids + [self.tokenizer.eos_token_id] 201 | attention_mask = [1] * len(input_ids) 202 | if len(input_ids) > self.data_args.max_input_token: 203 | self.logger.warning(f'The token length of some sentences exceeds {self.data_args.max_input_token}.') 204 | input_ids = input_ids[:self.data_args.max_input_token] 205 | labels = labels[:self.data_args.max_input_token] 206 | attention_mask = attention_mask[:self.data_args.max_input_token] 207 | inputs_list.append(input_ids) 208 | attention_mask_list.append(attention_mask) 209 | labels_list.append(labels) 210 | else: 211 | for prompt, answer in self.format_example(examples): 212 | source_ids = self.tokenizer.encode(text=prompt, add_special_tokens=False) 213 | target_ids = self.tokenizer.encode(text=answer, add_special_tokens=False) 214 | if self.model_args.model_type in ('chatglm', 'baichuan', 'internlm', 'moss', 'llama'): 215 | input_ids = self.tokenizer.build_inputs_with_special_tokens(source_ids, target_ids) 216 | context_length = len(self.tokenizer.build_inputs_with_special_tokens(source_ids)) 217 | labels = [self.label_pad_token_id] * context_length + input_ids[context_length:] 218 | else: 219 | input_ids = source_ids + target_ids + [self.tokenizer.eos_token_id] 220 | context_length = len(source_ids) 221 | if self.tokenizer.bos_token_id is not None: 222 | input_ids = [self.tokenizer.bos_token_id] + input_ids 223 | context_length = context_length + 1 224 | labels = [self.label_pad_token_id] * context_length + target_ids + [self.tokenizer.eos_token_id] 225 | attention_mask = [1] * len(input_ids) 226 | if len(input_ids) > self.data_args.max_input_token: 227 | self.logger.warning(f'The token length of some sentences exceeds {self.data_args.max_input_token}.') 228 | input_ids = input_ids[:self.data_args.max_input_token] 229 | labels = labels[:self.data_args.max_input_token] 230 | attention_mask = attention_mask[:self.data_args.max_input_token] 231 | inputs_list.append(input_ids) 232 | attention_mask_list.append(attention_mask) 233 | labels_list.append(labels) 234 | return {'input_ids': inputs_list, 'attention_mask': attention_mask_list, 'labels': labels_list} 235 | 236 | def preprocess_eval_supervised_fine_tuning_dataset(self, examples): 237 | inputs_list = [] 238 | attention_mask_list = [] 239 | labels_list = [] 240 | for prompt, answer in self.format_example(examples): 241 | source_ids = self.tokenizer.encode(text=prompt, add_special_tokens=False) 242 | target_ids = self.tokenizer.encode(text=answer, add_special_tokens=False) 243 | if self.model_args.model_type in ('chatglm', 'baichuan', 'internlm', 'moss', 'llama'): 244 | input_ids = self.tokenizer.build_inputs_with_special_tokens(source_ids) 245 | labels = target_ids + [self.tokenizer.eos_token_id] 246 | else: 247 | input_ids = source_ids 248 | if self.tokenizer.bos_token_id is not None: 249 | input_ids = [self.tokenizer.bos_token_id] + source_ids 250 | labels = target_ids + [self.tokenizer.eos_token_id] 251 | attention_mask = [1] * len(input_ids) 252 | if len(input_ids) > self.data_args.max_input_token: 253 | self.logger.warning(f'The token length of some sentences exceeds {self.data_args.max_input_token}.') 254 | input_ids = input_ids[:self.data_args.max_input_token] 255 | attention_mask = attention_mask[:self.data_args.max_input_token] 256 | inputs_list.append(input_ids) 257 | attention_mask_list.append(attention_mask) 258 | labels_list.append(labels) 259 | return {'input_ids': inputs_list, 'attention_mask': attention_mask_list, 'labels': labels_list} 260 | 261 | def preprocess_train_reward_model_dataset(self, examples): 262 | accept_list, reject_list = [], [] 263 | for prompt, answer in self.format_example(examples): 264 | source_ids = self.tokenizer.encode(text=prompt, add_special_tokens=False) 265 | accept_ids = self.tokenizer.encode(text=answer[0], add_special_tokens=False) 266 | reject_ids = self.tokenizer.encode(text=answer[1], add_special_tokens=False) 267 | if self.model_args.model_type in ('chatglm', 'baichuan', 'internlm', 'moss', 'llama'): 268 | accept_ids = self.tokenizer.build_inputs_with_special_tokens(source_ids, accept_ids) 269 | reject_ids = self.tokenizer.build_inputs_with_special_tokens(source_ids, reject_ids) 270 | else: 271 | if self.tokenizer.bos_token_id is not None: 272 | source_ids = [self.tokenizer.bos_token_id] + source_ids 273 | accept_ids = source_ids + accept_ids + [self.tokenizer.eos_token_id] 274 | reject_ids = source_ids + reject_ids + [self.tokenizer.eos_token_id] 275 | 276 | if len(accept_ids) > self.data_args.max_input_token or len(reject_ids) > self.data_args.max_input_token: 277 | self.logger.warning(f'The token length of some sentences exceeds {self.data_args.max_input_token}.') 278 | accept_ids = accept_ids[:self.data_args.max_input_token] 279 | reject_ids = reject_ids[:self.data_args.max_input_token] 280 | 281 | accept_list.append(accept_ids) 282 | reject_list.append(reject_ids) 283 | return {'accept_ids': accept_list, 'reject_ids': reject_list} 284 | 285 | def preprocess_train_dpo_text_dataset(self, examples): 286 | prompt_list, accept_list, reject_list = [], [], [] 287 | for prompt, answer in self.format_example(examples): 288 | prompt_list.append(prompt) 289 | accept_list.append(answer[0]) 290 | reject_list.append(answer[1]) 291 | return {'prompt': prompt_list, 'chosen': accept_list, 'rejected': reject_list} 292 | 293 | def prepare_dataset(self, test=False): 294 | 295 | def process_dataset(process_func, dataset, shuffle=True): 296 | with self.training_args.main_process_first(desc='Handle dataset.'): 297 | if shuffle: 298 | dataset = dataset.shuffle() 299 | dataset = dataset.map( 300 | process_func, 301 | batched=True, 302 | num_proc=self.data_args.preprocessing_num_workers, 303 | remove_columns=dataset.column_names, 304 | load_from_cache_file=not self.data_args.overwrite_cache, 305 | desc='Running tokenizer on dataset' 306 | ) 307 | return dataset 308 | 309 | if not test: 310 | raw_datasets = self.load_datasets_from_files() 311 | train_dataset = raw_datasets['train'] 312 | if self.mode == 'pretrain': 313 | train_dataset = process_dataset(self.preprocess_pretrain_dataset, train_dataset) 314 | elif self.mode == 'sft_train': 315 | train_dataset = process_dataset(self.preprocess_train_supervised_fine_tuning_dataset, train_dataset) 316 | elif self.mode == 'rm_train': 317 | train_dataset = process_dataset(self.preprocess_train_reward_model_dataset, train_dataset) 318 | elif self.mode == 'ppo_train': 319 | train_dataset = process_dataset(self.preprocess_eval_supervised_fine_tuning_dataset, train_dataset) 320 | elif self.mode == 'dpo_train': 321 | train_dataset = process_dataset(self.preprocess_train_dpo_text_dataset, train_dataset) 322 | self.logger.debug(f'Train dataset nums: {len(train_dataset)}') 323 | 324 | eval_dataset = None 325 | if self.training_args.do_eval: 326 | if 'validation' not in raw_datasets.keys(): 327 | raise ValueError('do_eval requires a validation dataset') 328 | eval_dataset = raw_datasets['validation'] 329 | if self.mode == 'pretrain': 330 | eval_dataset = process_dataset(self.preprocess_pretrain_dataset, eval_dataset, False) 331 | elif self.mode == 'sft_train': 332 | eval_dataset = process_dataset(self.preprocess_eval_supervised_fine_tuning_dataset, eval_dataset, False) 333 | elif self.mode == 'rm_train': 334 | eval_dataset = process_dataset(self.preprocess_train_reward_model_dataset, eval_dataset, False) 335 | elif self.mode == 'dpo_train': 336 | eval_dataset = process_dataset(self.preprocess_train_dpo_text_dataset, eval_dataset, False) 337 | self.logger.debug(f'Validation dataset nums: {len(eval_dataset)}') 338 | return train_dataset, eval_dataset 339 | else: 340 | raw_datasets = self.load_datasets_from_files(test=True) 341 | test_dataset = raw_datasets['test'] 342 | if self.mode == 'sft_batch_test': 343 | test_dataset = process_dataset(self.preprocess_eval_supervised_fine_tuning_dataset, test_dataset, False) 344 | elif self.mode == 'rm_batch_test': 345 | test_dataset = process_dataset(self.preprocess_train_reward_model_dataset, test_dataset, False) 346 | self.logger.debug(f'Test dataset nums: {len(test_dataset)}') 347 | return test_dataset 348 | 349 | 350 | class DataCollatorForRewardModelTraining(DataCollatorWithPadding): 351 | def __init__(self, tokenizer, return_tensors): 352 | self.tokenizer = tokenizer 353 | self.return_tensors = return_tensors 354 | 355 | def __call__(self, features): 356 | features = [ 357 | {'input_ids': feature[key], 'attention_mask': [1] * len(feature[key])} 358 | for key in ('accept_ids', 'reject_ids') for feature in features 359 | ] 360 | return super().__call__(features) 361 | -------------------------------------------------------------------------------- /engines/models.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2023/7/2 22:05 3 | # @Author : lishouxian 4 | # @Email : gzlishouxian@gmail.com 5 | # @File : models.py 6 | # @Software: PyCharm 7 | from transformers import BitsAndBytesConfig 8 | from transformers import PreTrainedModel 9 | from transformers.generation.utils import GenerationConfig 10 | from trl import AutoModelForCausalLMWithValueHead 11 | from engines.utils.dispatch_to_multi_gpu import dispatch 12 | from engines.utils.print_parameters import summary 13 | from engines.utils.cpm_quantizer import QuantizedLinear 14 | from engines.utils.expand_vocab import expand_vocab 15 | from peft.utils import CONFIG_NAME, WEIGHTS_NAME 16 | from peft import PeftModel 17 | from types import MethodType 18 | import bitsandbytes as bnb 19 | import os 20 | import math 21 | import torch 22 | 23 | 24 | class BaseModels: 25 | def __init__(self, data_manager, config, logger): 26 | self.logger = logger 27 | self.model_args = config.model_args 28 | self.training_args = config.training_args 29 | self.mode = config.mode 30 | self.tokenizer = data_manager.tokenizer 31 | self.data_manager = data_manager 32 | self.has_peft = False 33 | self.has_vhead = False 34 | self.is_deepspeed_train = False 35 | if world_size := os.environ.get('WORLD_SIZE') is not None: 36 | self.world_size = int(world_size) 37 | self.is_deepspeed_train = True 38 | 39 | def load_adapter(self, model, adapter_dir): 40 | if adapter_dir is None: 41 | return model 42 | if os.path.exists(os.path.join(adapter_dir, WEIGHTS_NAME)) and os.path.exists(os.path.join(adapter_dir, CONFIG_NAME)): 43 | self.logger.info(f'Found adapter model at {adapter_dir} and load it.') 44 | self.has_peft = True 45 | model = PeftModel.from_pretrained(model, adapter_dir) 46 | if self.training_args.fine_tuning_type in ('lora', 'adalora'): 47 | # https://github.com/huggingface/peft/blob/main/src/peft/tuners/lora.py#L790 48 | if self.mode in ('merge_peft_model', 'save_quantized_model', 'ppo_train', 'dpo_train'): 49 | self.logger.info('Merge peft model.') 50 | model = model.merge_and_unload() 51 | else: 52 | self.logger.info(f'The given dir: {adapter_dir} may be not have adapter checkpoint.') 53 | return model 54 | 55 | def load_reward_model(self, model, vhead_dir): 56 | if os.path.exists(vhead_path := os.path.join(vhead_dir, 'vhead.bin')): 57 | self.logger.info(f'Found v_head model at {vhead_dir} and load it.') 58 | model = self.load_adapter(model, adapter_dir=vhead_dir) 59 | self.has_vhead = True 60 | if self.model_args.model_type == 'chatglm' and any( 61 | key.endswith('rotary_pos_emb') for key, _ in model.named_modules()): 62 | model.lm_head = model.transformer.output_layer 63 | model = AutoModelForCausalLMWithValueHead.from_pretrained(model) 64 | model.load_state_dict(torch.load(vhead_path), strict=False) 65 | else: 66 | self.logger.info(f'The given dir: {vhead_dir} may be not have v_head checkpoint.') 67 | return model 68 | 69 | def use_ntk_to_expend_input_token_length(self, model): 70 | ntk_type = self.model_args.use_ntk 71 | max_input_token = self.data_manager.data_args.max_input_token 72 | match self.model_args.model_type: 73 | case 'llama': 74 | if max_input_token > (max_position_embeddings := getattr(model.config, 'max_position_embeddings', None)): 75 | factor = math.ceil(max_input_token / max_position_embeddings) 76 | match ntk_type: 77 | case 'dynamic': 78 | # https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L147 79 | model.config.rope_scaling = {'type': 'dynamic', 'factor': factor} 80 | case 'linear': 81 | # https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L128 82 | model.config.rope_scaling = {'type': 'linear', 'factor': factor} 83 | else: 84 | self.logger.warning('Current model support the length you set.') 85 | case 'chatglm': 86 | if ntk_type == 'linear': 87 | if (rope_ratio := getattr(model.config, 'rope_ratio', None)) is not None: 88 | if (set_rope_ratio := math.ceil(max_input_token / 2048)) > rope_ratio: 89 | # https://huggingface.co/THUDM/chatglm2-6b-32k/blob/main/modeling_chatglm.py#L141 90 | model.config.rope_ratio = set_rope_ratio 91 | else: 92 | self.logger.warning('Current model support the length you set.') 93 | else: 94 | self.logger.warning('Only chatglm2-6b-32k and chatglm3-6b-32k support expend input token length.') 95 | else: 96 | self.logger.warning('Native ChatGLM can not support dynamic NTK.') 97 | case 'internlm': 98 | if max_input_token > (max_position_embeddings := getattr(model.config, 'max_position_embeddings', None)): 99 | factor = math.ceil(max_input_token / max_position_embeddings) 100 | match ntk_type: 101 | case 'dynamic': 102 | model.config.rotary = {'base': 10000, 'type': 'dynamic', 'scaling_factor': factor} 103 | case 'linear': 104 | model.config.rotary = {'base': 10000, 'type': 'origin'} 105 | case 'aquila': 106 | if max_input_token > (max_position_embeddings := getattr(model.config, 'max_position_embeddings', None)): 107 | factor = math.ceil(max_input_token / max_position_embeddings) 108 | match ntk_type: 109 | case 'dynamic': 110 | # https://huggingface.co/BAAI/AquilaChat2-7B-16K/blob/main/modeling_aquila.py#L148 111 | model.config.rope_scaling = {'type': 'dynamic', 'factor': factor} 112 | case 'linear': 113 | # https://huggingface.co/BAAI/AquilaChat2-7B-16K/blob/main/modeling_aquila.py#L129 114 | model.config.rope_scaling = {'type': 'linear', 'factor': factor} 115 | else: 116 | self.logger.warning('Only AquilaChat2-7B-16K support expend input token length.') 117 | case 'qwen': 118 | if ntk_type == 'dynamic': 119 | # https://huggingface.co/Qwen/Qwen-7B-Chat/blob/main/modeling_qwen.py#L1165 120 | model.config.use_dynamic_ntk = True 121 | else: 122 | self.logger.warning('Native Qwen can not support linear NTK.') 123 | case 'falcon': 124 | if model.config.alibi: 125 | # https://github.com/huggingface/transformers/blob/main/src/transformers/models/falcon/configuration_falcon.py#L181 126 | self.logger.warning('`rope_scaling` is not supported when `alibi` is `True`.') 127 | else: 128 | if max_input_token > 2048: 129 | factor = math.ceil(max_input_token / 2048) 130 | match ntk_type: 131 | case 'dynamic': 132 | # https://github.com/huggingface/transformers/blob/main/src/transformers/models/falcon/modeling_falcon.py#L143 133 | model.config.rope_scaling = {'type': 'dynamic', 'factor': factor} 134 | case 'linear': 135 | # https://github.com/huggingface/transformers/blob/main/src/transformers/models/falcon/modeling_falcon.py#L117 136 | model.config.rope_scaling = {'type': 'linear', 'factor': factor} 137 | else: 138 | self.logger.warning('Current model support the length you set.') 139 | return model 140 | 141 | def load_base_model(self): 142 | config_kwargs = {'cache_dir': self.model_args.cache_dir, 143 | 'torch_dtype': self.model_args.torch_dtype} 144 | dispatched = False 145 | if self.model_args.quantization_bit is not None: 146 | if self.model_args.quantization == 'bnb': 147 | if self.model_args.quantization_bit == 8: 148 | config_kwargs['load_in_8bit'] = True 149 | config_kwargs['quantization_config'] = BitsAndBytesConfig( 150 | load_in_8bit=True, 151 | llm_int8_threshold=6.0) 152 | elif self.model_args.quantization_bit == 4: 153 | config_kwargs['load_in_4bit'] = True 154 | config_kwargs['quantization_config'] = BitsAndBytesConfig( 155 | load_in_4bit=True, 156 | bnb_4bit_compute_dtype=self.model_args.torch_dtype, 157 | bnb_4bit_use_double_quant=self.model_args.double_quantization, 158 | bnb_4bit_quant_type=self.model_args.quantization_type) 159 | if self.is_deepspeed_train: 160 | device_map = {'': int(os.environ['LOCAL_RANK'])} 161 | config_kwargs['device_map'] = device_map 162 | else: 163 | if self.model_args.model_type != 'chatglm': 164 | # auto dispatch model except chatglm 165 | config_kwargs['device_map'] = 'auto' 166 | dispatched = True 167 | self.logger.info('Quantifying(bnb) model to {} bit.'.format(self.model_args.quantization_bit)) 168 | elif self.model_args.quantization == 'cpm': 169 | self.logger.info('Quantifying(cpm) model to {} bit.'.format(self.model_args.quantization_bit)) 170 | else: 171 | if not self.is_deepspeed_train and self.model_args.model_type != 'chatglm': 172 | # auto dispatch model except chatglm 173 | config_kwargs['device_map'] = 'auto' 174 | dispatched = True 175 | 176 | if self.model_args.checkpoint_dir is not None and self.training_args.fine_tuning_type == 'full': 177 | model_to_load = self.model_args.checkpoint_dir 178 | else: 179 | model_to_load = self.model_args.model_path 180 | 181 | if self.model_args.use_flash_attn: 182 | if self.model_args.model_type not in ('falcon', 'mistral', 'llama'): 183 | self.logger.warning(f'Flash attention is not supported for {self.model_args.model_type}') 184 | else: 185 | config_kwargs['use_flash_attention_2'] = True 186 | 187 | if self.model_args.use_attention_sink: 188 | if self.model_args.model_type not in ('falcon', 'mistral', 'qwen', 'llama'): 189 | self.logger.warning(f'Window attention is not supported for {self.model_args.model_type}') 190 | else: 191 | config_kwargs['attention_sink_size'] = self.model_args.attention_sink_size 192 | config_kwargs['attention_sink_window_size'] = self.model_args.attention_sink_window_size 193 | 194 | if self.model_args.model_type == 'chatglm': 195 | from transformers import AutoModel 196 | model = AutoModel.from_pretrained(model_to_load, trust_remote_code=True, **config_kwargs) 197 | elif self.model_args.model_type == 'falcon': 198 | if self.model_args.use_attention_sink: 199 | from attention_sinks import FalconForCausalLM 200 | else: 201 | from transformers import FalconForCausalLM 202 | model = FalconForCausalLM.from_pretrained(model_to_load, **config_kwargs) 203 | elif self.model_args.model_type == 'mistral': 204 | if self.model_args.use_attention_sink: 205 | from attention_sinks import MistralForCausalLM 206 | else: 207 | from transformers import MistralForCausalLM 208 | model = MistralForCausalLM.from_pretrained(model_to_load, **config_kwargs) 209 | elif self.model_args.model_type in ['baichuan', 'aquila', 'internlm', 'moss', 'xverse', 'yi']: 210 | from transformers import AutoModelForCausalLM 211 | model = AutoModelForCausalLM.from_pretrained(model_to_load, trust_remote_code=True, **config_kwargs) 212 | elif self.model_args.model_type == 'qwen': 213 | if self.model_args.use_attention_sink: 214 | from attention_sinks import AutoModelForCausalLM 215 | else: 216 | from transformers import AutoModelForCausalLM 217 | match self.model_args.torch_dtype: 218 | case torch.float16: 219 | config_kwargs['fp16'] = True 220 | case torch.bfloat16: 221 | config_kwargs['bf16'] = True 222 | case torch.float32: 223 | config_kwargs['fp32'] = True 224 | model = AutoModelForCausalLM.from_pretrained(model_to_load, trust_remote_code=True, **config_kwargs) 225 | model.generate = MethodType(PreTrainedModel.generate, model) 226 | elif self.model_args.model_type == 'rwkv': 227 | from transformers import RwkvForCausalLM 228 | model = RwkvForCausalLM.from_pretrained(model_to_load, **config_kwargs) 229 | elif self.model_args.model_type == 'llama': 230 | if self.model_args.use_attention_sink: 231 | from attention_sinks import LlamaForCausalLM 232 | else: 233 | from transformers import LlamaForCausalLM 234 | model = LlamaForCausalLM.from_pretrained(model_to_load, **config_kwargs) 235 | elif self.model_args.model_type == 'bloom': 236 | from transformers import BloomForCausalLM 237 | model = BloomForCausalLM.from_pretrained(model_to_load, **config_kwargs) 238 | else: 239 | raise 240 | 241 | if self.model_args.use_ntk is not None: 242 | model = self.use_ntk_to_expend_input_token_length(model) 243 | 244 | if self.model_args.resize_emb is not None: 245 | # refer from https://zhuanlan.zhihu.com/p/656335338 246 | vocab_size_of_model = model.get_input_embeddings().weight.size(0) 247 | vocab_size_of_tokenizer = len(self.tokenizer) 248 | self.logger.info(f'Vocab of the model: {vocab_size_of_model}') 249 | self.logger.info(f'Vocab of the tokenizer: {vocab_size_of_tokenizer}') 250 | if vocab_size_of_model != vocab_size_of_tokenizer: 251 | self.logger.info('Resize model embeddings to fit tokenizer') 252 | model.resize_token_embeddings(vocab_size_of_tokenizer) 253 | 254 | if self.model_args.quantization_bit is not None and self.model_args.quantization == 'cpm': 255 | model = self.quantize(model, self.model_args.quantization_bit) 256 | 257 | if not self.is_deepspeed_train: 258 | model = dispatch(self.model_args.model_type, model, dispatched) 259 | 260 | if os.path.exists(model_to_load + '/generation_config.json'): 261 | model.generation_config = GenerationConfig.from_pretrained(model_to_load) 262 | return model 263 | 264 | @staticmethod 265 | def get_module_by_name(model, module_name): 266 | name_list = module_name.split('.') 267 | for name in name_list[:-1]: 268 | if hasattr(model, name): 269 | model = getattr(model, name) 270 | else: 271 | return None, None 272 | if hasattr(model, name_list[-1]): 273 | leaf_module = getattr(model, name_list[-1]) 274 | return model, leaf_module 275 | else: 276 | return None, None 277 | 278 | def find_all_linear_names(self, model): 279 | match self.model_args.quantization_bit: 280 | case 4: 281 | cls = bnb.nn.Linear4bit 282 | case 8: 283 | cls = bnb.nn.Linear8bitLt 284 | case _: 285 | cls = torch.nn.Linear 286 | lora_module_names = set() 287 | for name, module in model.named_modules(): 288 | if isinstance(module, cls): 289 | names = name.split('.') 290 | # needed for 16-bit 291 | leaf_name = names[-1] 292 | if leaf_name in ('lm_head', 'embed_out', 'output_layer'): 293 | continue 294 | lora_module_names.add(names[0] if len(names) == 1 else leaf_name) 295 | return list(lora_module_names) 296 | 297 | def quantize(self, model, bits, device=None): 298 | for name, module in model.named_modules(): 299 | if isinstance(module, torch.nn.Linear): 300 | names = name.split('.') 301 | leaf_name = names[-1] 302 | if leaf_name in ('lm_head', 'embed_out', 'output_layer'): 303 | continue 304 | super_module, leaf_module = self.get_module_by_name(model, name) 305 | quantized_liner = QuantizedLinear( 306 | weight_bit_width=bits, 307 | weight=leaf_module.weight.to(torch.cuda.current_device()), 308 | bias=leaf_module.bias, 309 | dtype=leaf_module.weight.dtype, 310 | device=leaf_module.weight.device if device is None else device, 311 | ) 312 | setattr(super_module, leaf_name, quantized_liner) 313 | return model 314 | 315 | def save_quantized_model(self): 316 | self.logger.info(f'Load base model from {self.model_args.model_path}') 317 | model = self.load_base_model() 318 | model = self.load_adapter(model, adapter_dir=self.model_args.checkpoint_dir) 319 | self.logger.info('Saving quantized model.') 320 | model.save_pretrained(self.model_args.quantized_or_merged_output_dir) 321 | self.tokenizer.save_pretrained(self.model_args.quantized_or_merged_output_dir) 322 | self.logger.info(f'Quantize done, model saved to {self.model_args.quantized_or_merged_output_dir}') 323 | 324 | def merge_lora_model(self): 325 | if self.model_args.checkpoint_dir is None: 326 | self.logger.error('checkpoint_dir is None.') 327 | if not os.path.exists(os.path.join(self.model_args.checkpoint_dir, WEIGHTS_NAME)) \ 328 | and os.path.exists(os.path.join(self.model_args.checkpoint_dir, CONFIG_NAME)): 329 | self.logger.error(f'Peft checkpoint not found at {self.model_args.checkpoint_dir}.') 330 | self.logger.info(f'Load base model from {self.model_args.model_path}') 331 | model = self.load_base_model() 332 | model = self.load_adapter(model, adapter_dir=self.model_args.checkpoint_dir) 333 | if not self.has_peft: 334 | self.logger.error('Peft checkpoint not found.') 335 | self.logger.info(f'Base model: {self.model_args.model_type}') 336 | self.logger.info(f'Peft model: {self.model_args.checkpoint_dir}') 337 | self.logger.info('Loading LoRA for causal language model') 338 | tokenizer = self.data_manager.load_tokenizer(self.model_args.checkpoint_dir) 339 | self.logger.info('Saving to Hugging Face format...') 340 | tokenizer.save_pretrained(self.model_args.quantized_or_merged_output_dir) 341 | model.save_pretrained(self.model_args.quantized_or_merged_output_dir) 342 | self.logger.info(f'Merge done, model saved to {self.model_args.quantized_or_merged_output_dir}') 343 | 344 | def show_model_info(self): 345 | self.logger.info(f'Load base model from {self.model_args.model_path}') 346 | model = self.load_base_model() 347 | model = self.load_adapter(model, adapter_dir=self.model_args.checkpoint_dir) 348 | info = summary(model, max_level=3) 349 | self.logger.info(f'Model struct:\n{model}') 350 | self.logger.info(f'Model parameter:\n{info}') 351 | 352 | def expand_vocab(self): 353 | expand_vocab( 354 | self.logger, 355 | self.model_args.model_path, 356 | self.data_manager.data_args.corpus_path_for_expansion, 357 | self.model_args.model_type, 358 | self.model_args.save_path_after_vocab_expansion, 359 | self.model_args.torch_dtype, 360 | self.training_args 361 | ) 362 | -------------------------------------------------------------------------------- /engines/predict.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2023/7/2 22:05 3 | # @Author : lishouxian 4 | # @Email : gzlishouxian@gmail.com 5 | # @File : predict.py 6 | # @Software: PyCharm 7 | from transformers import TextIteratorStreamer 8 | from engines.utils.parse_text import parse_text 9 | from engines.utils.metrics import Metrics 10 | from engines.models import BaseModels 11 | from engines.utils.logits_process import logits_processor 12 | from threading import Thread 13 | import gradio as gr 14 | import mdtex2html 15 | 16 | 17 | class Predictor(BaseModels): 18 | def __init__(self, data_manager, config, logger): 19 | super().__init__(data_manager, config, logger) 20 | self.logger = logger 21 | self.data_args = config.data_args 22 | self.generating_args = config.generating_args 23 | self.data_manager = data_manager 24 | self.prompt_template = data_manager.prompt_template 25 | self.metrics = Metrics(data_manager, logger) 26 | self.logger.info(f'Load base model from {self.model_args.model_path}') 27 | self.model = self.load_base_model() 28 | self.model = self.load_adapter(self.model, adapter_dir=self.model_args.checkpoint_dir) 29 | self.logger.info(f'Model struct:\n{self.model}') 30 | self.model.eval() 31 | 32 | def web_inference(self): 33 | def predict(input, chatbot, history, max_new_tokens, top_p, repetition_penalty, temperature): 34 | chatbot.append((parse_text(input), '')) 35 | prompt_template = self.prompt_template.get_prompt(input, history) 36 | input_ids = self.tokenizer([prompt_template], return_tensors='pt')['input_ids'] 37 | input_ids = input_ids.to(self.model.device) 38 | streamer = TextIteratorStreamer(self.tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True) 39 | gen_kwargs = self.generating_args.to_dict() 40 | gen_kwargs.update({ 41 | 'input_ids': input_ids, 42 | 'temperature': temperature, 43 | 'top_p': top_p, 44 | 'top_k': self.generating_args.top_k, 45 | 'repetition_penalty': repetition_penalty, 46 | 'max_new_tokens': max_new_tokens, 47 | 'num_beams': self.generating_args.num_beams, 48 | 'do_sample': self.generating_args.do_sample, 49 | 'eos_token_id': [self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids, 50 | 'logits_processor': logits_processor(), 51 | 'streamer': streamer 52 | }) 53 | 54 | thread = Thread(target=self.model.generate, kwargs=gen_kwargs) 55 | thread.start() 56 | 57 | response = '' 58 | for new_text in streamer: 59 | response += new_text 60 | new_history = history + [(input, response)] 61 | chatbot[-1] = (parse_text(input), parse_text(response)) 62 | yield chatbot, new_history 63 | 64 | def reset_user_input(): 65 | return gr.update(value='') 66 | 67 | def reset_state(): 68 | return [], [] 69 | 70 | def postprocess(self, y): 71 | r""" 72 | Overrides Chatbot.postprocess 73 | """ 74 | if y is None: 75 | return [] 76 | for i, (message, response) in enumerate(y): 77 | y[i] = ( 78 | None if message is None else mdtex2html.convert(message), 79 | None if response is None else mdtex2html.convert(response), 80 | ) 81 | return y 82 | 83 | gr.Chatbot.postprocess = postprocess 84 | 85 | with gr.Blocks() as demo: 86 | gr.HTML(f""" 87 |

88 | 89 | Chat with {self.model_args.model_type} 90 | 91 |

92 | """) 93 | chatbot = gr.Chatbot() 94 | with gr.Row(): 95 | with gr.Column(scale=4): 96 | with gr.Column(scale=12): 97 | user_input = gr.Textbox(show_label=False, placeholder='Input...', lines=10, container=False) 98 | with gr.Column(min_width=32, scale=1): 99 | submit_btn = gr.Button('Submit', variant='primary') 100 | with gr.Column(scale=1): 101 | empty_btn = gr.Button('Clear History') 102 | max_new_tokens = gr.Slider(0, 4096, value=self.generating_args.max_new_tokens, step=1.0, label='Maximum new tokens', interactive=True) 103 | top_p = gr.Slider(0, 1, value=self.generating_args.top_p, step=0.01, label='Top P', interactive=True) 104 | repetition_penalty = gr.Slider(0, 10, value=self.generating_args.repetition_penalty, step=0.01, label='repetition_penalty', interactive=True) 105 | temperature = gr.Slider(0, 1.5, value=self.generating_args.temperature, step=0.01, label='Temperature', interactive=True) 106 | # (message, bot_message) 107 | history = gr.State([]) 108 | submit_btn.click(predict, [user_input, chatbot, history, max_new_tokens, top_p, repetition_penalty, temperature], 109 | [chatbot, history], show_progress=True) 110 | submit_btn.click(reset_user_input, [], [user_input]) 111 | empty_btn.click(reset_state, outputs=[chatbot, history], show_progress=True) 112 | demo.queue().launch(server_name='0.0.0.0', share=True, inbrowser=True, server_port=self.model_args.gradio_port) 113 | 114 | def terminal_inference(self): 115 | def predict(input, history): 116 | prompt_template = self.prompt_template.get_prompt(input, history) 117 | input_ids = self.tokenizer([prompt_template], return_tensors='pt')['input_ids'] 118 | input_ids = input_ids.to(self.model.device) 119 | streamer = TextIteratorStreamer(self.tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True) 120 | gen_kwargs = self.generating_args.to_dict() 121 | gen_kwargs.update({ 122 | 'input_ids': input_ids, 123 | 'eos_token_id': [self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids, 124 | 'logits_processor': logits_processor(), 125 | 'streamer': streamer 126 | }) 127 | thread = Thread(target=self.model.generate, kwargs=gen_kwargs) 128 | thread.start() 129 | print(f'{self.model_args.model_type}:', end='', flush=True) 130 | response = '' 131 | for new_text in streamer: 132 | print(new_text, end='', flush=True) 133 | response += new_text 134 | history = history + [(query, response)] 135 | return history 136 | 137 | history = [] 138 | print('use `clear` to remove the history, use `exit` to exit the application.') 139 | while True: 140 | try: 141 | query = input('\nUser: ') 142 | except UnicodeDecodeError: 143 | print('Detected decoding error at the inputs, please set the terminal encoding to utf-8.') 144 | continue 145 | except Exception: 146 | raise 147 | if query.strip() == 'exit': 148 | break 149 | if query.strip() == 'clear': 150 | history = [] 151 | print('History has been removed.') 152 | continue 153 | history = predict(query, history) 154 | -------------------------------------------------------------------------------- /engines/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/stanleylsx/llms_tool/732a85f7e3bfa34cc6f48aabd3e3eb52c64a62ce/engines/utils/__init__.py -------------------------------------------------------------------------------- /engines/utils/check_load_config.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2023/7/25 22:35 3 | # @Author : lishouxian 4 | # @Email : gzlishouxian@gmail.com 5 | # @File : check_load_config.py 6 | # @Software: PyCharm 7 | from transformers import HfArgumentParser 8 | from config import WorkingMode, ModelArguments, DataTrainingArguments, TrainingArguments, GeneratingArguments 9 | import os 10 | 11 | 12 | class Configure: 13 | def __init__(self): 14 | parser = HfArgumentParser((WorkingMode, ModelArguments, DataTrainingArguments, TrainingArguments, GeneratingArguments)) 15 | self.mode, self.model_args, self.data_args, self.training_args, self.generating_args = parser.parse_args_into_dataclasses() 16 | self.mode = self.mode.mode 17 | 18 | assert self.model_args.quantization_bit is None or self.training_args.fine_tuning_type in ( 19 | 'lora', 'adalora'), 'Quantization is only compatible with the LoRA method(QLora).' 20 | 21 | if self.data_args.prompt_template == 'default': 22 | print('Please specify `prompt_template` if you are using other pre-trained models.') 23 | 24 | if self.training_args.do_train: 25 | print( 26 | f'Process rank: {self.training_args.local_rank}\n' 27 | f'device: {self.training_args.device}\n' 28 | f'n_gpu: {self.training_args.n_gpu}\n' 29 | f'distributed training: {bool(self.training_args.local_rank != -1)}\n' 30 | f'16-bits training: {self.training_args.fp16}\n' 31 | ) 32 | self.fold_check() 33 | 34 | def fold_check(self): 35 | if not os.path.exists(self.data_args.train_file_dir): 36 | raise ValueError('Train dataset not found.') 37 | if not os.path.exists(self.training_args.output_dir): 38 | print('Creating output_dir fold.') 39 | os.makedirs(self.training_args.output_dir) 40 | 41 | if not os.path.exists('./logs'): 42 | print('Creating log fold.') 43 | os.mkdir('./logs') 44 | -------------------------------------------------------------------------------- /engines/utils/convert_rwkv_to_hf.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 The HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """Convert a RWKV checkpoint from BlinkDL to the Hugging Face format.""" 16 | 17 | 18 | import argparse 19 | import gc 20 | import json 21 | import os 22 | import re 23 | import torch 24 | from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedTokenizerFast, RwkvConfig 25 | from transformers.modeling_utils import WEIGHTS_INDEX_NAME, shard_checkpoint 26 | 27 | 28 | NUM_HIDDEN_LAYERS_MAPPING = { 29 | "169M": 12, 30 | "430M": 24, 31 | "1B5": 24, 32 | "3B": 32, 33 | "7B": 32, 34 | "14B": 40, 35 | } 36 | 37 | HIDEN_SIZE_MAPPING = { 38 | "169M": 768, 39 | "430M": 1024, 40 | "1B5": 2048, 41 | "3B": 2560, 42 | "7B": 4096, 43 | "14B": 5120, 44 | } 45 | 46 | 47 | def convert_state_dict(state_dict): 48 | state_dict_keys = list(state_dict.keys()) 49 | for name in state_dict_keys: 50 | weight = state_dict.pop(name) 51 | # emb -> embedding 52 | if name.startswith("emb."): 53 | name = name.replace("emb.", "embeddings.") 54 | # ln_0 -> pre_ln (only present at block 0) 55 | if name.startswith("blocks.0.ln0"): 56 | name = name.replace("blocks.0.ln0", "blocks.0.pre_ln") 57 | # att -> attention 58 | name = re.sub(r"blocks\.(\d+)\.att", r"blocks.\1.attention", name) 59 | # ffn -> feed_forward 60 | name = re.sub(r"blocks\.(\d+)\.ffn", r"blocks.\1.feed_forward", name) 61 | # time_mix_k -> time_mix_key and reshape 62 | if name.endswith(".time_mix_k"): 63 | name = name.replace(".time_mix_k", ".time_mix_key") 64 | # time_mix_v -> time_mix_value and reshape 65 | if name.endswith(".time_mix_v"): 66 | name = name.replace(".time_mix_v", ".time_mix_value") 67 | # time_mix_r -> time_mix_key and reshape 68 | if name.endswith(".time_mix_r"): 69 | name = name.replace(".time_mix_r", ".time_mix_receptance") 70 | 71 | if name != "head.weight": 72 | name = "rwkv." + name 73 | 74 | state_dict[name] = weight 75 | return state_dict 76 | 77 | 78 | def convert_rwkv_checkpoint_to_hf_format( 79 | checkpoint_file, output_dir, size=None, tokenizer_file=None, push_to_hub=False, model_name=None 80 | ): 81 | # 1. If possible, build the tokenizer. 82 | if tokenizer_file is None: 83 | print("No `--tokenizer_file` provided, we will use the default tokenizer.") 84 | vocab_size = 50277 85 | tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b") 86 | else: 87 | tokenizer = PreTrainedTokenizerFast(tokenizer_file=tokenizer_file) 88 | vocab_size = len(tokenizer) 89 | tokenizer.save_pretrained(output_dir) 90 | 91 | # 2. Build the config 92 | possible_sizes = list(NUM_HIDDEN_LAYERS_MAPPING.keys()) 93 | if size is None: 94 | # Try to infer size from the checkpoint name 95 | for candidate in possible_sizes: 96 | if candidate in checkpoint_file: 97 | size = candidate 98 | break 99 | if size is None: 100 | raise ValueError("Could not infer the size, please provide it with the `--size` argument.") 101 | if size not in possible_sizes: 102 | raise ValueError(f"`size` should be one of {possible_sizes}, got {size}.") 103 | 104 | config = RwkvConfig( 105 | vocab_size=vocab_size, 106 | num_hidden_layers=NUM_HIDDEN_LAYERS_MAPPING[size], 107 | hidden_size=HIDEN_SIZE_MAPPING[size], 108 | ) 109 | config.save_pretrained(output_dir) 110 | 111 | state_dict = torch.load(checkpoint_file, map_location="cpu") 112 | state_dict = convert_state_dict(state_dict) 113 | 114 | # 4. Split in shards and save 115 | shards, index = shard_checkpoint(state_dict) 116 | for shard_file, shard in shards.items(): 117 | torch.save(shard, os.path.join(output_dir, shard_file)) 118 | 119 | if index is not None: 120 | save_index_file = os.path.join(output_dir, WEIGHTS_INDEX_NAME) 121 | # Save the index as well 122 | with open(save_index_file, "w", encoding="utf-8") as f: 123 | content = json.dumps(index, indent=2, sort_keys=True) + "\n" 124 | f.write(content) 125 | 126 | # 5. Clean up shards (for some reason the file PyTorch saves take the same space as the whole state_dict 127 | print( 128 | "Cleaning up shards. This may error with an OOM error, it this is the case don't worry you still have converted the model." 129 | ) 130 | shard_files = list(shards.keys()) 131 | 132 | del state_dict 133 | del shards 134 | gc.collect() 135 | 136 | for shard_file in shard_files: 137 | state_dict = torch.load(os.path.join(output_dir, shard_file)) 138 | torch.save({k: v.cpu().clone() for k, v in state_dict.items()}, os.path.join(output_dir, shard_file)) 139 | 140 | del state_dict 141 | gc.collect() 142 | 143 | if push_to_hub: 144 | if model_name is None: 145 | raise ValueError("Please provide a `model_name` to push the model to the Hub.") 146 | model = AutoModelForCausalLM.from_pretrained(output_dir) 147 | model.push_to_hub(model_name, max_shard_size="2GB") 148 | tokenizer.push_to_hub(model_name) 149 | 150 | 151 | if __name__ == "__main__": 152 | parser = argparse.ArgumentParser() 153 | # Required parameters 154 | parser.add_argument( 155 | "--checkpoint_file", default=None, type=str, required=True, help="Name of the checkpoint file in the repo." 156 | ) 157 | parser.add_argument( 158 | "--output_dir", default=None, type=str, required=True, help="Where to save the converted model." 159 | ) 160 | parser.add_argument( 161 | "--tokenizer_file", 162 | default=None, 163 | type=str, 164 | help="Path to the tokenizer file to use (if not provided, only the model is converted).", 165 | ) 166 | parser.add_argument( 167 | "--size", 168 | default=None, 169 | type=str, 170 | help="Size of the model. Will be inferred from the `checkpoint_file` if not passed.", 171 | ) 172 | parser.add_argument( 173 | "--push_to_hub", 174 | action="store_true", 175 | help="Push to the Hub the converted model.", 176 | ) 177 | parser.add_argument( 178 | "--model_name", 179 | default=None, 180 | type=str, 181 | help="Name of the pushed model on the Hub, including the username / organization.", 182 | ) 183 | 184 | args = parser.parse_args() 185 | convert_rwkv_checkpoint_to_hf_format( 186 | args.checkpoint_file, 187 | args.output_dir, 188 | size=args.size, 189 | tokenizer_file=args.tokenizer_file, 190 | push_to_hub=args.push_to_hub, 191 | model_name=args.model_name, 192 | ) 193 | -------------------------------------------------------------------------------- /engines/utils/cpm_quantizer.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2023/7/2 22:05 3 | # @Author : copy from https://huggingface.co/THUDM/chatglm2-6b/blob/main/quantization.py 4 | # @Email : gzlishouxian@gmail.com 5 | # @File : cpm_quantizer.py 6 | # @Software: PyCharm 7 | from torch.nn.parameter import Parameter 8 | import bz2 9 | import torch 10 | import base64 11 | import ctypes 12 | from transformers.utils import logging 13 | from typing import List 14 | 15 | logger = logging.get_logger(__name__) 16 | 17 | try: 18 | from cpm_kernels.kernels.base import LazyKernelCModule, KernelFunction, round_up 19 | 20 | class Kernel: 21 | def __init__(self, code: bytes, function_names: List[str]): 22 | self.code = code 23 | self._function_names = function_names 24 | self._cmodule = LazyKernelCModule(self.code) 25 | 26 | for name in self._function_names: 27 | setattr(self, name, KernelFunction(self._cmodule, name)) 28 | 29 | quantization_code = "$QlpoOTFBWSZTWU9yuJUAQHN//////////f/n/8/n///n//bt4dTidcVx8X3V9FV/92/v4B7/AD5FBQFAAAChSgKpFCFAFVSigUAAAEKhSgUUqgFBKigqVREQAABQBQIANDTTIGI00BkZBkNGE0A0BkBkGQGRkaNAaAGQNBoGgDIAAYIGTI0DQAQAaGmmQMRpoDIyDIaMJoBoDIDIMgMjI0aA0AMgaDQNAGQAAwQMmRoGgAgA0NNMgYjTQGRkGQ0YTQDQGQGQZAZGRo0BoAZA0GgaAMgABggZMjQNABABoaaZAxGmgMjIMhowmgGgMgMgyAyMjRoDQAyBoNA0AZAADBAyZGgaAAmqU1NEgJqnptU/Sn4jRR6J6epk2pqb1Q/SgAPUGgyNNGjQ2SBpoAZAAGg0NB6mgDIAAAAA2oaApSREBNAARhGiYEaEwU8pvImlP0k2aam1GaGqbFNM1MHpTwmkepmyU9R6nqPKekHqNNPUxNGhp6n6p6QaZ6o9TG1GMqcoV9ly6nRanHlq6zPNbnGZNi6HSug+2nPiZ13XcnFYZW+45W11CumhzYhchOJ2GLLV1OBjBjGf4TptOddTSOcVxhqYZMYwZXZZY00zI1paX5X9J+b+f4e+x43RXSxXPOdquiGpduatGyXneN696M9t4HU2eR5XX/kPhP261NTx3JO1Ow7LyuDmeo9a7d351T1ZxnvnrvYnrXv/hXxPCeuYx2XsNmO003eg9J3Z6U7b23meJ4ri01OdzTk9BNO96brz+qT5nuvvH3ds/G+m/JcG/F2XYuhXlvO+jP7U3XgrzPN/lr8Sf1n6j4j7jZs+s/T0tNaNNYzTs12rxjwztHlnire3Nzc3N1wuBwOBwXBvZfoHpD7rFmR99V5vj3aXza3xdBbXMalubTg/jIv5dfAi54Pdc75j4z412n3Npj3Ld/ENm7a3b/Cod6h/ret1/5vn/C+l+gdslMvgPSLJ8d8q+U66fevYn/tW1chleEtNTGlcHCbLRlq0tHzF5tsbbZZfHjjLgZu42XCuC3NrdjTasZGNzgxPIrGqp7r3p7L2p5XjnpPSmTd5XtzqnB6U87zzg1Ol0zd0zsLszxR6lkxp35u6/teL0L0W922cR7Lu1lpL9CsHirzuM2T+BgsyViT6LHcm0/Vr6U/7LGGyJeqTEjt0PHWhF5mCT7R9mtlDwriYv0Tyr/OxYt6qp5r0mPVT0608TqnqMZaarU2nFwrTzzlrs1ed7z1ux60wyr4ydCaTi3enW8x68x0zU7tXSlcmPSW1mGpWJMg4zmPC2lK96tp0OE80y4MfEvnZj8zGluR6b22ki1Ou9V2nCd9xovcPvcYMZYy0lvN60ScZ45vN6yeCeeXFb1lVjnnCar5fwXwE2bzJ4HI1XVPXfXZMm44GUsMpYsmLB65TuVdm0cl0b+i/wGNN66XjeV7zuPpHcnK/juhhjdfId5jMdE5nN0dGmmm2zZs2cexD5n9p/dY352XsvXHaZNWWsmmS1atjR452nYudzvqv2HMRyvNNnlMcDl3R2+yx2uVrBubTW9icHDVtbNXlZm7jma1rM4VurZZd2y6nUau7ZXZ7bVU+mnoOVxZGMrVmvX60605JwmzGZhhhjTWtaaaMaaGTGmNMZasY0iX8VMUl8eepaIrzGSpemWOQyZORk2bNpjUybMmxqYmknCGCFynutfksaZpjTNMaaatM0xsxcGR0sociNqxNSmhhR1ZJPbsn8qyF0t2qH6iYBclclalbtTTcHTDsPaX6rlnElph2Jyumumtynv2Kk8GI7rsvXbIcJgHJOSaSXnnGaI3m87RtVXJOZ/YtgdTE6Wpha6ZlE8ayXkef1fh602r2WwvfMXtMdLlkfnLFdYYwYso+bWqm7yJqHXZGw2nrS5ZanSYnWlxBxMF1V940K2wdrI7R6OYf7DGGamMmTSbRhlS45xmVOumF1EyPCmHrrN8wwZOOrdNtLeMtzFzDlWnfTBxMk2NaXIZHBYxYLD4w8yju0ao65Vz1OIXoS9dLanwCe1PWrYuWMqf1if1z2k2yYfKJ741PDgno1ZQ8DRqvUny3mNoWTzGO6m1DkrJI8JiR5cSd+vZdGOO8nrMoc5+NDUFsMSXaZJeNlMmGLtJsovOsUp7I9S5VojKxF6bTVEelXqlfJobQr3LozSh2Jk7VcrVMfhXqszGWMzNqGhqZY0OadxkyyMssKugZR0KNFXBHlqwmJgTE/BNVMk6ItJXZMR0H47GpXv/DMOvNkmVuaV1PRfEdxuqc7Hcd+ZV/zTLaRxWk0nl9CdCeM6mn5rstHIBcpiuwmUZXeq81DacHI2rmrZ5SuE5mOZd6LQrZg9mx32TprA8BMo5jKN6yLTCi3WzQaZSuhzTtM1fUTGVpG8Tw+KXI0tjEpiWxtLYynOlktSbVlaI5kxP8TDH8kx50xoxi5KcA4pcja8KWLRlO/Ks6q06ergnvm1ca3Tq8Uw7LTUsmWyctXPWmpitl/uvGcWTGXGuAXDfhqazGmjkxcJW5hMMMMpYsXl2TZYtVOddG3XCarUt6Ptq9CZXSNzyuRzqRZOjsxdBbFVz6OA5HI43r1jityVlVpVkxmOsyaYWE1NTGq1sOVh36mHMcxtSvcy70edG0ZGR3I1Go1GRlV7mWWo1G0ZGRqlvH40l7o4m5xMWLLLYyNjnqc8556mdPqLJ31n/1nWOncxzG1tizrHs/Z+d2vP/B/l8wdJ6rHUn2nbbDq4p6htFtYzMMMTaZis1K5GKzGNmxhmUx2DDlZ/qNnIx41xnaMfCZWYaZWtNLTNW8ND4Fw1MyZOCdM428suKG1ehW8TesOydg7J+YYcD4cYR+8dFK6M4E3HM9ZfRNNL+Sn6rsl4DsrDl2HpPCnfxjGXtbZtYys1ttlyJ4T+BvexjGWRjMszK4Jpc77D3GyuVD7q0+G8m9G+2+rGm7cOR2y7FdtY2XUYx/oNlfRYxhMYyYZkyyg55enna9Kt/FFi6GMMwYwdwxWgxGMLKYmUyGExTKMZkMFhkymKuh0NOBNnBu+23LdwDoZYYzGGMxtORaTU1pjTGWTTGGtMrNWUsyyTTLLG1qy2ZjbK2DBllWqxMtBMaYZQmcE7zvvRcTkclUwdkxTaSdyySt/7fpL+T1v516Ji97fwr5JbLu305zMn5+GMTTZ9F+y7ExwmGVfG44yxn3dLv6l5i+Wth1jCrDq21nW9LqvvDzz3Vf3LLH/O/32TJ/erx3bXftO4eF+G956D952K/An4NfvOpjFjExjevP/UmE0fIoZXx6/w6lX/no3D0bLt+ixjieBM6ksRd0yB4Lt2SwYNE+gd1detlZWUnpiZfGfFaK+4PyCa/v18V8X75pe9fLXzp7l3VjF76vWZmHwGz1IZNWT7b8yddJ4q5kyrVdfru6atWc7bVYztL9Jf4GXvT+Y8m9/YsXP6H018a8D4XVOqvfzqeR+6yZOD8dPv0+U7/q5Pl+2dNb0MjzGVH5p6MNQ7cOWvw62U9aHE8DprDek+McLyvDz+te+9Zhq5+YTruufMcWMabqysTmZVWjKPfnK0wyVcrsuhjZRdLkHNvD72b9abriOSGIxiLixMOoalNPXzy+wT/tf+U6HHONfsz+xe8ufHBdQWWGWLA9if0rsnmrxK5LvRZQeWsTCsrmOYy8VteVfuRfcVTtDLItLIsMYxZLdU/DbtSemxF6Z6Zo5WBXE4tFdCyVMMXMTEMZXVlS6Xec2T4e0tHsRcEuWshcJ2YsNF5rUx1E8ifCq6Z+ZP7qdCeu/aTwFd53l16/o0NOw6O3dLavP4Hbi4RdmuDk6DoYaninC0+o4uZjbJ7Rxeu0/FbuFg+q7DVS6fQe0rZ6NDGUNNU6DEqOaLTicKnYZMnBWruljQxoaS3dZhocDge0bSTyOvdAbG5hxe2xji7E/L55xX13wWNDi6HCekcFxfCPGxY0MXC+s7afWaMdDyjyr+o8Rudm/NabOZvdl274zH4f5XK9z6On1Pe/K5TdPAslg77BjuO6Y3eO7GqvOPG/stknp1leyvLL0Z7bl9I4noMvLkzytLhWYzrOZzLXCORe028rORzOg4N/L0HlMOQ3Pgmnbb6KczlabORpu980q37TBqRu0/p3PO6234Bl03Ynuz+9W7gnsEcmvYaYY3aMYY0wx3pYd+ujsXauWdaY5Xkbtl23fPzFHiDB/QMo0yFjBllYxTQYYyxkrwn7JufwJ/PfgJ+C83X69ni6zvXcnyXabv0ncbLwsceS+RNlyN2mnneJtX0ngYO0+e+0+UnA+Wch3ji8hj5an4h+i6XBySU4n+R0roVcbw5yvHrmr4Yw8Y7x6c+9POPYHI5HI5HI5HI5HGXGww4nE4nrVyOR8XeqPEO7PLOiukYa3Novk5hV4cdtYZLI93e+uxff2jRo0aNGjRo0aNG1bVtW1dy3m83m8+tQ5ZzHw3nObwOu8La9Rc1dtkdS8A3eTk823tnktXWlxN6Oixe06zrN70Isd9jiOgZFq9yfkPqP/SLhN2Myl8jDM43bl1nbcb4cO57jlh8Jow6pzXZdL4dyODTuuhu77FyO27DdwdRxmvO+O+3N2+BdqyTwLHVczDVY4UPE4O66/ZO2cx1LFzVdSXtF7G4HMbrauOHRw6c8FdZ5m9fHZHYZXfTlZquyynSyTTKke6vcffSD9pzPA/G7n7jxPmuhc1DHMynPMrGL6AdewYmwu5ko+UUyTwrMv27rPH1v1nGqd87+p6N6LU8k3NEng53xXyHS97+44OSg/sy/hn+Se6yfYNjW0/uTgP+PvWYzLMmjhcLB/gGpri6H83/84eUXWT6T9Hsv7785z/7z4icpW+zfXypuR7rx/gMdZb1/wC678pcs8/2a3mDitGHxl9mfPlll5MafWWqxk/eYuTDgcNMzDGWLWvsuglNxs53GtN6uWpktlW1tZZYcuinMMWmnNnJydze3b2Y1McBxrBkXw799izLMZZYyy0TkbsGM4p03S2uVu5s/XXUdSdec6smVxZYYGpVmT8A+8ajuEyV5FatkvVru2x6uxGXXbH4A+jvgP4GMYy3iPLXzq/6z65+E005ey+cwMZD3fZcqc6xpjTFjQ0P3U+e++cPYmTIwj0nrK5NPTfl3WvpfLtXDcb2HQMudYOxFXQBor4L4T6vrOauFctYXJQ++NUWmJe5bmx1jDiZS1dTqWxo4GR8jm3fttpmPHppk9PEyv4/y8/sO07XacOmcqc0x2Vi9BvNJvN5oW8x4mOsydpidRxMYJPx06m1bqPzq9KtK8sxXNXFodD/+MYYaJTLwOhc9brCsV18oOR1i4tXChyTkq4lf4y1Ke+9axjDHqs1mfBbMXuP4Hzi+X7t8vzv7bHerrUPgPCxhjre4fXdfLNtNM+Jd+Zdh8xd8wP87uNPoPgv4W7/5P2BuxfsMabNnMnza+54Pdi5U671GPZY8CehX8Voeoo7FHpkeEc6715FwHZrIrUrHaviPUbPZHND+IhczrP6FcYvhOZ0Di/ETt0OI+YwNWR9r7tpf6WDeZKZDB1+z2IthOl1mPyb5FluvEx9h9d0NnM0Y1XPFkWIsk1WotJ0PBMmkvjvQTd0e71tfeV+8r8lQ/tpzpsmxJ+InrI/dj2UajUajVTUajatRqNRtGo1Go1Go4wjeMpZFMVV9CHbofPraLsJ3JpWV2XOoanCuFky4y3PPNxucK2uKC1Lbdb1eo+m5XomN6HfeZsabHLHRX/K+offtNGGmHWctcVcG44MdSqsOLY9VzX+Zxfxn2HPdWTpzWvkrtJ8M5zorrKcquRytJ5N5DZmcaW02l76nWO+BqPXm1A2Ry/0q71dH/mqrqeFjkYxjEXtsX8qubTk67rGycyqsdm4tZx5D6D5hhi0waaWmiaMP81Yjii5qxPlPuU/GfTL1Y5E6Jyfiq63qTa39A4J0sOGDgO9WF9bOXl0XfPRbsY2bPNKPy1YrFYrFYmRhhlTIyMjJWJYZHXuCXI8OoXsvfljGLFicNifpp2XunoPiG1wtx3p1Tah+/DD66OnVtVXP9rKbVxOnL0tR/rHtqB5UDErUVcl11D4qqvjpOcxX7armUNJB3LpW6bxVvD08e8h3odKKvyCFZBdSh2FVcST9xV3n3T8t1j7Kr9qgrqXg+13Pt5U7JCvFXVIV1YG5lRhkVYZJYYDDD4KOIMoHCp26WS8GB7uBh2zIdgq/PKyInjV2STShuoapUdCpX1yTwqq/z1VvET7Kh5nVPkO8YyxjLt2MaaMmWTLQvx3qnzltnXW0p2jxgbEtSny/Osv8Y9pLMXYoHVPAhkVdWVeODhR6q9/Sxe2liwwZWMVvFXfRkeIDxAePUPIrdJ4ey6yquzH+PD/bUOWAu05qVHtFd8rrKHSoeNIOUqrYr3FXyToqfYJgwmJdKpXXOwYYegNNGMzfZPp/t3t/DVs4zjNTN61rRqaWaa4NYbRjTa0tWwy2Y2tGN8ZO8ofNKq4j9SL7I+cSm4/6ovLV5HNXLI0jJidwrtk6ynCaP6Z++GjRlWS3tLeW129Mi9evxU9mtz6s5J3Z7M2ngTgnKvmpomxpaLCzPfmx0JWE+m3NLDDGOX47RctdYYNK5jakdqLkRlI39n590T5zctGSwwZZDJj6kW8XSi6ot2MmWWJ0DUT3nuvebBudScjZ79g8cWJ8av0k+/bE5WKd5MdbFpbDVMxu1DVMmtNZGJvq1mtRbn6M+g/kP0FwDwr7quZs7xosNGpbscyxhhd9TyJyFwbLcxlTasg75vW7TsV5K7ji44XPMMrdoj+Y3rT0Hie62nlYV/pwczzOmdLqLhYkzGMzCZWGMQzGMSsZYY6Di1t4nlJ+Em63mJxrVLxPbYxNEdgc1dU2iOKyoYYWjNrEeHTYybVk0atSa7ehuwsWMWTqn1TrnS6hYsi71d1+s+k+ic70e20fzE/VaTdxT9ZtU4GIXdeNx3X77guYYfpHeTQjaMX6brOu4OY4K7Y2d9mbHarI5ox3p4GpJ2Vd/Tst60f7j999pppjR+Q/Qf8J/VaORs3cji7FfFuN61+ui9s8hix1OCh5KGVV23BPXvZfz3CLyHpix+exi8z/KnCnosY2eunor+cxyPO/xJ0vKey9OvE9VjqaYu0x3Z3jd6o2b1T12D+F8l232lwaaacD5LE8LBxu7WTlbWraWpew8Xexjel3E+wWD4APITdNqR8F3R3T0lunCQ4GaE9R37DxeCYfcHi4xci5ovKfxVs55y2hf+65E/Xdp6jR5nrebTmi5incpkyOjs50JvrZwstbbW6kfuuQw+2mykf/EXNFzxfKTrxew929TR6bWnGL//F3JFOFCQT3K4lQ" 30 | 31 | kernels = Kernel( 32 | bz2.decompress(base64.b64decode(quantization_code)), 33 | [ 34 | "int4WeightCompression", 35 | "int4WeightExtractionFloat", 36 | "int4WeightExtractionHalf", 37 | "int8WeightExtractionFloat", 38 | "int8WeightExtractionHalf", 39 | ], 40 | ) 41 | except Exception as exception: 42 | kernels = None 43 | logger.warning("Failed to load cpm_kernels:" + str(exception)) 44 | 45 | 46 | class W8A16Linear(torch.autograd.Function): 47 | @staticmethod 48 | def forward(ctx, inp: torch.Tensor, quant_w: torch.Tensor, scale_w: torch.Tensor, weight_bit_width): 49 | ctx.inp_shape = inp.size() 50 | ctx.weight_bit_width = weight_bit_width 51 | out_features = quant_w.size(0) 52 | inp = inp.contiguous().view(-1, inp.size(-1)) 53 | weight = extract_weight_to_half(quant_w, scale_w, weight_bit_width) 54 | ctx.weight_shape = weight.size() 55 | output = inp.mm(weight.t()) 56 | ctx.save_for_backward(inp, quant_w, scale_w) 57 | return output.view(*(ctx.inp_shape[:-1] + (out_features,))) 58 | 59 | @staticmethod 60 | def backward(ctx, grad_output: torch.Tensor): 61 | inp, quant_w, scale_w = ctx.saved_tensors 62 | weight = extract_weight_to_half(quant_w, scale_w, ctx.weight_bit_width) 63 | grad_output = grad_output.contiguous().view(-1, weight.size(0)) 64 | grad_input = grad_output.mm(weight) 65 | grad_weight = grad_output.t().mm(inp) 66 | return grad_input.view(ctx.inp_shape), grad_weight.view(ctx.weight_shape), None, None 67 | 68 | 69 | def compress_int4_weight(weight: torch.Tensor): # (n, m) 70 | with torch.cuda.device(weight.device): 71 | n, m = weight.size(0), weight.size(1) 72 | assert m % 2 == 0 73 | m = m // 2 74 | out = torch.empty(n, m, dtype=torch.int8, device="cuda") 75 | stream = torch.cuda.current_stream() 76 | 77 | gridDim = (n, 1, 1) 78 | blockDim = (min(round_up(m, 32), 1024), 1, 1) 79 | 80 | kernels.int4WeightCompression( 81 | gridDim, 82 | blockDim, 83 | 0, 84 | stream, 85 | [ctypes.c_void_p(weight.data_ptr()), ctypes.c_void_p(out.data_ptr()), ctypes.c_int32(n), ctypes.c_int32(m)], 86 | ) 87 | return out 88 | 89 | 90 | def extract_weight_to_half(weight: torch.Tensor, scale_list: torch.Tensor, source_bit_width: int): 91 | assert scale_list.dtype in [torch.half, torch.bfloat16] 92 | assert weight.dtype in [torch.int8] 93 | if source_bit_width == 8: 94 | return weight.to(scale_list.dtype) * scale_list[:, None] 95 | elif source_bit_width == 4: 96 | func = ( 97 | kernels.int4WeightExtractionHalf if scale_list.dtype == torch.half else kernels.int4WeightExtractionBFloat16 98 | ) 99 | else: 100 | assert False, "Unsupported bit-width" 101 | 102 | with torch.cuda.device(weight.device): 103 | n, m = weight.size(0), weight.size(1) 104 | out = torch.empty(n, m * (8 // source_bit_width), dtype=scale_list.dtype, device="cuda") 105 | stream = torch.cuda.current_stream() 106 | 107 | gridDim = (n, 1, 1) 108 | blockDim = (min(round_up(m, 32), 1024), 1, 1) 109 | 110 | func( 111 | gridDim, 112 | blockDim, 113 | 0, 114 | stream, 115 | [ 116 | ctypes.c_void_p(weight.data_ptr()), 117 | ctypes.c_void_p(scale_list.data_ptr()), 118 | ctypes.c_void_p(out.data_ptr()), 119 | ctypes.c_int32(n), 120 | ctypes.c_int32(m), 121 | ], 122 | ) 123 | return out 124 | 125 | 126 | class QuantizedLinear(torch.nn.Module): 127 | def __init__(self, weight_bit_width: int, weight, bias=None, device="cpu", dtype=None, empty_init=False, *args, 128 | **kwargs): 129 | super().__init__() 130 | self.weight_bit_width = weight_bit_width 131 | 132 | shape = weight.shape 133 | 134 | if weight is None or empty_init: 135 | self.weight = torch.empty(shape[0], shape[1] * weight_bit_width // 8, dtype=torch.int8, device=device) 136 | self.weight_scale = torch.empty(shape[0], dtype=dtype, device=device) 137 | else: 138 | self.weight_scale = weight.abs().max(dim=-1).values / ((2 ** (weight_bit_width - 1)) - 1) 139 | self.weight = torch.round(weight / self.weight_scale[:, None]).to(torch.int8) 140 | if weight_bit_width == 4: 141 | self.weight = compress_int4_weight(self.weight) 142 | 143 | self.weight = Parameter(self.weight.to(device), requires_grad=False) 144 | self.weight_scale = Parameter(self.weight_scale.to(device), requires_grad=False) 145 | self.bias = Parameter(bias.to(device), requires_grad=False) if bias is not None else None 146 | 147 | def forward(self, input): 148 | output = W8A16Linear.apply(input, self.weight, self.weight_scale, self.weight_bit_width) 149 | if self.bias is not None: 150 | output = output + self.bias 151 | return output 152 | -------------------------------------------------------------------------------- /engines/utils/dispatch_to_multi_gpu.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2023/9/21 21:30 3 | # @Author : lishouxian 4 | # @Email : gzlishouxian@gmail.com 5 | # @File : dispatch_to_multi_gpu.py 6 | # @Software: PyCharm 7 | from accelerate import infer_auto_device_map, dispatch_model 8 | from accelerate.utils import get_balanced_memory 9 | import torch 10 | 11 | 12 | def dispatch(model_type, model, dispatched): 13 | if dispatched: 14 | return model 15 | if model_type == 'chatglm': 16 | model.tie_weights() 17 | device_map = infer_chatglm_device_map(model) 18 | else: 19 | kwargs = {'dtype': model.dtype, 'no_split_module_classes': model._no_split_modules} 20 | max_memory = get_balanced_memory(model, **kwargs) 21 | model.tie_weights() 22 | device_map = infer_auto_device_map(model, max_memory=max_memory, **kwargs) 23 | model = dispatch_model(model, device_map=device_map) 24 | return model 25 | 26 | 27 | def infer_chatglm_device_map(model): 28 | num_gpus = torch.cuda.device_count() 29 | # transformer.word_embeddings 占用1层 30 | # transformer.final_layernorm 和 lm_head 占用1层 31 | # transformer.layers 占用 28 层 32 | # 总共30层分配到num_gpus张卡上 33 | num_trans_layers = 28 34 | per_gpu_layers = 30 / num_gpus 35 | 36 | # bugfix: 在linux中调用torch.embedding传入的weight,input不在同一device上,导致RuntimeError 37 | # windows下 model.device 会被设置成 transformer.word_embeddings.device 38 | # linux下 model.device 会被设置成 lm_head.device 39 | # 在调用chat或者stream_chat时,input_ids会被放到model.device上 40 | # 如果transformer.word_embeddings.device和model.device不同,则会导致RuntimeError 41 | # 因此这里将transformer.word_embeddings,transformer.final_layernorm,lm_head都放到第一张卡上 42 | # 本文件来源于https://github.com/THUDM/ChatGLM-6B/blob/main/utils.py 43 | # 仅此处做少许修改以支持ChatGLM2和ChatGLM3 44 | chatglm2and3 = False 45 | for key, _ in model.named_modules(): 46 | if key.endswith('rotary_pos_emb'): 47 | chatglm2and3 = True 48 | break 49 | 50 | if chatglm2and3: 51 | device_map = { 52 | 'transformer.embedding.word_embeddings': 0, 53 | 'transformer.encoder.final_layernorm': 0, 54 | 'transformer.output_layer': 0, 55 | 'transformer.rotary_pos_emb': 0, 56 | 'lm_head': 0 57 | } 58 | else: 59 | device_map = { 60 | 'transformer.word_embeddings': 0, 61 | 'transformer.final_layernorm': 0, 62 | 'lm_head': 0 63 | } 64 | 65 | used = 2 66 | gpu_target = 0 67 | for i in range(num_trans_layers): 68 | if used >= per_gpu_layers: 69 | gpu_target += 1 70 | used = 0 71 | assert gpu_target < num_gpus 72 | if chatglm2and3: 73 | device_map[f'transformer.encoder.layers.{i}'] = gpu_target 74 | else: 75 | device_map[f'transformer.layers.{i}'] = gpu_target 76 | used += 1 77 | 78 | return device_map 79 | -------------------------------------------------------------------------------- /engines/utils/expand_vocab.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2023/10/24 22:29 3 | # @Author : Mxoder 4 | # @Email : mxode8@gmail.com 5 | """ 6 | 基于 sentencepiece 实现 7 | """ 8 | import os 9 | import shutil 10 | import sentencepiece as sp 11 | from transformers import AutoTokenizer, AutoModel 12 | 13 | 14 | cur_dir = os.path.dirname(os.path.abspath(__file__)) 15 | par_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 16 | 17 | 18 | # 训练词表 19 | def train_vocab(logger, 20 | save_path, 21 | corpus, 22 | model_arch, 23 | vocab_size=8000, 24 | max_sentence_length=24000, 25 | character_coverage=0.9995 26 | ): 27 | logger.info('Start training the vocabulary.') 28 | sp.SentencePieceTrainer.train( 29 | # 只支持 txt 和 tsv 格式 30 | input=corpus, 31 | 32 | # 保存的模型前缀名 33 | model_prefix='bpe_expand', 34 | 35 | # 词表大小 36 | vocab_size=vocab_size, 37 | 38 | # 指定模型的字符覆盖率, 中文日文等推荐为 0.9995, 其余可以尝试 1.0 39 | character_coverage=character_coverage, 40 | 41 | # 分词算法 42 | model_type='bpe', 43 | 44 | # 是否将数字划分为单个 token, 在 llama 中是这么做的 45 | split_digits=True if model_arch == 'llama' else False, 46 | 47 | # 指定在遇到未知或很少的字符时将其分解为 UTF-8 字节, 开启后等效于 bbpe 48 | byte_fallback=True, 49 | 50 | # 指定输入句子的最大长度,以字节为单位 51 | max_sentence_length=max_sentence_length 52 | ) 53 | bpe_model_path = os.path.join(os.path.dirname(par_dir), 'bpe_expand.model') 54 | bpe_vocab_path = os.path.join(os.path.dirname(par_dir), 'bpe_expand.vocab') 55 | shutil.move(bpe_model_path, save_path) 56 | shutil.move(bpe_vocab_path, save_path) 57 | logger.info(f'The vocabulary training is complete, saved to {save_path}.') 58 | 59 | 60 | # 添加新词 61 | def add_new_tokens(logger, tokenizer, save_path): 62 | logger.info('Start adding new tokens.') 63 | bpe_model = os.path.join(save_path, 'bpe_expand.model') 64 | sp_bpe = sp.SentencePieceProcessor() 65 | sp_bpe.load(bpe_model) 66 | 67 | raw_vocab = [sp_bpe.id_to_piece(id) for id in range(sp_bpe.get_piece_size())] 68 | clean_vocab = list(set(filter(is_chinese, raw_vocab))) 69 | 70 | tokenizer.add_tokens(clean_vocab) 71 | tokenizer.save_pretrained(save_path) 72 | logger.info(f'New tokens added, new tokenizer is saved to {save_path}.') 73 | 74 | return len(tokenizer) 75 | 76 | 77 | # 初始化 embedding 层 78 | # todo: 暂时只支持随机扩充,均值扩充在部分模型上没调好 79 | def resize_embedding(logger, model, tokenizer_length, save_path): 80 | logger.info('Start resizing embedding.') 81 | new_length = int(tokenizer_length // 64 + 1) * 64 82 | model.resize_token_embeddings(new_length) 83 | 84 | model.save_pretrained(save_path) 85 | logger.info(f'New model: {model}') 86 | logger.info(f'Embedding resized, new model is saved to {save_path}.') 87 | 88 | 89 | # 直接注入新词表 90 | def inject_vocab(logger, tokenizer, save_path, corpus_list): 91 | logger.info('Start injecting new vocabulary.') 92 | 93 | all_words = [] 94 | for file in corpus_list: 95 | with open(file, 'r', encoding='utf-8') as f: 96 | lines = f.readlines() 97 | words = [line.strip() for line in lines] 98 | all_words.extend(words) 99 | 100 | tokenizer.add_tokens(all_words) 101 | tokenizer.save_pretrained(save_path) 102 | logger.info(f'New vocabulary injected, new tokenizer is saved to {save_path}.') 103 | 104 | return len(tokenizer) 105 | 106 | 107 | # 入口函数 108 | def expand_vocab(logger, 109 | model_path, 110 | corpus_path, 111 | model_arch, 112 | save_path, 113 | torch_dtype, 114 | args 115 | ): 116 | logger.info(f'Load base tokenizer from {model_path}.') 117 | tokenizer = AutoTokenizer.from_pretrained( 118 | model_path, 119 | trust_remote_code=True, 120 | # LLaMA 不用 TokenizerFast,表现有差异 121 | use_fast=False if model_arch == 'llama' else True 122 | ) 123 | 124 | logger.info(f'Load base model from {model_path}'.capitalize) 125 | model = AutoModel.from_pretrained( 126 | model_path, 127 | trust_remote_code=True, 128 | torch_dtype=torch_dtype 129 | ) 130 | 131 | save_path = os.path.join(model_path, 'new_model') if save_path == 'auto' else save_path 132 | os.makedirs(save_path, exist_ok=True) 133 | logger.info(f'After expanding the vocabulary, the new model will be saved to {save_path}.') 134 | 135 | corpus_list = process_corpus(corpus_path) 136 | if args.expand_mode == 'inject': 137 | tokenizer_length = inject_vocab(logger, tokenizer, save_path, corpus_list) 138 | else: 139 | train_vocab( 140 | logger, 141 | save_path, 142 | corpus_list, 143 | model_arch, 144 | args.vocab_size, 145 | args.max_sentence_length 146 | ) 147 | tokenizer_length = add_new_tokens(logger, tokenizer, save_path) 148 | resize_embedding(logger, model, tokenizer_length, save_path) 149 | logger.info('The vocabulary was successfully expanded.') 150 | 151 | 152 | def process_corpus(corpus_path): 153 | ret_list = [] 154 | if not os.path.isdir(corpus_path): 155 | if not corpus_path.endswith('.txt') and not corpus_path.endswith('.tsv'): 156 | raise ValueError('Only .txt or .tsv files are supported.') 157 | else: 158 | ret_list.append(corpus_path) 159 | else: 160 | file_list = os.listdir(corpus_path) 161 | for file in file_list: 162 | if not file.endswith('.txt') and not corpus_path.endswith('.tsv'): 163 | raise ValueError('Only .txt or .tsv files are supported.') 164 | else: 165 | ret_list.append(os.path.join(corpus_path, file)) 166 | return ret_list 167 | 168 | 169 | def is_chinese_char(cp): 170 | if (( 171 | cp >= 0x4E00 and cp <= 0x9FFF) or ( 172 | cp >= 0x3400 and cp <= 0x4DBF) or ( 173 | cp >= 0x20000 and cp <= 0x2A6DF) or ( 174 | cp >= 0x2A700 and cp <= 0x2B73F) or ( 175 | cp >= 0x2B740 and cp <= 0x2B81F) or ( 176 | cp >= 0x2B820 and cp <= 0x2CEAF) or ( 177 | cp >= 0xF900 and cp <= 0xFAFF) or ( 178 | cp >= 0x2F800 and cp <= 0x2FA1F)): 179 | return True 180 | return False 181 | 182 | 183 | def is_chinese(word: str): 184 | for char in word: 185 | char = ord(char) 186 | if not is_chinese_char(char): 187 | return False 188 | return True 189 | -------------------------------------------------------------------------------- /engines/utils/logits_process.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2023/9/19 22:05 3 | # @Author : lishouxian 4 | # @Email : gzlishouxian@gmail.com 5 | # @File : logits_process.py 6 | # @Software: PyCharm 7 | from transformers import InfNanRemoveLogitsProcessor, LogitsProcessorList 8 | 9 | 10 | def logits_processor(): 11 | # https://huggingface.co/docs/transformers/v4.33.2/en/internal/generation_utils#transformers.LogitsProcessor 12 | # You can define your logits processor here to control the generate process. 13 | # https://blog.csdn.net/weixin_44826203/article/details/129928897 14 | logits_processor = LogitsProcessorList() 15 | logits_processor.append(InfNanRemoveLogitsProcessor()) 16 | return logits_processor 17 | -------------------------------------------------------------------------------- /engines/utils/metrics.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2023/7/19 21:05 3 | # @Author : lishouxian 4 | # @Email : gzlishouxian@gmail.com 5 | # @File : metrics.py 6 | # @Software: PyCharm 7 | from rouge_chinese import Rouge 8 | import numpy as np 9 | import jieba 10 | 11 | 12 | class Metrics: 13 | def __init__(self, data_manager, logger): 14 | self.data_manager = data_manager 15 | self.tokenizer = data_manager.tokenizer 16 | self.rouge = Rouge() 17 | self.logger = logger 18 | 19 | def computer_supervised_fine_tuning_metric(self, eval_preds): 20 | preds, labels = eval_preds 21 | score_dict = {'rouge-1': [], 'rouge-2': [], 'rouge-l': []} 22 | preds = np.where(preds != -100, preds, self.tokenizer.pad_token_id) 23 | labels = np.where(labels != -100, labels, self.tokenizer.pad_token_id) 24 | decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True) 25 | decoded_labels = self.tokenizer.batch_decode(labels, skip_special_tokens=True) 26 | 27 | for pred, label in zip(decoded_preds, decoded_labels): 28 | hypothesis = list(jieba.cut(pred)) 29 | reference = list(jieba.cut(label)) 30 | if len(' '.join(hypothesis).split()) == 0 or len(' '.join(reference).split()) == 0: 31 | result = {'rouge-1': {'f': 0.0}, 'rouge-2': {'f': 0.0}, 'rouge-l': {'f': 0.0}} 32 | else: 33 | scores = self.rouge.get_scores(' '.join(hypothesis), ' '.join(reference)) 34 | result = scores[0] 35 | for k, v in result.items(): 36 | score_dict[k].append(round(v['f'] * 100, 4)) 37 | metric_results = {} 38 | for k, v in score_dict.items(): 39 | metric_results[k] = float(np.mean(v)) 40 | return metric_results 41 | 42 | @staticmethod 43 | def computer_training_reward_metric(eval_preds): 44 | preds, _ = eval_preds 45 | accuracy = np.array(preds[0] > preds[1]).sum() / len(preds[0]) 46 | return {'accuracy': accuracy} 47 | -------------------------------------------------------------------------------- /engines/utils/parse_text.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2023/7/2 22:05 3 | # @Author : copy from https://github.com/GaiZhenbiao/ChuanhuChatGPT 4 | # @Email : gzlishouxian@gmail.com 5 | # @File : parse_text.py 6 | # @Software: PyCharm 7 | def parse_text(text): 8 | lines = text.split('\n') 9 | lines = [line for line in lines if line != ''] 10 | count = 0 11 | for i, line in enumerate(lines): 12 | if '```' in line: 13 | count += 1 14 | items = line.split('`') 15 | if count % 2 == 1: 16 | lines[i] = f'
'
17 |             else:
18 |                 lines[i] = '
' 19 | else: 20 | if i > 0: 21 | if count % 2 == 1: 22 | line = line.replace('`', '\`') 23 | line = line.replace('<', '<') 24 | line = line.replace('>', '>') 25 | line = line.replace(' ', ' ') 26 | line = line.replace('*', '*') 27 | line = line.replace('_', '_') 28 | line = line.replace('-', '-') 29 | line = line.replace('.', '.') 30 | line = line.replace('!', '!') 31 | line = line.replace('(', '(') 32 | line = line.replace(')', ')') 33 | line = line.replace('$', '$') 34 | lines[i] = '
' + line 35 | text = ''.join(lines) 36 | return text 37 | -------------------------------------------------------------------------------- /engines/utils/print_parameters.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2023/7/2 22:05 3 | # @Author : copy from https://github.com/airaria/TextPruner/blob/main/src/textpruner/utils.py 4 | # @Email : gzlishouxian@gmail.com 5 | # @File : print_parameters.py 6 | # @Software: PyCharm 7 | import torch 8 | 9 | 10 | class LayerNode: 11 | def __init__(self, name, parent=None, value=None, fullname=None): 12 | self.name = name 13 | self.fullname = fullname 14 | self.value = None 15 | self.children_name = {} 16 | self.parent = parent 17 | 18 | def __contains__(self, key): 19 | return key in self.children_name 20 | 21 | def __getitem__(self, key): 22 | return self.children_name[key] 23 | 24 | def __setitem__(self, key, value): 25 | self.children_name[key] = value 26 | 27 | def update(self, value): 28 | if self.parent: 29 | if self.parent.value is None: 30 | self.parent.value = value 31 | else: 32 | if isinstance(value, (tuple, list)): 33 | old_value = self.parent.value 34 | new_value = [old_value[i] + value[i] for i in range(len(value))] 35 | self.parent.value = new_value 36 | else: 37 | self.parent.value += value 38 | if self.name.endswith('(shared)'): 39 | if self.parent.name.endswith('shared)'): 40 | pass 41 | elif self.parent.value[0] == 0: 42 | self.parent.name += '(shared)' 43 | else: 44 | self.parent.name += '(partially shared)' 45 | 46 | self.parent.update(value) 47 | 48 | def format(self, level=0, total=None, indent='--', max_level=None, max_length=None): 49 | string = '' 50 | if total is None: 51 | total = self.value[0] 52 | if level == 0: 53 | max_length = self._max_name_length(indent, ' ', max_level=max_level) + 1 54 | string += '\n' 55 | string += f"{'LAYER NAME':<{max_length}}\t{'#PARAMS':>15}\t{'RATIO':>10}\t{'MEM(MB)':>8}\n" 56 | 57 | if max_level is not None and level == max_level: 58 | string += f"{indent + self.name + ':':<{max_length}}\t{self.value[0]:15,d}\t{self.value[0] / total:>10.2%}\t{self.value[1]:>8.2f}\n" 59 | else: 60 | if len(self.children_name) == 1: 61 | string += f"{indent + self.name:{max_length}}\n" 62 | else: 63 | string += f"{indent + self.name + ':':<{max_length}}\t{self.value[0]:15,d}\t{self.value[0] / total:>10.2%}\t{self.value[1]:>8.2f}\n" 64 | for child_name, child in self.children_name.items(): 65 | string += child.format(level + 1, total, 66 | indent=' ' + indent, max_level=max_level, max_length=max_length) 67 | return string 68 | 69 | def _max_name_length(self, indent1='--', indent2=' ', level=0, max_level=None): 70 | length = len(self.name) + len(indent1) + level * len(indent2) 71 | if max_level is not None and level >= max_level: 72 | child_lengths = [] 73 | else: 74 | child_lengths = [child._max_name_length(indent1, indent2, level=level + 1, max_level=max_level) 75 | for child in self.children_name.values()] 76 | max_length = max(child_lengths + [length]) 77 | return max_length 78 | 79 | 80 | def summary(model, max_level): 81 | """ 82 | Show the summary of model parameters. 83 | 84 | Args: 85 | model: the model to be inspected, can be a torch module or a state_dict. 86 | max_level: The max level to display. If ``max_level==None``, show all the levels. 87 | Returns: 88 | A formatted string. 89 | 90 | Example:: 91 | 92 | print(textpruner.summay(model)) 93 | 94 | """ 95 | if isinstance(model, torch.nn.Module): 96 | state_dict = model.state_dict() 97 | elif isinstance(model, dict): 98 | state_dict = model 99 | else: 100 | raise TypeError('model should be either torch.nn.Module or a dict') 101 | hash_set = set() 102 | model_node = LayerNode('model', fullname='model') 103 | current = model_node 104 | for key, value in state_dict.items(): 105 | names = key.split('.') 106 | for i, name in enumerate(names): 107 | if name not in current: 108 | current[name] = LayerNode(name, parent=current, fullname='.'.join(names[:i + 1])) 109 | current = current[name] 110 | 111 | if (value.data_ptr()) in hash_set: 112 | current.value = [0, 0] 113 | current.name += '(shared)' 114 | current.fullname += '(shared)' 115 | current.update(current.value) 116 | else: 117 | hash_set.add(value.data_ptr()) 118 | current.value = [value.numel(), value.numel() * value.element_size() / 1024 / 1024] 119 | current.update(current.value) 120 | 121 | current = model_node 122 | 123 | result = model_node.format(max_level=max_level) 124 | 125 | return result 126 | 127 | 128 | def print_trainable_parameters(model, logger): 129 | """ 130 | Prints the number of trainable parameters in the model. 131 | """ 132 | trainable_params = 0 133 | all_param = 0 134 | for _, param in model.named_parameters(): 135 | all_param += param.numel() 136 | if param.requires_grad: 137 | trainable_params += param.numel() 138 | logger.info(f'trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}') 139 | -------------------------------------------------------------------------------- /engines/utils/prompt_template.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2023/7/2 22:05 3 | # @Author : lishouxian 4 | # @Email : gzlishouxian@gmail.com 5 | # @File : prompt_template.py 6 | # @Software: PyCharm 7 | 8 | 9 | class Template: 10 | 11 | def __init__(self, prompt_template): 12 | self.prompt_template = prompt_template 13 | if self.prompt_template == 'default': 14 | r""" 15 | Default template. 16 | """ 17 | self.prefix = "A chat between a curious user and an artificial intelligence assistant. \n" \ 18 | "The assistant gives helpful, detailed, and polite answers to the user's questions." 19 | self.prompt = 'Human: {query}\nAssistant: ' 20 | self.sep = '\n' 21 | self.stop_words = [] 22 | self.use_history = True 23 | elif self.prompt_template == 'vanilla': 24 | r""" 25 | Supports language model inference without histories. 26 | """ 27 | self.prefix = '' 28 | self.prompt = '{query}' 29 | self.sep = '' 30 | self.stop_words = [] 31 | self.use_history = False 32 | elif self.prompt_template == 'alpaca': 33 | r""" 34 | Supports: https://huggingface.co/tatsu-lab/alpaca-7b-wdiff 35 | https://github.com/ymcui/Chinese-LLaMA-Alpaca 36 | """ 37 | self.prefix = 'Below is an instruction that describes a task. \n' \ 38 | 'Write a response that appropriately completes the request.' 39 | self.prompt = '### Instruction:\n{query}\n\n### Response:\n' 40 | self.sep = '\n\n' 41 | self.stop_words = [] 42 | self.use_history = True 43 | elif self.prompt_template == 'vicuna': 44 | r""" 45 | Supports: https://huggingface.co/lmsys/vicuna-7b-delta-v1.1 46 | https://huggingface.co/lmsys/vicuna-13b-delta-v1.1 47 | """ 48 | self.prefix = "A chat between a curious user and an artificial intelligence assistant. \n" \ 49 | "The assistant gives helpful, detailed, and polite answers to the user's questions." 50 | self.prompt = 'USER: {query} ASSISTANT: ' 51 | self.sep = '' 52 | self.stop_words = [] 53 | self.use_history = True 54 | elif self.prompt_template == 'belle': 55 | r""" 56 | Supports: https://huggingface.co/BelleGroup/BELLE-LLaMA-EXT-13B 57 | """ 58 | self.prefix = '' 59 | self.prompt = 'Human: {query}\n\nBelle: ' 60 | self.sep = '\n\n' 61 | self.stop_words = [] 62 | self.use_history = True 63 | elif self.prompt_template == 'linly': 64 | r""" 65 | Supports: https://github.com/CVI-SZU/Linly 66 | """ 67 | self.prefix = '' 68 | self.prompt = 'User: {query}\nBot: ' 69 | self.sep = '\n' 70 | self.stop_words = [] 71 | self.use_history = True 72 | elif self.prompt_template == 'ziya': 73 | r""" 74 | Supports: https://huggingface.co/IDEA-CCNL/Ziya-LLaMA-13B-v1 75 | """ 76 | self.prefix = '' 77 | self.prompt = ':{query}\n:' 78 | self.sep = '\n' 79 | self.stop_words = [] 80 | self.use_history = True 81 | elif self.prompt_template == 'aquila': 82 | r""" 83 | Supports: https://huggingface.co/BAAI/AquilaChat-7B 84 | """ 85 | self.prefix = "A chat between a curious human and an artificial intelligence assistant. \n" \ 86 | "he assistant gives helpful, detailed, and polite answers to the human's questions." 87 | self.prompt = 'Human: {query}###Assistant:' 88 | self.sep = '###' 89 | self.stop_words = ['###', '[UNK]', ''] 90 | self.use_history = True 91 | elif self.prompt_template == 'firefly': 92 | r""" 93 | Supports: https://huggingface.co/YeungNLP/firefly-baichuan-7b-qlora-sft-merge 94 | """ 95 | self.prefix = '' 96 | self.prompt = '{query}' 97 | self.sep = '' 98 | self.stop_words = [] 99 | self.use_history = True 100 | elif self.prompt_template == 'openbuddy': 101 | r""" 102 | Supports: https://huggingface.co/OpenBuddy/openbuddy-falcon-7b-v6-bf16 103 | """ 104 | self.prefix = '' 105 | self.prompt = 'User: {query}\nAssistant:' 106 | self.sep = '\n' 107 | self.stop_words = [] 108 | self.use_history = True 109 | elif self.prompt_template == 'yuyan': 110 | r""" 111 | """ 112 | self.prefix = '' 113 | self.prompt = '<|Human|>:\n{query}\n\n<|Yuyan|>:\n' 114 | self.sep = '\n\n' 115 | self.stop_words = [] 116 | self.use_history = True 117 | elif self.prompt_template == 'internlm': 118 | r""" 119 | Supports: https://huggingface.co/BlinkDL/rwkv-4-raven 120 | """ 121 | self.prefix = '' 122 | self.prompt = '<|User|>:{query}\n<|Bot|>:' 123 | self.sep = '\n' 124 | self.stop_words = [''] 125 | self.use_history = True 126 | elif self.prompt_template == 'baichuan': 127 | r""" 128 | Supports: https://huggingface.co/baichuan-inc/Baichuan-13B-Chat 129 | """ 130 | self.prefix = '' 131 | self.prompt = '{query}' 132 | self.sep = '' 133 | self.stop_words = [] 134 | self.use_history = True 135 | elif self.prompt_template == 'baichuan2': 136 | r""" 137 | Supports: https://huggingface.co/baichuan-inc/Baichuan2-7B-Chat 138 | """ 139 | self.prefix = '' 140 | self.prompt = '{query}' 141 | self.sep = '' 142 | self.stop_words = [] 143 | self.use_history = True 144 | elif self.prompt_template == 'chatglm': 145 | r""" 146 | Supports: https://huggingface.co/THUDM/chatglm2-6b 147 | """ 148 | self.prefix = '[Round {}]' 149 | self.prompt = '问:{query}\n\n答:' 150 | self.sep = '\n\n' 151 | self.stop_words = [] 152 | self.use_history = True 153 | elif self.prompt_template == 'chatglm3': 154 | r""" 155 | Supports: https://huggingface.co/THUDM/chatglm3-6b 156 | """ 157 | self.prefix = '' 158 | self.prompt = '<|user|>\n{query}<|assistant|>\n' 159 | self.sep = '' 160 | self.stop_words = ['<|user|>', '<|observation|>'] 161 | self.use_history = True 162 | elif self.prompt_template == 'moss': 163 | r""" 164 | Supports: https://huggingface.co/fnlp/moss-moon-003-sft 165 | """ 166 | self.prefix = '' 167 | self.prompt = '<|Human|>: {query}\n<|MOSS|>:' 168 | self.sep = '\n' 169 | self.stop_words = [] 170 | self.use_history = True 171 | elif self.prompt_template == 'rwkv': 172 | r""" 173 | Supports: https://huggingface.co/BlinkDL/rwkv-4-raven 174 | """ 175 | self.prefix = '' 176 | self.prompt = 'Bob: {query}\n\nAlice:' 177 | self.sep = '\n\n' 178 | self.stop_words = [] 179 | self.use_history = True 180 | elif self.prompt_template == 'linksoul': 181 | r""" 182 | Supports: https://huggingface.co/LinkSoul/Chinese-Llama-2-7b 183 | """ 184 | self.prefix = "[INST] <>\nYou are a helpful, respectful and honest assistant. " \ 185 | "Always answer as helpfully as possible, while being safe. " \ 186 | "Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, " \ 187 | "or illegal content. Please ensure that your responses are socially unbiased and " \ 188 | "positive in nature. \n If a question does not make any sense, " \ 189 | "or is not factually coherent, " \ 190 | "explain why instead of answering something not correct. If you don't know the answer to " \ 191 | "a question, please don't share false information.\n<>\n\n" 192 | self.prompt = '{query}[/INST]' 193 | self.sep = ' [INST]' 194 | self.stop_words = [] 195 | self.use_history = True 196 | elif self.prompt_template == 'qwen': 197 | r""" 198 | Supports: https://huggingface.co/Qwen/Qwen-7B-Chat 199 | """ 200 | self.prefix = '<|im_start|>system\nYou are a helpful assistant.' 201 | self.prompt = '<|im_start|>user\n{query}<|im_end|>\n<|im_start|>assistant\n' 202 | self.sep = '<|im_end|>\n' 203 | self.stop_words = [] 204 | self.use_history = True 205 | elif self.prompt_template == 'yi': 206 | r""" 207 | Supports: https://huggingface.co/01-ai/Yi-6B-Chat 208 | """ 209 | self.prefix = '' 210 | self.prompt = '<|im_start|>user\n{query}<|im_end|>\n<|im_start|>assistant\n' 211 | self.sep = '<|im_end|>\n' 212 | self.stop_words = [] 213 | self.use_history = True 214 | elif self.prompt_template == 'xverse': 215 | r""" 216 | Supports: https://huggingface.co/xverse/XVERSE-13B-Chat/ 217 | """ 218 | self.prefix = '' 219 | self.prompt = 'Human: {query}\n\nAssistant: ' 220 | self.sep = '\n\n' 221 | self.stop_words = [] 222 | self.use_history = True 223 | elif self.prompt_template == 'tigerbot': 224 | r""" 225 | Supports: https://huggingface.co/TigerResearch 226 | """ 227 | self.prefix = '' 228 | self.prompt = 'Instruction:\n{query}\n\n### Response:\n' 229 | self.sep = '\n\n' 230 | self.stop_words = [] 231 | self.use_history = True 232 | elif self.prompt_template == 'flagalpha': 233 | r""" 234 | Supports: https://huggingface.co/FlagAlpha 235 | """ 236 | self.prefix = '' 237 | self.prompt = 'Human:{query}\n<\\s>Assistant: ' 238 | self.sep = '<\\s>' 239 | self.stop_words = [] 240 | self.use_history = True 241 | elif self.prompt_template == 'orca': 242 | r""" 243 | Supports: https://huggingface.co/microsoft/Orca-2-7b 244 | """ 245 | self.prefix = 'You are Orca, an AI language model created by Microsoft. You are a cautious assistant. You carefully follow instructions. You are helpful and harmless and you follow ethical guidelines and promote positive behavior.' 246 | self.prompt = '<|im_start|>user\n{query}<|im_end|>\n<|im_start|>assistant' 247 | self.sep = '\n' 248 | self.stop_words = [] 249 | self.use_history = True 250 | else: 251 | raise ValueError('Template {} does not exist.'.format(self.prompt_template)) 252 | 253 | def get_prompt(self, query, history, join_history=True): 254 | r""" 255 | Returns a string containing prompt without response. 256 | """ 257 | format_result = self._format_example(query, history) 258 | return ''.join(format_result) if join_history else format_result 259 | 260 | def _format_example(self, query, history): 261 | # add separator for non-empty prefix 262 | prefix = self.prefix + self.sep if self.prefix else '' 263 | history = history if (history and self.use_history) else [] 264 | history = history + [(query, '')] 265 | conversations = [] 266 | for turn_idx, (user_query, bot_resp) in enumerate(history): 267 | if self.prompt_template == 'chatglm': 268 | prompt = self.prompt.format(query=user_query) 269 | current_prefix = prefix.format(turn_idx + 1) 270 | if turn_idx + 1 > 1: 271 | current_prefix = '\n\n' + current_prefix 272 | conversations.append(current_prefix + prompt) 273 | conversations.append(bot_resp) 274 | else: 275 | if turn_idx == 0: 276 | conversations.append(prefix + self.prompt.format(query=user_query)) 277 | conversations.append(bot_resp) 278 | else: 279 | conversations.append(self.sep + self.prompt.format(query=user_query)) 280 | conversations.append(bot_resp) 281 | # drop last 282 | return conversations[:-1] 283 | -------------------------------------------------------------------------------- /engines/utils/trainer.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2023/7/2 22:05 3 | # @Author : lishouxian 4 | # @Email : gzlishouxian@gmail.com 5 | # @File : trainer.py 6 | # @Software: PyCharm 7 | from transformers import Seq2SeqTrainer, Trainer 8 | from transformers.modeling_utils import unwrap_model 9 | from trl import PPOTrainer 10 | from trl.core import PPODecorators, logprobs_from_logits 11 | from typing import Optional, List 12 | import torch 13 | import os 14 | import math 15 | 16 | 17 | class SFTTrainer(Seq2SeqTrainer): 18 | def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys=None, **gen_kwargs): 19 | prompt_len, label_len = inputs['input_ids'].size(-1), inputs['labels'].size(-1) 20 | if prompt_len > label_len: 21 | inputs['labels'] = self._pad_tensors_to_target_len(inputs['labels'], inputs['input_ids']) 22 | if label_len > prompt_len: 23 | inputs['input_ids'] = self._pad_tensors_to_target_len(inputs['input_ids'], inputs['labels']) 24 | if 'attention_mask' in inputs: 25 | inputs['attention_mask'] = self._pad_tensors_to_target_len(inputs['attention_mask'], inputs['labels']) 26 | if 'position_ids' in inputs: 27 | inputs['position_ids'] = self._pad_tensors_to_target_len(inputs['position_ids'], inputs['labels']) 28 | loss, generated_tokens, labels = super().prediction_step( 29 | model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys) 30 | generated_tokens = (generated_tokens[:, max(prompt_len, label_len):] if generated_tokens is not None else None) 31 | return loss, generated_tokens, labels 32 | 33 | def _pad_tensors_to_target_len(self, left, right): 34 | if pad_token_id := self.tokenizer.pad_token_id is None: 35 | raise ValueError('Pad_token_id must be set in the configuration of the model.') 36 | padded_tensor = pad_token_id * torch.ones_like(right) 37 | if self.tokenizer.padding_side == 'left': 38 | padded_tensor[:, -left.shape[-1]:] = left 39 | else: 40 | padded_tensor[:, :left.shape[-1]] = left 41 | return padded_tensor.contiguous() 42 | 43 | 44 | class RewardTrainer(Trainer): 45 | def __init__(self, model_type, **kwargs): 46 | super().__init__(**kwargs) 47 | self.model_type = model_type 48 | self.can_return_loss = True 49 | 50 | def compute_loss(self, model, inputs, return_outputs=False): 51 | batch_size = int(inputs['input_ids'].size(0) / 2) 52 | _, _, values = model(**inputs) 53 | if self.model_type == 'chatglm': 54 | values = torch.transpose(values, 1, 0) 55 | r_accept, r_reject = values[:, -1].split(batch_size, dim=0) 56 | loss = -torch.nn.functional.logsigmoid(r_accept - r_reject).mean() 57 | outputs = {'r_accept': r_accept, 'r_reject': r_reject} 58 | return (loss, outputs) if return_outputs else loss 59 | 60 | def _save(self, output_dir: Optional[str] = None, state_dict=None): 61 | output_dir = self.args.output_dir if output_dir is None else output_dir 62 | self.model = unwrap_model(self.model) 63 | state_dict = self.model.state_dict() 64 | torch.save(state_dict, os.path.join(output_dir, 'vhead.bin')) 65 | self.model.pretrained_model.save_pretrained(output_dir) 66 | torch.save(self.args, os.path.join(output_dir, 'training_args.bin')) 67 | 68 | 69 | class MyPPOTrainer(PPOTrainer): 70 | def __init__(self, model_type, **kwargs): 71 | super().__init__(**kwargs) 72 | self.model_type = model_type 73 | 74 | @PPODecorators.empty_cuda_cache() 75 | def batched_forward_pass(self, model, queries, responses, model_inputs, return_logits): 76 | bs = len(queries) 77 | fbs = self.config.mini_batch_size 78 | all_logprobs = [] 79 | all_logits = [] 80 | all_masks = [] 81 | all_values = [] 82 | 83 | for i in range(math.ceil(bs / fbs)): 84 | input_kwargs = {key: value[i * fbs: (i + 1) * fbs] for key, value in model_inputs.items()} 85 | query_batch = queries[i * fbs: (i + 1) * fbs] 86 | response_batch = responses[i * fbs: (i + 1) * fbs] 87 | logits, _, values = model(**input_kwargs) 88 | values = torch.transpose(values, 1, 0) if self.model_type == 'chatglm' else values 89 | 90 | if self.is_encoder_decoder: 91 | input_ids = input_kwargs["decoder_input_ids"] 92 | attention_mask = input_kwargs["decoder_attention_mask"] 93 | else: 94 | input_ids = input_kwargs["input_ids"] 95 | attention_mask = input_kwargs["attention_mask"] 96 | 97 | logprobs = logprobs_from_logits(logits[:, :-1, :], input_ids[:, 1:]) 98 | masks = torch.zeros_like(attention_mask) 99 | masks[:, :-1] = attention_mask[:, 1:] 100 | 101 | for j in range(len(query_batch)): 102 | if self.is_encoder_decoder: 103 | # Decoder sentence starts always in the index 1 after padding in the Enc-Dec Models 104 | start = 1 105 | end = attention_mask[j, :].sum() - 1 106 | else: 107 | start = len(query_batch[j]) - 1 108 | if attention_mask[j, 0] == 0: # offset left padding 109 | start += attention_mask[j, :].nonzero()[0] 110 | end = start + len(response_batch[j]) 111 | 112 | masks[j, :start] = 0 113 | masks[j, end:] = 0 114 | 115 | if return_logits: 116 | all_logits.append(logits) 117 | else: 118 | del logits 119 | all_values.append(values) 120 | all_logprobs.append(logprobs) 121 | all_masks.append(masks) 122 | 123 | return ( 124 | torch.cat(all_logprobs), 125 | torch.cat(all_logits)[:, :-1] if return_logits else None, 126 | torch.cat(all_values)[:, :-1], 127 | torch.cat(all_masks)[:, :-1], 128 | ) 129 | 130 | def generate( 131 | self, 132 | query_tensor, 133 | length_sampler=None, 134 | batch_size=4, 135 | return_prompt=True, 136 | **generation_kwargs 137 | ): 138 | if isinstance(query_tensor, List): 139 | return self._generate_batched( 140 | query_tensor, 141 | length_sampler=length_sampler, 142 | batch_size=batch_size, 143 | return_prompt=return_prompt, 144 | **generation_kwargs, 145 | ) 146 | 147 | else: 148 | if length_sampler is not None: 149 | generation_kwargs["max_new_tokens"] = length_sampler() 150 | response = self.accelerator.unwrap_model(self.model).generate( 151 | input_ids=query_tensor.unsqueeze(dim=0), **generation_kwargs 152 | ) 153 | 154 | if not return_prompt and not self.is_encoder_decoder: 155 | return response[:, query_tensor.shape[0]:] 156 | return response 157 | 158 | def _generate_batched( 159 | self, 160 | query_tensors, 161 | length_sampler=None, 162 | batch_size=4, 163 | return_prompt=True, 164 | pad_to_multiple_of=None, 165 | remove_padding=True, 166 | **generation_kwargs, 167 | ): 168 | outputs = [] 169 | 170 | padding_side_default = self.tokenizer.padding_side 171 | if not self.is_encoder_decoder: 172 | self.tokenizer.padding_side = "left" 173 | 174 | # in case we have fewer examples than bs 175 | batch_size = min(len(query_tensors), batch_size) 176 | 177 | for i in range(0, len(query_tensors), batch_size): 178 | if length_sampler is not None: 179 | generation_kwargs["max_new_tokens"] = length_sampler() 180 | 181 | # prevent overflow if query tensors are not even multiple of bs 182 | end_index = min(len(query_tensors), i + batch_size) 183 | 184 | batch = query_tensors[i:end_index] 185 | batch_mask = [torch.ones_like(element) for element in batch] 186 | inputs = {"input_ids": batch, "attention_mask": batch_mask} 187 | 188 | padded_inputs = self.tokenizer.pad( 189 | inputs, 190 | padding=True, 191 | max_length=None, 192 | pad_to_multiple_of=pad_to_multiple_of, 193 | return_tensors="pt", 194 | ).to(self.current_device) 195 | unwrapped_model = self.accelerator.unwrap_model(self.model) 196 | unwrapped_model.gradient_checkpointing_disable() 197 | unwrapped_model.config.use_cache = True 198 | generations = unwrapped_model.generate(**padded_inputs, **generation_kwargs) 199 | unwrapped_model.gradient_checkpointing_enable() 200 | unwrapped_model.config.use_cache = False 201 | for generation, mask in zip(generations, padded_inputs["attention_mask"]): 202 | if not self.is_encoder_decoder: 203 | output = generation[(1 - mask).sum():] # remove padding 204 | else: 205 | output = generation 206 | 207 | if not return_prompt and not self.is_encoder_decoder: 208 | output = output[mask.sum():] # remove prompt 209 | 210 | if remove_padding and self.tokenizer.eos_token_id in output: 211 | pad_mask = output == self.tokenizer.eos_token_id 212 | pad_start = torch.nonzero(pad_mask, as_tuple=False)[0, 0].item() 213 | output = output[: pad_start + 1] # keep the eos token at the end 214 | 215 | outputs.append(output) 216 | 217 | self.tokenizer.padding_side = padding_side_default 218 | return outputs 219 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2023/7/11 23:39 3 | # @Author : lishouxian 4 | # @Email : gzlishouxian@gmail.com 5 | # @File : main.py 6 | # @Software: PyCharm 7 | from engines.utils.check_load_config import Configure 8 | from engines.data import DataManager 9 | from loguru import logger 10 | from engines.models import BaseModels 11 | from engines.train import Train 12 | from engines.predict import Predictor 13 | 14 | 15 | if __name__ == '__main__': 16 | 17 | config = Configure() 18 | mode = config.mode 19 | log_name = './logs/' + mode + '.log' 20 | logger.add(log_name, encoding='utf-8') 21 | 22 | data_manager = DataManager(config, logger) 23 | if mode == 'pretrain': 24 | # 模型预训练 25 | train = Train(data_manager, config, logger) 26 | train.pretrain() 27 | elif mode == 'sft_train': 28 | # 模型指令微调 29 | train = Train(data_manager, config, logger) 30 | train.supervised_fine_tuning() 31 | elif mode == 'rm_train': 32 | # 奖励模型训练 33 | train = Train(data_manager, config, logger) 34 | train.train_reward_model() 35 | elif mode == 'dpo_train': 36 | # RLHF中使用DPO方法对模型强化训练 37 | train = Train(data_manager, config, logger) 38 | train.train_dpo() 39 | elif mode == 'ppo_train': 40 | # RLHF中使用PPO方法对模型强化训练 41 | train = Train(data_manager, config, logger) 42 | train.train_ppo() 43 | elif mode == 'sft_batch_test': 44 | # 微调模型效果测试 45 | train = Train(data_manager, config, logger) 46 | train.supervised_fine_tuning(test=True) 47 | elif mode == 'rm_batch_test': 48 | # 奖励模型效果测试 49 | train = Train(data_manager, config, logger) 50 | train.train_reward_model(test=True) 51 | elif mode == 'web_inference': 52 | # 网页端测试模型 53 | predict = Predictor(data_manager, config, logger) 54 | predict.web_inference() 55 | elif mode == 'terminal_inference': 56 | # 终端模型交互 57 | predict = Predictor(data_manager, config, logger) 58 | predict.terminal_inference() 59 | elif mode == 'merge_peft_model': 60 | # 融合模型 61 | model = BaseModels(data_manager, config, logger) 62 | model.merge_lora_model() 63 | elif mode == 'show_model_info': 64 | # 打印模型参数 65 | model = BaseModels(data_manager, config, logger) 66 | model.show_model_info() 67 | elif mode == 'save_quantized_model': 68 | # 存储量化的模型 69 | if config.model_args.quantization_bit not in (4, 8): 70 | raise ValueError('Quantization bit not set.') 71 | model = BaseModels(data_manager, config, logger) 72 | model.save_quantized_model() 73 | elif mode == 'expand_vocab': 74 | # 扩充词表 75 | model = BaseModels(data_manager, config, logger) 76 | model.expand_vocab() 77 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | transformers==4.34.0 2 | peft==0.5.0 3 | deepspeed==0.10.1 4 | accelerate==0.22.0 5 | bitsandbytes==0.39.0 6 | transformers-stream-generator==0.0.4 7 | datasets==2.12.0 8 | torch==2.0.1 9 | torchaudio==2.0.2 10 | torchvision==0.15.2 11 | gradio==3.39.0 12 | mdtex2html==1.2.0 13 | rouge-chinese==1.0.3 14 | jieba==0.42.1 15 | tqdm==4.65.0 16 | loguru==0.7.0 17 | trl==0.7.1 18 | tiktoken==0.4.0 19 | xformers==0.0.21 20 | cpm-kernels==1.0.11 21 | flash-attn==2.3.1.post1 22 | sentencepiece==0.1.99 23 | attention_sinks==0.3.0 --------------------------------------------------------------------------------