├── .gitignore ├── LICENSE ├── README.md ├── README_tmp.md ├── app.py ├── appPrepare ├── env_prepare.sh ├── files_prepare.py ├── func_prepare.py └── list_prepare.py ├── chat └── model_center.py ├── pictures ├── 1.png ├── demo.mp4 └── workflow.jpg ├── requirements.txt ├── template_configs ├── full_finetune.py ├── lora.py └── qlora.py ├── xtuner_config ├── build_config.py ├── check_custom_dataset.py ├── get_default_hyperparameters.py └── get_prompt_template.py ├── xtuner_convert ├── convert_and_merge.py ├── convert_with_progress.py ├── merge.py └── pth_to_hf.py ├── xtuner_download ├── README.md ├── __init__.py ├── data_list.txt ├── download_dataset.py ├── download_model.py ├── download_utils.py ├── find_datalist.py ├── kill_hf.sh ├── model_list.txt ├── test_download.py └── todo_list.md ├── xtuner_result └── draw.py └── xtuner_run ├── example.py ├── kill_xtuner.sh ├── shell_train.py ├── todo_list.md ├── train.py └── train_utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | /xtuner_download/__pycache__ 3 | /xtuner_run/__pycache__ 4 | /appPrepare/__pycache__ 5 | /appPrepare/tmp 6 | /appPrepare/work_dir 7 | /work_dir 8 | /appPrepare/download_cache 9 | 10 | # Byte-compiled / optimized / DLL files 11 | __pycache__/ 12 | *.py[cod] 13 | *$py.class 14 | 15 | # C extensions 16 | *.so 17 | 18 | # Distribution / packaging 19 | .Python 20 | build/ 21 | develop-eggs/ 22 | dist/ 23 | downloads/ 24 | eggs/ 25 | .eggs/ 26 | lib/ 27 | lib64/ 28 | parts/ 29 | sdist/ 30 | var/ 31 | wheels/ 32 | *.egg-info/ 33 | .installed.cfg 34 | *.egg 35 | MANIFEST 36 | 37 | # PyInstaller 38 | # Usually these files are written by a python script from a template 39 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 40 | *.manifest 41 | *.spec 42 | 43 | # Installer logs 44 | pip-log.txt 45 | pip-delete-this-directory.txt 46 | 47 | # Unit test / coverage reports 48 | htmlcov/ 49 | .tox/ 50 | .coverage 51 | .coverage.* 52 | .cache 53 | nosetests.xml 54 | coverage.xml 55 | *.cover 56 | .hypothesis/ 57 | .pytest_cache/ 58 | 59 | # Translations 60 | *.mo 61 | *.pot 62 | 63 | # Django stuff: 64 | *.log 65 | local_settings.py 66 | db.sqlite3 67 | 68 | # Flask stuff: 69 | instance/ 70 | .webassets-cache 71 | 72 | # Scrapy stuff: 73 | .scrapy 74 | 75 | # Sphinx documentation 76 | docs/*/_build/ 77 | 78 | # PyBuilder 79 | target/ 80 | 81 | # Jupyter Notebook 82 | .ipynb_checkpoints 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # celery beat schedule file 88 | celerybeat-schedule 89 | 90 | # SageMath parsed files 91 | *.sage.py 92 | 93 | # Environments 94 | .env 95 | .venv 96 | env/ 97 | venv/ 98 | ENV/ 99 | env.bak/ 100 | venv.bak/ 101 | 102 | # Spyder project settings 103 | .spyderproject 104 | .spyproject 105 | 106 | # Rope project settings 107 | .ropeproject 108 | 109 | # mkdocs documentation 110 | /site 111 | 112 | # mypy 113 | .mypy_cache/ 114 | 115 | # custom 116 | data/ 117 | data 118 | .vscode 119 | .idea 120 | .DS_Store 121 | *.pkl 122 | *.pkl.json 123 | *.log.json 124 | work_dirs/ 125 | 126 | # Pytorch 127 | *.pth 128 | *.py~ 129 | *.sh~ 130 | 131 | # srun 132 | *.out 133 | batchscript-* 134 | 135 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Scc_hy 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 |

4 | 5 | [![GitHub Repo stars](https://img.shields.io/github/stars/scchy/XtunerGUI?style=social)](https://github.com/scchy/XtunerGUI/stargazers) 6 | 7 |
8 | 9 | # 1. 项目背景 10 | 11 | XTuner是由InternLM团队开发的一个高效、灵活且全能的轻量化大模型微调工具库。其主要用于多种大型语言模型的高效微调,包括大语言模型InternLM和多模态图文模型LLaVa。XTuner不仅提供了丰富的模型、数据集、数据管道和算法支持,还配备了现成的配置文件和快速入门指南,使得用户能够便捷地进行模型微调和部署。总体来看,XTuner为大型语言模型的微调提供了一个高效、全面且用户友好的解决方案,适用于追求性能优化和定制化的开发者和研究者。 12 | 13 | 虽然XTuner已经简化了大量微调中的步骤,但由于对于0基础的小白而言,还是具有一定的技术门槛。因此,借由InternLM官方推出的大模型实战训练营的机会,我们小组成员有幸与XTuner官方技术团队合作,在参考了 [LLaMA-Factory](https://github.com/hiyouga/LLaMA-Factory) 的基础上,根据XTuner的特性进行修改完善,从而完成了基于Gradio的XTuner可视化的界面设计。 14 | 此项目旨在为基础知识较弱的初学者提供便捷的微调解决方案,使他们能够通过简单的点击来尝试对模型进行微调。该界面能够实时展示训练信息和训练结果,并支持用户对微调后的模型与原始模型进行对比测试。此外,除了支持官方提供的模型和数据集之外,高级用户还可以上传自己的模型和数据集进行微调。这种自定义模型的功能不仅有助于初学者在已经微调过的模型基础上进行进一步的学习和探索,也大大增强了该界面的实用性和灵活性。 15 | 16 | # 2. 项目成员介绍 17 | XTuner GUI项目得到了XTuner官方的支持,因此除了浦语实战训练营里的四位成员外,还包括了两名XTuner专业的开发人员cwh及pppppM。下面是对各个成员的贡献进行介绍,感谢大家一个多月以来的辛勤付出!也感谢书生.浦语官方为我们所提供课程以及算力支持!相信AI Lab将持续做大做强,成为中国数一数二的开源社区! 18 | ![image](https://github.com/scchy/XtunerGUI/assets/108343727/f10cabd8-c027-4fd9-95da-15c5f351f79d) 19 | 20 | 2.1 Xtuner GUI团队成员包括 21 | - Jianfeng777 – 负责整体前端开发、任务策划及文案撰写 22 | - Scchy – 负责整体后端开发及规划 23 | - L241025097 – 负责模型训练终端可视化 24 | - Semple030228 – 负责模型转换及整合部分 25 | 26 | 2.2 XTuner开发人员 27 | - HIT-cwh - 负责mmengine相关内容设置及配置文件生成,模型和数据集检查等开发工作 28 | - pppppM - 提供XTuner方面专业的指导意见 29 | 30 | # 3. 快速启动(仅支持Linux系统) 31 | 首先我们需要创建一个新的虚拟环境,并将GitHub的内容克隆到本地。 32 | 33 | ```bash 34 | conda create -n XtunerGUI python=3.10 -y 35 | conda activate XtunerGUI 36 | git clone https://github.com/scchy/XtunerGUI.git 37 | ``` 38 | 39 | 然后我们需要进入仓库的内部并安装运行XTunerGUI所需要的包(假如安装速度过慢请使用清华源镜像)。 40 | 41 | ```bash 42 | cd XtunerGUI 43 | pip install -r requirements.txt 44 | ``` 45 | 46 | 经过一段时间的安装后,我们就可以启动`app.py`文件进入我们创建的界面。 47 | 48 | ```bash 49 | python app.py 50 | ``` 51 | 52 | # 4. UI界面介绍 53 | 本页面共分为六部分,内容涵盖了大语言模型中所有基础的步骤(具体可看下图),下面我将一步步带领大家了解页面的布局以及具体的使用方式。此外,我们可以在OpenXLab里查看完整的页面细节([链接](https://openxlab.org.cn/apps/detail/Scchy/XtunerFactory))。 54 | 55 | ## 4.1 本地路径设置 56 | ![image](https://github.com/scchy/XtunerGUI/assets/108343727/e0d540bb-8a7b-4c67-a379-f5aeac8cbd90) 57 | 58 | 第一步,我们先要输入两个本地的路径。一个是整体文件保存的位置,另外一个是模型和数据集下载保存的位置。对于文件保存路径(customer_path)而言,在该路径下将保存配置文件、所有的训练过程文件(权重文件、训练日志等)及模型转换后的内容。那对于模型数据集文件路径(download_cache)而言,该路径下将保存在UI界面中下载的模型和数据集文件。在对路径进行修改后切记要点击确认路径按钮哦! 59 | ## 4.2 微调方法、模型、数据集设置 60 | ![image](https://github.com/scchy/XtunerGUI/assets/108343727/428b39c3-c1cd-4eca-90b1-2cfc79d8c6cc) 61 | 62 | 第二步,在这里我们需要选择微调的方法,目前已支持QLoRA、LoRA和全量微调(full)三种方法,大家可以根据自己的硬件情况和实际任务需求进行选择。另外我们也支持大量的官方数据集和模型,通过点击下方下载按钮即可自动从Modelscope、Huggingface和OpenXLab三个平台进行下载。假如发现下载错模型或者数据集也可点击取消下载工作,下载任务取消后将删除已下载的内容,从而减少内存的占用。 63 | ![image](https://github.com/scchy/XtunerGUI/assets/108343727/c5bf98aa-3cc3-4dbd-b33f-0fe0f6605fb9) 64 | 65 | 另外我们还支持大家上传自定义的模型或者数据集。 66 | - 对于自定义的模型,我们可以通过上传本地的路径,并且点击按钮即可检查模型是否可用及对模型提示词模版的匹配(在UI界面点击下载按钮下载的官方模型会自动进行提示词模版匹配),不同的模型会有其独特的提示词模版,更多详细信息可以进入UI界面中查看。 67 | - 对于自定义的数据集,我们目前仅支持OpenAI的数据集格式(最通用的数据集格式)。在UI界面中也展示了OpenAI数据集的格式,大家可以通过各类的大语言模型(比如说ChatGPT)对自己的数据集进行格式的更改。在将数据集格式转为OpenAI格式后,我们可以通过输入本地数据集路径或者将文件在gradio的文件框上上传。在完成数据集文件上传后,还可点击按钮检查数据集是否符合规定。 68 | ## 4.3 微调参数设置 69 | 第三步,在微调参数设置处,我们将所有的参数分为了基础参数和进阶参数两部分。我们根据微调方法的不同设置了几套默认的参数模版,一般情况下大家可以直接使用该参数进行微调即可。 70 | ![image](https://github.com/scchy/XtunerGUI/assets/108343727/d776dae4-cb02-4e12-a09f-f5187f79be55) 71 | 72 | 在基础参数部分是我们比较常用的参数,包括学习率、预热比、数据集最大长度、GPU数量、设备样本个数以及评估问题等等。值得一提的是我们可以自己设置多个评估问题,默认的评估问题是"请给我介绍五个上海景点"的中英文版,但是我们可以将其修改为我们所需要的问题,并且通过选择问题的数量可以增加在评估时候的问题(最多十个问题)。 73 | ![image](https://github.com/scchy/XtunerGUI/assets/108343727/048eb439-12d4-46a1-bff5-361d7641b3d3) 74 | 75 | 对于进阶参数而言,就是一些比较不常使用也不怎么需要修改的参数,比如说优化器的类型、权重衰减、梯度剪裁等。这些虽然会影响模型的效果,但是修改优化的难度比较大,大家在使用过程中除非有明确的修改目的,否则不建议大家进行更改。 76 | 在完成了参数的设置后,接下来就需要点击按钮生成配置文件了。配置文件的生成的模型训练的灵魂,模型的训练过程和方向都是基于配置文件里的内容进行。在这里配置文件的创建就是基于上面我们设置的内容。这里需要注意的是,假如大家同时在自定义数据集/模型以及在GUI界面下载了模型/数据集,这里默认以自定义的数据集/模型作为配置文件的模型/数据集的路径,因此大家在使用的过程中需要注意这一点。 77 | ## 4.4 微调模型训练 78 | ![image](https://github.com/scchy/XtunerGUI/assets/108343727/ed919fd9-3641-475c-b17a-7504f05fc502) 79 | 80 | 在完成配置文件后,我们就可以点击按钮启动模型训练工作。当然我们也可以点击按钮暂时中断我们的训练过程。当我们需要对中断的模型继续训练的时候,我们可以选择之前保存的权重并点击按钮继续训练。中断后续训是不会影响最终模型训练的效果的。 81 | 另外,在我们点击训练后,我们可以打开下面的终端界面查看训练的过程以及内容,这样我们就能够更好的监控整体的训练过程。假如训练效果过差,我们也能够及时进行模型训练的中断,以免浪费无谓的时间。 82 | ## 4.5 微调结果展示 83 | ![image](https://github.com/scchy/XtunerGUI/assets/108343727/90f24f77-31c6-40c3-aec6-2958f6c75450) 84 | 85 | 在模型微调进程结束后,我们可以点击下方按钮生成一些关键内容的展示。包括说损失函数的的变化图、学习率在训练过程中的变化图以及不同权重文件下测试问题。这样我们就既能够看到模型训练过程的变化,也能够通过测试问题的对比来看到模型是否过拟合,从而找到最优的权重文件进行模型测试及部署。 86 | ## 4.6 微调模型转化及测试 87 | ![image](https://github.com/scchy/XtunerGUI/assets/108343727/31a0a9fe-58f2-4ce4-8fa1-1cc7dc84de55) 88 | 89 | 在我们通过微调结果展示找到效果最好的模型权重文件后,我们还需要将我们的模型转化为常见的HuggingFace格式。对于LoRA或者QLoRA方式微调出来的模型还需要与原模型进行整合。在这里我们合并了这两部分,我们会基于大家第二步选择的微调方法进行调整。具体的操作就是我们需要在下拉列表中找到对应的权重文件后,点击模型转换按钮即可。 90 | ![image](https://github.com/scchy/XtunerGUI/assets/108343727/f7e5356e-ceb3-4baa-a430-68668514f383) 91 | 92 | 在模型转换后,我们就可以进行对话的测试。在左边可以展示原来底座模型的效果,而右边展示的是微调后模型的效果。我们只需要选择合适的模型推理参数,点击模型启动即可进行对话。我们可以通过原模型和微调后模型的对比查看微调的实际效果。 93 | 94 | 以下是我们录制的一个简短使用视频([B站](https://www.bilibili.com/video/BV1av42117yT/?spm_id_from=333.999.0.0)),大家可以通过视频来作更进一步的了解。 95 | 96 | 以上就是页面的一个基本的介绍,假如大家想单纯的使用的话就可以马上开始上手啦!但是假如大家希望对我们设计的思路以及原理有更深刻的认识的话,那就继续往下看吧! 97 | 下面的部分我们将谈谈XTuner GUI背后的XTuner的运作原理,从而能够更深一层次的了解XTuner GUI的实现原理。正所谓知其然还需要知其所以然,假如我们能够真正的通过XTuner整体的结构设计以及指令,那我们就能更好的理解XTuner GUI项目的运行机理。 98 | 99 | # 5. XTuner流程介绍 100 | 对于XTuner的基本操作,我们可以通过以下这张图,简单的了解一下。高清图片链接请点[击此位置](https://www.figma.com/file/0SVTWhnGxbY7ADy2UEluCR/XTuner-Flow?type=whiteboard&node-id=0%3A1&t=bzZP6fCSAuBj2uon-1)。 101 | 102 | ![image](https://github.com/scchy/XtunerGUI/assets/108343727/50dd5543-1f65-4984-a5a2-35a045e1f5c6) 103 | 104 | 可以看到,整个工作流程分为以下四个步骤(具体各个步骤的调用代码可参考下图): 105 | ![image](https://github.com/scchy/XtunerGUI/assets/108343727/f6a7722a-aca1-43cf-81d4-3e3d9de37a9a) 106 | 107 | ## 5.1 数据采集及格式转换 108 | ![image](https://github.com/scchy/XtunerGUI/assets/108343727/15384b89-aea9-4049-8866-f407ab93d314) 109 | 110 | - 首先,根据任务的不同,我们需要明确微调目标,进行数据采集,并将数据转换为 XTuner 所支持的格式类型。这部分需要大家自行完成,当然我们假如只是体验的话仅需要使用官方支持的数据集即可。 111 | - 然后我们还需要根据自己的硬件条件选择合适的微调方法和合适的基座模型。不同的基座模型对显存的需求都不太一样,模型参数越大,微调所需要显存就越多。而在微调方法中,对显存需求最小的就是QLoRA(最少8GB即可运行),而显存需求最大的则是全量微调。 112 | ## 5.2 配置文件的创建 113 | ![image](https://github.com/scchy/XtunerGUI/assets/108343727/310323ad-def0-4748-b2e5-fe8e1c808344) 114 | 115 | - 首先,我们可以通过执行 xtuner list-cfg 命令列出所有配置文件。 116 | - 通过上面选择的微调方法和基座模型找到合适的配置文件,并使用 xtuner copy-cfg ${CONFIG_NAME} ${SAVE_PATH} 命令复制到本地端。 117 | - 复制完成后还需要根据自己的需求修改配置文件以更新模型路径和数据集路径。 118 | - 特定时候还需要调整模型参数和配置,更改 load_dataset 函数和 dataset_map_fn 函数。并根据模型选择合适的 prompt_template。 119 | ## 5.3 模型训练 120 | ![image](https://github.com/scchy/XtunerGUI/assets/108343727/1b568d6a-5da1-493e-826c-f474b7a00db0) 121 | 122 | - 修改配置文件后,我就可以使用 xtuner train 命令启动训练。 123 | - 除此之外我们还可以设置特定参数优化训练,如启用 deepspeed,以及设置训练文件的保存路径。 124 | - 假如意外的中断了训练,还可以通过加上--resume {checkpoint_path}的方式进行模型续训。具体可看下面的指令详解。 125 | ## 5.4 模型转换、测试及部署 126 | ![image](https://github.com/scchy/XtunerGUI/assets/108343727/ba34bad6-f771-40aa-a82c-f53096c71841) 127 | 128 | - 在完成训练后,找到对应的训练文件并执行 `xtuner convert pth_to_hf` 命令,就可以将转换模型格式为 `huggingface` 格式。 129 | - 对于LoRA类的模型而言,则需要执行 `xtuner convert merge` 命令将 `adapter` 层与原模型进行合并。 130 | - 转换完成后,我们就可以以转换后的文件路径并使用 `xtuner chat` 命令启动模型进行性能测试。 131 | - 除此之外,我们还可以在安装 `LMDeploy` 后通过 `python -m lmdeploy.pytorch.chat` 命令进行模型部署,即使用TurboMind进行推理。 132 | 133 | 以上就是关于XTuner项目的一些基础功能及指令的展示,下面我将通过XTuner GUI整体的逻辑图来深入的剖析我们在设计原型过程中的思路和对现有流程的优化点。 134 | # 6. XTuner GUI设计思路介绍 135 | 那对于XTuner GUI而言,我们将其分为了六个部分路径设置、模型数据集微调方法设置、相关参数设置、模型训练、训练结果展示、模型转换及测试部分。之所以这样进行设计,主要目的还是希望作为一个小白可以先抛开一系列的专业知识,能够真真正正的先将模型跑起来看到效果后,再一步步的进行研究到底每一步的原理是什么。想当年我学习OpenMMLab相关的算法库,例如MMDetection和MMSegementation,我最开始的时候也就是去找到一些数据集然后跑起来,然后再慢慢研究怎么优化,整体运行逻辑。同样的,虽然XTuner的门槛已经非常低了,但是我希望能够把这个门槛能够降得更低,能够让更多人能够无痛上手。 136 | 137 | 那真正让用户无痛上手,那就必须要砍掉让他们思考的部分。比如说在原生的XTuner里面,我们还需要自己下载模型数据集,还需要找到合适的配置文件,还需要我们自己找到对应的文件夹进行转换等等,这些通通都不再需要,我们只需要通过点击按钮、选择下拉框的内容或者说调整一下参数就可以将模型跑起来了,并且最后也将整体的对话测试也是直接输入文本即可同时与原模型进行对话,这些都是我们希望能够最小化大家跑起模型的难度,能够真正打开大模型微调的大门。 138 | 139 | 除了对0基础的小白进行支持以外,我们对拥有一定使用经验的人也作出了考量。首先就是增加了自定义模型和自定义数据集两部分,那对于想要对自己的数据集或者模型微调的人就能够节省真正进入文件修改的时间。其次是提供了大量可修改的参数让大家进行选择。这些对于一个拥有一定经验的“炼丹师”而言,无疑是非常有意义的。 140 | 141 | 那对于大师而言,尤其是需要训练多模态模型的人,这里其实我们就没有做过多的特定支持。主要原因是这部分人群的代码能力和调试能力非常强,无论是使用原生的XTuner或者其他的微调工具都会得心应手,不需要过多在这些细节上进行可视化的展示。总的来说,我们所针对的人群其实更多是哪些0基础的小白以及有一定经验的炼丹师,通过使用这样一个工具能够更好的完成他们的工作。 142 | 那对于XTuner GUI而言,我同样也是制作了一个逻辑图来展示整体的运行思路(高清图片链接请[点击此位置](https://www.figma.com/file/wFN0wMlknYyzV3ZMCihnPC/XTuner-GUI-Flow?type=whiteboard&t=ch8xUYvYdXnoWGNX-1))。 143 | 144 | 下面将一步步的解释整体的架构,并说明相比于原生的XTuner,我们作出了哪些的调整以及设计时的思路: 145 | 146 | ## 6.1 路径设置 147 | 148 |

149 | image 150 | image 151 |

152 | 153 | 首先我们可以看到我们需要传入的是文件保存路径及模型数据集文件保存路径。对于文件保存路径(customer_path)而言,在该路径下将保存配置文件、所有的训练过程文件(权重文件、训练日志等)及模型转换后的内容。那对于模型数据集文件路径(download_cache)而言,该路径下将保存在UI界面中下载的模型和数据集文件。 154 | 其实在初版的设计当中,这一部分其实是没有被添加进去的,但是后面我们发现下载的模型和数据集文件都可能会相对比较大,再加上后面微调后的文件可能会撑爆内存,因此我们决定将两者分开,用户可以自行选择合适的路径进行保存,这样就可以降低内存的压力。当然对于那些内存充足的人而言,仅需要按照默认的路径即可。 155 | ## 6.2 模型、数据集及微调方法设置 156 | ![image](https://github.com/scchy/XtunerGUI/assets/108343727/d3f26565-e445-4d0b-ab3f-031e855eaa05) 157 | 158 | 对于所有微调任务而言,第一步我们要考虑的都是说,我要微调什么模型,我要用什么数据集去微调这个模型,具体使用的微调方法是什么。这三个基本的步骤其实就拦住了很多的人,因为他们不知道去哪里找这些东西。即便是使用原生XTuner的时候,即便我们真的根据仓库中给出的快速开始将模型跑起来,但是我们还是可能不太理解这一切是怎么执行的。并且,在XTuner仓库里已有的config文件其实并不包含所有的模型、数据集和微调方法的组合,因此对于那些想直接用但是找不到对应config文件的人们来说,可能就真的是从入门到放弃了,毕竟对于他们而言,修改一个类似的config然后调整里面对应的东西难度都太高了吗,真的能做下来也不是0基础的小白了。 159 | 基于以上的思考,我们所做的就是简化这一系列的流程。首先我们设置了下拉框来直接根据需求选择模型微调的方法。其次是对于在Huggingface、Modelscope和OpenXLab上已有的数据集和模型,我们提供下拉框让他们直接进行选择并能够点击按钮进行下载。下载完的模型也将自动保存在上面设置的模型数据集文件保存路径(download_cache)上。这样用户是真的知道自己要训练的是一个什么模型,用的是一个什么数据集,具体的微调方法是什么,而不是仅仅给他们一个config文件一个文件名去自己领悟。 160 | 其次对于进阶用户的自定义模型和数据集,那用户可以选择上传自己的模型然后使用官方的数据集,也可以使用官方的模型然后使用自己的数据集进行微调,这些都是可行的。并且无论是数据集还是模型,我们都增加了一个检查的机制,来帮助用户了解自己的模型和数据集是否存在问题,是否能够正常使用,那就避免了后续出现bug无法解决的问题。 161 | 那对于模型而言,还有一个很重要的步骤就是有一个与之相匹配的提示词模版。一般来说,用户上传自己的模型也不会说提供一套自己的提示词模版,真的能够训练出一个自己提示词模版的模型那也不是小白了。一般而言,这些用户自己上传的模型都是微调别的官方大模型实现的。基于这一层的思考,我们就决定了对于用户自定义上传的模型,我们不仅检查其是否可用,还找到其对应的提示词模版,这样用户也不需要再找到原模型的提示词模版然后放进去了,这就节省了他们不少的时间和精力。 162 | ## 6.3 相关参数设置 163 | ![image](https://github.com/scchy/XtunerGUI/assets/108343727/38181c97-6158-4ab4-8aad-af914344aab7) 164 | 165 | 那在准备好模型数据集了以后,其实我们最重要需要的就是模型和数据集的一个路径作为配置文件的一个重要参数。除此之外呢,其实我们还是需要对一些基本的超参数以及训练的一些参数进行设置,这样才能够生成一个好的pipeline用于实际的模型训练。 166 | 对于大部分的人而言,如何调整参数都是一件相对复杂且困难的事情,因此这里我们也是基于微调方法的不同配套了不同的超参数模版,这个在XTuner原生的config创建时就已经制定好的了,只不过没有公开而已。另外,在UI界面里,每个超参数我都是进行了基本的解释和介绍,以方便初学者对这些参数的含义进行了解。 167 | 除了默认的基本超参数设置,一些个性化的内容还需要我们自己进行设置,包括GPU的数量、是否启动Deepspeed优化等。基于以上提到的这些的参数我们就能够生成一份可用的配置文件,并存放在我们最开始设置的文件保存路径(customer_path),从而能够利用创建好的这个pipeline开始我们个性化的训练过程。 168 | 但是需要注意的是,这个配置文件生成的内容并不是我们平常在XTuner见到的一个config文件,而是一个具有同等效用的json文件,这里也需要感谢XTuner官方开发团队的cwh为我们提供的支持。 169 | ## 6.4 模型训练 170 | ![image](https://github.com/scchy/XtunerGUI/assets/108343727/12dc5e9d-df3a-4496-8466-8f691d27d269) 171 | 172 | 其实到了模型训练这个阶段,基本上就是程序自己运行的事情了。但是我们考虑到有可能会出现的情况是中途因为什么特殊原因训练被中断了,这个时候我们必须提供一个续训的选项让用户能够重新将模型跑起来。假如是重新拉起一个训练过程的话,训练的结果可能和预先结果不太一样,这样问题也蛮严重。因此我们就选择使用XTuner train里面提供的选项--resume来进行续训。 173 | 其实在原本的考虑当中,我希望做成的样子其实是可以想OpenMMLab里面的魔改。就比如说我在MMDetection里的模型Loss一直降不下去了,那我可能就会拿到最后一个epoch的权重文件,然后修改config文件把这个权重文件load进去,并且修改调高学习率等参数来尝试看看能不能好的效果。但是在大模型中,这种魔改的操作并不常见,并且也不是初学者需要考虑的内容,因此在这里也没有加上去了。但是未来要是有需求的话也可以把这部分内容加上去。 174 | 此外,在我们最初的设置里也没有加上终端界面展示部分的内容,但是我们考虑到的一点是:模型训练所要花费的时间太长了,假如用户就看着前端的界面在那转,但是不知道到底发生了什么,训练到了哪一步,自己设置的评估问题回复大概是什么样的话,那么这也是一件很煎熬的事情。因此我们也是增加了一个不断更新的终端界面,以展示训练过程的内容和状态。并且有实验证明说,假如用户能够看到进度条的话,那么整体的焦虑情绪就不会那么严重(除非一直卡在99%),因此后续我们也将根据iter的数量加上一个进度条来让用户看到整体的进度,从而降低他们等待的焦虑。 175 | 在训练的过程中会产生大量的文件,包括说模型的权重文件以及记载模型训练过程的log文件,这些文件都将保存在文件保存路径(customer_path)下。需要注意的是,我们可以通过调整epoch数量、每个n个iter保存一次权重文件以及只保留n个权重文件这类型的训练参数来进行调整实际的训练过程。这样就能够让用户能够自己决定到底训练怎么进行,从而给予了大家更多的自由度。 176 | ## 6.5 训练结果展示 177 | ![image](https://github.com/scchy/XtunerGUI/assets/108343727/7c451d8c-a84c-4ec5-8308-89bd582dedc0) 178 | 179 | 在训练完成后,我们其实可以通过查看损失值的变化图等内容来看看模型训练的成果如何。这里我主要希望能够让大家能够真正的看到自己训练的模型是否满足大致的要求,尤其是看到不同的权重文件对应自己设定的评估问题的答案。这个能够看到不同阶段下模型能够对评估问题产生的回复其实我认为非常重要,这个其实就等同于看到一个婴儿一岁到两岁到三岁到长大成人的变化过程。可能一开始模型微调的时候所得到的答案并不满意,到后面慢慢变得成熟,到最后过拟合可能就只能懂得回答一个问题,这样的变化过程可以让我们清晰的找到哪个iter下的权重文件是最好的,从而能够真正筛选出好的权重文件。 180 | 此外,单纯的看Loss看learning rate其实真的并不足够,因为这些都是一些数字而已,Loss太低你可能会觉得说肯定过拟合了,那Loss太高你可能说是不是没训练好。真正能够评判一个模型的好坏还是要通过真刀真枪的进行对话测试。那在对话测试的话一般我们也就只是对某几个问题进行测试,那假如我们把这些关键的核心的问题能够放在评估问题里,让模型在训练过程中回答并看到其中变化,这确实是能够节省不少的时间。模型的其他能力比如说上下文啊,连续回复等等的就交给最后的对话测试环节吧,仅仅筛选的话通过看设定的评估问题的答案其实完全足够了。 181 | ## 6.6 模型转换及测试 182 | ![image](https://github.com/scchy/XtunerGUI/assets/108343727/37b87354-7bf1-43e0-a222-3559d4096023) 183 | 184 | 在找到最好的模型权重文件后,由于训练过程生成的是Pytroch格式权重文件,因此我们在后续使用前需要将其转为huggingface格式。那对于Lora或Qlora方法训练出来的模型是额外的层(adapter),因此需要将其与原模型进行整合才能进行进一步操作,而全量微调仅需要进行整合工作即可。整合后的模型也将保存在文件保存路径(customer_path)当中。在XTuner里面,模型的转换和整合是被分开来的,但是由于在UI界面完成的都是指定的任务,因此我们可以把这两步组合起来,根据最初设定的微调方法进行判定,假如是QLoRA或者LoRA的话就使用转换+整合,假如是全量微调就只用转换即可。这样用户就不用再去找文件夹然后输入到终端进行设置,而是直接一键生成即可。 185 | 另外对于模型对话测试这一块,我们其实是希望能够让用户感受一下微调前模型和微调后的差别。主要就是在页面里将聊天框一分为二,一边是原模型一边是微调后的模型,那用户同时问一个问题的时候两个模型都能给你回复,并且你也能看到两者的差别。这样用户其实就能够真的感受到微调后到底是什么样的,这也帮助他们了解进一步改进的方向。那假若真的测试后对模型不满意,可以将转换后模型路径传入自定义模型中继续训练或者重新进行训练。然后重复这样一个流程即可。 186 | # 7. 总结及展望 187 | 总的来说,我们在XTuner提供的基础功能基础上,增加了一些独有的内容,包括说绘制loss和learning rate的图表、提取每次评估问题的答案并保存、手动下载数据集和模型到本地使用等等。我们可以很自信的说,该界面目前已实现了基本的微调功能,并且能够有效的帮助新入门者快速上手并训练属于自己的模型。我们相信随着XTuner的流行,XTuner GUI也将能够受到更多的关注,并帮助更多的初级开发者快速上手模型的微调工作。 188 | 对于这个XTuner GUI项目而言,未来可能会朝着以下几个方向进行持续的发展: 189 | 190 | - **完善XTuner GUI的功能**:由于我们只花费了一个多月,并且大家都并不是全职的进行开发,因此其中还有一些问题需要解决,包括说实时显示训练的损失值,显示训练的进度条等等。另外还有很重要的一个内容就是需要适配Windows系统,由于我们的开发都是在浦语所提供的开发机上所进行的,而开发机的环境是Linux系统,因此我们还需要对Windows系统以及MacOS系统作出适配的工作。 191 | - **利用SCC及HTML等前端界面完善界面**:目前我们的前端开发完全依赖于Gradio,这主要是由于我作为一名外行的前端开发者仅仅只懂得Gradio的制作,因此后期可能还需要专门的前端开发人员对界面可能作进一步的设计,让整体更加美观。 192 | - **接入模型部署功能**:由于微调好的模型转换后的格式为HuggingFace的,但是HuggingFace格式下的推理速度和模型文件大小可能都不满足实际项目需求,所以可能还是需要进行实际的落地部署工作。因此我们还需要与LMDeploy进行整合,完成包括模型本地部署以及模型api部署等工作。 193 | - **打造类似于OpenAI的PlayGround**:为了降低Agent的使用门槛,我们预计将来也将推出类似于PlayGround的界面,届时可以像OpenAI一样实现Function call、Retrieval、Code Interpreter、Vision、Audio以及Video等多模态调用的平台。 194 | 195 | 以上就是XTuner GUI这个项目未来的计划,再次感谢所有人员对此的付出,也感谢星佬、米佬等书生.浦语官方人员的帮助和算力支持,未来我们会持续的推进这个计划,为更多新手的快速入门出自己的一份力,也为开源社区作出更多的贡献! 196 | 作为项目的主要负责人,完成这个任务并且参与其中我也觉得非常的有成就感。一方面,能够为开源社区提供一套方案是一件非常有意义的事情,另一方面,在这个项目的开展过程中也认识了很多为爱发电的大佬们,大家能够一起不断学习且进步。真心希望书生浦语开源社区的生态越来越好,为中国的开源社区建立一个良好的榜样! 197 | 假如认为XTuner和XTuner GUI真的有对你产生帮助的话,也希望能给我们Star呀,对于开源社区来说,多一个的Star就代表多一份的认可。正如那句歌词说的,假如人人都能贡献出一点爱,那么世界也将更加的美好~相信随着开源社区的繁荣,国内也能不再那么的浮躁,而是真的做一些推动人类发展的大事吧! 198 | 199 | 200 | > **都看到这里了,不留下个Star是不是有点说不过去啦~那就动动手指点进下面的链接给我Star啦!** 201 | > - **XTuner项目链接**:[https://github.com/InternLM/xtuner](https://github.com/InternLM/xtuner) 202 | > - **XTuner GUI项目链接**:[https://github.com/scchy/XtunerGUI](https://github.com/scchy/XtunerGUI) 203 | > 204 | > **什么?还要给我fork?!那我就代表所有开发者对你感谢啦!好人一生平安!** 205 | 206 | -------------------------------------------------------------------------------- /README_tmp.md: -------------------------------------------------------------------------------- 1 | # XtunerGUI 2 | Xtuner Factory 3 | 4 | [Disign Doc: XtunerGUI](https://aab2vs0do9o.feishu.cn/docx/JWkbdoDiboVKBAxUyQvcg9MQnbb?from=from_copylink) 5 | 6 | 7 | 8 | - 下载`xtuner_donwload` 9 | - 模型下载 10 | - 入参: `model: gr.Dropdown` 11 | - 出参: `model_path: gr.Textbox` 12 | - 路径: 当前版本路径 `XtunerGUI/appPrepare/download_cache/model_download/{username}_{repository}` 13 | - 数据下载 14 | - 入参: `dataset: gr.Dropdown` 15 | - 出参: `data_path: gr.Textbox` 16 | - 路径:当前版本路径 `XtunerGUI/appPrepare/download_cache/data_download/dataset_{username}_{repository}` 17 | 18 | - config `xtuner_config` 19 | - 入参: 非常多 20 | - 出参:`cfg_py_box` 21 | 22 | - fintune `xtuner_run` 23 | - 指定 环境变量 24 | - shell_train: 直接执行shell `xtuner train xxxx` 25 | - 日志显示: -> 日志显示慢的问题 26 | 27 | - 转换合并 `xtuner_convert` 28 | - `convert_and_merge.py` 29 | - 入参: 30 | - todo: 选择epoch 31 | - `config_file`: `xtuner_config` 生成 32 | - `pth_model`: work_dir 目录下模型问题 33 | - `save_hf_dir`: 指定生成目录 {work_dir}/hf 34 | - `model_path`: 模型路径 `xtuner_donwload` 产出 `model_path` 35 | - `save_merged_dir`: 指定生成目录 {work_dir}/merge_epoch{n} 36 | 37 | 38 | todo: 39 | - [X] todo: load_dataset 下载数 40 | - [ ] 测试问题 是否可以添加, 直接输入list 41 | - [ ] 自定义模型 42 | 1. template 映射 -> 43 | 2. 路径校验(template) 44 | - [X] 路径最终确定 45 | - [ ] prompt_template 位置改动? 46 | - [ ] 自定义数据集 只支持openAI 数据集格式 47 | 48 | 49 | ```text 50 | customer_path 51 | |-- download_cache 52 | | |-- data_download 53 | | | `-- tatsu-lab_alpaca 54 | | `-- model_download 55 | | `-- internlm_internlm-chat-7b 56 | `-- work_dir 57 | |-- 20240202_153301 58 | | |-- 20240202_153301.log 59 | | `-- vis_data 60 | |-- iter_100.pth 61 | |-- iter_50.pth 62 | |-- last_checkpoint 63 | |-- xtuner_config.py 64 | `-- xtuner_iter_100_hf 65 | |-- README.md 66 | |-- adapter_config.json 67 | |-- adapter_model.safetensors 68 | `-- xtuner_config.py 69 | 70 | ``` 71 | 72 | /root/share/model_repos/internlm-chat-7b 73 | /root/personal_assistant/data/personal_assistant_openai_final.json 74 | 75 | 76 | ## Test 77 | - [X] customer-root /root/sccTest3 78 | - [X] customer-data-dir /root/download_cache 79 | - [X] customer model: /root/share/model_repos/internlm-chat-7b 80 | - [X] check customer model template detect 81 | - [X] -> detect_prompt_template -> prompt_template_show 82 | - [X] customer dataset: 83 | - /root/personal_assistant/data/personal_assistant_openai_final.json 84 | - /root/xtunerUITest/ttt.json 85 | - [X] data: tatsu-lab/alpaca -> downloading 86 | - [X] config 87 | - [X] ft_method -> DEFAULT_HYPERPARAMETERS 88 | - [X] generate check 89 | - [ ] xtuner 90 | - [ ] running without pregress ? 91 | - show result 92 | - [X] plot 93 | - [X] dynamic select_checkpoint -> 94 | - convert 95 | - [X] choose pth 96 | - [ ] convert progress 97 | 98 | 99 | 100 | 101 | 102 | 103 | -------------------------------------------------------------------------------- /app.py: -------------------------------------------------------------------------------- 1 | # python3 2 | # Create Date: 2024-01-26 3 | # ======================================== 4 | 5 | from xtuner_download.download_model import xtunerModelDownload 6 | from xtuner_download.download_dataset import xtunerDataDownload 7 | from xtuner_convert.convert_and_merge import convert_and_merged 8 | from xtuner_convert.convert_with_progress import ConvertMerged 9 | from xtuner_run.shell_train import quickTrain 10 | from appPrepare.files_prepare import DATA_DOWNLOAD_DIR, MODEL_DOWNLOAD_DIR, CUR_PATH, DEFAULT_DOWNLOAD_DIR 11 | from appPrepare.list_prepare import DATA_LIST, MODEL_LIST, PROMPT_TEMPLATE_LIST 12 | from appPrepare.func_prepare import read_first_ten_lines, get_template_format_by_name, OPENAI_FORMAT 13 | from xtuner_config.build_config import build_and_save_config, model_path_map_fn 14 | from xtuner_config.check_custom_dataset import check_custom_dataset 15 | from xtuner_config.get_prompt_template import app_get_prompt_template 16 | from xtuner_config.get_default_hyperparameters import get_default_hyperparameters 17 | from chat.model_center import ModelCenter 18 | from tqdm import tqdm 19 | from xtuner_result.draw import resPlot 20 | import gradio as gr 21 | import warnings 22 | warnings.filterwarnings(action='ignore') 23 | CHAT_ORG = ModelCenter() 24 | FT_CHAT_ORG = ModelCenter() 25 | CVT_MG = ConvertMerged() 26 | 27 | def combine_message_and_history(message, chat_history): 28 | # 将聊天历史中的每个元素(假设是元组)转换为字符串 29 | history_str = "\n".join(f"{sender}: {text}" for sender, text in chat_history) 30 | 31 | # 将新消息和聊天历史结合成一个字符串 32 | full_message = f"{history_str}\nUser: {message}" 33 | return full_message 34 | 35 | def respond(message, chat_history): 36 | # message1 = combine_message_and_history(message,chat_history) 37 | # client = OpenAI() 38 | # messages=[ 39 | # {"role": "system", "content": "You are a helpful assistant."}, 40 | # {"role": "user", "content": message1} 41 | # ] 42 | 43 | # completion = client.chat.completions.create( 44 | # model="gpt-3.5-turbo", 45 | # messages=messages, 46 | # max_tokens=150, # 设置生成响应的最大 token 数量 47 | # seed=12345, # 设置种子以获得确定性采样(如果可能) 48 | # temperature=0.7, # 设置采样温度 49 | # top_p=0.9 # 设置核心采样的概率质量百分比 50 | # ) 51 | # bot_message_text = completion.choices[0].message.content 52 | # #这里的bot_message_text就是最后输出的文本 53 | # chat_history.append((message, bot_message_text)) 54 | return "", chat_history 55 | 56 | def clear_history(chat_history): 57 | chat_history.clear() 58 | return chat_history 59 | 60 | def regenerate(chat_history): 61 | if chat_history: 62 | # 提取上一条输入消息 63 | last_message = chat_history[-1][0] 64 | # 移除最后一条记录 65 | chat_history.pop() 66 | # 使用上一条输入消息调用 respond 函数以生成新的回复 67 | msg,chat_history = respond(last_message, chat_history) 68 | # 返回更新后的聊天记录 69 | return msg, chat_history 70 | 71 | 72 | def evaluation_question_number_change_wrap(max_textboxes): 73 | def evaluation_question_number_change(k): 74 | k = int(k) 75 | return [gr.Textbox(visible=True)]*k + [gr.Textbox(value='', visible=False)]*(max_textboxes-k) 76 | return evaluation_question_number_change 77 | 78 | with gr.Blocks() as demo: 79 | gr.Markdown(value=''' 80 |
81 | 82 |

83 | 84 | [![GitHub Repo stars](https://img.shields.io/github/stars/scchy/XtunerGUI?style=social)](https://github.com/scchy/XtunerGUI/stargazers) 85 | ''') 86 | 87 | with gr.Tab("基础训练"): 88 | with gr.Accordion(label='使用指南', open=False): 89 | gr.Markdown('## 流程图') 90 | # process = gr.Image(value='/root/XtunerGUI/pictures/workflow.jpg',label='使用流程图',container=False,show_download_button=False ) 91 | gr.Markdown('## 演示视频') 92 | # video_customer_introduction = gr.Video(label='Xtuner GUI用法演示',value='/mnt/d/xtuner/demo.mp4',interactive=False) 93 | gr.Markdown("## 1. 本地路径设置") 94 | with gr.Row(): 95 | local_path = gr.Textbox( 96 | label='请上传所有文件保存的文件本地路径', 97 | value=CUR_PATH, 98 | info='将会在选择的路径下保存模型的配置文件、训练过程文件及模型转换后的文件' 99 | ) 100 | local_model_path = gr.Textbox(label='请确定数据集和模型下载的本地位置', value=DEFAULT_DOWNLOAD_DIR, info='将保存所有通过下方按钮下载的模型和数据集内容在该路径') 101 | 102 | # 这个里面是存放着保存数据集的路径 103 | local_path_button = gr.Button('确认路径') 104 | gr.Markdown("## 2. 微调方法、模型、数据集设置") 105 | 106 | with gr.Row(): 107 | ft_method = gr.Dropdown(choices=['qlora', 'lora', 'full'], value='qlora',label = '微调方法', info='''请选择需要的微调方法(全量微调(full)需要大量显存,请谨慎选择)''',interactive=True) 108 | with gr.Column(): 109 | model = gr.Dropdown(choices=MODEL_LIST + ['自定义'], value='internlm/internlm-chat-7b',label = '模型', info='请选择你希望微调的模型,选择后可点击下方按钮进行下载',interactive=True) 110 | DM_CLS = xtunerModelDownload( 111 | model_name=model.value, 112 | out_path=MODEL_DOWNLOAD_DIR, 113 | tqdm_class=tqdm 114 | ) 115 | local_path_button.click(DM_CLS.reset_path, inputs=[local_model_path]) 116 | model.change(DM_CLS.reset, inputs=[model]) 117 | with gr.Row(): 118 | model_download_button = gr.Button('模型下载') 119 | model_stop_download = gr.Button('取消下载') 120 | model_path = gr.Markdown(label='模型下载详情') 121 | 122 | # model_download_information = gr.Markdown(label='模型下载信息') 123 | # model_download_path = gr.Textbox(visible=False) 124 | model_download_button.click(DM_CLS.auto_download, outputs=[model_path]) 125 | model_stop_download.click(DM_CLS.break_download, outputs=[model_path]) 126 | 127 | with gr.Column(): 128 | dataset = gr.Dropdown(choices=DATA_LIST + ['自定义'], value='shibing624/medical',label = '数据集', info='请选择合适的数据集,选择后可点击下方按钮进行下载',interactive=True) 129 | DT_CLS = xtunerDataDownload( 130 | data_name= dataset.value, 131 | out_path=DATA_DOWNLOAD_DIR, 132 | tqdm_class=tqdm 133 | ) 134 | local_path_button.click(DT_CLS.reset_path, inputs=[local_model_path]) 135 | dataset.change(DT_CLS.reset, inputs=[dataset]) 136 | with gr.Row(): 137 | dataset_download_button = gr.Button('数据集下载') 138 | dataset_stop_download = gr.Button('取消下载') 139 | data_path = gr.Markdown(label='数据下载详情') 140 | # dataset_download_information = gr.Markdown(label='数据集下载信息') 141 | # dataset_download_path = gr.Textbox(visible=False) 142 | dataset_download_button.click(DT_CLS.auto_download, outputs=[data_path]) 143 | dataset_stop_download.click(DT_CLS.break_download, outputs=[data_path]) 144 | wrong_message1 = gr.Markdown() 145 | with gr.Row(): 146 | with gr.Column(scale=1): 147 | with gr.Accordion(label="自定义模型",open=False): 148 | model_personal_path = gr.Textbox(label='自定义模型本地路径', info = '请输入模型的本地路径在下方文本框中') 149 | personal_model = gr.Files(label='请上传自定义模型文件',visible=False) 150 | check_personal_model = gr.Button('模型检查及提示词模板自动匹配(请务必点击!)') 151 | detect_prompt_status = gr.Markdown() #可用于承接检查后得到的结果 152 | # 上传文件自动显示在 model_personal_path 153 | personal_model.change(lambda x: x, inputs=[personal_model], outputs=[model_personal_path]) 154 | with gr.Column(scale=2): 155 | with gr.Accordion(label="自定义数据集(仅支持OpenAI格式)",open=False): 156 | with gr.Row(): 157 | with gr.Column(): 158 | dataset_type = gr.Dropdown(choices=['OpenAI'],value='OpenAI',label = '支持的数据集格式', interactive=False) 159 | dataset_type_preview = gr.TextArea(label='OpenAI数据集格式展示', info= '该数据集的标准格式如下所示,请将自定义的数据集格式转化为该格式。',value=OPENAI_FORMAT) 160 | #dataset_type_preview = gr.JSON(label='数据集格式展示') 161 | with gr.Column(): 162 | dataset_personal_path = gr.Textbox(label = '数据集本地路径', info='请填入本地数据集路径或直接在下方上传数据文件') 163 | # dataset_personal_path_upload = gr.Button('请点击上传数据集本地路径') 164 | dataset_personal = gr.File(label='请上传自定义的数据集或在上方填入本地路径',type='filepath') 165 | check_personal_dataset = gr.Button('检查数据集是否符合要求') 166 | wrong_message3 = gr.Markdown() #判定数据集格式是否符合要求,符合就在上面显示 167 | check_personal_dataset.click(check_custom_dataset, inputs=[dataset_personal_path, dataset_personal], outputs=wrong_message3) 168 | 169 | # with gr.Accordion(label="数据集预览",open=False): 170 | # dataset_preview = gr.TextArea(label='数据集展示', info = '截取前n行内容,可用于对比原数据集格式。') 171 | # #dataset_preview = gr.JSON(label='数据集展示') 172 | # dataset_personal_path_upload.click(fn=read_first_ten_lines, inputs=dataset_personal_path, outputs=dataset_preview, queue=False) 173 | # dataset_personal.change(fn=read_first_ten_lines, inputs=dataset_personal, outputs=dataset_preview, queue=False) 174 | with gr.Accordion(label="对应提示词模版展示",open=False): 175 | with gr.Row(): 176 | prompt_template = gr.Dropdown(PROMPT_TEMPLATE_LIST, label='提示词模版', value='default', info='请选择合适的提示词模版(请勿随意进行调整)',interactive=True) 177 | prompt_template_show = gr.TextArea(label='提示词模版展示') 178 | 179 | model.change(model_path_map_fn, inputs=[model], outputs=[prompt_template]) 180 | # 检测完毕后 -> 改变 prompt_template -> prompt_template_show 181 | check_personal_model.click(app_get_prompt_template, inputs=[model_personal_path, personal_model], outputs=[detect_prompt_status, prompt_template]) 182 | prompt_template.change(fn=get_template_format_by_name, inputs=prompt_template, outputs=prompt_template_show) 183 | 184 | gr.Markdown("## 3. 微调参数设置") 185 | # with gr.Accordion(label="参数调整指南",open=False): 186 | # gr.Markdown('#### 参数调整方式为...') 187 | with gr.Tab("基础参数"): 188 | with gr.Row(): 189 | lr = gr.Number(label='学习率(Learning Rate)', value=2.0e-5, info='学习率控制模型权重调整的幅度,在训练过程中对损失函数的优化有直接影响。较小的学习率可能导致学习过程缓慢,而较大的学习率可能导致学习过程中出现不稳定。') 190 | warmup_ratio = gr.Number(label='预热比(Warmup Ratio)', value=0.03, info='预热比例用于在训练初期逐渐增加学习率,这有助于模型训练初期的稳定性,避免因学习率过高导致的训练不稳定。') 191 | max_length = gr.Number(label='数据集最大长度(Max Length)', value=2048, info='设置数据在处理前的最大长度,确保模型可以处理的序列长度范围内,有助于控制训练过程的内存使用。') 192 | pack_to_max_length = gr.Dropdown(choices=[True, False], value=True, label='合并为最长样本(Pack to Max Length)', info='决定是否将多个样本合并成一个最大长度的样本。这可以提高数据处理的效率,但可能影响模型学习到的模式。') 193 | with gr.Row(): 194 | batch_size_per_device = gr.Number(label='每设备样本个数(Batch Size per Device)', value=1, info='定义每个设备上进行处理的样本数量。较大的批量大小可以提高训练效率,但也会增加内存的使用量。') 195 | accumulative_counts = gr.Number(label='梯度累计数(Gradient Accumulation Steps)', value=16, info='在进行一次参数更新前累积的梯度步数,可以增大批处理大小的效果而不增加内存消耗。') 196 | deepspeed = gr.Dropdown(choices=['None','zero1','zero2','zero3'], value='None', label='Deepspeed算子(Deepspeed)', info='选择Deepspeed优化策略来加速训练和降低内存使用。不同的优化级别提供了不同的内存和计算优化。') 197 | num_GPU = gr.Number(label='GPU数量(Number of GPUs)', value=1, info='设置训练过程中使用的GPU数量。增加GPU数量可以提高训练速度,但需要确保硬件资源充足。') 198 | with gr.Row(): 199 | max_epochs = gr.Number(label='训练迭代数(Max Epochs)', value=2, info='设置模型训练过程中数据将被遍历的次数。较多的迭代次数可以提高模型性能,但也会增加训练时间。') 200 | save_checkpoint_interval = gr.Number(label='保存权重间隔(Save Checkpoint Interval)', value=1000, info='设置自动保存模型权重的间隔(以迭代次数计)。这有助于从训练中途的某个点恢复训练过程。') 201 | save_total_limit = gr.Number(label='最多保存权重文件数(Save Total Limit)', value=2, info='限制保存的模型权重文件的最大数量,有助于管理存储空间,避免因保存过多的模型文件而耗尽存储。') 202 | evaluation_freq = gr.Number(label='验证对话效果频率(evaluation_freq)', value=100, info='请确定模型每多少轮需要验证一次对话效果,具体的对话问题及系统提示词可以在下方评估问题处进行设置') 203 | 204 | # todo: 测试问题 多个的问题 205 | with gr.Accordion(label="评估问题设置", open=True): 206 | evaluation_system_prompt = gr.Textbox(label = '系统提示词(system_prompt)', value='', info='请设置在评估模式下的系统提示词(默认为无)') 207 | default_evaluation_question_number = 2 208 | max_evaluation_question_number = 10 209 | default_evaluation_question_list = [ 210 | '请给我介绍五个上海的景点', 211 | 'Please tell me five scenic spots in Shanghai' 212 | ] 213 | evaluation_question_list = [] 214 | with gr.Accordion(label='评估问题数量及内容',open=True): 215 | with gr.Row(): 216 | with gr.Column(): 217 | evaluation_question_number = gr.Number(label='评估问题数', value=default_evaluation_question_number, minimum=1, maximum=max_evaluation_question_number, info='调整评估问题的数量(最多10个问题)') 218 | with gr.Column(): 219 | for i in range(max_evaluation_question_number): 220 | evaluation_question_if_visible = True if i < default_evaluation_question_number else False 221 | evaluation_question_value = default_evaluation_question_list[i] if i < default_evaluation_question_number else '' 222 | t = gr.Textbox(label=f'评估问题{i + 1}', value=evaluation_question_value, interactive=True, placeholder=f"请输入第{i + 1}个评估的问题", visible=evaluation_question_if_visible) 223 | evaluation_question_list.append(t) 224 | evaluation_question_number.change(evaluation_question_number_change_wrap(max_evaluation_question_number), evaluation_question_number, evaluation_question_list) 225 | with gr.Tab('进阶参数'): 226 | with gr.Row(): 227 | optim_type = gr.Dropdown(choices=['AdamW'], value='AdamW', label='优化器(Optimizer)', info='选择优化器用于调整网络权重以减少误差;AdamW是Adam优化器的一种变体,提供权重衰减控制,通常用于更好的泛化。', visible=True) 228 | 229 | weight_decay = gr.Number(label='权重衰减(Weight Decay)', value=0, info='权重衰减是一种正则化技术,通过为模型的损失函数添加一个与权重大小成比例的惩罚项来防止模型的过拟合。') 230 | 231 | with gr.Row(): 232 | max_norm = gr.Number(label='梯度剪裁(Gradient Clipping)', value=1, info='通过设置梯度的最大阈值来防止在训练过程中梯度爆炸的问题,有助于稳定模型的训练过程。') 233 | dataloader_num_workers = gr.Number(label='数据加载线程数(Data Loader Number of Workers)', value=0, info='设置在数据加载时并行工作的线程数,较高的值可以加快数据加载速度,但会增加内存和处理器的负担。') 234 | 235 | with gr.Accordion(label="AdamW优化器betas", open=False): 236 | beta1 = gr.Number(label='beta1 (一阶矩估计)', value=0.9, info='用于计算梯度的一阶矩估计(即梯度的指数移动平均),决定了过去梯度的权重,高值意味着模型更加关注过去的梯度。') 237 | beta2 = gr.Number(label='beta2 (二阶矩估计)', value=0.999, info='用于计算梯度的二阶矩估计(即梯度平方的指数移动平均),决定了梯度变化率的平滑程度,高值可以使优化过程在长时间内更加平稳。') 238 | 239 | ft_method.change( 240 | get_default_hyperparameters, inputs=[ft_method], 241 | outputs=[ 242 | warmup_ratio, 243 | batch_size_per_device, 244 | accumulative_counts, 245 | num_GPU, 246 | max_length, 247 | pack_to_max_length, 248 | evaluation_freq, 249 | optim_type, 250 | weight_decay, 251 | max_norm, 252 | dataloader_num_workers, 253 | beta1, 254 | beta2, 255 | lr, 256 | save_checkpoint_interval, 257 | save_total_limit 258 | ] 259 | ) 260 | change_config_button = gr.Button('点击生成配置文件') 261 | cfg_py_box = gr.Markdown(value="还未生成配置文件") 262 | change_config_button.click( 263 | build_and_save_config, 264 | inputs=[ 265 | dataset_personal_path, 266 | dataset_personal, 267 | model_personal_path, 268 | personal_model, 269 | prompt_template, 270 | local_path, 271 | ft_method, 272 | model_path, 273 | data_path, 274 | deepspeed, 275 | lr, 276 | warmup_ratio, 277 | batch_size_per_device, 278 | accumulative_counts, 279 | num_GPU, 280 | max_length, 281 | pack_to_max_length, 282 | max_epochs, 283 | save_checkpoint_interval, 284 | save_total_limit, 285 | evaluation_freq, 286 | evaluation_system_prompt, 287 | optim_type, 288 | weight_decay, 289 | max_norm, 290 | dataloader_num_workers, 291 | beta1, 292 | beta2, 293 | prompt_template, 294 | *evaluation_question_list 295 | ], 296 | outputs=[cfg_py_box] 297 | ) 298 | wrong_message4 = gr.Markdown() 299 | 300 | gr.Markdown("## 4. 微调模型训练") 301 | TR_CLS = quickTrain( 302 | config_py_path=cfg_py_box.value, 303 | work_dir=f'{local_path.value}/work_dir', 304 | deepspeed_seed=deepspeed 305 | ) 306 | change_config_button.click(TR_CLS.reset_cfg_py, inputs=[cfg_py_box]) 307 | cfg_py_box.change(TR_CLS.reset_cfg_py, inputs=[cfg_py_box]) 308 | deepspeed.change(TR_CLS.reset_deepspeed, inputs=[deepspeed]) 309 | local_path_button.click(TR_CLS.reset_work_dir, inputs=[local_path]) 310 | with gr.Row(): 311 | train_model = gr.Button('Xtuner!启动!',size='lg') 312 | stop_button = gr.Button('训练中断',size='lg') 313 | work_path = gr.Textbox(label='work dir',visible=False) 314 | 315 | tmp_trian_pg_md = gr.Markdown() 316 | train_model.click(TR_CLS.quick_train, outputs=[tmp_trian_pg_md, work_path]) 317 | stop_button.click(TR_CLS.break_train, outputs=[tmp_trian_pg_md, work_path], queue=False) 318 | with gr.Accordion(label='模型续训', open=False): 319 | retry_path_dropdown = gr.Dropdown(label='请选择需要继续训练的权重文件', info='将从训练中断前的模型权重文件进行搜索',interactive=True) 320 | retry_button = gr.Button('继续训练') 321 | retry_path_dropdown.change(TR_CLS.reset_resume_from_checkpoint, inputs=[retry_path_dropdown]) 322 | retry_button.click(TR_CLS.resume_train, outputs=[tmp_trian_pg_md, work_path]) 323 | 324 | with gr.Accordion(label="终端界面",open=False): 325 | log_file = gr.TextArea(label='日志文件打印', info= '点击可查看模型训练信息') 326 | # train_model.click(TR_CLS.start_log, outputs=[log_file]) 327 | # retry_button.click(TR_CLS.start_log, outputs=[log_file]) 328 | 329 | wrong_message5 = gr.Markdown() 330 | gr.Markdown("## 5. 微调结果展示") 331 | PLT = resPlot( 332 | work_dir = f'{local_path.value}/work_dir', 333 | ) 334 | # 点击停止训练的时候 retry_path_dropdown 进行更新 335 | stop_button.click(PLT.dynamic_drop_down, outputs=retry_path_dropdown, queue=False) 336 | work_path.change(PLT.dynamic_drop_down, outputs=retry_path_dropdown, queue=False) 337 | local_path_button.click(PLT.reset_work_dir, inputs=[local_path]) 338 | work_path.change(PLT.reset_work_dir, inputs=[local_path]) 339 | with gr.Tab('训练结果'): 340 | # with gr.Row(): 341 | # ft_model_save_path = gr.Textbox(label='模型保存路径',visible=False) 342 | # detect work_dir find newest 343 | # iter_num = gr.Number(label='训练轮数', scale=1) 344 | # num_pth = gr.Number(label='权重文件数量', scale=1) 345 | with gr.Row(): 346 | # lr_plot = gr.Image(label='学习率变化图',container=False,show_download_button=False,interactive=False) 347 | # loss_graph = gr.Image(label='损失变化图',container=False,show_download_button=False) 348 | lr_plot = gr.LinePlot(label='学习率变化图') 349 | loss_graph = gr.LinePlot(label='损失变化图') 350 | with gr.Row(): 351 | num_pth_evaluation = gr.Dropdown(label='请选择权重文件', info='可获取模型训练过程中评估问题的结果展示') 352 | evaluation_question = gr.TextArea(label='测试问题结果') 353 | 354 | stop_button.click(PLT.dynamic_eval_drop_down, outputs=num_pth_evaluation, queue=False) 355 | show_evaluation_button = gr.Button('微调结果生成') 356 | show_evaluation_button.click(PLT.reset_work_dir, inputs=[local_path], queue=False) 357 | show_evaluation_button.click(PLT.lr_plot, outputs=[lr_plot], queue=False) 358 | show_evaluation_button.click(PLT.loss_plot, outputs=[loss_graph], queue=False) 359 | # 更新eval的下拉列表 360 | show_evaluation_button.click(PLT.dynamic_eval_drop_down, outputs=num_pth_evaluation, queue=False) 361 | work_path.change(PLT.dynamic_eval_drop_down, outputs=num_pth_evaluation, queue=False) 362 | # 找到 & read eval 363 | num_pth_evaluation.change(PLT.get_eval_test, inputs=[num_pth_evaluation], outputs=[evaluation_question]) 364 | 365 | gr.Markdown("## 6. 微调模型转化及测试") 366 | 367 | with gr.Accordion(label="模型转换",open=True): 368 | # Textbox 369 | # select_checkpoint =gr.Dropdown(choices=['epoch_1.pth', 'epoch_1.pth'], value='epoch_1.pth', label='微调模型的权重文件', info = '请选择需要进行测试的模型权重文件并进行转化') 370 | select_checkpoint = gr.Dropdown(label='微调模型的权重文件', info = '请选择需要进行测试的模型权重文件并进行转化',interactive = True) 371 | stop_button.click(PLT.dynamic_drop_down, outputs=select_checkpoint, queue=False) 372 | show_evaluation_button.click(PLT.dynamic_drop_down, outputs=select_checkpoint, queue=False) 373 | 374 | covert_hf = gr.Button('模型转换',scale=1) 375 | covert_hf_path = gr.Textbox(label='模型转换后地址', visible=False) # False 376 | wrong_message6 = gr.Markdown() 377 | 378 | # root_dir, config_file, epoch_pth, model_path, customer_model_path) 379 | # todo ft_method full-convert oth-convert+merge 380 | covert_hf.click(CVT_MG.auto_convert_merge, inputs=[local_path, cfg_py_box, select_checkpoint, model_path, model_personal_path, ft_method], outputs=[wrong_message6, covert_hf_path]) 381 | with gr.Accordion(label='对话测试', open=True): 382 | with gr.Row(): 383 | with gr.Accordion(label="原模型对话测试", open=True): 384 | with gr.Column(): 385 | with gr.Accordion(label='参数设置',open=False): 386 | max_new_tokens = gr.Slider(minimum=0, maximum=4096, value=1024, label='模型输出的最长Token(max_new_tokens)', info='这个参数决定了模型输出的最大token数量。增加这个值允许模型生成更长的文本,而减少这个值会导致生成的文本更短。') 387 | temperature = gr.Slider(maximum=2, minimum=0, label='温度值(temperature)',value=1, info='控制生成文本的随机性。较高的温度值会使输出更加多样和不可预测,而较低的值使输出更确定和重复。') 388 | top_k = gr.Slider(minimum=0, maximum=100, value=40, label='Top-k Sampling(top-k)', info='限制模型在每一步生成文本时考虑的最可能候选词的数量。较大的k值增加了多样性,但可能降低文本的连贯性;较小的k值则相反。') 389 | top_p = gr.Slider(minimum=0, maximum=2, value=0.75, label='Top-p Sampling(top-p)', info='类似于top_k,但通过选择累积概率高于某个阈值p的最小词集,动态调整考虑的候选词数量。较高的p值增加多样性,较低的p值提高连贯性。') 390 | num_beams = gr.Slider(minimum=0, maximum=12, value=5, label='Beam Search(num_beams)', info='在beam search中,num_beams指定了搜索宽度。更多的beams可以提高生成文本的质量,但也会增加计算负担。') 391 | 392 | #还可以添加更多 393 | wrong_message9 = gr.Markdown() 394 | start_testing_model = gr.Button('模型启动') 395 | testig_model_loaded = gr.Markdown() 396 | chatbot = gr.Chatbot(label='微调模型测试') 397 | msg = gr.Textbox(label="输入信息") 398 | msg.submit(CHAT_ORG.qa_answer, inputs=[msg, max_new_tokens, temperature, top_k, top_p, num_beams, chatbot], outputs=[msg, chatbot]) 399 | # 模型载入 400 | start_testing_model.click(CHAT_ORG.load_model, inputs=[model_personal_path, personal_model, model_path], outputs=[testig_model_loaded]) 401 | send = gr.Button('信息发送') # .click(regenerate, inputs=[chatbot], outputs = [msg, chatbot]) 402 | with gr.Row(): 403 | clear = gr.Button('记录删除') # .click(clear_history, inputs=[chatbot], outputs=[chatbot]) 404 | undo = gr.Button('撤回上一条') # .click(undo, inputs=[chatbot], outputs=[chatbot]) 405 | 406 | 407 | clear.click(CHAT_ORG.qa_clear, inputs=[chatbot], outputs=[chatbot]) 408 | undo.click(CHAT_ORG.qa_undo, inputs=[chatbot], outputs=[chatbot]) 409 | send.click(CHAT_ORG.qa_answer, inputs=[msg, max_new_tokens, temperature, top_k, top_p, num_beams, chatbot], outputs=[msg, chatbot]) 410 | 411 | with gr.Accordion(label="微调模型对话测试", open=True): 412 | with gr.Column(): 413 | with gr.Accordion(label='参数设置',open=False): 414 | ft_max_new_tokens = gr.Slider(minimum=0, maximum=4096, value=1024, label='模型输出的最长Token(max_new_tokens)', info='这个参数决定了模型输出的最大token数量。增加这个值允许模型生成更长的文本,而减少这个值会导致生成的文本更短。') 415 | ft_temperature = gr.Slider(maximum=2, minimum=0,value=1, label='温度值(temperature)', info='控制生成文本的随机性。较高的温度值会使输出更加多样和不可预测,而较低的值使输出更确定和重复。') 416 | ft_top_k = gr.Slider(minimum=0, maximum=100, value=40, label='Top-k Sampling(top-k)', info='限制模型在每一步生成文本时考虑的最可能候选词的数量。较大的k值增加了多样性,但可能降低文本的连贯性;较小的k值则相反。') 417 | ft_top_p = gr.Slider(minimum=0, maximum=2, value=0.75, label='Top-p Sampling(top-p)', info='类似于top_k,但通过选择累积概率高于某个阈值p的最小词集,动态调整考虑的候选词数量。较高的p值增加多样性,较低的p值提高连贯性。') 418 | ft_num_beams = gr.Slider(minimum=0, maximum=12, value=5, label='Beam Search(num_beams)', info='在beam search中,num_beams指定了搜索宽度。更多的beams可以提高生成文本的质量,但也会增加计算负担。') 419 | #还可以添加更多 420 | ft_wrong_message9 = gr.Markdown() 421 | ft_start_testing_model = gr.Button('模型启动') 422 | ft_testig_model_loaded = gr.Markdown() 423 | ft_chatbot = gr.Chatbot(label='微调模型测试') 424 | ft_msg = gr.Textbox(label="输入信息") 425 | 426 | ft_msg.submit(FT_CHAT_ORG.qa_answer, inputs=[ft_msg, ft_max_new_tokens, ft_temperature, ft_top_k, ft_top_p, ft_num_beams, ft_chatbot], outputs=[ft_msg, ft_chatbot]) 427 | # 模型载入 428 | ft_start_testing_model.click(FT_CHAT_ORG.load_model, inputs=[covert_hf_path, personal_model, model_path], outputs=[ft_testig_model_loaded]) 429 | 430 | ft_send = gr.Button('信息发送') 431 | with gr.Row(): 432 | ft_clear = gr.Button('记录删除') 433 | ft_undo = gr.Button('撤回上一条') 434 | 435 | 436 | ft_clear.click(FT_CHAT_ORG.qa_clear, inputs=[ft_chatbot], outputs=[ft_chatbot]) 437 | ft_undo.click(FT_CHAT_ORG.qa_undo, inputs=[ft_chatbot], outputs=[ft_chatbot]) 438 | ft_send.click(FT_CHAT_ORG.qa_answer, inputs=[ft_msg, ft_max_new_tokens, ft_temperature, ft_top_k, ft_top_p, ft_num_beams, ft_chatbot], outputs=[ft_msg, ft_chatbot]) 439 | # with gr.Accordion(label='模型基础能力评估测试',open=False): 440 | # mmlu_test_button = gr.Button('MMLU模型能力评估测试') 441 | with gr.Accordion(label="其他信息", open=True): 442 | star = gr.Markdown('### 如果您觉得该UI界面能帮助您高效完成微调工作,请为[XTuner](https://github.com/InternLM/xtuner.git)和[XTunerGUI](https://github.com/scchy/XtunerGUI.git)点个星星!非常感谢您的支持!') 443 | thanks = gr.Markdown('''### 最后,感谢XTUnerGUI团队成员对该项目的贡献: 444 | - [scchy](https://github.com/scchy) - 整体后端开发 445 | - [jianfeng777](https://github.com/Jianfeng777) - 整体前端开发 446 | - [l241025097](https://github.com/l241025097) - 模型训练终端可视化 447 | - [semple030228](https://github.com/semple030228) - 模型转化 448 | 449 | ### 同时也感谢XTuner团队的大力支持: 450 | - [HIT-cwh](https://github.com/HIT-cwh) - 配置文件生成及相关检查 451 | - [pppppM](https://github.com/pppppM) - 提供指导意见 452 | 453 | 我们相信XTuner将成为国内有影响力的模型微调工具包! 454 | ''') 455 | 456 | #with gr.Tab('微调模型部署(LMDeploy)'): 457 | 458 | #with gr.Tab('微调模型测评(OpenCompass)'): 459 | 460 | demo.load(TR_CLS.read_log, outputs=[log_file], every=1) 461 | demo.launch(share=True) #, server_name="0.0.0.0", server_port=6007, root_path=f'/proxy/6007/') #, server_name="0.0.0.0", server_port=6006, root_path=f'/proxy/6006/') 462 | 463 | 464 | 465 | 466 | -------------------------------------------------------------------------------- /appPrepare/env_prepare.sh: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | cd ~ 5 | mkdir xtunerPKG 6 | git clone https://gitee.com/InternLM/xtuner.git 7 | cd xtuner 8 | pip install -e '.[all]' 9 | 10 | -------------------------------------------------------------------------------- /appPrepare/files_prepare.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | CUR_PATH = os.path.dirname(__file__) 5 | DATA_PATH = os.path.dirname(CUR_PATH) 6 | 7 | def dir_create(_dir): 8 | if not os.path.exists(_dir): 9 | os.system(f'mkdir -p {_dir}') 10 | return _dir 11 | 12 | DEFAULT_DOWNLOAD_DIR = dir_create(f"{DATA_PATH}/download_cache") 13 | MODEL_DOWNLOAD_DIR = dir_create(f"{DEFAULT_DOWNLOAD_DIR}/model_download") 14 | DATA_DOWNLOAD_DIR = dir_create(f"{DEFAULT_DOWNLOAD_DIR}/data_download") 15 | WORK_DIR = dir_create(f"{CUR_PATH}/work_dir") 16 | 17 | 18 | 19 | 20 | 21 | -------------------------------------------------------------------------------- /appPrepare/func_prepare.py: -------------------------------------------------------------------------------- 1 | from xtuner.utils import PROMPT_TEMPLATE 2 | 3 | OPENAI_FORMAT = ''' 4 | [ 5 | { 6 | "messages": 7 | [ 8 | { "role": "system", "content": "You are an assistant that occasionally misspells words." }, 9 | { "role": "user", "content": "Tell me a story." }, 10 | { "role": "assistant", "content": "One day a student went to schoool." } 11 | ] 12 | }, 13 | 14 | { 15 | "messages": 16 | [ 17 | { "role": "user", "content": "Tell me a story." }, 18 | { "role": "assistant", "content": "One day a student went to schoool." } 19 | ] 20 | } 21 | ] 22 | ''' 23 | 24 | def read_first_ten_lines(input_path, upload_path=None): 25 | file_path = input_path if len(input_path) >= 3 else upload_path 26 | try: 27 | with open(file_path, 'r', encoding='utf-8') as file: 28 | lines = file.readlines() # 读取所有行 29 | first_ten_lines = lines[:20] # 获取前十行 30 | return ''.join(first_ten_lines) # 将前十行合并为一个字符串并返回 31 | except Exception as e: 32 | return f"Error reading file: {str(e)}" 33 | 34 | 35 | def get_template_format_by_name(template_name): 36 | template = PROMPT_TEMPLATE.get(template_name, None) 37 | if template is None: 38 | return "Template not found" 39 | return str(template) 40 | -------------------------------------------------------------------------------- /appPrepare/list_prepare.py: -------------------------------------------------------------------------------- 1 | 2 | DATA_LIST = [ 3 | 'ArmelR/stack-exchange-instruction', 4 | 'HuggingFaceH4/CodeAlpaca_20K', 5 | 'Open-Orca/OpenOrca', 6 | 'Skywork/SkyPile-150B', 7 | 'WizardLM/WizardLM_evol_instruct_V2_196k', 8 | 'b-mc2/sql-create-context', 9 | 'burkelibbey/colors', 10 | 'damo/MSAgent-Bench', 11 | 'garage-bAInd/Open-Platypus', 12 | 'mistralai/Mistral-7B-v0.1', 13 | 'nampdn-ai/tiny-codes', 14 | 'shibing624/medical', 15 | 'silk-road/alpaca-data-gpt4-chinese', 16 | 'tatsu-lab/alpaca', 17 | 'timdettmers/openassistant-guanaco', 18 | ] 19 | 20 | MODEL_LIST = [ 21 | 'internlm/internlm-7b', 22 | 'internlm/internlm-20b', 23 | 'internlm/internlm-chat-7b', 24 | 'internlm/internlm-chat-20b', 25 | 'meta-llama/Llama-2-7b-chat', 26 | 'meta-llama/llama2-70b', 27 | 'meta-llama/Llama-2-7b', 28 | 'huggyllama/llama-7b', 29 | 'baichuan-inc/Baichuan-7B-Chat', 30 | 'baichuan-inc/Baichuan-13B-Base', 31 | 'baichuan-inc/baichuan-7B', 32 | 'baichuan-inc/Baichuan2-13B-Chat', 33 | 'baichuan-inc/Baichuan2-7B-Chat', 34 | 'baichuan-inc/Baichuan2-13B-Chat', 35 | 'baichuan-inc/Baichuan2-13B-Base', 36 | 'THUDM/chatglm3-6b', 37 | 'THUDM/chatglm2-6b', 38 | 'THUDM/chatglm3-6b-base', 39 | '01-ai/Yi-24B', 40 | '01-ai/Yi-6B', 41 | 'Qwen/Qwen-7B-Chat', 42 | 'Qwen/Qwen-7B', 43 | ] 44 | 45 | 46 | PROMPT_TEMPLATE_LIST = [ 47 | 'default', 48 | 'zephyr', 49 | 'internlm_chat', 50 | 'internlm2_chat', 51 | 'moss_sft', 52 | 'llama2_chat', 53 | 'code_llama_chat', 54 | 'chatglm2', 55 | 'chatglm3', 56 | 'qwen_chat', 57 | 'baichuan_chat', 58 | 'baichuan2_chat', 59 | 'wizardlm', 60 | 'wizardcoder', 61 | 'vicuna', 62 | 'deepseek_coder', 63 | 'deepseekcoder', 64 | 'deepseek_moe', 65 | 'mistral', 66 | 'mixtral' 67 | ] -------------------------------------------------------------------------------- /chat/model_center.py: -------------------------------------------------------------------------------- 1 | # python3 2 | # Create Date: 2024-02-03 3 | # Author: Scc_hy 4 | # Func: chat center 5 | # ============================================================================== 6 | 7 | import torch 8 | from transformers import AutoModelForCausalLM, AutoTokenizer 9 | 10 | user_prompt = "<|User|>:{user}\n" 11 | robot_prompt = "<|Bot|>:{robot}\n" 12 | cur_query_prompt = "<|User|>:{user}\n<|Bot|>:" 13 | 14 | 15 | class ModelCenter(): 16 | def __init__(self): 17 | self.model = None 18 | self.tokenizer = None 19 | 20 | def load_model(self, 21 | model_personal_path, 22 | personal_model, 23 | model_path_in 24 | ): 25 | # 构造函数,加载检索问答链 26 | model_path = self.choice_path(model_personal_path, personal_model, model_path_in) 27 | print(f'ModelCenter.load_model({model_path})') 28 | self.model = ( 29 | AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True) 30 | .to(torch.bfloat16) 31 | .cuda() 32 | ) 33 | self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) 34 | print('>>>>> Loading Model Done !') 35 | return f'>>> Loaded {model_path}' 36 | 37 | @staticmethod 38 | def choice_path(model_personal_path, personal_model, model_path_in): 39 | if len(model_personal_path) >= 3: 40 | return model_personal_path 41 | if len(personal_model) >= 3: 42 | return personal_model 43 | return model_path_in 44 | 45 | def qa_answer(self, question: str, max_new_tokens, temperature, top_k, top_p, num_beams, chat_history: list = []): 46 | if question == None or len(question) < 1: 47 | return "", chat_history 48 | try: 49 | question = question.replace(" ", ' ') 50 | response, history = self.model.chat( 51 | self.tokenizer, 52 | question, 53 | history=chat_history, 54 | max_new_tokens=max_new_tokens, 55 | temperature=temperature, 56 | top_k=top_k, 57 | top_p=top_p, 58 | num_beams=int(num_beams) 59 | ) 60 | chat_history.append((question, response)) 61 | return "", chat_history 62 | except Exception as e: 63 | return e, chat_history 64 | 65 | def qa_undo(self, chat_history: list = []): 66 | if len(chat_history): 67 | chat_history.pop() 68 | return chat_history 69 | 70 | def qa_clear(self, chat_history: list = []): 71 | return [] 72 | -------------------------------------------------------------------------------- /pictures/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scchy/XtunerGUI/6e510e39abf55169f45f0e94fb14055c2b973af4/pictures/1.png -------------------------------------------------------------------------------- /pictures/demo.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scchy/XtunerGUI/6e510e39abf55169f45f0e94fb14055c2b973af4/pictures/demo.mp4 -------------------------------------------------------------------------------- /pictures/workflow.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scchy/XtunerGUI/6e510e39abf55169f45f0e94fb14055c2b973af4/pictures/workflow.jpg -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tqdm 2 | openxlab==0.0.34 3 | gradio==4.4.0 4 | transformers==4.34.0 5 | huggingface_hub 6 | modelscope==1.11.0 7 | unstructured==0.10.30 8 | markdown==3.3.7 9 | xtuner 10 | altair 11 | -------------------------------------------------------------------------------- /template_configs/full_finetune.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | from datasets import load_dataset 3 | from mmengine.dataset import DefaultSampler 4 | from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook, 5 | LoggerHook, ParamSchedulerHook) 6 | from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR 7 | from torch.optim import AdamW 8 | from transformers import AutoModelForCausalLM, AutoTokenizer 9 | 10 | from xtuner.dataset import process_hf_dataset 11 | from xtuner.dataset.collate_fns import default_collate_fn 12 | from xtuner.dataset.map_fns import alpaca_map_fn, template_map_fn_factory 13 | from xtuner.engine.hooks import (DatasetInfoHook, EvaluateChatHook, 14 | ThroughputHook, 15 | VarlenAttnArgsToMessageHubHook) 16 | from xtuner.engine.runner import TrainLoop 17 | from xtuner.model import SupervisedFinetune 18 | from xtuner.utils import PROMPT_TEMPLATE, SYSTEM_TEMPLATE 19 | 20 | ####################################################################### 21 | # PART 1 Settings # 22 | ####################################################################### 23 | # Model 24 | pretrained_model_name_or_path = 'internlm/internlm-7b' 25 | use_varlen_attn = False 26 | 27 | # Data 28 | data_path = 'tatsu-lab/alpaca' 29 | prompt_template = PROMPT_TEMPLATE.internlm_chat 30 | max_length = 2048 31 | pack_to_max_length = True 32 | 33 | # Scheduler & Optimizer 34 | batch_size = 1 # per_device 35 | accumulative_counts = 2 36 | dataloader_num_workers = 0 37 | max_epochs = 3 38 | optim_type = AdamW 39 | lr = 2e-5 40 | betas = (0.9, 0.999) 41 | weight_decay = 0 42 | max_norm = 1 # grad clip 43 | warmup_ratio = 0.03 44 | 45 | # Save 46 | save_steps = 500 47 | save_total_limit = 2 # Maximum checkpoints to keep (-1 means unlimited) 48 | 49 | # Evaluate the generation performance during the training 50 | evaluation_freq = 500 51 | SYSTEM = SYSTEM_TEMPLATE.alpaca 52 | evaluation_inputs = [ 53 | '请给我介绍五个上海的景点', 'Please tell me five scenic spots in Shanghai' 54 | ] 55 | 56 | ####################################################################### 57 | # PART 2 Model & Tokenizer # 58 | ####################################################################### 59 | tokenizer = dict( 60 | type=AutoTokenizer.from_pretrained, 61 | pretrained_model_name_or_path=pretrained_model_name_or_path, 62 | trust_remote_code=True, 63 | padding_side='right') 64 | 65 | model = dict( 66 | type=SupervisedFinetune, 67 | use_varlen_attn=use_varlen_attn, 68 | llm=dict( 69 | type=AutoModelForCausalLM.from_pretrained, 70 | pretrained_model_name_or_path=pretrained_model_name_or_path, 71 | trust_remote_code=True)) 72 | 73 | ####################################################################### 74 | # PART 3 Dataset & Dataloader # 75 | ####################################################################### 76 | train_dataset = dict( 77 | type=process_hf_dataset, 78 | dataset=dict(type=load_dataset, path=data_path), 79 | tokenizer=tokenizer, 80 | max_length=max_length, 81 | dataset_map_fn=alpaca_map_fn, 82 | template_map_fn=dict( 83 | type=template_map_fn_factory, template=prompt_template), 84 | remove_unused_columns=True, 85 | shuffle_before_pack=True, 86 | pack_to_max_length=pack_to_max_length, 87 | use_varlen_attn=use_varlen_attn) 88 | 89 | train_dataloader = dict( 90 | batch_size=batch_size, 91 | num_workers=dataloader_num_workers, 92 | dataset=train_dataset, 93 | sampler=dict(type=DefaultSampler, shuffle=True), 94 | collate_fn=dict(type=default_collate_fn, use_varlen_attn=use_varlen_attn)) 95 | 96 | ####################################################################### 97 | # PART 4 Scheduler & Optimizer # 98 | ####################################################################### 99 | # optimizer 100 | optim_wrapper = dict( 101 | type=AmpOptimWrapper, 102 | optimizer=dict( 103 | type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay), 104 | clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False), 105 | accumulative_counts=accumulative_counts, 106 | loss_scale='dynamic', 107 | dtype='float16') 108 | 109 | # learning policy 110 | # More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501 111 | param_scheduler = [ 112 | dict( 113 | type=LinearLR, 114 | start_factor=1e-5, 115 | by_epoch=True, 116 | begin=0, 117 | end=warmup_ratio * max_epochs, 118 | convert_to_iter_based=True), 119 | dict( 120 | type=CosineAnnealingLR, 121 | eta_min=0.0, 122 | by_epoch=True, 123 | begin=warmup_ratio * max_epochs, 124 | end=max_epochs, 125 | convert_to_iter_based=True) 126 | ] 127 | 128 | # train, val, test setting 129 | train_cfg = dict(type=TrainLoop, max_epochs=max_epochs) 130 | 131 | ####################################################################### 132 | # PART 5 Runtime # 133 | ####################################################################### 134 | # Log the dialogue periodically during the training process, optional 135 | custom_hooks = [ 136 | dict(type=DatasetInfoHook, tokenizer=tokenizer), 137 | dict( 138 | type=EvaluateChatHook, 139 | tokenizer=tokenizer, 140 | every_n_iters=evaluation_freq, 141 | evaluation_inputs=evaluation_inputs, 142 | system=SYSTEM, 143 | prompt_template=prompt_template), 144 | dict(type=ThroughputHook) 145 | ] 146 | 147 | if use_varlen_attn: 148 | custom_hooks += [dict(type=VarlenAttnArgsToMessageHubHook)] 149 | 150 | # configure default hooks 151 | default_hooks = dict( 152 | # record the time of every iteration. 153 | timer=dict(type=IterTimerHook), 154 | # print log every 10 iterations. 155 | logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10), 156 | # enable the parameter scheduler. 157 | param_scheduler=dict(type=ParamSchedulerHook), 158 | # save checkpoint per `save_steps`. 159 | checkpoint=dict( 160 | type=CheckpointHook, 161 | by_epoch=False, 162 | interval=save_steps, 163 | max_keep_ckpts=save_total_limit), 164 | # set sampler seed in distributed evrionment. 165 | sampler_seed=dict(type=DistSamplerSeedHook), 166 | ) 167 | 168 | # configure environment 169 | env_cfg = dict( 170 | # whether to enable cudnn benchmark 171 | cudnn_benchmark=False, 172 | # set multi process parameters 173 | mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), 174 | # set distributed parameters 175 | dist_cfg=dict(backend='nccl'), 176 | ) 177 | 178 | # set visualizer 179 | visualizer = None 180 | 181 | # set log level 182 | log_level = 'INFO' 183 | 184 | # load from which checkpoint 185 | load_from = None 186 | 187 | # whether to resume training from the loaded checkpoint 188 | resume = False 189 | 190 | # Defaults to use random seed and disable `deterministic` 191 | randomness = dict(seed=None, deterministic=False) 192 | 193 | # set log processor 194 | log_processor = dict(by_epoch=False) 195 | -------------------------------------------------------------------------------- /template_configs/lora.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | from datasets import load_dataset 4 | from mmengine.dataset import DefaultSampler 5 | from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook, 6 | LoggerHook, ParamSchedulerHook) 7 | from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR 8 | from peft import LoraConfig 9 | from torch.optim import AdamW 10 | from transformers import AutoModelForCausalLM, AutoTokenizer 11 | 12 | from xtuner.dataset import process_hf_dataset 13 | from xtuner.dataset.collate_fns import default_collate_fn 14 | from xtuner.dataset.map_fns import alpaca_map_fn, template_map_fn_factory 15 | from xtuner.engine import DatasetInfoHook, EvaluateChatHook, VarlenAttnArgsToMessageHubHook 16 | from xtuner.engine.runner import TrainLoop 17 | from xtuner.model import SupervisedFinetune 18 | from xtuner.utils import PROMPT_TEMPLATE, SYSTEM_TEMPLATE 19 | 20 | ####################################################################### 21 | # PART 1 Settings # 22 | ####################################################################### 23 | # Model 24 | pretrained_model_name_or_path = 'internlm/internlm-7b' 25 | use_varlen_attn = False 26 | 27 | # Data 28 | data_path = 'tatsu-lab/alpaca' 29 | prompt_template = PROMPT_TEMPLATE.default 30 | max_length = 2048 31 | pack_to_max_length = True 32 | 33 | # Scheduler & Optimizer 34 | batch_size = 1 # per_device 35 | accumulative_counts = 2 36 | dataloader_num_workers = 0 37 | max_epochs = 3 38 | optim_type = AdamW 39 | lr = 2e-4 40 | betas = (0.9, 0.999) 41 | weight_decay = 0 42 | max_norm = 1 # grad clip 43 | warmup_ratio = 0.03 44 | 45 | # Save 46 | save_steps = 200 47 | save_total_limit = 5 # Maximum checkpoints to keep (-1 means unlimited) 48 | 49 | # Evaluate the generation performance during the training 50 | evaluation_freq = 500 51 | SYSTEM = SYSTEM_TEMPLATE.alpaca 52 | evaluation_inputs = [ 53 | '请给我介绍五个上海的景点', 'Please tell me five scenic spots in Shanghai' 54 | ] 55 | 56 | ####################################################################### 57 | # PART 2 Model & Tokenizer # 58 | ####################################################################### 59 | tokenizer = dict( 60 | type=AutoTokenizer.from_pretrained, 61 | pretrained_model_name_or_path=pretrained_model_name_or_path, 62 | trust_remote_code=True, 63 | padding_side='right') 64 | 65 | model = dict( 66 | type=SupervisedFinetune, 67 | use_varlen_attn=use_varlen_attn, 68 | llm=dict( 69 | type=AutoModelForCausalLM.from_pretrained, 70 | pretrained_model_name_or_path=pretrained_model_name_or_path, 71 | trust_remote_code=True, 72 | torch_dtype=torch.float16), 73 | lora=dict( 74 | type=LoraConfig, 75 | r=64, 76 | lora_alpha=16, 77 | lora_dropout=0.1, 78 | bias='none', 79 | task_type='CAUSAL_LM')) 80 | 81 | ####################################################################### 82 | # PART 3 Dataset & Dataloader # 83 | ####################################################################### 84 | train_dataset = dict( 85 | type=process_hf_dataset, 86 | dataset=dict(type=load_dataset, path=data_path), 87 | tokenizer=tokenizer, 88 | max_length=max_length, 89 | dataset_map_fn=alpaca_map_fn, 90 | template_map_fn=dict( 91 | type=template_map_fn_factory, template=prompt_template), 92 | remove_unused_columns=True, 93 | shuffle_before_pack=True, 94 | pack_to_max_length=pack_to_max_length, 95 | use_varlen_attn=use_varlen_attn) 96 | 97 | train_dataloader = dict( 98 | batch_size=batch_size, 99 | num_workers=dataloader_num_workers, 100 | dataset=train_dataset, 101 | sampler=dict(type=DefaultSampler, shuffle=True), 102 | collate_fn=dict(type=default_collate_fn, use_varlen_attn=use_varlen_attn)) 103 | 104 | ####################################################################### 105 | # PART 4 Scheduler & Optimizer # 106 | ####################################################################### 107 | # optimizer 108 | optim_wrapper = dict( 109 | type=AmpOptimWrapper, 110 | optimizer=dict( 111 | type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay), 112 | clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False), 113 | accumulative_counts=accumulative_counts, 114 | loss_scale='dynamic', 115 | dtype='float16') 116 | 117 | # learning policy 118 | # More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501 119 | param_scheduler = [ 120 | dict( 121 | type=LinearLR, 122 | start_factor=1e-5, 123 | by_epoch=True, 124 | begin=0, 125 | end=warmup_ratio * max_epochs, 126 | convert_to_iter_based=True), 127 | dict( 128 | type=CosineAnnealingLR, 129 | eta_min=0.0, 130 | by_epoch=True, 131 | begin=warmup_ratio * max_epochs, 132 | end=max_epochs, 133 | convert_to_iter_based=True) 134 | ] 135 | 136 | # train, val, test setting 137 | train_cfg = dict(type=TrainLoop, max_epochs=max_epochs) 138 | 139 | ####################################################################### 140 | # PART 5 Runtime # 141 | ####################################################################### 142 | # Log the dialogue periodically during the training process, optional 143 | custom_hooks = [ 144 | dict(type=DatasetInfoHook, tokenizer=tokenizer), 145 | dict( 146 | type=EvaluateChatHook, 147 | tokenizer=tokenizer, 148 | every_n_iters=evaluation_freq, 149 | evaluation_inputs=evaluation_inputs, 150 | system=SYSTEM, 151 | prompt_template=prompt_template) 152 | ] 153 | 154 | if use_varlen_attn: 155 | custom_hooks += [dict(type=VarlenAttnArgsToMessageHubHook)] 156 | 157 | # configure default hooks 158 | default_hooks = dict( 159 | # record the time of every iteration. 160 | timer=dict(type=IterTimerHook), 161 | # print log every 10 iterations. 162 | logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10), 163 | # enable the parameter scheduler. 164 | param_scheduler=dict(type=ParamSchedulerHook), 165 | # save checkpoint per `save_steps`. 166 | checkpoint=dict( 167 | type=CheckpointHook, 168 | by_epoch=False, 169 | interval=save_steps, 170 | max_keep_ckpts=save_total_limit), 171 | # set sampler seed in distributed evrionment. 172 | sampler_seed=dict(type=DistSamplerSeedHook), 173 | ) 174 | 175 | # configure environment 176 | env_cfg = dict( 177 | # whether to enable cudnn benchmark 178 | cudnn_benchmark=False, 179 | # set multi process parameters 180 | mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), 181 | # set distributed parameters 182 | dist_cfg=dict(backend='nccl'), 183 | ) 184 | 185 | # set visualizer 186 | visualizer = None 187 | 188 | # set log level 189 | log_level = 'INFO' 190 | 191 | # load from which checkpoint 192 | load_from = None 193 | 194 | # whether to resume training from the loaded checkpoint 195 | resume = False 196 | 197 | # Defaults to use random seed and disable `deterministic` 198 | randomness = dict(seed=None, deterministic=False) 199 | 200 | # set log processor 201 | log_processor = dict(by_epoch=False) 202 | -------------------------------------------------------------------------------- /template_configs/qlora.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | from datasets import load_dataset 4 | from mmengine.dataset import DefaultSampler 5 | from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook, 6 | LoggerHook, ParamSchedulerHook) 7 | from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR 8 | from peft import LoraConfig 9 | from torch.optim import AdamW 10 | from transformers import (AutoModelForCausalLM, AutoTokenizer, 11 | BitsAndBytesConfig) 12 | 13 | from xtuner.dataset import process_hf_dataset 14 | from xtuner.dataset.collate_fns import default_collate_fn 15 | from xtuner.dataset.map_fns import alpaca_map_fn, template_map_fn_factory 16 | from xtuner.engine.hooks import (DatasetInfoHook, EvaluateChatHook, 17 | VarlenAttnArgsToMessageHubHook) 18 | from xtuner.engine.runner import TrainLoop 19 | from xtuner.model import SupervisedFinetune 20 | from xtuner.utils import PROMPT_TEMPLATE, SYSTEM_TEMPLATE 21 | 22 | ####################################################################### 23 | # PART 1 Settings # 24 | ####################################################################### 25 | # Model 26 | pretrained_model_name_or_path = 'internlm/internlm-7b' 27 | use_varlen_attn = False 28 | 29 | # Data 30 | alpaca_en_path = 'tatsu-lab/alpaca' 31 | prompt_template = PROMPT_TEMPLATE.default 32 | max_length = 2048 33 | pack_to_max_length = True 34 | 35 | # Scheduler & Optimizer 36 | batch_size = 1 # per_device 37 | accumulative_counts = 2 38 | dataloader_num_workers = 0 39 | max_epochs = 3 40 | optim_type = AdamW 41 | lr = 2e-4 42 | betas = (0.9, 0.999) 43 | weight_decay = 0 44 | max_norm = 1 # grad clip 45 | warmup_ratio = 0.03 46 | 47 | # Save 48 | save_steps = 200 49 | save_total_limit = 5 # Maximum checkpoints to keep (-1 means unlimited) 50 | 51 | # Evaluate the generation performance during the training 52 | evaluation_freq = 500 53 | SYSTEM = SYSTEM_TEMPLATE.alpaca 54 | evaluation_inputs = [ 55 | '请给我介绍五个上海的景点', 'Please tell me five scenic spots in Shanghai' 56 | ] 57 | 58 | ####################################################################### 59 | # PART 2 Model & Tokenizer # 60 | ####################################################################### 61 | tokenizer = dict( 62 | type=AutoTokenizer.from_pretrained, 63 | pretrained_model_name_or_path=pretrained_model_name_or_path, 64 | trust_remote_code=True, 65 | padding_side='right') 66 | 67 | model = dict( 68 | type=SupervisedFinetune, 69 | use_varlen_attn=use_varlen_attn, 70 | llm=dict( 71 | type=AutoModelForCausalLM.from_pretrained, 72 | pretrained_model_name_or_path=pretrained_model_name_or_path, 73 | trust_remote_code=True, 74 | torch_dtype=torch.float16, 75 | quantization_config=dict( 76 | type=BitsAndBytesConfig, 77 | load_in_4bit=True, 78 | load_in_8bit=False, 79 | llm_int8_threshold=6.0, 80 | llm_int8_has_fp16_weight=False, 81 | bnb_4bit_compute_dtype=torch.float16, 82 | bnb_4bit_use_double_quant=True, 83 | bnb_4bit_quant_type='nf4')), 84 | lora=dict( 85 | type=LoraConfig, 86 | r=64, 87 | lora_alpha=16, 88 | lora_dropout=0.1, 89 | bias='none', 90 | task_type='CAUSAL_LM')) 91 | 92 | ####################################################################### 93 | # PART 3 Dataset & Dataloader # 94 | ####################################################################### 95 | alpaca_en = dict( 96 | type=process_hf_dataset, 97 | dataset=dict(type=load_dataset, path=alpaca_en_path), 98 | tokenizer=tokenizer, 99 | max_length=max_length, 100 | dataset_map_fn=alpaca_map_fn, 101 | template_map_fn=dict( 102 | type=template_map_fn_factory, template=prompt_template), 103 | remove_unused_columns=True, 104 | shuffle_before_pack=True, 105 | pack_to_max_length=pack_to_max_length, 106 | use_varlen_attn=use_varlen_attn) 107 | 108 | train_dataloader = dict( 109 | batch_size=batch_size, 110 | num_workers=dataloader_num_workers, 111 | dataset=alpaca_en, 112 | sampler=dict(type=DefaultSampler, shuffle=True), 113 | collate_fn=dict(type=default_collate_fn, use_varlen_attn=use_varlen_attn)) 114 | 115 | ####################################################################### 116 | # PART 4 Scheduler & Optimizer # 117 | ####################################################################### 118 | # optimizer 119 | optim_wrapper = dict( 120 | type=AmpOptimWrapper, 121 | optimizer=dict( 122 | type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay), 123 | clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False), 124 | accumulative_counts=accumulative_counts, 125 | loss_scale='dynamic', 126 | dtype='float16') 127 | 128 | # learning policy 129 | # More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501 130 | param_scheduler = [ 131 | dict( 132 | type=LinearLR, 133 | start_factor=1e-5, 134 | by_epoch=True, 135 | begin=0, 136 | end=warmup_ratio * max_epochs, 137 | convert_to_iter_based=True), 138 | dict( 139 | type=CosineAnnealingLR, 140 | eta_min=0.0, 141 | by_epoch=True, 142 | begin=warmup_ratio * max_epochs, 143 | end=max_epochs, 144 | convert_to_iter_based=True) 145 | ] 146 | 147 | # train, val, test setting 148 | train_cfg = dict(type=TrainLoop, max_epochs=max_epochs) 149 | 150 | ####################################################################### 151 | # PART 5 Runtime # 152 | ####################################################################### 153 | # Log the dialogue periodically during the training process, optional 154 | custom_hooks = [ 155 | dict(type=DatasetInfoHook, tokenizer=tokenizer), 156 | dict( 157 | type=EvaluateChatHook, 158 | tokenizer=tokenizer, 159 | every_n_iters=evaluation_freq, 160 | evaluation_inputs=evaluation_inputs, 161 | system=SYSTEM, 162 | prompt_template=prompt_template) 163 | ] 164 | 165 | if use_varlen_attn: 166 | custom_hooks += [dict(type=VarlenAttnArgsToMessageHubHook)] 167 | 168 | # configure default hooks 169 | default_hooks = dict( 170 | # record the time of every iteration. 171 | timer=dict(type=IterTimerHook), 172 | # print log every 10 iterations. 173 | logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10), 174 | # enable the parameter scheduler. 175 | param_scheduler=dict(type=ParamSchedulerHook), 176 | # save checkpoint per `save_steps`. 177 | checkpoint=dict( 178 | type=CheckpointHook, 179 | by_epoch=False, 180 | interval=save_steps, 181 | max_keep_ckpts=save_total_limit), 182 | # set sampler seed in distributed evrionment. 183 | sampler_seed=dict(type=DistSamplerSeedHook), 184 | ) 185 | 186 | # configure environment 187 | env_cfg = dict( 188 | # whether to enable cudnn benchmark 189 | cudnn_benchmark=False, 190 | # set multi process parameters 191 | mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), 192 | # set distributed parameters 193 | dist_cfg=dict(backend='nccl'), 194 | ) 195 | 196 | # set visualizer 197 | visualizer = None 198 | 199 | # set log level 200 | log_level = 'INFO' 201 | 202 | # load from which checkpoint 203 | load_from = None 204 | 205 | # whether to resume training from the loaded checkpoint 206 | resume = False 207 | 208 | # Defaults to use random seed and disable `deterministic` 209 | randomness = dict(seed=None, deterministic=False) 210 | 211 | # set log processor 212 | log_processor = dict(by_epoch=False) 213 | -------------------------------------------------------------------------------- /xtuner_config/build_config.py: -------------------------------------------------------------------------------- 1 | from mmengine import Config, ConfigDict 2 | from mmengine.config.lazy import LazyObject 3 | from xtuner.utils import PROMPT_TEMPLATE, SYSTEM_TEMPLATE 4 | import torch 5 | import os 6 | from .get_prompt_template import get_prompt_template 7 | CUR_DIR = os.path.dirname(__file__) 8 | TEMPLATE_DIR = os.path.join(os.path.dirname(CUR_DIR), "template_configs") 9 | 10 | 11 | MODEL_TO_TEMPLATE = { 12 | "baichuan-inc/Baichuan-7B": "default", 13 | "baichuan-inc/Baichuan-13B-Base": "default", 14 | "baichuan-inc/Baichuan-13B-Chat": "baichuan_chat", 15 | "baichuan-inc/Baichuan2-7B-Base": "default", 16 | "baichuan-inc/Baichuan2-7B-Chat": "baichuan2_chat", 17 | "baichuan-inc/Baichuan2-13B-Base": "default", 18 | "baichuan-inc/Baichuan2-13B-Chat": "baichuan2_chat", 19 | "THUDM/chatglm2-6b": "chatglm2", 20 | "THUDM/chatglm3-6b": "chatglm3", 21 | "THUDM/chatglm3-6b-base": "chatglm3", 22 | "deepseek-ai/deepseek-coder-6.7b-base": "deepseek_coder", 23 | "deepseek-ai/deepseek-coder-6.7b-instruct": "deepseek_coder", 24 | "internlm/internlm-7b": "default", 25 | "internlm/internlm-20b": "default", 26 | "internlm/internlm-chat-7b": "internlm_chat", 27 | "internlm/internlm-chat-20b": "internlm_chat", 28 | "huggyllama/llama-7b": "default", 29 | "meta-llama/Llama-2-7b-hf": "llama2_chat", 30 | "meta-llama/Llama-2-7b": "llama2_chat", 31 | "meta-llama/Llama-2-7b-chat-hf": "llama2_chat", 32 | "meta-llama/Llama-2-7b-chat": "llama2_chat", 33 | "meta-llama/Llama-2-70b-hf": "llama2_chat", 34 | "lmsys/vicuna-7b-v1.5": "vicuna", 35 | "lmsys/vicuna-13b-v1.5": "vicuna", 36 | "mistralai/Mistral-7B-v0.1": "mistral", 37 | "mistralai/Mixtral-8x7B-v0.1": "mixtral", 38 | "mistralai/Mixtral-8x7B-Instruct-v0.1": "mixtral", 39 | "Qwen/Qwen-1_8B": "default", 40 | "Qwen/Qwen-1_8B-Chat": "qwen_chat", 41 | "Qwen/Qwen-7B": "default", 42 | "Qwen/Qwen-7B-Chat": "qwen_chat", 43 | "Qwen/Qwen-72B": "default", 44 | "Qwen/Qwen-72B-Chat": "qwen_chat", 45 | "bigcode/starcoder": "default", 46 | "01-ai/Yi-6B": "default", 47 | "01-ai/Yi-34B": "default", 48 | "HuggingFaceH4/zephyr-7b-beta": "zephyr", 49 | "deepseek-ai/deepseek-moe-16b-base": "deepseek_moe", 50 | "deepseek-ai/deepseek-moe-16b-chat": "deepseek_moe", 51 | "internlm/internlm2-7b": "default", 52 | "internlm/internlm2-20b": "default", 53 | "internlm/internlm2-chat-7b": "internlm2_chat", 54 | "internlm/internlm2-chat-20b": "internlm2_chat" 55 | } 56 | 57 | DATA2MAPFN = { 58 | 'tatsu-lab/alpaca': 'alpaca_map_fn', 59 | 'silk-road/alpaca-data-gpt4-chinese': 'alpaca_zh_map_fn', 60 | 'garage-bAInd/Open-Platypus': 'alpaca_map_fn', 61 | 'HuggingFaceH4/CodeAlpaca_20K': 'code_alpaca_map_fn', 62 | 'burkelibbey/colors': 'colors_map_fn', 63 | 'shibing624/medical': 'medical_map_fn', 64 | 'damo/MSAgent-Bench': 'msagent_react_map_fn', 65 | 'timdettmers/openassistant-guanaco': 'oasst1_map_fn', 66 | 'Open-Orca/OpenOrca': 'openorca_map_fn', 67 | 'Skywork/SkyPile-150B': 'pretrain_map_fn', 68 | 'mistralai/Mistral-7B-v0.1': 'pretrain_map_fn', 69 | 'b-mc2/sql-create-context': 'sql_map_fn', 70 | 'ArmelR/stack-exchange-instruction': 'stack_exchange_map_fn', 71 | 'nampdn-ai/tiny-codes': 'tiny_codes_map_fn', 72 | 'WizardLM/WizardLM_evol_instruct_V2_196k': 'wizardlm_map_fn', 73 | } 74 | 75 | def data_path_map_fn(file): 76 | if file in DATA2MAPFN: 77 | return DATA2MAPFN[file] 78 | for k, v in DATA2MAPFN.items(): 79 | k_list = k.split('/') 80 | k_fix = '_'.join(k_list) 81 | if k_fix in file: 82 | return v 83 | return None 84 | 85 | def model_path_map_fn(file): 86 | print(f'model_path_map_fn({file})') 87 | if file in MODEL_TO_TEMPLATE: 88 | return MODEL_TO_TEMPLATE[file] 89 | for k, v in MODEL_TO_TEMPLATE.items(): 90 | k_list = k.split('/') 91 | k_fix = '_'.join(k_list) 92 | if k_fix in file: 93 | return v 94 | return None 95 | 96 | """ 97 | save_checkpoint_ratio -> save_checkpoint_interval 98 | accumulative_counts -> accumulative_counts 99 | 新增 save_total_limit 100 | 'bigcode/starcoder' 不是DATA_LIST 101 | """ 102 | def traverse_keys(cfg_dict, target_keys, new_value): 103 | if isinstance(cfg_dict, dict): 104 | for key, value in dict.items(cfg_dict): 105 | if key in target_keys: 106 | cfg_dict[key] = new_value 107 | else: 108 | traverse_keys(value, target_keys, new_value) 109 | elif isinstance(cfg_dict, (list, tuple)): 110 | for value in cfg_dict: 111 | traverse_keys(value, target_keys, new_value) 112 | 113 | def traverse_value(cfg_dict, target_value, new_value): 114 | if isinstance(cfg_dict, dict): 115 | for key, value in dict.items(cfg_dict): 116 | if value == target_value: 117 | cfg_dict[key] = new_value 118 | else: 119 | traverse_value(value, target_value, new_value) 120 | elif isinstance(cfg_dict, (list, tuple)): 121 | for value in cfg_dict: 122 | traverse_value(value, target_value, new_value) 123 | 124 | 125 | def set_model_related(cfg, model_path): 126 | traverse_keys(cfg._cfg_dict, ('pretrained_model_name_or_path', ), model_path) 127 | 128 | 129 | def set_data_related(cfg, dataset, is_custom_dataset, prompt_template, max_length, pack_to_max_length): 130 | if is_custom_dataset: 131 | dataset = ConfigDict(path='json', data_files=dataset) 132 | cfg.alpaca_en.dataset.update(dataset) 133 | cfg.train_dataloader.dataset.dataset.update(dataset) 134 | 135 | traverse_keys(cfg._cfg_dict, ('dataset_map_fn', ), LazyObject('xtuner.dataset.map_fns', 'openai_map_fn')) 136 | else: 137 | traverse_value(cfg._cfg_dict, 'tatsu-lab/alpaca', dataset) 138 | 139 | traverse_keys(cfg._cfg_dict, ('dataset_map_fn', ), LazyObject('xtuner.dataset.map_fns', data_path_map_fn(dataset))) 140 | 141 | assert prompt_template in PROMPT_TEMPLATE, \ 142 | f'Expect prompt_template to be one of {PROMPT_TEMPLATE.keys()}, but got {prompt_template}.' 143 | prompt_template = PROMPT_TEMPLATE[prompt_template] 144 | traverse_keys(cfg._cfg_dict, ('template', 'prompt_template'), prompt_template) 145 | 146 | traverse_keys(cfg._cfg_dict, ('max_length', ), max_length) 147 | 148 | traverse_keys(cfg._cfg_dict, ('pack_to_max_length', ), pack_to_max_length) 149 | 150 | 151 | def set_scheduler_optimizer_related( 152 | cfg, batch_size_per_device, accumulative_counts, dataloader_num_workers, 153 | max_epochs, optim_type, lr, beta1, beta2, weight_decay, max_norm, warmup_ratio): 154 | traverse_keys(cfg._cfg_dict, ('batch_size', ), batch_size_per_device) 155 | traverse_keys(cfg._cfg_dict, ('accumulative_counts', ), accumulative_counts) 156 | traverse_keys(cfg._cfg_dict, ('dataloader_num_workers', 'num_workers'), dataloader_num_workers) 157 | 158 | traverse_keys(cfg._cfg_dict, ('max_epochs', 'T_max'), max_epochs) 159 | cfg.param_scheduler[0].end = warmup_ratio * max_epochs 160 | cfg.param_scheduler[1].begin = warmup_ratio * max_epochs 161 | cfg.warmup_ratio = warmup_ratio 162 | 163 | assert hasattr(torch.optim, optim_type) 164 | cfg.optim_type = LazyObject('torch.optim', optim_type) 165 | cfg.optim_wrapper.optimizer.type = LazyObject('torch.optim', optim_type) 166 | 167 | cfg.lr = lr 168 | cfg.optim_wrapper.optimizer.lr = lr 169 | 170 | if optim_type == 'AdamW': 171 | traverse_keys(cfg._cfg_dict, ('betas', ), (beta1, beta2)) 172 | 173 | traverse_keys(cfg._cfg_dict, ('weight_decay', ), weight_decay) 174 | traverse_keys(cfg._cfg_dict, ('max_norm', ), max_norm) 175 | 176 | 177 | def set_checkpoint_related(cfg, save_checkpoint_interval, save_total_limit): 178 | cfg.save_steps = save_checkpoint_interval 179 | cfg.default_hooks.checkpoint.interval = save_checkpoint_interval 180 | 181 | cfg.save_total_limit = save_total_limit 182 | cfg.default_hooks.checkpoint.max_keep_ckpts = save_total_limit 183 | 184 | 185 | def set_evaluate_related(cfg, evaluation_freq, evaluation_system_prompt, evaluation_inputs): 186 | traverse_keys(cfg._cfg_dict, ('evaluation_freq', 'every_n_iters'), evaluation_freq) 187 | 188 | system_prompt = SYSTEM_TEMPLATE[evaluation_system_prompt] if evaluation_system_prompt else '' 189 | traverse_keys(cfg._cfg_dict, ('SYSTEM', 'system'), system_prompt) 190 | 191 | # evaluation_inputs = [evaluation_input1, evaluation_input2] 192 | traverse_keys(cfg._cfg_dict, ('evaluation_inputs', ), evaluation_inputs) 193 | 194 | 195 | def build_config( 196 | ft_method, model_path, dataset, is_custom_dataset, deepspeed, lr, warmup_ratio, batch_size_per_device, 197 | accumulative_counts, num_GPU, max_length, pack_to_max_length, max_epochs, save_checkpoint_interval, save_total_limit, 198 | evaluation_freq, evaluation_system_prompt, evaluation_inputs, 199 | optim_type, weight_decay, max_norm, dataloader_num_workers, beta1, beta2, 200 | prompt_template): 201 | if ft_method == 'full': 202 | cfg = Config.fromfile(f'{TEMPLATE_DIR}/full_finetune.py') 203 | elif ft_method == 'lora': 204 | cfg = Config.fromfile(f'{TEMPLATE_DIR}/lora.py') 205 | elif ft_method == 'qlora': 206 | cfg = Config.fromfile(f'{TEMPLATE_DIR}/qlora.py') 207 | else: 208 | raise NotImplementedError(f'Expect ft_method to be one of (full, lora, qlora), but got {ft_method}.') 209 | 210 | set_model_related(cfg, model_path) 211 | set_data_related(cfg, dataset, is_custom_dataset, prompt_template, max_length, pack_to_max_length) 212 | set_scheduler_optimizer_related(cfg, batch_size_per_device, accumulative_counts, dataloader_num_workers, 213 | max_epochs, optim_type, lr, beta1, beta2, weight_decay, max_norm, warmup_ratio) 214 | set_checkpoint_related(cfg, save_checkpoint_interval, save_total_limit) 215 | set_evaluate_related(cfg, evaluation_freq, evaluation_system_prompt, evaluation_inputs) 216 | 217 | return cfg 218 | 219 | 220 | kwargs = dict( 221 | ft_method='full', 222 | model_path='/mnt/petrelfs/share_data/caoweihan/official_Ampere_7B_1_0_0', 223 | dataset='timdettmers/openassistant-guanaco', 224 | is_custom_dataset=False, 225 | deepspeed=None, # 与生成config无关 226 | lr=2e-5, 227 | warmup_ratio=0.03, 228 | batch_size_per_device=1, 229 | accumulative_counts=2, 230 | num_GPU=None, # 与生成config无关 231 | max_length=2048, 232 | pack_to_max_length=True, 233 | max_epochs=2, 234 | save_checkpoint_interval=1000, 235 | save_total_limit=2, 236 | evaluation_freq=100, 237 | evaluation_system_prompt='', 238 | evaluation_inputs=['请给我介绍五个上海的景点', 'Please tell me five scenic spots in Shanghai'], 239 | optim_type='AdamW', 240 | weight_decay=0, 241 | max_norm=1, 242 | dataloader_num_workers=0, 243 | beta1=0.9, 244 | beta2=0.999, 245 | prompt_template='internlm2_chat' 246 | ) 247 | 248 | int_args = [ 249 | 'batch_size_per_device', 250 | 'accumulative_counts', 251 | 'num_GPU', 252 | 'max_length', 253 | 'pack_to_max_length', 254 | 'max_epochs', 255 | 'save_checkpoint_interval', 256 | 'save_total_limit', 257 | 'evaluation_freq', 258 | 'dataloader_num_workers', 259 | ] 260 | default_args_key = [ 261 | 'ft_method', 262 | 'model_path', 263 | 'dataset', 264 | 'deepspeed', 265 | 'lr', 266 | 'warmup_ratio', 267 | 'batch_size_per_device', 268 | 'accumulative_counts', 269 | 'num_GPU', 270 | 'max_length', 271 | 'pack_to_max_length', 272 | 'max_epochs', 273 | 'save_checkpoint_interval', 274 | 'save_total_limit', 275 | 'evaluation_freq', 276 | 'evaluation_system_prompt', 277 | 'optim_type', 278 | 'weight_decay', 279 | 'max_norm', 280 | 'dataloader_num_workers', 281 | 'beta1', 282 | 'beta2', 283 | 'prompt_template', 284 | ] 285 | 286 | def build_config_path(root_dir): 287 | work_dir = os.path.join(root_dir, 'work_dir') 288 | if not os.path.exists(work_dir): 289 | os.system(f'mkdir -p {work_dir}') 290 | return os.path.join(work_dir, 'xtuner_config.py') 291 | 292 | 293 | def build_and_save_config( 294 | dataset_personal_path, 295 | dataset_personal, 296 | model_personal_path, 297 | personal_model, 298 | detect_prompt_template, 299 | root_dir, 300 | *args, **kwargs 301 | ): 302 | kwargs.update( 303 | dict(zip(default_args_key, list(args))) 304 | ) 305 | # prepare 'evaluation_inputs' 306 | evaluation_inputs = list(args)[len(default_args_key):] 307 | kwargs['evaluation_inputs'] = [i for i in evaluation_inputs if len(i)] 308 | print(f'dataset_personal_path={dataset_personal_path}||') 309 | # float -> int 310 | for k in int_args: 311 | kwargs[k] = int(kwargs[k]) 312 | # custom dataset 313 | kwargs['is_custom_dataset'] = False 314 | # dataset_personal_path > dataset_personal > dataset 315 | if dataset_personal is not None and len(dataset_personal) >= 3: 316 | kwargs['is_custom_dataset'] = True 317 | kwargs['dataset'] = dataset_personal 318 | if dataset_personal_path is not None and len(dataset_personal_path) >= 3: 319 | kwargs['is_custom_dataset'] = True 320 | kwargs['dataset'] = dataset_personal_path 321 | 322 | # dropdown-list prompt_template 323 | prompt_template = model_path_map_fn(kwargs['model_path']) 324 | if personal_model is not None and len(personal_model) >= 3: 325 | kwargs['model_path'] = personal_model 326 | prompt_template = detect_prompt_template 327 | 328 | if model_personal_path is not None and len(model_personal_path) >= 3: 329 | kwargs['model_path'] = model_personal_path 330 | prompt_template = detect_prompt_template 331 | 332 | # final prompt_template 333 | kwargs['prompt_template'] = prompt_template 334 | if kwargs['prompt_template'] is None: 335 | kwargs['prompt_template'] = detect_prompt_template 336 | print(f'kwargs={kwargs}') 337 | cfg = build_config(**kwargs) 338 | cfg_py = build_config_path(root_dir) 339 | cfg.dump(cfg_py) 340 | print('cfg_py=', cfg_py) 341 | return cfg_py 342 | 343 | 344 | if __name__ == '__main__': 345 | build_and_save_config('.', **kwargs) 346 | -------------------------------------------------------------------------------- /xtuner_config/check_custom_dataset.py: -------------------------------------------------------------------------------- 1 | from datasets import load_dataset 2 | import pprint 3 | 4 | 5 | DATA_EXAMPLE = """example["messages"] = [ 6 | { "role": "system", "content": "You are an assistant that 7 | occasionally misspells words." }, 8 | { "role": "user", "content": "Tell me a story." }, 9 | { "role": "assistant", "content": "One day a student 10 | went to schoool." }]""" 11 | 12 | def check_custom_dataset(input_path, upload_path): 13 | path = input_path if len(input_path) >= 3 else upload_path 14 | try: 15 | data = load_dataset('json', data_files=path) 16 | except: 17 | return f"There's a problem with the JSON file in {path}; it can't be read." 18 | data = data['train'] 19 | 20 | if 'messages' not in data.column_names: 21 | return ('Expect "messages" as a column in the dataset. Here is an ' 22 | f'example:\n{DATA_EXAMPLE}') 23 | 24 | if not isinstance(data['messages'], (list, tuple)): 25 | return ('Expect the type of example["messages"] to be a list or ' 26 | f'a tuple, but got {type(data["messages"])}.' 27 | f'Here is an example:\n{DATA_EXAMPLE}') 28 | 29 | check_first_n_messages = 100 30 | for message_idx, message in enumerate(data['messages'][:check_first_n_messages]): 31 | for conv_idx, single_conversation in enumerate(message): 32 | if not isinstance(single_conversation, dict): 33 | return ('Expect each single conversation to be a dict, ' 34 | f'but got {type(single_conversation)}. ' 35 | f'Here is an example:\n{DATA_EXAMPLE}') 36 | if not {'role', 'content'}.issubset(single_conversation.keys()): 37 | return ('Expect "role" and "content" in each single ' 38 | f'conversation. The {conv_idx + 1} conversation in the' 39 | f' {message_idx} message is {single_conversation}.' 40 | f'Here is an example:\n{DATA_EXAMPLE}') 41 | 42 | return 'Data is OK.' 43 | 44 | 45 | if __name__ == "__main__": 46 | out = check_custom_dataset('/mnt/petrelfs/caoweihan/projects/xtuner/data.json') 47 | if out is None: 48 | print('Data is OK.') 49 | else: 50 | pprint.pprint(out) 51 | -------------------------------------------------------------------------------- /xtuner_config/get_default_hyperparameters.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from collections import OrderedDict 3 | 4 | 5 | DEFAULT_HYPERPARAMETERS = OrderedDict( 6 | warmup_ratio=0.03, 7 | batch_size_per_device=1, 8 | accumulative_counts=2, 9 | num_GPU=8, 10 | max_length=2048, 11 | pack_to_max_length=True, 12 | evaluation_freq=500, 13 | optim_type='AdamW', 14 | weight_decay=0, 15 | max_norm=1, 16 | dataloader_num_workers=0, 17 | beta1=0.9, 18 | beta2=0.999, 19 | lr=2e-5, 20 | save_checkpoint_interval=500, 21 | save_total_limit=2 22 | ) 23 | 24 | 25 | def get_default_hyperparameters(ft_method): 26 | out_dict = copy.deepcopy(DEFAULT_HYPERPARAMETERS) 27 | if ft_method.lower() == 'full': 28 | out_dict.update(dict( 29 | lr=2e-5, 30 | save_checkpoint_interval=500, 31 | save_total_limit=2)) 32 | else: 33 | out_dict.update(dict( 34 | lr=2e-4, 35 | save_checkpoint_interval=200, 36 | save_total_limit=5)) 37 | out_list = [] 38 | for i in DEFAULT_HYPERPARAMETERS.keys(): 39 | out_list.append(out_dict[i]) 40 | return out_list 41 | 42 | 43 | if __name__ == '__main__': 44 | print(DEFAULT_HYPERPARAMETERS.keys()) 45 | print( 46 | get_default_hyperparameters('full') 47 | ) 48 | print( 49 | get_default_hyperparameters('lora') 50 | ) 51 | 52 | -------------------------------------------------------------------------------- /xtuner_config/get_prompt_template.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoConfig, AutoModelForCausalLM 2 | from accelerate import init_empty_weights 3 | 4 | 5 | MODEL_TEMPLATE_MAPPING = dict( 6 | InternLM2ForCausalLM='internlm2_chat', 7 | InternLMForCausalLM='internlm_chat', 8 | BaiChuanForCausalLM='baichuan_chat', 9 | BaichuanForCausalLM='baichuan2_chat', 10 | DeepseekForCausalLM='deepseek_moe', 11 | MixtralForCausalLM='mixtral', 12 | QWenLMHeadModel='qwen_chat', 13 | GPTBigCodeForCausalLM='default' 14 | ) 15 | 16 | 17 | def get_prompt_template(pretrained_model_name_or_path): 18 | try: 19 | config = AutoConfig.from_pretrained(pretrained_model_name_or_path, trust_remote_code=True) 20 | with init_empty_weights(): 21 | model = AutoModelForCausalLM.from_config(config, trust_remote_code=True) 22 | except: 23 | return f'Model {pretrained_model_name_or_path} can not be loaded.', None 24 | model_type = type(model).__name__ 25 | if model_type == 'LlamaForCausalLM': 26 | vocab_size = config.vocab_size 27 | if vocab_size == 32256: 28 | return 'Success', 'deepseek_coder' 29 | elif vocab_size == 64000: # yi 30 | return 'Success', 'default' 31 | elif vocab_size == 32000: # llama2 32 | return 'Success', 'llama2_chat' 33 | elif model_type == 'ChatGLMForConditionalGeneration': 34 | seq_length = config.seq_length 35 | if seq_length == 131072: 36 | return 'Success', 'chatglm3' 37 | elif seq_length == 32768: 38 | return 'Success', 'chatglm2' 39 | else: 40 | return 'Fail to match automatically, please enter corresponding prompt template manually', None 41 | elif model_type == 'MistralForCausalLM': 42 | # 无法判断 43 | return 'The prompt template should be one of mistral or zephyr, please enter the correct prompt template manually', None 44 | elif model_type in MODEL_TEMPLATE_MAPPING: 45 | return 'Success', MODEL_TEMPLATE_MAPPING[model_type] 46 | else: 47 | return 'Fail to match automatically, please enter corresponding prompt template manually', None 48 | 49 | 50 | def app_get_prompt_template(input_path, upload_path): 51 | print(f'app_get_prompt_template({input_path}, {upload_path})') 52 | pretrained_model_name_or_path = input_path if len(input_path) >= 3 else upload_path 53 | info, prompt_template = get_prompt_template(pretrained_model_name_or_path) 54 | return f'{info} >> {prompt_template}', prompt_template 55 | 56 | 57 | if __name__ == "__main__": 58 | print(get_prompt_template('/mnt/petrelfs/share_data/caoweihan/official_Ampere_7B_1_0_0')) -------------------------------------------------------------------------------- /xtuner_convert/convert_and_merge.py: -------------------------------------------------------------------------------- 1 | 2 | # python3 3 | # Create Date: 2024-01-30 4 | # Author: 爱科研的瞌睡虫 5 | 6 | import os 7 | from .merge import merge 8 | from .pth_to_hf import convert_to_hf 9 | 10 | def _convert_and_merged(config_file, pth_model, save_hf_dir, model_path, save_merged_dir): 11 | convert_to_hf(config_file, pth_model, save_hf_dir) 12 | merge(model_path, save_hf_dir, save_merged_dir) 13 | 14 | 15 | def build_convert_and_merged_path(root_dir, epoch_pth): 16 | epoch = os.path.basename(epoch_pth).split('.')[0] 17 | work_dir = os.path.join(root_dir, 'work_dir') 18 | if not os.path.exists(work_dir): 19 | os.system(f'mkdir -p {work_dir}') 20 | hf = os.path.join(work_dir, f'xtuner_{epoch}_hf') 21 | mg = os.path.join(work_dir, f'xtuner_{epoch}_merge') 22 | # clear 23 | if os.path.exists(hf): 24 | os.system(f'rm -rf {hf}') 25 | if os.path.exists(mg): 26 | os.system(f'rm -rf {mg}') 27 | return work_dir, hf, mg 28 | 29 | 30 | def convert_and_merged(root_dir, config_file, epoch_pth, model_path, model_personal_path, ft_method): 31 | if len(model_personal_path) >= 3: 32 | model_path = model_personal_path 33 | work_dir, save_hf_dir, save_merged_dir = build_convert_and_merged_path(root_dir, epoch_pth) 34 | pth_model = os.path.join(work_dir, epoch_pth) 35 | print( 36 | f'config_file = {config_file}' 37 | ,f'\npth_model = {pth_model}' 38 | ,f'\nsave_hf_dir = {save_hf_dir}' 39 | ,f'\nmodel_path ={model_path}' 40 | ,f'\nsave_merged_dir ={save_merged_dir}' 41 | ) 42 | merged_flag = ft_method != 'full' 43 | try: 44 | convert_to_hf(config_file, pth_model, save_hf_dir) 45 | out_dir = save_hf_dir 46 | if merged_flag: 47 | merge(model_path, save_hf_dir, save_merged_dir) 48 | out_dir = save_merged_dir 49 | 50 | info = 'Successfully converted model ! ' 51 | except Exception as e: 52 | info = e 53 | pass 54 | return info, out_dir 55 | 56 | 57 | if __name__ == '__main__': 58 | 59 | config_file = '/root/ft-oasst1/internlm_chat_7b_qlora_oasst1_e3_copy.py' 60 | pth_model = '/root/ft-oasst1/work_dirs/internlm_chat_7b_qlora_oasst1_e3_copy/epoch_1.pth' 61 | save_hf_dir = '/root/ft-oasst1/hf5' 62 | model_path = '/root/ft-oasst1/internlm-chat-7b' 63 | save_merged_dir = '/root/ft-oasst1/merged5' 64 | 65 | _convert_and_merged(config_file, pth_model, save_hf_dir, model_path, save_merged_dir) 66 | -------------------------------------------------------------------------------- /xtuner_convert/convert_with_progress.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | from .convert_and_merge import build_convert_and_merged_path, convert_to_hf, merge 4 | import re 5 | from tqdm.auto import tqdm 6 | import threading 7 | import time 8 | import gradio as gr 9 | 10 | 11 | class ConvertMerged: 12 | def __init__(self): 13 | self.save_hf_dir = None 14 | self.save_merged_dir = None 15 | self.merged_flag = True 16 | self.model_path = None 17 | self.out_dir = None 18 | self.info = None 19 | 20 | def convert_and_merged(self, root_dir, config_file, epoch_pth, model_path, model_personal_path, ft_method): 21 | if len(model_personal_path) >= 3: 22 | model_path = model_personal_path 23 | 24 | self.model_path = model_path 25 | work_dir, save_hf_dir, save_merged_dir = build_convert_and_merged_path(root_dir, epoch_pth) 26 | self.save_hf_dir = save_hf_dir 27 | self.save_merged_dir = save_merged_dir 28 | pth_model = os.path.join(work_dir, epoch_pth) 29 | print( 30 | f'config_file = {config_file}' 31 | ,f'\npth_model = {pth_model}' 32 | ,f'\nsave_hf_dir = {save_hf_dir}' 33 | ,f'\nmodel_path ={model_path}' 34 | ,f'\nsave_merged_dir ={save_merged_dir}' 35 | ) 36 | self.merged_flag = ft_method.lower() != 'full' 37 | try: 38 | convert_to_hf(config_file, pth_model, save_hf_dir) 39 | self.out_dir = save_hf_dir 40 | if self.merged_flag: 41 | merge(model_path, save_hf_dir, save_merged_dir) 42 | self.out_dir = save_merged_dir 43 | 44 | self.info = 'Successfully converted model ! ' 45 | except Exception as e: 46 | self.info = e 47 | pass 48 | return self.info, self.out_dir 49 | 50 | def auto_convert_merge( 51 | self, 52 | root_dir, 53 | config_file, 54 | epoch_pth, 55 | model_path, 56 | model_personal_path, 57 | ft_method, 58 | progress=gr.Progress(track_tqdm=True) 59 | ): 60 | self._t_convert(root_dir, config_file, epoch_pth, model_path, model_personal_path, ft_method) 61 | time.sleep(2) 62 | print( 63 | f'self.model_path={self.model_path}', 64 | f'self.save_hf_dir={self.save_hf_dir}', 65 | f'self.save_merged_dir={self.save_merged_dir}' 66 | ) 67 | self.progress() 68 | return self.info, self.out_dir 69 | 70 | def _t_convert(self, root_dir, config_file, epoch_pth, model_path, model_personal_path, ft_method): 71 | self._t_handle_convert = threading.Thread( 72 | target=self.convert_and_merged, args=(root_dir, config_file, epoch_pth, model_path, model_personal_path, ft_method) , 73 | name='X-model-convert-merge', daemon=True) 74 | self._t_handle_convert.start() 75 | 76 | def find_max_sub(self, _dir): 77 | total_ = [i for i in os.listdir(_dir) if len(re.findall(r'[0-9]+-of-[0-9]+', i))] 78 | if len(total_): 79 | info = [int(re.findall(r'(\d+)-of', i)[0]) for i in total_ if len(re.findall(r'(\d+)-of', i))] 80 | if len(info): 81 | return max(info) 82 | return 0 83 | 84 | def progress(self, progress=None): 85 | big_step = 100 86 | total = 0 87 | hf_total = 1 88 | # /root/share/model_repos/internlm-chat-7b/pytorch_model-00001-of-00008.bin 89 | base_model_parts = [i for i in os.listdir(self.model_path) if len(re.findall(r'[0-9]+-of-[0-9]+', i))] 90 | base_max = 0 91 | if len(base_model_parts): 92 | fd = re.findall(r'-of-(\d+)', base_model_parts[0]) 93 | print(f'progress fd => {fd}') 94 | base_max = int(fd[0]) if len(fd) else 0 95 | if self.merged_flag: 96 | total = hf_total + base_max 97 | else: 98 | total = hf_total = base_max 99 | 100 | hf_total *= big_step 101 | total *= big_step 102 | tq_bar = tqdm(total=total) 103 | big_step_hf_now = 0 104 | big_step_mg_now = 0 105 | hf_now = 0 106 | mg_now = 0 107 | while True: 108 | if self._t_handle_convert is None: 109 | break 110 | if not self._t_handle_convert.is_alive(): 111 | break 112 | 113 | up_hf = 0 114 | if os.path.exists(self.save_hf_dir) and not self.merged_flag: 115 | max_hf_b = self.find_max_sub(self.save_hf_dir) * big_step 116 | # 在一个的时候 117 | if big_step_hf_now == max_hf_b and (big_step_hf_now + big_step) > hf_now and hf_now < hf_total: 118 | up_hf = 1 119 | hf_now += 1 120 | elif max_hf_b > hf_now: 121 | up_hf = max_hf_b - hf_now 122 | hf_now = max_hf_b 123 | else: 124 | up_hf = 0 125 | 126 | big_step_hf_now = max_hf_b 127 | elif self.merged_flag and not os.path.exists(self.save_hf_dir): 128 | if big_step >= hf_now: 129 | up_hf = 1 130 | hf_now += 1 131 | else: 132 | max_hf_b = big_step 133 | up_hf = max_hf_b - hf_now 134 | hf_now = max_hf_b 135 | 136 | up_mg = 0 137 | if self.merged_flag: 138 | if not os.path.exists(self.save_merged_dir): 139 | if big_step > mg_now: 140 | up_mg = 1 141 | mg_now += 1 142 | else: 143 | max_mg_b = self.find_max_sub(self.save_merged_dir) * big_step 144 | # 在一个的时候 145 | if big_step_mg_now == max_mg_b and (big_step_mg_now + big_step) > mg_now and (mg_now + hf_now) < total: 146 | up_mg = 1 147 | mg_now += 1 148 | elif max_mg_b > mg_now: 149 | up_mg = max_mg_b - mg_now 150 | mg_now = max_mg_b 151 | else: 152 | up_mg = 0 153 | 154 | big_step_mg_now = max_mg_b 155 | 156 | tq_bar.update(up_mg + up_hf) 157 | time.sleep(1) 158 | 159 | -------------------------------------------------------------------------------- /xtuner_convert/merge.py: -------------------------------------------------------------------------------- 1 | # python3 2 | # Create Date: 2024-01-30 3 | # Author: 爱科研的瞌睡虫 4 | 5 | import argparse 6 | 7 | import torch 8 | from peft import PeftModel 9 | from transformers import AutoModelForCausalLM, AutoTokenizer 10 | 11 | 12 | def parse_args(): 13 | parser = argparse.ArgumentParser( 14 | description='Merge a HuggingFace adapter to LLM') 15 | # parser.add_argument('model_name_or_path', help='model name or path') 16 | # parser.add_argument('adapter_name_or_path', help='adapter name or path') 17 | # parser.add_argument( 18 | # 'save_dir', help='the directory to save the merged model') 19 | parser.add_argument( 20 | '--max-shard-size', 21 | type=str, 22 | default='2GB', 23 | help='Only applicable for LLM. The maximum size for ' 24 | 'each sharded checkpoint.') 25 | parser.add_argument( 26 | '--offload-folder', 27 | default=None, 28 | help='The folder in which to offload the model weights (or where ' 29 | 'the model weights are already offloaded).') 30 | args = parser.parse_args() 31 | return args 32 | 33 | 34 | def merge(model_path, adapter_hf_path, save_dir): 35 | args = parse_args() 36 | model = AutoModelForCausalLM.from_pretrained( 37 | model_path, 38 | torch_dtype=torch.float16, 39 | low_cpu_mem_usage=True, 40 | device_map='auto', 41 | offload_folder=args.offload_folder, 42 | trust_remote_code=True) 43 | tokenizer = AutoTokenizer.from_pretrained( 44 | model_path, 45 | trust_remote_code=True, 46 | encode_special_tokens=True) 47 | model_unmerged = PeftModel.from_pretrained( 48 | model, 49 | adapter_hf_path, 50 | device_map='auto', 51 | torch_dtype=torch.float16, 52 | offload_folder=args.offload_folder, 53 | is_trainable=False) 54 | model_merged = model_unmerged.merge_and_unload() 55 | print(f'Merged Saving to {save_dir}...') 56 | model_merged.save_pretrained( 57 | save_dir, max_shard_size=args.max_shard_size) 58 | tokenizer.save_pretrained(save_dir) 59 | print('Merged All done!') 60 | 61 | 62 | if __name__ == '__main__': 63 | 64 | model_path = '/root/ft-oasst1/internlm-chat-7b' 65 | adapter_hf_path = '/root/ft-oasst1/hf4' 66 | save_dir = '/root/ft-oasst1/merged5' 67 | 68 | merge(model_path, adapter_hf_path, save_dir) -------------------------------------------------------------------------------- /xtuner_convert/pth_to_hf.py: -------------------------------------------------------------------------------- 1 | # python3 2 | # Create Date: 2024-01-30 3 | # Author: 爱科研的瞌睡虫 4 | 5 | import argparse 6 | import os 7 | import shutil 8 | 9 | import torch 10 | from mmengine.config import Config, DictAction 11 | 12 | from xtuner.configs import cfgs_name_path 13 | from xtuner.registry import BUILDER 14 | 15 | 16 | def parse_args(): 17 | parser = argparse.ArgumentParser( 18 | description='Convert the pth model to HuggingFace model') 19 | # parser.add_argument('config', help='config file name or path.') 20 | # parser.add_argument('pth_model', help='pth model file') 21 | # parser.add_argument( 22 | # 'save_dir', help='the directory to save HuggingFace model') 23 | parser.add_argument( 24 | '--fp32', 25 | action='store_true', 26 | help='Save as fp32. If not set, fp16 will be used by default.') 27 | parser.add_argument( 28 | '--max-shard-size', 29 | type=str, 30 | default='2GB', 31 | help='Only applicable for LLM. The maximum size for ' 32 | 'each sharded checkpoint.') 33 | parser.add_argument( 34 | '--cfg-options', 35 | nargs='+', 36 | action=DictAction, 37 | help='override some settings in the used config, the key-value pair ' 38 | 'in xxx=yyy format will be merged into config file. If the value to ' 39 | 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' 40 | 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' 41 | 'Note that the quotation marks are necessary and that no white space ' 42 | 'is allowed.') 43 | args = parser.parse_args() 44 | return args 45 | 46 | 47 | def guess_load_checkpoint(pth_model): 48 | if os.path.isfile(pth_model): 49 | state_dict = torch.load(pth_model, map_location='cpu') 50 | if 'state_dict' in state_dict: 51 | state_dict = state_dict['state_dict'] 52 | elif os.path.isdir(pth_model): 53 | try: 54 | from deepspeed.utils.zero_to_fp32 import \ 55 | get_fp32_state_dict_from_zero_checkpoint 56 | except ImportError: 57 | raise ImportError( 58 | 'The provided PTH model appears to be a DeepSpeed checkpoint. ' 59 | 'However, DeepSpeed library is not detected in current ' 60 | 'environment. This suggests that DeepSpeed may not be ' 61 | 'installed or is incorrectly configured. Please verify your ' 62 | 'setup.') 63 | state_dict = get_fp32_state_dict_from_zero_checkpoint( 64 | os.path.dirname(pth_model), os.path.basename(pth_model)) 65 | else: 66 | raise FileNotFoundError(f'Cannot find {pth_model}') 67 | return state_dict 68 | 69 | 70 | def convert_to_hf(config_file, pth_model, save_dir): 71 | args = parse_args() 72 | 73 | # parse config 74 | if not os.path.isfile(config_file): 75 | try: 76 | config_file = cfgs_name_path[config_file] 77 | except KeyError: 78 | raise FileNotFoundError(f'Cannot find {config_file}') 79 | 80 | # load config 81 | cfg = Config.fromfile(config_file) 82 | if args.cfg_options is not None: 83 | cfg.merge_from_dict(args.cfg_options) 84 | 85 | model = BUILDER.build(cfg.model) 86 | 87 | state_dict = guess_load_checkpoint(pth_model) 88 | model.load_state_dict(state_dict, strict=False) 89 | print(f'Load PTH model from {pth_model}') 90 | 91 | if not args.fp32: 92 | print('Convert weights to float16') 93 | model.llm.half() 94 | 95 | print(f'Saving HuggingFace model to {save_dir}') 96 | model.llm.save_pretrained( 97 | save_dir, max_shard_size=args.max_shard_size) 98 | if 'PeftModel' not in model.llm.__class__.__name__: 99 | print(f'Saving HuggingFace tokenizer to {save_dir}') 100 | tokenizer = BUILDER.build(cfg.tokenizer) 101 | tokenizer.save_pretrained(save_dir) 102 | shutil.copyfile(config_file, os.path.join(save_dir, 103 | 'xtuner_config.py')) 104 | print('Pth to hf all done!') 105 | 106 | 107 | if __name__ == '__main__': 108 | 109 | config_file = '/root/ft-oasst1/internlm_chat_7b_qlora_oasst1_e3_copy.py' 110 | pth_model = '/root/ft-oasst1/work_dirs/internlm_chat_7b_qlora_oasst1_e3_copy/epoch_1.pth' 111 | save_dir = '/root/ft-oasst1/hf5' 112 | 113 | convert_to_hf(config_file, pth_model, save_dir) -------------------------------------------------------------------------------- /xtuner_download/README.md: -------------------------------------------------------------------------------- 1 | 2 | ## 模型下载 3 | 4 | 核心流程 5 | ```mermaid 6 | flowchart LR 7 | 8 | IN1(model_name) --> A 9 | IN2(out_path) --> A 10 | IN3(tqdm_class) --> A 11 | subgraph xtunerModelDownload 12 | A(initial) --> B(获取模型信息:文件数量、文件大小) 13 | 14 | DM(起用下载线程:X-model-download) 15 | DP(起用进度线程:X-model-progress) 16 | clear(文件清除) 17 | DP -->|检查下载进度|DM 18 | end 19 | 20 | B --> C 21 | B --> Break 22 | C -->|Start线程|DM 23 | C -->|Start线程|DP --> grP 24 | subgraph gr-client 25 | 26 | C(模型下载) 27 | grP(进度条显示) 28 | 29 | Break(下载中断) 30 | end 31 | 32 | Break -->|kill线程|DM 33 | Break -->|kill线程|DP 34 | Break --> clear 35 | 36 | DM --> hf 37 | subgraph download 38 | modelscope(modelscope-download) 39 | hf(huggingface-download) 40 | openxlab(openxlab-download) 41 | 42 | hf--> check1{下载成功?}-->|否|modelscope--> check2{下载成功?}-->|否|openxlab 43 | 44 | check1 -->|是|finshed 45 | check2 -->|是|finshed 46 | end 47 | ``` 48 | 49 | - example: 50 | ```python 51 | from download_model import xtunerModelDownload 52 | from tqdm.auto import tqdm 53 | import time 54 | 55 | model_name = 'internlm/internlm-chat-7b' 56 | d_model = xtunerModelDownload( 57 | model_name, 58 | out_path='/root/tmp/download_model', 59 | tqdm_class=tqdm 60 | ) 61 | d_model.auto_download() 62 | time.sleep(60) 63 | d_model.break_download() 64 | print('Yes') 65 | ``` 66 | 67 | ## 数据下载 68 | 69 | 70 | 核心流程 71 | ```mermaid 72 | flowchart LR 73 | 74 | IN1(data_name) --> A 75 | IN2(out_path) --> A 76 | IN3(tqdm_class) --> A 77 | subgraph xtunerModelDownload 78 | A(initial) --> B(获取数据信息:文件数量、文件大小) 79 | 80 | DM(起用下载线程:X-dataset-download) 81 | DP(起用进度线程:X-dataset-progress) 82 | clear(文件清除) 83 | DP -->|检查下载进度|DM 84 | end 85 | 86 | B --> C 87 | B --> Break 88 | C -->|Start线程|DM 89 | C -->|Start线程|DP --> grP 90 | subgraph gr-client 91 | 92 | C(数据下载) 93 | grP(进度条显示) 94 | 95 | Break(下载中断) 96 | end 97 | 98 | Break -->|kill线程|DM 99 | Break -->|kill线程|DP 100 | Break --> clear 101 | 102 | DM --> hf 103 | subgraph download 104 | hf(huggingface-download) 105 | 106 | hf--> check1{下载成功?}-->|否|hf-->|Retry-n times|check1 107 | 108 | check1 -->|是|finshed 109 | end 110 | ``` 111 | 112 | - example: 113 | ```python 114 | from download_dataset import xtunerDataDownload 115 | from tqdm.auto import tqdm 116 | import time 117 | 118 | data_name = 'shibing624/medical' 119 | d_data = xtunerDataDownload( 120 | data_name, 121 | out_path='/root/tmp/download_data', 122 | tqdm_class=tqdm, 123 | retry_times=1 124 | ) 125 | d_data.auto_download() 126 | time.sleep(60) 127 | d_data.break_download() 128 | print('Yes') 129 | ``` 130 | 131 | 132 | -------------------------------------------------------------------------------- /xtuner_download/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scchy/XtunerGUI/6e510e39abf55169f45f0e94fb14055c2b973af4/xtuner_download/__init__.py -------------------------------------------------------------------------------- /xtuner_download/data_list.txt: -------------------------------------------------------------------------------- 1 | ./data/CrimeKgAssitant清洗后_52k.json 2 | ./data/arxiv_data.json 3 | ./data/moss-003-sft-no-tools.jsonl 4 | ./data/训练数据_带法律依据_92k.json 5 | 6 | 7 | ArmelR/stack-exchange-instruction 8 | HuggingFaceH4/CodeAlpaca_20K 9 | Open-Orca/OpenOrca # 这个丰富的增强 FLAN 数据集尽可能地与 Orca 论文中概述的分布保持一致。它有助于生成高性能的模型检查点,是所有 NLP 研究人员和开发人员的宝贵资源! 10 | Skywork/SkyPile-150B # SkyPile-150B 是一个全面的大型中文数据集,专门用于大型语言模型的预训练。 11 | WizardLM/WizardLM_evol_instruct_V2_196k 12 | b-mc2/sql-create-context 13 | bigcode/starcoder 14 | burkelibbey/colors 15 | damo/MSAgent-Bench 16 | garage-bAInd/Open-Platypus 17 | mistralai/Mistral-7B-v0.1 18 | nampdn-ai/tiny-codes 19 | shibing624/medical 20 | silk-road/alpaca-data-gpt4-chinese 21 | tatsu-lab/alpaca 22 | timdettmers/openassistant-guanaco 23 | -------------------------------------------------------------------------------- /xtuner_download/download_dataset.py: -------------------------------------------------------------------------------- 1 | # python3 2 | # Create Date: 2024-01-25 3 | # Author: Scc_hy 4 | # Func: 模型拉取到本地 5 | # =========================================================================================== 6 | import os 7 | import gradio as gr 8 | from tqdm.auto import tqdm 9 | from os.path import getsize as p_getsize 10 | from os.path import join as p_join 11 | import threading 12 | import time 13 | from .download_utils import stop_thread, _split_repo, get_hf_cache_files, get_data_info, get_final_out_files, TOKEN 14 | os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com' 15 | CUR_DIR = os.path.dirname(__file__) 16 | 17 | 18 | class xtunerDataDownload(): 19 | def __init__(self, data_name, out_path, tqdm_class=tqdm, progress_sleep=1, retry_times=0) -> None: 20 | self.progress_sleep = progress_sleep 21 | self.run_times = retry_times + 1 22 | self.tqdm_class = tqdm_class 23 | self.username, self.repository = _split_repo(data_name) 24 | self.data_name = data_name 25 | self.out_path = out_path 26 | self.final_out_path = p_join(out_path, f'dataset_{self.username}_{self.repository}') 27 | self.mid_download_dir = self.final_out_path 28 | self._t_handle_dl = None 29 | self._t_handle_pg = None 30 | self._break_flag = False 31 | self._get_info_flag = False 32 | self.__check_create_dir() 33 | self.get_download_info() 34 | 35 | def reset_path(self, customer_dir): 36 | self.__remove_mid_files() 37 | self.__remove_final_files() 38 | self.out_path = f'{customer_dir}/data_download' 39 | print(f'xtunerDataDownload reset_path->{self.out_path}') 40 | self.final_out_path = p_join(self.out_path, f'{self.username}_{self.repository}') 41 | self.mid_download_dir = self.final_out_path 42 | self.__check_create_dir() 43 | 44 | def reset(self, data_name): 45 | self.remove_and_create() 46 | print(f'reset({data_name})') 47 | self.username, self.repository = _split_repo(data_name) 48 | self.data_name = data_name 49 | self.final_out_path = p_join(self.out_path, f'dataset_{self.username}_{self.repository}') 50 | self.mid_download_dir = self.final_out_path 51 | self._t_handle_dl = None 52 | self._t_handle_pg = None 53 | self._break_flag = False 54 | self._get_info_flag = False 55 | self.__check_create_dir() 56 | self.get_download_info() 57 | 58 | def get_download_info(self): 59 | self.total_MB, self.total_file_nums = get_data_info(self.data_name) 60 | self._get_info_flag = True 61 | 62 | def __check_create_dir(self): 63 | if not os.path.exists(self.out_path): 64 | os.system(f'mkdir -p {self.out_path}') 65 | if not os.path.exists(self.final_out_path): 66 | os.system(f'mkdir -p {self.final_out_path}') 67 | 68 | def __remove_mid_files(self): 69 | """中断时删除所有文件""" 70 | os.system(f'rm -rf {self.mid_download_dir}') 71 | # cd rm 72 | rm_dir = './' + self.mid_download_dir.replace(self.out_path, '.')[2:].split('/')[0] 73 | os.system(f'cd {self.out_path} && rm -rf {rm_dir} && rm -rf temp') 74 | # 删除 hf 的cache 75 | os.system(f'rm -rf {self.final_out_path}/cache') 76 | 77 | def __remove_final_files(self): 78 | os.system(f'rm -rf {self.final_out_path}') 79 | os.system(f'cd {self.out_path} && rm -rf dataset_{self.username}_{self.repository}') 80 | 81 | def remove_and_create(self): 82 | self.__remove_mid_files() 83 | self.__remove_final_files() 84 | self.__check_create_dir() 85 | 86 | def auto_download(self, progress=gr.Progress(track_tqdm=True)): 87 | # wait for info 88 | cnt_ = 0 89 | while not self._get_info_flag: 90 | cnt_ += 1 91 | time.sleep(1) 92 | if cnt_ == 30: 93 | break 94 | 95 | self._break_flag = False 96 | self._t_download(self.safe_download) 97 | # self._t_start() 98 | self.progress(progress=progress) 99 | if self._break_flag: 100 | return "Done! Data-Download had interrupted!" 101 | return self.final_out_path 102 | 103 | def safe_download(self): 104 | for _ in range(self.run_times): 105 | self.hf_download() 106 | time.sleep(0.5) 107 | # 执行完检验 108 | if self._finished_check(): 109 | print('finished download all model files') 110 | break 111 | self.remove_and_create() 112 | return 113 | 114 | def hf_download(self): 115 | print('>>>>>>> Start hf_download') 116 | # 1- mid download local dir 117 | self.mid_download_dir = self.final_out_path 118 | # 2- download load_dataset 环境变量调整 119 | os.system(f""" 120 | export HF_ENDPOINT=https://hf-mirror.com && \ 121 | huggingface-cli download --resume-download {self.data_name} --local-dir-use-symlinks False \ 122 | --repo-type dataset \ 123 | --local-dir {self.final_out_path} \ 124 | --cache-dir {self.final_out_path}/cache \ 125 | --token {TOKEN} 126 | """) 127 | os.system(f'rm -rf {self.final_out_path}/cache') 128 | return self.final_out_path 129 | 130 | def _finished_check(self): 131 | """检查是否下载完整数据 132 | """ 133 | no_flag = (self.total_file_nums is not None) or (self.total_file_nums <= 0.01) 134 | print(f'self.total_file_nums={self.total_file_nums} no_flag={no_flag}') 135 | if not no_flag and os.path.exists(self.final_out_path): 136 | downloaded_files, download_bytes = self._get_final_out_bytes() 137 | print(f'downloaded_files={downloaded_files}\ndownload_bytes={download_bytes}') 138 | file_same = len(downloaded_files) == self.total_file_nums 139 | size_same = download_bytes / 1024**2 / (self.total_MB + 1e-5) >= 0.99 140 | return size_same & file_same 141 | return True 142 | 143 | def _t_start(self): 144 | self._t_handle_pg = threading.Thread(target=self.progress, name='X-dataset-progress', daemon=True) 145 | self._t_handle_pg.start() 146 | 147 | def _t_download(self, d_func): 148 | self._t_handle_dl = threading.Thread(target=d_func, name='X-dataset-download', daemon=True) 149 | self._t_handle_dl.start() 150 | 151 | def _get_final_out_bytes(self): 152 | # data存在多层的情况 153 | downloaded_files = get_final_out_files(self.final_out_path) 154 | cached_mb1 = sum([p_getsize(f) for f in downloaded_files]) 155 | return downloaded_files, cached_mb1 156 | 157 | def progress(self, progress=None): 158 | hf_cache_dir = p_join(self.final_out_path, 'cache') 159 | self.bar_ = self.tqdm_class(total=round(self.total_MB*1024**2, 3), unit='iB', unit_scale=True) 160 | self.bar_.set_description('TotalProgress') 161 | bf = 0 162 | while True: 163 | if self._t_handle_dl is None: 164 | break 165 | if not self._t_handle_dl.is_alive(): 166 | break 167 | hf_cache_files = get_hf_cache_files(hf_cache_dir) if os.path.exists(hf_cache_dir) else [] 168 | _, cached_mb1 = self._get_final_out_bytes() 169 | cached_mb2 = sum([p_getsize(f) for f in hf_cache_files]) 170 | cached_mb = (cached_mb1 + cached_mb2) 171 | 172 | self.bar_.update(round(cached_mb - bf, 3)) 173 | bf = cached_mb 174 | time.sleep(self.progress_sleep) 175 | 176 | # 数据统计可能不准确 177 | finished_rate = cached_mb / (self.total_MB + 1e-5) / 1024**2 178 | if self._t_handle_dl is None and finished_rate <= 0.99: 179 | left = self.total_MB * 1024**2 - bf 180 | self.bar_.update(round(left, 3)) 181 | 182 | return 183 | 184 | def break_download(self): 185 | # 然后杀死该线程 186 | # 删除文件 187 | if self._t_handle_dl is not None: 188 | print('>>>>>>>>>>>>>>>>> break_download') 189 | stop_thread(self._t_handle_dl) 190 | self._t_handle_dl = None 191 | os.system(f'sh {CUR_DIR}/kill_hf.sh dataset') 192 | print('>>>>>>>>>>>>>>>>> stop_thread(self._t_handle_dl)') 193 | 194 | if self._t_handle_pg is not None: 195 | stop_thread(self._t_handle_pg) 196 | print('>>>>>>>>>>>>>>>>> stop_thread(self._t_handle_pg)') 197 | self._t_handle_pg = None 198 | self.remove_and_create() 199 | self._break_flag = True 200 | return "Done! Data-Download had interrupted!" 201 | 202 | 203 | if __name__ == '__main__': 204 | print(os.getcwd()) 205 | download_ = xtunerDataDownload( 206 | 'shibing624/medical', 207 | out_path='/home/scc/sccWork/myGitHub/My_Learn/tmp/download') 208 | # out_path='/root/tmp/download' 209 | # ) 210 | # download_.hf_download() # checked 211 | download_.auto_download() # checked-download & progress 212 | time.sleep(10) 213 | download_.break_download() # checked-download & progress & break 214 | print('Yes') 215 | # chech finished 216 | f_ = download_._finished_check() 217 | print(f'_finished_check={f_}') 218 | 219 | -------------------------------------------------------------------------------- /xtuner_download/download_model.py: -------------------------------------------------------------------------------- 1 | # python3 2 | # Create Date: 2024-01-23 3 | # Author: Scc_hy 4 | # Func: 模型拉取到本地 5 | # =========================================================================================== 6 | import os 7 | import gradio as gr 8 | from tqdm.auto import tqdm 9 | from openxlab.model import download as ox_download 10 | from modelscope.hub.snapshot_download import snapshot_download 11 | from modelscope.hub.api import HubApi, ModelScopeConfig 12 | from os.path import getsize as p_getsize 13 | from os.path import join as p_join 14 | import threading 15 | import time 16 | from .download_utils import stop_thread, _split_repo, get_hf_cache_files, get_model_info, TOKEN 17 | os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com' 18 | CUR_DIR = os.path.dirname(__file__) 19 | 20 | 21 | class xtunerModelDownload(): 22 | def __init__(self, model_name, out_path, tqdm_class=tqdm, progress_sleep=1) -> None: 23 | self.progress_sleep = progress_sleep 24 | self.tqdm_class = tqdm_class 25 | self.username, self.repository = _split_repo(model_name) 26 | self.model_name = model_name 27 | self.out_path = out_path 28 | self.final_out_path = p_join(out_path, f'{self.username}_{self.repository}') 29 | self.mid_download_dir = self.final_out_path 30 | self._t_handle_dl = None 31 | self._t_handle_pg = None 32 | self._break_flag = False 33 | self._get_info_flag = False 34 | self.remove_and_create() 35 | self.get_download_info() 36 | 37 | def reset_path(self, customer_dir): 38 | self.__remove_mid_files() 39 | self.__remove_final_files() 40 | self.out_path = f'{customer_dir}/model_download' 41 | print(f'xtunerModelDownload reset_path->{self.out_path}') 42 | self.final_out_path = p_join(self.out_path, f'{self.username}_{self.repository}') 43 | self.mid_download_dir = self.final_out_path 44 | self.__check_create_dir() 45 | 46 | def reset(self, model_name): 47 | self.remove_and_create() 48 | print(f'reset({model_name})') 49 | self.username, self.repository = _split_repo(model_name) 50 | self.model_name = model_name 51 | self.final_out_path = p_join(self.out_path, f'{self.username}_{self.repository}') 52 | self.mid_download_dir = self.final_out_path 53 | self._t_handle_dl = None 54 | self._t_handle_pg = None 55 | self._break_flag = False 56 | self._get_info_flag = False 57 | self.remove_and_create() 58 | self.get_download_info() 59 | 60 | def _username_map(self, tp): 61 | """username 映射 62 | """ 63 | modelscope_map_dict = { 64 | 'internlm': 'Shanghai_AI_Laboratory', 65 | 'meta-llama': 'shakechen', # Llma-2 66 | 'huggyllama': 'skyline2006', # Llma 67 | 'THUDM': 'ZhipuAI', 68 | '01-ai': '01ai', 69 | 'Qwen': 'qwen' 70 | } 71 | hf_map_dict = {} 72 | openxlab_map_dict = { 73 | 'internlm': 'OpenLMLab', 74 | 'meta-llama': 'shakechen', # Llma-2 75 | 'huggyllama': 'skyline2006', # Llma 76 | 'THUDM': 'ZhipuAI', 77 | '01-ai': '01ai' 78 | } 79 | sp_model_name = '{u_name}/{rep}'.format( 80 | u_name=eval(f"{tp}_map_dict.get('{self.username}', '{self.username}')"), 81 | rep=self.repository 82 | ) 83 | return sp_model_name 84 | 85 | def get_download_info(self): 86 | # 优先modelscope查看 87 | try: 88 | self.total_MB, self.total_file_nums = self._get_download_info() 89 | except Exception as e: 90 | self.total_MB, self.total_file_nums = get_model_info(self.model_name) 91 | self._get_info_flag = True 92 | 93 | def _get_download_info(self): 94 | _api = HubApi() 95 | headers = {'user-agent': ModelScopeConfig.get_user_agent(user_agent=None, )} 96 | snapshot_header = headers if 'CI_TEST' in os.environ else { 97 | **headers, 98 | **{ 99 | 'Snapshot': 'True' 100 | } 101 | } 102 | model_id = self._username_map('modelscope') 103 | model_files = _api.get_model_files( 104 | model_id=model_id, 105 | revision=None, 106 | recursive=True, 107 | use_cookies=False, 108 | headers=snapshot_header, 109 | ) 110 | total_MB = sum([i['Size']/1024**2 for i in model_files]) 111 | total_file_nums = len(model_files) 112 | return total_MB, total_file_nums 113 | 114 | def __check_create_dir(self): 115 | if not os.path.exists(self.out_path): 116 | os.system(f'mkdir -p {self.out_path}') 117 | if not os.path.exists(self.final_out_path): 118 | os.system(f'mkdir -p {self.final_out_path}') 119 | 120 | def __remove_mid_files(self): 121 | """中断时删除所有文件""" 122 | os.system(f'rm -rf {self.mid_download_dir}') 123 | # cd rm 124 | rm_dir = './' + self.mid_download_dir.replace(self.out_path, '.')[2:].split('/')[0] 125 | os.system(f'cd {self.out_path} && rm -rf {rm_dir} && rm -rf temp') 126 | # 删除 hf 的cache 127 | os.system(f'rm -rf {self.final_out_path}/cache') 128 | 129 | def __remove_final_files(self): 130 | os.system(f'rm -rf {self.final_out_path}') 131 | os.system(f'cd {self.out_path} && rm -rf {self.username}_{self.repository}') 132 | 133 | def remove_and_create(self): 134 | self.__remove_mid_files() 135 | self.__remove_final_files() 136 | self.__check_create_dir() 137 | self.mid_download_dir = self.final_out_path 138 | 139 | def auto_download(self, progress=gr.Progress(track_tqdm=True), tp='speed'): 140 | cnt_ = 0 141 | while not self._get_info_flag: 142 | cnt_ += 1 143 | time.sleep(1) 144 | if cnt_ == 30: 145 | break 146 | 147 | self._break_flag = False 148 | self._t_download(self.loop_download, tp) 149 | # self._t_start(progress) 150 | # progress not use thread 151 | self.progress(progress=progress) 152 | if self._break_flag: 153 | return "Done! Model-Download had interrupted!" 154 | return self.final_out_path 155 | 156 | def loop_download(self, tp='speed'): 157 | # modelscope first 158 | # if 'internlm' in self.model_name.lower(): 159 | # loop_list = [self.openxlab_download, self.modelscope_download, self.hf_download] 160 | if tp == 'speed': 161 | loop_list = [self.modelscope_download, self.hf_download, self.openxlab_download] 162 | else: 163 | loop_list = [self.hf_download, self.modelscope_download, self.openxlab_download] 164 | 165 | for download_func in loop_list: 166 | print("download_func=", download_func) 167 | print("_get_info_flag=", self._get_info_flag) 168 | try: 169 | download_func() 170 | time.sleep(1) 171 | except Exception as e: 172 | print("download_func=", download_func, '\n', '--'*25, f'\nerror={e}\n', '--'*25) 173 | pass 174 | # 执行完检验 175 | if self._finished_check(): 176 | print('finished download all model files') 177 | break 178 | 179 | print('Failed download all model files & remove_and_create') 180 | self.remove_and_create() 181 | return 182 | 183 | def hf_download(self): 184 | print('>>>>>>> Start hf_download') 185 | # 1- mid download local dir 186 | self.mid_download_dir = self.final_out_path 187 | # 2- download 188 | os.system(f""" 189 | export HF_ENDPOINT=https://hf-mirror.com && \ 190 | huggingface-cli download --resume-download {self.model_name} --local-dir-use-symlinks False \ 191 | --repo-type model \ 192 | --local-dir {self.final_out_path} \ 193 | --cache-dir {self.final_out_path}/cache \ 194 | --token {TOKEN} 195 | """) 196 | os.system(f'rm -rf {self.final_out_path}/cache') 197 | return self.final_out_path 198 | 199 | def modelscope_download(self): 200 | print('>>>>>>> Start modelscope_download') 201 | # 1- fix-name 202 | model_name = self._username_map('modelscope') 203 | # 2- mid download local dir 204 | self.mid_download_dir = mid_download_dir = p_join(self.out_path, model_name) 205 | # 3- download 206 | snapshot_download(model_id=model_name, cache_dir=self.out_path) 207 | # 保证目录一致 out_path/sccHyFuture/LLM_medQA_adapter --> final_out_path 208 | os.system(f'mv {mid_download_dir}/* {self.final_out_path}') 209 | self.__remove_mid_files() 210 | return self.final_out_path 211 | 212 | def openxlab_download(self): 213 | print('>>>>>>> Start openxlab_download') 214 | # 1- fix-name 215 | model_name = self._username_map('openxlab') 216 | # 2- mid download local dir 217 | self.mid_download_dir = self.final_out_path 218 | # 3- download 219 | ox_download(model_repo=model_name, output=self.final_out_path, cache=False) 220 | return self.final_out_path 221 | 222 | def _finished_check(self): 223 | """检查是否下载完整数据 224 | """ 225 | no_flag = (self.total_file_nums is not None) or (self.total_file_nums <= 0.01) 226 | if no_flag and os.path.exists(self.final_out_path): 227 | final_nums = len([i for i in os.listdir(self.final_out_path) if not os.path.isdir(i) ]) 228 | print('os.listdir(self.final_out_path)=', os.listdir(self.final_out_path)) 229 | final_MB = sum([p_getsize(p_join(self.final_out_path, i))/ 1024**2 for i in os.listdir(self.final_out_path)]) 230 | file_same = final_nums >= self.total_file_nums 231 | size_same = final_MB/(self.total_MB + 1e-5) >= 0.99 232 | print(f"self.total_file_nums={self.total_file_nums} final_nums={final_nums} >>>>> file_same={file_same}") 233 | print(f"self.total_MB={self.total_MB:.3f} final_MB={final_MB:.3f} >>>>> size_same={size_same}") 234 | return size_same & file_same 235 | return True 236 | 237 | def _t_start(self, pg=None): 238 | self._t_handle_pg = threading.Thread(target=self.progress, args=(pg,), name='X-model-progress', daemon=True) 239 | self._t_handle_pg.start() 240 | 241 | def _t_download(self, d_func, tp): 242 | self._t_handle_dl = threading.Thread(target=d_func, args=(tp,) ,name='X-model-download', daemon=True) 243 | self._t_handle_dl.start() 244 | 245 | def progress(self, progress=None): 246 | model_scope_cache_dir = p_join(self.out_path, 'temp') 247 | hf_cache_dir = p_join(self.final_out_path, 'cache') 248 | self.bar_ = self.tqdm_class(total=round(self.total_MB*1024**2, 3), unit='iB', unit_scale=True) 249 | self.bar_.set_description('TotalProgress') 250 | bf = 0 251 | while True: 252 | if self._t_handle_dl is None: 253 | break 254 | if not self._t_handle_dl.is_alive(): 255 | break 256 | hf_cache_files = get_hf_cache_files(hf_cache_dir) if os.path.exists(hf_cache_dir) else [] 257 | if self.mid_download_dir == self.final_out_path: 258 | cached_mb1 = sum([p_getsize(p_join(self.final_out_path, i)) for i in os.listdir(self.final_out_path)]) 259 | cached_mb4 = sum([p_getsize(f) for f in hf_cache_files]) 260 | cached_mb = cached_mb1 + cached_mb4 261 | else: 262 | # 获取最近创建的temp文件 263 | try: 264 | model_scope_cache_dir_tmp = sorted([ 265 | [p_join(model_scope_cache_dir, i), os.stat(p_join(model_scope_cache_dir, i)).st_atime] for i in os.listdir(model_scope_cache_dir) 266 | ], key=lambda c:c[1])[-1][0] 267 | except Exception as e: 268 | model_scope_cache_dir_tmp = None 269 | 270 | cached_mb1 = 0 271 | if os.path.exists(self.mid_download_dir): 272 | cached_mb1 = sum([p_getsize(p_join(self.mid_download_dir, i)) for i in os.listdir(self.mid_download_dir)]) 273 | 274 | cached_mb2 = sum([p_getsize(p_join(self.final_out_path, i)) for i in os.listdir(self.final_out_path)]) 275 | cached_mb3 = 0 276 | if model_scope_cache_dir_tmp is not None: 277 | cached_mb3 = sum([p_getsize(p_join(model_scope_cache_dir_tmp, i)) for i in os.listdir(model_scope_cache_dir_tmp)]) 278 | 279 | cached_mb4 = sum([p_getsize(f) for f in hf_cache_files]) 280 | cached_mb = (cached_mb1 + cached_mb2 + cached_mb3 + cached_mb4) 281 | 282 | self.bar_.update(round(cached_mb - bf, 3)) 283 | bf = cached_mb 284 | if cached_mb / (self.total_MB + 1e-5) / 1024**2 > 99.99: 285 | break 286 | time.sleep(self.progress_sleep) 287 | return 288 | 289 | def break_download(self): 290 | # 然后杀死该线程 291 | # 删除文件 292 | if self._t_handle_dl is not None: 293 | print('>>>>>>>>>>>>>>>>> break_download') 294 | stop_thread(self._t_handle_dl) 295 | self._t_handle_dl = None 296 | os.system(f'sh {CUR_DIR}/kill_hf.sh model') 297 | print('>>>>>>>>>>>>>>>>> stop_thread(self._t_handle_dl)') 298 | 299 | if self._t_handle_pg is not None: 300 | stop_thread(self._t_handle_pg) 301 | print('>>>>>>>>>>>>>>>>> stop_thread(self._t_handle_pg)') 302 | self._t_handle_pg = None 303 | self.remove_and_create() 304 | self._break_flag = True 305 | return "Done! Model-Download had interrupted!" 306 | 307 | 308 | if __name__ == '__main__': 309 | print(os.getcwd()) 310 | download_ = xtunerModelDownload( 311 | 'internlm/InternLM-chat-7b', 312 | out_path='/home/scc/sccWork/myGitHub/My_Learn/tmp/download') 313 | #'/root/tmp/download') 314 | #'/home/scc/sccWork/myGitHub/My_Learn/tmp/download') 315 | # download_.hf_download() # checked-download & progress 316 | # download_.openxlab_download() # checked-download & progress 317 | # download_.modelscope_download() # checked-download & progress 318 | download_.auto_download() # checked-download & progress 319 | time.sleep(10) 320 | download_.break_download() # checked-download & progress & break 321 | print('Yes') 322 | # chech finished 323 | f_ = download_._finished_check() 324 | print(f'_finished_check={f_}') 325 | -------------------------------------------------------------------------------- /xtuner_download/download_utils.py: -------------------------------------------------------------------------------- 1 | # python3 2 | # Create Date: 2024-01-24 3 | # Author: Scc_hy 4 | # Func: 模型拉取到本地 5 | # =========================================================================================== 6 | import os 7 | import re 8 | import inspect 9 | import ctypes 10 | import requests 11 | import inspect 12 | import ctypes 13 | import requests 14 | from tqdm.auto import tqdm 15 | from os.path import join as p_join 16 | from concurrent.futures import ThreadPoolExecutor 17 | from huggingface_hub.hf_api import HfApi 18 | 19 | os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com' 20 | TOKEN = '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' 21 | 22 | def get_hf_cache_files(folder_path): 23 | all_files = [] 24 | for root, dirs, files in os.walk(folder_path): 25 | for file in files: 26 | file_tt = p_join(root, file) 27 | if os.path.isfile(file_tt) and '.incomplete' in file: 28 | all_files.append(file_tt) 29 | return all_files 30 | 31 | 32 | def get_final_out_files(folder_path): 33 | all_files = [] 34 | for root, dirs, files in os.walk(folder_path): 35 | for file in files: 36 | file_tt = p_join(root, file) 37 | if not os.path.isdir(file_tt) and 'cache' not in file_tt: 38 | all_files.append(file_tt) 39 | return all_files 40 | 41 | 42 | 43 | def _split_repo(model_repo) -> (str, str): 44 | """ 45 | Split a full repository name into two separate strings: the username and the repository name. 46 | """ 47 | # username/repository format check 48 | pattern = r'.+/.+' 49 | if not re.match(pattern, model_repo): 50 | raise ValueError("The input string must be in the format 'username/model_repo'") 51 | 52 | values = model_repo.split('/') 53 | return values[0], values[1] 54 | 55 | 56 | def _async_raise(tid, exctype): 57 | """Raises an exception in the threads with id tid""" 58 | if not inspect.isclass(exctype): 59 | raise TypeError("Only types can be raised (not instances)") 60 | res = ctypes.pythonapi.PyThreadState_SetAsyncExc(ctypes.c_long(tid), ctypes.py_object(exctype)) 61 | if res == 0: 62 | raise ValueError("invalid thread id") 63 | elif res != 1: 64 | # """if it returns a number greater than one, you're in trouble, 65 | # and you should call it again with exc=NULL to revert the effect""" 66 | ctypes.pythonapi.PyThreadState_SetAsyncExc(tid, None) 67 | raise SystemError("PyThreadState_SetAsyncExc failed") 68 | 69 | 70 | def stop_thread(thread): 71 | try: 72 | _async_raise(thread.ident, SystemExit) 73 | except Exception as e: 74 | print(e) 75 | 76 | 77 | def detect_data_file_bytes(data_name, data_file): 78 | try: 79 | txt = requests.get(f'https://hf-mirror.com/datasets/{data_name}/blob/main/{data_file}', timeout=3).text 80 | except Exception as e: 81 | print(e) 82 | return 0.0 83 | find_out = re.findall(r'Size of remote file:(.*?)B', txt, flags=re.S) 84 | find_info = find_out[0] if len(find_out) else '0 ' 85 | info_num = float(re.findall(r'\d+.\d+|\d+', find_info)[0]) 86 | info_cat = find_info[-1]+'B' 87 | # huggingface 直接1000 88 | info_map = { 89 | 'kB': 1000, 90 | 'KB': 1000, 91 | 'MB': 1000 ** 2, 92 | 'mB': 1000 ** 2, 93 | 'GB': 1000 ** 3, 94 | 'gB': 1000 ** 3, 95 | ' B': 1, 96 | } 97 | return info_num * info_map.get(info_cat, 1.0) 98 | 99 | 100 | def get_data_info(data_name): 101 | os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com' 102 | api_ = HfApi(token=TOKEN) 103 | try: 104 | df_info = api_.dataset_info(repo_id=data_name, token=TOKEN, timeout=3) 105 | except Exception as e: 106 | print(e) 107 | return 0, 0 108 | df_files = [i.rfilename for i in df_info.siblings] 109 | exec_ = ThreadPoolExecutor(max_workers=2) 110 | tasks = [exec_.submit(detect_data_file_bytes, data_name=data_name, data_file=i) for i in df_files] 111 | res = [] 112 | for t in tqdm(tasks): 113 | res.append(t.result()) 114 | 115 | total_MB = sum(res) / 1024 ** 2 116 | total_file_nums = len(df_files) 117 | return total_MB, total_file_nums 118 | 119 | 120 | def detect_model_file_bytes(model_name, data_file): 121 | try: 122 | txt = requests.get(f'https://hf-mirror.com/{model_name}/blob/main/{data_file}', timeout=3).text 123 | except Exception as e: 124 | print(e) 125 | return 0.0 126 | find_out = re.findall(r'Size of remote file:(.*?)B', txt, flags=re.S) 127 | find_info = find_out[0] if len(find_out) else '0 ' 128 | info_num = float(re.findall(r'\d+.\d+|\d+', find_info)[0]) 129 | info_cat = find_info[-1]+'B' 130 | # huggingface 直接1000 131 | info_map = { 132 | 'kB': 1000, 133 | 'KB': 1000, 134 | 'MB': 1000 ** 2, 135 | 'mB': 1000 ** 2, 136 | 'GB': 1000 ** 3, 137 | 'gB': 1000 ** 3, 138 | ' B': 1, 139 | } 140 | return info_num * info_map.get(info_cat, 1.0) 141 | 142 | 143 | def get_model_info(model_name): 144 | os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com' 145 | api_ = HfApi(token=TOKEN) 146 | try: 147 | df_info = api_.model_info(repo_id=model_name, token=TOKEN, timeout=3) 148 | except Exception as e: 149 | print(e) 150 | return 0, 0 151 | df_files = [i.rfilename for i in df_info.siblings] 152 | exec_ = ThreadPoolExecutor(max_workers=2) 153 | tasks = [exec_.submit(detect_model_file_bytes, model_name=model_name, data_file=i) for i in df_files] 154 | res = [] 155 | for t in tqdm(tasks): 156 | res.append(t.result()) 157 | 158 | total_MB = sum(res) / 1024 ** 2 159 | total_file_nums = len(df_files) 160 | return total_MB, total_file_nums 161 | 162 | 163 | def test_get_data_info(): 164 | print('>>>>>>>> Start test') 165 | data_name = 'shibing624/medical' 166 | total_MB, total_file_nums = get_data_info(data_name) 167 | print(f'[ data_name={data_name} ] total_file_nums={total_file_nums} | total_MB={total_MB:.3f}MiB') 168 | 169 | def test_down_load_data(): 170 | print('>>>>>>>> Start test') 171 | model_name = 'internlm/internlm-chat-7b' 172 | total_MB, total_file_nums = get_model_info(model_name) 173 | print(f'[ data_name={model_name} ] total_file_nums={total_file_nums} | total_MB={total_MB:.3f}MiB') 174 | 175 | 176 | 177 | if __name__ == '__main__': 178 | test_get_data_info() 179 | test_down_load_data() 180 | -------------------------------------------------------------------------------- /xtuner_download/find_datalist.py: -------------------------------------------------------------------------------- 1 | 2 | import re 3 | import os 4 | from os.path import getsize as p_getsize 5 | from os.path import join as p_join 6 | 7 | 8 | def get_py_config_files(folder_path): 9 | all_files = [] 10 | for root, dirs, files in os.walk(folder_path): 11 | for file in files: 12 | file_tt = p_join(root, file) 13 | if os.path.isfile(file_tt) and '.py' in file and '__init__' not in file: 14 | all_files.append(file_tt) 15 | return all_files 16 | 17 | 18 | def read_and_find_path(file): 19 | with open(file, 'r') as f: 20 | res = f.readlines()[:50] 21 | return re.findall(r"path\s+=\s+'(.*?)'\n", ''.join(res)) 22 | 23 | 24 | father_p = '/home/scc/sccWork/openProject/xtuner019/xtuner/xtuner/configs' 25 | py_files = get_py_config_files(father_p) 26 | py_need_files = [i.rsplit('/', 1)[1] for i in py_files] 27 | files_need = set([i.split('lora_')[-1].replace('.py', '') for i in py_need_files]) 28 | path_ = [read_and_find_path(i) for i in py_files] 29 | path_final = [] 30 | for p in path_: 31 | path_final.extend(p) 32 | 33 | model_str = ['Llama', 'Qwen', 'Baichuan', 'chatglm', '01-ai', 'internlm', 'llama'] 34 | path_final = [i for i in sorted(set(path_final)) if all(j not in i for j in model_str)] 35 | 36 | 37 | 38 | 39 | -------------------------------------------------------------------------------- /xtuner_download/kill_hf.sh: -------------------------------------------------------------------------------- 1 | 2 | 3 | rp_tp=$1; 4 | 5 | pid=`ps -ef | grep -E "huggingface-cli.*?${rp_tp}" | grep -v "grep" | awk '{print $2}'` 6 | for id in $pid 7 | do 8 | kill -9 $id 9 | echo "killed $id" 10 | done; 11 | -------------------------------------------------------------------------------- /xtuner_download/model_list.txt: -------------------------------------------------------------------------------- 1 | OpenLMLab/internlm-7b 2 | OpenLMLab/internlm-20b 3 | OpenLMLab/internlm-chat-7b 4 | OpenLMLab/internlm-chat-20b 5 | 6 | shakechen/Llama-2-7b-chat meta-llama/Llama-2-7b-chat 7 | --llama2_70b meta-llama/llama2-70b-- 8 | shakechen/Llama-2-7b meta-llama/Llama-2-7b 9 | skyline2006/llama-7b huggyllama/llama-7b 10 | 11 | baichuan-inc/Baichuan-7B-Chat 12 | baichuan-inc/Baichuan-13B-Base 13 | baichuan-inc/baichuan-7B 14 | baichuan-inc/Baichuan2-13B-Chat 15 | baichuan-inc/Baichuan2-7B-Chat 16 | baichuan-inc/Baichuan2-13B-Chat 17 | baichuan-inc/Baichuan2-13B-Base 18 | 19 | ZhipuAI/chatglm3-6b THUDM/chatglm3-6b 20 | ZhipuAI/chatglm2-6b THUDM/chatglm2-6b 21 | ZhipuAI/chatglm3-6b-base THUDM/chatglm3-6b-base 22 | 23 | yi_34b 24 | 01ai/Yi-6B 01-ai/Yi-6B 25 | Qwen/Qwen-7B-Chat 26 | Qwen/Qwen-7B 27 | 28 | 29 | -------------------------------------------------------------------------------- /xtuner_download/test_download.py: -------------------------------------------------------------------------------- 1 | from download_model import xtunerModelDownload 2 | from download_dataset import xtunerDataDownload 3 | from tqdm.auto import tqdm 4 | import time 5 | 6 | def main(): 7 | print('>>>>>>>> Start xtunerModelDownload') 8 | model_name = 'internlm/internlm-chat-7b' 9 | d_model = xtunerModelDownload( 10 | model_name, 11 | out_path='/root/tmp/download_model', 12 | tqdm_class=tqdm 13 | ) 14 | d_model.auto_download() 15 | print('>>>>>>>> Start xtunerDataDownload') 16 | data_name = 'shibing624/medical' 17 | d_data = xtunerDataDownload( 18 | data_name, 19 | out_path='/root/tmp/download_data', 20 | tqdm_class=tqdm, 21 | retry_times=0 22 | ) 23 | d_data.auto_download() 24 | time.sleep(60) 25 | d_data.break_download() 26 | d_model.break_download() 27 | print('Yes') 28 | 29 | 30 | if __name__ == '__main__': 31 | main() 32 | -------------------------------------------------------------------------------- /xtuner_download/todo_list.md: -------------------------------------------------------------------------------- 1 | - modelDownload 2 | - [X] download base 3 | - [X] check finished download 4 | - [X] 通过modelscope获取文件列表 5 | - [X] progress 6 | - [X] break down of downloading 7 | - [X] kill and delete restart files 8 | 9 | - dataDownload 10 | - [X] DataLists 11 | - [X] Download method 12 | - [X] only huggingface 13 | -------------------------------------------------------------------------------- /xtuner_result/draw.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | from datasets import load_dataset 3 | import altair as alt 4 | import pandas as pd 5 | import os 6 | import re 7 | import gradio as gr 8 | 9 | 10 | class resPlot: 11 | def __init__(self, work_dir): 12 | self.work_dir = work_dir 13 | self.log_file = None 14 | self.iter_dir_list = [] 15 | self.get_log_path() 16 | self.find_iter_pth() 17 | print(f"resPlot(self.log_file={self.log_file})") 18 | 19 | def get_log_path(self): 20 | try: 21 | list_ = sorted([i for i in os.listdir(self.work_dir) if '.' not in i and re.match(r'\d+_\d+', i)]) 22 | dir_name = list_[-1] 23 | self.log_file = os.path.join(self.work_dir, dir_name, 'vis_data' , f'{dir_name}.json') 24 | except Exception as e: 25 | print(e) 26 | pass 27 | 28 | def find_iter_pth(self): 29 | try: 30 | self.iter_dir_list = sorted([i for i in os.listdir(self.work_dir) if '.pth' in i]) 31 | except Exception as e: 32 | print(e) 33 | pass 34 | 35 | def get_eval_test(self, ep_pth): 36 | ep_str = ep_pth.split('.')[0] 37 | ep = ep_str.split('_')[0] + '_' + str(int(ep_str.split('_')[1]) - 1) 38 | list_ = sorted([i for i in os.listdir(self.work_dir) if '.' not in i and re.match(r'\d+_\d+', i)]) 39 | dir_name = list_[-1] 40 | eval_file = os.path.join(self.work_dir, dir_name, 'vis_data' , f'eval_outputs_{ep}.txt') 41 | try: 42 | return open(eval_file, 'r').read() 43 | except Exception as e: 44 | return f'eval_file={eval_file}\nERROR: {e} ' 45 | 46 | def dynamic_eval_drop_down(self): 47 | list_ = sorted([i for i in os.listdir(self.work_dir) if '.' not in i and re.match(r'\d+_\d+', i)]) 48 | dir_name = list_[-1] 49 | # /root/xtunerUITest/test/appPrepare/work_dir/20240204_204337/vis_data/eval_outputs_iter_49.txt 50 | eval_file = [i for i in os.listdir(os.path.join(self.work_dir, dir_name, 'vis_data')) if '.txt' in i] 51 | final_list = [] 52 | if len(eval_file): 53 | final_list = ["iter_{}".format(int(i.split('_')[-1].split('.')[0])+1) for i in eval_file] 54 | 55 | return gr.Dropdown(choices=final_list, interactive=True) 56 | 57 | def dynamic_drop_down(self): 58 | self.iter_dir_list = sorted([i for i in os.listdir(self.work_dir) if '.pth' in i]) 59 | return gr.Dropdown(choices=self.iter_dir_list, interactive=True) 60 | 61 | def reset_work_dir(self, root_dir): 62 | self.work_dir = f'{root_dir}/work_dir' 63 | self.get_log_path() 64 | self.find_iter_pth() 65 | print(f"resPlot -> self.work_dir={self.work_dir}\nself.log_file={self.log_file}\nself.iter_dir_list={self.iter_dir_list}") 66 | 67 | def lr_plot(self): 68 | self.get_log_path() 69 | y_axis_name = 'lr' 70 | return self.gr_line_plot(y_axis_name, self.log_file) 71 | 72 | def loss_plot(self): 73 | self.get_log_path() 74 | y_axis_name = 'loss' 75 | return self.gr_line_plot(y_axis_name, self.log_file) 76 | 77 | @staticmethod 78 | def make_plot(y_axis_name, log_path): 79 | ds = load_dataset('json', data_files=log_path) 80 | ds = ds['train'].to_pandas() 81 | ds = ds.rename(columns={'iter': 'iter_num'}) 82 | # ['lr', 'data_time', 'loss', 'time', 'grad_norm', 'iter', 'memory', 'step'] 83 | source = pd.DataFrame({ 84 | 'iter_num': ds['iter_num'].map(int).tolist(), 85 | y_axis_name: ds[y_axis_name].map(float).tolist(), 86 | }) 87 | base = alt.Chart(source).mark_line( 88 | point=alt.OverlayMarkDef(filled=False, fill="white") 89 | ).encode(x='iter_num',y=y_axis_name) 90 | return base 91 | 92 | @staticmethod 93 | def gr_line_plot(y_axis_name, log_path): 94 | ds = load_dataset('json', data_files=log_path) 95 | ds = ds['train'].to_pandas() 96 | ds = ds.rename(columns={'iter': 'iter_num'}) 97 | source = pd.DataFrame({ 98 | 'iter_num': ds['iter_num'].map(int).tolist(), 99 | y_axis_name: ds[y_axis_name].map(float).tolist(), 100 | }) 101 | return gr.LinePlot( 102 | source, 103 | x="iter_num", 104 | x_title='iter_num', 105 | y=y_axis_name, 106 | y_title=y_axis_name, 107 | overlay_point=True, 108 | tooltip=["iter_num", y_axis_name], 109 | title=y_axis_name, 110 | height=300, 111 | width=500, 112 | ) 113 | 114 | def draw(y_axis_name, log_path, save_path): 115 | """ 116 | Args: 117 | y_axis_name: One of ('lr', 'loss') 118 | """ 119 | ds = load_dataset('json', data_files=log_path) 120 | ds = ds['train'] 121 | x = ds['iter'] 122 | y = ds[y_axis_name] 123 | plt.figure(figsize=(10,5)) 124 | 125 | plt.plot(x, y, marker='.') 126 | 127 | plt.title(f'Training Iterations vs {y_axis_name}') 128 | plt.xlabel('Training iterations') 129 | plt.ylabel(y_axis_name) 130 | 131 | plt.savefig(save_path) 132 | 133 | 134 | if __name__ == '__main__': 135 | draw('loss', './dummy_log.json', 'd1.png') 136 | -------------------------------------------------------------------------------- /xtuner_run/example.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import torch 3 | from datasets import load_dataset 4 | from mmengine.dataset import DefaultSampler 5 | from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook, 6 | LoggerHook, ParamSchedulerHook) 7 | from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR 8 | from peft import LoraConfig 9 | from torch.optim import AdamW 10 | from transformers import (AutoModelForCausalLM, AutoTokenizer, 11 | BitsAndBytesConfig) 12 | 13 | from xtuner.dataset import process_hf_dataset 14 | from xtuner.dataset.collate_fns import default_collate_fn 15 | from xtuner.dataset.map_fns import code_alpaca_map_fn, template_map_fn_factory 16 | from xtuner.engine.hooks import DatasetInfoHook, EvaluateChatHook 17 | # from xtuner.engine.runner import TrainLoop 18 | from xtuner.model import SupervisedFinetune 19 | from xtuner.utils import PROMPT_TEMPLATE, SYSTEM_TEMPLATE 20 | 21 | ####################################################################### 22 | # PART 1 Settings # 23 | ####################################################################### 24 | # Model 25 | pretrained_model_name_or_path = 'internlm/internlm2-chat-7b' 26 | 27 | # Data 28 | data_path = 'HuggingFaceH4/CodeAlpaca_20K' 29 | prompt_template = PROMPT_TEMPLATE.internlm2_chat 30 | max_length = 2048 31 | pack_to_max_length = True 32 | 33 | # Scheduler & Optimizer 34 | batch_size = 1 # per_device 35 | accumulative_counts = 16 36 | dataloader_num_workers = 0 37 | max_epochs = 3 38 | optim_type = AdamW 39 | lr = 2e-4 40 | betas = (0.9, 0.999) 41 | weight_decay = 0 42 | max_norm = 1 # grad clip 43 | warmup_ratio = 0.03 44 | 45 | # Save 46 | save_steps = 500 47 | save_total_limit = 2 # Maximum checkpoints to keep (-1 means unlimited) 48 | 49 | # Evaluate the generation performance during the training 50 | evaluation_freq = 100 51 | SYSTEM = SYSTEM_TEMPLATE.coder 52 | evaluation_inputs = [ 53 | ('写一个Python函数,将十六进制颜色代码(如#0066ee)转换为对应的' 54 | '红、绿、蓝(RGB)三个颜色分量值,并以元组的形式返回。'), 55 | ('Write a Python function that takes a hexadecimal color code ' 56 | '(e.g., #0066ee) as input and converts it into the corresponding ' 57 | 'red, green, and blue (RGB) color component values.') 58 | ] 59 | 60 | ####################################################################### 61 | # PART 2 Model & Tokenizer # 62 | ####################################################################### 63 | tokenizer = dict( 64 | type=AutoTokenizer.from_pretrained, 65 | pretrained_model_name_or_path=pretrained_model_name_or_path, 66 | trust_remote_code=True, 67 | padding_side='right') 68 | 69 | model = dict( 70 | type=SupervisedFinetune, 71 | llm=dict( 72 | type=AutoModelForCausalLM.from_pretrained, 73 | pretrained_model_name_or_path=pretrained_model_name_or_path, 74 | trust_remote_code=True, 75 | torch_dtype=torch.float16, 76 | quantization_config=dict( 77 | type=BitsAndBytesConfig, 78 | load_in_4bit=True, 79 | load_in_8bit=False, 80 | llm_int8_threshold=6.0, 81 | llm_int8_has_fp16_weight=False, 82 | bnb_4bit_compute_dtype=torch.float16, 83 | bnb_4bit_use_double_quant=True, 84 | bnb_4bit_quant_type='nf4')), 85 | lora=dict( 86 | type=LoraConfig, 87 | r=64, 88 | lora_alpha=16, 89 | lora_dropout=0.1, 90 | bias='none', 91 | task_type='CAUSAL_LM')) 92 | 93 | ####################################################################### 94 | # PART 3 Dataset & Dataloader # 95 | ####################################################################### 96 | train_dataset = dict( 97 | type=process_hf_dataset, 98 | dataset=dict(type=load_dataset, path=data_path), 99 | tokenizer=tokenizer, 100 | max_length=max_length, 101 | dataset_map_fn=code_alpaca_map_fn, 102 | template_map_fn=dict( 103 | type=template_map_fn_factory, template=prompt_template), 104 | remove_unused_columns=True, 105 | shuffle_before_pack=True, 106 | pack_to_max_length=pack_to_max_length) 107 | 108 | train_dataloader = dict( 109 | batch_size=batch_size, 110 | num_workers=dataloader_num_workers, 111 | dataset=train_dataset, 112 | sampler=dict(type=DefaultSampler, shuffle=True), 113 | collate_fn=dict(type=default_collate_fn)) 114 | 115 | ####################################################################### 116 | # PART 4 Scheduler & Optimizer # 117 | ####################################################################### 118 | # optimizer 119 | optim_wrapper = dict( 120 | type=AmpOptimWrapper, 121 | optimizer=dict( 122 | type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay), 123 | clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False), 124 | accumulative_counts=accumulative_counts, 125 | loss_scale='dynamic', 126 | dtype='float16') 127 | 128 | # learning policy 129 | # More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501 130 | param_scheduler = [ 131 | dict( 132 | type=LinearLR, 133 | start_factor=1e-5, 134 | by_epoch=True, 135 | begin=0, 136 | end=warmup_ratio * max_epochs, 137 | convert_to_iter_based=True), 138 | dict( 139 | type=CosineAnnealingLR, 140 | eta_min=0.0, 141 | by_epoch=True, 142 | begin=warmup_ratio * max_epochs, 143 | T_max=max_epochs, 144 | convert_to_iter_based=True) 145 | ] 146 | 147 | # train, val, test setting 148 | # train_cfg = dict(type=TrainLoop, max_epochs=max_epochs) 149 | 150 | ####################################################################### 151 | # PART 5 Runtime # 152 | ####################################################################### 153 | # Log the dialogue periodically during the training process, optional 154 | custom_hooks = [ 155 | dict(type=DatasetInfoHook, tokenizer=tokenizer), 156 | dict( 157 | type=EvaluateChatHook, 158 | tokenizer=tokenizer, 159 | every_n_iters=evaluation_freq, 160 | evaluation_inputs=evaluation_inputs, 161 | system=SYSTEM, 162 | prompt_template=prompt_template) 163 | ] 164 | 165 | # configure default hooks 166 | default_hooks = dict( 167 | # record the time of every iteration. 168 | timer=dict(type=IterTimerHook), 169 | # print log every 10 iterations. 170 | logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10), 171 | # enable the parameter scheduler. 172 | param_scheduler=dict(type=ParamSchedulerHook), 173 | # save checkpoint per `save_steps`. 174 | checkpoint=dict( 175 | type=CheckpointHook, 176 | by_epoch=False, 177 | interval=save_steps, 178 | max_keep_ckpts=save_total_limit), 179 | # set sampler seed in distributed evrionment. 180 | sampler_seed=dict(type=DistSamplerSeedHook), 181 | ) 182 | 183 | # configure environment 184 | env_cfg = dict( 185 | # whether to enable cudnn benchmark 186 | cudnn_benchmark=False, 187 | # set multi process parameters 188 | mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), 189 | # set distributed parameters 190 | dist_cfg=dict(backend='nccl'), 191 | ) 192 | 193 | # set visualizer 194 | visualizer = None 195 | 196 | # set log level 197 | log_level = 'INFO' 198 | 199 | # load from which checkpoint 200 | load_from = None 201 | 202 | # whether to resume training from the loaded checkpoint 203 | resume = False 204 | 205 | # Defaults to use random seed and disable `deterministic` 206 | randomness = dict(seed=None, deterministic=False) 207 | 208 | # set log processor 209 | log_processor = dict(by_epoch=False) -------------------------------------------------------------------------------- /xtuner_run/kill_xtuner.sh: -------------------------------------------------------------------------------- 1 | 2 | pid=`ps -ef | grep -E "xtuner train.*" | grep -v "grep" | awk '{print $2}'` 3 | for id in $pid 4 | do 5 | kill -9 $id 6 | echo "killed $id" 7 | done; 8 | 9 | pid=`ps -ef | grep -E "python.*?train.py.*" | grep -v "grep" | awk '{print $2}'` 10 | for id in $pid 11 | do 12 | kill -9 $id 13 | echo "killed $id" 14 | done; -------------------------------------------------------------------------------- /xtuner_run/shell_train.py: -------------------------------------------------------------------------------- 1 | # python3 2 | # Create Date: 2024-01-30 3 | # Author: Scc_hy 4 | # Func: 用shell 启动xtuner 5 | # =========================================================================================== 6 | 7 | from .train_utils import prepareConfig, prepareUtil, stop_thread 8 | import threading 9 | import os 10 | import gradio as gr 11 | CUR_DIR = os.path.dirname(__file__) 12 | 13 | 14 | class quickTrain: 15 | def __init__(self, 16 | work_dir, 17 | config_py_path, 18 | deepspeed_seed=None, 19 | resume_from_checkpoint=None, 20 | run_type='mmengine'): 21 | self.work_dir = work_dir 22 | self.config_py_path = config_py_path 23 | self.resume_from_checkpoint = resume_from_checkpoint 24 | self.run_type = run_type 25 | self.deepspeed_seed = deepspeed_seed 26 | self._t_handle_tr = None 27 | self.log_file = os.path.join(self.work_dir, '__xtuner_tr.log') 28 | self.remove_log_file() 29 | print(f'config_py_path={config_py_path}') 30 | 31 | def reset_resume_from_checkpoint(self, ckpt): 32 | self.resume_from_checkpoint = f'{self.work_dir}/{ckpt}' 33 | print(f"reset_resume_from_checkpoint({self.resume_from_checkpoint})") 34 | 35 | def reset_deepspeed(self, deepspeed): 36 | print(f"reset_deepspeed({deepspeed})") 37 | self.deepspeed_seed = deepspeed 38 | 39 | def reset_work_dir(self, local_path): 40 | print(f"reset_work_dir({local_path})") 41 | self.work_dir = os.path.join(local_path, 'work_dir') 42 | if not os.path.exists(self.work_dir): 43 | os.system(f'mkdir -p {self.work_dir}') 44 | self.log_file = os.path.join(self.work_dir, '__xtuner_tr.log') 45 | self.remove_log_file() 46 | 47 | def reset_cfg_py(self, cfg_py): 48 | print(f"reset_cfg_py({cfg_py})") 49 | self.config_py_path = cfg_py 50 | 51 | def remove_log_file(self): 52 | if os.path.exists(self.log_file): 53 | os.system(f'rm -rf {self.log_file}') 54 | 55 | def _quick_train(self, resume, progress=gr.Progress(track_tqdm=True)): 56 | self.remove_log_file() 57 | add_ = resume_ = '' 58 | if str(self.deepspeed_seed).lower() not in ['none', 'dropdown']: 59 | add_ = f'--deepspeed deepspeed_{self.deepspeed_seed} ' 60 | 61 | if self.resume_from_checkpoint is not None: 62 | resume_ = f'--resume {self.resume_from_checkpoint}' 63 | 64 | if resume: 65 | exec_ = f'xtuner train {self.config_py_path} --work-dir {self.work_dir} {add_} {resume_} > {self.log_file} 2>&1' 66 | else: 67 | exec_ = f'xtuner train {self.config_py_path} --work-dir {self.work_dir} {add_} > {self.log_file} 2>&1' 68 | 69 | print(f'exec={exec_}') 70 | os.system(exec_) 71 | 72 | def _t_start(self, resume=0): 73 | self._t_handle_tr = threading.Thread(target=self._quick_train, args=(resume,), name=f'X-train-{self.run_type}', daemon=True) 74 | self._t_handle_tr.start() 75 | 76 | def quick_train(self, progress=gr.Progress(track_tqdm=True)): 77 | self._break_flag = False 78 | self._t_start(0) 79 | self._t_handle_tr.join() 80 | if self._break_flag: 81 | return f"Done! Xtuner had interrupted!\nwork_dir={self.work_dir}", self.work_dir 82 | return "Success", self.work_dir 83 | 84 | def resume_train(self, progress=gr.Progress(track_tqdm=True)): 85 | self._break_flag = False 86 | self._t_start(1) 87 | self._t_handle_tr.join() 88 | if self._break_flag: 89 | return f"Done! Xtuner had interrupted!\nRESUME work_dir={self.work_dir}", self.work_dir 90 | return "Success", self.work_dir 91 | 92 | def _tail(self, n=100): 93 | line_list = [] 94 | with open(self.log_file, "rb") as f: 95 | f.seek(0, 2) 96 | while 1: 97 | if f.read(1) == b"\n": 98 | now_index = f.tell() 99 | line_list.append(f.readline()) 100 | f.seek(now_index, 0) 101 | if len(line_list) >= n: 102 | return line_list[::-1] 103 | if f.tell() <= 1: 104 | f.seek(0, 0) 105 | line_list.append(f.readline()) 106 | return line_list[::-1] 107 | f.seek(-2, 1) 108 | 109 | def read_log(self): 110 | if self._t_handle_tr is None: 111 | return "" 112 | if os.path.exists(self.log_file): 113 | # with open(self.log_file, 'r') as f: 114 | # res_ = f.readlines() 115 | # return ''.join(res_) 116 | line_list = self._tail(5) 117 | return b"".join(line_list).decode() 118 | 119 | def break_train(self): 120 | # 然后杀死该线程 121 | # 删除文件 122 | if self._t_handle_tr is not None: 123 | print('>>>>>>>>>>>>>>>>> break_download') 124 | stop_thread(self._t_handle_tr) 125 | os.system(f'sh {CUR_DIR}/kill_xtuner.sh') 126 | self._t_handle_tr = None 127 | 128 | self._break_flag = True 129 | return f"Done! Xtuner had interrupted!\nwork_dir={self.work_dir}", self.work_dir 130 | 131 | -------------------------------------------------------------------------------- /xtuner_run/todo_list.md: -------------------------------------------------------------------------------- 1 | 2 | - [X] config prepare 3 | - [X] test 4 | - [X] model sft 5 | - [X] code 6 | - [X] app add func 7 | - [ ] break & continue 8 | - [ ] log print 9 | 10 | -------------------------------------------------------------------------------- /xtuner_run/train.py: -------------------------------------------------------------------------------- 1 | # python3 2 | # Create Date: 2024-01-29 3 | # Author: Scc_hy 4 | # Func: 基于指定参数进行模型训练 5 | # nohup xtuner train ./internlm_chat_7b_qlora_oasst1_e3_copy.py > __xtuner.log & 6 | # pip install -U xtuner 7 | # =========================================================================================== 8 | 9 | from .train_utils import prepareConfig, prepareUtil, stop_thread 10 | import threading 11 | import os 12 | import gradio as gr 13 | from transformers import Trainer 14 | from xtuner.dataset.collate_fns import default_collate_fn 15 | from functools import partial 16 | from xtuner.dataset import process_hf_dataset 17 | from datasets import load_dataset 18 | from xtuner.dataset.map_fns import template_map_fn_factory 19 | from transformers import TrainingArguments 20 | import torch 21 | from peft import LoraConfig 22 | from transformers import (AutoModelForCausalLM, AutoTokenizer, 23 | BitsAndBytesConfig) 24 | from xtuner.model import SupervisedFinetune 25 | from xtuner.apis.datasets import alpaca_data_collator, alpaca_dataset 26 | from mmengine.runner import Runner 27 | import time 28 | 29 | def safe_load_data(file): 30 | tp = file.split('.')[-1] 31 | if tp == 'csv': 32 | return load_dataset('csv', data_files=dict(train=file)) 33 | if 'json' in tp: 34 | return load_dataset('json', data_files=dict(train=file)) 35 | 36 | # py 37 | return load_dataset(file, split='train') 38 | 39 | 40 | def saft_build_model(model_name_or_path, 41 | quantization_config=None, 42 | lora_config=None, 43 | return_tokenizer=True, 44 | qlora_flag=True): 45 | if quantization_config is None: 46 | quantization_config = BitsAndBytesConfig( 47 | load_in_4bit=True, 48 | load_in_8bit=False, 49 | llm_int8_threshold=6.0, 50 | llm_int8_has_fp16_weight=False, 51 | bnb_4bit_compute_dtype=torch.float16, 52 | bnb_4bit_use_double_quant=True, 53 | bnb_4bit_quant_type='nf4') 54 | if lora_config is None: 55 | lora_config = LoraConfig( 56 | r=64, 57 | lora_alpha=16, 58 | lora_dropout=0.1, 59 | bias='none', 60 | task_type='CAUSAL_LM') 61 | 62 | llm = AutoModelForCausalLM.from_pretrained( 63 | model_name_or_path, 64 | torch_dtype=torch.float16, 65 | trust_remote_code=True, 66 | quantization_config=quantization_config if qlora_flag else None) 67 | 68 | try: 69 | model = SupervisedFinetune(llm, lora=lora_config) 70 | except Exception as e: 71 | model = SupervisedFinetune(llm, lora=lora_config, use_activation_checkpointing=False) 72 | 73 | if return_tokenizer: 74 | tokenizer = AutoTokenizer.from_pretrained( 75 | model_name_or_path, 76 | trust_remote_code=True, 77 | encode_special_tokens=True) 78 | return model.llm, tokenizer 79 | else: 80 | return model.llm 81 | 82 | 83 | 84 | def mm_run( 85 | model_name_or_path, 86 | dataset_name_or_path, 87 | work_dir, 88 | xtuner_type='qlora', 89 | resume_from_checkpoint=None, 90 | progress=gr.Progress(track_tqdm=True) 91 | ): 92 | cfg_org = prepareConfig( 93 | model_name_or_path=model_name_or_path, 94 | dataset_name_or_path=dataset_name_or_path 95 | ) 96 | if resume_from_checkpoint is not None: 97 | cfg_org.resume_from_checkpoint = resume_from_checkpoint 98 | pp = prepareUtil(cfg_org, work_dir=work_dir, lora_type=xtuner_type) 99 | cfg = pp.auto_prepare() 100 | if resume_from_checkpoint is not None: 101 | cfg['load_from'] = resume_from_checkpoint 102 | cfg['resume'] = True 103 | try: 104 | runner = Runner.from_cfg(cfg) 105 | # runner = Runner.from_cfg(org_cfg) 106 | runner.train() 107 | # runner.test() 108 | except Exception as e: 109 | print(f"mm_run ERROR: \n{e}") 110 | return f"mm_run ERROR: \n{e}" 111 | return pp.work_dir 112 | 113 | 114 | def hf_run( 115 | model_name_or_path, 116 | dataset_name_or_path, 117 | work_dir, 118 | xtuner_type='qlora', 119 | resume_from_checkpoint=None, 120 | progress=gr.Progress(track_tqdm=True) 121 | ): 122 | cfg = prepareConfig( 123 | model_name_or_path=model_name_or_path, 124 | dataset_name_or_path=dataset_name_or_path 125 | ) 126 | cfg.output_dir = work_dir 127 | if 'LOCAL_RANK' not in os.environ: 128 | os.environ['LOCAL_RANK'] = str(cfg.local_rank) 129 | 130 | cfg_dict = cfg.to_tr_dict() 131 | if resume_from_checkpoint is not None: 132 | cfg_dict['resume_from_checkpoint'] = resume_from_checkpoint 133 | tr_args = TrainingArguments(**cfg_dict) 134 | print('=='*35) 135 | print('tr_args=', tr_args) 136 | print('=='*35) 137 | model, tokenizer = saft_build_model( 138 | model_name_or_path=cfg.model_name_or_path, 139 | return_tokenizer=True, 140 | qlora_flag=xtuner_type == 'qlora' 141 | ) 142 | train_dataset = process_hf_dataset( 143 | dataset=safe_load_data(cfg.dataset_name_or_path), 144 | tokenizer=tokenizer, 145 | max_length=cfg.max_length, 146 | dataset_map_fn=cfg.dataset_map_fn, 147 | template_map_fn=template_map_fn_factory(template=cfg.task_prompt_template), 148 | pack_to_max_length=cfg.pack_to_max_length 149 | ) 150 | # build trainer 151 | trainer = Trainer( 152 | model=model, 153 | args=tr_args, 154 | train_dataset=train_dataset, 155 | data_collator=partial(default_collate_fn, return_hf_format=True) 156 | ) 157 | # training 158 | trainer.train(resume_from_checkpoint=cfg.resume_from_checkpoint) 159 | trainer.save_state() 160 | trainer.save_model(output_dir=tr_args.output_dir) 161 | return tr_args.output_dir 162 | 163 | 164 | 165 | class quickTrain: 166 | def __init__(self, 167 | model_name_or_path, 168 | dataset_name_or_path, 169 | work_dir, 170 | xtuner_type='qlora', 171 | resume_from_checkpoint=None, 172 | run_type='mmengine'): 173 | self.model_name_or_path = model_name_or_path 174 | self.dataset_name_or_path = dataset_name_or_path 175 | self.xtuner_type = xtuner_type 176 | self.resume_from_checkpoint = resume_from_checkpoint 177 | self.run_type = run_type 178 | self.work_dir = work_dir 179 | self._break_flag = False 180 | self._t_handle_tr = None 181 | self.log_file = None 182 | self.mm_run_res_ = None 183 | 184 | def get_log_path(self): 185 | list_ = sorted([i for i in os.listdir(self.work_dir) if '.' not in i]) 186 | dir_name = list_[-2] if 'last_' in list_[-1] else list_[-1] 187 | self.log_file = os.path.join(self.work_dir, dir_name, f'{dir_name}.log') 188 | 189 | def set_model_path(self, model_path): 190 | print(f'set_model_path({model_path})') 191 | self.model_name_or_path = model_path 192 | 193 | def set_data_path(self, data_path): 194 | print(f'set_data_path({data_path})') 195 | self.dataset_name_or_path = data_path 196 | 197 | def set_xtuner_type(self, xtuner_type): 198 | print(f'set_xtuner_type({xtuner_type})') 199 | self.xtuner_type = xtuner_type 200 | 201 | def set_work_dir(self, work_dir): 202 | print(f'set_work_dir({work_dir})') 203 | self.work_dir = f'{work_dir}/work_dir' 204 | if not os.path.exists(self.work_dir): 205 | os.system(f'mkdir -p {self.work_dir}') 206 | 207 | def set_resume_from_checkpoint(self, work_dir=None): 208 | self.resume_from_checkpoint = work_dir 209 | 210 | def _t_start(self): 211 | self._t_handle_tr = threading.Thread(target=self._quick_train, name=f'X-train-{self.run_type}', daemon=True) 212 | self._t_handle_tr.start() 213 | 214 | def _quick_train(self, progress=gr.Progress(track_tqdm=True)): 215 | print( 216 | f'self.model_name_or_path={self.model_name_or_path}\nself.dataset_name_or_path={self.dataset_name_or_path}\nself.work_dir={self.dataset_name_or_path}\nself.xtuner_type={self.xtuner_type}' 217 | ) 218 | if self.run_type == 'mmengine': 219 | self.mm_run_res_ = None 220 | self.mm_run_res_ = mm_run(self.model_name_or_path, self.dataset_name_or_path, self.work_dir, self.xtuner_type, self.resume_from_checkpoint) 221 | return self.mm_run_res_ 222 | return hf_run(self.model_name_or_path, self.dataset_name_or_path, self.work_dir, self.xtuner_type, self.resume_from_checkpoint) 223 | 224 | def read_log(self): 225 | if self.log_file is None: 226 | return "" 227 | if not os.path.exists(self.log_file): 228 | return "" 229 | with open(self.log_file , 'r') as f: 230 | read_res = f.readlines() 231 | read_res = ''.join(read_res) # [-20:] 232 | if self.mm_run_res_ is not None: 233 | return f'{read_res}\n{self.mm_run_res_}' 234 | return read_res 235 | 236 | def start_log(self): 237 | time.sleep(10) 238 | return "Start Training" 239 | 240 | def quick_train(self, progress=gr.Progress(track_tqdm=True)): 241 | self.log_file = None 242 | self._break_flag = False 243 | self._t_start() 244 | time.sleep(5) 245 | self.get_log_path() 246 | self._t_handle_tr.join() 247 | if self._break_flag: 248 | return "Done! Xtuner had interrupted!" 249 | return self.work_dir 250 | 251 | def break_train(self): 252 | # 然后杀死该线程 253 | # 删除文件 254 | if self._t_handle_tr is not None: 255 | print('>>>>>>>>>>>>>>>>> break_download') 256 | stop_thread(self._t_handle_tr) 257 | self._t_handle_tr = None 258 | 259 | self._break_flag = True 260 | return "Done! Xtuner had interrupted!" 261 | 262 | 263 | 264 | def main_test(): 265 | model_ = '/root/share/model_repos/internlm-chat-7b' 266 | model_2 = '/root/share/model_repos/internlm2-chat-7b' 267 | TR_ = quickTrain( 268 | model_name_or_path=model_, 269 | dataset_name_or_path='/root/ft-medqa/MedQA2019-structured-train.jsonl', 270 | work_dir='./work_dir', 271 | xtuner_type='qlora', 272 | run_type='mmengine' 273 | ) 274 | TR_.quick_train() 275 | 276 | 277 | if __name__ == '__main__': 278 | main_test() 279 | 280 | 281 | 282 | 283 | -------------------------------------------------------------------------------- /xtuner_run/train_utils.py: -------------------------------------------------------------------------------- 1 | # python3 2 | # Create Date: 2024-01-29 3 | # Author: Scc_hy 4 | # Func: 参数准备 5 | # =========================================================================================== 6 | 7 | import os 8 | import transformers 9 | from dataclasses import dataclass, field 10 | from typing import List, Dict, ClassVar, Tuple, AnyStr, Callable 11 | import warnings 12 | import os 13 | import re 14 | import inspect 15 | import ctypes 16 | import inspect 17 | import ctypes 18 | warnings.filterwarnings(action='ignore') 19 | 20 | 21 | def _async_raise(tid, exctype): 22 | """Raises an exception in the threads with id tid""" 23 | if not inspect.isclass(exctype): 24 | raise TypeError("Only types can be raised (not instances)") 25 | res = ctypes.pythonapi.PyThreadState_SetAsyncExc(ctypes.c_long(tid), ctypes.py_object(exctype)) 26 | if res == 0: 27 | raise ValueError("invalid thread id") 28 | elif res != 1: 29 | # """if it returns a number greater than one, you're in trouble, 30 | # and you should call it again with exc=NULL to revert the effect""" 31 | ctypes.pythonapi.PyThreadState_SetAsyncExc(tid, None) 32 | raise SystemError("PyThreadState_SetAsyncExc failed") 33 | 34 | 35 | def stop_thread(thread): 36 | try: 37 | _async_raise(thread.ident, SystemExit) 38 | except Exception as e: 39 | print(e) 40 | 41 | 42 | @dataclass 43 | class prepareConfig: 44 | model_name_or_path: AnyStr 45 | dataset_name_or_path: AnyStr 46 | do_train: bool = True 47 | save_strategy: AnyStr = 'epoch' 48 | lr_scheduler_type: AnyStr = 'cosine' 49 | logging_steps: int = 5 50 | framework: AnyStr = 'huggingface' 51 | output_dir: AnyStr = './work_dir' 52 | deepspeed: bool = None 53 | local_rank: int = -1 54 | seed: int = 42 55 | max_length: int = 2048 56 | pack_to_max_length: bool = True 57 | batch_size: int = 1 58 | per_device_train_batch_size: int = 1 # per_device 59 | per_device_eval_batch_size: int = 1 # per_device 60 | # accumulative_counts: int = 16 61 | gradient_accumulation_steps: int = 16 62 | dataloader_num_workers: int = 0 63 | max_epochs: int = 3 64 | num_train_epochs: float = 3.0 # TrainingArguments 65 | optim: AnyStr = "adamw_torch" 66 | # optim_args 67 | optim_type: AnyStr = 'bitsandbytes.optim.PagedAdamW32bit' 68 | learning_rate: float = 2e-4 69 | betas: Tuple[float, float] = (0.9, 0.999) 70 | adam_beta1: float = 0.9 71 | adam_beta2: float = 0.999 72 | weight_decay: float = 0 73 | max_norm: float = 1 74 | max_grad_norm: float = 1 75 | warmup_ratio: float = 0.03 76 | task_prompt_template: AnyStr = 'xtuner.utils.PROMPT_TEMPLATE.internlm_chat' 77 | # Save 78 | save_steps: int = 500 79 | save_total_limit: int = 2 # Maximum checkpoints to keep (-1 means unlimited) 80 | # Evaluate the generation performance during the training 81 | evaluation_freq: int = 500 82 | system: AnyStr = "" # SYSTEM_TEMPLATE.coder 83 | # prompt trick 84 | evaluation_inputs: List = field(default_factory=lambda: [ 85 | '请给我介绍五个上海的景点', 'Please tell me five scenic spots in Shanghai' 86 | ]) 87 | # set visualizer 88 | visualizer = None 89 | # set log level 90 | log_level: AnyStr = 'info' 91 | # load from which checkpoint 92 | resume_from_checkpoint: AnyStr = None 93 | # whether to resume training from the loaded checkpoint 94 | resume: bool = False 95 | # Defaults to use random seed and disable `deterministic` 96 | randomness: Dict = field(default_factory=lambda: dict(seed=None, deterministic=False)) 97 | trust_remote_code: bool = True 98 | env_cfg: Dict = field(default_factory=lambda: dict( 99 | # whether to enable cudnn benchmark 100 | cudnn_benchmark=False, 101 | # set multi process parameters 102 | mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0), 103 | # set distributed parameters 104 | dist_cfg=dict(backend='nccl'), 105 | )) 106 | # Callable path 107 | dataset_map_fn:AnyStr = None 108 | 109 | def get_device_map(self): 110 | if self.deepspeed: 111 | self.device_map = None 112 | else: 113 | self.device_map = { 114 | '': int(os.environ.get('LOCAL_RANK', self.local_rank)) 115 | } 116 | 117 | def to_tr_dict(self): 118 | self_dict = self.__dict__ 119 | focus_key = [ 120 | 'output_dir', 121 | 'deepspeed', 122 | 'local_rank', 123 | 'seed', 124 | 'per_device_train_batch_size', 125 | 'per_device_eval_batch_size', 126 | 'dataloader_num_workers', 127 | 'gradient_accumulation_steps', 128 | 'num_train_epochs', 129 | 'optim', 130 | 'weight_decay', 131 | 'adam_beta1', 132 | 'adam_beta2', 133 | 'max_grad_norm', 134 | 'warmup_ratio', 135 | 'save_steps', 136 | 'save_total_limit', 137 | 'log_level', 138 | 'resume_from_checkpoint', 139 | 'deepspeed', 140 | 'do_train', 141 | 'save_strategy' , 142 | 'lr_scheduler_type', 143 | 'logging_steps', 144 | ] 145 | res = {} 146 | for k in focus_key: 147 | res[k] = self_dict[k] 148 | return res 149 | 150 | 151 | class prepareUtil: 152 | def __init__(self, cfg: prepareConfig, work_dir: AnyStr, lora_type: str='qlora'): 153 | self.cfg = cfg 154 | self.cfg.output_dir = work_dir 155 | self.work_dir = work_dir 156 | self.qlora_flag = lora_type == 'qlora' 157 | self._model, self._tokenizer = self.prepare_model_tokenizer() 158 | 159 | def auto_prepare(self): 160 | train_dataset, train_dataloader = self.prepare_data() 161 | optim_wrapper, param_scheduler = self.prepare_scheduler_optimizer() 162 | custom_hooks, default_hooks = self.prepare_hook() 163 | return dict( 164 | model=self._model, 165 | work_dir=self.cfg.output_dir, 166 | train_dataloader=train_dataloader, 167 | val_dataloader=None, 168 | test_dataloader=None, 169 | train_cfg=dict(by_epoch=True, max_epochs=self.cfg.max_epochs, val_interval=1), 170 | val_cfg=None, 171 | test_cfg=None, 172 | optim_wrapper=optim_wrapper, 173 | param_scheduler=param_scheduler, 174 | val_evaluator=None, 175 | test_evaluator=None, 176 | custom_hooks=custom_hooks, 177 | default_hooks=default_hooks, 178 | resume=self.cfg.resume_from_checkpoint is not None, 179 | env_cfg=self.cfg.env_cfg, 180 | visualizer=self.cfg.visualizer, 181 | log_level=self.cfg.log_level.upper(), 182 | randomness=self.cfg.randomness, 183 | launcher='none' 184 | ) 185 | 186 | def prepare_scheduler_optimizer(self): 187 | optim_wrapper = dict( 188 | type='mmengine.optim.AmpOptimWrapper', 189 | optimizer=dict( 190 | type=self.cfg.optim_type, lr=self.cfg.learning_rate, betas=self.cfg.betas, weight_decay=self.cfg.weight_decay), 191 | clip_grad=dict(max_norm=self.cfg.max_norm, error_if_nonfinite=False), 192 | accumulative_counts=self.cfg.gradient_accumulation_steps, 193 | loss_scale='dynamic', 194 | dtype='float16' 195 | ) 196 | param_scheduler = dict( 197 | type='mmengine.optim.CosineAnnealingLR', 198 | eta_min=0.0, 199 | by_epoch=True, 200 | begin=self.cfg.warmup_ratio * self.cfg.max_epochs, 201 | T_max=self.cfg.max_epochs, 202 | convert_to_iter_based=True 203 | ) 204 | return optim_wrapper, param_scheduler 205 | 206 | def prepare_model_tokenizer(self): 207 | tokenizer = dict( 208 | type='transformers.AutoTokenizer.from_pretrained', 209 | pretrained_model_name_or_path=self.cfg.model_name_or_path, 210 | trust_remote_code=True, 211 | padding_side='right' 212 | ) 213 | model = dict( 214 | type='xtuner.model.SupervisedFinetune', 215 | llm=dict( 216 | type='transformers.AutoModelForCausalLM.from_pretrained', 217 | pretrained_model_name_or_path=self.cfg.model_name_or_path, 218 | trust_remote_code=True, 219 | torch_dtype='torch.float16', 220 | quantization_config=dict( 221 | type='transformers.BitsAndBytesConfig', 222 | load_in_4bit=True, 223 | load_in_8bit=False, 224 | llm_int8_threshold=6.0, 225 | llm_int8_has_fp16_weight=False, 226 | bnb_4bit_compute_dtype='torch.float16', 227 | bnb_4bit_use_double_quant=True, 228 | bnb_4bit_quant_type='nf4') if self.qlora_flag else None 229 | ), 230 | lora=dict( 231 | type='peft.LoraConfig', 232 | r=64, 233 | lora_alpha=16, 234 | lora_dropout=0.1, 235 | bias='none', 236 | task_type='CAUSAL_LM')) 237 | return model, tokenizer 238 | 239 | def prepare_hook(self): 240 | custom_hooks = [ 241 | dict(type='xtuner.engine.DatasetInfoHook', tokenizer=self._tokenizer), 242 | dict( 243 | type='xtuner.engine.EvaluateChatHook', 244 | tokenizer=self._tokenizer , 245 | every_n_iters=self.cfg.evaluation_freq, 246 | evaluation_inputs=self.cfg.evaluation_inputs, 247 | system=self.cfg.system, 248 | prompt_template=self.cfg.task_prompt_template) 249 | ] 250 | # configure default hooks 251 | default_hooks = dict( 252 | checkpoint=dict(interval=1, type='mmengine.hooks.CheckpointHook'), 253 | logger=dict(interval=10, type='mmengine.hooks.LoggerHook'), 254 | param_scheduler=dict(type='mmengine.hooks.ParamSchedulerHook'), 255 | sampler_seed=dict(type='mmengine.hooks.DistSamplerSeedHook'), 256 | timer=dict(type='mmengine.hooks.IterTimerHook') 257 | ) 258 | return custom_hooks, default_hooks 259 | 260 | def prepare_data(self): 261 | train_dataset = dict( 262 | type='xtuner.dataset.process_hf_dataset', 263 | dataset=self.safe_load_dataset(self.cfg.dataset_name_or_path), 264 | tokenizer=self._tokenizer, 265 | max_length=self.cfg.max_length, 266 | dataset_map_fn=self.cfg.dataset_map_fn, 267 | template_map_fn=dict( 268 | type='xtuner.dataset.map_fns.template_map_fn_factory', 269 | template=self.cfg.task_prompt_template), 270 | remove_unused_columns=True, 271 | shuffle_before_pack=True, 272 | pack_to_max_length=self.cfg.pack_to_max_length) 273 | 274 | train_dataloader = dict( 275 | batch_size=self.cfg.batch_size, 276 | num_workers=self.cfg.dataloader_num_workers, 277 | dataset=train_dataset, 278 | sampler=dict(type='mmengine.dataset.DefaultSampler', shuffle=True), 279 | collate_fn=dict(type='xtuner.dataset.collate_fns.default_collate_fn')) 280 | return train_dataset, train_dataloader 281 | 282 | def safe_load_dataset(self, file): 283 | load_tp = 'datasets.load_dataset' 284 | tp = file.split('.')[-1] 285 | if tp == 'csv': 286 | return dict(type=load_tp, path='csv', data_files=dict(train=file)) 287 | if 'json' in tp: 288 | return dict(type=load_tp, path='json', data_files=dict(train=file)) 289 | # py 290 | return dict(type=load_tp, path=file) 291 | 292 | 293 | def main_test(): 294 | import transformers 295 | cfg = prepareConfig( 296 | model_name_or_path='/root/opencompass/InternLM/Shanghai_AI_Laboratory/internlm2-chat-7b', 297 | dataset_name_or_path='/root/ft-medqa/MedQA2019-structured-train.jsonl' 298 | ) 299 | print(f'type(cfg.evaluation_inputs)={type(cfg.evaluation_inputs)}') 300 | print(cfg.evaluation_inputs) 301 | pp = prepareUtil(cfg) 302 | pp_res = pp.auto_prepare() 303 | # pp_res = cfg.to_tr_dict() 304 | print('--'*25) 305 | print(pp_res.custom_hooks) 306 | print(pp_res.custom_hooks[0]['type']) 307 | # print('=='*35) 308 | 309 | 310 | if __name__ == '__main__': 311 | main_test() # checked 312 | 313 | 314 | --------------------------------------------------------------------------------