├── requirements.txt ├── pics ├── ZhiLu_CEVAL_20231029.png ├── ZhiLu_CMMLU_20231029.png └── ZhiLu_CMMLU_zs_20231029.png ├── scripts ├── merge.sh └── merge_llama_with_chinese_lora_low_mem.py ├── FAQ.md ├── .gitignore ├── LICENSE └── README.md /requirements.txt: -------------------------------------------------------------------------------- 1 | peft>=0.3.0 2 | torch==2.0.1 3 | transformers==4.31.0 4 | sentencepiece==0.1.97 5 | bitsandbytes==0.41.0 6 | -------------------------------------------------------------------------------- /pics/ZhiLu_CEVAL_20231029.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SYSU-MUCFC-FinTech-Research-Center/ZhiLu/HEAD/pics/ZhiLu_CEVAL_20231029.png -------------------------------------------------------------------------------- /pics/ZhiLu_CMMLU_20231029.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SYSU-MUCFC-FinTech-Research-Center/ZhiLu/HEAD/pics/ZhiLu_CMMLU_20231029.png -------------------------------------------------------------------------------- /pics/ZhiLu_CMMLU_zs_20231029.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SYSU-MUCFC-FinTech-Research-Center/ZhiLu/HEAD/pics/ZhiLu_CMMLU_zs_20231029.png -------------------------------------------------------------------------------- /scripts/merge.sh: -------------------------------------------------------------------------------- 1 | Python merge_llama_with_chinese_lora_low_mem.py \ 2 | --base_model path/to/base_model \ 3 | --lora_model path/to/lora_model \ 4 | --output_type huggingface \ 5 | --output_dir path/to/output_dir -------------------------------------------------------------------------------- /FAQ.md: -------------------------------------------------------------------------------- 1 | ### 模型在预训练还是在指令微调阶段进行LoRA训练? 2 | 3 | 本项目在预训练和指令微调阶段均使用LoRA进行高效训练。 4 | 5 | ### 为什么对模型进行LoRA训练而非全量参数训练? 6 | 7 | 考虑到训练成本和效率等因素,我们选择在Alpaca-2的基础上使用LoRA进行训练(embedding/lm_head全量参与训练)。 8 | 9 | ### 为什么选择在Alpaca-2的基础上进行训练,而非Llama-2? 10 | 11 | Llama-2虽然已具备一定的中文理解能力,但在生成中文文本时仍然会夹杂英文,且中文词表大小有限,需要进行进一步的中文能力扩展,Alpaca-2在该方面已经做出了出色的工作。 12 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | 162 | # 忽略.idea目录 163 | .idea/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | 203 | 204 | 205 | ## Some of ZhiLu's code is derived from others projects, which is subject to the following copyright notice: 206 | 207 | Copyright 2023 Yiming Cui, Ziqing Yang, Xin Yao 208 | 209 | Licensed under the Apache License, Version 2.0 (the "License"); 210 | you may not use this file except in compliance with the License. 211 | You may obtain a copy of the License at 212 | 213 | http://www.apache.org/licenses/LICENSE-2.0 214 | 215 | Unless required by applicable law or agreed to in writing, software 216 | distributed under the License is distributed on an "AS IS" BASIS, 217 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 218 | See the License for the specific language governing permissions and 219 | limitations under the License. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | 🤗 模型下载 3 |

4 | 5 | # News 6 | [2023.10.28] 开源智鹿-13B对话大模型。🤗 [HuggingFace](https://huggingface.co/SYSU-MUCFC-FinTech-Research-Center/ZhiLu-13B-Instruct) 7 | 8 | # 目录 9 | 10 | - [智鹿](https://github.com/SYSU-MUCFC-FinTech-Research-Center/ZhiLu#%E6%99%BA%E9%B9%BF) 11 | - [介绍](https://github.com/SYSU-MUCFC-FinTech-Research-Center/ZhiLu#%E4%BB%8B%E7%BB%8D) 12 | - [训练细节](https://github.com/SYSU-MUCFC-FinTech-Research-Center/ZhiLu#%E8%AE%AD%E7%BB%83%E7%BB%86%E8%8A%82) 13 | - [性能评测](https://github.com/SYSU-MUCFC-FinTech-Research-Center/ZhiLu#%E6%80%A7%E8%83%BD%E8%AF%84%E6%B5%8B) 14 | - [对话示例](https://github.com/SYSU-MUCFC-FinTech-Research-Center/ZhiLu#%E5%AF%B9%E8%AF%9D%E7%A4%BA%E4%BE%8B) 15 | - [使用方式](https://github.com/SYSU-MUCFC-FinTech-Research-Center/ZhiLu#%E4%BD%BF%E7%94%A8%E6%96%B9%E5%BC%8F) 16 | - [Roadmap](https://github.com/SYSU-MUCFC-FinTech-Research-Center/ZhiLu#Roadmap) 17 | - [带PROMPT模板的对话](https://github.com/SYSU-MUCFC-FinTech-Research-Center/ZhiLu#%E5%B8%A6PROMPT%E6%A8%A1%E6%9D%BF%E7%9A%84%E5%AF%B9%E8%AF%9D) 18 | 19 | # 智鹿 20 | 21 | ## 介绍 22 | 智鹿是一个基于[Alpaca2-13B](https://github.com/ymcui/Chinese-LLaMA-Alpaca-2)进行二次训练的中文消金领域对话大模型,我们使用大量中英文语料进行增量预训练,同时使用高质量指令数据进行对齐,并且整合了人类价值观对齐相关数据和训练tricks。模型训练的目标是在保持通用能力的前提下,显著提升金融领域的能力。 23 | 24 | ## 训练细节 25 | 26 | - 我们收集了各类数据,包括上市公司公告、财经新闻、上市公司年度报告、新闻、金融资讯、社区问答、维基百科等高质量数据。 27 | 28 | - 模型训练的总token数为`14.69B`,通用语料与金融预料比例约为2:1,中英配比约为2:1。 29 | 30 | - 使用LoRA进行高效训练(含emb/lm-head),部分超参数设置如下: 31 | 32 | ```sh 33 | lr=2e-4 34 | lora_rank=64 35 | lora_alpha=128 36 | lora_trainable="q_proj,v_proj,k_proj,o_proj,gate_proj,down_proj,up_proj" 37 | modules_to_save="embed_tokens,lm_head" 38 | lora_dropout=0.05 39 | per_device_train_batch_size=64 40 | per_device_eval_batch_size=64 41 | gradient_accumulation_steps=1 42 | ``` 43 | 44 | - 使用FlashAttention-2技术进行加速训练。 45 | 46 | ## 性能评测 47 | 48 | 我们对模型的评测分为通用领域和金融领域两方面的评测。 49 | 50 | 在通用能力评测中,我们选择当下主流的两类客观评测基准: 51 | 52 | - [C-Eval](https://cevalbenchmark.com/index.html#home) 是一个全面的中文基础模型评估基准。它包含了13948个多项选择题,涵盖了52个不同的学科和四个难度级别 53 | - [CMMLU](https://github.com/haonan-li/CMMLU) 是一个综合性的中文评估基准,专门用于评估语言模型在中文语境下的知识和推理能力。CMMLU涵盖了从基础学科到高级专业水平的67个主题。 54 | 55 | 金融能力评测的基准: 56 | 57 | - [FinEval ](https://github.com/SUFE-AIFLM-Lab/FinEval/)是一个专门为LLMs中的金融领域知识而设计的基准测试,涵盖金融、经济、会计和证书等领域。它包括4,661个问题,涵盖了34个不同的学术科目。(需要声明的是,由于FinEval并未公开test集的answer,也未开放评测接口,因此本文在FinEval上的评测上使用了val集) 58 | 59 | 与底座模型Alpaca2-13B的评测结果对比如下: 60 | 61 | | 模型 | CEVAL | CMMLU | FinEval | 62 | | ------ | ------ | ------ | ------ | 63 | | 智鹿 | 50.1 | 57.16 |58.69 | 64 | | Alpaca2-13B | 43.6 | 44.48 |45.26 | 65 | 66 | - 在通用领域评测中,智鹿对比Alpaca2-13B,保留了基本的通用能力,C-Eval得分些许提高,CMMLU得分显著提高。 67 | - 在金融领域评测FinEval中,智鹿的性能显著提高。 68 | 69 | C-Eval榜单(截止至2023.10.29) 70 |

71 | 72 | CMMLU榜单(few-shot、zero-shot)(截止至2023.10.29) 73 |

74 |

75 | 76 | ## 对话示例 77 | 78 | ### 金融 79 | 80 | ``` 81 | 问题:怎么炒股才能稳赚不赔? 82 | 智鹿:股市投资是有风险的,无法保证稳赚不赔。以下是一些建议来帮助您在股票市场中实现更稳定的投资回报:1. 分散投资组合:将资金分配到不同的行业、公司和资产类别上,以降低特定投资的风险。这样即使某些投资表现不佳,其他投资也可能弥补损失。2. 长期投资策略:选择具备稳定增长潜力的优质企业,并持有它们的股票较长时间。长期投资可以减少短期市场波动对投资的影响,并获得潜在的资本增值和股息收入。3. 定期重新评估投资计划:根据市场情况和个人目标,定期检查和调整投资组合。这可以帮助您及时做出必要的调整,以适应市场的变化并保持符合您的目标。4. 控制情绪:避免受到市场噪音和情绪影响的投资决策。不要盲目跟风或恐慌出局,而是根据自己的研究和分析制定合理的投资计划。5. 学习和继续教育:了解基本的投资原理和技术分析工具,不断提升自己的知识水平。这有助于更好地理解市场趋势和投资机会,从而作出明智的决策。请注意,以上建议仅供参考,并不能保证投资成功。在进行任何投资之前,请务必咨询专业金融顾问的意见,以便为您的具体情况提供个性化的建议。 83 | 84 | 问题:什么是A股? 85 | 智鹿:简单来说,就是中国内地的股票市场。它包括上海证券交易所和深圳证券交易所。A股市场起源于1990年,是中国股票市场的主要组成部分。A股市场的特点:1. 开放性:A股市场对外开放程度较高,允许国内外投资者买卖股票。2. 风险较高:与发达国家相比,A股市场波动较大,投资风险相对较高。3. 监管严格:中国证监会负责监督和管理A股市场,并采取措施保护投资者利益。4. 长期投资机会多:尽管存在风险,但A股市场也提供了丰富的长期投资机会。对于投资者而言,了解A股市场的运作机制、相关规则和最新动态非常重要。这有助于他们做出明智的投资决策,并在A股市场中获取收益。 86 | ``` 87 | 88 | ### 通用领域 89 | 90 | 91 | ``` 92 | 问题:潮汕地区有什么特产? 93 | 智鹿:1.牛肉:潮汕牛肉以肉质细嫩、多汁,且不腥臊腥著称。2.虾螺贝类:汕头地区的海产品丰富,包括各种海鲜如鲍鱼、海参、石斑鱼等。3.柑橘:潮州的柑橘以其独特的清香和丰富的营养价值而闻名。4.姜撞子:是潮州特色小吃之一,将猪肉、鸡蛋、木姜黄等馅料包裹在糯米粉里制成。5.牛肉炒米:用牛肉和糯米制作而成,口感软糯,深受人们喜爱。6.沙士:是一种传统的甜点,由糯米粉制成,并加入花生碎、芝麻等配料。7.牛肉丸:潮汕地区的牛肉丸因其Q弹口感而著名,通常配上辣椒酱或特制的汤汁食用。8.大龙粿:一种传统潮州糕点,采用粘糯米制成,内陷有豆沙或红绿丝等。9.牛肉火锅:潮汕地区有许多著名的牛肉火锅店,其特点是使用高质量的牛肉和独特的调料。10.冬至面:在冬至日,许多家庭会准备热气腾腾的冬至面,其中包含着多种肉类和蔬菜。以上只是一些常见的潮汕美食和特产,还有更多值得探索的地方和美食等待你去发现。 94 | 95 | 问题:晚上睡不着怎么办? 96 | 智鹿:1. 做点轻松的事,如阅读、写字等。2. 听音乐,放松身心。3. 数羊或数数,转移注意力。4. 洗个澡,让自己放松下来。5. 出去走走,呼吸新鲜空气。6. 与朋友聊天,分享烦恼并寻求帮助。7. 写日记,记录今天发生的事情和感受。8. 运动运动,释放压力和紧张情绪。9. 冥想或练习深呼吸,平静自己的心神。10. 看喜剧片,暂时忘记忧愁。 97 | 98 | 问题:介绍一下中国象棋的规则。 99 | 智鹿:1. 走法:双方各自有15块棋子,包括王、车、兵、相、马和炮。每种单位都有特定的走法,例如,王只能前进一步或退后到安全的位置,而其他单位则可以根据具体情况选择不同的走法。2. 控制区域:每个玩家拥有自己的半盘,称为“阵营”。在战斗中,可以将对方的单位推回自己阵营之外的区域内,从而控制该区域。3. 攻击方式:不同类型的单位具有不同的攻击能力。例如,车可以向前推进一个空格,同时可以把敌人的单位移开;兵则可以从初始位置出发,沿着直线向对手发起进攻。4. 防御策略:为了保护己方单位免受攻击,玩家可以通过调整布局来制定有效的防御策略。这可能涉及到隐藏单位、利用障碍物(如山石)以及进行交换等手段。5. 中国象棋的历史和文化价值:中国象棋起源于中国古代,有着悠久的历史。它不仅作为一种智力游戏受到人们喜爱,还蕴含着丰富的哲学思想和道德修养。通过学习和参与中国象棋比赛,人们可以增强团队合作意识、培养决断力和战略思维能力。 100 | ``` 101 | 102 | ## 使用方式 103 | 104 | 我们已经把智鹿的模型权重发布在HuggingFace,由于我们的底座模型是Alpaca2,而Alpaca2是基于Llama2,因此可以使用Llama2-hf的调用方式,如下: 105 | 106 | 我们提供两种模型调用方式:`模型合并后直接加载`和`底座模型、LoRA模块分开加载` 107 | 108 | 温馨提醒: 109 | 110 | 1、使用智鹿需要约30GB显存。 111 | 112 | 2、prompt的结尾尽量加上标点符号,否则模型可能会续写prompt。 113 | 114 | ### 环境安装 115 | 使用 pip 安装依赖: 116 | ``` 117 | pip install -r requirements.txt 118 | ``` 119 | 其中 `transformers` 库版本推荐为 `4.30`以上,`torch` 推荐使用 2.0 及以上的版本,以获得最佳的推理性能。 120 | 121 | ### 合并后使用 122 | 123 | 需将Alpaca-2-13b与我们提供的LoRA模块进行合并后,才能使用,合并脚本:`\scripts\merge.sh` 124 | 125 | 模型使用时只需提供合并后的地址: 126 | 127 | ```python 128 | import torch 129 | from transformers import LlamaForCausalLM, LlamaTokenizer 130 | model_name_or_path = "" 131 | tokenizer = LlamaTokenizer.from_pretrained(model_name_or_path, use_fast=False, legacy=True) 132 | model = LlamaForCausalLM.from_pretrained(model_name_or_path, torch_dtype=torch.bfloat16,device_map="auto") 133 | inputs = tokenizer("什么是A股?", return_tensors="pt").to("cuda") 134 | outputs = model.generate(**inputs, max_new_tokens=64, repetition_penalty=1.1) 135 | outputs = tokenizer.decode(outputs.cpu()[0][len(inputs.input_ids[0]):], skip_special_tokens=True) 136 | print(outputs) 137 | ``` 138 | 139 | ### 加载LoRA模块 140 | 141 | 需提供Alpaca-2-13b和LoRA模块的地址: 142 | 143 | ```python 144 | import torch 145 | from transformers import LlamaForCausalLM, LlamaTokenizer 146 | from peft import PeftModel 147 | model_name_or_path = "" 148 | peft_model_path = "" 149 | tokenizer = LlamaTokenizer.from_pretrained(model_name_or_path, use_fast=False, legacy=True) 150 | model = LlamaForCausalLM.from_pretrained(model_name_or_path, torch_dtype=torch.bfloat16,device_map="auto") 151 | if peft_model_path is not None: 152 | model = PeftModel.from_pretrained( 153 | model, 154 | peft_model_path, 155 | torch_dtype=( 156 | torch.bfloat16 157 | if torch.cuda.is_bf16_supported() 158 | else torch.float32 159 | ), 160 | ) 161 | inputs = tokenizer("什么是A股?", return_tensors="pt").to("cuda") 162 | outputs = model.generate(**inputs, max_new_tokens=64, repetition_penalty=1.1) 163 | outputs = tokenizer.decode(outputs.cpu()[0][len(inputs.input_ids[0]):], skip_special_tokens=True) 164 | print(outputs) 165 | ``` 166 | 167 | ## 带PROMPT模板的对话 168 | 169 | ### 单轮对话 170 | 171 | ```python 172 | import os 173 | import torch 174 | from transformers import ( 175 | LlamaForCausalLM, 176 | LlamaTokenizer, 177 | GenerationConfig 178 | ) 179 | from peft import PeftModel 180 | model_name_or_path = "" 181 | peft_model_path = "" 182 | tokenizer = LlamaTokenizer.from_pretrained(model_name_or_path, use_fast=False, legacy=True) 183 | model = LlamaForCausalLM.from_pretrained(model_name_or_path, torch_dtype=torch.bfloat16,device_map="auto") 184 | if peft_model_path is not None: 185 | model = PeftModel.from_pretrained( 186 | model, 187 | peft_model_path, 188 | torch_dtype=( 189 | torch.bfloat16 190 | if torch.cuda.is_bf16_supported() 191 | else torch.float32 192 | ), 193 | ) 194 | print(model) 195 | generation_config = GenerationConfig.from_pretrained(model_name_or_path) 196 | config_updates = { 197 | 'do_sample': True, 198 | 'num_beams': 1, 199 | 'repetition_penalty': 1.1, 200 | 'min_new_tokens': 20, 201 | 'top_k' : 50, 202 | 'top_p' : 0.95, 203 | 'temperature' : 0.7 204 | } 205 | # 更新生成配置 206 | for key, value in config_updates.items(): 207 | setattr(generation_config, key, value) 208 | prompt_template = ( 209 | "[INST] <>\n" 210 | "You are a helpful assistant. 你是一个乐于助人的助手。\n" 211 | "<>\n\n" 212 | "{instruction} [/INST]" 213 | ) 214 | instruction = "什么是A股?" 215 | prompt = prompt_template.format_map({'instruction':instruction}) 216 | input_ids = tokenizer.encode(prompt, return_tensors="pt").cuda() 217 | output = model.generate(inputs = input_ids,generation_config = generation_config, return_dict_in_generate=True, num_return_sequences=1) 218 | output_text = tokenizer.decode(output.sequences[0], skip_special_tokens=True) 219 | print(output_text) 220 | 221 | ''' 222 | [INST] <> 223 | You are a helpful assistant. 你是一个乐于助人的助手。 224 | <> 225 | 226 | 什么是A股? [/INST] A股指的是中国内地的股票市场,也被称为上海证券交易所和深圳证券交易所的主板股票市场。它是中国证券市场的主要部分,包括了许多大型上市公司。与此同时,A股市场也包括了中小型企业和创业公司的股票交易。A股市场的特点是具有较高的波动性,投资者可以通过购买股票来参与市场的买卖活动,并从中获得资本收益或分红收益。在A股市场上,股价会受到多种因素的影响,例如公司业绩、宏观经济状况以及政府政策等。因此,了解A股市场的基本概念和运作方式对于投资决策非常重要。 227 | ''' 228 | ``` 229 | 230 | ### 多轮对话 231 | 232 | ```python 233 | import os 234 | import torch 235 | from transformers import ( 236 | LlamaForCausalLM, 237 | LlamaTokenizer, 238 | GenerationConfig 239 | ) 240 | from peft import PeftModel 241 | model_name_or_path = "" 242 | peft_model_path = "" 243 | tokenizer = LlamaTokenizer.from_pretrained(model_name_or_path, use_fast=False, legacy=True) 244 | model = LlamaForCausalLM.from_pretrained(model_name_or_path, torch_dtype=torch.bfloat16,device_map="auto") 245 | if peft_model_path is not None: 246 | model = PeftModel.from_pretrained( 247 | model, 248 | peft_model_path, 249 | torch_dtype=( 250 | torch.bfloat16 251 | if torch.cuda.is_bf16_supported() 252 | else torch.float32 253 | ), 254 | ) 255 | print(model) 256 | generation_config = GenerationConfig.from_pretrained(model_name_or_path) 257 | config_updates = { 258 | 'do_sample': True, 259 | 'num_beams': 1, 260 | 'repetition_penalty': 1.1, 261 | 'min_new_tokens': 20, 262 | 'top_k' : 50, 263 | 'top_p' : 0.95, 264 | 'temperature' : 0.7 265 | } 266 | # 更新生成配置 267 | for key, value in config_updates.items(): 268 | setattr(generation_config, key, value) 269 | def multi_generate(history, insturction): 270 | history_str = "".join(history) 271 | prompt = history_str + insturction 272 | input_ids = tokenizer.encode(prompt, return_tensors="pt").cuda() 273 | output = model.generate(inputs = input_ids,generation_config = generation_config, return_dict_in_generate=True, num_return_sequences=1) 274 | output_text = tokenizer.decode(output.sequences[0], skip_special_tokens=True) 275 | history.append(output_text.replace(history_str, "")) 276 | answer = output_text.replace(history_str, "") 277 | answer = answer.replace(insturction, "") 278 | # answer 已去掉上下文及用户输入的文本 279 | return answer 280 | history = [] 281 | cnt = 1 282 | prompt_template = ( 283 | "[INST] <>\n" 284 | "You are a helpful assistant. 你是一个乐于助人的助手。\n" 285 | "<>\n\n" 286 | "{instruction} [/INST]" 287 | ) 288 | print(f"-------------round {cnt} --------------") 289 | cnt += 1 290 | user_input = input("user:") 291 | user_input = prompt_template.format_map({'instruction':user_input}) 292 | answer = multi_generate(history, user_input) 293 | print(f"ZhiLu:{answer}") 294 | while True: 295 | print(f"-------------round {cnt} --------------") 296 | cnt += 1 297 | user_input = input("user:") 298 | user_input = f"[INST] {user_input} [/INST]" 299 | answer = multi_generate(history, user_input) 300 | print(f"ZhiLu:{answer}") 301 | 302 | ''' 303 | -------------round 1 -------------- 304 | user:什么是A股? 305 | ZhiLu:A股市场是中国股票市场的一个重要组成部分,指代在上海证券交易所和深圳证券交易所上市的股票。它是中国内地投资者进行股票交易的主要场所。A股市场的特点如下:- A股市场是中国内地股票市场的主体,也是全球最大的股票市场之一。- A股市场的股票只能在中国境内交易,且只有中国公民和机构可以参与该市场。- A股市场的交易时间一般为周一至周五上午9:30至下午3:00,中午有一个短暂的休市时间。- A股市场实行T+1制度,即当日买入的股票次日才能卖出,有助于减少短期投机行为。- A股市场的股票价格通常会受到多种因素的影响,包括经济状况、政策变化、行业动态和公司业绩等。总而言之,A股市场是中国内地股票市场,提供了广泛的投资机会和灵活的交易机制,但也存在相应的风险需要投资者谨慎对待。 306 | -------------round 2 -------------- 307 | user:在哪可以看A股? 308 | ZhiLu:可以通过以下几种方式查看A股市场的数据和行情:1. 股票交易软件:许多券商都提供自家的手机应用或网页版交易软件,你可以在这些平台上查询股票的价格、成交量、涨跌幅等相关指标。2. 财经网站:像新浪财经、东方财富网等财经网站都可以提供实时的股票行情、新闻资讯和研究报告等信息。3. 电视台财经频道:观看财经类电视台如CCTV证券、上海证券报等,他们会定期报道股市动态、新闻报道和专家观点等。无论使用哪种方法,记得始终注意信息来源的可靠性和准确性,同时也要注意市场风险和自身投资能力。 309 | ''' 310 | ``` 311 | 312 | 313 | 314 | # Roadmap 315 | 316 | - [x] 底座类型 317 | - [x] 交互(Chat)型 318 | - [ ] 指令(Instruct)型 319 | - [x] 数据增强 320 | - [x] 通用场景知识 321 | - [ ] 消金场景知识增强 322 | - [x] 训练模块 323 | - [x] 常规LoRA模块 324 | - [ ] 预训练&指令微调 325 | - [ ] 人类偏好反馈增强(RLHF、DPO、PRO等) 326 | - [x] 长度外推 327 | - [x] 原生支持4k 328 | - [ ] 支持8k、16k等长度版本 329 | - [x] 性能优化模块 330 | - [x] 支持Flash-Attention-2 331 | - [ ] 支持vLLM、Streaming-LLM等加速模块 332 | 333 | # FAQ 334 | 335 | 项目答疑以及部署中遇到问题的解决方案,请查阅[FAQ](https://github.com/SYSU-MUCFC-FinTech-Research-Center/ZhiLu/blob/main/FAQ.md) 336 | 337 | # 致谢 338 | 339 | 本项目主要基于Alpaca2开源项目进行二次开发,在此对相关项目和研究开发人员表示感谢。 340 | 341 | # 免责声明与许可协议 342 | 本仓库的代码依照[Apache-2.0](https://github.com/SYSU-MUCFC-FinTech-Research-Center/ZhiLu/blob/main/LICENSE)协议开源。模型权重对学术研究完全开放;使用者也可通过[申请表](https://wj.qq.com/s2/13390238/fea9/)申请、经同意并发放商用授权证书后免费用于商业用途。 343 | 尽管我们在模型训练过程中尽力确保数据的合规性和准确性,但由于模型受概率随机性因素影响及易被误导,无法保证输出内容的准确性。因此,使用者在使用本模型及其生成的内容时,应自行审慎识别后作出独立判断,必要时应征询专业人士意见,并由使用者承担使用风险。使用者也不得将本模型用于任何可能给国家和社会带来危害的用途以及用于任何未经过安全评估和备案的服务。我们不承担开源模型及其生成的内容导致的安全风险、知识产权风险、舆情风险或发生任何模型被误导、滥用、不当利用及传播而产生的风险和责任。 344 | 345 | # 总结 346 | 347 | 我们鼓励使用者在相关工作中引用智鹿,以促进知识的共享和交流,并为中文金融对话系统的不断发展贡献力量。 348 | 智鹿的发布旨在为金融领域的应用和研究提供有力支持,为中文金融对话系统的进步做出积极贡献。我们期待见证更多的创新和应用案例,以提升金融服务和用户体验,同时也推动人工智能技术在金融领域的蓬勃发展。通过合作和分享,我们可以共同推动这一领域的发展,为社会和行业带来更多的好处。 349 | -------------------------------------------------------------------------------- /scripts/merge_llama_with_chinese_lora_low_mem.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage: 3 | python merge_llama_with_chinese_lora_low_mem.py \ 4 | --base_model /mnt/data2/finLLM/models/alpaca-2-13b \ 5 | --lora_model /mnt/data2/finLLM/llama-alpaca-2-main/pt_lora_model \ 6 | --output_type huggingface \ 7 | --output_dir /mnt/data2/finLLM/test 8 | """ 9 | import argparse 10 | import json 11 | import os 12 | import gc 13 | import torch 14 | import peft 15 | from transformers import LlamaTokenizer 16 | from transformers.modeling_utils import dtype_byte_size 17 | from huggingface_hub import snapshot_download 18 | import re 19 | 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument('--base_model', default=None, required=True, 22 | type=str, help="Please specify a base model") 23 | parser.add_argument('--lora_model', default=None, required=True, 24 | type=str, help="Please specify LoRA models to be merged (ordered); use commas to separate multiple LoRA models") 25 | parser.add_argument('--output_type', default='pth',choices=['pth','huggingface'], 26 | type=str, help="Save the merged model in pth or huggingface format") 27 | parser.add_argument('--output_dir', default='./merged_model', 28 | type=str, help="The output folder to save the merged model") 29 | parser.add_argument('--verbose', default=False, action='store_true', 30 | help="Show detailed messages") 31 | 32 | 33 | emb_to_model_size = { 34 | 4096 : '7B', 35 | 5120 : '13B', 36 | 6656 : '33B', 37 | 8192 : '65B', 38 | } 39 | num_shards_of_models = {'7B': 1, '13B': 2, '33B': 4, '65B': 8} 40 | params_of_models = { 41 | '7B': 42 | { 43 | "dim": 4096, 44 | "multiple_of": 256, 45 | "n_heads": 32, 46 | "n_layers": 32, 47 | "norm_eps": 1e-06, 48 | "vocab_size": -1, 49 | }, 50 | '13B': 51 | { 52 | "dim": 5120, 53 | "multiple_of": 256, 54 | "n_heads": 40, 55 | "n_layers": 40, 56 | "norm_eps": 1e-06, 57 | "vocab_size": -1, 58 | }, 59 | '33B': 60 | { 61 | "dim": 6656, 62 | "multiple_of": 256, 63 | "n_heads": 52, 64 | "n_layers": 60, 65 | "norm_eps": 1e-06, 66 | "vocab_size": -1, 67 | }, 68 | '65B': 69 | { 70 | "dim": 8192, 71 | "multiple_of": 256, 72 | "n_heads": 64, 73 | "n_layers": 80, 74 | "norm_eps": 1e-05, 75 | "vocab_size": -1, 76 | }, 77 | } 78 | 79 | def transpose(weight, fan_in_fan_out): 80 | return weight.T if fan_in_fan_out else weight 81 | 82 | # Borrowed and modified from https://github.com/tloen/alpaca-lora 83 | def translate_state_dict_key(k): 84 | k = k.replace("base_model.model.", "") 85 | if k == "model.embed_tokens.weight": 86 | return "tok_embeddings.weight" 87 | elif k == "model.norm.weight": 88 | return "norm.weight" 89 | elif k == "lm_head.weight": 90 | return "output.weight" 91 | elif k.startswith("model.layers."): 92 | layer = k.split(".")[2] 93 | if k.endswith(".self_attn.q_proj.weight"): 94 | return f"layers.{layer}.attention.wq.weight" 95 | elif k.endswith(".self_attn.k_proj.weight"): 96 | return f"layers.{layer}.attention.wk.weight" 97 | elif k.endswith(".self_attn.v_proj.weight"): 98 | return f"layers.{layer}.attention.wv.weight" 99 | elif k.endswith(".self_attn.o_proj.weight"): 100 | return f"layers.{layer}.attention.wo.weight" 101 | elif k.endswith(".mlp.gate_proj.weight"): 102 | return f"layers.{layer}.feed_forward.w1.weight" 103 | elif k.endswith(".mlp.down_proj.weight"): 104 | return f"layers.{layer}.feed_forward.w2.weight" 105 | elif k.endswith(".mlp.up_proj.weight"): 106 | return f"layers.{layer}.feed_forward.w3.weight" 107 | elif k.endswith(".input_layernorm.weight"): 108 | return f"layers.{layer}.attention_norm.weight" 109 | elif k.endswith(".post_attention_layernorm.weight"): 110 | return f"layers.{layer}.ffn_norm.weight" 111 | elif k.endswith("rotary_emb.inv_freq") or "lora" in k: 112 | return None 113 | else: 114 | print(layer, k) 115 | raise NotImplementedError 116 | else: 117 | print(k) 118 | raise NotImplementedError 119 | 120 | 121 | def unpermute(w): 122 | return ( 123 | w.view(n_heads, 2, dim // n_heads // 2, dim).transpose(1, 2).reshape(dim, dim) 124 | ) 125 | 126 | 127 | def save_shards(model_sd, num_shards: int, prefix="", verbose=False): 128 | """ 129 | Convert and save the HF format weights to PTH format weights 130 | """ 131 | with torch.no_grad(): 132 | if num_shards == 1: 133 | new_state_dict = {} 134 | for k, v in model_sd.items(): 135 | new_k = translate_state_dict_key(k) 136 | if new_k is not None: 137 | if "wq" in new_k or "wk" in new_k: 138 | new_state_dict[new_k] = unpermute(v) 139 | else: 140 | new_state_dict[new_k] = v 141 | 142 | os.makedirs(output_dir, exist_ok=True) 143 | print(f"Saving shard 1 of {num_shards} into {output_dir}/{prefix}consolidated.00.pth") 144 | torch.save(new_state_dict, output_dir + f"/{prefix}consolidated.00.pth") 145 | else: 146 | new_state_dicts = [dict() for _ in range(num_shards)] 147 | for k in list(model_sd.keys()): 148 | v = model_sd[k] 149 | new_k = translate_state_dict_key(k) 150 | if new_k is not None: 151 | if new_k=='tok_embeddings.weight': 152 | assert v.size(1)%num_shards==0 153 | splits = v.split(v.size(1)//num_shards,dim=1) 154 | elif new_k=='output.weight': 155 | if v.size(0)%num_shards==0: 156 | splits = v.split(v.size(0)//num_shards,dim=0) 157 | else: 158 | size_list = [v.size(0)//num_shards] * num_shards 159 | size_list[-1] += v.size(0)%num_shards 160 | splits = v.split(size_list, dim=0) # 13B: size_list == [24976,24977] 161 | elif new_k=='norm.weight': 162 | splits = [v] * num_shards 163 | elif 'ffn_norm.weight' in new_k: 164 | splits = [v] * num_shards 165 | elif 'attention_norm.weight' in new_k: 166 | splits = [v] * num_shards 167 | 168 | 169 | elif 'w1.weight' in new_k: 170 | splits = v.split(v.size(0)//num_shards,dim=0) 171 | elif 'w2.weight' in new_k: 172 | splits = v.split(v.size(1)//num_shards,dim=1) 173 | elif 'w3.weight' in new_k: 174 | splits = v.split(v.size(0)//num_shards,dim=0) 175 | 176 | 177 | elif 'wo.weight' in new_k: 178 | splits = v.split(v.size(1)//num_shards,dim=1) 179 | 180 | elif 'wv.weight' in new_k: 181 | splits = v.split(v.size(0)//num_shards,dim=0) 182 | 183 | elif "wq.weight" in new_k or "wk.weight" in new_k: 184 | v = unpermute(v) 185 | splits = v.split(v.size(0)//num_shards,dim=0) 186 | else: 187 | print(f"Unexpected key {new_k}") 188 | raise ValueError 189 | if verbose: 190 | print(f"Processing {new_k}") 191 | for sd,split in zip(new_state_dicts,splits): 192 | sd[new_k] = split.clone() 193 | del split 194 | del splits 195 | del model_sd[k],v 196 | gc.collect() # Effectively enforce garbage collection 197 | 198 | os.makedirs(output_dir, exist_ok=True) 199 | for i,new_state_dict in enumerate(new_state_dicts): 200 | print(f"Saving shard {i+1} of {num_shards} into {output_dir}/{prefix}consolidated.0{i}.pth") 201 | torch.save(new_state_dict, output_dir + f"/{prefix}consolidated.0{i}.pth") 202 | 203 | def merge_shards(output_dir, num_shards: int): 204 | ckpt_filenames = sorted([f for f in os.listdir(output_dir) if re.match('L(\d+)-consolidated.(\d+).pth',f)]) 205 | 206 | for i in range(num_shards): 207 | shards_filenames = sorted([f for f in ckpt_filenames if re.match(f'L(\d+)-consolidated.0{i}.pth',f)]) 208 | print(f"Loading {shards_filenames} ...") 209 | shards_dicts = [torch.load(os.path.join(output_dir,fn)) for fn in shards_filenames] 210 | shards_merged = {} 211 | for d in shards_dicts: 212 | shards_merged |= d 213 | 214 | print(f"Saving the merged shard to " + os.path.join(output_dir, f"consolidated.0{i}.pth")) 215 | torch.save(shards_merged, os.path.join(output_dir, f"consolidated.0{i}.pth")) 216 | 217 | print("Cleaning up...") 218 | del shards_merged 219 | for d in shards_dicts: 220 | del d 221 | del shards_dicts 222 | gc.collect() # Effectively enforce garbage collection 223 | for fn in shards_filenames: 224 | os.remove(os.path.join(output_dir,fn)) 225 | 226 | if __name__=='__main__': 227 | 228 | args = parser.parse_args() 229 | base_model_path = args.base_model 230 | lora_model_paths = [s.strip() for s in args.lora_model.split(',') if len(s.strip())!=0] 231 | output_dir = args.output_dir 232 | output_type = args.output_type 233 | os.makedirs(output_dir, exist_ok=True) 234 | 235 | print(f"Base model: {base_model_path}") 236 | print(f"LoRA model(s) {lora_model_paths}:") 237 | 238 | tokenizers_and_loras = [] 239 | for lora_model_path in lora_model_paths: 240 | print(f"Loading {lora_model_path}") 241 | if not os.path.exists(lora_model_path): 242 | print("Cannot find lora model on the disk. Downloading lora model from hub...") 243 | lora_model_path = snapshot_download(repo_id=lora_model_path) 244 | tokenizer = LlamaTokenizer.from_pretrained(lora_model_path) 245 | lora_config = peft.LoraConfig.from_pretrained(lora_model_path) 246 | lora_state_dict = torch.load(os.path.join(lora_model_path,'adapter_model.bin'),map_location='cpu') 247 | if 'base_model.model.model.embed_tokens.weight' in lora_state_dict: 248 | lora_vocab_size = lora_state_dict['base_model.model.model.embed_tokens.weight'].shape[0] 249 | assert lora_vocab_size==len(tokenizer), \ 250 | (f"The vocab size of the tokenizer {len(tokenizer)} does not match the vocab size of the LoRA weight {lora_vocab_size}.\n" 251 | "Make sure that you use LLaMA tokenizer with the LLaMA-LoRA weight and Alpaca tokenizer with the Alpaca-LoRA weight!") 252 | tokenizers_and_loras.append( 253 | { 254 | "tokenizer" :tokenizer, 255 | "state_dict" :lora_state_dict, 256 | "config": lora_config, 257 | "scaling": lora_config.lora_alpha / lora_config.r, 258 | "fan_in_fan_out" : lora_config.fan_in_fan_out, 259 | }) 260 | if len(tokenizers_and_loras)==2: 261 | t1_vocab_size = len(tokenizers_and_loras[0]["tokenizer"]) 262 | t2_vocab_size = len(tokenizers_and_loras[1]["tokenizer"]) 263 | assert t1_vocab_size<=t2_vocab_size, \ 264 | (f"The vocab size of the first tokenizer is {t1_vocab_size}\n" 265 | f"The vocab size of the second tokenizer is {t2_vocab_size}, found to be smaller than {t1_vocab_size}\n" 266 | "This is not the intended use. Please check your model and tokenizer.") 267 | 268 | if not os.path.exists(base_model_path): 269 | print("Cannot find lora model on the disk. Downloading lora model from hub...") 270 | base_model_path = snapshot_download(repo_id=base_model_path) 271 | ckpt_filenames = sorted([f for f in os.listdir(base_model_path) if re.match('pytorch_model-(\d+)-of-(\d+).bin',f)]) 272 | 273 | embedding_size = None 274 | model_size = None 275 | 276 | 277 | total_size = 0 278 | for index, filename in enumerate(ckpt_filenames): 279 | print(f"Loading ckpt {filename}") 280 | state_dict = torch.load(os.path.join(base_model_path,filename), map_location='cpu') 281 | if index == 0: 282 | embedding_size = state_dict['model.embed_tokens.weight'].shape[1] 283 | model_size = emb_to_model_size[embedding_size] 284 | if output_type=='pth': 285 | params = params_of_models[model_size] 286 | num_shards = num_shards_of_models[model_size] 287 | n_layers = params["n_layers"] 288 | n_heads = params["n_heads"] 289 | dim = params["dim"] 290 | dims_per_head = dim // n_heads 291 | base = 10000.0 292 | inv_freq = 1.0 / (base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head)) 293 | print("Merging...") 294 | for k in state_dict: 295 | for tl_idx, t_and_l in enumerate(tokenizers_and_loras): 296 | saved_key = 'base_model.model.'+k 297 | lora_key_A = saved_key.replace('.weight','.lora_A.weight') 298 | if saved_key in t_and_l['state_dict']: 299 | if args.verbose: 300 | print(f"copying {saved_key} from {tl_idx}-th LoRA weight to {k}") 301 | state_dict[k] = t_and_l['state_dict'][saved_key].half().clone() # do we need half()? 302 | if lora_key_A in t_and_l['state_dict']: 303 | lora_key_B = lora_key_A.replace('lora_A.weight','lora_B.weight') 304 | if args.verbose: 305 | print(f"merging {lora_key_A} and lora_B.weight form {tl_idx}-th LoRA weight to {k}") 306 | state_dict[k] += ( 307 | transpose( 308 | t_and_l['state_dict'][lora_key_B].float() 309 | @ t_and_l['state_dict'][lora_key_A].float(), t_and_l['fan_in_fan_out']) * t_and_l['scaling'] 310 | ) 311 | weight_size = state_dict[k].numel() * dtype_byte_size(state_dict[k].dtype) 312 | total_size += weight_size 313 | 314 | if output_type=='huggingface': 315 | print(f"Saving ckpt {filename} to {output_dir} in HF format...") 316 | torch.save(state_dict,os.path.join(output_dir, filename)) 317 | elif output_type=='pth': 318 | print(f"Converting to pth format...") 319 | save_shards(model_sd=state_dict, num_shards=num_shards,prefix=f"L{index+1}-", verbose=args.verbose) 320 | del state_dict 321 | gc.collect() # Effectively enforce garbage collection 322 | 323 | 324 | print(f"Saving tokenizer") 325 | tokenizers_and_loras[-1]['tokenizer'].save_pretrained(output_dir) 326 | if output_type == 'pth': 327 | with open(output_dir + "/params.json", "w") as f: 328 | print(f"Saving params.json into {output_dir}/params.json") 329 | json.dump(params, f) 330 | merge_shards(output_dir, num_shards=num_shards) 331 | 332 | if output_type=='huggingface': 333 | configs = ('config.json', 'generation_config.json', 'pytorch_model.bin.index.json') 334 | for config in configs: 335 | if os.path.exists(os.path.join(base_model_path, config)): 336 | print(f"Saving {config}") 337 | with open(os.path.join(base_model_path, config),'r') as f: 338 | obj = json.load(f) 339 | if config=='config.json': 340 | obj['vocab_size'] = len(tokenizers_and_loras[-1]['tokenizer']) 341 | if config=='pytorch_model.bin.index.json': 342 | obj['metadata']['total_size'] = total_size 343 | with open(os.path.join(output_dir, config), 'w') as f: 344 | json.dump(obj, f, indent=2) 345 | print("Done.") 346 | --------------------------------------------------------------------------------