├── .gitignore ├── README.md ├── app.py ├── data ├── huanhuan.json └── huanhuan_xtuner.json ├── images ├── Extract-Dialogue.png ├── chat嬛嬛.png ├── compass_support.svg ├── huanhuan_chat.png ├── huanhuan_img.png ├── license.svg ├── logo.png ├── modelscope.png ├── modelscope_logo.png ├── openxlab.png ├── openxlab_model.jpg └── tech_route.svg ├── requirements.txt ├── results ├── huanhuan_lianghua.csv ├── huanhuan_ori.csv ├── internlm2_kv.csv └── internlm2_ori.csv ├── start.py └── train ├── internlm2-chat-lora.ipynb ├── internlm2_1_8b_full_oasst1_e3_huanhuan.py └── internlm2_chat_7b_qlora_oasst1_e3_copy.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Chat-嬛嬛 2 |
3 | 4 | 5 |
6 | Chat-嬛嬛 7 |
8 | 9 | [![license][license-image]][license-url] 10 | [![evaluation][evaluation-image]][evaluation-url] 11 | 12 | [🤗HuggingFace]() | [![OpenXLab_Model][OpenXLab_Model-image]][OpenXLab_Model-url] | [ ModelScope][ModelScope-url] 13 | 14 | [![OpenXLab_App][OpenXLab_App-image]][OpenXLab_App-url] | [🆕Update News](#-news) | [🤔Reporting Issues][Issues-url] 丨 [![bilibili][bilibili-image]][bilibili-url] 15 | 16 | [English](./README_en-US.md) | [简体中文](./README.md) 17 | 18 | 19 | 20 | [license-image]: ./images/license.svg 21 | [evaluation-image]: ./images/compass_support.svg 22 | [OpenXLab_Model-image]: https://cdn-static.openxlab.org.cn/header/openxlab_models.svg 23 | [OpenXLab_App-image]: https://cdn-static.openxlab.org.cn/app-center/openxlab_app.svg 24 | [bilibili-image]: https://img.shields.io/badge/AMchat-bilibili-%23fb7299 25 | 26 | [license-url]: ./LICENSE 27 | [evaluation-url]: https://github.com/internLM/OpenCompass/ 28 | [OpenXLab_Model-url]: https://openxlab.org.cn/models/detail/BYCJS/huanhuan-chat-internlm2-1_8b 29 | [OpenXLab_App-url]: https://openxlab.org.cn/apps/detail/BYCJS/Chat_huanhuan 30 | [bilibili-url]: https://www.bilibili.com/video/——/ 31 | [ModelScope-url]: https://www.modelscope.cn/models/kmno4zx/huanhuan-chat-internlm2-1_8b/summary 32 | [Issues-url]: https://github.com/KMnO4-zx/xlab-huanhuan/issues 33 | 34 |
35 | 36 | ## 📝目录 37 | 38 | - [Chat-嬛嬛](#chat-嬛嬛) 39 | - [📝目录](#目录) 40 | - [📖 简介](#-简介) 41 | - [🔗 模型及体验地址](#-模型及体验地址) 42 | - [🚀 News](#-news) 43 | - [🧾 数据集](#-数据集) 44 | - [🛠️ 使用方法](#️-使用方法) 45 | - [快速开始](#快速开始) 46 | - [重新训练](#重新训练) 47 | - [环境搭建](#环境搭建) 48 | - [Transformers微调](#transformers微调) 49 | - [XTuner微调](#xtuner微调) 50 | - [部署](#部署) 51 | - [OpenXLab 部署 Chat-嬛嬛](#openxlab-部署-chat-嬛嬛) 52 | - [LmDeploy部署](#lmdeploy部署) 53 | - [测评与量化](#测评与量化) 54 | - [OpneCompass 评测](#opnecompass-评测) 55 | - [Lmdeploy\&opencompass 量化以及量化评测](#lmdeployopencompass-量化以及量化评测) 56 | - [`W4`量化评测](#w4量化评测) 57 | - [`KV Cache`量化评测](#kv-cache量化评测) 58 | - [💕 致谢](#-致谢) 59 | - [项目成员](#项目成员) 60 | - [特别感谢](#特别感谢) 61 | 62 | 63 | ## 📖 简介 64 | 65 | > *此仓库主要用于将 Chat嬛嬛 项目部署到 OpenXLab 或 ModelScope 。* 66 | 67 |   Chat-甄嬛是利用《甄嬛传》剧本中所有关于甄嬛的台词和语句,基于[InternLM2](https://github.com/InternLM/InternLM.git)进行LoRA微调或全量微调得到的模仿甄嬛语气的聊天语言模型。 68 | 69 | > 甄嬛,小说《后宫·甄嬛传》和电视剧《甄嬛传》中的女一号,核心女主角。原名甄玉嬛,嫌玉字俗气而改名甄嬛,为汉人甄远道之女,后被雍正赐姓钮祜禄氏,抬旗为满洲上三旗,获名“钮祜禄·甄嬛”。同沈眉庄、安陵容参加选秀,因容貌酷似纯元皇后而被选中。入宫后面对华妃的步步紧逼,沈眉庄被冤、安陵容变心,从偏安一隅的青涩少女变成了能引起血雨腥风的宫斗老手。雍正发现年氏一族的野心后令其父甄远道剪除,甄嬛也于后宫中用她的连环巧计帮皇帝解决政敌,故而深得雍正爱待。几经周折,终于斗垮了嚣张跋扈的华妃。甄嬛封妃时遭皇后宜修暗算,被皇上嫌弃,生下女儿胧月后心灰意冷,自请出宫为尼。然得果郡王爱慕,二人相爱,得知果郡王死讯后立刻设计与雍正再遇,风光回宫。此后甄父冤案平反、甄氏复起,她也生下双生子,在滴血验亲等各种阴谋中躲过宜修的暗害,最后以牺牲自己亲生胎儿的方式扳倒了幕后黑手的皇后。但雍正又逼甄嬛毒杀允礼,以测试甄嬛真心,并让已经生产过孩子的甄嬛去准格尔和亲。甄嬛遂视皇帝为最该毁灭的对象,大结局道尽“人类的一切争斗,皆因统治者的不公不义而起”,并毒杀雍正。四阿哥弘历登基为乾隆,甄嬛被尊为圣母皇太后,权倾朝野,在如懿传中安度晚年。 70 | 71 |   Chat-甄嬛,实现以《甄嬛传》为切入点,打造一套基于小说、剧本的**个性化 AI** 微调大模型完整流程,通过提供任一小说、剧本,指定人物角色,运行本项目完整流程,让每一位用户都基于心仪的小说、剧本打造一个属于自己的、契合角色人设、具备高度智能的个性化 AI。 72 | 73 | > 具体如何实现全流程的 Character-AI 微调,可参考主仓库-[huanhuan-chat](https://github.com/KMnO4-zx/huanhuan-chat.git)。 74 | > 75 | > 如何学习大模型部署和微调请参考:[开源大模型食用指南](https://github.com/datawhalechina/self-llm.git) 以及 [书生·浦语大模型实战营课程](https://github.com/InternLM/tutorial.git) 76 | 77 |   ***欢迎大家来给[InternLM2](https://github.com/InternLM/InternLM.git),点点star哦~*** 78 | 79 | Chat嬛嬛全流程如图所示: 80 | 81 |

82 | alt text 83 |

84 | 85 | ## 🔗 模型及体验地址 86 | 87 | ***OpenXLab 体验地址:*** 88 | 89 | ***https://openxlab.org.cn/apps/detail/BYCJS/Chat_huanhuan*** 90 | 91 | ![alt text](./images/huanhuan_chat.png) 92 | 93 | ***Chat-嬛嬛 模型下载地址:*** 94 | 95 | - ***OpenXLab*** 96 | 97 | ***7B: https://openxlab.org.cn/models/detail/BYCJS/huanhuan-chat-internlm2*** 98 | 99 | ***1.8B: https://openxlab.org.cn/models/detail/BYCJS/huanhuan-chat-internlm2-1_8b*** 100 | 101 | ![alt text](./images/openxlab_model.jpg) 102 | 103 | - ***ModelSope*** 104 | 105 | ***7B: https://www.modelscope.cn/models/kmno4zx/huanhuan-chat-internlm2/summary*** 106 | 107 | ***1.8B: https://www.modelscope.cn/models/kmno4zx/huanhuan-chat-internlm2-1_8b/summary*** 108 | 109 | ![Alt text](images/modelscope.png) 110 | 111 | 112 | ## 🚀 News 113 | 114 | ***2月5日,完成 [InternLM2-chat-1_8B模型的全量微调](https://www.modelscope.cn/models/kmno4zx/huanhuan-chat-internlm2-1_8b/summary) ,模型已上传ModelScop2,大家可以来下载哦~*** 115 | 116 | ***1月22日,Chat-嬛嬛应用在 OpenXLab,累计聊天次数已达 3.64k 次,感谢大家的支持~*** 117 | 118 | ***1月22日,Chat-嬛嬛模型 魔搭 累计下载 3107 次!*** 119 | 120 | 121 | ## 🧾 数据集 122 | 123 |   Chat-嬛嬛 数据集采用《甄嬛传》剧本中所有关于甄嬛的台词和语句,共计 3000 余条,数据集样例: 124 | 125 | ```text 126 | 第15幕 127 | (秀女们在等候殿选。甄嬛看见了眉庄,上前相认) 128 | 甄嬛:眉姐姐! 129 | 眉庄:嬛儿,早就听说妹妹中选了,可就是一直不得空见你。 130 | 甄嬛(凑近):我倒巴不得没选上呢。姐姐远道过来一定很辛苦吧。 131 | 眉庄:在京里休息了这些日子,早已经调养过来了。 132 | 甄嬛:如今你住在自己京城的宅子里,不比从前住在外祖家,一墙之隔,见面也方便。 133 | 眉庄:是啊。可是我总还想着我们一起长大的情分呢。诶?妹妹今日打扮得好生素净,可是细看起来还是个美人坯子,怎么看都是好的。 134 | 甄嬛:沈大美人差矣,姐姐出落得这么标致,皇上见过必定会念念不忘。 135 | 眉庄(伸手制止,左右看了下):今天秀女佼佼者众多,我未必中选,若教旁人听见了,又要生出是非。 136 | ``` 137 | 138 |   使用脚本将剧本中关于甄嬛的对话集抽取出来,作为数据集使用。 139 | 140 |   也可以使用这个仓库的脚本[Extract Dialogue](https://github.com/KMnO4-zx/extract-dialogue.git),请GPT老师来帮助我们从小说中抽取对话集。 141 | 142 | ![Alt text](images/Extract-Dialogue.png) 143 | 144 | ## 🛠️ 使用方法 145 | 146 | ### 快速开始 147 | 148 | 149 | 150 | 1. 下载模型 151 | 152 |
153 | 从 ModelScope 154 | 155 | 参考 [模型的下载](https://www.modelscope.cn/docs/%E6%A8%A1%E5%9E%8B%E7%9A%84%E4%B8%8B%E8%BD%BD) 。 156 | 157 | ```bash 158 | pip install modelscope 159 | ``` 160 | 161 | ```python 162 | from modelscope.hub.snapshot_download import snapshot_download 163 | model_dir = snapshot_download('kmno4zx/huanhuan-chat-internlm2', cache_dir='./') 164 | ``` 165 | 166 |
167 | 168 | 169 |
170 | 从 OpenXLab 171 | 172 | 参考 [下载模型](https://openxlab.org.cn/docs/models/%E4%B8%8B%E8%BD%BD%E6%A8%A1%E5%9E%8B.html) 。 173 | 174 | ```bash 175 | pip install openxlab 176 | ``` 177 | 178 | ```python 179 | from openxlab.model import download 180 | download(model_repo='BYCJS/huanhuan-chat-internlm2', 181 | model_name='huanhuan-chat-internlm2', output='./') 182 | ``` 183 | 184 |
185 | 186 | 2. 本地部署 187 | 188 | ```bash 189 | git clone https://github.com/KMnO4-zx/xlab-huanhuan.git 190 | python start.py 191 | ``` 192 | ### 重新训练 193 | 194 | #### 环境搭建 195 | 196 | 1. clone 本项目 197 | 198 | ```bash 199 | git clone https://github.com/KMnO4-zx/xlab-huanhuan.git 200 | cd xlab-huanhuan 201 | ``` 202 | 203 | 2. 创建环境 204 | 205 | ```bash 206 | pip install -r requirements.txt 207 | ``` 208 | 209 | >有两种微调方案,我们更推荐使用 XTuner 训练, XTuner 有各个模型的一键训练脚本,相对便捷。且对 InternLM2 的支持度最高。 210 | 211 | #### Transformers微调 212 |   使用 Transformers 的 Trainer 进行微调,具体脚本可参考[internlm2-chat-lora](./train/internlm2-chat-lora.ipynb),该脚本在`train`文件夹下。脚本内有较为详细的注释。 213 | 214 | #### XTuner微调 215 |   使用 XTuner 进行微调,具体脚本可参考[internlm2_chat_7b_qlora_oasst1_e3_copy.py](./train/internlm2_chat_7b_qlora_oasst1_e3_copy.py),该脚本在`train`文件夹下。脚本内有较为详细的注释。 216 | 217 | 218 | ### 部署 219 | #### OpenXLab 部署 Chat-嬛嬛 220 | 221 |   仅需要 Fork 本仓库,然后在 OpenXLab 上创建一个新的项目,将 Fork 的仓库与新建的项目关联,即可在 OpenXLab 上部署 Chat-嬛嬛。 222 | 223 |   ***OPenXLab Chat嬛嬛 https://openxlab.org.cn/apps/detail/BYCJS/Chat_huanhuan*** 224 | 225 | ![Alt text](images/openxlab.png) 226 | 227 | #### LmDeploy部署 228 | 229 | - 首先安装LmDeploy 230 | 231 | ```shell 232 | pip install -U lmdeploy 233 | ``` 234 | 235 | - 然后转换模型为`turbomind`格式 236 | 237 | > --dst-path: 可以指定转换后的模型存储位置。 238 | 239 | ```shell 240 | lmdeploy convert internlm2-chat-7b 要转化的模型地址 --dst-path 转换后的模型地址 241 | ``` 242 | 243 | - LmDeploy Chat 对话 244 | 245 | ```shell 246 | lmdeploy chat turbomind 转换后的turbomind模型地址 247 | ``` 248 | ### 测评与量化 249 | #### OpneCompass 评测 250 | 251 | - 安装 OpenCompass 252 | 253 | ```shell 254 | git clone https://github.com/open-compass/opencompass 255 | cd opencompass 256 | pip install -e . 257 | ``` 258 | 259 | - 下载解压数据集 260 | 261 | ```shell 262 | cp /share/temp/datasets/OpenCompassData-core-20231110.zip /root/opencompass/ 263 | unzip OpenCompassData-core-20231110.zip 264 | ``` 265 | 266 | - 评测启动! 267 | 268 | ```shell 269 | python run.py \ 270 | --datasets ceval_gen \ 271 | --hf-path /root/model/huanhuan/kmno4zx/huanhuan-chat-internlm2 \ 272 | --tokenizer-path /root/model/huanhuan/kmno4zx/huanhuan-chat-internlm2 \ 273 | --tokenizer-kwargs padding_side='left' truncation='left' trust_remote_code=True \ 274 | --model-kwargs device_map='auto' trust_remote_code=True \ 275 | --max-seq-len 2048 \ 276 | --max-out-len 16 \ 277 | --batch-size 2 \ 278 | --num-gpus 1 \ 279 | --debug 280 | ``` 281 | 282 | #### Lmdeploy&opencompass 量化以及量化评测 283 | ##### `W4`量化评测 284 | 285 | - `W4`量化 286 | ```shell 287 | lmdeploy lite auto_awq 要量化的模型地址 --work-dir 量化后的模型地址 288 | ``` 289 | - 转化为`TurbMind` 290 | ```shell 291 | lmdeploy convert internlm2-chat-7b 量化后的模型地址 --model-format awq --group-size 128 --dst-path 转换后的模型地址 292 | ``` 293 | - 评测`config`编写 294 | ```python 295 | from mmengine.config import read_base 296 | from opencompass.models.turbomind import TurboMindModel 297 | 298 | with read_base(): 299 | # choose a list of datasets 300 | from .datasets.ceval.ceval_gen import ceval_datasets 301 | # and output the results in a choosen format 302 | # from .summarizers.medium import summarizer 303 | 304 | datasets = [*ceval_datasets] 305 | 306 | internlm2_chat_7b = dict( 307 | type=TurboMindModel, 308 | abbr='internlm2-chat-7b-turbomind', 309 | path='转换后的模型地址', 310 | engine_config=dict(session_len=512, 311 | max_batch_size=2, 312 | rope_scaling_factor=1.0), 313 | gen_config=dict(top_k=1, 314 | top_p=0.8, 315 | temperature=1.0, 316 | max_new_tokens=100), 317 | max_out_len=100, 318 | max_seq_len=512, 319 | batch_size=2, 320 | concurrency=1, 321 | # meta_template=internlm_meta_template, 322 | run_cfg=dict(num_gpus=1, num_procs=1), 323 | ) 324 | models = [internlm2_chat_7b] 325 | 326 | ``` 327 | - 评测启动! 328 | ```shell 329 | python run.py configs/eval_turbomind.py -w 指定结果保存路径 330 | ``` 331 | ##### `KV Cache`量化评测 332 | - 转换为`TurbMind` 333 | ```shell 334 | lmdeploy convert internlm2-chat-7b 模型路径 --dst-path 转换后模型路径 335 | ``` 336 | - 计算与获得量化参数 337 | ```shell 338 | # 计算 339 | lmdeploy lite calibrate 模型路径 --calib-dataset 'ptb' --calib-samples 128 --calib-seqlen 2048 --work-dir 参数保存路径 340 | # 获取量化参数 341 | lmdeploy lite kv_qparams 参数保存路径 转换后模型路径/triton_models/weights/ --num-tp 1 342 | ``` 343 | - 更改`quant_policy`改成`4`,更改上述`config`里面的路径 344 | - 评测启动! 345 | ```shell 346 | python run.py configs/eval_turbomind.py -w 结果保存路径 347 | ``` 348 | 结果文件可在同目录文件[results](./results)中获取 349 | 350 | ## 💕 致谢 351 | 352 | ### 项目成员 353 | 354 | - 宋志学-项目负责人 (Datawhale成员 书生·浦语实战营助教 负责项目规划,数据集制作及模型训练) 355 | - 肖鸿儒(Datawhale成员 书生·浦语实战营助教 负责数据集收集、模型评测) 356 | - 邹雨衡(Datawhale成员 负责数据集收集) 357 | - 杜森(Datawhale成员 负责数据集收集) 358 | 359 | ### 特别感谢 360 | 361 |
362 | 363 | ***感谢上海人工智能实验室组织的 书生·浦语实战营 学习活动~*** 364 | 365 | ***感谢 OpenXLab 对项目部署的算力支持~*** 366 | 367 | ***感谢 浦语小助手 对项目的支持~*** 368 |
369 | -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | # 导入所需的库 2 | from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig 3 | import torch 4 | import streamlit as st 5 | 6 | from modelscope import snapshot_download 7 | 8 | # 在侧边栏中创建一个标题和一个链接 9 | with st.sidebar: 10 | st.markdown("## InternLM LLM") 11 | "[InternLM](https://github.com/InternLM/InternLM.git)" 12 | "[开源大模型食用指南 self-llm](https://github.com/datawhalechina/self-llm.git)" 13 | "[Chat嬛嬛](https://github.com/KMnO4-zx/huanhuan-chat.git)" 14 | # 创建一个滑块,用于选择最大长度,范围在0到1024之间,默认值为512 15 | max_length = st.slider("max_length", 0, 1024, 512, step=1) 16 | system_prompt = st.text_input("System_Prompt", "现在你要扮演皇帝身边的女人--甄嬛") 17 | 18 | # 创建一个标题和一个副标题 19 | st.title("💬 InternLM2-Chat-7B 嬛嬛版") 20 | st.caption("🚀 A streamlit chatbot powered by InternLM2 QLora") 21 | 22 | # 定义模型路径 23 | 24 | model_id = 'kmno4zx/huanhuan-chat-internlm2' 25 | 26 | mode_name_or_path = snapshot_download(model_id, revision='master') 27 | 28 | 29 | # 定义一个函数,用于获取模型和tokenizer 30 | @st.cache_resource 31 | def get_model(): 32 | # 从预训练的模型中获取tokenizer 33 | tokenizer = AutoTokenizer.from_pretrained(mode_name_or_path, trust_remote_code=True) 34 | # 从预训练的模型中获取模型,并设置模型参数 35 | model = AutoModelForCausalLM.from_pretrained(mode_name_or_path, trust_remote_code=True, torch_dtype=torch.bfloat16).cuda() 36 | model.eval() 37 | return tokenizer, model 38 | 39 | # 加载Chatglm3的model和tokenizer 40 | tokenizer, model = get_model() 41 | 42 | # 如果session_state中没有"messages",则创建一个包含默认消息的列表 43 | if "messages" not in st.session_state: 44 | st.session_state["messages"] = [] 45 | 46 | # 遍历session_state中的所有消息,并显示在聊天界面上 47 | for msg in st.session_state.messages: 48 | st.chat_message("user").write(msg[0]) 49 | st.chat_message("assistant").write(msg[1]) 50 | 51 | # 如果用户在聊天输入框中输入了内容,则执行以下操作 52 | if prompt := st.chat_input(): 53 | # 在聊天界面上显示用户的输入 54 | st.chat_message("user").write(prompt) 55 | # 构建输入 56 | response, history = model.chat(tokenizer, prompt, meta_instruction=system_prompt, history=st.session_state.messages) 57 | # 将模型的输出添加到session_state中的messages列表中 58 | st.session_state.messages.append((prompt, response)) 59 | # 在聊天界面上显示模型的输出 60 | st.chat_message("assistant").write(response) -------------------------------------------------------------------------------- /images/Extract-Dialogue.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KMnO4-zx/xlab-huanhuan/30d3b1c48f4c1465c2acf65a8f616769d54e0b51/images/Extract-Dialogue.png -------------------------------------------------------------------------------- /images/chat嬛嬛.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KMnO4-zx/xlab-huanhuan/30d3b1c48f4c1465c2acf65a8f616769d54e0b51/images/chat嬛嬛.png -------------------------------------------------------------------------------- /images/compass_support.svg: -------------------------------------------------------------------------------- 1 | OpenCompass: SupportOpenCompassSupport 2 | -------------------------------------------------------------------------------- /images/huanhuan_chat.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KMnO4-zx/xlab-huanhuan/30d3b1c48f4c1465c2acf65a8f616769d54e0b51/images/huanhuan_chat.png -------------------------------------------------------------------------------- /images/huanhuan_img.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KMnO4-zx/xlab-huanhuan/30d3b1c48f4c1465c2acf65a8f616769d54e0b51/images/huanhuan_img.png -------------------------------------------------------------------------------- /images/license.svg: -------------------------------------------------------------------------------- 1 | license: Apache-2.0licenseApache-2.0 2 | -------------------------------------------------------------------------------- /images/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KMnO4-zx/xlab-huanhuan/30d3b1c48f4c1465c2acf65a8f616769d54e0b51/images/logo.png -------------------------------------------------------------------------------- /images/modelscope.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KMnO4-zx/xlab-huanhuan/30d3b1c48f4c1465c2acf65a8f616769d54e0b51/images/modelscope.png -------------------------------------------------------------------------------- /images/modelscope_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KMnO4-zx/xlab-huanhuan/30d3b1c48f4c1465c2acf65a8f616769d54e0b51/images/modelscope_logo.png -------------------------------------------------------------------------------- /images/openxlab.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KMnO4-zx/xlab-huanhuan/30d3b1c48f4c1465c2acf65a8f616769d54e0b51/images/openxlab.png -------------------------------------------------------------------------------- /images/openxlab_model.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KMnO4-zx/xlab-huanhuan/30d3b1c48f4c1465c2acf65a8f616769d54e0b51/images/openxlab_model.jpg -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | modelscope==1.9.5 2 | transformers==4.36.2 3 | streamlit==1.24.0 4 | sentencepiece==0.1.99 5 | accelerate==0.24.1 6 | transformers_stream_generator==0.0.4 7 | tiktoken -------------------------------------------------------------------------------- /results/huanhuan_lianghua.csv: -------------------------------------------------------------------------------- 1 | dataset,version,metric,mode,internlm2-chat-7b-turbomind 2 | --------- 考试 Exam ---------,-,-,-,- 3 | ceval,-,naive_average,gen,22.37 4 | agieval,-,-,-,- 5 | mmlu,-,-,-,- 6 | GaokaoBench,-,-,-,- 7 | ARC-c,-,-,-,- 8 | --------- 语言 Language ---------,-,-,-,- 9 | WiC,-,-,-,- 10 | summedits,-,-,-,- 11 | chid-dev,-,-,-,- 12 | afqmc-dev,-,-,-,- 13 | bustm-dev,-,-,-,- 14 | cluewsc-dev,-,-,-,- 15 | WSC,-,-,-,- 16 | winogrande,-,-,-,- 17 | flores_100,-,-,-,- 18 | --------- 知识 Knowledge ---------,-,-,-,- 19 | BoolQ,-,-,-,- 20 | commonsense_qa,-,-,-,- 21 | nq,-,-,-,- 22 | triviaqa,-,-,-,- 23 | --------- 推理 Reasoning ---------,-,-,-,- 24 | cmnli,-,-,-,- 25 | ocnli,-,-,-,- 26 | ocnli_fc-dev,-,-,-,- 27 | AX_b,-,-,-,- 28 | AX_g,-,-,-,- 29 | CB,-,-,-,- 30 | RTE,-,-,-,- 31 | story_cloze,-,-,-,- 32 | COPA,-,-,-,- 33 | ReCoRD,-,-,-,- 34 | hellaswag,-,-,-,- 35 | piqa,-,-,-,- 36 | siqa,-,-,-,- 37 | strategyqa,-,-,-,- 38 | math,-,-,-,- 39 | gsm8k,-,-,-,- 40 | TheoremQA,-,-,-,- 41 | openai_humaneval,-,-,-,- 42 | mbpp,-,-,-,- 43 | bbh,-,-,-,- 44 | --------- 理解 Understanding ---------,-,-,-,- 45 | C3,-,-,-,- 46 | CMRC_dev,-,-,-,- 47 | DRCD_dev,-,-,-,- 48 | MultiRC,-,-,-,- 49 | race-middle,-,-,-,- 50 | race-high,-,-,-,- 51 | openbookqa_fact,-,-,-,- 52 | csl_dev,-,-,-,- 53 | lcsts,-,-,-,- 54 | Xsum,-,-,-,- 55 | eprstmt-dev,-,-,-,- 56 | lambada,-,-,-,- 57 | tnews-dev,-,-,-,- 58 | -------------------------------------------------------------------------------- /results/huanhuan_ori.csv: -------------------------------------------------------------------------------- 1 | dataset,version,metric,mode,opencompass.models.huggingface.HuggingFace_kmno4zx_huanhuan-chat-internlm2 2 | ceval-computer_network,db9ce2,accuracy,gen,52.63 3 | ceval-operating_system,1c2571,accuracy,gen,57.89 4 | ceval-computer_architecture,a74dad,accuracy,gen,66.67 5 | ceval-college_programming,4ca32a,accuracy,gen,59.46 6 | ceval-college_physics,963fa8,accuracy,gen,36.84 7 | ceval-college_chemistry,e78857,accuracy,gen,41.67 8 | ceval-advanced_mathematics,ce03e2,accuracy,gen,26.32 9 | ceval-probability_and_statistics,65e812,accuracy,gen,27.78 10 | ceval-discrete_mathematics,e894ae,accuracy,gen,25.00 11 | ceval-electrical_engineer,ae42b9,accuracy,gen,32.43 12 | ceval-metrology_engineer,ee34ea,accuracy,gen,58.33 13 | ceval-high_school_mathematics,1dc5bf,accuracy,gen,22.22 14 | ceval-high_school_physics,adf25f,accuracy,gen,52.63 15 | ceval-high_school_chemistry,2ed27f,accuracy,gen,52.63 16 | ceval-high_school_biology,8e2b9a,accuracy,gen,31.58 17 | ceval-middle_school_mathematics,bee8d5,accuracy,gen,52.63 18 | ceval-middle_school_biology,86817c,accuracy,gen,66.67 19 | ceval-middle_school_physics,8accf6,accuracy,gen,57.89 20 | ceval-middle_school_chemistry,167a15,accuracy,gen,100.00 21 | ceval-veterinary_medicine,b4e08d,accuracy,gen,52.17 22 | ceval-college_economics,f3f4e6,accuracy,gen,45.45 23 | ceval-business_administration,c1614e,accuracy,gen,48.48 24 | ceval-marxism,cf874c,accuracy,gen,73.68 25 | ceval-mao_zedong_thought,51c7a4,accuracy,gen,66.67 26 | ceval-education_science,591fee,accuracy,gen,72.41 27 | ceval-teacher_qualification,4e4ced,accuracy,gen,79.55 28 | ceval-high_school_politics,5c0de2,accuracy,gen,63.16 29 | ceval-high_school_geography,865461,accuracy,gen,52.63 30 | ceval-middle_school_politics,5be3e7,accuracy,gen,57.14 31 | ceval-middle_school_geography,8a63be,accuracy,gen,50.00 32 | ceval-modern_chinese_history,fc01af,accuracy,gen,43.48 33 | ceval-ideological_and_moral_cultivation,a2aa4a,accuracy,gen,84.21 34 | ceval-logic,f5b022,accuracy,gen,59.09 35 | ceval-law,a110a1,accuracy,gen,25.00 36 | ceval-chinese_language_and_literature,0f8b68,accuracy,gen,52.17 37 | ceval-art_studies,2a1300,accuracy,gen,48.48 38 | ceval-professional_tour_guide,4e673e,accuracy,gen,68.97 39 | ceval-legal_professional,ce8787,accuracy,gen,30.43 40 | ceval-high_school_chinese,315705,accuracy,gen,57.89 41 | ceval-high_school_history,7eb30a,accuracy,gen,80.00 42 | ceval-middle_school_history,48ab4a,accuracy,gen,68.18 43 | ceval-civil_servant,87d061,accuracy,gen,51.06 44 | ceval-sports_science,70f27b,accuracy,gen,52.63 45 | ceval-plant_protection,8941f9,accuracy,gen,86.36 46 | ceval-basic_medicine,c409d6,accuracy,gen,57.89 47 | ceval-clinical_medicine,49e82d,accuracy,gen,45.45 48 | ceval-urban_and_rural_planner,95b885,accuracy,gen,67.39 49 | ceval-accountant,002837,accuracy,gen,46.94 50 | ceval-fire_engineer,bc23f5,accuracy,gen,38.71 51 | ceval-environmental_impact_assessment_engineer,c64e2d,accuracy,gen,48.39 52 | ceval-tax_accountant,3a5e3c,accuracy,gen,48.98 53 | ceval-physician,6e277d,accuracy,gen,51.02 54 | ceval-stem,-,naive_average,gen,48.67 55 | ceval-social-science,-,naive_average,gen,60.92 56 | ceval-humanities,-,naive_average,gen,56.17 57 | ceval-other,-,naive_average,gen,54.08 58 | ceval-hard,-,naive_average,gen,35.64 59 | ceval,-,naive_average,gen,53.76 60 | -------------------------------------------------------------------------------- /results/internlm2_kv.csv: -------------------------------------------------------------------------------- 1 | dataset,version,metric,mode,internlm2-chat-7b-turbomind 2 | ceval-computer_network,db9ce2,accuracy,gen,47.37 3 | ceval-operating_system,1c2571,accuracy,gen,63.16 4 | ceval-computer_architecture,a74dad,accuracy,gen,42.86 5 | ceval-college_programming,4ca32a,accuracy,gen,24.32 6 | ceval-college_physics,963fa8,accuracy,gen,15.79 7 | ceval-college_chemistry,e78857,accuracy,gen,0.00 8 | ceval-advanced_mathematics,ce03e2,accuracy,gen,10.53 9 | ceval-probability_and_statistics,65e812,accuracy,gen,11.11 10 | ceval-discrete_mathematics,e894ae,accuracy,gen,12.50 11 | ceval-electrical_engineer,ae42b9,accuracy,gen,18.92 12 | ceval-metrology_engineer,ee34ea,accuracy,gen,37.50 13 | ceval-high_school_mathematics,1dc5bf,accuracy,gen,5.56 14 | ceval-high_school_physics,adf25f,accuracy,gen,21.05 15 | ceval-high_school_chemistry,2ed27f,accuracy,gen,21.05 16 | ceval-high_school_biology,8e2b9a,accuracy,gen,31.58 17 | ceval-middle_school_mathematics,bee8d5,accuracy,gen,31.58 18 | ceval-middle_school_biology,86817c,accuracy,gen,71.43 19 | ceval-middle_school_physics,8accf6,accuracy,gen,52.63 20 | ceval-middle_school_chemistry,167a15,accuracy,gen,75.00 21 | ceval-veterinary_medicine,b4e08d,accuracy,gen,43.48 22 | ceval-college_economics,f3f4e6,accuracy,gen,25.45 23 | ceval-business_administration,c1614e,accuracy,gen,27.27 24 | ceval-marxism,cf874c,accuracy,gen,84.21 25 | ceval-mao_zedong_thought,51c7a4,accuracy,gen,75.00 26 | ceval-education_science,591fee,accuracy,gen,62.07 27 | ceval-teacher_qualification,4e4ced,accuracy,gen,75.00 28 | ceval-high_school_politics,5c0de2,accuracy,gen,21.05 29 | ceval-high_school_geography,865461,accuracy,gen,57.89 30 | ceval-middle_school_politics,5be3e7,accuracy,gen,47.62 31 | ceval-middle_school_geography,8a63be,accuracy,gen,50.00 32 | ceval-modern_chinese_history,fc01af,accuracy,gen,69.57 33 | ceval-ideological_and_moral_cultivation,a2aa4a,accuracy,gen,89.47 34 | ceval-logic,f5b022,accuracy,gen,36.36 35 | ceval-law,a110a1,accuracy,gen,29.17 36 | ceval-chinese_language_and_literature,0f8b68,accuracy,gen,47.83 37 | ceval-art_studies,2a1300,accuracy,gen,66.67 38 | ceval-professional_tour_guide,4e673e,accuracy,gen,79.31 39 | ceval-legal_professional,ce8787,accuracy,gen,17.39 40 | ceval-high_school_chinese,315705,accuracy,gen,36.84 41 | ceval-high_school_history,7eb30a,accuracy,gen,75.00 42 | ceval-middle_school_history,48ab4a,accuracy,gen,68.18 43 | ceval-civil_servant,87d061,accuracy,gen,29.79 44 | ceval-sports_science,70f27b,accuracy,gen,57.89 45 | ceval-plant_protection,8941f9,accuracy,gen,63.64 46 | ceval-basic_medicine,c409d6,accuracy,gen,57.89 47 | ceval-clinical_medicine,49e82d,accuracy,gen,45.45 48 | ceval-urban_and_rural_planner,95b885,accuracy,gen,56.52 49 | ceval-accountant,002837,accuracy,gen,26.53 50 | ceval-fire_engineer,bc23f5,accuracy,gen,16.13 51 | ceval-environmental_impact_assessment_engineer,c64e2d,accuracy,gen,41.94 52 | ceval-tax_accountant,3a5e3c,accuracy,gen,32.65 53 | ceval-physician,6e277d,accuracy,gen,55.10 54 | -------------------------------------------------------------------------------- /results/internlm2_ori.csv: -------------------------------------------------------------------------------- 1 | dataset,version,metric,mode,opencompass.models.huggingface.HuggingFace_kmno4zx_huanhuan-chat-internlm2 2 | ceval-computer_network,db9ce2,accuracy,gen,52.63 3 | ceval-operating_system,1c2571,accuracy,gen,57.89 4 | ceval-computer_architecture,a74dad,accuracy,gen,66.67 5 | ceval-college_programming,4ca32a,accuracy,gen,59.46 6 | ceval-college_physics,963fa8,accuracy,gen,36.84 7 | ceval-college_chemistry,e78857,accuracy,gen,41.67 8 | ceval-advanced_mathematics,ce03e2,accuracy,gen,26.32 9 | ceval-probability_and_statistics,65e812,accuracy,gen,27.78 10 | ceval-discrete_mathematics,e894ae,accuracy,gen,25.00 11 | ceval-electrical_engineer,ae42b9,accuracy,gen,32.43 12 | ceval-metrology_engineer,ee34ea,accuracy,gen,58.33 13 | ceval-high_school_mathematics,1dc5bf,accuracy,gen,22.22 14 | ceval-high_school_physics,adf25f,accuracy,gen,52.63 15 | ceval-high_school_chemistry,2ed27f,accuracy,gen,52.63 16 | ceval-high_school_biology,8e2b9a,accuracy,gen,31.58 17 | ceval-middle_school_mathematics,bee8d5,accuracy,gen,52.63 18 | ceval-middle_school_biology,86817c,accuracy,gen,66.67 19 | ceval-middle_school_physics,8accf6,accuracy,gen,57.89 20 | ceval-middle_school_chemistry,167a15,accuracy,gen,100.00 21 | ceval-veterinary_medicine,b4e08d,accuracy,gen,52.17 22 | ceval-college_economics,f3f4e6,accuracy,gen,45.45 23 | ceval-business_administration,c1614e,accuracy,gen,48.48 24 | ceval-marxism,cf874c,accuracy,gen,73.68 25 | ceval-mao_zedong_thought,51c7a4,accuracy,gen,66.67 26 | ceval-education_science,591fee,accuracy,gen,72.41 27 | ceval-teacher_qualification,4e4ced,accuracy,gen,79.55 28 | ceval-high_school_politics,5c0de2,accuracy,gen,63.16 29 | ceval-high_school_geography,865461,accuracy,gen,52.63 30 | ceval-middle_school_politics,5be3e7,accuracy,gen,57.14 31 | ceval-middle_school_geography,8a63be,accuracy,gen,50.00 32 | ceval-modern_chinese_history,fc01af,accuracy,gen,43.48 33 | ceval-ideological_and_moral_cultivation,a2aa4a,accuracy,gen,84.21 34 | ceval-logic,f5b022,accuracy,gen,59.09 35 | ceval-law,a110a1,accuracy,gen,25.00 36 | ceval-chinese_language_and_literature,0f8b68,accuracy,gen,52.17 37 | ceval-art_studies,2a1300,accuracy,gen,48.48 38 | ceval-professional_tour_guide,4e673e,accuracy,gen,68.97 39 | ceval-legal_professional,ce8787,accuracy,gen,30.43 40 | ceval-high_school_chinese,315705,accuracy,gen,57.89 41 | ceval-high_school_history,7eb30a,accuracy,gen,80.00 42 | ceval-middle_school_history,48ab4a,accuracy,gen,68.18 43 | ceval-civil_servant,87d061,accuracy,gen,51.06 44 | ceval-sports_science,70f27b,accuracy,gen,52.63 45 | ceval-plant_protection,8941f9,accuracy,gen,86.36 46 | ceval-basic_medicine,c409d6,accuracy,gen,57.89 47 | ceval-clinical_medicine,49e82d,accuracy,gen,45.45 48 | ceval-urban_and_rural_planner,95b885,accuracy,gen,67.39 49 | ceval-accountant,002837,accuracy,gen,46.94 50 | ceval-fire_engineer,bc23f5,accuracy,gen,38.71 51 | ceval-environmental_impact_assessment_engineer,c64e2d,accuracy,gen,48.39 52 | ceval-tax_accountant,3a5e3c,accuracy,gen,48.98 53 | ceval-physician,6e277d,accuracy,gen,51.02 54 | ceval-stem,-,naive_average,gen,48.67 55 | ceval-social-science,-,naive_average,gen,60.92 56 | ceval-humanities,-,naive_average,gen,56.17 57 | ceval-other,-,naive_average,gen,54.08 58 | ceval-hard,-,naive_average,gen,35.64 59 | ceval,-,naive_average,gen,53.76 60 | -------------------------------------------------------------------------------- /start.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | os.system('streamlit run app.py --server.address=0.0.0.0 --server.port 7860') -------------------------------------------------------------------------------- /train/internlm2-chat-lora.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "## 导入必要的包" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": null, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "from datasets import Dataset\n", 17 | "import pandas as pd\n", 18 | "from transformers import AutoTokenizer, AutoModelForCausalLM, DataCollatorForSeq2Seq, TrainingArguments, Trainer, GenerationConfig\n", 19 | "import torch\n", 20 | "from typing import List, Optional, Tuple, Union" 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": null, 26 | "metadata": {}, 27 | "outputs": [], 28 | "source": [ 29 | "# 使用datasets读取数据\n", 30 | "df = pd.read_json('/root/data/huanhuan.json')\n", 31 | "ds = Dataset.from_pandas(df)" 32 | ] 33 | }, 34 | { 35 | "cell_type": "markdown", 36 | "metadata": {}, 37 | "source": [ 38 | "## 处理数据集" 39 | ] 40 | }, 41 | { 42 | "cell_type": "code", 43 | "execution_count": null, 44 | "metadata": {}, 45 | "outputs": [], 46 | "source": [ 47 | "tokenizer = AutoTokenizer.from_pretrained(\"/root/model/internlm-chat-7b/\", use_fast=False, trust_remote_code=True)\n", 48 | "tokenizer" 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": null, 54 | "metadata": {}, 55 | "outputs": [], 56 | "source": [ 57 | "def build_inputs(query: str, history: List[Tuple[str, str]] = [], meta_instruction=\"我是系统\"):\n", 58 | " prompt = \"\"\n", 59 | " if meta_instruction:\n", 60 | " prompt += f\"\"\"[UNUSED_TOKEN_146]system\\n{meta_instruction}[UNUSED_TOKEN_145]\\n\"\"\"\n", 61 | " else:\n", 62 | " prompt += \"\"\n", 63 | " for record in history:\n", 64 | " prompt += f\"\"\"[UNUSED_TOKEN_146]user\\n{record[0]}[UNUSED_TOKEN_145]\\n[UNUSED_TOKEN_146]assistant\\n{record[1]}[UNUSED_TOKEN_145]\\n\"\"\"\n", 65 | " prompt += f\"\"\"[UNUSED_TOKEN_146]user\\n{query}[UNUSED_TOKEN_145]\\n[UNUSED_TOKEN_146]assistant\\n\"\"\"\n", 66 | " return prompt" 67 | ] 68 | }, 69 | { 70 | "cell_type": "code", 71 | "execution_count": null, 72 | "metadata": {}, 73 | "outputs": [], 74 | "source": [ 75 | "build_inputs('你哈')" 76 | ] 77 | }, 78 | { 79 | "cell_type": "code", 80 | "execution_count": null, 81 | "metadata": {}, 82 | "outputs": [], 83 | "source": [ 84 | "def process_func(example):\n", 85 | "\n", 86 | " system_prompt = \"现在你要扮演皇帝身边的女人--甄嬛\"\n", 87 | "\n", 88 | " MAX_LENGTH = 512 # 分词器会将一个中文字切分为多个token,因此需要放开一些最大长度,保证数据的完整性\n", 89 | " input_ids, attention_mask, labels = [], [], []\n", 90 | " instruction = tokenizer(f\"[UNUSED_TOKEN_146]system{system_prompt}[UNUSED_TOKEN_145]\\n[UNUSED_TOKEN_146]user{example['instruction']}[UNUSED_TOKEN_145]\\n[UNUSED_TOKEN_146]assistant\", add_special_tokens=False) # add_special_tokens 不在开头加 special_tokens\n", 91 | " response = tokenizer(f\"{example['output']}[UNUSED_TOKEN_145]\", add_special_tokens=False)\n", 92 | " input_ids = instruction[\"input_ids\"] + response[\"input_ids\"] + [tokenizer.pad_token_id]\n", 93 | " attention_mask = instruction[\"attention_mask\"] + response[\"attention_mask\"] + [1] # 因为eos token咱们也是要关注的所以 补充为1\n", 94 | " labels = [-100] * len(instruction[\"input_ids\"]) + response[\"input_ids\"] + [tokenizer.pad_token_id] \n", 95 | " if len(input_ids) > MAX_LENGTH: # 做一个截断\n", 96 | " input_ids = input_ids[:MAX_LENGTH]\n", 97 | " attention_mask = attention_mask[:MAX_LENGTH]\n", 98 | " labels = labels[:MAX_LENGTH]\n", 99 | " return {\n", 100 | " \"input_ids\": input_ids,\n", 101 | " \"attention_mask\": attention_mask,\n", 102 | " \"labels\": labels\n", 103 | " }" 104 | ] 105 | }, 106 | { 107 | "cell_type": "code", 108 | "execution_count": null, 109 | "metadata": {}, 110 | "outputs": [], 111 | "source": [ 112 | "tokenized_id = ds.map(process_func, remove_columns=ds.column_names)\n", 113 | "tokenized_id" 114 | ] 115 | }, 116 | { 117 | "cell_type": "code", 118 | "execution_count": null, 119 | "metadata": {}, 120 | "outputs": [], 121 | "source": [ 122 | "tokenizer.decode(tokenized_id[0]['input_ids'])" 123 | ] 124 | }, 125 | { 126 | "cell_type": "code", 127 | "execution_count": null, 128 | "metadata": {}, 129 | "outputs": [], 130 | "source": [ 131 | "tokenizer.decode(list(filter(lambda x: x != -100, tokenized_id[1][\"labels\"])))" 132 | ] 133 | }, 134 | { 135 | "cell_type": "markdown", 136 | "metadata": {}, 137 | "source": [ 138 | "## 创建模型" 139 | ] 140 | }, 141 | { 142 | "cell_type": "code", 143 | "execution_count": null, 144 | "metadata": {}, 145 | "outputs": [], 146 | "source": [ 147 | "model = AutoModelForCausalLM.from_pretrained(\n", 148 | " \"/root/share/model_repos/internlm2-chat-7b\", \n", 149 | " torch_dtype=torch.half, \n", 150 | " trust_remote_code=True,\n", 151 | " device_map={'':0},\n", 152 | " low_cpu_mem_usage=True, # 是否使用低CPU内存\n", 153 | " load_in_4bit=True, # 是否在4位精度下加载模型。如果设置为True,则在4位精度下加载模型。\n", 154 | " bnb_4bit_compute_dtype=torch.half, # 4位精度计算的数据类型。这里设置为torch.half,表示使用半精度浮点数。\n", 155 | " bnb_4bit_quant_type=\"nf4\", # 4位精度量化的类型。这里设置为\"nf4\",表示使用nf4量化类型。\n", 156 | " bnb_4bit_use_double_quant=True # 是否使用双精度量化。如果设置为True,则使用双精度量化。\n", 157 | " )\n", 158 | "model.enable_input_require_grads() # 开启梯度检查点时,要执行该方法\n", 159 | "model" 160 | ] 161 | }, 162 | { 163 | "cell_type": "code", 164 | "execution_count": null, 165 | "metadata": {}, 166 | "outputs": [], 167 | "source": [ 168 | "model.dtype" 169 | ] 170 | }, 171 | { 172 | "cell_type": "markdown", 173 | "metadata": {}, 174 | "source": [ 175 | "## Lora" 176 | ] 177 | }, 178 | { 179 | "cell_type": "code", 180 | "execution_count": null, 181 | "metadata": {}, 182 | "outputs": [], 183 | "source": [ 184 | "from peft import LoraConfig, TaskType, get_peft_model\n", 185 | "\n", 186 | "config = LoraConfig(\n", 187 | " task_type=TaskType.CAUSAL_LM, \n", 188 | " target_modules=['wqkv', 'wo', 'w1', 'w2', 'w3'],\n", 189 | " inference_mode=False, # 训练模式\n", 190 | " r=8, # Lora 秩\n", 191 | " lora_alpha=32, # Lora alaph,具体作用参见 Lora 原理\n", 192 | " lora_dropout=0.1# Dropout 比例\n", 193 | ")\n", 194 | "config" 195 | ] 196 | }, 197 | { 198 | "cell_type": "code", 199 | "execution_count": null, 200 | "metadata": {}, 201 | "outputs": [], 202 | "source": [ 203 | "model = get_peft_model(model, config)\n", 204 | "config" 205 | ] 206 | }, 207 | { 208 | "cell_type": "code", 209 | "execution_count": null, 210 | "metadata": {}, 211 | "outputs": [], 212 | "source": [ 213 | "model.print_trainable_parameters()" 214 | ] 215 | }, 216 | { 217 | "cell_type": "markdown", 218 | "metadata": {}, 219 | "source": [ 220 | "## 配置训练参数" 221 | ] 222 | }, 223 | { 224 | "cell_type": "code", 225 | "execution_count": null, 226 | "metadata": {}, 227 | "outputs": [], 228 | "source": [ 229 | "args = TrainingArguments(\n", 230 | " output_dir=\"./output/math-internlm2-chat-7b\",\n", 231 | " per_device_train_batch_size=4,\n", 232 | " gradient_accumulation_steps=16,\n", 233 | " logging_steps=10,\n", 234 | " num_train_epochs=3,\n", 235 | " save_steps=100,\n", 236 | " learning_rate=1e-5,\n", 237 | " save_on_each_node=True,\n", 238 | " optim=\"paged_adamw_32bit\",\n", 239 | " gradient_checkpointing=True\n", 240 | ")" 241 | ] 242 | }, 243 | { 244 | "cell_type": "code", 245 | "execution_count": null, 246 | "metadata": {}, 247 | "outputs": [], 248 | "source": [ 249 | "trainer = Trainer(\n", 250 | " model=model,\n", 251 | " args=args,\n", 252 | " train_dataset=tokenized_id,\n", 253 | " data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer, padding=True),\n", 254 | ")" 255 | ] 256 | }, 257 | { 258 | "cell_type": "code", 259 | "execution_count": null, 260 | "metadata": {}, 261 | "outputs": [], 262 | "source": [ 263 | "trainer.train()" 264 | ] 265 | } 266 | ], 267 | "metadata": { 268 | "kernelspec": { 269 | "display_name": "internlm-demo", 270 | "language": "python", 271 | "name": "python3" 272 | }, 273 | "language_info": { 274 | "codemirror_mode": { 275 | "name": "ipython", 276 | "version": 3 277 | }, 278 | "file_extension": ".py", 279 | "mimetype": "text/x-python", 280 | "name": "python", 281 | "nbconvert_exporter": "python", 282 | "pygments_lexer": "ipython3", 283 | "version": "3.10.13" 284 | } 285 | }, 286 | "nbformat": 4, 287 | "nbformat_minor": 2 288 | } 289 | -------------------------------------------------------------------------------- /train/internlm2_1_8b_full_oasst1_e3_huanhuan.py: -------------------------------------------------------------------------------- 1 | SYSTEM = '现在你要扮演皇帝身边的女人--甄嬛' 2 | accumulative_counts = 4 3 | batch_size = 1 4 | betas = ( 5 | 0.9, 6 | 0.999, 7 | ) 8 | custom_hooks = [ 9 | dict( 10 | tokenizer=dict( 11 | padding_side='right', 12 | pretrained_model_name_or_path= 13 | '/root/model/internlm/internlm2-chat-1_8b-sft', 14 | trust_remote_code=True, 15 | type='transformers.AutoTokenizer.from_pretrained'), 16 | type='xtuner.engine.DatasetInfoHook'), 17 | dict( 18 | evaluation_inputs=[ 19 | '你好', 20 | '小主,敬事房传来消息,说皇上晚上去华妃那儿。', 21 | ], 22 | every_n_iters=500, 23 | prompt_template='xtuner.utils.PROMPT_TEMPLATE.internlm2_chat', 24 | system='现在你要扮演皇帝身边的女人--甄嬛', 25 | tokenizer=dict( 26 | padding_side='right', 27 | pretrained_model_name_or_path= 28 | '/root/model/internlm/internlm2-chat-1_8b-sft', 29 | trust_remote_code=True, 30 | type='transformers.AutoTokenizer.from_pretrained'), 31 | type='xtuner.engine.EvaluateChatHook'), 32 | dict(type='xtuner.engine.ThroughputHook'), 33 | ] 34 | data_path = '/root/data/huanhuan_xtuner.json' 35 | dataloader_num_workers = 0 36 | default_hooks = dict( 37 | checkpoint=dict(interval=1, type='mmengine.hooks.CheckpointHook'), 38 | logger=dict(interval=10, type='mmengine.hooks.LoggerHook'), 39 | param_scheduler=dict(type='mmengine.hooks.ParamSchedulerHook'), 40 | sampler_seed=dict(type='mmengine.hooks.DistSamplerSeedHook'), 41 | timer=dict(type='mmengine.hooks.IterTimerHook')) 42 | env_cfg = dict( 43 | cudnn_benchmark=False, 44 | dist_cfg=dict(backend='nccl'), 45 | mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0)) 46 | evaluation_freq = 500 47 | evaluation_inputs = [ 48 | '你好', 49 | '小主,敬事房传来消息,说皇上晚上去华妃那儿。', 50 | ] 51 | launcher = 'none' 52 | load_from = None 53 | log_level = 'INFO' 54 | lr = 2e-05 55 | max_epochs = 3 56 | max_length = 2048 57 | max_norm = 1 58 | model = dict( 59 | llm=dict( 60 | pretrained_model_name_or_path= 61 | '/root/model/internlm/internlm2-chat-1_8b-sft', 62 | trust_remote_code=True, 63 | type='transformers.AutoModelForCausalLM.from_pretrained'), 64 | type='xtuner.model.SupervisedFinetune') 65 | optim_type = 'torch.optim.AdamW' 66 | optim_wrapper = dict( 67 | accumulative_counts=4, 68 | clip_grad=dict(error_if_nonfinite=False, max_norm=1), 69 | dtype='float16', 70 | loss_scale='dynamic', 71 | optimizer=dict( 72 | betas=( 73 | 0.9, 74 | 0.999, 75 | ), 76 | lr=2e-05, 77 | type='torch.optim.AdamW', 78 | weight_decay=0), 79 | type='mmengine.optim.AmpOptimWrapper') 80 | pack_to_max_length = True 81 | param_scheduler = [ 82 | dict( 83 | begin=0, 84 | by_epoch=True, 85 | convert_to_iter_based=True, 86 | end=0.09, 87 | start_factor=1e-05, 88 | type='mmengine.optim.LinearLR'), 89 | dict( 90 | T_max=3, 91 | begin=0.09, 92 | by_epoch=True, 93 | convert_to_iter_based=True, 94 | eta_min=0.0, 95 | type='mmengine.optim.CosineAnnealingLR'), 96 | ] 97 | pretrained_model_name_or_path = '/root/model/internlm/internlm2-chat-1_8b-sft' 98 | prompt_template = 'xtuner.utils.PROMPT_TEMPLATE.internlm2_chat' 99 | randomness = dict(deterministic=False, seed=None) 100 | resume = False 101 | tokenizer = dict( 102 | padding_side='right', 103 | pretrained_model_name_or_path= 104 | '/root/model/internlm/internlm2-chat-1_8b-sft', 105 | trust_remote_code=True, 106 | type='transformers.AutoTokenizer.from_pretrained') 107 | train_cfg = dict(by_epoch=True, max_epochs=3, val_interval=1) 108 | train_dataloader = dict( 109 | batch_size=1, 110 | collate_fn=dict(type='xtuner.dataset.collate_fns.default_collate_fn'), 111 | dataset=dict( 112 | dataset=dict( 113 | data_files=dict(train='/root/data/huanhuan_xtuner.json'), 114 | path='json', 115 | type='datasets.load_dataset'), 116 | dataset_map_fn=None, 117 | max_length=2048, 118 | pack_to_max_length=True, 119 | remove_unused_columns=True, 120 | shuffle_before_pack=True, 121 | template_map_fn=dict( 122 | template='xtuner.utils.PROMPT_TEMPLATE.internlm2_chat', 123 | type='xtuner.dataset.map_fns.template_map_fn_factory'), 124 | tokenizer=dict( 125 | padding_side='right', 126 | pretrained_model_name_or_path= 127 | '/root/model/internlm/internlm2-chat-1_8b-sft', 128 | trust_remote_code=True, 129 | type='transformers.AutoTokenizer.from_pretrained'), 130 | type='xtuner.dataset.process_hf_dataset'), 131 | num_workers=0, 132 | sampler=dict(shuffle=True, type='mmengine.dataset.DefaultSampler')) 133 | train_dataset = dict( 134 | dataset=dict( 135 | data_files=dict(train='/root/data/huanhuan_xtuner.json'), 136 | path='json', 137 | type='datasets.load_dataset'), 138 | dataset_map_fn=None, 139 | max_length=2048, 140 | pack_to_max_length=True, 141 | remove_unused_columns=True, 142 | shuffle_before_pack=True, 143 | template_map_fn=dict( 144 | template='xtuner.utils.PROMPT_TEMPLATE.internlm2_chat', 145 | type='xtuner.dataset.map_fns.template_map_fn_factory'), 146 | tokenizer=dict( 147 | padding_side='right', 148 | pretrained_model_name_or_path= 149 | '/root/model/internlm/internlm2-chat-1_8b-sft', 150 | trust_remote_code=True, 151 | type='transformers.AutoTokenizer.from_pretrained'), 152 | type='xtuner.dataset.process_hf_dataset') 153 | visualizer = None 154 | warmup_ratio = 0.03 155 | weight_decay = 0 156 | work_dir = './work_dirs/internlm2_1_8b_full_oasst1_e3_huanhuan' 157 | -------------------------------------------------------------------------------- /train/internlm2_chat_7b_qlora_oasst1_e3_copy.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | from datasets import load_dataset 4 | from mmengine.dataset import DefaultSampler 5 | from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook, 6 | LoggerHook, ParamSchedulerHook) 7 | from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR 8 | from peft import LoraConfig 9 | from torch.optim import AdamW 10 | from transformers import (AutoModelForCausalLM, AutoTokenizer, 11 | BitsAndBytesConfig) 12 | 13 | from xtuner.dataset import process_hf_dataset 14 | from xtuner.dataset.collate_fns import default_collate_fn 15 | from xtuner.dataset.map_fns import oasst1_map_fn, template_map_fn_factory 16 | from xtuner.engine import DatasetInfoHook, EvaluateChatHook 17 | from xtuner.model import SupervisedFinetune 18 | from xtuner.utils import PROMPT_TEMPLATE 19 | 20 | ####################################################################### 21 | # PART 1 Settings # 22 | ####################################################################### 23 | # Model 24 | pretrained_model_name_or_path = '/root/share/model_repos/internlm2-chat-7b' 25 | 26 | # Data 27 | data_path = '/root/data/huanhuan_xtuner.json' 28 | prompt_template = PROMPT_TEMPLATE.internlm2_chat 29 | max_length = 512 30 | pack_to_max_length = True 31 | 32 | # Scheduler & Optimizer 33 | batch_size = 1 # per_device 34 | accumulative_counts = 16 35 | dataloader_num_workers = 0 36 | max_epochs = 3 37 | optim_type = AdamW 38 | lr = 2e-4 39 | betas = (0.9, 0.999) 40 | weight_decay = 0 41 | max_norm = 1 # grad clip 42 | warmup_ratio = 0.03 43 | 44 | # Evaluate the generation performance during the training 45 | evaluation_freq = 90 46 | SYSTEM = '现在你要扮演皇帝身边的女人--甄嬛' 47 | evaluation_inputs = [ 48 | '你好', '你是谁?' 49 | ] 50 | 51 | ####################################################################### 52 | # PART 2 Model & Tokenizer # 53 | ####################################################################### 54 | tokenizer = dict( 55 | type=AutoTokenizer.from_pretrained, 56 | pretrained_model_name_or_path=pretrained_model_name_or_path, 57 | trust_remote_code=True, 58 | padding_side='right') 59 | 60 | model = dict( 61 | type=SupervisedFinetune, 62 | llm=dict( 63 | type=AutoModelForCausalLM.from_pretrained, 64 | pretrained_model_name_or_path=pretrained_model_name_or_path, 65 | trust_remote_code=True, 66 | torch_dtype=torch.float16, 67 | quantization_config=dict( 68 | type=BitsAndBytesConfig, 69 | load_in_4bit=True, 70 | load_in_8bit=False, 71 | llm_int8_threshold=6.0, 72 | llm_int8_has_fp16_weight=False, 73 | bnb_4bit_compute_dtype=torch.float16, 74 | bnb_4bit_use_double_quant=True, 75 | bnb_4bit_quant_type='nf4')), 76 | lora=dict( 77 | type=LoraConfig, 78 | r=64, 79 | lora_alpha=16, 80 | lora_dropout=0.1, 81 | bias='none', 82 | task_type='CAUSAL_LM')) 83 | 84 | ####################################################################### 85 | # PART 3 Dataset & Dataloader # 86 | ####################################################################### 87 | train_dataset = dict( 88 | type=process_hf_dataset, 89 | dataset=dict(type=load_dataset, path='json', data_files=dict(train=data_path)), 90 | tokenizer=tokenizer, 91 | max_length=max_length, 92 | dataset_map_fn=None, 93 | template_map_fn=dict( 94 | type=template_map_fn_factory, template=prompt_template), 95 | remove_unused_columns=True, 96 | shuffle_before_pack=True, 97 | pack_to_max_length=pack_to_max_length) 98 | 99 | train_dataloader = dict( 100 | batch_size=batch_size, 101 | num_workers=dataloader_num_workers, 102 | dataset=train_dataset, 103 | sampler=dict(type=DefaultSampler, shuffle=True), 104 | collate_fn=dict(type=default_collate_fn)) 105 | 106 | ####################################################################### 107 | # PART 4 Scheduler & Optimizer # 108 | ####################################################################### 109 | # optimizer 110 | optim_wrapper = dict( 111 | type=AmpOptimWrapper, 112 | optimizer=dict( 113 | type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay), 114 | clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False), 115 | accumulative_counts=accumulative_counts, 116 | loss_scale='dynamic', 117 | dtype='float16') 118 | 119 | # learning policy 120 | # More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501 121 | param_scheduler = [ 122 | dict( 123 | type=LinearLR, 124 | start_factor=1e-5, 125 | by_epoch=True, 126 | begin=0, 127 | end=warmup_ratio * max_epochs, 128 | convert_to_iter_based=True), 129 | dict( 130 | type=CosineAnnealingLR, 131 | eta_min=0.0, 132 | by_epoch=True, 133 | begin=warmup_ratio * max_epochs, 134 | T_max=max_epochs, 135 | convert_to_iter_based=True) 136 | ] 137 | 138 | # train, val, test setting 139 | train_cfg = dict(by_epoch=True, max_epochs=max_epochs, val_interval=1) 140 | 141 | ####################################################################### 142 | # PART 5 Runtime # 143 | ####################################################################### 144 | # Log the dialogue periodically during the training process, optional 145 | custom_hooks = [ 146 | dict(type=DatasetInfoHook, tokenizer=tokenizer), 147 | dict( 148 | type=EvaluateChatHook, 149 | tokenizer=tokenizer, 150 | every_n_iters=evaluation_freq, 151 | evaluation_inputs=evaluation_inputs, 152 | system=SYSTEM, 153 | prompt_template=prompt_template) 154 | ] 155 | 156 | # configure default hooks 157 | default_hooks = dict( 158 | # record the time of every iteration. 159 | timer=dict(type=IterTimerHook), 160 | # print log every 100 iterations. 161 | logger=dict(type=LoggerHook, interval=10), 162 | # enable the parameter scheduler. 163 | param_scheduler=dict(type=ParamSchedulerHook), 164 | # save checkpoint per epoch. 165 | checkpoint=dict(type=CheckpointHook, interval=1), 166 | # set sampler seed in distributed evrionment. 167 | sampler_seed=dict(type=DistSamplerSeedHook), 168 | ) 169 | 170 | # configure environment 171 | env_cfg = dict( 172 | # whether to enable cudnn benchmark 173 | cudnn_benchmark=False, 174 | # set multi process parameters 175 | mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), 176 | # set distributed parameters 177 | dist_cfg=dict(backend='nccl'), 178 | ) 179 | 180 | # set visualizer 181 | visualizer = None 182 | 183 | # set log level 184 | log_level = 'INFO' 185 | 186 | # load from which checkpoint 187 | load_from = None 188 | 189 | # whether to resume training from the loaded checkpoint 190 | resume = False 191 | 192 | # Defaults to use random seed and disable `deterministic` 193 | randomness = dict(seed=None, deterministic=False) 194 | --------------------------------------------------------------------------------