├── .gitignore
├── README.md
├── app.py
├── data
├── huanhuan.json
└── huanhuan_xtuner.json
├── images
├── Extract-Dialogue.png
├── chat嬛嬛.png
├── compass_support.svg
├── huanhuan_chat.png
├── huanhuan_img.png
├── license.svg
├── logo.png
├── modelscope.png
├── modelscope_logo.png
├── openxlab.png
├── openxlab_model.jpg
└── tech_route.svg
├── requirements.txt
├── results
├── huanhuan_lianghua.csv
├── huanhuan_ori.csv
├── internlm2_kv.csv
└── internlm2_ori.csv
├── start.py
└── train
├── internlm2-chat-lora.ipynb
├── internlm2_1_8b_full_oasst1_e3_huanhuan.py
└── internlm2_chat_7b_qlora_oasst1_e3_copy.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # poetry
98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102 | #poetry.lock
103 |
104 | # pdm
105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106 | #pdm.lock
107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108 | # in version control.
109 | # https://pdm.fming.dev/#use-with-ide
110 | .pdm.toml
111 |
112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113 | __pypackages__/
114 |
115 | # Celery stuff
116 | celerybeat-schedule
117 | celerybeat.pid
118 |
119 | # SageMath parsed files
120 | *.sage.py
121 |
122 | # Environments
123 | .env
124 | .venv
125 | env/
126 | venv/
127 | ENV/
128 | env.bak/
129 | venv.bak/
130 |
131 | # Spyder project settings
132 | .spyderproject
133 | .spyproject
134 |
135 | # Rope project settings
136 | .ropeproject
137 |
138 | # mkdocs documentation
139 | /site
140 |
141 | # mypy
142 | .mypy_cache/
143 | .dmypy.json
144 | dmypy.json
145 |
146 | # Pyre type checker
147 | .pyre/
148 |
149 | # pytype static type analyzer
150 | .pytype/
151 |
152 | # Cython debug symbols
153 | cython_debug/
154 |
155 | # PyCharm
156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158 | # and can be added to the global gitignore or merged into this file. For a more nuclear
159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160 | #.idea/
161 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Chat-嬛嬛
2 |
3 |
4 |

5 |
6 | Chat-嬛嬛
7 |
8 |
9 | [![license][license-image]][license-url]
10 | [![evaluation][evaluation-image]][evaluation-url]
11 |
12 | [🤗HuggingFace]() | [![OpenXLab_Model][OpenXLab_Model-image]][OpenXLab_Model-url] | [

ModelScope][ModelScope-url]
13 |
14 | [![OpenXLab_App][OpenXLab_App-image]][OpenXLab_App-url] | [🆕Update News](#-news) | [🤔Reporting Issues][Issues-url] 丨 [![bilibili][bilibili-image]][bilibili-url]
15 |
16 | [English](./README_en-US.md) | [简体中文](./README.md)
17 |
18 |
19 |
20 | [license-image]: ./images/license.svg
21 | [evaluation-image]: ./images/compass_support.svg
22 | [OpenXLab_Model-image]: https://cdn-static.openxlab.org.cn/header/openxlab_models.svg
23 | [OpenXLab_App-image]: https://cdn-static.openxlab.org.cn/app-center/openxlab_app.svg
24 | [bilibili-image]: https://img.shields.io/badge/AMchat-bilibili-%23fb7299
25 |
26 | [license-url]: ./LICENSE
27 | [evaluation-url]: https://github.com/internLM/OpenCompass/
28 | [OpenXLab_Model-url]: https://openxlab.org.cn/models/detail/BYCJS/huanhuan-chat-internlm2-1_8b
29 | [OpenXLab_App-url]: https://openxlab.org.cn/apps/detail/BYCJS/Chat_huanhuan
30 | [bilibili-url]: https://www.bilibili.com/video/——/
31 | [ModelScope-url]: https://www.modelscope.cn/models/kmno4zx/huanhuan-chat-internlm2-1_8b/summary
32 | [Issues-url]: https://github.com/KMnO4-zx/xlab-huanhuan/issues
33 |
34 |
35 |
36 | ## 📝目录
37 |
38 | - [Chat-嬛嬛](#chat-嬛嬛)
39 | - [📝目录](#目录)
40 | - [📖 简介](#-简介)
41 | - [🔗 模型及体验地址](#-模型及体验地址)
42 | - [🚀 News](#-news)
43 | - [🧾 数据集](#-数据集)
44 | - [🛠️ 使用方法](#️-使用方法)
45 | - [快速开始](#快速开始)
46 | - [重新训练](#重新训练)
47 | - [环境搭建](#环境搭建)
48 | - [Transformers微调](#transformers微调)
49 | - [XTuner微调](#xtuner微调)
50 | - [部署](#部署)
51 | - [OpenXLab 部署 Chat-嬛嬛](#openxlab-部署-chat-嬛嬛)
52 | - [LmDeploy部署](#lmdeploy部署)
53 | - [测评与量化](#测评与量化)
54 | - [OpneCompass 评测](#opnecompass-评测)
55 | - [Lmdeploy\&opencompass 量化以及量化评测](#lmdeployopencompass-量化以及量化评测)
56 | - [`W4`量化评测](#w4量化评测)
57 | - [`KV Cache`量化评测](#kv-cache量化评测)
58 | - [💕 致谢](#-致谢)
59 | - [项目成员](#项目成员)
60 | - [特别感谢](#特别感谢)
61 |
62 |
63 | ## 📖 简介
64 |
65 | > *此仓库主要用于将 Chat嬛嬛 项目部署到 OpenXLab 或 ModelScope 。*
66 |
67 | Chat-甄嬛是利用《甄嬛传》剧本中所有关于甄嬛的台词和语句,基于[InternLM2](https://github.com/InternLM/InternLM.git)进行LoRA微调或全量微调得到的模仿甄嬛语气的聊天语言模型。
68 |
69 | > 甄嬛,小说《后宫·甄嬛传》和电视剧《甄嬛传》中的女一号,核心女主角。原名甄玉嬛,嫌玉字俗气而改名甄嬛,为汉人甄远道之女,后被雍正赐姓钮祜禄氏,抬旗为满洲上三旗,获名“钮祜禄·甄嬛”。同沈眉庄、安陵容参加选秀,因容貌酷似纯元皇后而被选中。入宫后面对华妃的步步紧逼,沈眉庄被冤、安陵容变心,从偏安一隅的青涩少女变成了能引起血雨腥风的宫斗老手。雍正发现年氏一族的野心后令其父甄远道剪除,甄嬛也于后宫中用她的连环巧计帮皇帝解决政敌,故而深得雍正爱待。几经周折,终于斗垮了嚣张跋扈的华妃。甄嬛封妃时遭皇后宜修暗算,被皇上嫌弃,生下女儿胧月后心灰意冷,自请出宫为尼。然得果郡王爱慕,二人相爱,得知果郡王死讯后立刻设计与雍正再遇,风光回宫。此后甄父冤案平反、甄氏复起,她也生下双生子,在滴血验亲等各种阴谋中躲过宜修的暗害,最后以牺牲自己亲生胎儿的方式扳倒了幕后黑手的皇后。但雍正又逼甄嬛毒杀允礼,以测试甄嬛真心,并让已经生产过孩子的甄嬛去准格尔和亲。甄嬛遂视皇帝为最该毁灭的对象,大结局道尽“人类的一切争斗,皆因统治者的不公不义而起”,并毒杀雍正。四阿哥弘历登基为乾隆,甄嬛被尊为圣母皇太后,权倾朝野,在如懿传中安度晚年。
70 |
71 | Chat-甄嬛,实现以《甄嬛传》为切入点,打造一套基于小说、剧本的**个性化 AI** 微调大模型完整流程,通过提供任一小说、剧本,指定人物角色,运行本项目完整流程,让每一位用户都基于心仪的小说、剧本打造一个属于自己的、契合角色人设、具备高度智能的个性化 AI。
72 |
73 | > 具体如何实现全流程的 Character-AI 微调,可参考主仓库-[huanhuan-chat](https://github.com/KMnO4-zx/huanhuan-chat.git)。
74 | >
75 | > 如何学习大模型部署和微调请参考:[开源大模型食用指南](https://github.com/datawhalechina/self-llm.git) 以及 [书生·浦语大模型实战营课程](https://github.com/InternLM/tutorial.git)
76 |
77 | ***欢迎大家来给[InternLM2](https://github.com/InternLM/InternLM.git),点点star哦~***
78 |
79 | Chat嬛嬛全流程如图所示:
80 |
81 |
82 |
83 |
84 |
85 | ## 🔗 模型及体验地址
86 |
87 | ***OpenXLab 体验地址:***
88 |
89 | ***https://openxlab.org.cn/apps/detail/BYCJS/Chat_huanhuan***
90 |
91 | 
92 |
93 | ***Chat-嬛嬛 模型下载地址:***
94 |
95 | - ***OpenXLab***
96 |
97 | ***7B: https://openxlab.org.cn/models/detail/BYCJS/huanhuan-chat-internlm2***
98 |
99 | ***1.8B: https://openxlab.org.cn/models/detail/BYCJS/huanhuan-chat-internlm2-1_8b***
100 |
101 | 
102 |
103 | - ***ModelSope***
104 |
105 | ***7B: https://www.modelscope.cn/models/kmno4zx/huanhuan-chat-internlm2/summary***
106 |
107 | ***1.8B: https://www.modelscope.cn/models/kmno4zx/huanhuan-chat-internlm2-1_8b/summary***
108 |
109 | 
110 |
111 |
112 | ## 🚀 News
113 |
114 | ***2月5日,完成 [InternLM2-chat-1_8B模型的全量微调](https://www.modelscope.cn/models/kmno4zx/huanhuan-chat-internlm2-1_8b/summary) ,模型已上传ModelScop2,大家可以来下载哦~***
115 |
116 | ***1月22日,Chat-嬛嬛应用在 OpenXLab,累计聊天次数已达 3.64k 次,感谢大家的支持~***
117 |
118 | ***1月22日,Chat-嬛嬛模型 魔搭 累计下载 3107 次!***
119 |
120 |
121 | ## 🧾 数据集
122 |
123 | Chat-嬛嬛 数据集采用《甄嬛传》剧本中所有关于甄嬛的台词和语句,共计 3000 余条,数据集样例:
124 |
125 | ```text
126 | 第15幕
127 | (秀女们在等候殿选。甄嬛看见了眉庄,上前相认)
128 | 甄嬛:眉姐姐!
129 | 眉庄:嬛儿,早就听说妹妹中选了,可就是一直不得空见你。
130 | 甄嬛(凑近):我倒巴不得没选上呢。姐姐远道过来一定很辛苦吧。
131 | 眉庄:在京里休息了这些日子,早已经调养过来了。
132 | 甄嬛:如今你住在自己京城的宅子里,不比从前住在外祖家,一墙之隔,见面也方便。
133 | 眉庄:是啊。可是我总还想着我们一起长大的情分呢。诶?妹妹今日打扮得好生素净,可是细看起来还是个美人坯子,怎么看都是好的。
134 | 甄嬛:沈大美人差矣,姐姐出落得这么标致,皇上见过必定会念念不忘。
135 | 眉庄(伸手制止,左右看了下):今天秀女佼佼者众多,我未必中选,若教旁人听见了,又要生出是非。
136 | ```
137 |
138 | 使用脚本将剧本中关于甄嬛的对话集抽取出来,作为数据集使用。
139 |
140 | 也可以使用这个仓库的脚本[Extract Dialogue](https://github.com/KMnO4-zx/extract-dialogue.git),请GPT老师来帮助我们从小说中抽取对话集。
141 |
142 | 
143 |
144 | ## 🛠️ 使用方法
145 |
146 | ### 快速开始
147 |
148 |
149 |
150 | 1. 下载模型
151 |
152 |
153 | 从 ModelScope
154 |
155 | 参考 [模型的下载](https://www.modelscope.cn/docs/%E6%A8%A1%E5%9E%8B%E7%9A%84%E4%B8%8B%E8%BD%BD) 。
156 |
157 | ```bash
158 | pip install modelscope
159 | ```
160 |
161 | ```python
162 | from modelscope.hub.snapshot_download import snapshot_download
163 | model_dir = snapshot_download('kmno4zx/huanhuan-chat-internlm2', cache_dir='./')
164 | ```
165 |
166 |
167 |
168 |
169 |
170 | 从 OpenXLab
171 |
172 | 参考 [下载模型](https://openxlab.org.cn/docs/models/%E4%B8%8B%E8%BD%BD%E6%A8%A1%E5%9E%8B.html) 。
173 |
174 | ```bash
175 | pip install openxlab
176 | ```
177 |
178 | ```python
179 | from openxlab.model import download
180 | download(model_repo='BYCJS/huanhuan-chat-internlm2',
181 | model_name='huanhuan-chat-internlm2', output='./')
182 | ```
183 |
184 |
185 |
186 | 2. 本地部署
187 |
188 | ```bash
189 | git clone https://github.com/KMnO4-zx/xlab-huanhuan.git
190 | python start.py
191 | ```
192 | ### 重新训练
193 |
194 | #### 环境搭建
195 |
196 | 1. clone 本项目
197 |
198 | ```bash
199 | git clone https://github.com/KMnO4-zx/xlab-huanhuan.git
200 | cd xlab-huanhuan
201 | ```
202 |
203 | 2. 创建环境
204 |
205 | ```bash
206 | pip install -r requirements.txt
207 | ```
208 |
209 | >有两种微调方案,我们更推荐使用 XTuner 训练, XTuner 有各个模型的一键训练脚本,相对便捷。且对 InternLM2 的支持度最高。
210 |
211 | #### Transformers微调
212 | 使用 Transformers 的 Trainer 进行微调,具体脚本可参考[internlm2-chat-lora](./train/internlm2-chat-lora.ipynb),该脚本在`train`文件夹下。脚本内有较为详细的注释。
213 |
214 | #### XTuner微调
215 | 使用 XTuner 进行微调,具体脚本可参考[internlm2_chat_7b_qlora_oasst1_e3_copy.py](./train/internlm2_chat_7b_qlora_oasst1_e3_copy.py),该脚本在`train`文件夹下。脚本内有较为详细的注释。
216 |
217 |
218 | ### 部署
219 | #### OpenXLab 部署 Chat-嬛嬛
220 |
221 | 仅需要 Fork 本仓库,然后在 OpenXLab 上创建一个新的项目,将 Fork 的仓库与新建的项目关联,即可在 OpenXLab 上部署 Chat-嬛嬛。
222 |
223 | ***OPenXLab Chat嬛嬛 https://openxlab.org.cn/apps/detail/BYCJS/Chat_huanhuan***
224 |
225 | 
226 |
227 | #### LmDeploy部署
228 |
229 | - 首先安装LmDeploy
230 |
231 | ```shell
232 | pip install -U lmdeploy
233 | ```
234 |
235 | - 然后转换模型为`turbomind`格式
236 |
237 | > --dst-path: 可以指定转换后的模型存储位置。
238 |
239 | ```shell
240 | lmdeploy convert internlm2-chat-7b 要转化的模型地址 --dst-path 转换后的模型地址
241 | ```
242 |
243 | - LmDeploy Chat 对话
244 |
245 | ```shell
246 | lmdeploy chat turbomind 转换后的turbomind模型地址
247 | ```
248 | ### 测评与量化
249 | #### OpneCompass 评测
250 |
251 | - 安装 OpenCompass
252 |
253 | ```shell
254 | git clone https://github.com/open-compass/opencompass
255 | cd opencompass
256 | pip install -e .
257 | ```
258 |
259 | - 下载解压数据集
260 |
261 | ```shell
262 | cp /share/temp/datasets/OpenCompassData-core-20231110.zip /root/opencompass/
263 | unzip OpenCompassData-core-20231110.zip
264 | ```
265 |
266 | - 评测启动!
267 |
268 | ```shell
269 | python run.py \
270 | --datasets ceval_gen \
271 | --hf-path /root/model/huanhuan/kmno4zx/huanhuan-chat-internlm2 \
272 | --tokenizer-path /root/model/huanhuan/kmno4zx/huanhuan-chat-internlm2 \
273 | --tokenizer-kwargs padding_side='left' truncation='left' trust_remote_code=True \
274 | --model-kwargs device_map='auto' trust_remote_code=True \
275 | --max-seq-len 2048 \
276 | --max-out-len 16 \
277 | --batch-size 2 \
278 | --num-gpus 1 \
279 | --debug
280 | ```
281 |
282 | #### Lmdeploy&opencompass 量化以及量化评测
283 | ##### `W4`量化评测
284 |
285 | - `W4`量化
286 | ```shell
287 | lmdeploy lite auto_awq 要量化的模型地址 --work-dir 量化后的模型地址
288 | ```
289 | - 转化为`TurbMind`
290 | ```shell
291 | lmdeploy convert internlm2-chat-7b 量化后的模型地址 --model-format awq --group-size 128 --dst-path 转换后的模型地址
292 | ```
293 | - 评测`config`编写
294 | ```python
295 | from mmengine.config import read_base
296 | from opencompass.models.turbomind import TurboMindModel
297 |
298 | with read_base():
299 | # choose a list of datasets
300 | from .datasets.ceval.ceval_gen import ceval_datasets
301 | # and output the results in a choosen format
302 | # from .summarizers.medium import summarizer
303 |
304 | datasets = [*ceval_datasets]
305 |
306 | internlm2_chat_7b = dict(
307 | type=TurboMindModel,
308 | abbr='internlm2-chat-7b-turbomind',
309 | path='转换后的模型地址',
310 | engine_config=dict(session_len=512,
311 | max_batch_size=2,
312 | rope_scaling_factor=1.0),
313 | gen_config=dict(top_k=1,
314 | top_p=0.8,
315 | temperature=1.0,
316 | max_new_tokens=100),
317 | max_out_len=100,
318 | max_seq_len=512,
319 | batch_size=2,
320 | concurrency=1,
321 | # meta_template=internlm_meta_template,
322 | run_cfg=dict(num_gpus=1, num_procs=1),
323 | )
324 | models = [internlm2_chat_7b]
325 |
326 | ```
327 | - 评测启动!
328 | ```shell
329 | python run.py configs/eval_turbomind.py -w 指定结果保存路径
330 | ```
331 | ##### `KV Cache`量化评测
332 | - 转换为`TurbMind`
333 | ```shell
334 | lmdeploy convert internlm2-chat-7b 模型路径 --dst-path 转换后模型路径
335 | ```
336 | - 计算与获得量化参数
337 | ```shell
338 | # 计算
339 | lmdeploy lite calibrate 模型路径 --calib-dataset 'ptb' --calib-samples 128 --calib-seqlen 2048 --work-dir 参数保存路径
340 | # 获取量化参数
341 | lmdeploy lite kv_qparams 参数保存路径 转换后模型路径/triton_models/weights/ --num-tp 1
342 | ```
343 | - 更改`quant_policy`改成`4`,更改上述`config`里面的路径
344 | - 评测启动!
345 | ```shell
346 | python run.py configs/eval_turbomind.py -w 结果保存路径
347 | ```
348 | 结果文件可在同目录文件[results](./results)中获取
349 |
350 | ## 💕 致谢
351 |
352 | ### 项目成员
353 |
354 | - 宋志学-项目负责人 (Datawhale成员 书生·浦语实战营助教 负责项目规划,数据集制作及模型训练)
355 | - 肖鸿儒(Datawhale成员 书生·浦语实战营助教 负责数据集收集、模型评测)
356 | - 邹雨衡(Datawhale成员 负责数据集收集)
357 | - 杜森(Datawhale成员 负责数据集收集)
358 |
359 | ### 特别感谢
360 |
361 |
362 |
363 | ***感谢上海人工智能实验室组织的 书生·浦语实战营 学习活动~***
364 |
365 | ***感谢 OpenXLab 对项目部署的算力支持~***
366 |
367 | ***感谢 浦语小助手 对项目的支持~***
368 |
369 |
--------------------------------------------------------------------------------
/app.py:
--------------------------------------------------------------------------------
1 | # 导入所需的库
2 | from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
3 | import torch
4 | import streamlit as st
5 |
6 | from modelscope import snapshot_download
7 |
8 | # 在侧边栏中创建一个标题和一个链接
9 | with st.sidebar:
10 | st.markdown("## InternLM LLM")
11 | "[InternLM](https://github.com/InternLM/InternLM.git)"
12 | "[开源大模型食用指南 self-llm](https://github.com/datawhalechina/self-llm.git)"
13 | "[Chat嬛嬛](https://github.com/KMnO4-zx/huanhuan-chat.git)"
14 | # 创建一个滑块,用于选择最大长度,范围在0到1024之间,默认值为512
15 | max_length = st.slider("max_length", 0, 1024, 512, step=1)
16 | system_prompt = st.text_input("System_Prompt", "现在你要扮演皇帝身边的女人--甄嬛")
17 |
18 | # 创建一个标题和一个副标题
19 | st.title("💬 InternLM2-Chat-7B 嬛嬛版")
20 | st.caption("🚀 A streamlit chatbot powered by InternLM2 QLora")
21 |
22 | # 定义模型路径
23 |
24 | model_id = 'kmno4zx/huanhuan-chat-internlm2'
25 |
26 | mode_name_or_path = snapshot_download(model_id, revision='master')
27 |
28 |
29 | # 定义一个函数,用于获取模型和tokenizer
30 | @st.cache_resource
31 | def get_model():
32 | # 从预训练的模型中获取tokenizer
33 | tokenizer = AutoTokenizer.from_pretrained(mode_name_or_path, trust_remote_code=True)
34 | # 从预训练的模型中获取模型,并设置模型参数
35 | model = AutoModelForCausalLM.from_pretrained(mode_name_or_path, trust_remote_code=True, torch_dtype=torch.bfloat16).cuda()
36 | model.eval()
37 | return tokenizer, model
38 |
39 | # 加载Chatglm3的model和tokenizer
40 | tokenizer, model = get_model()
41 |
42 | # 如果session_state中没有"messages",则创建一个包含默认消息的列表
43 | if "messages" not in st.session_state:
44 | st.session_state["messages"] = []
45 |
46 | # 遍历session_state中的所有消息,并显示在聊天界面上
47 | for msg in st.session_state.messages:
48 | st.chat_message("user").write(msg[0])
49 | st.chat_message("assistant").write(msg[1])
50 |
51 | # 如果用户在聊天输入框中输入了内容,则执行以下操作
52 | if prompt := st.chat_input():
53 | # 在聊天界面上显示用户的输入
54 | st.chat_message("user").write(prompt)
55 | # 构建输入
56 | response, history = model.chat(tokenizer, prompt, meta_instruction=system_prompt, history=st.session_state.messages)
57 | # 将模型的输出添加到session_state中的messages列表中
58 | st.session_state.messages.append((prompt, response))
59 | # 在聊天界面上显示模型的输出
60 | st.chat_message("assistant").write(response)
--------------------------------------------------------------------------------
/images/Extract-Dialogue.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KMnO4-zx/xlab-huanhuan/30d3b1c48f4c1465c2acf65a8f616769d54e0b51/images/Extract-Dialogue.png
--------------------------------------------------------------------------------
/images/chat嬛嬛.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KMnO4-zx/xlab-huanhuan/30d3b1c48f4c1465c2acf65a8f616769d54e0b51/images/chat嬛嬛.png
--------------------------------------------------------------------------------
/images/compass_support.svg:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/images/huanhuan_chat.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KMnO4-zx/xlab-huanhuan/30d3b1c48f4c1465c2acf65a8f616769d54e0b51/images/huanhuan_chat.png
--------------------------------------------------------------------------------
/images/huanhuan_img.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KMnO4-zx/xlab-huanhuan/30d3b1c48f4c1465c2acf65a8f616769d54e0b51/images/huanhuan_img.png
--------------------------------------------------------------------------------
/images/license.svg:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/images/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KMnO4-zx/xlab-huanhuan/30d3b1c48f4c1465c2acf65a8f616769d54e0b51/images/logo.png
--------------------------------------------------------------------------------
/images/modelscope.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KMnO4-zx/xlab-huanhuan/30d3b1c48f4c1465c2acf65a8f616769d54e0b51/images/modelscope.png
--------------------------------------------------------------------------------
/images/modelscope_logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KMnO4-zx/xlab-huanhuan/30d3b1c48f4c1465c2acf65a8f616769d54e0b51/images/modelscope_logo.png
--------------------------------------------------------------------------------
/images/openxlab.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KMnO4-zx/xlab-huanhuan/30d3b1c48f4c1465c2acf65a8f616769d54e0b51/images/openxlab.png
--------------------------------------------------------------------------------
/images/openxlab_model.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/KMnO4-zx/xlab-huanhuan/30d3b1c48f4c1465c2acf65a8f616769d54e0b51/images/openxlab_model.jpg
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | modelscope==1.9.5
2 | transformers==4.36.2
3 | streamlit==1.24.0
4 | sentencepiece==0.1.99
5 | accelerate==0.24.1
6 | transformers_stream_generator==0.0.4
7 | tiktoken
--------------------------------------------------------------------------------
/results/huanhuan_lianghua.csv:
--------------------------------------------------------------------------------
1 | dataset,version,metric,mode,internlm2-chat-7b-turbomind
2 | --------- 考试 Exam ---------,-,-,-,-
3 | ceval,-,naive_average,gen,22.37
4 | agieval,-,-,-,-
5 | mmlu,-,-,-,-
6 | GaokaoBench,-,-,-,-
7 | ARC-c,-,-,-,-
8 | --------- 语言 Language ---------,-,-,-,-
9 | WiC,-,-,-,-
10 | summedits,-,-,-,-
11 | chid-dev,-,-,-,-
12 | afqmc-dev,-,-,-,-
13 | bustm-dev,-,-,-,-
14 | cluewsc-dev,-,-,-,-
15 | WSC,-,-,-,-
16 | winogrande,-,-,-,-
17 | flores_100,-,-,-,-
18 | --------- 知识 Knowledge ---------,-,-,-,-
19 | BoolQ,-,-,-,-
20 | commonsense_qa,-,-,-,-
21 | nq,-,-,-,-
22 | triviaqa,-,-,-,-
23 | --------- 推理 Reasoning ---------,-,-,-,-
24 | cmnli,-,-,-,-
25 | ocnli,-,-,-,-
26 | ocnli_fc-dev,-,-,-,-
27 | AX_b,-,-,-,-
28 | AX_g,-,-,-,-
29 | CB,-,-,-,-
30 | RTE,-,-,-,-
31 | story_cloze,-,-,-,-
32 | COPA,-,-,-,-
33 | ReCoRD,-,-,-,-
34 | hellaswag,-,-,-,-
35 | piqa,-,-,-,-
36 | siqa,-,-,-,-
37 | strategyqa,-,-,-,-
38 | math,-,-,-,-
39 | gsm8k,-,-,-,-
40 | TheoremQA,-,-,-,-
41 | openai_humaneval,-,-,-,-
42 | mbpp,-,-,-,-
43 | bbh,-,-,-,-
44 | --------- 理解 Understanding ---------,-,-,-,-
45 | C3,-,-,-,-
46 | CMRC_dev,-,-,-,-
47 | DRCD_dev,-,-,-,-
48 | MultiRC,-,-,-,-
49 | race-middle,-,-,-,-
50 | race-high,-,-,-,-
51 | openbookqa_fact,-,-,-,-
52 | csl_dev,-,-,-,-
53 | lcsts,-,-,-,-
54 | Xsum,-,-,-,-
55 | eprstmt-dev,-,-,-,-
56 | lambada,-,-,-,-
57 | tnews-dev,-,-,-,-
58 |
--------------------------------------------------------------------------------
/results/huanhuan_ori.csv:
--------------------------------------------------------------------------------
1 | dataset,version,metric,mode,opencompass.models.huggingface.HuggingFace_kmno4zx_huanhuan-chat-internlm2
2 | ceval-computer_network,db9ce2,accuracy,gen,52.63
3 | ceval-operating_system,1c2571,accuracy,gen,57.89
4 | ceval-computer_architecture,a74dad,accuracy,gen,66.67
5 | ceval-college_programming,4ca32a,accuracy,gen,59.46
6 | ceval-college_physics,963fa8,accuracy,gen,36.84
7 | ceval-college_chemistry,e78857,accuracy,gen,41.67
8 | ceval-advanced_mathematics,ce03e2,accuracy,gen,26.32
9 | ceval-probability_and_statistics,65e812,accuracy,gen,27.78
10 | ceval-discrete_mathematics,e894ae,accuracy,gen,25.00
11 | ceval-electrical_engineer,ae42b9,accuracy,gen,32.43
12 | ceval-metrology_engineer,ee34ea,accuracy,gen,58.33
13 | ceval-high_school_mathematics,1dc5bf,accuracy,gen,22.22
14 | ceval-high_school_physics,adf25f,accuracy,gen,52.63
15 | ceval-high_school_chemistry,2ed27f,accuracy,gen,52.63
16 | ceval-high_school_biology,8e2b9a,accuracy,gen,31.58
17 | ceval-middle_school_mathematics,bee8d5,accuracy,gen,52.63
18 | ceval-middle_school_biology,86817c,accuracy,gen,66.67
19 | ceval-middle_school_physics,8accf6,accuracy,gen,57.89
20 | ceval-middle_school_chemistry,167a15,accuracy,gen,100.00
21 | ceval-veterinary_medicine,b4e08d,accuracy,gen,52.17
22 | ceval-college_economics,f3f4e6,accuracy,gen,45.45
23 | ceval-business_administration,c1614e,accuracy,gen,48.48
24 | ceval-marxism,cf874c,accuracy,gen,73.68
25 | ceval-mao_zedong_thought,51c7a4,accuracy,gen,66.67
26 | ceval-education_science,591fee,accuracy,gen,72.41
27 | ceval-teacher_qualification,4e4ced,accuracy,gen,79.55
28 | ceval-high_school_politics,5c0de2,accuracy,gen,63.16
29 | ceval-high_school_geography,865461,accuracy,gen,52.63
30 | ceval-middle_school_politics,5be3e7,accuracy,gen,57.14
31 | ceval-middle_school_geography,8a63be,accuracy,gen,50.00
32 | ceval-modern_chinese_history,fc01af,accuracy,gen,43.48
33 | ceval-ideological_and_moral_cultivation,a2aa4a,accuracy,gen,84.21
34 | ceval-logic,f5b022,accuracy,gen,59.09
35 | ceval-law,a110a1,accuracy,gen,25.00
36 | ceval-chinese_language_and_literature,0f8b68,accuracy,gen,52.17
37 | ceval-art_studies,2a1300,accuracy,gen,48.48
38 | ceval-professional_tour_guide,4e673e,accuracy,gen,68.97
39 | ceval-legal_professional,ce8787,accuracy,gen,30.43
40 | ceval-high_school_chinese,315705,accuracy,gen,57.89
41 | ceval-high_school_history,7eb30a,accuracy,gen,80.00
42 | ceval-middle_school_history,48ab4a,accuracy,gen,68.18
43 | ceval-civil_servant,87d061,accuracy,gen,51.06
44 | ceval-sports_science,70f27b,accuracy,gen,52.63
45 | ceval-plant_protection,8941f9,accuracy,gen,86.36
46 | ceval-basic_medicine,c409d6,accuracy,gen,57.89
47 | ceval-clinical_medicine,49e82d,accuracy,gen,45.45
48 | ceval-urban_and_rural_planner,95b885,accuracy,gen,67.39
49 | ceval-accountant,002837,accuracy,gen,46.94
50 | ceval-fire_engineer,bc23f5,accuracy,gen,38.71
51 | ceval-environmental_impact_assessment_engineer,c64e2d,accuracy,gen,48.39
52 | ceval-tax_accountant,3a5e3c,accuracy,gen,48.98
53 | ceval-physician,6e277d,accuracy,gen,51.02
54 | ceval-stem,-,naive_average,gen,48.67
55 | ceval-social-science,-,naive_average,gen,60.92
56 | ceval-humanities,-,naive_average,gen,56.17
57 | ceval-other,-,naive_average,gen,54.08
58 | ceval-hard,-,naive_average,gen,35.64
59 | ceval,-,naive_average,gen,53.76
60 |
--------------------------------------------------------------------------------
/results/internlm2_kv.csv:
--------------------------------------------------------------------------------
1 | dataset,version,metric,mode,internlm2-chat-7b-turbomind
2 | ceval-computer_network,db9ce2,accuracy,gen,47.37
3 | ceval-operating_system,1c2571,accuracy,gen,63.16
4 | ceval-computer_architecture,a74dad,accuracy,gen,42.86
5 | ceval-college_programming,4ca32a,accuracy,gen,24.32
6 | ceval-college_physics,963fa8,accuracy,gen,15.79
7 | ceval-college_chemistry,e78857,accuracy,gen,0.00
8 | ceval-advanced_mathematics,ce03e2,accuracy,gen,10.53
9 | ceval-probability_and_statistics,65e812,accuracy,gen,11.11
10 | ceval-discrete_mathematics,e894ae,accuracy,gen,12.50
11 | ceval-electrical_engineer,ae42b9,accuracy,gen,18.92
12 | ceval-metrology_engineer,ee34ea,accuracy,gen,37.50
13 | ceval-high_school_mathematics,1dc5bf,accuracy,gen,5.56
14 | ceval-high_school_physics,adf25f,accuracy,gen,21.05
15 | ceval-high_school_chemistry,2ed27f,accuracy,gen,21.05
16 | ceval-high_school_biology,8e2b9a,accuracy,gen,31.58
17 | ceval-middle_school_mathematics,bee8d5,accuracy,gen,31.58
18 | ceval-middle_school_biology,86817c,accuracy,gen,71.43
19 | ceval-middle_school_physics,8accf6,accuracy,gen,52.63
20 | ceval-middle_school_chemistry,167a15,accuracy,gen,75.00
21 | ceval-veterinary_medicine,b4e08d,accuracy,gen,43.48
22 | ceval-college_economics,f3f4e6,accuracy,gen,25.45
23 | ceval-business_administration,c1614e,accuracy,gen,27.27
24 | ceval-marxism,cf874c,accuracy,gen,84.21
25 | ceval-mao_zedong_thought,51c7a4,accuracy,gen,75.00
26 | ceval-education_science,591fee,accuracy,gen,62.07
27 | ceval-teacher_qualification,4e4ced,accuracy,gen,75.00
28 | ceval-high_school_politics,5c0de2,accuracy,gen,21.05
29 | ceval-high_school_geography,865461,accuracy,gen,57.89
30 | ceval-middle_school_politics,5be3e7,accuracy,gen,47.62
31 | ceval-middle_school_geography,8a63be,accuracy,gen,50.00
32 | ceval-modern_chinese_history,fc01af,accuracy,gen,69.57
33 | ceval-ideological_and_moral_cultivation,a2aa4a,accuracy,gen,89.47
34 | ceval-logic,f5b022,accuracy,gen,36.36
35 | ceval-law,a110a1,accuracy,gen,29.17
36 | ceval-chinese_language_and_literature,0f8b68,accuracy,gen,47.83
37 | ceval-art_studies,2a1300,accuracy,gen,66.67
38 | ceval-professional_tour_guide,4e673e,accuracy,gen,79.31
39 | ceval-legal_professional,ce8787,accuracy,gen,17.39
40 | ceval-high_school_chinese,315705,accuracy,gen,36.84
41 | ceval-high_school_history,7eb30a,accuracy,gen,75.00
42 | ceval-middle_school_history,48ab4a,accuracy,gen,68.18
43 | ceval-civil_servant,87d061,accuracy,gen,29.79
44 | ceval-sports_science,70f27b,accuracy,gen,57.89
45 | ceval-plant_protection,8941f9,accuracy,gen,63.64
46 | ceval-basic_medicine,c409d6,accuracy,gen,57.89
47 | ceval-clinical_medicine,49e82d,accuracy,gen,45.45
48 | ceval-urban_and_rural_planner,95b885,accuracy,gen,56.52
49 | ceval-accountant,002837,accuracy,gen,26.53
50 | ceval-fire_engineer,bc23f5,accuracy,gen,16.13
51 | ceval-environmental_impact_assessment_engineer,c64e2d,accuracy,gen,41.94
52 | ceval-tax_accountant,3a5e3c,accuracy,gen,32.65
53 | ceval-physician,6e277d,accuracy,gen,55.10
54 |
--------------------------------------------------------------------------------
/results/internlm2_ori.csv:
--------------------------------------------------------------------------------
1 | dataset,version,metric,mode,opencompass.models.huggingface.HuggingFace_kmno4zx_huanhuan-chat-internlm2
2 | ceval-computer_network,db9ce2,accuracy,gen,52.63
3 | ceval-operating_system,1c2571,accuracy,gen,57.89
4 | ceval-computer_architecture,a74dad,accuracy,gen,66.67
5 | ceval-college_programming,4ca32a,accuracy,gen,59.46
6 | ceval-college_physics,963fa8,accuracy,gen,36.84
7 | ceval-college_chemistry,e78857,accuracy,gen,41.67
8 | ceval-advanced_mathematics,ce03e2,accuracy,gen,26.32
9 | ceval-probability_and_statistics,65e812,accuracy,gen,27.78
10 | ceval-discrete_mathematics,e894ae,accuracy,gen,25.00
11 | ceval-electrical_engineer,ae42b9,accuracy,gen,32.43
12 | ceval-metrology_engineer,ee34ea,accuracy,gen,58.33
13 | ceval-high_school_mathematics,1dc5bf,accuracy,gen,22.22
14 | ceval-high_school_physics,adf25f,accuracy,gen,52.63
15 | ceval-high_school_chemistry,2ed27f,accuracy,gen,52.63
16 | ceval-high_school_biology,8e2b9a,accuracy,gen,31.58
17 | ceval-middle_school_mathematics,bee8d5,accuracy,gen,52.63
18 | ceval-middle_school_biology,86817c,accuracy,gen,66.67
19 | ceval-middle_school_physics,8accf6,accuracy,gen,57.89
20 | ceval-middle_school_chemistry,167a15,accuracy,gen,100.00
21 | ceval-veterinary_medicine,b4e08d,accuracy,gen,52.17
22 | ceval-college_economics,f3f4e6,accuracy,gen,45.45
23 | ceval-business_administration,c1614e,accuracy,gen,48.48
24 | ceval-marxism,cf874c,accuracy,gen,73.68
25 | ceval-mao_zedong_thought,51c7a4,accuracy,gen,66.67
26 | ceval-education_science,591fee,accuracy,gen,72.41
27 | ceval-teacher_qualification,4e4ced,accuracy,gen,79.55
28 | ceval-high_school_politics,5c0de2,accuracy,gen,63.16
29 | ceval-high_school_geography,865461,accuracy,gen,52.63
30 | ceval-middle_school_politics,5be3e7,accuracy,gen,57.14
31 | ceval-middle_school_geography,8a63be,accuracy,gen,50.00
32 | ceval-modern_chinese_history,fc01af,accuracy,gen,43.48
33 | ceval-ideological_and_moral_cultivation,a2aa4a,accuracy,gen,84.21
34 | ceval-logic,f5b022,accuracy,gen,59.09
35 | ceval-law,a110a1,accuracy,gen,25.00
36 | ceval-chinese_language_and_literature,0f8b68,accuracy,gen,52.17
37 | ceval-art_studies,2a1300,accuracy,gen,48.48
38 | ceval-professional_tour_guide,4e673e,accuracy,gen,68.97
39 | ceval-legal_professional,ce8787,accuracy,gen,30.43
40 | ceval-high_school_chinese,315705,accuracy,gen,57.89
41 | ceval-high_school_history,7eb30a,accuracy,gen,80.00
42 | ceval-middle_school_history,48ab4a,accuracy,gen,68.18
43 | ceval-civil_servant,87d061,accuracy,gen,51.06
44 | ceval-sports_science,70f27b,accuracy,gen,52.63
45 | ceval-plant_protection,8941f9,accuracy,gen,86.36
46 | ceval-basic_medicine,c409d6,accuracy,gen,57.89
47 | ceval-clinical_medicine,49e82d,accuracy,gen,45.45
48 | ceval-urban_and_rural_planner,95b885,accuracy,gen,67.39
49 | ceval-accountant,002837,accuracy,gen,46.94
50 | ceval-fire_engineer,bc23f5,accuracy,gen,38.71
51 | ceval-environmental_impact_assessment_engineer,c64e2d,accuracy,gen,48.39
52 | ceval-tax_accountant,3a5e3c,accuracy,gen,48.98
53 | ceval-physician,6e277d,accuracy,gen,51.02
54 | ceval-stem,-,naive_average,gen,48.67
55 | ceval-social-science,-,naive_average,gen,60.92
56 | ceval-humanities,-,naive_average,gen,56.17
57 | ceval-other,-,naive_average,gen,54.08
58 | ceval-hard,-,naive_average,gen,35.64
59 | ceval,-,naive_average,gen,53.76
60 |
--------------------------------------------------------------------------------
/start.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | os.system('streamlit run app.py --server.address=0.0.0.0 --server.port 7860')
--------------------------------------------------------------------------------
/train/internlm2-chat-lora.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "## 导入必要的包"
8 | ]
9 | },
10 | {
11 | "cell_type": "code",
12 | "execution_count": null,
13 | "metadata": {},
14 | "outputs": [],
15 | "source": [
16 | "from datasets import Dataset\n",
17 | "import pandas as pd\n",
18 | "from transformers import AutoTokenizer, AutoModelForCausalLM, DataCollatorForSeq2Seq, TrainingArguments, Trainer, GenerationConfig\n",
19 | "import torch\n",
20 | "from typing import List, Optional, Tuple, Union"
21 | ]
22 | },
23 | {
24 | "cell_type": "code",
25 | "execution_count": null,
26 | "metadata": {},
27 | "outputs": [],
28 | "source": [
29 | "# 使用datasets读取数据\n",
30 | "df = pd.read_json('/root/data/huanhuan.json')\n",
31 | "ds = Dataset.from_pandas(df)"
32 | ]
33 | },
34 | {
35 | "cell_type": "markdown",
36 | "metadata": {},
37 | "source": [
38 | "## 处理数据集"
39 | ]
40 | },
41 | {
42 | "cell_type": "code",
43 | "execution_count": null,
44 | "metadata": {},
45 | "outputs": [],
46 | "source": [
47 | "tokenizer = AutoTokenizer.from_pretrained(\"/root/model/internlm-chat-7b/\", use_fast=False, trust_remote_code=True)\n",
48 | "tokenizer"
49 | ]
50 | },
51 | {
52 | "cell_type": "code",
53 | "execution_count": null,
54 | "metadata": {},
55 | "outputs": [],
56 | "source": [
57 | "def build_inputs(query: str, history: List[Tuple[str, str]] = [], meta_instruction=\"我是系统\"):\n",
58 | " prompt = \"\"\n",
59 | " if meta_instruction:\n",
60 | " prompt += f\"\"\"[UNUSED_TOKEN_146]system\\n{meta_instruction}[UNUSED_TOKEN_145]\\n\"\"\"\n",
61 | " else:\n",
62 | " prompt += \"\"\n",
63 | " for record in history:\n",
64 | " prompt += f\"\"\"[UNUSED_TOKEN_146]user\\n{record[0]}[UNUSED_TOKEN_145]\\n[UNUSED_TOKEN_146]assistant\\n{record[1]}[UNUSED_TOKEN_145]\\n\"\"\"\n",
65 | " prompt += f\"\"\"[UNUSED_TOKEN_146]user\\n{query}[UNUSED_TOKEN_145]\\n[UNUSED_TOKEN_146]assistant\\n\"\"\"\n",
66 | " return prompt"
67 | ]
68 | },
69 | {
70 | "cell_type": "code",
71 | "execution_count": null,
72 | "metadata": {},
73 | "outputs": [],
74 | "source": [
75 | "build_inputs('你哈')"
76 | ]
77 | },
78 | {
79 | "cell_type": "code",
80 | "execution_count": null,
81 | "metadata": {},
82 | "outputs": [],
83 | "source": [
84 | "def process_func(example):\n",
85 | "\n",
86 | " system_prompt = \"现在你要扮演皇帝身边的女人--甄嬛\"\n",
87 | "\n",
88 | " MAX_LENGTH = 512 # 分词器会将一个中文字切分为多个token,因此需要放开一些最大长度,保证数据的完整性\n",
89 | " input_ids, attention_mask, labels = [], [], []\n",
90 | " instruction = tokenizer(f\"[UNUSED_TOKEN_146]system{system_prompt}[UNUSED_TOKEN_145]\\n[UNUSED_TOKEN_146]user{example['instruction']}[UNUSED_TOKEN_145]\\n[UNUSED_TOKEN_146]assistant\", add_special_tokens=False) # add_special_tokens 不在开头加 special_tokens\n",
91 | " response = tokenizer(f\"{example['output']}[UNUSED_TOKEN_145]\", add_special_tokens=False)\n",
92 | " input_ids = instruction[\"input_ids\"] + response[\"input_ids\"] + [tokenizer.pad_token_id]\n",
93 | " attention_mask = instruction[\"attention_mask\"] + response[\"attention_mask\"] + [1] # 因为eos token咱们也是要关注的所以 补充为1\n",
94 | " labels = [-100] * len(instruction[\"input_ids\"]) + response[\"input_ids\"] + [tokenizer.pad_token_id] \n",
95 | " if len(input_ids) > MAX_LENGTH: # 做一个截断\n",
96 | " input_ids = input_ids[:MAX_LENGTH]\n",
97 | " attention_mask = attention_mask[:MAX_LENGTH]\n",
98 | " labels = labels[:MAX_LENGTH]\n",
99 | " return {\n",
100 | " \"input_ids\": input_ids,\n",
101 | " \"attention_mask\": attention_mask,\n",
102 | " \"labels\": labels\n",
103 | " }"
104 | ]
105 | },
106 | {
107 | "cell_type": "code",
108 | "execution_count": null,
109 | "metadata": {},
110 | "outputs": [],
111 | "source": [
112 | "tokenized_id = ds.map(process_func, remove_columns=ds.column_names)\n",
113 | "tokenized_id"
114 | ]
115 | },
116 | {
117 | "cell_type": "code",
118 | "execution_count": null,
119 | "metadata": {},
120 | "outputs": [],
121 | "source": [
122 | "tokenizer.decode(tokenized_id[0]['input_ids'])"
123 | ]
124 | },
125 | {
126 | "cell_type": "code",
127 | "execution_count": null,
128 | "metadata": {},
129 | "outputs": [],
130 | "source": [
131 | "tokenizer.decode(list(filter(lambda x: x != -100, tokenized_id[1][\"labels\"])))"
132 | ]
133 | },
134 | {
135 | "cell_type": "markdown",
136 | "metadata": {},
137 | "source": [
138 | "## 创建模型"
139 | ]
140 | },
141 | {
142 | "cell_type": "code",
143 | "execution_count": null,
144 | "metadata": {},
145 | "outputs": [],
146 | "source": [
147 | "model = AutoModelForCausalLM.from_pretrained(\n",
148 | " \"/root/share/model_repos/internlm2-chat-7b\", \n",
149 | " torch_dtype=torch.half, \n",
150 | " trust_remote_code=True,\n",
151 | " device_map={'':0},\n",
152 | " low_cpu_mem_usage=True, # 是否使用低CPU内存\n",
153 | " load_in_4bit=True, # 是否在4位精度下加载模型。如果设置为True,则在4位精度下加载模型。\n",
154 | " bnb_4bit_compute_dtype=torch.half, # 4位精度计算的数据类型。这里设置为torch.half,表示使用半精度浮点数。\n",
155 | " bnb_4bit_quant_type=\"nf4\", # 4位精度量化的类型。这里设置为\"nf4\",表示使用nf4量化类型。\n",
156 | " bnb_4bit_use_double_quant=True # 是否使用双精度量化。如果设置为True,则使用双精度量化。\n",
157 | " )\n",
158 | "model.enable_input_require_grads() # 开启梯度检查点时,要执行该方法\n",
159 | "model"
160 | ]
161 | },
162 | {
163 | "cell_type": "code",
164 | "execution_count": null,
165 | "metadata": {},
166 | "outputs": [],
167 | "source": [
168 | "model.dtype"
169 | ]
170 | },
171 | {
172 | "cell_type": "markdown",
173 | "metadata": {},
174 | "source": [
175 | "## Lora"
176 | ]
177 | },
178 | {
179 | "cell_type": "code",
180 | "execution_count": null,
181 | "metadata": {},
182 | "outputs": [],
183 | "source": [
184 | "from peft import LoraConfig, TaskType, get_peft_model\n",
185 | "\n",
186 | "config = LoraConfig(\n",
187 | " task_type=TaskType.CAUSAL_LM, \n",
188 | " target_modules=['wqkv', 'wo', 'w1', 'w2', 'w3'],\n",
189 | " inference_mode=False, # 训练模式\n",
190 | " r=8, # Lora 秩\n",
191 | " lora_alpha=32, # Lora alaph,具体作用参见 Lora 原理\n",
192 | " lora_dropout=0.1# Dropout 比例\n",
193 | ")\n",
194 | "config"
195 | ]
196 | },
197 | {
198 | "cell_type": "code",
199 | "execution_count": null,
200 | "metadata": {},
201 | "outputs": [],
202 | "source": [
203 | "model = get_peft_model(model, config)\n",
204 | "config"
205 | ]
206 | },
207 | {
208 | "cell_type": "code",
209 | "execution_count": null,
210 | "metadata": {},
211 | "outputs": [],
212 | "source": [
213 | "model.print_trainable_parameters()"
214 | ]
215 | },
216 | {
217 | "cell_type": "markdown",
218 | "metadata": {},
219 | "source": [
220 | "## 配置训练参数"
221 | ]
222 | },
223 | {
224 | "cell_type": "code",
225 | "execution_count": null,
226 | "metadata": {},
227 | "outputs": [],
228 | "source": [
229 | "args = TrainingArguments(\n",
230 | " output_dir=\"./output/math-internlm2-chat-7b\",\n",
231 | " per_device_train_batch_size=4,\n",
232 | " gradient_accumulation_steps=16,\n",
233 | " logging_steps=10,\n",
234 | " num_train_epochs=3,\n",
235 | " save_steps=100,\n",
236 | " learning_rate=1e-5,\n",
237 | " save_on_each_node=True,\n",
238 | " optim=\"paged_adamw_32bit\",\n",
239 | " gradient_checkpointing=True\n",
240 | ")"
241 | ]
242 | },
243 | {
244 | "cell_type": "code",
245 | "execution_count": null,
246 | "metadata": {},
247 | "outputs": [],
248 | "source": [
249 | "trainer = Trainer(\n",
250 | " model=model,\n",
251 | " args=args,\n",
252 | " train_dataset=tokenized_id,\n",
253 | " data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer, padding=True),\n",
254 | ")"
255 | ]
256 | },
257 | {
258 | "cell_type": "code",
259 | "execution_count": null,
260 | "metadata": {},
261 | "outputs": [],
262 | "source": [
263 | "trainer.train()"
264 | ]
265 | }
266 | ],
267 | "metadata": {
268 | "kernelspec": {
269 | "display_name": "internlm-demo",
270 | "language": "python",
271 | "name": "python3"
272 | },
273 | "language_info": {
274 | "codemirror_mode": {
275 | "name": "ipython",
276 | "version": 3
277 | },
278 | "file_extension": ".py",
279 | "mimetype": "text/x-python",
280 | "name": "python",
281 | "nbconvert_exporter": "python",
282 | "pygments_lexer": "ipython3",
283 | "version": "3.10.13"
284 | }
285 | },
286 | "nbformat": 4,
287 | "nbformat_minor": 2
288 | }
289 |
--------------------------------------------------------------------------------
/train/internlm2_1_8b_full_oasst1_e3_huanhuan.py:
--------------------------------------------------------------------------------
1 | SYSTEM = '现在你要扮演皇帝身边的女人--甄嬛'
2 | accumulative_counts = 4
3 | batch_size = 1
4 | betas = (
5 | 0.9,
6 | 0.999,
7 | )
8 | custom_hooks = [
9 | dict(
10 | tokenizer=dict(
11 | padding_side='right',
12 | pretrained_model_name_or_path=
13 | '/root/model/internlm/internlm2-chat-1_8b-sft',
14 | trust_remote_code=True,
15 | type='transformers.AutoTokenizer.from_pretrained'),
16 | type='xtuner.engine.DatasetInfoHook'),
17 | dict(
18 | evaluation_inputs=[
19 | '你好',
20 | '小主,敬事房传来消息,说皇上晚上去华妃那儿。',
21 | ],
22 | every_n_iters=500,
23 | prompt_template='xtuner.utils.PROMPT_TEMPLATE.internlm2_chat',
24 | system='现在你要扮演皇帝身边的女人--甄嬛',
25 | tokenizer=dict(
26 | padding_side='right',
27 | pretrained_model_name_or_path=
28 | '/root/model/internlm/internlm2-chat-1_8b-sft',
29 | trust_remote_code=True,
30 | type='transformers.AutoTokenizer.from_pretrained'),
31 | type='xtuner.engine.EvaluateChatHook'),
32 | dict(type='xtuner.engine.ThroughputHook'),
33 | ]
34 | data_path = '/root/data/huanhuan_xtuner.json'
35 | dataloader_num_workers = 0
36 | default_hooks = dict(
37 | checkpoint=dict(interval=1, type='mmengine.hooks.CheckpointHook'),
38 | logger=dict(interval=10, type='mmengine.hooks.LoggerHook'),
39 | param_scheduler=dict(type='mmengine.hooks.ParamSchedulerHook'),
40 | sampler_seed=dict(type='mmengine.hooks.DistSamplerSeedHook'),
41 | timer=dict(type='mmengine.hooks.IterTimerHook'))
42 | env_cfg = dict(
43 | cudnn_benchmark=False,
44 | dist_cfg=dict(backend='nccl'),
45 | mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0))
46 | evaluation_freq = 500
47 | evaluation_inputs = [
48 | '你好',
49 | '小主,敬事房传来消息,说皇上晚上去华妃那儿。',
50 | ]
51 | launcher = 'none'
52 | load_from = None
53 | log_level = 'INFO'
54 | lr = 2e-05
55 | max_epochs = 3
56 | max_length = 2048
57 | max_norm = 1
58 | model = dict(
59 | llm=dict(
60 | pretrained_model_name_or_path=
61 | '/root/model/internlm/internlm2-chat-1_8b-sft',
62 | trust_remote_code=True,
63 | type='transformers.AutoModelForCausalLM.from_pretrained'),
64 | type='xtuner.model.SupervisedFinetune')
65 | optim_type = 'torch.optim.AdamW'
66 | optim_wrapper = dict(
67 | accumulative_counts=4,
68 | clip_grad=dict(error_if_nonfinite=False, max_norm=1),
69 | dtype='float16',
70 | loss_scale='dynamic',
71 | optimizer=dict(
72 | betas=(
73 | 0.9,
74 | 0.999,
75 | ),
76 | lr=2e-05,
77 | type='torch.optim.AdamW',
78 | weight_decay=0),
79 | type='mmengine.optim.AmpOptimWrapper')
80 | pack_to_max_length = True
81 | param_scheduler = [
82 | dict(
83 | begin=0,
84 | by_epoch=True,
85 | convert_to_iter_based=True,
86 | end=0.09,
87 | start_factor=1e-05,
88 | type='mmengine.optim.LinearLR'),
89 | dict(
90 | T_max=3,
91 | begin=0.09,
92 | by_epoch=True,
93 | convert_to_iter_based=True,
94 | eta_min=0.0,
95 | type='mmengine.optim.CosineAnnealingLR'),
96 | ]
97 | pretrained_model_name_or_path = '/root/model/internlm/internlm2-chat-1_8b-sft'
98 | prompt_template = 'xtuner.utils.PROMPT_TEMPLATE.internlm2_chat'
99 | randomness = dict(deterministic=False, seed=None)
100 | resume = False
101 | tokenizer = dict(
102 | padding_side='right',
103 | pretrained_model_name_or_path=
104 | '/root/model/internlm/internlm2-chat-1_8b-sft',
105 | trust_remote_code=True,
106 | type='transformers.AutoTokenizer.from_pretrained')
107 | train_cfg = dict(by_epoch=True, max_epochs=3, val_interval=1)
108 | train_dataloader = dict(
109 | batch_size=1,
110 | collate_fn=dict(type='xtuner.dataset.collate_fns.default_collate_fn'),
111 | dataset=dict(
112 | dataset=dict(
113 | data_files=dict(train='/root/data/huanhuan_xtuner.json'),
114 | path='json',
115 | type='datasets.load_dataset'),
116 | dataset_map_fn=None,
117 | max_length=2048,
118 | pack_to_max_length=True,
119 | remove_unused_columns=True,
120 | shuffle_before_pack=True,
121 | template_map_fn=dict(
122 | template='xtuner.utils.PROMPT_TEMPLATE.internlm2_chat',
123 | type='xtuner.dataset.map_fns.template_map_fn_factory'),
124 | tokenizer=dict(
125 | padding_side='right',
126 | pretrained_model_name_or_path=
127 | '/root/model/internlm/internlm2-chat-1_8b-sft',
128 | trust_remote_code=True,
129 | type='transformers.AutoTokenizer.from_pretrained'),
130 | type='xtuner.dataset.process_hf_dataset'),
131 | num_workers=0,
132 | sampler=dict(shuffle=True, type='mmengine.dataset.DefaultSampler'))
133 | train_dataset = dict(
134 | dataset=dict(
135 | data_files=dict(train='/root/data/huanhuan_xtuner.json'),
136 | path='json',
137 | type='datasets.load_dataset'),
138 | dataset_map_fn=None,
139 | max_length=2048,
140 | pack_to_max_length=True,
141 | remove_unused_columns=True,
142 | shuffle_before_pack=True,
143 | template_map_fn=dict(
144 | template='xtuner.utils.PROMPT_TEMPLATE.internlm2_chat',
145 | type='xtuner.dataset.map_fns.template_map_fn_factory'),
146 | tokenizer=dict(
147 | padding_side='right',
148 | pretrained_model_name_or_path=
149 | '/root/model/internlm/internlm2-chat-1_8b-sft',
150 | trust_remote_code=True,
151 | type='transformers.AutoTokenizer.from_pretrained'),
152 | type='xtuner.dataset.process_hf_dataset')
153 | visualizer = None
154 | warmup_ratio = 0.03
155 | weight_decay = 0
156 | work_dir = './work_dirs/internlm2_1_8b_full_oasst1_e3_huanhuan'
157 |
--------------------------------------------------------------------------------
/train/internlm2_chat_7b_qlora_oasst1_e3_copy.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import torch
3 | from datasets import load_dataset
4 | from mmengine.dataset import DefaultSampler
5 | from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
6 | LoggerHook, ParamSchedulerHook)
7 | from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
8 | from peft import LoraConfig
9 | from torch.optim import AdamW
10 | from transformers import (AutoModelForCausalLM, AutoTokenizer,
11 | BitsAndBytesConfig)
12 |
13 | from xtuner.dataset import process_hf_dataset
14 | from xtuner.dataset.collate_fns import default_collate_fn
15 | from xtuner.dataset.map_fns import oasst1_map_fn, template_map_fn_factory
16 | from xtuner.engine import DatasetInfoHook, EvaluateChatHook
17 | from xtuner.model import SupervisedFinetune
18 | from xtuner.utils import PROMPT_TEMPLATE
19 |
20 | #######################################################################
21 | # PART 1 Settings #
22 | #######################################################################
23 | # Model
24 | pretrained_model_name_or_path = '/root/share/model_repos/internlm2-chat-7b'
25 |
26 | # Data
27 | data_path = '/root/data/huanhuan_xtuner.json'
28 | prompt_template = PROMPT_TEMPLATE.internlm2_chat
29 | max_length = 512
30 | pack_to_max_length = True
31 |
32 | # Scheduler & Optimizer
33 | batch_size = 1 # per_device
34 | accumulative_counts = 16
35 | dataloader_num_workers = 0
36 | max_epochs = 3
37 | optim_type = AdamW
38 | lr = 2e-4
39 | betas = (0.9, 0.999)
40 | weight_decay = 0
41 | max_norm = 1 # grad clip
42 | warmup_ratio = 0.03
43 |
44 | # Evaluate the generation performance during the training
45 | evaluation_freq = 90
46 | SYSTEM = '现在你要扮演皇帝身边的女人--甄嬛'
47 | evaluation_inputs = [
48 | '你好', '你是谁?'
49 | ]
50 |
51 | #######################################################################
52 | # PART 2 Model & Tokenizer #
53 | #######################################################################
54 | tokenizer = dict(
55 | type=AutoTokenizer.from_pretrained,
56 | pretrained_model_name_or_path=pretrained_model_name_or_path,
57 | trust_remote_code=True,
58 | padding_side='right')
59 |
60 | model = dict(
61 | type=SupervisedFinetune,
62 | llm=dict(
63 | type=AutoModelForCausalLM.from_pretrained,
64 | pretrained_model_name_or_path=pretrained_model_name_or_path,
65 | trust_remote_code=True,
66 | torch_dtype=torch.float16,
67 | quantization_config=dict(
68 | type=BitsAndBytesConfig,
69 | load_in_4bit=True,
70 | load_in_8bit=False,
71 | llm_int8_threshold=6.0,
72 | llm_int8_has_fp16_weight=False,
73 | bnb_4bit_compute_dtype=torch.float16,
74 | bnb_4bit_use_double_quant=True,
75 | bnb_4bit_quant_type='nf4')),
76 | lora=dict(
77 | type=LoraConfig,
78 | r=64,
79 | lora_alpha=16,
80 | lora_dropout=0.1,
81 | bias='none',
82 | task_type='CAUSAL_LM'))
83 |
84 | #######################################################################
85 | # PART 3 Dataset & Dataloader #
86 | #######################################################################
87 | train_dataset = dict(
88 | type=process_hf_dataset,
89 | dataset=dict(type=load_dataset, path='json', data_files=dict(train=data_path)),
90 | tokenizer=tokenizer,
91 | max_length=max_length,
92 | dataset_map_fn=None,
93 | template_map_fn=dict(
94 | type=template_map_fn_factory, template=prompt_template),
95 | remove_unused_columns=True,
96 | shuffle_before_pack=True,
97 | pack_to_max_length=pack_to_max_length)
98 |
99 | train_dataloader = dict(
100 | batch_size=batch_size,
101 | num_workers=dataloader_num_workers,
102 | dataset=train_dataset,
103 | sampler=dict(type=DefaultSampler, shuffle=True),
104 | collate_fn=dict(type=default_collate_fn))
105 |
106 | #######################################################################
107 | # PART 4 Scheduler & Optimizer #
108 | #######################################################################
109 | # optimizer
110 | optim_wrapper = dict(
111 | type=AmpOptimWrapper,
112 | optimizer=dict(
113 | type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
114 | clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
115 | accumulative_counts=accumulative_counts,
116 | loss_scale='dynamic',
117 | dtype='float16')
118 |
119 | # learning policy
120 | # More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501
121 | param_scheduler = [
122 | dict(
123 | type=LinearLR,
124 | start_factor=1e-5,
125 | by_epoch=True,
126 | begin=0,
127 | end=warmup_ratio * max_epochs,
128 | convert_to_iter_based=True),
129 | dict(
130 | type=CosineAnnealingLR,
131 | eta_min=0.0,
132 | by_epoch=True,
133 | begin=warmup_ratio * max_epochs,
134 | T_max=max_epochs,
135 | convert_to_iter_based=True)
136 | ]
137 |
138 | # train, val, test setting
139 | train_cfg = dict(by_epoch=True, max_epochs=max_epochs, val_interval=1)
140 |
141 | #######################################################################
142 | # PART 5 Runtime #
143 | #######################################################################
144 | # Log the dialogue periodically during the training process, optional
145 | custom_hooks = [
146 | dict(type=DatasetInfoHook, tokenizer=tokenizer),
147 | dict(
148 | type=EvaluateChatHook,
149 | tokenizer=tokenizer,
150 | every_n_iters=evaluation_freq,
151 | evaluation_inputs=evaluation_inputs,
152 | system=SYSTEM,
153 | prompt_template=prompt_template)
154 | ]
155 |
156 | # configure default hooks
157 | default_hooks = dict(
158 | # record the time of every iteration.
159 | timer=dict(type=IterTimerHook),
160 | # print log every 100 iterations.
161 | logger=dict(type=LoggerHook, interval=10),
162 | # enable the parameter scheduler.
163 | param_scheduler=dict(type=ParamSchedulerHook),
164 | # save checkpoint per epoch.
165 | checkpoint=dict(type=CheckpointHook, interval=1),
166 | # set sampler seed in distributed evrionment.
167 | sampler_seed=dict(type=DistSamplerSeedHook),
168 | )
169 |
170 | # configure environment
171 | env_cfg = dict(
172 | # whether to enable cudnn benchmark
173 | cudnn_benchmark=False,
174 | # set multi process parameters
175 | mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
176 | # set distributed parameters
177 | dist_cfg=dict(backend='nccl'),
178 | )
179 |
180 | # set visualizer
181 | visualizer = None
182 |
183 | # set log level
184 | log_level = 'INFO'
185 |
186 | # load from which checkpoint
187 | load_from = None
188 |
189 | # whether to resume training from the loaded checkpoint
190 | resume = False
191 |
192 | # Defaults to use random seed and disable `deterministic`
193 | randomness = dict(seed=None, deterministic=False)
194 |
--------------------------------------------------------------------------------