├── requirements.txt
├── pics
├── ZhiLu_CEVAL_20231029.png
├── ZhiLu_CMMLU_20231029.png
└── ZhiLu_CMMLU_zs_20231029.png
├── scripts
├── merge.sh
└── merge_llama_with_chinese_lora_low_mem.py
├── FAQ.md
├── .gitignore
├── LICENSE
└── README.md
/requirements.txt:
--------------------------------------------------------------------------------
1 | peft>=0.3.0
2 | torch==2.0.1
3 | transformers==4.31.0
4 | sentencepiece==0.1.97
5 | bitsandbytes==0.41.0
6 |
--------------------------------------------------------------------------------
/pics/ZhiLu_CEVAL_20231029.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SYSU-MUCFC-FinTech-Research-Center/ZhiLu/HEAD/pics/ZhiLu_CEVAL_20231029.png
--------------------------------------------------------------------------------
/pics/ZhiLu_CMMLU_20231029.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SYSU-MUCFC-FinTech-Research-Center/ZhiLu/HEAD/pics/ZhiLu_CMMLU_20231029.png
--------------------------------------------------------------------------------
/pics/ZhiLu_CMMLU_zs_20231029.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SYSU-MUCFC-FinTech-Research-Center/ZhiLu/HEAD/pics/ZhiLu_CMMLU_zs_20231029.png
--------------------------------------------------------------------------------
/scripts/merge.sh:
--------------------------------------------------------------------------------
1 | Python merge_llama_with_chinese_lora_low_mem.py \
2 | --base_model path/to/base_model \
3 | --lora_model path/to/lora_model \
4 | --output_type huggingface \
5 | --output_dir path/to/output_dir
--------------------------------------------------------------------------------
/FAQ.md:
--------------------------------------------------------------------------------
1 | ### 模型在预训练还是在指令微调阶段进行LoRA训练?
2 |
3 | 本项目在预训练和指令微调阶段均使用LoRA进行高效训练。
4 |
5 | ### 为什么对模型进行LoRA训练而非全量参数训练?
6 |
7 | 考虑到训练成本和效率等因素,我们选择在Alpaca-2的基础上使用LoRA进行训练(embedding/lm_head全量参与训练)。
8 |
9 | ### 为什么选择在Alpaca-2的基础上进行训练,而非Llama-2?
10 |
11 | Llama-2虽然已具备一定的中文理解能力,但在生成中文文本时仍然会夹杂英文,且中文词表大小有限,需要进行进一步的中文能力扩展,Alpaca-2在该方面已经做出了出色的工作。
12 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/#use-with-ide
110 | .pdm.toml
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # pytype static type analyzer
150 | .pytype/
151 |
152 | # Cython debug symbols
153 | cython_debug/
154 |
155 | # PyCharm
156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158 | # and can be added to the global gitignore or merged into this file. For a more nuclear
159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160 | #.idea/
161 |
162 | # 忽略.idea目录
163 | .idea/
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
203 |
204 |
205 | ## Some of ZhiLu's code is derived from others projects, which is subject to the following copyright notice:
206 |
207 | Copyright 2023 Yiming Cui, Ziqing Yang, Xin Yao
208 |
209 | Licensed under the Apache License, Version 2.0 (the "License");
210 | you may not use this file except in compliance with the License.
211 | You may obtain a copy of the License at
212 |
213 | http://www.apache.org/licenses/LICENSE-2.0
214 |
215 | Unless required by applicable law or agreed to in writing, software
216 | distributed under the License is distributed on an "AS IS" BASIS,
217 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
218 | See the License for the specific language governing permissions and
219 | limitations under the License.
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 | 🤗 模型下载
3 |
4 |
5 | # News
6 | [2023.10.28] 开源智鹿-13B对话大模型。🤗 [HuggingFace](https://huggingface.co/SYSU-MUCFC-FinTech-Research-Center/ZhiLu-13B-Instruct)
7 |
8 | # 目录
9 |
10 | - [智鹿](https://github.com/SYSU-MUCFC-FinTech-Research-Center/ZhiLu#%E6%99%BA%E9%B9%BF)
11 | - [介绍](https://github.com/SYSU-MUCFC-FinTech-Research-Center/ZhiLu#%E4%BB%8B%E7%BB%8D)
12 | - [训练细节](https://github.com/SYSU-MUCFC-FinTech-Research-Center/ZhiLu#%E8%AE%AD%E7%BB%83%E7%BB%86%E8%8A%82)
13 | - [性能评测](https://github.com/SYSU-MUCFC-FinTech-Research-Center/ZhiLu#%E6%80%A7%E8%83%BD%E8%AF%84%E6%B5%8B)
14 | - [对话示例](https://github.com/SYSU-MUCFC-FinTech-Research-Center/ZhiLu#%E5%AF%B9%E8%AF%9D%E7%A4%BA%E4%BE%8B)
15 | - [使用方式](https://github.com/SYSU-MUCFC-FinTech-Research-Center/ZhiLu#%E4%BD%BF%E7%94%A8%E6%96%B9%E5%BC%8F)
16 | - [Roadmap](https://github.com/SYSU-MUCFC-FinTech-Research-Center/ZhiLu#Roadmap)
17 | - [带PROMPT模板的对话](https://github.com/SYSU-MUCFC-FinTech-Research-Center/ZhiLu#%E5%B8%A6PROMPT%E6%A8%A1%E6%9D%BF%E7%9A%84%E5%AF%B9%E8%AF%9D)
18 |
19 | # 智鹿
20 |
21 | ## 介绍
22 | 智鹿是一个基于[Alpaca2-13B](https://github.com/ymcui/Chinese-LLaMA-Alpaca-2)进行二次训练的中文消金领域对话大模型,我们使用大量中英文语料进行增量预训练,同时使用高质量指令数据进行对齐,并且整合了人类价值观对齐相关数据和训练tricks。模型训练的目标是在保持通用能力的前提下,显著提升金融领域的能力。
23 |
24 | ## 训练细节
25 |
26 | - 我们收集了各类数据,包括上市公司公告、财经新闻、上市公司年度报告、新闻、金融资讯、社区问答、维基百科等高质量数据。
27 |
28 | - 模型训练的总token数为`14.69B`,通用语料与金融预料比例约为2:1,中英配比约为2:1。
29 |
30 | - 使用LoRA进行高效训练(含emb/lm-head),部分超参数设置如下:
31 |
32 | ```sh
33 | lr=2e-4
34 | lora_rank=64
35 | lora_alpha=128
36 | lora_trainable="q_proj,v_proj,k_proj,o_proj,gate_proj,down_proj,up_proj"
37 | modules_to_save="embed_tokens,lm_head"
38 | lora_dropout=0.05
39 | per_device_train_batch_size=64
40 | per_device_eval_batch_size=64
41 | gradient_accumulation_steps=1
42 | ```
43 |
44 | - 使用FlashAttention-2技术进行加速训练。
45 |
46 | ## 性能评测
47 |
48 | 我们对模型的评测分为通用领域和金融领域两方面的评测。
49 |
50 | 在通用能力评测中,我们选择当下主流的两类客观评测基准:
51 |
52 | - [C-Eval](https://cevalbenchmark.com/index.html#home) 是一个全面的中文基础模型评估基准。它包含了13948个多项选择题,涵盖了52个不同的学科和四个难度级别
53 | - [CMMLU](https://github.com/haonan-li/CMMLU) 是一个综合性的中文评估基准,专门用于评估语言模型在中文语境下的知识和推理能力。CMMLU涵盖了从基础学科到高级专业水平的67个主题。
54 |
55 | 金融能力评测的基准:
56 |
57 | - [FinEval ](https://github.com/SUFE-AIFLM-Lab/FinEval/)是一个专门为LLMs中的金融领域知识而设计的基准测试,涵盖金融、经济、会计和证书等领域。它包括4,661个问题,涵盖了34个不同的学术科目。(需要声明的是,由于FinEval并未公开test集的answer,也未开放评测接口,因此本文在FinEval上的评测上使用了val集)
58 |
59 | 与底座模型Alpaca2-13B的评测结果对比如下:
60 |
61 | | 模型 | CEVAL | CMMLU | FinEval |
62 | | ------ | ------ | ------ | ------ |
63 | | 智鹿 | 50.1 | 57.16 |58.69 |
64 | | Alpaca2-13B | 43.6 | 44.48 |45.26 |
65 |
66 | - 在通用领域评测中,智鹿对比Alpaca2-13B,保留了基本的通用能力,C-Eval得分些许提高,CMMLU得分显著提高。
67 | - 在金融领域评测FinEval中,智鹿的性能显著提高。
68 |
69 | C-Eval榜单(截止至2023.10.29)
70 |
71 |
72 | CMMLU榜单(few-shot、zero-shot)(截止至2023.10.29)
73 |
74 |
75 |
76 | ## 对话示例
77 |
78 | ### 金融
79 |
80 | ```
81 | 问题:怎么炒股才能稳赚不赔?
82 | 智鹿:股市投资是有风险的,无法保证稳赚不赔。以下是一些建议来帮助您在股票市场中实现更稳定的投资回报:1. 分散投资组合:将资金分配到不同的行业、公司和资产类别上,以降低特定投资的风险。这样即使某些投资表现不佳,其他投资也可能弥补损失。2. 长期投资策略:选择具备稳定增长潜力的优质企业,并持有它们的股票较长时间。长期投资可以减少短期市场波动对投资的影响,并获得潜在的资本增值和股息收入。3. 定期重新评估投资计划:根据市场情况和个人目标,定期检查和调整投资组合。这可以帮助您及时做出必要的调整,以适应市场的变化并保持符合您的目标。4. 控制情绪:避免受到市场噪音和情绪影响的投资决策。不要盲目跟风或恐慌出局,而是根据自己的研究和分析制定合理的投资计划。5. 学习和继续教育:了解基本的投资原理和技术分析工具,不断提升自己的知识水平。这有助于更好地理解市场趋势和投资机会,从而作出明智的决策。请注意,以上建议仅供参考,并不能保证投资成功。在进行任何投资之前,请务必咨询专业金融顾问的意见,以便为您的具体情况提供个性化的建议。
83 |
84 | 问题:什么是A股?
85 | 智鹿:简单来说,就是中国内地的股票市场。它包括上海证券交易所和深圳证券交易所。A股市场起源于1990年,是中国股票市场的主要组成部分。A股市场的特点:1. 开放性:A股市场对外开放程度较高,允许国内外投资者买卖股票。2. 风险较高:与发达国家相比,A股市场波动较大,投资风险相对较高。3. 监管严格:中国证监会负责监督和管理A股市场,并采取措施保护投资者利益。4. 长期投资机会多:尽管存在风险,但A股市场也提供了丰富的长期投资机会。对于投资者而言,了解A股市场的运作机制、相关规则和最新动态非常重要。这有助于他们做出明智的投资决策,并在A股市场中获取收益。
86 | ```
87 |
88 | ### 通用领域
89 |
90 |
91 | ```
92 | 问题:潮汕地区有什么特产?
93 | 智鹿:1.牛肉:潮汕牛肉以肉质细嫩、多汁,且不腥臊腥著称。2.虾螺贝类:汕头地区的海产品丰富,包括各种海鲜如鲍鱼、海参、石斑鱼等。3.柑橘:潮州的柑橘以其独特的清香和丰富的营养价值而闻名。4.姜撞子:是潮州特色小吃之一,将猪肉、鸡蛋、木姜黄等馅料包裹在糯米粉里制成。5.牛肉炒米:用牛肉和糯米制作而成,口感软糯,深受人们喜爱。6.沙士:是一种传统的甜点,由糯米粉制成,并加入花生碎、芝麻等配料。7.牛肉丸:潮汕地区的牛肉丸因其Q弹口感而著名,通常配上辣椒酱或特制的汤汁食用。8.大龙粿:一种传统潮州糕点,采用粘糯米制成,内陷有豆沙或红绿丝等。9.牛肉火锅:潮汕地区有许多著名的牛肉火锅店,其特点是使用高质量的牛肉和独特的调料。10.冬至面:在冬至日,许多家庭会准备热气腾腾的冬至面,其中包含着多种肉类和蔬菜。以上只是一些常见的潮汕美食和特产,还有更多值得探索的地方和美食等待你去发现。
94 |
95 | 问题:晚上睡不着怎么办?
96 | 智鹿:1. 做点轻松的事,如阅读、写字等。2. 听音乐,放松身心。3. 数羊或数数,转移注意力。4. 洗个澡,让自己放松下来。5. 出去走走,呼吸新鲜空气。6. 与朋友聊天,分享烦恼并寻求帮助。7. 写日记,记录今天发生的事情和感受。8. 运动运动,释放压力和紧张情绪。9. 冥想或练习深呼吸,平静自己的心神。10. 看喜剧片,暂时忘记忧愁。
97 |
98 | 问题:介绍一下中国象棋的规则。
99 | 智鹿:1. 走法:双方各自有15块棋子,包括王、车、兵、相、马和炮。每种单位都有特定的走法,例如,王只能前进一步或退后到安全的位置,而其他单位则可以根据具体情况选择不同的走法。2. 控制区域:每个玩家拥有自己的半盘,称为“阵营”。在战斗中,可以将对方的单位推回自己阵营之外的区域内,从而控制该区域。3. 攻击方式:不同类型的单位具有不同的攻击能力。例如,车可以向前推进一个空格,同时可以把敌人的单位移开;兵则可以从初始位置出发,沿着直线向对手发起进攻。4. 防御策略:为了保护己方单位免受攻击,玩家可以通过调整布局来制定有效的防御策略。这可能涉及到隐藏单位、利用障碍物(如山石)以及进行交换等手段。5. 中国象棋的历史和文化价值:中国象棋起源于中国古代,有着悠久的历史。它不仅作为一种智力游戏受到人们喜爱,还蕴含着丰富的哲学思想和道德修养。通过学习和参与中国象棋比赛,人们可以增强团队合作意识、培养决断力和战略思维能力。
100 | ```
101 |
102 | ## 使用方式
103 |
104 | 我们已经把智鹿的模型权重发布在HuggingFace,由于我们的底座模型是Alpaca2,而Alpaca2是基于Llama2,因此可以使用Llama2-hf的调用方式,如下:
105 |
106 | 我们提供两种模型调用方式:`模型合并后直接加载`和`底座模型、LoRA模块分开加载`
107 |
108 | 温馨提醒:
109 |
110 | 1、使用智鹿需要约30GB显存。
111 |
112 | 2、prompt的结尾尽量加上标点符号,否则模型可能会续写prompt。
113 |
114 | ### 环境安装
115 | 使用 pip 安装依赖:
116 | ```
117 | pip install -r requirements.txt
118 | ```
119 | 其中 `transformers` 库版本推荐为 `4.30`以上,`torch` 推荐使用 2.0 及以上的版本,以获得最佳的推理性能。
120 |
121 | ### 合并后使用
122 |
123 | 需将Alpaca-2-13b与我们提供的LoRA模块进行合并后,才能使用,合并脚本:`\scripts\merge.sh`
124 |
125 | 模型使用时只需提供合并后的地址:
126 |
127 | ```python
128 | import torch
129 | from transformers import LlamaForCausalLM, LlamaTokenizer
130 | model_name_or_path = ""
131 | tokenizer = LlamaTokenizer.from_pretrained(model_name_or_path, use_fast=False, legacy=True)
132 | model = LlamaForCausalLM.from_pretrained(model_name_or_path, torch_dtype=torch.bfloat16,device_map="auto")
133 | inputs = tokenizer("什么是A股?", return_tensors="pt").to("cuda")
134 | outputs = model.generate(**inputs, max_new_tokens=64, repetition_penalty=1.1)
135 | outputs = tokenizer.decode(outputs.cpu()[0][len(inputs.input_ids[0]):], skip_special_tokens=True)
136 | print(outputs)
137 | ```
138 |
139 | ### 加载LoRA模块
140 |
141 | 需提供Alpaca-2-13b和LoRA模块的地址:
142 |
143 | ```python
144 | import torch
145 | from transformers import LlamaForCausalLM, LlamaTokenizer
146 | from peft import PeftModel
147 | model_name_or_path = ""
148 | peft_model_path = ""
149 | tokenizer = LlamaTokenizer.from_pretrained(model_name_or_path, use_fast=False, legacy=True)
150 | model = LlamaForCausalLM.from_pretrained(model_name_or_path, torch_dtype=torch.bfloat16,device_map="auto")
151 | if peft_model_path is not None:
152 | model = PeftModel.from_pretrained(
153 | model,
154 | peft_model_path,
155 | torch_dtype=(
156 | torch.bfloat16
157 | if torch.cuda.is_bf16_supported()
158 | else torch.float32
159 | ),
160 | )
161 | inputs = tokenizer("什么是A股?", return_tensors="pt").to("cuda")
162 | outputs = model.generate(**inputs, max_new_tokens=64, repetition_penalty=1.1)
163 | outputs = tokenizer.decode(outputs.cpu()[0][len(inputs.input_ids[0]):], skip_special_tokens=True)
164 | print(outputs)
165 | ```
166 |
167 | ## 带PROMPT模板的对话
168 |
169 | ### 单轮对话
170 |
171 | ```python
172 | import os
173 | import torch
174 | from transformers import (
175 | LlamaForCausalLM,
176 | LlamaTokenizer,
177 | GenerationConfig
178 | )
179 | from peft import PeftModel
180 | model_name_or_path = ""
181 | peft_model_path = ""
182 | tokenizer = LlamaTokenizer.from_pretrained(model_name_or_path, use_fast=False, legacy=True)
183 | model = LlamaForCausalLM.from_pretrained(model_name_or_path, torch_dtype=torch.bfloat16,device_map="auto")
184 | if peft_model_path is not None:
185 | model = PeftModel.from_pretrained(
186 | model,
187 | peft_model_path,
188 | torch_dtype=(
189 | torch.bfloat16
190 | if torch.cuda.is_bf16_supported()
191 | else torch.float32
192 | ),
193 | )
194 | print(model)
195 | generation_config = GenerationConfig.from_pretrained(model_name_or_path)
196 | config_updates = {
197 | 'do_sample': True,
198 | 'num_beams': 1,
199 | 'repetition_penalty': 1.1,
200 | 'min_new_tokens': 20,
201 | 'top_k' : 50,
202 | 'top_p' : 0.95,
203 | 'temperature' : 0.7
204 | }
205 | # 更新生成配置
206 | for key, value in config_updates.items():
207 | setattr(generation_config, key, value)
208 | prompt_template = (
209 | "[INST] <>\n"
210 | "You are a helpful assistant. 你是一个乐于助人的助手。\n"
211 | "<>\n\n"
212 | "{instruction} [/INST]"
213 | )
214 | instruction = "什么是A股?"
215 | prompt = prompt_template.format_map({'instruction':instruction})
216 | input_ids = tokenizer.encode(prompt, return_tensors="pt").cuda()
217 | output = model.generate(inputs = input_ids,generation_config = generation_config, return_dict_in_generate=True, num_return_sequences=1)
218 | output_text = tokenizer.decode(output.sequences[0], skip_special_tokens=True)
219 | print(output_text)
220 |
221 | '''
222 | [INST] <>
223 | You are a helpful assistant. 你是一个乐于助人的助手。
224 | <>
225 |
226 | 什么是A股? [/INST] A股指的是中国内地的股票市场,也被称为上海证券交易所和深圳证券交易所的主板股票市场。它是中国证券市场的主要部分,包括了许多大型上市公司。与此同时,A股市场也包括了中小型企业和创业公司的股票交易。A股市场的特点是具有较高的波动性,投资者可以通过购买股票来参与市场的买卖活动,并从中获得资本收益或分红收益。在A股市场上,股价会受到多种因素的影响,例如公司业绩、宏观经济状况以及政府政策等。因此,了解A股市场的基本概念和运作方式对于投资决策非常重要。
227 | '''
228 | ```
229 |
230 | ### 多轮对话
231 |
232 | ```python
233 | import os
234 | import torch
235 | from transformers import (
236 | LlamaForCausalLM,
237 | LlamaTokenizer,
238 | GenerationConfig
239 | )
240 | from peft import PeftModel
241 | model_name_or_path = ""
242 | peft_model_path = ""
243 | tokenizer = LlamaTokenizer.from_pretrained(model_name_or_path, use_fast=False, legacy=True)
244 | model = LlamaForCausalLM.from_pretrained(model_name_or_path, torch_dtype=torch.bfloat16,device_map="auto")
245 | if peft_model_path is not None:
246 | model = PeftModel.from_pretrained(
247 | model,
248 | peft_model_path,
249 | torch_dtype=(
250 | torch.bfloat16
251 | if torch.cuda.is_bf16_supported()
252 | else torch.float32
253 | ),
254 | )
255 | print(model)
256 | generation_config = GenerationConfig.from_pretrained(model_name_or_path)
257 | config_updates = {
258 | 'do_sample': True,
259 | 'num_beams': 1,
260 | 'repetition_penalty': 1.1,
261 | 'min_new_tokens': 20,
262 | 'top_k' : 50,
263 | 'top_p' : 0.95,
264 | 'temperature' : 0.7
265 | }
266 | # 更新生成配置
267 | for key, value in config_updates.items():
268 | setattr(generation_config, key, value)
269 | def multi_generate(history, insturction):
270 | history_str = "".join(history)
271 | prompt = history_str + insturction
272 | input_ids = tokenizer.encode(prompt, return_tensors="pt").cuda()
273 | output = model.generate(inputs = input_ids,generation_config = generation_config, return_dict_in_generate=True, num_return_sequences=1)
274 | output_text = tokenizer.decode(output.sequences[0], skip_special_tokens=True)
275 | history.append(output_text.replace(history_str, ""))
276 | answer = output_text.replace(history_str, "")
277 | answer = answer.replace(insturction, "")
278 | # answer 已去掉上下文及用户输入的文本
279 | return answer
280 | history = []
281 | cnt = 1
282 | prompt_template = (
283 | "[INST] <>\n"
284 | "You are a helpful assistant. 你是一个乐于助人的助手。\n"
285 | "<>\n\n"
286 | "{instruction} [/INST]"
287 | )
288 | print(f"-------------round {cnt} --------------")
289 | cnt += 1
290 | user_input = input("user:")
291 | user_input = prompt_template.format_map({'instruction':user_input})
292 | answer = multi_generate(history, user_input)
293 | print(f"ZhiLu:{answer}")
294 | while True:
295 | print(f"-------------round {cnt} --------------")
296 | cnt += 1
297 | user_input = input("user:")
298 | user_input = f"[INST] {user_input} [/INST]"
299 | answer = multi_generate(history, user_input)
300 | print(f"ZhiLu:{answer}")
301 |
302 | '''
303 | -------------round 1 --------------
304 | user:什么是A股?
305 | ZhiLu:A股市场是中国股票市场的一个重要组成部分,指代在上海证券交易所和深圳证券交易所上市的股票。它是中国内地投资者进行股票交易的主要场所。A股市场的特点如下:- A股市场是中国内地股票市场的主体,也是全球最大的股票市场之一。- A股市场的股票只能在中国境内交易,且只有中国公民和机构可以参与该市场。- A股市场的交易时间一般为周一至周五上午9:30至下午3:00,中午有一个短暂的休市时间。- A股市场实行T+1制度,即当日买入的股票次日才能卖出,有助于减少短期投机行为。- A股市场的股票价格通常会受到多种因素的影响,包括经济状况、政策变化、行业动态和公司业绩等。总而言之,A股市场是中国内地股票市场,提供了广泛的投资机会和灵活的交易机制,但也存在相应的风险需要投资者谨慎对待。
306 | -------------round 2 --------------
307 | user:在哪可以看A股?
308 | ZhiLu:可以通过以下几种方式查看A股市场的数据和行情:1. 股票交易软件:许多券商都提供自家的手机应用或网页版交易软件,你可以在这些平台上查询股票的价格、成交量、涨跌幅等相关指标。2. 财经网站:像新浪财经、东方财富网等财经网站都可以提供实时的股票行情、新闻资讯和研究报告等信息。3. 电视台财经频道:观看财经类电视台如CCTV证券、上海证券报等,他们会定期报道股市动态、新闻报道和专家观点等。无论使用哪种方法,记得始终注意信息来源的可靠性和准确性,同时也要注意市场风险和自身投资能力。
309 | '''
310 | ```
311 |
312 |
313 |
314 | # Roadmap
315 |
316 | - [x] 底座类型
317 | - [x] 交互(Chat)型
318 | - [ ] 指令(Instruct)型
319 | - [x] 数据增强
320 | - [x] 通用场景知识
321 | - [ ] 消金场景知识增强
322 | - [x] 训练模块
323 | - [x] 常规LoRA模块
324 | - [ ] 预训练&指令微调
325 | - [ ] 人类偏好反馈增强(RLHF、DPO、PRO等)
326 | - [x] 长度外推
327 | - [x] 原生支持4k
328 | - [ ] 支持8k、16k等长度版本
329 | - [x] 性能优化模块
330 | - [x] 支持Flash-Attention-2
331 | - [ ] 支持vLLM、Streaming-LLM等加速模块
332 |
333 | # FAQ
334 |
335 | 项目答疑以及部署中遇到问题的解决方案,请查阅[FAQ](https://github.com/SYSU-MUCFC-FinTech-Research-Center/ZhiLu/blob/main/FAQ.md)
336 |
337 | # 致谢
338 |
339 | 本项目主要基于Alpaca2开源项目进行二次开发,在此对相关项目和研究开发人员表示感谢。
340 |
341 | # 免责声明与许可协议
342 | 本仓库的代码依照[Apache-2.0](https://github.com/SYSU-MUCFC-FinTech-Research-Center/ZhiLu/blob/main/LICENSE)协议开源。模型权重对学术研究完全开放;使用者也可通过[申请表](https://wj.qq.com/s2/13390238/fea9/)申请、经同意并发放商用授权证书后免费用于商业用途。
343 | 尽管我们在模型训练过程中尽力确保数据的合规性和准确性,但由于模型受概率随机性因素影响及易被误导,无法保证输出内容的准确性。因此,使用者在使用本模型及其生成的内容时,应自行审慎识别后作出独立判断,必要时应征询专业人士意见,并由使用者承担使用风险。使用者也不得将本模型用于任何可能给国家和社会带来危害的用途以及用于任何未经过安全评估和备案的服务。我们不承担开源模型及其生成的内容导致的安全风险、知识产权风险、舆情风险或发生任何模型被误导、滥用、不当利用及传播而产生的风险和责任。
344 |
345 | # 总结
346 |
347 | 我们鼓励使用者在相关工作中引用智鹿,以促进知识的共享和交流,并为中文金融对话系统的不断发展贡献力量。
348 | 智鹿的发布旨在为金融领域的应用和研究提供有力支持,为中文金融对话系统的进步做出积极贡献。我们期待见证更多的创新和应用案例,以提升金融服务和用户体验,同时也推动人工智能技术在金融领域的蓬勃发展。通过合作和分享,我们可以共同推动这一领域的发展,为社会和行业带来更多的好处。
349 |
--------------------------------------------------------------------------------
/scripts/merge_llama_with_chinese_lora_low_mem.py:
--------------------------------------------------------------------------------
1 | """
2 | Usage:
3 | python merge_llama_with_chinese_lora_low_mem.py \
4 | --base_model /mnt/data2/finLLM/models/alpaca-2-13b \
5 | --lora_model /mnt/data2/finLLM/llama-alpaca-2-main/pt_lora_model \
6 | --output_type huggingface \
7 | --output_dir /mnt/data2/finLLM/test
8 | """
9 | import argparse
10 | import json
11 | import os
12 | import gc
13 | import torch
14 | import peft
15 | from transformers import LlamaTokenizer
16 | from transformers.modeling_utils import dtype_byte_size
17 | from huggingface_hub import snapshot_download
18 | import re
19 |
20 | parser = argparse.ArgumentParser()
21 | parser.add_argument('--base_model', default=None, required=True,
22 | type=str, help="Please specify a base model")
23 | parser.add_argument('--lora_model', default=None, required=True,
24 | type=str, help="Please specify LoRA models to be merged (ordered); use commas to separate multiple LoRA models")
25 | parser.add_argument('--output_type', default='pth',choices=['pth','huggingface'],
26 | type=str, help="Save the merged model in pth or huggingface format")
27 | parser.add_argument('--output_dir', default='./merged_model',
28 | type=str, help="The output folder to save the merged model")
29 | parser.add_argument('--verbose', default=False, action='store_true',
30 | help="Show detailed messages")
31 |
32 |
33 | emb_to_model_size = {
34 | 4096 : '7B',
35 | 5120 : '13B',
36 | 6656 : '33B',
37 | 8192 : '65B',
38 | }
39 | num_shards_of_models = {'7B': 1, '13B': 2, '33B': 4, '65B': 8}
40 | params_of_models = {
41 | '7B':
42 | {
43 | "dim": 4096,
44 | "multiple_of": 256,
45 | "n_heads": 32,
46 | "n_layers": 32,
47 | "norm_eps": 1e-06,
48 | "vocab_size": -1,
49 | },
50 | '13B':
51 | {
52 | "dim": 5120,
53 | "multiple_of": 256,
54 | "n_heads": 40,
55 | "n_layers": 40,
56 | "norm_eps": 1e-06,
57 | "vocab_size": -1,
58 | },
59 | '33B':
60 | {
61 | "dim": 6656,
62 | "multiple_of": 256,
63 | "n_heads": 52,
64 | "n_layers": 60,
65 | "norm_eps": 1e-06,
66 | "vocab_size": -1,
67 | },
68 | '65B':
69 | {
70 | "dim": 8192,
71 | "multiple_of": 256,
72 | "n_heads": 64,
73 | "n_layers": 80,
74 | "norm_eps": 1e-05,
75 | "vocab_size": -1,
76 | },
77 | }
78 |
79 | def transpose(weight, fan_in_fan_out):
80 | return weight.T if fan_in_fan_out else weight
81 |
82 | # Borrowed and modified from https://github.com/tloen/alpaca-lora
83 | def translate_state_dict_key(k):
84 | k = k.replace("base_model.model.", "")
85 | if k == "model.embed_tokens.weight":
86 | return "tok_embeddings.weight"
87 | elif k == "model.norm.weight":
88 | return "norm.weight"
89 | elif k == "lm_head.weight":
90 | return "output.weight"
91 | elif k.startswith("model.layers."):
92 | layer = k.split(".")[2]
93 | if k.endswith(".self_attn.q_proj.weight"):
94 | return f"layers.{layer}.attention.wq.weight"
95 | elif k.endswith(".self_attn.k_proj.weight"):
96 | return f"layers.{layer}.attention.wk.weight"
97 | elif k.endswith(".self_attn.v_proj.weight"):
98 | return f"layers.{layer}.attention.wv.weight"
99 | elif k.endswith(".self_attn.o_proj.weight"):
100 | return f"layers.{layer}.attention.wo.weight"
101 | elif k.endswith(".mlp.gate_proj.weight"):
102 | return f"layers.{layer}.feed_forward.w1.weight"
103 | elif k.endswith(".mlp.down_proj.weight"):
104 | return f"layers.{layer}.feed_forward.w2.weight"
105 | elif k.endswith(".mlp.up_proj.weight"):
106 | return f"layers.{layer}.feed_forward.w3.weight"
107 | elif k.endswith(".input_layernorm.weight"):
108 | return f"layers.{layer}.attention_norm.weight"
109 | elif k.endswith(".post_attention_layernorm.weight"):
110 | return f"layers.{layer}.ffn_norm.weight"
111 | elif k.endswith("rotary_emb.inv_freq") or "lora" in k:
112 | return None
113 | else:
114 | print(layer, k)
115 | raise NotImplementedError
116 | else:
117 | print(k)
118 | raise NotImplementedError
119 |
120 |
121 | def unpermute(w):
122 | return (
123 | w.view(n_heads, 2, dim // n_heads // 2, dim).transpose(1, 2).reshape(dim, dim)
124 | )
125 |
126 |
127 | def save_shards(model_sd, num_shards: int, prefix="", verbose=False):
128 | """
129 | Convert and save the HF format weights to PTH format weights
130 | """
131 | with torch.no_grad():
132 | if num_shards == 1:
133 | new_state_dict = {}
134 | for k, v in model_sd.items():
135 | new_k = translate_state_dict_key(k)
136 | if new_k is not None:
137 | if "wq" in new_k or "wk" in new_k:
138 | new_state_dict[new_k] = unpermute(v)
139 | else:
140 | new_state_dict[new_k] = v
141 |
142 | os.makedirs(output_dir, exist_ok=True)
143 | print(f"Saving shard 1 of {num_shards} into {output_dir}/{prefix}consolidated.00.pth")
144 | torch.save(new_state_dict, output_dir + f"/{prefix}consolidated.00.pth")
145 | else:
146 | new_state_dicts = [dict() for _ in range(num_shards)]
147 | for k in list(model_sd.keys()):
148 | v = model_sd[k]
149 | new_k = translate_state_dict_key(k)
150 | if new_k is not None:
151 | if new_k=='tok_embeddings.weight':
152 | assert v.size(1)%num_shards==0
153 | splits = v.split(v.size(1)//num_shards,dim=1)
154 | elif new_k=='output.weight':
155 | if v.size(0)%num_shards==0:
156 | splits = v.split(v.size(0)//num_shards,dim=0)
157 | else:
158 | size_list = [v.size(0)//num_shards] * num_shards
159 | size_list[-1] += v.size(0)%num_shards
160 | splits = v.split(size_list, dim=0) # 13B: size_list == [24976,24977]
161 | elif new_k=='norm.weight':
162 | splits = [v] * num_shards
163 | elif 'ffn_norm.weight' in new_k:
164 | splits = [v] * num_shards
165 | elif 'attention_norm.weight' in new_k:
166 | splits = [v] * num_shards
167 |
168 |
169 | elif 'w1.weight' in new_k:
170 | splits = v.split(v.size(0)//num_shards,dim=0)
171 | elif 'w2.weight' in new_k:
172 | splits = v.split(v.size(1)//num_shards,dim=1)
173 | elif 'w3.weight' in new_k:
174 | splits = v.split(v.size(0)//num_shards,dim=0)
175 |
176 |
177 | elif 'wo.weight' in new_k:
178 | splits = v.split(v.size(1)//num_shards,dim=1)
179 |
180 | elif 'wv.weight' in new_k:
181 | splits = v.split(v.size(0)//num_shards,dim=0)
182 |
183 | elif "wq.weight" in new_k or "wk.weight" in new_k:
184 | v = unpermute(v)
185 | splits = v.split(v.size(0)//num_shards,dim=0)
186 | else:
187 | print(f"Unexpected key {new_k}")
188 | raise ValueError
189 | if verbose:
190 | print(f"Processing {new_k}")
191 | for sd,split in zip(new_state_dicts,splits):
192 | sd[new_k] = split.clone()
193 | del split
194 | del splits
195 | del model_sd[k],v
196 | gc.collect() # Effectively enforce garbage collection
197 |
198 | os.makedirs(output_dir, exist_ok=True)
199 | for i,new_state_dict in enumerate(new_state_dicts):
200 | print(f"Saving shard {i+1} of {num_shards} into {output_dir}/{prefix}consolidated.0{i}.pth")
201 | torch.save(new_state_dict, output_dir + f"/{prefix}consolidated.0{i}.pth")
202 |
203 | def merge_shards(output_dir, num_shards: int):
204 | ckpt_filenames = sorted([f for f in os.listdir(output_dir) if re.match('L(\d+)-consolidated.(\d+).pth',f)])
205 |
206 | for i in range(num_shards):
207 | shards_filenames = sorted([f for f in ckpt_filenames if re.match(f'L(\d+)-consolidated.0{i}.pth',f)])
208 | print(f"Loading {shards_filenames} ...")
209 | shards_dicts = [torch.load(os.path.join(output_dir,fn)) for fn in shards_filenames]
210 | shards_merged = {}
211 | for d in shards_dicts:
212 | shards_merged |= d
213 |
214 | print(f"Saving the merged shard to " + os.path.join(output_dir, f"consolidated.0{i}.pth"))
215 | torch.save(shards_merged, os.path.join(output_dir, f"consolidated.0{i}.pth"))
216 |
217 | print("Cleaning up...")
218 | del shards_merged
219 | for d in shards_dicts:
220 | del d
221 | del shards_dicts
222 | gc.collect() # Effectively enforce garbage collection
223 | for fn in shards_filenames:
224 | os.remove(os.path.join(output_dir,fn))
225 |
226 | if __name__=='__main__':
227 |
228 | args = parser.parse_args()
229 | base_model_path = args.base_model
230 | lora_model_paths = [s.strip() for s in args.lora_model.split(',') if len(s.strip())!=0]
231 | output_dir = args.output_dir
232 | output_type = args.output_type
233 | os.makedirs(output_dir, exist_ok=True)
234 |
235 | print(f"Base model: {base_model_path}")
236 | print(f"LoRA model(s) {lora_model_paths}:")
237 |
238 | tokenizers_and_loras = []
239 | for lora_model_path in lora_model_paths:
240 | print(f"Loading {lora_model_path}")
241 | if not os.path.exists(lora_model_path):
242 | print("Cannot find lora model on the disk. Downloading lora model from hub...")
243 | lora_model_path = snapshot_download(repo_id=lora_model_path)
244 | tokenizer = LlamaTokenizer.from_pretrained(lora_model_path)
245 | lora_config = peft.LoraConfig.from_pretrained(lora_model_path)
246 | lora_state_dict = torch.load(os.path.join(lora_model_path,'adapter_model.bin'),map_location='cpu')
247 | if 'base_model.model.model.embed_tokens.weight' in lora_state_dict:
248 | lora_vocab_size = lora_state_dict['base_model.model.model.embed_tokens.weight'].shape[0]
249 | assert lora_vocab_size==len(tokenizer), \
250 | (f"The vocab size of the tokenizer {len(tokenizer)} does not match the vocab size of the LoRA weight {lora_vocab_size}.\n"
251 | "Make sure that you use LLaMA tokenizer with the LLaMA-LoRA weight and Alpaca tokenizer with the Alpaca-LoRA weight!")
252 | tokenizers_and_loras.append(
253 | {
254 | "tokenizer" :tokenizer,
255 | "state_dict" :lora_state_dict,
256 | "config": lora_config,
257 | "scaling": lora_config.lora_alpha / lora_config.r,
258 | "fan_in_fan_out" : lora_config.fan_in_fan_out,
259 | })
260 | if len(tokenizers_and_loras)==2:
261 | t1_vocab_size = len(tokenizers_and_loras[0]["tokenizer"])
262 | t2_vocab_size = len(tokenizers_and_loras[1]["tokenizer"])
263 | assert t1_vocab_size<=t2_vocab_size, \
264 | (f"The vocab size of the first tokenizer is {t1_vocab_size}\n"
265 | f"The vocab size of the second tokenizer is {t2_vocab_size}, found to be smaller than {t1_vocab_size}\n"
266 | "This is not the intended use. Please check your model and tokenizer.")
267 |
268 | if not os.path.exists(base_model_path):
269 | print("Cannot find lora model on the disk. Downloading lora model from hub...")
270 | base_model_path = snapshot_download(repo_id=base_model_path)
271 | ckpt_filenames = sorted([f for f in os.listdir(base_model_path) if re.match('pytorch_model-(\d+)-of-(\d+).bin',f)])
272 |
273 | embedding_size = None
274 | model_size = None
275 |
276 |
277 | total_size = 0
278 | for index, filename in enumerate(ckpt_filenames):
279 | print(f"Loading ckpt {filename}")
280 | state_dict = torch.load(os.path.join(base_model_path,filename), map_location='cpu')
281 | if index == 0:
282 | embedding_size = state_dict['model.embed_tokens.weight'].shape[1]
283 | model_size = emb_to_model_size[embedding_size]
284 | if output_type=='pth':
285 | params = params_of_models[model_size]
286 | num_shards = num_shards_of_models[model_size]
287 | n_layers = params["n_layers"]
288 | n_heads = params["n_heads"]
289 | dim = params["dim"]
290 | dims_per_head = dim // n_heads
291 | base = 10000.0
292 | inv_freq = 1.0 / (base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head))
293 | print("Merging...")
294 | for k in state_dict:
295 | for tl_idx, t_and_l in enumerate(tokenizers_and_loras):
296 | saved_key = 'base_model.model.'+k
297 | lora_key_A = saved_key.replace('.weight','.lora_A.weight')
298 | if saved_key in t_and_l['state_dict']:
299 | if args.verbose:
300 | print(f"copying {saved_key} from {tl_idx}-th LoRA weight to {k}")
301 | state_dict[k] = t_and_l['state_dict'][saved_key].half().clone() # do we need half()?
302 | if lora_key_A in t_and_l['state_dict']:
303 | lora_key_B = lora_key_A.replace('lora_A.weight','lora_B.weight')
304 | if args.verbose:
305 | print(f"merging {lora_key_A} and lora_B.weight form {tl_idx}-th LoRA weight to {k}")
306 | state_dict[k] += (
307 | transpose(
308 | t_and_l['state_dict'][lora_key_B].float()
309 | @ t_and_l['state_dict'][lora_key_A].float(), t_and_l['fan_in_fan_out']) * t_and_l['scaling']
310 | )
311 | weight_size = state_dict[k].numel() * dtype_byte_size(state_dict[k].dtype)
312 | total_size += weight_size
313 |
314 | if output_type=='huggingface':
315 | print(f"Saving ckpt {filename} to {output_dir} in HF format...")
316 | torch.save(state_dict,os.path.join(output_dir, filename))
317 | elif output_type=='pth':
318 | print(f"Converting to pth format...")
319 | save_shards(model_sd=state_dict, num_shards=num_shards,prefix=f"L{index+1}-", verbose=args.verbose)
320 | del state_dict
321 | gc.collect() # Effectively enforce garbage collection
322 |
323 |
324 | print(f"Saving tokenizer")
325 | tokenizers_and_loras[-1]['tokenizer'].save_pretrained(output_dir)
326 | if output_type == 'pth':
327 | with open(output_dir + "/params.json", "w") as f:
328 | print(f"Saving params.json into {output_dir}/params.json")
329 | json.dump(params, f)
330 | merge_shards(output_dir, num_shards=num_shards)
331 |
332 | if output_type=='huggingface':
333 | configs = ('config.json', 'generation_config.json', 'pytorch_model.bin.index.json')
334 | for config in configs:
335 | if os.path.exists(os.path.join(base_model_path, config)):
336 | print(f"Saving {config}")
337 | with open(os.path.join(base_model_path, config),'r') as f:
338 | obj = json.load(f)
339 | if config=='config.json':
340 | obj['vocab_size'] = len(tokenizers_and_loras[-1]['tokenizer'])
341 | if config=='pytorch_model.bin.index.json':
342 | obj['metadata']['total_size'] = total_size
343 | with open(os.path.join(output_dir, config), 'w') as f:
344 | json.dump(obj, f, indent=2)
345 | print("Done.")
346 |
--------------------------------------------------------------------------------