8 |
9 | # 1. 项目背景
10 |
11 | XTuner是由InternLM团队开发的一个高效、灵活且全能的轻量化大模型微调工具库。其主要用于多种大型语言模型的高效微调,包括大语言模型InternLM和多模态图文模型LLaVa。XTuner不仅提供了丰富的模型、数据集、数据管道和算法支持,还配备了现成的配置文件和快速入门指南,使得用户能够便捷地进行模型微调和部署。总体来看,XTuner为大型语言模型的微调提供了一个高效、全面且用户友好的解决方案,适用于追求性能优化和定制化的开发者和研究者。
12 |
13 | 虽然XTuner已经简化了大量微调中的步骤,但由于对于0基础的小白而言,还是具有一定的技术门槛。因此,借由InternLM官方推出的大模型实战训练营的机会,我们小组成员有幸与XTuner官方技术团队合作,在参考了 [LLaMA-Factory](https://github.com/hiyouga/LLaMA-Factory) 的基础上,根据XTuner的特性进行修改完善,从而完成了基于Gradio的XTuner可视化的界面设计。
14 | 此项目旨在为基础知识较弱的初学者提供便捷的微调解决方案,使他们能够通过简单的点击来尝试对模型进行微调。该界面能够实时展示训练信息和训练结果,并支持用户对微调后的模型与原始模型进行对比测试。此外,除了支持官方提供的模型和数据集之外,高级用户还可以上传自己的模型和数据集进行微调。这种自定义模型的功能不仅有助于初学者在已经微调过的模型基础上进行进一步的学习和探索,也大大增强了该界面的实用性和灵活性。
15 |
16 | # 2. 项目成员介绍
17 | XTuner GUI项目得到了XTuner官方的支持,因此除了浦语实战训练营里的四位成员外,还包括了两名XTuner专业的开发人员cwh及pppppM。下面是对各个成员的贡献进行介绍,感谢大家一个多月以来的辛勤付出!也感谢书生.浦语官方为我们所提供课程以及算力支持!相信AI Lab将持续做大做强,成为中国数一数二的开源社区!
18 | 
19 |
20 | 2.1 Xtuner GUI团队成员包括
21 | - Jianfeng777 – 负责整体前端开发、任务策划及文案撰写
22 | - Scchy – 负责整体后端开发及规划
23 | - L241025097 – 负责模型训练终端可视化
24 | - Semple030228 – 负责模型转换及整合部分
25 |
26 | 2.2 XTuner开发人员
27 | - HIT-cwh - 负责mmengine相关内容设置及配置文件生成,模型和数据集检查等开发工作
28 | - pppppM - 提供XTuner方面专业的指导意见
29 |
30 | # 3. 快速启动(仅支持Linux系统)
31 | 首先我们需要创建一个新的虚拟环境,并将GitHub的内容克隆到本地。
32 |
33 | ```bash
34 | conda create -n XtunerGUI python=3.10 -y
35 | conda activate XtunerGUI
36 | git clone https://github.com/scchy/XtunerGUI.git
37 | ```
38 |
39 | 然后我们需要进入仓库的内部并安装运行XTunerGUI所需要的包(假如安装速度过慢请使用清华源镜像)。
40 |
41 | ```bash
42 | cd XtunerGUI
43 | pip install -r requirements.txt
44 | ```
45 |
46 | 经过一段时间的安装后,我们就可以启动`app.py`文件进入我们创建的界面。
47 |
48 | ```bash
49 | python app.py
50 | ```
51 |
52 | # 4. UI界面介绍
53 | 本页面共分为六部分,内容涵盖了大语言模型中所有基础的步骤(具体可看下图),下面我将一步步带领大家了解页面的布局以及具体的使用方式。此外,我们可以在OpenXLab里查看完整的页面细节([链接](https://openxlab.org.cn/apps/detail/Scchy/XtunerFactory))。
54 |
55 | ## 4.1 本地路径设置
56 | 
57 |
58 | 第一步,我们先要输入两个本地的路径。一个是整体文件保存的位置,另外一个是模型和数据集下载保存的位置。对于文件保存路径(customer_path)而言,在该路径下将保存配置文件、所有的训练过程文件(权重文件、训练日志等)及模型转换后的内容。那对于模型数据集文件路径(download_cache)而言,该路径下将保存在UI界面中下载的模型和数据集文件。在对路径进行修改后切记要点击确认路径按钮哦!
59 | ## 4.2 微调方法、模型、数据集设置
60 | 
61 |
62 | 第二步,在这里我们需要选择微调的方法,目前已支持QLoRA、LoRA和全量微调(full)三种方法,大家可以根据自己的硬件情况和实际任务需求进行选择。另外我们也支持大量的官方数据集和模型,通过点击下方下载按钮即可自动从Modelscope、Huggingface和OpenXLab三个平台进行下载。假如发现下载错模型或者数据集也可点击取消下载工作,下载任务取消后将删除已下载的内容,从而减少内存的占用。
63 | 
64 |
65 | 另外我们还支持大家上传自定义的模型或者数据集。
66 | - 对于自定义的模型,我们可以通过上传本地的路径,并且点击按钮即可检查模型是否可用及对模型提示词模版的匹配(在UI界面点击下载按钮下载的官方模型会自动进行提示词模版匹配),不同的模型会有其独特的提示词模版,更多详细信息可以进入UI界面中查看。
67 | - 对于自定义的数据集,我们目前仅支持OpenAI的数据集格式(最通用的数据集格式)。在UI界面中也展示了OpenAI数据集的格式,大家可以通过各类的大语言模型(比如说ChatGPT)对自己的数据集进行格式的更改。在将数据集格式转为OpenAI格式后,我们可以通过输入本地数据集路径或者将文件在gradio的文件框上上传。在完成数据集文件上传后,还可点击按钮检查数据集是否符合规定。
68 | ## 4.3 微调参数设置
69 | 第三步,在微调参数设置处,我们将所有的参数分为了基础参数和进阶参数两部分。我们根据微调方法的不同设置了几套默认的参数模版,一般情况下大家可以直接使用该参数进行微调即可。
70 | 
71 |
72 | 在基础参数部分是我们比较常用的参数,包括学习率、预热比、数据集最大长度、GPU数量、设备样本个数以及评估问题等等。值得一提的是我们可以自己设置多个评估问题,默认的评估问题是"请给我介绍五个上海景点"的中英文版,但是我们可以将其修改为我们所需要的问题,并且通过选择问题的数量可以增加在评估时候的问题(最多十个问题)。
73 | 
74 |
75 | 对于进阶参数而言,就是一些比较不常使用也不怎么需要修改的参数,比如说优化器的类型、权重衰减、梯度剪裁等。这些虽然会影响模型的效果,但是修改优化的难度比较大,大家在使用过程中除非有明确的修改目的,否则不建议大家进行更改。
76 | 在完成了参数的设置后,接下来就需要点击按钮生成配置文件了。配置文件的生成的模型训练的灵魂,模型的训练过程和方向都是基于配置文件里的内容进行。在这里配置文件的创建就是基于上面我们设置的内容。这里需要注意的是,假如大家同时在自定义数据集/模型以及在GUI界面下载了模型/数据集,这里默认以自定义的数据集/模型作为配置文件的模型/数据集的路径,因此大家在使用的过程中需要注意这一点。
77 | ## 4.4 微调模型训练
78 | 
79 |
80 | 在完成配置文件后,我们就可以点击按钮启动模型训练工作。当然我们也可以点击按钮暂时中断我们的训练过程。当我们需要对中断的模型继续训练的时候,我们可以选择之前保存的权重并点击按钮继续训练。中断后续训是不会影响最终模型训练的效果的。
81 | 另外,在我们点击训练后,我们可以打开下面的终端界面查看训练的过程以及内容,这样我们就能够更好的监控整体的训练过程。假如训练效果过差,我们也能够及时进行模型训练的中断,以免浪费无谓的时间。
82 | ## 4.5 微调结果展示
83 | 
84 |
85 | 在模型微调进程结束后,我们可以点击下方按钮生成一些关键内容的展示。包括说损失函数的的变化图、学习率在训练过程中的变化图以及不同权重文件下测试问题。这样我们就既能够看到模型训练过程的变化,也能够通过测试问题的对比来看到模型是否过拟合,从而找到最优的权重文件进行模型测试及部署。
86 | ## 4.6 微调模型转化及测试
87 | 
88 |
89 | 在我们通过微调结果展示找到效果最好的模型权重文件后,我们还需要将我们的模型转化为常见的HuggingFace格式。对于LoRA或者QLoRA方式微调出来的模型还需要与原模型进行整合。在这里我们合并了这两部分,我们会基于大家第二步选择的微调方法进行调整。具体的操作就是我们需要在下拉列表中找到对应的权重文件后,点击模型转换按钮即可。
90 | 
91 |
92 | 在模型转换后,我们就可以进行对话的测试。在左边可以展示原来底座模型的效果,而右边展示的是微调后模型的效果。我们只需要选择合适的模型推理参数,点击模型启动即可进行对话。我们可以通过原模型和微调后模型的对比查看微调的实际效果。
93 |
94 | 以下是我们录制的一个简短使用视频([B站](https://www.bilibili.com/video/BV1av42117yT/?spm_id_from=333.999.0.0)),大家可以通过视频来作更进一步的了解。
95 |
96 | 以上就是页面的一个基本的介绍,假如大家想单纯的使用的话就可以马上开始上手啦!但是假如大家希望对我们设计的思路以及原理有更深刻的认识的话,那就继续往下看吧!
97 | 下面的部分我们将谈谈XTuner GUI背后的XTuner的运作原理,从而能够更深一层次的了解XTuner GUI的实现原理。正所谓知其然还需要知其所以然,假如我们能够真正的通过XTuner整体的结构设计以及指令,那我们就能更好的理解XTuner GUI项目的运行机理。
98 |
99 | # 5. XTuner流程介绍
100 | 对于XTuner的基本操作,我们可以通过以下这张图,简单的了解一下。高清图片链接请点[击此位置](https://www.figma.com/file/0SVTWhnGxbY7ADy2UEluCR/XTuner-Flow?type=whiteboard&node-id=0%3A1&t=bzZP6fCSAuBj2uon-1)。
101 |
102 | 
103 |
104 | 可以看到,整个工作流程分为以下四个步骤(具体各个步骤的调用代码可参考下图):
105 | 
106 |
107 | ## 5.1 数据采集及格式转换
108 | 
109 |
110 | - 首先,根据任务的不同,我们需要明确微调目标,进行数据采集,并将数据转换为 XTuner 所支持的格式类型。这部分需要大家自行完成,当然我们假如只是体验的话仅需要使用官方支持的数据集即可。
111 | - 然后我们还需要根据自己的硬件条件选择合适的微调方法和合适的基座模型。不同的基座模型对显存的需求都不太一样,模型参数越大,微调所需要显存就越多。而在微调方法中,对显存需求最小的就是QLoRA(最少8GB即可运行),而显存需求最大的则是全量微调。
112 | ## 5.2 配置文件的创建
113 | 
114 |
115 | - 首先,我们可以通过执行 xtuner list-cfg 命令列出所有配置文件。
116 | - 通过上面选择的微调方法和基座模型找到合适的配置文件,并使用 xtuner copy-cfg ${CONFIG_NAME} ${SAVE_PATH} 命令复制到本地端。
117 | - 复制完成后还需要根据自己的需求修改配置文件以更新模型路径和数据集路径。
118 | - 特定时候还需要调整模型参数和配置,更改 load_dataset 函数和 dataset_map_fn 函数。并根据模型选择合适的 prompt_template。
119 | ## 5.3 模型训练
120 | 
121 |
122 | - 修改配置文件后,我就可以使用 xtuner train 命令启动训练。
123 | - 除此之外我们还可以设置特定参数优化训练,如启用 deepspeed,以及设置训练文件的保存路径。
124 | - 假如意外的中断了训练,还可以通过加上--resume {checkpoint_path}的方式进行模型续训。具体可看下面的指令详解。
125 | ## 5.4 模型转换、测试及部署
126 | 
127 |
128 | - 在完成训练后,找到对应的训练文件并执行 `xtuner convert pth_to_hf` 命令,就可以将转换模型格式为 `huggingface` 格式。
129 | - 对于LoRA类的模型而言,则需要执行 `xtuner convert merge` 命令将 `adapter` 层与原模型进行合并。
130 | - 转换完成后,我们就可以以转换后的文件路径并使用 `xtuner chat` 命令启动模型进行性能测试。
131 | - 除此之外,我们还可以在安装 `LMDeploy` 后通过 `python -m lmdeploy.pytorch.chat` 命令进行模型部署,即使用TurboMind进行推理。
132 |
133 | 以上就是关于XTuner项目的一些基础功能及指令的展示,下面我将通过XTuner GUI整体的逻辑图来深入的剖析我们在设计原型过程中的思路和对现有流程的优化点。
134 | # 6. XTuner GUI设计思路介绍
135 | 那对于XTuner GUI而言,我们将其分为了六个部分路径设置、模型数据集微调方法设置、相关参数设置、模型训练、训练结果展示、模型转换及测试部分。之所以这样进行设计,主要目的还是希望作为一个小白可以先抛开一系列的专业知识,能够真真正正的先将模型跑起来看到效果后,再一步步的进行研究到底每一步的原理是什么。想当年我学习OpenMMLab相关的算法库,例如MMDetection和MMSegementation,我最开始的时候也就是去找到一些数据集然后跑起来,然后再慢慢研究怎么优化,整体运行逻辑。同样的,虽然XTuner的门槛已经非常低了,但是我希望能够把这个门槛能够降得更低,能够让更多人能够无痛上手。
136 |
137 | 那真正让用户无痛上手,那就必须要砍掉让他们思考的部分。比如说在原生的XTuner里面,我们还需要自己下载模型数据集,还需要找到合适的配置文件,还需要我们自己找到对应的文件夹进行转换等等,这些通通都不再需要,我们只需要通过点击按钮、选择下拉框的内容或者说调整一下参数就可以将模型跑起来了,并且最后也将整体的对话测试也是直接输入文本即可同时与原模型进行对话,这些都是我们希望能够最小化大家跑起模型的难度,能够真正打开大模型微调的大门。
138 |
139 | 除了对0基础的小白进行支持以外,我们对拥有一定使用经验的人也作出了考量。首先就是增加了自定义模型和自定义数据集两部分,那对于想要对自己的数据集或者模型微调的人就能够节省真正进入文件修改的时间。其次是提供了大量可修改的参数让大家进行选择。这些对于一个拥有一定经验的“炼丹师”而言,无疑是非常有意义的。
140 |
141 | 那对于大师而言,尤其是需要训练多模态模型的人,这里其实我们就没有做过多的特定支持。主要原因是这部分人群的代码能力和调试能力非常强,无论是使用原生的XTuner或者其他的微调工具都会得心应手,不需要过多在这些细节上进行可视化的展示。总的来说,我们所针对的人群其实更多是哪些0基础的小白以及有一定经验的炼丹师,通过使用这样一个工具能够更好的完成他们的工作。
142 | 那对于XTuner GUI而言,我同样也是制作了一个逻辑图来展示整体的运行思路(高清图片链接请[点击此位置](https://www.figma.com/file/wFN0wMlknYyzV3ZMCihnPC/XTuner-GUI-Flow?type=whiteboard&t=ch8xUYvYdXnoWGNX-1))。
143 |
144 | 下面将一步步的解释整体的架构,并说明相比于原生的XTuner,我们作出了哪些的调整以及设计时的思路:
145 |
146 | ## 6.1 路径设置
147 |
148 |
152 |
153 | 首先我们可以看到我们需要传入的是文件保存路径及模型数据集文件保存路径。对于文件保存路径(customer_path)而言,在该路径下将保存配置文件、所有的训练过程文件(权重文件、训练日志等)及模型转换后的内容。那对于模型数据集文件路径(download_cache)而言,该路径下将保存在UI界面中下载的模型和数据集文件。
154 | 其实在初版的设计当中,这一部分其实是没有被添加进去的,但是后面我们发现下载的模型和数据集文件都可能会相对比较大,再加上后面微调后的文件可能会撑爆内存,因此我们决定将两者分开,用户可以自行选择合适的路径进行保存,这样就可以降低内存的压力。当然对于那些内存充足的人而言,仅需要按照默认的路径即可。
155 | ## 6.2 模型、数据集及微调方法设置
156 | 
157 |
158 | 对于所有微调任务而言,第一步我们要考虑的都是说,我要微调什么模型,我要用什么数据集去微调这个模型,具体使用的微调方法是什么。这三个基本的步骤其实就拦住了很多的人,因为他们不知道去哪里找这些东西。即便是使用原生XTuner的时候,即便我们真的根据仓库中给出的快速开始将模型跑起来,但是我们还是可能不太理解这一切是怎么执行的。并且,在XTuner仓库里已有的config文件其实并不包含所有的模型、数据集和微调方法的组合,因此对于那些想直接用但是找不到对应config文件的人们来说,可能就真的是从入门到放弃了,毕竟对于他们而言,修改一个类似的config然后调整里面对应的东西难度都太高了吗,真的能做下来也不是0基础的小白了。
159 | 基于以上的思考,我们所做的就是简化这一系列的流程。首先我们设置了下拉框来直接根据需求选择模型微调的方法。其次是对于在Huggingface、Modelscope和OpenXLab上已有的数据集和模型,我们提供下拉框让他们直接进行选择并能够点击按钮进行下载。下载完的模型也将自动保存在上面设置的模型数据集文件保存路径(download_cache)上。这样用户是真的知道自己要训练的是一个什么模型,用的是一个什么数据集,具体的微调方法是什么,而不是仅仅给他们一个config文件一个文件名去自己领悟。
160 | 其次对于进阶用户的自定义模型和数据集,那用户可以选择上传自己的模型然后使用官方的数据集,也可以使用官方的模型然后使用自己的数据集进行微调,这些都是可行的。并且无论是数据集还是模型,我们都增加了一个检查的机制,来帮助用户了解自己的模型和数据集是否存在问题,是否能够正常使用,那就避免了后续出现bug无法解决的问题。
161 | 那对于模型而言,还有一个很重要的步骤就是有一个与之相匹配的提示词模版。一般来说,用户上传自己的模型也不会说提供一套自己的提示词模版,真的能够训练出一个自己提示词模版的模型那也不是小白了。一般而言,这些用户自己上传的模型都是微调别的官方大模型实现的。基于这一层的思考,我们就决定了对于用户自定义上传的模型,我们不仅检查其是否可用,还找到其对应的提示词模版,这样用户也不需要再找到原模型的提示词模版然后放进去了,这就节省了他们不少的时间和精力。
162 | ## 6.3 相关参数设置
163 | 
164 |
165 | 那在准备好模型数据集了以后,其实我们最重要需要的就是模型和数据集的一个路径作为配置文件的一个重要参数。除此之外呢,其实我们还是需要对一些基本的超参数以及训练的一些参数进行设置,这样才能够生成一个好的pipeline用于实际的模型训练。
166 | 对于大部分的人而言,如何调整参数都是一件相对复杂且困难的事情,因此这里我们也是基于微调方法的不同配套了不同的超参数模版,这个在XTuner原生的config创建时就已经制定好的了,只不过没有公开而已。另外,在UI界面里,每个超参数我都是进行了基本的解释和介绍,以方便初学者对这些参数的含义进行了解。
167 | 除了默认的基本超参数设置,一些个性化的内容还需要我们自己进行设置,包括GPU的数量、是否启动Deepspeed优化等。基于以上提到的这些的参数我们就能够生成一份可用的配置文件,并存放在我们最开始设置的文件保存路径(customer_path),从而能够利用创建好的这个pipeline开始我们个性化的训练过程。
168 | 但是需要注意的是,这个配置文件生成的内容并不是我们平常在XTuner见到的一个config文件,而是一个具有同等效用的json文件,这里也需要感谢XTuner官方开发团队的cwh为我们提供的支持。
169 | ## 6.4 模型训练
170 | 
171 |
172 | 其实到了模型训练这个阶段,基本上就是程序自己运行的事情了。但是我们考虑到有可能会出现的情况是中途因为什么特殊原因训练被中断了,这个时候我们必须提供一个续训的选项让用户能够重新将模型跑起来。假如是重新拉起一个训练过程的话,训练的结果可能和预先结果不太一样,这样问题也蛮严重。因此我们就选择使用XTuner train里面提供的选项--resume来进行续训。
173 | 其实在原本的考虑当中,我希望做成的样子其实是可以想OpenMMLab里面的魔改。就比如说我在MMDetection里的模型Loss一直降不下去了,那我可能就会拿到最后一个epoch的权重文件,然后修改config文件把这个权重文件load进去,并且修改调高学习率等参数来尝试看看能不能好的效果。但是在大模型中,这种魔改的操作并不常见,并且也不是初学者需要考虑的内容,因此在这里也没有加上去了。但是未来要是有需求的话也可以把这部分内容加上去。
174 | 此外,在我们最初的设置里也没有加上终端界面展示部分的内容,但是我们考虑到的一点是:模型训练所要花费的时间太长了,假如用户就看着前端的界面在那转,但是不知道到底发生了什么,训练到了哪一步,自己设置的评估问题回复大概是什么样的话,那么这也是一件很煎熬的事情。因此我们也是增加了一个不断更新的终端界面,以展示训练过程的内容和状态。并且有实验证明说,假如用户能够看到进度条的话,那么整体的焦虑情绪就不会那么严重(除非一直卡在99%),因此后续我们也将根据iter的数量加上一个进度条来让用户看到整体的进度,从而降低他们等待的焦虑。
175 | 在训练的过程中会产生大量的文件,包括说模型的权重文件以及记载模型训练过程的log文件,这些文件都将保存在文件保存路径(customer_path)下。需要注意的是,我们可以通过调整epoch数量、每个n个iter保存一次权重文件以及只保留n个权重文件这类型的训练参数来进行调整实际的训练过程。这样就能够让用户能够自己决定到底训练怎么进行,从而给予了大家更多的自由度。
176 | ## 6.5 训练结果展示
177 | 
178 |
179 | 在训练完成后,我们其实可以通过查看损失值的变化图等内容来看看模型训练的成果如何。这里我主要希望能够让大家能够真正的看到自己训练的模型是否满足大致的要求,尤其是看到不同的权重文件对应自己设定的评估问题的答案。这个能够看到不同阶段下模型能够对评估问题产生的回复其实我认为非常重要,这个其实就等同于看到一个婴儿一岁到两岁到三岁到长大成人的变化过程。可能一开始模型微调的时候所得到的答案并不满意,到后面慢慢变得成熟,到最后过拟合可能就只能懂得回答一个问题,这样的变化过程可以让我们清晰的找到哪个iter下的权重文件是最好的,从而能够真正筛选出好的权重文件。
180 | 此外,单纯的看Loss看learning rate其实真的并不足够,因为这些都是一些数字而已,Loss太低你可能会觉得说肯定过拟合了,那Loss太高你可能说是不是没训练好。真正能够评判一个模型的好坏还是要通过真刀真枪的进行对话测试。那在对话测试的话一般我们也就只是对某几个问题进行测试,那假如我们把这些关键的核心的问题能够放在评估问题里,让模型在训练过程中回答并看到其中变化,这确实是能够节省不少的时间。模型的其他能力比如说上下文啊,连续回复等等的就交给最后的对话测试环节吧,仅仅筛选的话通过看设定的评估问题的答案其实完全足够了。
181 | ## 6.6 模型转换及测试
182 | 
183 |
184 | 在找到最好的模型权重文件后,由于训练过程生成的是Pytroch格式权重文件,因此我们在后续使用前需要将其转为huggingface格式。那对于Lora或Qlora方法训练出来的模型是额外的层(adapter),因此需要将其与原模型进行整合才能进行进一步操作,而全量微调仅需要进行整合工作即可。整合后的模型也将保存在文件保存路径(customer_path)当中。在XTuner里面,模型的转换和整合是被分开来的,但是由于在UI界面完成的都是指定的任务,因此我们可以把这两步组合起来,根据最初设定的微调方法进行判定,假如是QLoRA或者LoRA的话就使用转换+整合,假如是全量微调就只用转换即可。这样用户就不用再去找文件夹然后输入到终端进行设置,而是直接一键生成即可。
185 | 另外对于模型对话测试这一块,我们其实是希望能够让用户感受一下微调前模型和微调后的差别。主要就是在页面里将聊天框一分为二,一边是原模型一边是微调后的模型,那用户同时问一个问题的时候两个模型都能给你回复,并且你也能看到两者的差别。这样用户其实就能够真的感受到微调后到底是什么样的,这也帮助他们了解进一步改进的方向。那假若真的测试后对模型不满意,可以将转换后模型路径传入自定义模型中继续训练或者重新进行训练。然后重复这样一个流程即可。
186 | # 7. 总结及展望
187 | 总的来说,我们在XTuner提供的基础功能基础上,增加了一些独有的内容,包括说绘制loss和learning rate的图表、提取每次评估问题的答案并保存、手动下载数据集和模型到本地使用等等。我们可以很自信的说,该界面目前已实现了基本的微调功能,并且能够有效的帮助新入门者快速上手并训练属于自己的模型。我们相信随着XTuner的流行,XTuner GUI也将能够受到更多的关注,并帮助更多的初级开发者快速上手模型的微调工作。
188 | 对于这个XTuner GUI项目而言,未来可能会朝着以下几个方向进行持续的发展:
189 |
190 | - **完善XTuner GUI的功能**:由于我们只花费了一个多月,并且大家都并不是全职的进行开发,因此其中还有一些问题需要解决,包括说实时显示训练的损失值,显示训练的进度条等等。另外还有很重要的一个内容就是需要适配Windows系统,由于我们的开发都是在浦语所提供的开发机上所进行的,而开发机的环境是Linux系统,因此我们还需要对Windows系统以及MacOS系统作出适配的工作。
191 | - **利用SCC及HTML等前端界面完善界面**:目前我们的前端开发完全依赖于Gradio,这主要是由于我作为一名外行的前端开发者仅仅只懂得Gradio的制作,因此后期可能还需要专门的前端开发人员对界面可能作进一步的设计,让整体更加美观。
192 | - **接入模型部署功能**:由于微调好的模型转换后的格式为HuggingFace的,但是HuggingFace格式下的推理速度和模型文件大小可能都不满足实际项目需求,所以可能还是需要进行实际的落地部署工作。因此我们还需要与LMDeploy进行整合,完成包括模型本地部署以及模型api部署等工作。
193 | - **打造类似于OpenAI的PlayGround**:为了降低Agent的使用门槛,我们预计将来也将推出类似于PlayGround的界面,届时可以像OpenAI一样实现Function call、Retrieval、Code Interpreter、Vision、Audio以及Video等多模态调用的平台。
194 |
195 | 以上就是XTuner GUI这个项目未来的计划,再次感谢所有人员对此的付出,也感谢星佬、米佬等书生.浦语官方人员的帮助和算力支持,未来我们会持续的推进这个计划,为更多新手的快速入门出自己的一份力,也为开源社区作出更多的贡献!
196 | 作为项目的主要负责人,完成这个任务并且参与其中我也觉得非常的有成就感。一方面,能够为开源社区提供一套方案是一件非常有意义的事情,另一方面,在这个项目的开展过程中也认识了很多为爱发电的大佬们,大家能够一起不断学习且进步。真心希望书生浦语开源社区的生态越来越好,为中国的开源社区建立一个良好的榜样!
197 | 假如认为XTuner和XTuner GUI真的有对你产生帮助的话,也希望能给我们Star呀,对于开源社区来说,多一个的Star就代表多一份的认可。正如那句歌词说的,假如人人都能贡献出一点爱,那么世界也将更加的美好~相信随着开源社区的繁荣,国内也能不再那么的浮躁,而是真的做一些推动人类发展的大事吧!
198 |
199 |
200 | > **都看到这里了,不留下个Star是不是有点说不过去啦~那就动动手指点进下面的链接给我Star啦!**
201 | > - **XTuner项目链接**:[https://github.com/InternLM/xtuner](https://github.com/InternLM/xtuner)
202 | > - **XTuner GUI项目链接**:[https://github.com/scchy/XtunerGUI](https://github.com/scchy/XtunerGUI)
203 | >
204 | > **什么?还要给我fork?!那我就代表所有开发者对你感谢啦!好人一生平安!**
205 |
206 |
--------------------------------------------------------------------------------
/README_tmp.md:
--------------------------------------------------------------------------------
1 | # XtunerGUI
2 | Xtuner Factory
3 |
4 | [Disign Doc: XtunerGUI](https://aab2vs0do9o.feishu.cn/docx/JWkbdoDiboVKBAxUyQvcg9MQnbb?from=from_copylink)
5 |
6 |
7 |
8 | - 下载`xtuner_donwload`
9 | - 模型下载
10 | - 入参: `model: gr.Dropdown`
11 | - 出参: `model_path: gr.Textbox`
12 | - 路径: 当前版本路径 `XtunerGUI/appPrepare/download_cache/model_download/{username}_{repository}`
13 | - 数据下载
14 | - 入参: `dataset: gr.Dropdown`
15 | - 出参: `data_path: gr.Textbox`
16 | - 路径:当前版本路径 `XtunerGUI/appPrepare/download_cache/data_download/dataset_{username}_{repository}`
17 |
18 | - config `xtuner_config`
19 | - 入参: 非常多
20 | - 出参:`cfg_py_box`
21 |
22 | - fintune `xtuner_run`
23 | - 指定 环境变量
24 | - shell_train: 直接执行shell `xtuner train xxxx`
25 | - 日志显示: -> 日志显示慢的问题
26 |
27 | - 转换合并 `xtuner_convert`
28 | - `convert_and_merge.py`
29 | - 入参:
30 | - todo: 选择epoch
31 | - `config_file`: `xtuner_config` 生成
32 | - `pth_model`: work_dir 目录下模型问题
33 | - `save_hf_dir`: 指定生成目录 {work_dir}/hf
34 | - `model_path`: 模型路径 `xtuner_donwload` 产出 `model_path`
35 | - `save_merged_dir`: 指定生成目录 {work_dir}/merge_epoch{n}
36 |
37 |
38 | todo:
39 | - [X] todo: load_dataset 下载数
40 | - [ ] 测试问题 是否可以添加, 直接输入list
41 | - [ ] 自定义模型
42 | 1. template 映射 ->
43 | 2. 路径校验(template)
44 | - [X] 路径最终确定
45 | - [ ] prompt_template 位置改动?
46 | - [ ] 自定义数据集 只支持openAI 数据集格式
47 |
48 |
49 | ```text
50 | customer_path
51 | |-- download_cache
52 | | |-- data_download
53 | | | `-- tatsu-lab_alpaca
54 | | `-- model_download
55 | | `-- internlm_internlm-chat-7b
56 | `-- work_dir
57 | |-- 20240202_153301
58 | | |-- 20240202_153301.log
59 | | `-- vis_data
60 | |-- iter_100.pth
61 | |-- iter_50.pth
62 | |-- last_checkpoint
63 | |-- xtuner_config.py
64 | `-- xtuner_iter_100_hf
65 | |-- README.md
66 | |-- adapter_config.json
67 | |-- adapter_model.safetensors
68 | `-- xtuner_config.py
69 |
70 | ```
71 |
72 | /root/share/model_repos/internlm-chat-7b
73 | /root/personal_assistant/data/personal_assistant_openai_final.json
74 |
75 |
76 | ## Test
77 | - [X] customer-root /root/sccTest3
78 | - [X] customer-data-dir /root/download_cache
79 | - [X] customer model: /root/share/model_repos/internlm-chat-7b
80 | - [X] check customer model template detect
81 | - [X] -> detect_prompt_template -> prompt_template_show
82 | - [X] customer dataset:
83 | - /root/personal_assistant/data/personal_assistant_openai_final.json
84 | - /root/xtunerUITest/ttt.json
85 | - [X] data: tatsu-lab/alpaca -> downloading
86 | - [X] config
87 | - [X] ft_method -> DEFAULT_HYPERPARAMETERS
88 | - [X] generate check
89 | - [ ] xtuner
90 | - [ ] running without pregress ?
91 | - show result
92 | - [X] plot
93 | - [X] dynamic select_checkpoint ->
94 | - convert
95 | - [X] choose pth
96 | - [ ] convert progress
97 |
98 |
99 |
100 |
101 |
102 |
103 |
--------------------------------------------------------------------------------
/app.py:
--------------------------------------------------------------------------------
1 | # python3
2 | # Create Date: 2024-01-26
3 | # ========================================
4 |
5 | from xtuner_download.download_model import xtunerModelDownload
6 | from xtuner_download.download_dataset import xtunerDataDownload
7 | from xtuner_convert.convert_and_merge import convert_and_merged
8 | from xtuner_convert.convert_with_progress import ConvertMerged
9 | from xtuner_run.shell_train import quickTrain
10 | from appPrepare.files_prepare import DATA_DOWNLOAD_DIR, MODEL_DOWNLOAD_DIR, CUR_PATH, DEFAULT_DOWNLOAD_DIR
11 | from appPrepare.list_prepare import DATA_LIST, MODEL_LIST, PROMPT_TEMPLATE_LIST
12 | from appPrepare.func_prepare import read_first_ten_lines, get_template_format_by_name, OPENAI_FORMAT
13 | from xtuner_config.build_config import build_and_save_config, model_path_map_fn
14 | from xtuner_config.check_custom_dataset import check_custom_dataset
15 | from xtuner_config.get_prompt_template import app_get_prompt_template
16 | from xtuner_config.get_default_hyperparameters import get_default_hyperparameters
17 | from chat.model_center import ModelCenter
18 | from tqdm import tqdm
19 | from xtuner_result.draw import resPlot
20 | import gradio as gr
21 | import warnings
22 | warnings.filterwarnings(action='ignore')
23 | CHAT_ORG = ModelCenter()
24 | FT_CHAT_ORG = ModelCenter()
25 | CVT_MG = ConvertMerged()
26 |
27 | def combine_message_and_history(message, chat_history):
28 | # 将聊天历史中的每个元素(假设是元组)转换为字符串
29 | history_str = "\n".join(f"{sender}: {text}" for sender, text in chat_history)
30 |
31 | # 将新消息和聊天历史结合成一个字符串
32 | full_message = f"{history_str}\nUser: {message}"
33 | return full_message
34 |
35 | def respond(message, chat_history):
36 | # message1 = combine_message_and_history(message,chat_history)
37 | # client = OpenAI()
38 | # messages=[
39 | # {"role": "system", "content": "You are a helpful assistant."},
40 | # {"role": "user", "content": message1}
41 | # ]
42 |
43 | # completion = client.chat.completions.create(
44 | # model="gpt-3.5-turbo",
45 | # messages=messages,
46 | # max_tokens=150, # 设置生成响应的最大 token 数量
47 | # seed=12345, # 设置种子以获得确定性采样(如果可能)
48 | # temperature=0.7, # 设置采样温度
49 | # top_p=0.9 # 设置核心采样的概率质量百分比
50 | # )
51 | # bot_message_text = completion.choices[0].message.content
52 | # #这里的bot_message_text就是最后输出的文本
53 | # chat_history.append((message, bot_message_text))
54 | return "", chat_history
55 |
56 | def clear_history(chat_history):
57 | chat_history.clear()
58 | return chat_history
59 |
60 | def regenerate(chat_history):
61 | if chat_history:
62 | # 提取上一条输入消息
63 | last_message = chat_history[-1][0]
64 | # 移除最后一条记录
65 | chat_history.pop()
66 | # 使用上一条输入消息调用 respond 函数以生成新的回复
67 | msg,chat_history = respond(last_message, chat_history)
68 | # 返回更新后的聊天记录
69 | return msg, chat_history
70 |
71 |
72 | def evaluation_question_number_change_wrap(max_textboxes):
73 | def evaluation_question_number_change(k):
74 | k = int(k)
75 | return [gr.Textbox(visible=True)]*k + [gr.Textbox(value='', visible=False)]*(max_textboxes-k)
76 | return evaluation_question_number_change
77 |
78 | with gr.Blocks() as demo:
79 | gr.Markdown(value='''
80 |
81 |

82 |
83 |
84 | [](https://github.com/scchy/XtunerGUI/stargazers)
85 | ''')
86 |
87 | with gr.Tab("基础训练"):
88 | with gr.Accordion(label='使用指南', open=False):
89 | gr.Markdown('## 流程图')
90 | # process = gr.Image(value='/root/XtunerGUI/pictures/workflow.jpg',label='使用流程图',container=False,show_download_button=False )
91 | gr.Markdown('## 演示视频')
92 | # video_customer_introduction = gr.Video(label='Xtuner GUI用法演示',value='/mnt/d/xtuner/demo.mp4',interactive=False)
93 | gr.Markdown("## 1. 本地路径设置")
94 | with gr.Row():
95 | local_path = gr.Textbox(
96 | label='请上传所有文件保存的文件本地路径',
97 | value=CUR_PATH,
98 | info='将会在选择的路径下保存模型的配置文件、训练过程文件及模型转换后的文件'
99 | )
100 | local_model_path = gr.Textbox(label='请确定数据集和模型下载的本地位置', value=DEFAULT_DOWNLOAD_DIR, info='将保存所有通过下方按钮下载的模型和数据集内容在该路径')
101 |
102 | # 这个里面是存放着保存数据集的路径
103 | local_path_button = gr.Button('确认路径')
104 | gr.Markdown("## 2. 微调方法、模型、数据集设置")
105 |
106 | with gr.Row():
107 | ft_method = gr.Dropdown(choices=['qlora', 'lora', 'full'], value='qlora',label = '微调方法', info='''请选择需要的微调方法(全量微调(full)需要大量显存,请谨慎选择)''',interactive=True)
108 | with gr.Column():
109 | model = gr.Dropdown(choices=MODEL_LIST + ['自定义'], value='internlm/internlm-chat-7b',label = '模型', info='请选择你希望微调的模型,选择后可点击下方按钮进行下载',interactive=True)
110 | DM_CLS = xtunerModelDownload(
111 | model_name=model.value,
112 | out_path=MODEL_DOWNLOAD_DIR,
113 | tqdm_class=tqdm
114 | )
115 | local_path_button.click(DM_CLS.reset_path, inputs=[local_model_path])
116 | model.change(DM_CLS.reset, inputs=[model])
117 | with gr.Row():
118 | model_download_button = gr.Button('模型下载')
119 | model_stop_download = gr.Button('取消下载')
120 | model_path = gr.Markdown(label='模型下载详情')
121 |
122 | # model_download_information = gr.Markdown(label='模型下载信息')
123 | # model_download_path = gr.Textbox(visible=False)
124 | model_download_button.click(DM_CLS.auto_download, outputs=[model_path])
125 | model_stop_download.click(DM_CLS.break_download, outputs=[model_path])
126 |
127 | with gr.Column():
128 | dataset = gr.Dropdown(choices=DATA_LIST + ['自定义'], value='shibing624/medical',label = '数据集', info='请选择合适的数据集,选择后可点击下方按钮进行下载',interactive=True)
129 | DT_CLS = xtunerDataDownload(
130 | data_name= dataset.value,
131 | out_path=DATA_DOWNLOAD_DIR,
132 | tqdm_class=tqdm
133 | )
134 | local_path_button.click(DT_CLS.reset_path, inputs=[local_model_path])
135 | dataset.change(DT_CLS.reset, inputs=[dataset])
136 | with gr.Row():
137 | dataset_download_button = gr.Button('数据集下载')
138 | dataset_stop_download = gr.Button('取消下载')
139 | data_path = gr.Markdown(label='数据下载详情')
140 | # dataset_download_information = gr.Markdown(label='数据集下载信息')
141 | # dataset_download_path = gr.Textbox(visible=False)
142 | dataset_download_button.click(DT_CLS.auto_download, outputs=[data_path])
143 | dataset_stop_download.click(DT_CLS.break_download, outputs=[data_path])
144 | wrong_message1 = gr.Markdown()
145 | with gr.Row():
146 | with gr.Column(scale=1):
147 | with gr.Accordion(label="自定义模型",open=False):
148 | model_personal_path = gr.Textbox(label='自定义模型本地路径', info = '请输入模型的本地路径在下方文本框中')
149 | personal_model = gr.Files(label='请上传自定义模型文件',visible=False)
150 | check_personal_model = gr.Button('模型检查及提示词模板自动匹配(请务必点击!)')
151 | detect_prompt_status = gr.Markdown() #可用于承接检查后得到的结果
152 | # 上传文件自动显示在 model_personal_path
153 | personal_model.change(lambda x: x, inputs=[personal_model], outputs=[model_personal_path])
154 | with gr.Column(scale=2):
155 | with gr.Accordion(label="自定义数据集(仅支持OpenAI格式)",open=False):
156 | with gr.Row():
157 | with gr.Column():
158 | dataset_type = gr.Dropdown(choices=['OpenAI'],value='OpenAI',label = '支持的数据集格式', interactive=False)
159 | dataset_type_preview = gr.TextArea(label='OpenAI数据集格式展示', info= '该数据集的标准格式如下所示,请将自定义的数据集格式转化为该格式。',value=OPENAI_FORMAT)
160 | #dataset_type_preview = gr.JSON(label='数据集格式展示')
161 | with gr.Column():
162 | dataset_personal_path = gr.Textbox(label = '数据集本地路径', info='请填入本地数据集路径或直接在下方上传数据文件')
163 | # dataset_personal_path_upload = gr.Button('请点击上传数据集本地路径')
164 | dataset_personal = gr.File(label='请上传自定义的数据集或在上方填入本地路径',type='filepath')
165 | check_personal_dataset = gr.Button('检查数据集是否符合要求')
166 | wrong_message3 = gr.Markdown() #判定数据集格式是否符合要求,符合就在上面显示
167 | check_personal_dataset.click(check_custom_dataset, inputs=[dataset_personal_path, dataset_personal], outputs=wrong_message3)
168 |
169 | # with gr.Accordion(label="数据集预览",open=False):
170 | # dataset_preview = gr.TextArea(label='数据集展示', info = '截取前n行内容,可用于对比原数据集格式。')
171 | # #dataset_preview = gr.JSON(label='数据集展示')
172 | # dataset_personal_path_upload.click(fn=read_first_ten_lines, inputs=dataset_personal_path, outputs=dataset_preview, queue=False)
173 | # dataset_personal.change(fn=read_first_ten_lines, inputs=dataset_personal, outputs=dataset_preview, queue=False)
174 | with gr.Accordion(label="对应提示词模版展示",open=False):
175 | with gr.Row():
176 | prompt_template = gr.Dropdown(PROMPT_TEMPLATE_LIST, label='提示词模版', value='default', info='请选择合适的提示词模版(请勿随意进行调整)',interactive=True)
177 | prompt_template_show = gr.TextArea(label='提示词模版展示')
178 |
179 | model.change(model_path_map_fn, inputs=[model], outputs=[prompt_template])
180 | # 检测完毕后 -> 改变 prompt_template -> prompt_template_show
181 | check_personal_model.click(app_get_prompt_template, inputs=[model_personal_path, personal_model], outputs=[detect_prompt_status, prompt_template])
182 | prompt_template.change(fn=get_template_format_by_name, inputs=prompt_template, outputs=prompt_template_show)
183 |
184 | gr.Markdown("## 3. 微调参数设置")
185 | # with gr.Accordion(label="参数调整指南",open=False):
186 | # gr.Markdown('#### 参数调整方式为...')
187 | with gr.Tab("基础参数"):
188 | with gr.Row():
189 | lr = gr.Number(label='学习率(Learning Rate)', value=2.0e-5, info='学习率控制模型权重调整的幅度,在训练过程中对损失函数的优化有直接影响。较小的学习率可能导致学习过程缓慢,而较大的学习率可能导致学习过程中出现不稳定。')
190 | warmup_ratio = gr.Number(label='预热比(Warmup Ratio)', value=0.03, info='预热比例用于在训练初期逐渐增加学习率,这有助于模型训练初期的稳定性,避免因学习率过高导致的训练不稳定。')
191 | max_length = gr.Number(label='数据集最大长度(Max Length)', value=2048, info='设置数据在处理前的最大长度,确保模型可以处理的序列长度范围内,有助于控制训练过程的内存使用。')
192 | pack_to_max_length = gr.Dropdown(choices=[True, False], value=True, label='合并为最长样本(Pack to Max Length)', info='决定是否将多个样本合并成一个最大长度的样本。这可以提高数据处理的效率,但可能影响模型学习到的模式。')
193 | with gr.Row():
194 | batch_size_per_device = gr.Number(label='每设备样本个数(Batch Size per Device)', value=1, info='定义每个设备上进行处理的样本数量。较大的批量大小可以提高训练效率,但也会增加内存的使用量。')
195 | accumulative_counts = gr.Number(label='梯度累计数(Gradient Accumulation Steps)', value=16, info='在进行一次参数更新前累积的梯度步数,可以增大批处理大小的效果而不增加内存消耗。')
196 | deepspeed = gr.Dropdown(choices=['None','zero1','zero2','zero3'], value='None', label='Deepspeed算子(Deepspeed)', info='选择Deepspeed优化策略来加速训练和降低内存使用。不同的优化级别提供了不同的内存和计算优化。')
197 | num_GPU = gr.Number(label='GPU数量(Number of GPUs)', value=1, info='设置训练过程中使用的GPU数量。增加GPU数量可以提高训练速度,但需要确保硬件资源充足。')
198 | with gr.Row():
199 | max_epochs = gr.Number(label='训练迭代数(Max Epochs)', value=2, info='设置模型训练过程中数据将被遍历的次数。较多的迭代次数可以提高模型性能,但也会增加训练时间。')
200 | save_checkpoint_interval = gr.Number(label='保存权重间隔(Save Checkpoint Interval)', value=1000, info='设置自动保存模型权重的间隔(以迭代次数计)。这有助于从训练中途的某个点恢复训练过程。')
201 | save_total_limit = gr.Number(label='最多保存权重文件数(Save Total Limit)', value=2, info='限制保存的模型权重文件的最大数量,有助于管理存储空间,避免因保存过多的模型文件而耗尽存储。')
202 | evaluation_freq = gr.Number(label='验证对话效果频率(evaluation_freq)', value=100, info='请确定模型每多少轮需要验证一次对话效果,具体的对话问题及系统提示词可以在下方评估问题处进行设置')
203 |
204 | # todo: 测试问题 多个的问题
205 | with gr.Accordion(label="评估问题设置", open=True):
206 | evaluation_system_prompt = gr.Textbox(label = '系统提示词(system_prompt)', value='', info='请设置在评估模式下的系统提示词(默认为无)')
207 | default_evaluation_question_number = 2
208 | max_evaluation_question_number = 10
209 | default_evaluation_question_list = [
210 | '请给我介绍五个上海的景点',
211 | 'Please tell me five scenic spots in Shanghai'
212 | ]
213 | evaluation_question_list = []
214 | with gr.Accordion(label='评估问题数量及内容',open=True):
215 | with gr.Row():
216 | with gr.Column():
217 | evaluation_question_number = gr.Number(label='评估问题数', value=default_evaluation_question_number, minimum=1, maximum=max_evaluation_question_number, info='调整评估问题的数量(最多10个问题)')
218 | with gr.Column():
219 | for i in range(max_evaluation_question_number):
220 | evaluation_question_if_visible = True if i < default_evaluation_question_number else False
221 | evaluation_question_value = default_evaluation_question_list[i] if i < default_evaluation_question_number else ''
222 | t = gr.Textbox(label=f'评估问题{i + 1}', value=evaluation_question_value, interactive=True, placeholder=f"请输入第{i + 1}个评估的问题", visible=evaluation_question_if_visible)
223 | evaluation_question_list.append(t)
224 | evaluation_question_number.change(evaluation_question_number_change_wrap(max_evaluation_question_number), evaluation_question_number, evaluation_question_list)
225 | with gr.Tab('进阶参数'):
226 | with gr.Row():
227 | optim_type = gr.Dropdown(choices=['AdamW'], value='AdamW', label='优化器(Optimizer)', info='选择优化器用于调整网络权重以减少误差;AdamW是Adam优化器的一种变体,提供权重衰减控制,通常用于更好的泛化。', visible=True)
228 |
229 | weight_decay = gr.Number(label='权重衰减(Weight Decay)', value=0, info='权重衰减是一种正则化技术,通过为模型的损失函数添加一个与权重大小成比例的惩罚项来防止模型的过拟合。')
230 |
231 | with gr.Row():
232 | max_norm = gr.Number(label='梯度剪裁(Gradient Clipping)', value=1, info='通过设置梯度的最大阈值来防止在训练过程中梯度爆炸的问题,有助于稳定模型的训练过程。')
233 | dataloader_num_workers = gr.Number(label='数据加载线程数(Data Loader Number of Workers)', value=0, info='设置在数据加载时并行工作的线程数,较高的值可以加快数据加载速度,但会增加内存和处理器的负担。')
234 |
235 | with gr.Accordion(label="AdamW优化器betas", open=False):
236 | beta1 = gr.Number(label='beta1 (一阶矩估计)', value=0.9, info='用于计算梯度的一阶矩估计(即梯度的指数移动平均),决定了过去梯度的权重,高值意味着模型更加关注过去的梯度。')
237 | beta2 = gr.Number(label='beta2 (二阶矩估计)', value=0.999, info='用于计算梯度的二阶矩估计(即梯度平方的指数移动平均),决定了梯度变化率的平滑程度,高值可以使优化过程在长时间内更加平稳。')
238 |
239 | ft_method.change(
240 | get_default_hyperparameters, inputs=[ft_method],
241 | outputs=[
242 | warmup_ratio,
243 | batch_size_per_device,
244 | accumulative_counts,
245 | num_GPU,
246 | max_length,
247 | pack_to_max_length,
248 | evaluation_freq,
249 | optim_type,
250 | weight_decay,
251 | max_norm,
252 | dataloader_num_workers,
253 | beta1,
254 | beta2,
255 | lr,
256 | save_checkpoint_interval,
257 | save_total_limit
258 | ]
259 | )
260 | change_config_button = gr.Button('点击生成配置文件')
261 | cfg_py_box = gr.Markdown(value="还未生成配置文件")
262 | change_config_button.click(
263 | build_and_save_config,
264 | inputs=[
265 | dataset_personal_path,
266 | dataset_personal,
267 | model_personal_path,
268 | personal_model,
269 | prompt_template,
270 | local_path,
271 | ft_method,
272 | model_path,
273 | data_path,
274 | deepspeed,
275 | lr,
276 | warmup_ratio,
277 | batch_size_per_device,
278 | accumulative_counts,
279 | num_GPU,
280 | max_length,
281 | pack_to_max_length,
282 | max_epochs,
283 | save_checkpoint_interval,
284 | save_total_limit,
285 | evaluation_freq,
286 | evaluation_system_prompt,
287 | optim_type,
288 | weight_decay,
289 | max_norm,
290 | dataloader_num_workers,
291 | beta1,
292 | beta2,
293 | prompt_template,
294 | *evaluation_question_list
295 | ],
296 | outputs=[cfg_py_box]
297 | )
298 | wrong_message4 = gr.Markdown()
299 |
300 | gr.Markdown("## 4. 微调模型训练")
301 | TR_CLS = quickTrain(
302 | config_py_path=cfg_py_box.value,
303 | work_dir=f'{local_path.value}/work_dir',
304 | deepspeed_seed=deepspeed
305 | )
306 | change_config_button.click(TR_CLS.reset_cfg_py, inputs=[cfg_py_box])
307 | cfg_py_box.change(TR_CLS.reset_cfg_py, inputs=[cfg_py_box])
308 | deepspeed.change(TR_CLS.reset_deepspeed, inputs=[deepspeed])
309 | local_path_button.click(TR_CLS.reset_work_dir, inputs=[local_path])
310 | with gr.Row():
311 | train_model = gr.Button('Xtuner!启动!',size='lg')
312 | stop_button = gr.Button('训练中断',size='lg')
313 | work_path = gr.Textbox(label='work dir',visible=False)
314 |
315 | tmp_trian_pg_md = gr.Markdown()
316 | train_model.click(TR_CLS.quick_train, outputs=[tmp_trian_pg_md, work_path])
317 | stop_button.click(TR_CLS.break_train, outputs=[tmp_trian_pg_md, work_path], queue=False)
318 | with gr.Accordion(label='模型续训', open=False):
319 | retry_path_dropdown = gr.Dropdown(label='请选择需要继续训练的权重文件', info='将从训练中断前的模型权重文件进行搜索',interactive=True)
320 | retry_button = gr.Button('继续训练')
321 | retry_path_dropdown.change(TR_CLS.reset_resume_from_checkpoint, inputs=[retry_path_dropdown])
322 | retry_button.click(TR_CLS.resume_train, outputs=[tmp_trian_pg_md, work_path])
323 |
324 | with gr.Accordion(label="终端界面",open=False):
325 | log_file = gr.TextArea(label='日志文件打印', info= '点击可查看模型训练信息')
326 | # train_model.click(TR_CLS.start_log, outputs=[log_file])
327 | # retry_button.click(TR_CLS.start_log, outputs=[log_file])
328 |
329 | wrong_message5 = gr.Markdown()
330 | gr.Markdown("## 5. 微调结果展示")
331 | PLT = resPlot(
332 | work_dir = f'{local_path.value}/work_dir',
333 | )
334 | # 点击停止训练的时候 retry_path_dropdown 进行更新
335 | stop_button.click(PLT.dynamic_drop_down, outputs=retry_path_dropdown, queue=False)
336 | work_path.change(PLT.dynamic_drop_down, outputs=retry_path_dropdown, queue=False)
337 | local_path_button.click(PLT.reset_work_dir, inputs=[local_path])
338 | work_path.change(PLT.reset_work_dir, inputs=[local_path])
339 | with gr.Tab('训练结果'):
340 | # with gr.Row():
341 | # ft_model_save_path = gr.Textbox(label='模型保存路径',visible=False)
342 | # detect work_dir find newest
343 | # iter_num = gr.Number(label='训练轮数', scale=1)
344 | # num_pth = gr.Number(label='权重文件数量', scale=1)
345 | with gr.Row():
346 | # lr_plot = gr.Image(label='学习率变化图',container=False,show_download_button=False,interactive=False)
347 | # loss_graph = gr.Image(label='损失变化图',container=False,show_download_button=False)
348 | lr_plot = gr.LinePlot(label='学习率变化图')
349 | loss_graph = gr.LinePlot(label='损失变化图')
350 | with gr.Row():
351 | num_pth_evaluation = gr.Dropdown(label='请选择权重文件', info='可获取模型训练过程中评估问题的结果展示')
352 | evaluation_question = gr.TextArea(label='测试问题结果')
353 |
354 | stop_button.click(PLT.dynamic_eval_drop_down, outputs=num_pth_evaluation, queue=False)
355 | show_evaluation_button = gr.Button('微调结果生成')
356 | show_evaluation_button.click(PLT.reset_work_dir, inputs=[local_path], queue=False)
357 | show_evaluation_button.click(PLT.lr_plot, outputs=[lr_plot], queue=False)
358 | show_evaluation_button.click(PLT.loss_plot, outputs=[loss_graph], queue=False)
359 | # 更新eval的下拉列表
360 | show_evaluation_button.click(PLT.dynamic_eval_drop_down, outputs=num_pth_evaluation, queue=False)
361 | work_path.change(PLT.dynamic_eval_drop_down, outputs=num_pth_evaluation, queue=False)
362 | # 找到 & read eval
363 | num_pth_evaluation.change(PLT.get_eval_test, inputs=[num_pth_evaluation], outputs=[evaluation_question])
364 |
365 | gr.Markdown("## 6. 微调模型转化及测试")
366 |
367 | with gr.Accordion(label="模型转换",open=True):
368 | # Textbox
369 | # select_checkpoint =gr.Dropdown(choices=['epoch_1.pth', 'epoch_1.pth'], value='epoch_1.pth', label='微调模型的权重文件', info = '请选择需要进行测试的模型权重文件并进行转化')
370 | select_checkpoint = gr.Dropdown(label='微调模型的权重文件', info = '请选择需要进行测试的模型权重文件并进行转化',interactive = True)
371 | stop_button.click(PLT.dynamic_drop_down, outputs=select_checkpoint, queue=False)
372 | show_evaluation_button.click(PLT.dynamic_drop_down, outputs=select_checkpoint, queue=False)
373 |
374 | covert_hf = gr.Button('模型转换',scale=1)
375 | covert_hf_path = gr.Textbox(label='模型转换后地址', visible=False) # False
376 | wrong_message6 = gr.Markdown()
377 |
378 | # root_dir, config_file, epoch_pth, model_path, customer_model_path)
379 | # todo ft_method full-convert oth-convert+merge
380 | covert_hf.click(CVT_MG.auto_convert_merge, inputs=[local_path, cfg_py_box, select_checkpoint, model_path, model_personal_path, ft_method], outputs=[wrong_message6, covert_hf_path])
381 | with gr.Accordion(label='对话测试', open=True):
382 | with gr.Row():
383 | with gr.Accordion(label="原模型对话测试", open=True):
384 | with gr.Column():
385 | with gr.Accordion(label='参数设置',open=False):
386 | max_new_tokens = gr.Slider(minimum=0, maximum=4096, value=1024, label='模型输出的最长Token(max_new_tokens)', info='这个参数决定了模型输出的最大token数量。增加这个值允许模型生成更长的文本,而减少这个值会导致生成的文本更短。')
387 | temperature = gr.Slider(maximum=2, minimum=0, label='温度值(temperature)',value=1, info='控制生成文本的随机性。较高的温度值会使输出更加多样和不可预测,而较低的值使输出更确定和重复。')
388 | top_k = gr.Slider(minimum=0, maximum=100, value=40, label='Top-k Sampling(top-k)', info='限制模型在每一步生成文本时考虑的最可能候选词的数量。较大的k值增加了多样性,但可能降低文本的连贯性;较小的k值则相反。')
389 | top_p = gr.Slider(minimum=0, maximum=2, value=0.75, label='Top-p Sampling(top-p)', info='类似于top_k,但通过选择累积概率高于某个阈值p的最小词集,动态调整考虑的候选词数量。较高的p值增加多样性,较低的p值提高连贯性。')
390 | num_beams = gr.Slider(minimum=0, maximum=12, value=5, label='Beam Search(num_beams)', info='在beam search中,num_beams指定了搜索宽度。更多的beams可以提高生成文本的质量,但也会增加计算负担。')
391 |
392 | #还可以添加更多
393 | wrong_message9 = gr.Markdown()
394 | start_testing_model = gr.Button('模型启动')
395 | testig_model_loaded = gr.Markdown()
396 | chatbot = gr.Chatbot(label='微调模型测试')
397 | msg = gr.Textbox(label="输入信息")
398 | msg.submit(CHAT_ORG.qa_answer, inputs=[msg, max_new_tokens, temperature, top_k, top_p, num_beams, chatbot], outputs=[msg, chatbot])
399 | # 模型载入
400 | start_testing_model.click(CHAT_ORG.load_model, inputs=[model_personal_path, personal_model, model_path], outputs=[testig_model_loaded])
401 | send = gr.Button('信息发送') # .click(regenerate, inputs=[chatbot], outputs = [msg, chatbot])
402 | with gr.Row():
403 | clear = gr.Button('记录删除') # .click(clear_history, inputs=[chatbot], outputs=[chatbot])
404 | undo = gr.Button('撤回上一条') # .click(undo, inputs=[chatbot], outputs=[chatbot])
405 |
406 |
407 | clear.click(CHAT_ORG.qa_clear, inputs=[chatbot], outputs=[chatbot])
408 | undo.click(CHAT_ORG.qa_undo, inputs=[chatbot], outputs=[chatbot])
409 | send.click(CHAT_ORG.qa_answer, inputs=[msg, max_new_tokens, temperature, top_k, top_p, num_beams, chatbot], outputs=[msg, chatbot])
410 |
411 | with gr.Accordion(label="微调模型对话测试", open=True):
412 | with gr.Column():
413 | with gr.Accordion(label='参数设置',open=False):
414 | ft_max_new_tokens = gr.Slider(minimum=0, maximum=4096, value=1024, label='模型输出的最长Token(max_new_tokens)', info='这个参数决定了模型输出的最大token数量。增加这个值允许模型生成更长的文本,而减少这个值会导致生成的文本更短。')
415 | ft_temperature = gr.Slider(maximum=2, minimum=0,value=1, label='温度值(temperature)', info='控制生成文本的随机性。较高的温度值会使输出更加多样和不可预测,而较低的值使输出更确定和重复。')
416 | ft_top_k = gr.Slider(minimum=0, maximum=100, value=40, label='Top-k Sampling(top-k)', info='限制模型在每一步生成文本时考虑的最可能候选词的数量。较大的k值增加了多样性,但可能降低文本的连贯性;较小的k值则相反。')
417 | ft_top_p = gr.Slider(minimum=0, maximum=2, value=0.75, label='Top-p Sampling(top-p)', info='类似于top_k,但通过选择累积概率高于某个阈值p的最小词集,动态调整考虑的候选词数量。较高的p值增加多样性,较低的p值提高连贯性。')
418 | ft_num_beams = gr.Slider(minimum=0, maximum=12, value=5, label='Beam Search(num_beams)', info='在beam search中,num_beams指定了搜索宽度。更多的beams可以提高生成文本的质量,但也会增加计算负担。')
419 | #还可以添加更多
420 | ft_wrong_message9 = gr.Markdown()
421 | ft_start_testing_model = gr.Button('模型启动')
422 | ft_testig_model_loaded = gr.Markdown()
423 | ft_chatbot = gr.Chatbot(label='微调模型测试')
424 | ft_msg = gr.Textbox(label="输入信息")
425 |
426 | ft_msg.submit(FT_CHAT_ORG.qa_answer, inputs=[ft_msg, ft_max_new_tokens, ft_temperature, ft_top_k, ft_top_p, ft_num_beams, ft_chatbot], outputs=[ft_msg, ft_chatbot])
427 | # 模型载入
428 | ft_start_testing_model.click(FT_CHAT_ORG.load_model, inputs=[covert_hf_path, personal_model, model_path], outputs=[ft_testig_model_loaded])
429 |
430 | ft_send = gr.Button('信息发送')
431 | with gr.Row():
432 | ft_clear = gr.Button('记录删除')
433 | ft_undo = gr.Button('撤回上一条')
434 |
435 |
436 | ft_clear.click(FT_CHAT_ORG.qa_clear, inputs=[ft_chatbot], outputs=[ft_chatbot])
437 | ft_undo.click(FT_CHAT_ORG.qa_undo, inputs=[ft_chatbot], outputs=[ft_chatbot])
438 | ft_send.click(FT_CHAT_ORG.qa_answer, inputs=[ft_msg, ft_max_new_tokens, ft_temperature, ft_top_k, ft_top_p, ft_num_beams, ft_chatbot], outputs=[ft_msg, ft_chatbot])
439 | # with gr.Accordion(label='模型基础能力评估测试',open=False):
440 | # mmlu_test_button = gr.Button('MMLU模型能力评估测试')
441 | with gr.Accordion(label="其他信息", open=True):
442 | star = gr.Markdown('### 如果您觉得该UI界面能帮助您高效完成微调工作,请为[XTuner](https://github.com/InternLM/xtuner.git)和[XTunerGUI](https://github.com/scchy/XtunerGUI.git)点个星星!非常感谢您的支持!')
443 | thanks = gr.Markdown('''### 最后,感谢XTUnerGUI团队成员对该项目的贡献:
444 | - [scchy](https://github.com/scchy) - 整体后端开发
445 | - [jianfeng777](https://github.com/Jianfeng777) - 整体前端开发
446 | - [l241025097](https://github.com/l241025097) - 模型训练终端可视化
447 | - [semple030228](https://github.com/semple030228) - 模型转化
448 |
449 | ### 同时也感谢XTuner团队的大力支持:
450 | - [HIT-cwh](https://github.com/HIT-cwh) - 配置文件生成及相关检查
451 | - [pppppM](https://github.com/pppppM) - 提供指导意见
452 |
453 | 我们相信XTuner将成为国内有影响力的模型微调工具包!
454 | ''')
455 |
456 | #with gr.Tab('微调模型部署(LMDeploy)'):
457 |
458 | #with gr.Tab('微调模型测评(OpenCompass)'):
459 |
460 | demo.load(TR_CLS.read_log, outputs=[log_file], every=1)
461 | demo.launch(share=True) #, server_name="0.0.0.0", server_port=6007, root_path=f'/proxy/6007/') #, server_name="0.0.0.0", server_port=6006, root_path=f'/proxy/6006/')
462 |
463 |
464 |
465 |
466 |
--------------------------------------------------------------------------------
/appPrepare/env_prepare.sh:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | cd ~
5 | mkdir xtunerPKG
6 | git clone https://gitee.com/InternLM/xtuner.git
7 | cd xtuner
8 | pip install -e '.[all]'
9 |
10 |
--------------------------------------------------------------------------------
/appPrepare/files_prepare.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 |
4 | CUR_PATH = os.path.dirname(__file__)
5 | DATA_PATH = os.path.dirname(CUR_PATH)
6 |
7 | def dir_create(_dir):
8 | if not os.path.exists(_dir):
9 | os.system(f'mkdir -p {_dir}')
10 | return _dir
11 |
12 | DEFAULT_DOWNLOAD_DIR = dir_create(f"{DATA_PATH}/download_cache")
13 | MODEL_DOWNLOAD_DIR = dir_create(f"{DEFAULT_DOWNLOAD_DIR}/model_download")
14 | DATA_DOWNLOAD_DIR = dir_create(f"{DEFAULT_DOWNLOAD_DIR}/data_download")
15 | WORK_DIR = dir_create(f"{CUR_PATH}/work_dir")
16 |
17 |
18 |
19 |
20 |
21 |
--------------------------------------------------------------------------------
/appPrepare/func_prepare.py:
--------------------------------------------------------------------------------
1 | from xtuner.utils import PROMPT_TEMPLATE
2 |
3 | OPENAI_FORMAT = '''
4 | [
5 | {
6 | "messages":
7 | [
8 | { "role": "system", "content": "You are an assistant that occasionally misspells words." },
9 | { "role": "user", "content": "Tell me a story." },
10 | { "role": "assistant", "content": "One day a student went to schoool." }
11 | ]
12 | },
13 |
14 | {
15 | "messages":
16 | [
17 | { "role": "user", "content": "Tell me a story." },
18 | { "role": "assistant", "content": "One day a student went to schoool." }
19 | ]
20 | }
21 | ]
22 | '''
23 |
24 | def read_first_ten_lines(input_path, upload_path=None):
25 | file_path = input_path if len(input_path) >= 3 else upload_path
26 | try:
27 | with open(file_path, 'r', encoding='utf-8') as file:
28 | lines = file.readlines() # 读取所有行
29 | first_ten_lines = lines[:20] # 获取前十行
30 | return ''.join(first_ten_lines) # 将前十行合并为一个字符串并返回
31 | except Exception as e:
32 | return f"Error reading file: {str(e)}"
33 |
34 |
35 | def get_template_format_by_name(template_name):
36 | template = PROMPT_TEMPLATE.get(template_name, None)
37 | if template is None:
38 | return "Template not found"
39 | return str(template)
40 |
--------------------------------------------------------------------------------
/appPrepare/list_prepare.py:
--------------------------------------------------------------------------------
1 |
2 | DATA_LIST = [
3 | 'ArmelR/stack-exchange-instruction',
4 | 'HuggingFaceH4/CodeAlpaca_20K',
5 | 'Open-Orca/OpenOrca',
6 | 'Skywork/SkyPile-150B',
7 | 'WizardLM/WizardLM_evol_instruct_V2_196k',
8 | 'b-mc2/sql-create-context',
9 | 'burkelibbey/colors',
10 | 'damo/MSAgent-Bench',
11 | 'garage-bAInd/Open-Platypus',
12 | 'mistralai/Mistral-7B-v0.1',
13 | 'nampdn-ai/tiny-codes',
14 | 'shibing624/medical',
15 | 'silk-road/alpaca-data-gpt4-chinese',
16 | 'tatsu-lab/alpaca',
17 | 'timdettmers/openassistant-guanaco',
18 | ]
19 |
20 | MODEL_LIST = [
21 | 'internlm/internlm-7b',
22 | 'internlm/internlm-20b',
23 | 'internlm/internlm-chat-7b',
24 | 'internlm/internlm-chat-20b',
25 | 'meta-llama/Llama-2-7b-chat',
26 | 'meta-llama/llama2-70b',
27 | 'meta-llama/Llama-2-7b',
28 | 'huggyllama/llama-7b',
29 | 'baichuan-inc/Baichuan-7B-Chat',
30 | 'baichuan-inc/Baichuan-13B-Base',
31 | 'baichuan-inc/baichuan-7B',
32 | 'baichuan-inc/Baichuan2-13B-Chat',
33 | 'baichuan-inc/Baichuan2-7B-Chat',
34 | 'baichuan-inc/Baichuan2-13B-Chat',
35 | 'baichuan-inc/Baichuan2-13B-Base',
36 | 'THUDM/chatglm3-6b',
37 | 'THUDM/chatglm2-6b',
38 | 'THUDM/chatglm3-6b-base',
39 | '01-ai/Yi-24B',
40 | '01-ai/Yi-6B',
41 | 'Qwen/Qwen-7B-Chat',
42 | 'Qwen/Qwen-7B',
43 | ]
44 |
45 |
46 | PROMPT_TEMPLATE_LIST = [
47 | 'default',
48 | 'zephyr',
49 | 'internlm_chat',
50 | 'internlm2_chat',
51 | 'moss_sft',
52 | 'llama2_chat',
53 | 'code_llama_chat',
54 | 'chatglm2',
55 | 'chatglm3',
56 | 'qwen_chat',
57 | 'baichuan_chat',
58 | 'baichuan2_chat',
59 | 'wizardlm',
60 | 'wizardcoder',
61 | 'vicuna',
62 | 'deepseek_coder',
63 | 'deepseekcoder',
64 | 'deepseek_moe',
65 | 'mistral',
66 | 'mixtral'
67 | ]
--------------------------------------------------------------------------------
/chat/model_center.py:
--------------------------------------------------------------------------------
1 | # python3
2 | # Create Date: 2024-02-03
3 | # Author: Scc_hy
4 | # Func: chat center
5 | # ==============================================================================
6 |
7 | import torch
8 | from transformers import AutoModelForCausalLM, AutoTokenizer
9 |
10 | user_prompt = "<|User|>:{user}\n"
11 | robot_prompt = "<|Bot|>:{robot}
\n"
12 | cur_query_prompt = "<|User|>:{user}\n<|Bot|>:"
13 |
14 |
15 | class ModelCenter():
16 | def __init__(self):
17 | self.model = None
18 | self.tokenizer = None
19 |
20 | def load_model(self,
21 | model_personal_path,
22 | personal_model,
23 | model_path_in
24 | ):
25 | # 构造函数,加载检索问答链
26 | model_path = self.choice_path(model_personal_path, personal_model, model_path_in)
27 | print(f'ModelCenter.load_model({model_path})')
28 | self.model = (
29 | AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True)
30 | .to(torch.bfloat16)
31 | .cuda()
32 | )
33 | self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
34 | print('>>>>> Loading Model Done !')
35 | return f'>>> Loaded {model_path}'
36 |
37 | @staticmethod
38 | def choice_path(model_personal_path, personal_model, model_path_in):
39 | if len(model_personal_path) >= 3:
40 | return model_personal_path
41 | if len(personal_model) >= 3:
42 | return personal_model
43 | return model_path_in
44 |
45 | def qa_answer(self, question: str, max_new_tokens, temperature, top_k, top_p, num_beams, chat_history: list = []):
46 | if question == None or len(question) < 1:
47 | return "", chat_history
48 | try:
49 | question = question.replace(" ", ' ')
50 | response, history = self.model.chat(
51 | self.tokenizer,
52 | question,
53 | history=chat_history,
54 | max_new_tokens=max_new_tokens,
55 | temperature=temperature,
56 | top_k=top_k,
57 | top_p=top_p,
58 | num_beams=int(num_beams)
59 | )
60 | chat_history.append((question, response))
61 | return "", chat_history
62 | except Exception as e:
63 | return e, chat_history
64 |
65 | def qa_undo(self, chat_history: list = []):
66 | if len(chat_history):
67 | chat_history.pop()
68 | return chat_history
69 |
70 | def qa_clear(self, chat_history: list = []):
71 | return []
72 |
--------------------------------------------------------------------------------
/pictures/1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/scchy/XtunerGUI/6e510e39abf55169f45f0e94fb14055c2b973af4/pictures/1.png
--------------------------------------------------------------------------------
/pictures/demo.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/scchy/XtunerGUI/6e510e39abf55169f45f0e94fb14055c2b973af4/pictures/demo.mp4
--------------------------------------------------------------------------------
/pictures/workflow.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/scchy/XtunerGUI/6e510e39abf55169f45f0e94fb14055c2b973af4/pictures/workflow.jpg
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | tqdm
2 | openxlab==0.0.34
3 | gradio==4.4.0
4 | transformers==4.34.0
5 | huggingface_hub
6 | modelscope==1.11.0
7 | unstructured==0.10.30
8 | markdown==3.3.7
9 | xtuner
10 | altair
11 |
--------------------------------------------------------------------------------
/template_configs/full_finetune.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | from datasets import load_dataset
3 | from mmengine.dataset import DefaultSampler
4 | from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
5 | LoggerHook, ParamSchedulerHook)
6 | from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
7 | from torch.optim import AdamW
8 | from transformers import AutoModelForCausalLM, AutoTokenizer
9 |
10 | from xtuner.dataset import process_hf_dataset
11 | from xtuner.dataset.collate_fns import default_collate_fn
12 | from xtuner.dataset.map_fns import alpaca_map_fn, template_map_fn_factory
13 | from xtuner.engine.hooks import (DatasetInfoHook, EvaluateChatHook,
14 | ThroughputHook,
15 | VarlenAttnArgsToMessageHubHook)
16 | from xtuner.engine.runner import TrainLoop
17 | from xtuner.model import SupervisedFinetune
18 | from xtuner.utils import PROMPT_TEMPLATE, SYSTEM_TEMPLATE
19 |
20 | #######################################################################
21 | # PART 1 Settings #
22 | #######################################################################
23 | # Model
24 | pretrained_model_name_or_path = 'internlm/internlm-7b'
25 | use_varlen_attn = False
26 |
27 | # Data
28 | data_path = 'tatsu-lab/alpaca'
29 | prompt_template = PROMPT_TEMPLATE.internlm_chat
30 | max_length = 2048
31 | pack_to_max_length = True
32 |
33 | # Scheduler & Optimizer
34 | batch_size = 1 # per_device
35 | accumulative_counts = 2
36 | dataloader_num_workers = 0
37 | max_epochs = 3
38 | optim_type = AdamW
39 | lr = 2e-5
40 | betas = (0.9, 0.999)
41 | weight_decay = 0
42 | max_norm = 1 # grad clip
43 | warmup_ratio = 0.03
44 |
45 | # Save
46 | save_steps = 500
47 | save_total_limit = 2 # Maximum checkpoints to keep (-1 means unlimited)
48 |
49 | # Evaluate the generation performance during the training
50 | evaluation_freq = 500
51 | SYSTEM = SYSTEM_TEMPLATE.alpaca
52 | evaluation_inputs = [
53 | '请给我介绍五个上海的景点', 'Please tell me five scenic spots in Shanghai'
54 | ]
55 |
56 | #######################################################################
57 | # PART 2 Model & Tokenizer #
58 | #######################################################################
59 | tokenizer = dict(
60 | type=AutoTokenizer.from_pretrained,
61 | pretrained_model_name_or_path=pretrained_model_name_or_path,
62 | trust_remote_code=True,
63 | padding_side='right')
64 |
65 | model = dict(
66 | type=SupervisedFinetune,
67 | use_varlen_attn=use_varlen_attn,
68 | llm=dict(
69 | type=AutoModelForCausalLM.from_pretrained,
70 | pretrained_model_name_or_path=pretrained_model_name_or_path,
71 | trust_remote_code=True))
72 |
73 | #######################################################################
74 | # PART 3 Dataset & Dataloader #
75 | #######################################################################
76 | train_dataset = dict(
77 | type=process_hf_dataset,
78 | dataset=dict(type=load_dataset, path=data_path),
79 | tokenizer=tokenizer,
80 | max_length=max_length,
81 | dataset_map_fn=alpaca_map_fn,
82 | template_map_fn=dict(
83 | type=template_map_fn_factory, template=prompt_template),
84 | remove_unused_columns=True,
85 | shuffle_before_pack=True,
86 | pack_to_max_length=pack_to_max_length,
87 | use_varlen_attn=use_varlen_attn)
88 |
89 | train_dataloader = dict(
90 | batch_size=batch_size,
91 | num_workers=dataloader_num_workers,
92 | dataset=train_dataset,
93 | sampler=dict(type=DefaultSampler, shuffle=True),
94 | collate_fn=dict(type=default_collate_fn, use_varlen_attn=use_varlen_attn))
95 |
96 | #######################################################################
97 | # PART 4 Scheduler & Optimizer #
98 | #######################################################################
99 | # optimizer
100 | optim_wrapper = dict(
101 | type=AmpOptimWrapper,
102 | optimizer=dict(
103 | type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
104 | clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
105 | accumulative_counts=accumulative_counts,
106 | loss_scale='dynamic',
107 | dtype='float16')
108 |
109 | # learning policy
110 | # More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501
111 | param_scheduler = [
112 | dict(
113 | type=LinearLR,
114 | start_factor=1e-5,
115 | by_epoch=True,
116 | begin=0,
117 | end=warmup_ratio * max_epochs,
118 | convert_to_iter_based=True),
119 | dict(
120 | type=CosineAnnealingLR,
121 | eta_min=0.0,
122 | by_epoch=True,
123 | begin=warmup_ratio * max_epochs,
124 | end=max_epochs,
125 | convert_to_iter_based=True)
126 | ]
127 |
128 | # train, val, test setting
129 | train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
130 |
131 | #######################################################################
132 | # PART 5 Runtime #
133 | #######################################################################
134 | # Log the dialogue periodically during the training process, optional
135 | custom_hooks = [
136 | dict(type=DatasetInfoHook, tokenizer=tokenizer),
137 | dict(
138 | type=EvaluateChatHook,
139 | tokenizer=tokenizer,
140 | every_n_iters=evaluation_freq,
141 | evaluation_inputs=evaluation_inputs,
142 | system=SYSTEM,
143 | prompt_template=prompt_template),
144 | dict(type=ThroughputHook)
145 | ]
146 |
147 | if use_varlen_attn:
148 | custom_hooks += [dict(type=VarlenAttnArgsToMessageHubHook)]
149 |
150 | # configure default hooks
151 | default_hooks = dict(
152 | # record the time of every iteration.
153 | timer=dict(type=IterTimerHook),
154 | # print log every 10 iterations.
155 | logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
156 | # enable the parameter scheduler.
157 | param_scheduler=dict(type=ParamSchedulerHook),
158 | # save checkpoint per `save_steps`.
159 | checkpoint=dict(
160 | type=CheckpointHook,
161 | by_epoch=False,
162 | interval=save_steps,
163 | max_keep_ckpts=save_total_limit),
164 | # set sampler seed in distributed evrionment.
165 | sampler_seed=dict(type=DistSamplerSeedHook),
166 | )
167 |
168 | # configure environment
169 | env_cfg = dict(
170 | # whether to enable cudnn benchmark
171 | cudnn_benchmark=False,
172 | # set multi process parameters
173 | mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
174 | # set distributed parameters
175 | dist_cfg=dict(backend='nccl'),
176 | )
177 |
178 | # set visualizer
179 | visualizer = None
180 |
181 | # set log level
182 | log_level = 'INFO'
183 |
184 | # load from which checkpoint
185 | load_from = None
186 |
187 | # whether to resume training from the loaded checkpoint
188 | resume = False
189 |
190 | # Defaults to use random seed and disable `deterministic`
191 | randomness = dict(seed=None, deterministic=False)
192 |
193 | # set log processor
194 | log_processor = dict(by_epoch=False)
195 |
--------------------------------------------------------------------------------
/template_configs/lora.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import torch
3 | from datasets import load_dataset
4 | from mmengine.dataset import DefaultSampler
5 | from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
6 | LoggerHook, ParamSchedulerHook)
7 | from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
8 | from peft import LoraConfig
9 | from torch.optim import AdamW
10 | from transformers import AutoModelForCausalLM, AutoTokenizer
11 |
12 | from xtuner.dataset import process_hf_dataset
13 | from xtuner.dataset.collate_fns import default_collate_fn
14 | from xtuner.dataset.map_fns import alpaca_map_fn, template_map_fn_factory
15 | from xtuner.engine import DatasetInfoHook, EvaluateChatHook, VarlenAttnArgsToMessageHubHook
16 | from xtuner.engine.runner import TrainLoop
17 | from xtuner.model import SupervisedFinetune
18 | from xtuner.utils import PROMPT_TEMPLATE, SYSTEM_TEMPLATE
19 |
20 | #######################################################################
21 | # PART 1 Settings #
22 | #######################################################################
23 | # Model
24 | pretrained_model_name_or_path = 'internlm/internlm-7b'
25 | use_varlen_attn = False
26 |
27 | # Data
28 | data_path = 'tatsu-lab/alpaca'
29 | prompt_template = PROMPT_TEMPLATE.default
30 | max_length = 2048
31 | pack_to_max_length = True
32 |
33 | # Scheduler & Optimizer
34 | batch_size = 1 # per_device
35 | accumulative_counts = 2
36 | dataloader_num_workers = 0
37 | max_epochs = 3
38 | optim_type = AdamW
39 | lr = 2e-4
40 | betas = (0.9, 0.999)
41 | weight_decay = 0
42 | max_norm = 1 # grad clip
43 | warmup_ratio = 0.03
44 |
45 | # Save
46 | save_steps = 200
47 | save_total_limit = 5 # Maximum checkpoints to keep (-1 means unlimited)
48 |
49 | # Evaluate the generation performance during the training
50 | evaluation_freq = 500
51 | SYSTEM = SYSTEM_TEMPLATE.alpaca
52 | evaluation_inputs = [
53 | '请给我介绍五个上海的景点', 'Please tell me five scenic spots in Shanghai'
54 | ]
55 |
56 | #######################################################################
57 | # PART 2 Model & Tokenizer #
58 | #######################################################################
59 | tokenizer = dict(
60 | type=AutoTokenizer.from_pretrained,
61 | pretrained_model_name_or_path=pretrained_model_name_or_path,
62 | trust_remote_code=True,
63 | padding_side='right')
64 |
65 | model = dict(
66 | type=SupervisedFinetune,
67 | use_varlen_attn=use_varlen_attn,
68 | llm=dict(
69 | type=AutoModelForCausalLM.from_pretrained,
70 | pretrained_model_name_or_path=pretrained_model_name_or_path,
71 | trust_remote_code=True,
72 | torch_dtype=torch.float16),
73 | lora=dict(
74 | type=LoraConfig,
75 | r=64,
76 | lora_alpha=16,
77 | lora_dropout=0.1,
78 | bias='none',
79 | task_type='CAUSAL_LM'))
80 |
81 | #######################################################################
82 | # PART 3 Dataset & Dataloader #
83 | #######################################################################
84 | train_dataset = dict(
85 | type=process_hf_dataset,
86 | dataset=dict(type=load_dataset, path=data_path),
87 | tokenizer=tokenizer,
88 | max_length=max_length,
89 | dataset_map_fn=alpaca_map_fn,
90 | template_map_fn=dict(
91 | type=template_map_fn_factory, template=prompt_template),
92 | remove_unused_columns=True,
93 | shuffle_before_pack=True,
94 | pack_to_max_length=pack_to_max_length,
95 | use_varlen_attn=use_varlen_attn)
96 |
97 | train_dataloader = dict(
98 | batch_size=batch_size,
99 | num_workers=dataloader_num_workers,
100 | dataset=train_dataset,
101 | sampler=dict(type=DefaultSampler, shuffle=True),
102 | collate_fn=dict(type=default_collate_fn, use_varlen_attn=use_varlen_attn))
103 |
104 | #######################################################################
105 | # PART 4 Scheduler & Optimizer #
106 | #######################################################################
107 | # optimizer
108 | optim_wrapper = dict(
109 | type=AmpOptimWrapper,
110 | optimizer=dict(
111 | type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
112 | clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
113 | accumulative_counts=accumulative_counts,
114 | loss_scale='dynamic',
115 | dtype='float16')
116 |
117 | # learning policy
118 | # More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501
119 | param_scheduler = [
120 | dict(
121 | type=LinearLR,
122 | start_factor=1e-5,
123 | by_epoch=True,
124 | begin=0,
125 | end=warmup_ratio * max_epochs,
126 | convert_to_iter_based=True),
127 | dict(
128 | type=CosineAnnealingLR,
129 | eta_min=0.0,
130 | by_epoch=True,
131 | begin=warmup_ratio * max_epochs,
132 | end=max_epochs,
133 | convert_to_iter_based=True)
134 | ]
135 |
136 | # train, val, test setting
137 | train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
138 |
139 | #######################################################################
140 | # PART 5 Runtime #
141 | #######################################################################
142 | # Log the dialogue periodically during the training process, optional
143 | custom_hooks = [
144 | dict(type=DatasetInfoHook, tokenizer=tokenizer),
145 | dict(
146 | type=EvaluateChatHook,
147 | tokenizer=tokenizer,
148 | every_n_iters=evaluation_freq,
149 | evaluation_inputs=evaluation_inputs,
150 | system=SYSTEM,
151 | prompt_template=prompt_template)
152 | ]
153 |
154 | if use_varlen_attn:
155 | custom_hooks += [dict(type=VarlenAttnArgsToMessageHubHook)]
156 |
157 | # configure default hooks
158 | default_hooks = dict(
159 | # record the time of every iteration.
160 | timer=dict(type=IterTimerHook),
161 | # print log every 10 iterations.
162 | logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
163 | # enable the parameter scheduler.
164 | param_scheduler=dict(type=ParamSchedulerHook),
165 | # save checkpoint per `save_steps`.
166 | checkpoint=dict(
167 | type=CheckpointHook,
168 | by_epoch=False,
169 | interval=save_steps,
170 | max_keep_ckpts=save_total_limit),
171 | # set sampler seed in distributed evrionment.
172 | sampler_seed=dict(type=DistSamplerSeedHook),
173 | )
174 |
175 | # configure environment
176 | env_cfg = dict(
177 | # whether to enable cudnn benchmark
178 | cudnn_benchmark=False,
179 | # set multi process parameters
180 | mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
181 | # set distributed parameters
182 | dist_cfg=dict(backend='nccl'),
183 | )
184 |
185 | # set visualizer
186 | visualizer = None
187 |
188 | # set log level
189 | log_level = 'INFO'
190 |
191 | # load from which checkpoint
192 | load_from = None
193 |
194 | # whether to resume training from the loaded checkpoint
195 | resume = False
196 |
197 | # Defaults to use random seed and disable `deterministic`
198 | randomness = dict(seed=None, deterministic=False)
199 |
200 | # set log processor
201 | log_processor = dict(by_epoch=False)
202 |
--------------------------------------------------------------------------------
/template_configs/qlora.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import torch
3 | from datasets import load_dataset
4 | from mmengine.dataset import DefaultSampler
5 | from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
6 | LoggerHook, ParamSchedulerHook)
7 | from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
8 | from peft import LoraConfig
9 | from torch.optim import AdamW
10 | from transformers import (AutoModelForCausalLM, AutoTokenizer,
11 | BitsAndBytesConfig)
12 |
13 | from xtuner.dataset import process_hf_dataset
14 | from xtuner.dataset.collate_fns import default_collate_fn
15 | from xtuner.dataset.map_fns import alpaca_map_fn, template_map_fn_factory
16 | from xtuner.engine.hooks import (DatasetInfoHook, EvaluateChatHook,
17 | VarlenAttnArgsToMessageHubHook)
18 | from xtuner.engine.runner import TrainLoop
19 | from xtuner.model import SupervisedFinetune
20 | from xtuner.utils import PROMPT_TEMPLATE, SYSTEM_TEMPLATE
21 |
22 | #######################################################################
23 | # PART 1 Settings #
24 | #######################################################################
25 | # Model
26 | pretrained_model_name_or_path = 'internlm/internlm-7b'
27 | use_varlen_attn = False
28 |
29 | # Data
30 | alpaca_en_path = 'tatsu-lab/alpaca'
31 | prompt_template = PROMPT_TEMPLATE.default
32 | max_length = 2048
33 | pack_to_max_length = True
34 |
35 | # Scheduler & Optimizer
36 | batch_size = 1 # per_device
37 | accumulative_counts = 2
38 | dataloader_num_workers = 0
39 | max_epochs = 3
40 | optim_type = AdamW
41 | lr = 2e-4
42 | betas = (0.9, 0.999)
43 | weight_decay = 0
44 | max_norm = 1 # grad clip
45 | warmup_ratio = 0.03
46 |
47 | # Save
48 | save_steps = 200
49 | save_total_limit = 5 # Maximum checkpoints to keep (-1 means unlimited)
50 |
51 | # Evaluate the generation performance during the training
52 | evaluation_freq = 500
53 | SYSTEM = SYSTEM_TEMPLATE.alpaca
54 | evaluation_inputs = [
55 | '请给我介绍五个上海的景点', 'Please tell me five scenic spots in Shanghai'
56 | ]
57 |
58 | #######################################################################
59 | # PART 2 Model & Tokenizer #
60 | #######################################################################
61 | tokenizer = dict(
62 | type=AutoTokenizer.from_pretrained,
63 | pretrained_model_name_or_path=pretrained_model_name_or_path,
64 | trust_remote_code=True,
65 | padding_side='right')
66 |
67 | model = dict(
68 | type=SupervisedFinetune,
69 | use_varlen_attn=use_varlen_attn,
70 | llm=dict(
71 | type=AutoModelForCausalLM.from_pretrained,
72 | pretrained_model_name_or_path=pretrained_model_name_or_path,
73 | trust_remote_code=True,
74 | torch_dtype=torch.float16,
75 | quantization_config=dict(
76 | type=BitsAndBytesConfig,
77 | load_in_4bit=True,
78 | load_in_8bit=False,
79 | llm_int8_threshold=6.0,
80 | llm_int8_has_fp16_weight=False,
81 | bnb_4bit_compute_dtype=torch.float16,
82 | bnb_4bit_use_double_quant=True,
83 | bnb_4bit_quant_type='nf4')),
84 | lora=dict(
85 | type=LoraConfig,
86 | r=64,
87 | lora_alpha=16,
88 | lora_dropout=0.1,
89 | bias='none',
90 | task_type='CAUSAL_LM'))
91 |
92 | #######################################################################
93 | # PART 3 Dataset & Dataloader #
94 | #######################################################################
95 | alpaca_en = dict(
96 | type=process_hf_dataset,
97 | dataset=dict(type=load_dataset, path=alpaca_en_path),
98 | tokenizer=tokenizer,
99 | max_length=max_length,
100 | dataset_map_fn=alpaca_map_fn,
101 | template_map_fn=dict(
102 | type=template_map_fn_factory, template=prompt_template),
103 | remove_unused_columns=True,
104 | shuffle_before_pack=True,
105 | pack_to_max_length=pack_to_max_length,
106 | use_varlen_attn=use_varlen_attn)
107 |
108 | train_dataloader = dict(
109 | batch_size=batch_size,
110 | num_workers=dataloader_num_workers,
111 | dataset=alpaca_en,
112 | sampler=dict(type=DefaultSampler, shuffle=True),
113 | collate_fn=dict(type=default_collate_fn, use_varlen_attn=use_varlen_attn))
114 |
115 | #######################################################################
116 | # PART 4 Scheduler & Optimizer #
117 | #######################################################################
118 | # optimizer
119 | optim_wrapper = dict(
120 | type=AmpOptimWrapper,
121 | optimizer=dict(
122 | type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
123 | clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
124 | accumulative_counts=accumulative_counts,
125 | loss_scale='dynamic',
126 | dtype='float16')
127 |
128 | # learning policy
129 | # More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501
130 | param_scheduler = [
131 | dict(
132 | type=LinearLR,
133 | start_factor=1e-5,
134 | by_epoch=True,
135 | begin=0,
136 | end=warmup_ratio * max_epochs,
137 | convert_to_iter_based=True),
138 | dict(
139 | type=CosineAnnealingLR,
140 | eta_min=0.0,
141 | by_epoch=True,
142 | begin=warmup_ratio * max_epochs,
143 | end=max_epochs,
144 | convert_to_iter_based=True)
145 | ]
146 |
147 | # train, val, test setting
148 | train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
149 |
150 | #######################################################################
151 | # PART 5 Runtime #
152 | #######################################################################
153 | # Log the dialogue periodically during the training process, optional
154 | custom_hooks = [
155 | dict(type=DatasetInfoHook, tokenizer=tokenizer),
156 | dict(
157 | type=EvaluateChatHook,
158 | tokenizer=tokenizer,
159 | every_n_iters=evaluation_freq,
160 | evaluation_inputs=evaluation_inputs,
161 | system=SYSTEM,
162 | prompt_template=prompt_template)
163 | ]
164 |
165 | if use_varlen_attn:
166 | custom_hooks += [dict(type=VarlenAttnArgsToMessageHubHook)]
167 |
168 | # configure default hooks
169 | default_hooks = dict(
170 | # record the time of every iteration.
171 | timer=dict(type=IterTimerHook),
172 | # print log every 10 iterations.
173 | logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
174 | # enable the parameter scheduler.
175 | param_scheduler=dict(type=ParamSchedulerHook),
176 | # save checkpoint per `save_steps`.
177 | checkpoint=dict(
178 | type=CheckpointHook,
179 | by_epoch=False,
180 | interval=save_steps,
181 | max_keep_ckpts=save_total_limit),
182 | # set sampler seed in distributed evrionment.
183 | sampler_seed=dict(type=DistSamplerSeedHook),
184 | )
185 |
186 | # configure environment
187 | env_cfg = dict(
188 | # whether to enable cudnn benchmark
189 | cudnn_benchmark=False,
190 | # set multi process parameters
191 | mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
192 | # set distributed parameters
193 | dist_cfg=dict(backend='nccl'),
194 | )
195 |
196 | # set visualizer
197 | visualizer = None
198 |
199 | # set log level
200 | log_level = 'INFO'
201 |
202 | # load from which checkpoint
203 | load_from = None
204 |
205 | # whether to resume training from the loaded checkpoint
206 | resume = False
207 |
208 | # Defaults to use random seed and disable `deterministic`
209 | randomness = dict(seed=None, deterministic=False)
210 |
211 | # set log processor
212 | log_processor = dict(by_epoch=False)
213 |
--------------------------------------------------------------------------------
/xtuner_config/build_config.py:
--------------------------------------------------------------------------------
1 | from mmengine import Config, ConfigDict
2 | from mmengine.config.lazy import LazyObject
3 | from xtuner.utils import PROMPT_TEMPLATE, SYSTEM_TEMPLATE
4 | import torch
5 | import os
6 | from .get_prompt_template import get_prompt_template
7 | CUR_DIR = os.path.dirname(__file__)
8 | TEMPLATE_DIR = os.path.join(os.path.dirname(CUR_DIR), "template_configs")
9 |
10 |
11 | MODEL_TO_TEMPLATE = {
12 | "baichuan-inc/Baichuan-7B": "default",
13 | "baichuan-inc/Baichuan-13B-Base": "default",
14 | "baichuan-inc/Baichuan-13B-Chat": "baichuan_chat",
15 | "baichuan-inc/Baichuan2-7B-Base": "default",
16 | "baichuan-inc/Baichuan2-7B-Chat": "baichuan2_chat",
17 | "baichuan-inc/Baichuan2-13B-Base": "default",
18 | "baichuan-inc/Baichuan2-13B-Chat": "baichuan2_chat",
19 | "THUDM/chatglm2-6b": "chatglm2",
20 | "THUDM/chatglm3-6b": "chatglm3",
21 | "THUDM/chatglm3-6b-base": "chatglm3",
22 | "deepseek-ai/deepseek-coder-6.7b-base": "deepseek_coder",
23 | "deepseek-ai/deepseek-coder-6.7b-instruct": "deepseek_coder",
24 | "internlm/internlm-7b": "default",
25 | "internlm/internlm-20b": "default",
26 | "internlm/internlm-chat-7b": "internlm_chat",
27 | "internlm/internlm-chat-20b": "internlm_chat",
28 | "huggyllama/llama-7b": "default",
29 | "meta-llama/Llama-2-7b-hf": "llama2_chat",
30 | "meta-llama/Llama-2-7b": "llama2_chat",
31 | "meta-llama/Llama-2-7b-chat-hf": "llama2_chat",
32 | "meta-llama/Llama-2-7b-chat": "llama2_chat",
33 | "meta-llama/Llama-2-70b-hf": "llama2_chat",
34 | "lmsys/vicuna-7b-v1.5": "vicuna",
35 | "lmsys/vicuna-13b-v1.5": "vicuna",
36 | "mistralai/Mistral-7B-v0.1": "mistral",
37 | "mistralai/Mixtral-8x7B-v0.1": "mixtral",
38 | "mistralai/Mixtral-8x7B-Instruct-v0.1": "mixtral",
39 | "Qwen/Qwen-1_8B": "default",
40 | "Qwen/Qwen-1_8B-Chat": "qwen_chat",
41 | "Qwen/Qwen-7B": "default",
42 | "Qwen/Qwen-7B-Chat": "qwen_chat",
43 | "Qwen/Qwen-72B": "default",
44 | "Qwen/Qwen-72B-Chat": "qwen_chat",
45 | "bigcode/starcoder": "default",
46 | "01-ai/Yi-6B": "default",
47 | "01-ai/Yi-34B": "default",
48 | "HuggingFaceH4/zephyr-7b-beta": "zephyr",
49 | "deepseek-ai/deepseek-moe-16b-base": "deepseek_moe",
50 | "deepseek-ai/deepseek-moe-16b-chat": "deepseek_moe",
51 | "internlm/internlm2-7b": "default",
52 | "internlm/internlm2-20b": "default",
53 | "internlm/internlm2-chat-7b": "internlm2_chat",
54 | "internlm/internlm2-chat-20b": "internlm2_chat"
55 | }
56 |
57 | DATA2MAPFN = {
58 | 'tatsu-lab/alpaca': 'alpaca_map_fn',
59 | 'silk-road/alpaca-data-gpt4-chinese': 'alpaca_zh_map_fn',
60 | 'garage-bAInd/Open-Platypus': 'alpaca_map_fn',
61 | 'HuggingFaceH4/CodeAlpaca_20K': 'code_alpaca_map_fn',
62 | 'burkelibbey/colors': 'colors_map_fn',
63 | 'shibing624/medical': 'medical_map_fn',
64 | 'damo/MSAgent-Bench': 'msagent_react_map_fn',
65 | 'timdettmers/openassistant-guanaco': 'oasst1_map_fn',
66 | 'Open-Orca/OpenOrca': 'openorca_map_fn',
67 | 'Skywork/SkyPile-150B': 'pretrain_map_fn',
68 | 'mistralai/Mistral-7B-v0.1': 'pretrain_map_fn',
69 | 'b-mc2/sql-create-context': 'sql_map_fn',
70 | 'ArmelR/stack-exchange-instruction': 'stack_exchange_map_fn',
71 | 'nampdn-ai/tiny-codes': 'tiny_codes_map_fn',
72 | 'WizardLM/WizardLM_evol_instruct_V2_196k': 'wizardlm_map_fn',
73 | }
74 |
75 | def data_path_map_fn(file):
76 | if file in DATA2MAPFN:
77 | return DATA2MAPFN[file]
78 | for k, v in DATA2MAPFN.items():
79 | k_list = k.split('/')
80 | k_fix = '_'.join(k_list)
81 | if k_fix in file:
82 | return v
83 | return None
84 |
85 | def model_path_map_fn(file):
86 | print(f'model_path_map_fn({file})')
87 | if file in MODEL_TO_TEMPLATE:
88 | return MODEL_TO_TEMPLATE[file]
89 | for k, v in MODEL_TO_TEMPLATE.items():
90 | k_list = k.split('/')
91 | k_fix = '_'.join(k_list)
92 | if k_fix in file:
93 | return v
94 | return None
95 |
96 | """
97 | save_checkpoint_ratio -> save_checkpoint_interval
98 | accumulative_counts -> accumulative_counts
99 | 新增 save_total_limit
100 | 'bigcode/starcoder' 不是DATA_LIST
101 | """
102 | def traverse_keys(cfg_dict, target_keys, new_value):
103 | if isinstance(cfg_dict, dict):
104 | for key, value in dict.items(cfg_dict):
105 | if key in target_keys:
106 | cfg_dict[key] = new_value
107 | else:
108 | traverse_keys(value, target_keys, new_value)
109 | elif isinstance(cfg_dict, (list, tuple)):
110 | for value in cfg_dict:
111 | traverse_keys(value, target_keys, new_value)
112 |
113 | def traverse_value(cfg_dict, target_value, new_value):
114 | if isinstance(cfg_dict, dict):
115 | for key, value in dict.items(cfg_dict):
116 | if value == target_value:
117 | cfg_dict[key] = new_value
118 | else:
119 | traverse_value(value, target_value, new_value)
120 | elif isinstance(cfg_dict, (list, tuple)):
121 | for value in cfg_dict:
122 | traverse_value(value, target_value, new_value)
123 |
124 |
125 | def set_model_related(cfg, model_path):
126 | traverse_keys(cfg._cfg_dict, ('pretrained_model_name_or_path', ), model_path)
127 |
128 |
129 | def set_data_related(cfg, dataset, is_custom_dataset, prompt_template, max_length, pack_to_max_length):
130 | if is_custom_dataset:
131 | dataset = ConfigDict(path='json', data_files=dataset)
132 | cfg.alpaca_en.dataset.update(dataset)
133 | cfg.train_dataloader.dataset.dataset.update(dataset)
134 |
135 | traverse_keys(cfg._cfg_dict, ('dataset_map_fn', ), LazyObject('xtuner.dataset.map_fns', 'openai_map_fn'))
136 | else:
137 | traverse_value(cfg._cfg_dict, 'tatsu-lab/alpaca', dataset)
138 |
139 | traverse_keys(cfg._cfg_dict, ('dataset_map_fn', ), LazyObject('xtuner.dataset.map_fns', data_path_map_fn(dataset)))
140 |
141 | assert prompt_template in PROMPT_TEMPLATE, \
142 | f'Expect prompt_template to be one of {PROMPT_TEMPLATE.keys()}, but got {prompt_template}.'
143 | prompt_template = PROMPT_TEMPLATE[prompt_template]
144 | traverse_keys(cfg._cfg_dict, ('template', 'prompt_template'), prompt_template)
145 |
146 | traverse_keys(cfg._cfg_dict, ('max_length', ), max_length)
147 |
148 | traverse_keys(cfg._cfg_dict, ('pack_to_max_length', ), pack_to_max_length)
149 |
150 |
151 | def set_scheduler_optimizer_related(
152 | cfg, batch_size_per_device, accumulative_counts, dataloader_num_workers,
153 | max_epochs, optim_type, lr, beta1, beta2, weight_decay, max_norm, warmup_ratio):
154 | traverse_keys(cfg._cfg_dict, ('batch_size', ), batch_size_per_device)
155 | traverse_keys(cfg._cfg_dict, ('accumulative_counts', ), accumulative_counts)
156 | traverse_keys(cfg._cfg_dict, ('dataloader_num_workers', 'num_workers'), dataloader_num_workers)
157 |
158 | traverse_keys(cfg._cfg_dict, ('max_epochs', 'T_max'), max_epochs)
159 | cfg.param_scheduler[0].end = warmup_ratio * max_epochs
160 | cfg.param_scheduler[1].begin = warmup_ratio * max_epochs
161 | cfg.warmup_ratio = warmup_ratio
162 |
163 | assert hasattr(torch.optim, optim_type)
164 | cfg.optim_type = LazyObject('torch.optim', optim_type)
165 | cfg.optim_wrapper.optimizer.type = LazyObject('torch.optim', optim_type)
166 |
167 | cfg.lr = lr
168 | cfg.optim_wrapper.optimizer.lr = lr
169 |
170 | if optim_type == 'AdamW':
171 | traverse_keys(cfg._cfg_dict, ('betas', ), (beta1, beta2))
172 |
173 | traverse_keys(cfg._cfg_dict, ('weight_decay', ), weight_decay)
174 | traverse_keys(cfg._cfg_dict, ('max_norm', ), max_norm)
175 |
176 |
177 | def set_checkpoint_related(cfg, save_checkpoint_interval, save_total_limit):
178 | cfg.save_steps = save_checkpoint_interval
179 | cfg.default_hooks.checkpoint.interval = save_checkpoint_interval
180 |
181 | cfg.save_total_limit = save_total_limit
182 | cfg.default_hooks.checkpoint.max_keep_ckpts = save_total_limit
183 |
184 |
185 | def set_evaluate_related(cfg, evaluation_freq, evaluation_system_prompt, evaluation_inputs):
186 | traverse_keys(cfg._cfg_dict, ('evaluation_freq', 'every_n_iters'), evaluation_freq)
187 |
188 | system_prompt = SYSTEM_TEMPLATE[evaluation_system_prompt] if evaluation_system_prompt else ''
189 | traverse_keys(cfg._cfg_dict, ('SYSTEM', 'system'), system_prompt)
190 |
191 | # evaluation_inputs = [evaluation_input1, evaluation_input2]
192 | traverse_keys(cfg._cfg_dict, ('evaluation_inputs', ), evaluation_inputs)
193 |
194 |
195 | def build_config(
196 | ft_method, model_path, dataset, is_custom_dataset, deepspeed, lr, warmup_ratio, batch_size_per_device,
197 | accumulative_counts, num_GPU, max_length, pack_to_max_length, max_epochs, save_checkpoint_interval, save_total_limit,
198 | evaluation_freq, evaluation_system_prompt, evaluation_inputs,
199 | optim_type, weight_decay, max_norm, dataloader_num_workers, beta1, beta2,
200 | prompt_template):
201 | if ft_method == 'full':
202 | cfg = Config.fromfile(f'{TEMPLATE_DIR}/full_finetune.py')
203 | elif ft_method == 'lora':
204 | cfg = Config.fromfile(f'{TEMPLATE_DIR}/lora.py')
205 | elif ft_method == 'qlora':
206 | cfg = Config.fromfile(f'{TEMPLATE_DIR}/qlora.py')
207 | else:
208 | raise NotImplementedError(f'Expect ft_method to be one of (full, lora, qlora), but got {ft_method}.')
209 |
210 | set_model_related(cfg, model_path)
211 | set_data_related(cfg, dataset, is_custom_dataset, prompt_template, max_length, pack_to_max_length)
212 | set_scheduler_optimizer_related(cfg, batch_size_per_device, accumulative_counts, dataloader_num_workers,
213 | max_epochs, optim_type, lr, beta1, beta2, weight_decay, max_norm, warmup_ratio)
214 | set_checkpoint_related(cfg, save_checkpoint_interval, save_total_limit)
215 | set_evaluate_related(cfg, evaluation_freq, evaluation_system_prompt, evaluation_inputs)
216 |
217 | return cfg
218 |
219 |
220 | kwargs = dict(
221 | ft_method='full',
222 | model_path='/mnt/petrelfs/share_data/caoweihan/official_Ampere_7B_1_0_0',
223 | dataset='timdettmers/openassistant-guanaco',
224 | is_custom_dataset=False,
225 | deepspeed=None, # 与生成config无关
226 | lr=2e-5,
227 | warmup_ratio=0.03,
228 | batch_size_per_device=1,
229 | accumulative_counts=2,
230 | num_GPU=None, # 与生成config无关
231 | max_length=2048,
232 | pack_to_max_length=True,
233 | max_epochs=2,
234 | save_checkpoint_interval=1000,
235 | save_total_limit=2,
236 | evaluation_freq=100,
237 | evaluation_system_prompt='',
238 | evaluation_inputs=['请给我介绍五个上海的景点', 'Please tell me five scenic spots in Shanghai'],
239 | optim_type='AdamW',
240 | weight_decay=0,
241 | max_norm=1,
242 | dataloader_num_workers=0,
243 | beta1=0.9,
244 | beta2=0.999,
245 | prompt_template='internlm2_chat'
246 | )
247 |
248 | int_args = [
249 | 'batch_size_per_device',
250 | 'accumulative_counts',
251 | 'num_GPU',
252 | 'max_length',
253 | 'pack_to_max_length',
254 | 'max_epochs',
255 | 'save_checkpoint_interval',
256 | 'save_total_limit',
257 | 'evaluation_freq',
258 | 'dataloader_num_workers',
259 | ]
260 | default_args_key = [
261 | 'ft_method',
262 | 'model_path',
263 | 'dataset',
264 | 'deepspeed',
265 | 'lr',
266 | 'warmup_ratio',
267 | 'batch_size_per_device',
268 | 'accumulative_counts',
269 | 'num_GPU',
270 | 'max_length',
271 | 'pack_to_max_length',
272 | 'max_epochs',
273 | 'save_checkpoint_interval',
274 | 'save_total_limit',
275 | 'evaluation_freq',
276 | 'evaluation_system_prompt',
277 | 'optim_type',
278 | 'weight_decay',
279 | 'max_norm',
280 | 'dataloader_num_workers',
281 | 'beta1',
282 | 'beta2',
283 | 'prompt_template',
284 | ]
285 |
286 | def build_config_path(root_dir):
287 | work_dir = os.path.join(root_dir, 'work_dir')
288 | if not os.path.exists(work_dir):
289 | os.system(f'mkdir -p {work_dir}')
290 | return os.path.join(work_dir, 'xtuner_config.py')
291 |
292 |
293 | def build_and_save_config(
294 | dataset_personal_path,
295 | dataset_personal,
296 | model_personal_path,
297 | personal_model,
298 | detect_prompt_template,
299 | root_dir,
300 | *args, **kwargs
301 | ):
302 | kwargs.update(
303 | dict(zip(default_args_key, list(args)))
304 | )
305 | # prepare 'evaluation_inputs'
306 | evaluation_inputs = list(args)[len(default_args_key):]
307 | kwargs['evaluation_inputs'] = [i for i in evaluation_inputs if len(i)]
308 | print(f'dataset_personal_path={dataset_personal_path}||')
309 | # float -> int
310 | for k in int_args:
311 | kwargs[k] = int(kwargs[k])
312 | # custom dataset
313 | kwargs['is_custom_dataset'] = False
314 | # dataset_personal_path > dataset_personal > dataset
315 | if dataset_personal is not None and len(dataset_personal) >= 3:
316 | kwargs['is_custom_dataset'] = True
317 | kwargs['dataset'] = dataset_personal
318 | if dataset_personal_path is not None and len(dataset_personal_path) >= 3:
319 | kwargs['is_custom_dataset'] = True
320 | kwargs['dataset'] = dataset_personal_path
321 |
322 | # dropdown-list prompt_template
323 | prompt_template = model_path_map_fn(kwargs['model_path'])
324 | if personal_model is not None and len(personal_model) >= 3:
325 | kwargs['model_path'] = personal_model
326 | prompt_template = detect_prompt_template
327 |
328 | if model_personal_path is not None and len(model_personal_path) >= 3:
329 | kwargs['model_path'] = model_personal_path
330 | prompt_template = detect_prompt_template
331 |
332 | # final prompt_template
333 | kwargs['prompt_template'] = prompt_template
334 | if kwargs['prompt_template'] is None:
335 | kwargs['prompt_template'] = detect_prompt_template
336 | print(f'kwargs={kwargs}')
337 | cfg = build_config(**kwargs)
338 | cfg_py = build_config_path(root_dir)
339 | cfg.dump(cfg_py)
340 | print('cfg_py=', cfg_py)
341 | return cfg_py
342 |
343 |
344 | if __name__ == '__main__':
345 | build_and_save_config('.', **kwargs)
346 |
--------------------------------------------------------------------------------
/xtuner_config/check_custom_dataset.py:
--------------------------------------------------------------------------------
1 | from datasets import load_dataset
2 | import pprint
3 |
4 |
5 | DATA_EXAMPLE = """example["messages"] = [
6 | { "role": "system", "content": "You are an assistant that
7 | occasionally misspells words." },
8 | { "role": "user", "content": "Tell me a story." },
9 | { "role": "assistant", "content": "One day a student
10 | went to schoool." }]"""
11 |
12 | def check_custom_dataset(input_path, upload_path):
13 | path = input_path if len(input_path) >= 3 else upload_path
14 | try:
15 | data = load_dataset('json', data_files=path)
16 | except:
17 | return f"There's a problem with the JSON file in {path}; it can't be read."
18 | data = data['train']
19 |
20 | if 'messages' not in data.column_names:
21 | return ('Expect "messages" as a column in the dataset. Here is an '
22 | f'example:\n{DATA_EXAMPLE}')
23 |
24 | if not isinstance(data['messages'], (list, tuple)):
25 | return ('Expect the type of example["messages"] to be a list or '
26 | f'a tuple, but got {type(data["messages"])}.'
27 | f'Here is an example:\n{DATA_EXAMPLE}')
28 |
29 | check_first_n_messages = 100
30 | for message_idx, message in enumerate(data['messages'][:check_first_n_messages]):
31 | for conv_idx, single_conversation in enumerate(message):
32 | if not isinstance(single_conversation, dict):
33 | return ('Expect each single conversation to be a dict, '
34 | f'but got {type(single_conversation)}. '
35 | f'Here is an example:\n{DATA_EXAMPLE}')
36 | if not {'role', 'content'}.issubset(single_conversation.keys()):
37 | return ('Expect "role" and "content" in each single '
38 | f'conversation. The {conv_idx + 1} conversation in the'
39 | f' {message_idx} message is {single_conversation}.'
40 | f'Here is an example:\n{DATA_EXAMPLE}')
41 |
42 | return 'Data is OK.'
43 |
44 |
45 | if __name__ == "__main__":
46 | out = check_custom_dataset('/mnt/petrelfs/caoweihan/projects/xtuner/data.json')
47 | if out is None:
48 | print('Data is OK.')
49 | else:
50 | pprint.pprint(out)
51 |
--------------------------------------------------------------------------------
/xtuner_config/get_default_hyperparameters.py:
--------------------------------------------------------------------------------
1 | import copy
2 | from collections import OrderedDict
3 |
4 |
5 | DEFAULT_HYPERPARAMETERS = OrderedDict(
6 | warmup_ratio=0.03,
7 | batch_size_per_device=1,
8 | accumulative_counts=2,
9 | num_GPU=8,
10 | max_length=2048,
11 | pack_to_max_length=True,
12 | evaluation_freq=500,
13 | optim_type='AdamW',
14 | weight_decay=0,
15 | max_norm=1,
16 | dataloader_num_workers=0,
17 | beta1=0.9,
18 | beta2=0.999,
19 | lr=2e-5,
20 | save_checkpoint_interval=500,
21 | save_total_limit=2
22 | )
23 |
24 |
25 | def get_default_hyperparameters(ft_method):
26 | out_dict = copy.deepcopy(DEFAULT_HYPERPARAMETERS)
27 | if ft_method.lower() == 'full':
28 | out_dict.update(dict(
29 | lr=2e-5,
30 | save_checkpoint_interval=500,
31 | save_total_limit=2))
32 | else:
33 | out_dict.update(dict(
34 | lr=2e-4,
35 | save_checkpoint_interval=200,
36 | save_total_limit=5))
37 | out_list = []
38 | for i in DEFAULT_HYPERPARAMETERS.keys():
39 | out_list.append(out_dict[i])
40 | return out_list
41 |
42 |
43 | if __name__ == '__main__':
44 | print(DEFAULT_HYPERPARAMETERS.keys())
45 | print(
46 | get_default_hyperparameters('full')
47 | )
48 | print(
49 | get_default_hyperparameters('lora')
50 | )
51 |
52 |
--------------------------------------------------------------------------------
/xtuner_config/get_prompt_template.py:
--------------------------------------------------------------------------------
1 | from transformers import AutoConfig, AutoModelForCausalLM
2 | from accelerate import init_empty_weights
3 |
4 |
5 | MODEL_TEMPLATE_MAPPING = dict(
6 | InternLM2ForCausalLM='internlm2_chat',
7 | InternLMForCausalLM='internlm_chat',
8 | BaiChuanForCausalLM='baichuan_chat',
9 | BaichuanForCausalLM='baichuan2_chat',
10 | DeepseekForCausalLM='deepseek_moe',
11 | MixtralForCausalLM='mixtral',
12 | QWenLMHeadModel='qwen_chat',
13 | GPTBigCodeForCausalLM='default'
14 | )
15 |
16 |
17 | def get_prompt_template(pretrained_model_name_or_path):
18 | try:
19 | config = AutoConfig.from_pretrained(pretrained_model_name_or_path, trust_remote_code=True)
20 | with init_empty_weights():
21 | model = AutoModelForCausalLM.from_config(config, trust_remote_code=True)
22 | except:
23 | return f'Model {pretrained_model_name_or_path} can not be loaded.', None
24 | model_type = type(model).__name__
25 | if model_type == 'LlamaForCausalLM':
26 | vocab_size = config.vocab_size
27 | if vocab_size == 32256:
28 | return 'Success', 'deepseek_coder'
29 | elif vocab_size == 64000: # yi
30 | return 'Success', 'default'
31 | elif vocab_size == 32000: # llama2
32 | return 'Success', 'llama2_chat'
33 | elif model_type == 'ChatGLMForConditionalGeneration':
34 | seq_length = config.seq_length
35 | if seq_length == 131072:
36 | return 'Success', 'chatglm3'
37 | elif seq_length == 32768:
38 | return 'Success', 'chatglm2'
39 | else:
40 | return 'Fail to match automatically, please enter corresponding prompt template manually', None
41 | elif model_type == 'MistralForCausalLM':
42 | # 无法判断
43 | return 'The prompt template should be one of mistral or zephyr, please enter the correct prompt template manually', None
44 | elif model_type in MODEL_TEMPLATE_MAPPING:
45 | return 'Success', MODEL_TEMPLATE_MAPPING[model_type]
46 | else:
47 | return 'Fail to match automatically, please enter corresponding prompt template manually', None
48 |
49 |
50 | def app_get_prompt_template(input_path, upload_path):
51 | print(f'app_get_prompt_template({input_path}, {upload_path})')
52 | pretrained_model_name_or_path = input_path if len(input_path) >= 3 else upload_path
53 | info, prompt_template = get_prompt_template(pretrained_model_name_or_path)
54 | return f'{info} >> {prompt_template}', prompt_template
55 |
56 |
57 | if __name__ == "__main__":
58 | print(get_prompt_template('/mnt/petrelfs/share_data/caoweihan/official_Ampere_7B_1_0_0'))
--------------------------------------------------------------------------------
/xtuner_convert/convert_and_merge.py:
--------------------------------------------------------------------------------
1 |
2 | # python3
3 | # Create Date: 2024-01-30
4 | # Author: 爱科研的瞌睡虫
5 |
6 | import os
7 | from .merge import merge
8 | from .pth_to_hf import convert_to_hf
9 |
10 | def _convert_and_merged(config_file, pth_model, save_hf_dir, model_path, save_merged_dir):
11 | convert_to_hf(config_file, pth_model, save_hf_dir)
12 | merge(model_path, save_hf_dir, save_merged_dir)
13 |
14 |
15 | def build_convert_and_merged_path(root_dir, epoch_pth):
16 | epoch = os.path.basename(epoch_pth).split('.')[0]
17 | work_dir = os.path.join(root_dir, 'work_dir')
18 | if not os.path.exists(work_dir):
19 | os.system(f'mkdir -p {work_dir}')
20 | hf = os.path.join(work_dir, f'xtuner_{epoch}_hf')
21 | mg = os.path.join(work_dir, f'xtuner_{epoch}_merge')
22 | # clear
23 | if os.path.exists(hf):
24 | os.system(f'rm -rf {hf}')
25 | if os.path.exists(mg):
26 | os.system(f'rm -rf {mg}')
27 | return work_dir, hf, mg
28 |
29 |
30 | def convert_and_merged(root_dir, config_file, epoch_pth, model_path, model_personal_path, ft_method):
31 | if len(model_personal_path) >= 3:
32 | model_path = model_personal_path
33 | work_dir, save_hf_dir, save_merged_dir = build_convert_and_merged_path(root_dir, epoch_pth)
34 | pth_model = os.path.join(work_dir, epoch_pth)
35 | print(
36 | f'config_file = {config_file}'
37 | ,f'\npth_model = {pth_model}'
38 | ,f'\nsave_hf_dir = {save_hf_dir}'
39 | ,f'\nmodel_path ={model_path}'
40 | ,f'\nsave_merged_dir ={save_merged_dir}'
41 | )
42 | merged_flag = ft_method != 'full'
43 | try:
44 | convert_to_hf(config_file, pth_model, save_hf_dir)
45 | out_dir = save_hf_dir
46 | if merged_flag:
47 | merge(model_path, save_hf_dir, save_merged_dir)
48 | out_dir = save_merged_dir
49 |
50 | info = 'Successfully converted model ! '
51 | except Exception as e:
52 | info = e
53 | pass
54 | return info, out_dir
55 |
56 |
57 | if __name__ == '__main__':
58 |
59 | config_file = '/root/ft-oasst1/internlm_chat_7b_qlora_oasst1_e3_copy.py'
60 | pth_model = '/root/ft-oasst1/work_dirs/internlm_chat_7b_qlora_oasst1_e3_copy/epoch_1.pth'
61 | save_hf_dir = '/root/ft-oasst1/hf5'
62 | model_path = '/root/ft-oasst1/internlm-chat-7b'
63 | save_merged_dir = '/root/ft-oasst1/merged5'
64 |
65 | _convert_and_merged(config_file, pth_model, save_hf_dir, model_path, save_merged_dir)
66 |
--------------------------------------------------------------------------------
/xtuner_convert/convert_with_progress.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | from .convert_and_merge import build_convert_and_merged_path, convert_to_hf, merge
4 | import re
5 | from tqdm.auto import tqdm
6 | import threading
7 | import time
8 | import gradio as gr
9 |
10 |
11 | class ConvertMerged:
12 | def __init__(self):
13 | self.save_hf_dir = None
14 | self.save_merged_dir = None
15 | self.merged_flag = True
16 | self.model_path = None
17 | self.out_dir = None
18 | self.info = None
19 |
20 | def convert_and_merged(self, root_dir, config_file, epoch_pth, model_path, model_personal_path, ft_method):
21 | if len(model_personal_path) >= 3:
22 | model_path = model_personal_path
23 |
24 | self.model_path = model_path
25 | work_dir, save_hf_dir, save_merged_dir = build_convert_and_merged_path(root_dir, epoch_pth)
26 | self.save_hf_dir = save_hf_dir
27 | self.save_merged_dir = save_merged_dir
28 | pth_model = os.path.join(work_dir, epoch_pth)
29 | print(
30 | f'config_file = {config_file}'
31 | ,f'\npth_model = {pth_model}'
32 | ,f'\nsave_hf_dir = {save_hf_dir}'
33 | ,f'\nmodel_path ={model_path}'
34 | ,f'\nsave_merged_dir ={save_merged_dir}'
35 | )
36 | self.merged_flag = ft_method.lower() != 'full'
37 | try:
38 | convert_to_hf(config_file, pth_model, save_hf_dir)
39 | self.out_dir = save_hf_dir
40 | if self.merged_flag:
41 | merge(model_path, save_hf_dir, save_merged_dir)
42 | self.out_dir = save_merged_dir
43 |
44 | self.info = 'Successfully converted model ! '
45 | except Exception as e:
46 | self.info = e
47 | pass
48 | return self.info, self.out_dir
49 |
50 | def auto_convert_merge(
51 | self,
52 | root_dir,
53 | config_file,
54 | epoch_pth,
55 | model_path,
56 | model_personal_path,
57 | ft_method,
58 | progress=gr.Progress(track_tqdm=True)
59 | ):
60 | self._t_convert(root_dir, config_file, epoch_pth, model_path, model_personal_path, ft_method)
61 | time.sleep(2)
62 | print(
63 | f'self.model_path={self.model_path}',
64 | f'self.save_hf_dir={self.save_hf_dir}',
65 | f'self.save_merged_dir={self.save_merged_dir}'
66 | )
67 | self.progress()
68 | return self.info, self.out_dir
69 |
70 | def _t_convert(self, root_dir, config_file, epoch_pth, model_path, model_personal_path, ft_method):
71 | self._t_handle_convert = threading.Thread(
72 | target=self.convert_and_merged, args=(root_dir, config_file, epoch_pth, model_path, model_personal_path, ft_method) ,
73 | name='X-model-convert-merge', daemon=True)
74 | self._t_handle_convert.start()
75 |
76 | def find_max_sub(self, _dir):
77 | total_ = [i for i in os.listdir(_dir) if len(re.findall(r'[0-9]+-of-[0-9]+', i))]
78 | if len(total_):
79 | info = [int(re.findall(r'(\d+)-of', i)[0]) for i in total_ if len(re.findall(r'(\d+)-of', i))]
80 | if len(info):
81 | return max(info)
82 | return 0
83 |
84 | def progress(self, progress=None):
85 | big_step = 100
86 | total = 0
87 | hf_total = 1
88 | # /root/share/model_repos/internlm-chat-7b/pytorch_model-00001-of-00008.bin
89 | base_model_parts = [i for i in os.listdir(self.model_path) if len(re.findall(r'[0-9]+-of-[0-9]+', i))]
90 | base_max = 0
91 | if len(base_model_parts):
92 | fd = re.findall(r'-of-(\d+)', base_model_parts[0])
93 | print(f'progress fd => {fd}')
94 | base_max = int(fd[0]) if len(fd) else 0
95 | if self.merged_flag:
96 | total = hf_total + base_max
97 | else:
98 | total = hf_total = base_max
99 |
100 | hf_total *= big_step
101 | total *= big_step
102 | tq_bar = tqdm(total=total)
103 | big_step_hf_now = 0
104 | big_step_mg_now = 0
105 | hf_now = 0
106 | mg_now = 0
107 | while True:
108 | if self._t_handle_convert is None:
109 | break
110 | if not self._t_handle_convert.is_alive():
111 | break
112 |
113 | up_hf = 0
114 | if os.path.exists(self.save_hf_dir) and not self.merged_flag:
115 | max_hf_b = self.find_max_sub(self.save_hf_dir) * big_step
116 | # 在一个的时候
117 | if big_step_hf_now == max_hf_b and (big_step_hf_now + big_step) > hf_now and hf_now < hf_total:
118 | up_hf = 1
119 | hf_now += 1
120 | elif max_hf_b > hf_now:
121 | up_hf = max_hf_b - hf_now
122 | hf_now = max_hf_b
123 | else:
124 | up_hf = 0
125 |
126 | big_step_hf_now = max_hf_b
127 | elif self.merged_flag and not os.path.exists(self.save_hf_dir):
128 | if big_step >= hf_now:
129 | up_hf = 1
130 | hf_now += 1
131 | else:
132 | max_hf_b = big_step
133 | up_hf = max_hf_b - hf_now
134 | hf_now = max_hf_b
135 |
136 | up_mg = 0
137 | if self.merged_flag:
138 | if not os.path.exists(self.save_merged_dir):
139 | if big_step > mg_now:
140 | up_mg = 1
141 | mg_now += 1
142 | else:
143 | max_mg_b = self.find_max_sub(self.save_merged_dir) * big_step
144 | # 在一个的时候
145 | if big_step_mg_now == max_mg_b and (big_step_mg_now + big_step) > mg_now and (mg_now + hf_now) < total:
146 | up_mg = 1
147 | mg_now += 1
148 | elif max_mg_b > mg_now:
149 | up_mg = max_mg_b - mg_now
150 | mg_now = max_mg_b
151 | else:
152 | up_mg = 0
153 |
154 | big_step_mg_now = max_mg_b
155 |
156 | tq_bar.update(up_mg + up_hf)
157 | time.sleep(1)
158 |
159 |
--------------------------------------------------------------------------------
/xtuner_convert/merge.py:
--------------------------------------------------------------------------------
1 | # python3
2 | # Create Date: 2024-01-30
3 | # Author: 爱科研的瞌睡虫
4 |
5 | import argparse
6 |
7 | import torch
8 | from peft import PeftModel
9 | from transformers import AutoModelForCausalLM, AutoTokenizer
10 |
11 |
12 | def parse_args():
13 | parser = argparse.ArgumentParser(
14 | description='Merge a HuggingFace adapter to LLM')
15 | # parser.add_argument('model_name_or_path', help='model name or path')
16 | # parser.add_argument('adapter_name_or_path', help='adapter name or path')
17 | # parser.add_argument(
18 | # 'save_dir', help='the directory to save the merged model')
19 | parser.add_argument(
20 | '--max-shard-size',
21 | type=str,
22 | default='2GB',
23 | help='Only applicable for LLM. The maximum size for '
24 | 'each sharded checkpoint.')
25 | parser.add_argument(
26 | '--offload-folder',
27 | default=None,
28 | help='The folder in which to offload the model weights (or where '
29 | 'the model weights are already offloaded).')
30 | args = parser.parse_args()
31 | return args
32 |
33 |
34 | def merge(model_path, adapter_hf_path, save_dir):
35 | args = parse_args()
36 | model = AutoModelForCausalLM.from_pretrained(
37 | model_path,
38 | torch_dtype=torch.float16,
39 | low_cpu_mem_usage=True,
40 | device_map='auto',
41 | offload_folder=args.offload_folder,
42 | trust_remote_code=True)
43 | tokenizer = AutoTokenizer.from_pretrained(
44 | model_path,
45 | trust_remote_code=True,
46 | encode_special_tokens=True)
47 | model_unmerged = PeftModel.from_pretrained(
48 | model,
49 | adapter_hf_path,
50 | device_map='auto',
51 | torch_dtype=torch.float16,
52 | offload_folder=args.offload_folder,
53 | is_trainable=False)
54 | model_merged = model_unmerged.merge_and_unload()
55 | print(f'Merged Saving to {save_dir}...')
56 | model_merged.save_pretrained(
57 | save_dir, max_shard_size=args.max_shard_size)
58 | tokenizer.save_pretrained(save_dir)
59 | print('Merged All done!')
60 |
61 |
62 | if __name__ == '__main__':
63 |
64 | model_path = '/root/ft-oasst1/internlm-chat-7b'
65 | adapter_hf_path = '/root/ft-oasst1/hf4'
66 | save_dir = '/root/ft-oasst1/merged5'
67 |
68 | merge(model_path, adapter_hf_path, save_dir)
--------------------------------------------------------------------------------
/xtuner_convert/pth_to_hf.py:
--------------------------------------------------------------------------------
1 | # python3
2 | # Create Date: 2024-01-30
3 | # Author: 爱科研的瞌睡虫
4 |
5 | import argparse
6 | import os
7 | import shutil
8 |
9 | import torch
10 | from mmengine.config import Config, DictAction
11 |
12 | from xtuner.configs import cfgs_name_path
13 | from xtuner.registry import BUILDER
14 |
15 |
16 | def parse_args():
17 | parser = argparse.ArgumentParser(
18 | description='Convert the pth model to HuggingFace model')
19 | # parser.add_argument('config', help='config file name or path.')
20 | # parser.add_argument('pth_model', help='pth model file')
21 | # parser.add_argument(
22 | # 'save_dir', help='the directory to save HuggingFace model')
23 | parser.add_argument(
24 | '--fp32',
25 | action='store_true',
26 | help='Save as fp32. If not set, fp16 will be used by default.')
27 | parser.add_argument(
28 | '--max-shard-size',
29 | type=str,
30 | default='2GB',
31 | help='Only applicable for LLM. The maximum size for '
32 | 'each sharded checkpoint.')
33 | parser.add_argument(
34 | '--cfg-options',
35 | nargs='+',
36 | action=DictAction,
37 | help='override some settings in the used config, the key-value pair '
38 | 'in xxx=yyy format will be merged into config file. If the value to '
39 | 'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
40 | 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
41 | 'Note that the quotation marks are necessary and that no white space '
42 | 'is allowed.')
43 | args = parser.parse_args()
44 | return args
45 |
46 |
47 | def guess_load_checkpoint(pth_model):
48 | if os.path.isfile(pth_model):
49 | state_dict = torch.load(pth_model, map_location='cpu')
50 | if 'state_dict' in state_dict:
51 | state_dict = state_dict['state_dict']
52 | elif os.path.isdir(pth_model):
53 | try:
54 | from deepspeed.utils.zero_to_fp32 import \
55 | get_fp32_state_dict_from_zero_checkpoint
56 | except ImportError:
57 | raise ImportError(
58 | 'The provided PTH model appears to be a DeepSpeed checkpoint. '
59 | 'However, DeepSpeed library is not detected in current '
60 | 'environment. This suggests that DeepSpeed may not be '
61 | 'installed or is incorrectly configured. Please verify your '
62 | 'setup.')
63 | state_dict = get_fp32_state_dict_from_zero_checkpoint(
64 | os.path.dirname(pth_model), os.path.basename(pth_model))
65 | else:
66 | raise FileNotFoundError(f'Cannot find {pth_model}')
67 | return state_dict
68 |
69 |
70 | def convert_to_hf(config_file, pth_model, save_dir):
71 | args = parse_args()
72 |
73 | # parse config
74 | if not os.path.isfile(config_file):
75 | try:
76 | config_file = cfgs_name_path[config_file]
77 | except KeyError:
78 | raise FileNotFoundError(f'Cannot find {config_file}')
79 |
80 | # load config
81 | cfg = Config.fromfile(config_file)
82 | if args.cfg_options is not None:
83 | cfg.merge_from_dict(args.cfg_options)
84 |
85 | model = BUILDER.build(cfg.model)
86 |
87 | state_dict = guess_load_checkpoint(pth_model)
88 | model.load_state_dict(state_dict, strict=False)
89 | print(f'Load PTH model from {pth_model}')
90 |
91 | if not args.fp32:
92 | print('Convert weights to float16')
93 | model.llm.half()
94 |
95 | print(f'Saving HuggingFace model to {save_dir}')
96 | model.llm.save_pretrained(
97 | save_dir, max_shard_size=args.max_shard_size)
98 | if 'PeftModel' not in model.llm.__class__.__name__:
99 | print(f'Saving HuggingFace tokenizer to {save_dir}')
100 | tokenizer = BUILDER.build(cfg.tokenizer)
101 | tokenizer.save_pretrained(save_dir)
102 | shutil.copyfile(config_file, os.path.join(save_dir,
103 | 'xtuner_config.py'))
104 | print('Pth to hf all done!')
105 |
106 |
107 | if __name__ == '__main__':
108 |
109 | config_file = '/root/ft-oasst1/internlm_chat_7b_qlora_oasst1_e3_copy.py'
110 | pth_model = '/root/ft-oasst1/work_dirs/internlm_chat_7b_qlora_oasst1_e3_copy/epoch_1.pth'
111 | save_dir = '/root/ft-oasst1/hf5'
112 |
113 | convert_to_hf(config_file, pth_model, save_dir)
--------------------------------------------------------------------------------
/xtuner_download/README.md:
--------------------------------------------------------------------------------
1 |
2 | ## 模型下载
3 |
4 | 核心流程
5 | ```mermaid
6 | flowchart LR
7 |
8 | IN1(model_name) --> A
9 | IN2(out_path) --> A
10 | IN3(tqdm_class) --> A
11 | subgraph xtunerModelDownload
12 | A(initial) --> B(获取模型信息:文件数量、文件大小)
13 |
14 | DM(起用下载线程:X-model-download)
15 | DP(起用进度线程:X-model-progress)
16 | clear(文件清除)
17 | DP -->|检查下载进度|DM
18 | end
19 |
20 | B --> C
21 | B --> Break
22 | C -->|Start线程|DM
23 | C -->|Start线程|DP --> grP
24 | subgraph gr-client
25 |
26 | C(模型下载)
27 | grP(进度条显示)
28 |
29 | Break(下载中断)
30 | end
31 |
32 | Break -->|kill线程|DM
33 | Break -->|kill线程|DP
34 | Break --> clear
35 |
36 | DM --> hf
37 | subgraph download
38 | modelscope(modelscope-download)
39 | hf(huggingface-download)
40 | openxlab(openxlab-download)
41 |
42 | hf--> check1{下载成功?}-->|否|modelscope--> check2{下载成功?}-->|否|openxlab
43 |
44 | check1 -->|是|finshed
45 | check2 -->|是|finshed
46 | end
47 | ```
48 |
49 | - example:
50 | ```python
51 | from download_model import xtunerModelDownload
52 | from tqdm.auto import tqdm
53 | import time
54 |
55 | model_name = 'internlm/internlm-chat-7b'
56 | d_model = xtunerModelDownload(
57 | model_name,
58 | out_path='/root/tmp/download_model',
59 | tqdm_class=tqdm
60 | )
61 | d_model.auto_download()
62 | time.sleep(60)
63 | d_model.break_download()
64 | print('Yes')
65 | ```
66 |
67 | ## 数据下载
68 |
69 |
70 | 核心流程
71 | ```mermaid
72 | flowchart LR
73 |
74 | IN1(data_name) --> A
75 | IN2(out_path) --> A
76 | IN3(tqdm_class) --> A
77 | subgraph xtunerModelDownload
78 | A(initial) --> B(获取数据信息:文件数量、文件大小)
79 |
80 | DM(起用下载线程:X-dataset-download)
81 | DP(起用进度线程:X-dataset-progress)
82 | clear(文件清除)
83 | DP -->|检查下载进度|DM
84 | end
85 |
86 | B --> C
87 | B --> Break
88 | C -->|Start线程|DM
89 | C -->|Start线程|DP --> grP
90 | subgraph gr-client
91 |
92 | C(数据下载)
93 | grP(进度条显示)
94 |
95 | Break(下载中断)
96 | end
97 |
98 | Break -->|kill线程|DM
99 | Break -->|kill线程|DP
100 | Break --> clear
101 |
102 | DM --> hf
103 | subgraph download
104 | hf(huggingface-download)
105 |
106 | hf--> check1{下载成功?}-->|否|hf-->|Retry-n times|check1
107 |
108 | check1 -->|是|finshed
109 | end
110 | ```
111 |
112 | - example:
113 | ```python
114 | from download_dataset import xtunerDataDownload
115 | from tqdm.auto import tqdm
116 | import time
117 |
118 | data_name = 'shibing624/medical'
119 | d_data = xtunerDataDownload(
120 | data_name,
121 | out_path='/root/tmp/download_data',
122 | tqdm_class=tqdm,
123 | retry_times=1
124 | )
125 | d_data.auto_download()
126 | time.sleep(60)
127 | d_data.break_download()
128 | print('Yes')
129 | ```
130 |
131 |
132 |
--------------------------------------------------------------------------------
/xtuner_download/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/scchy/XtunerGUI/6e510e39abf55169f45f0e94fb14055c2b973af4/xtuner_download/__init__.py
--------------------------------------------------------------------------------
/xtuner_download/data_list.txt:
--------------------------------------------------------------------------------
1 | ./data/CrimeKgAssitant清洗后_52k.json
2 | ./data/arxiv_data.json
3 | ./data/moss-003-sft-no-tools.jsonl
4 | ./data/训练数据_带法律依据_92k.json
5 |
6 |
7 | ArmelR/stack-exchange-instruction
8 | HuggingFaceH4/CodeAlpaca_20K
9 | Open-Orca/OpenOrca # 这个丰富的增强 FLAN 数据集尽可能地与 Orca 论文中概述的分布保持一致。它有助于生成高性能的模型检查点,是所有 NLP 研究人员和开发人员的宝贵资源!
10 | Skywork/SkyPile-150B # SkyPile-150B 是一个全面的大型中文数据集,专门用于大型语言模型的预训练。
11 | WizardLM/WizardLM_evol_instruct_V2_196k
12 | b-mc2/sql-create-context
13 | bigcode/starcoder
14 | burkelibbey/colors
15 | damo/MSAgent-Bench
16 | garage-bAInd/Open-Platypus
17 | mistralai/Mistral-7B-v0.1
18 | nampdn-ai/tiny-codes
19 | shibing624/medical
20 | silk-road/alpaca-data-gpt4-chinese
21 | tatsu-lab/alpaca
22 | timdettmers/openassistant-guanaco
23 |
--------------------------------------------------------------------------------
/xtuner_download/download_dataset.py:
--------------------------------------------------------------------------------
1 | # python3
2 | # Create Date: 2024-01-25
3 | # Author: Scc_hy
4 | # Func: 模型拉取到本地
5 | # ===========================================================================================
6 | import os
7 | import gradio as gr
8 | from tqdm.auto import tqdm
9 | from os.path import getsize as p_getsize
10 | from os.path import join as p_join
11 | import threading
12 | import time
13 | from .download_utils import stop_thread, _split_repo, get_hf_cache_files, get_data_info, get_final_out_files, TOKEN
14 | os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
15 | CUR_DIR = os.path.dirname(__file__)
16 |
17 |
18 | class xtunerDataDownload():
19 | def __init__(self, data_name, out_path, tqdm_class=tqdm, progress_sleep=1, retry_times=0) -> None:
20 | self.progress_sleep = progress_sleep
21 | self.run_times = retry_times + 1
22 | self.tqdm_class = tqdm_class
23 | self.username, self.repository = _split_repo(data_name)
24 | self.data_name = data_name
25 | self.out_path = out_path
26 | self.final_out_path = p_join(out_path, f'dataset_{self.username}_{self.repository}')
27 | self.mid_download_dir = self.final_out_path
28 | self._t_handle_dl = None
29 | self._t_handle_pg = None
30 | self._break_flag = False
31 | self._get_info_flag = False
32 | self.__check_create_dir()
33 | self.get_download_info()
34 |
35 | def reset_path(self, customer_dir):
36 | self.__remove_mid_files()
37 | self.__remove_final_files()
38 | self.out_path = f'{customer_dir}/data_download'
39 | print(f'xtunerDataDownload reset_path->{self.out_path}')
40 | self.final_out_path = p_join(self.out_path, f'{self.username}_{self.repository}')
41 | self.mid_download_dir = self.final_out_path
42 | self.__check_create_dir()
43 |
44 | def reset(self, data_name):
45 | self.remove_and_create()
46 | print(f'reset({data_name})')
47 | self.username, self.repository = _split_repo(data_name)
48 | self.data_name = data_name
49 | self.final_out_path = p_join(self.out_path, f'dataset_{self.username}_{self.repository}')
50 | self.mid_download_dir = self.final_out_path
51 | self._t_handle_dl = None
52 | self._t_handle_pg = None
53 | self._break_flag = False
54 | self._get_info_flag = False
55 | self.__check_create_dir()
56 | self.get_download_info()
57 |
58 | def get_download_info(self):
59 | self.total_MB, self.total_file_nums = get_data_info(self.data_name)
60 | self._get_info_flag = True
61 |
62 | def __check_create_dir(self):
63 | if not os.path.exists(self.out_path):
64 | os.system(f'mkdir -p {self.out_path}')
65 | if not os.path.exists(self.final_out_path):
66 | os.system(f'mkdir -p {self.final_out_path}')
67 |
68 | def __remove_mid_files(self):
69 | """中断时删除所有文件"""
70 | os.system(f'rm -rf {self.mid_download_dir}')
71 | # cd rm
72 | rm_dir = './' + self.mid_download_dir.replace(self.out_path, '.')[2:].split('/')[0]
73 | os.system(f'cd {self.out_path} && rm -rf {rm_dir} && rm -rf temp')
74 | # 删除 hf 的cache
75 | os.system(f'rm -rf {self.final_out_path}/cache')
76 |
77 | def __remove_final_files(self):
78 | os.system(f'rm -rf {self.final_out_path}')
79 | os.system(f'cd {self.out_path} && rm -rf dataset_{self.username}_{self.repository}')
80 |
81 | def remove_and_create(self):
82 | self.__remove_mid_files()
83 | self.__remove_final_files()
84 | self.__check_create_dir()
85 |
86 | def auto_download(self, progress=gr.Progress(track_tqdm=True)):
87 | # wait for info
88 | cnt_ = 0
89 | while not self._get_info_flag:
90 | cnt_ += 1
91 | time.sleep(1)
92 | if cnt_ == 30:
93 | break
94 |
95 | self._break_flag = False
96 | self._t_download(self.safe_download)
97 | # self._t_start()
98 | self.progress(progress=progress)
99 | if self._break_flag:
100 | return "Done! Data-Download had interrupted!"
101 | return self.final_out_path
102 |
103 | def safe_download(self):
104 | for _ in range(self.run_times):
105 | self.hf_download()
106 | time.sleep(0.5)
107 | # 执行完检验
108 | if self._finished_check():
109 | print('finished download all model files')
110 | break
111 | self.remove_and_create()
112 | return
113 |
114 | def hf_download(self):
115 | print('>>>>>>> Start hf_download')
116 | # 1- mid download local dir
117 | self.mid_download_dir = self.final_out_path
118 | # 2- download load_dataset 环境变量调整
119 | os.system(f"""
120 | export HF_ENDPOINT=https://hf-mirror.com && \
121 | huggingface-cli download --resume-download {self.data_name} --local-dir-use-symlinks False \
122 | --repo-type dataset \
123 | --local-dir {self.final_out_path} \
124 | --cache-dir {self.final_out_path}/cache \
125 | --token {TOKEN}
126 | """)
127 | os.system(f'rm -rf {self.final_out_path}/cache')
128 | return self.final_out_path
129 |
130 | def _finished_check(self):
131 | """检查是否下载完整数据
132 | """
133 | no_flag = (self.total_file_nums is not None) or (self.total_file_nums <= 0.01)
134 | print(f'self.total_file_nums={self.total_file_nums} no_flag={no_flag}')
135 | if not no_flag and os.path.exists(self.final_out_path):
136 | downloaded_files, download_bytes = self._get_final_out_bytes()
137 | print(f'downloaded_files={downloaded_files}\ndownload_bytes={download_bytes}')
138 | file_same = len(downloaded_files) == self.total_file_nums
139 | size_same = download_bytes / 1024**2 / (self.total_MB + 1e-5) >= 0.99
140 | return size_same & file_same
141 | return True
142 |
143 | def _t_start(self):
144 | self._t_handle_pg = threading.Thread(target=self.progress, name='X-dataset-progress', daemon=True)
145 | self._t_handle_pg.start()
146 |
147 | def _t_download(self, d_func):
148 | self._t_handle_dl = threading.Thread(target=d_func, name='X-dataset-download', daemon=True)
149 | self._t_handle_dl.start()
150 |
151 | def _get_final_out_bytes(self):
152 | # data存在多层的情况
153 | downloaded_files = get_final_out_files(self.final_out_path)
154 | cached_mb1 = sum([p_getsize(f) for f in downloaded_files])
155 | return downloaded_files, cached_mb1
156 |
157 | def progress(self, progress=None):
158 | hf_cache_dir = p_join(self.final_out_path, 'cache')
159 | self.bar_ = self.tqdm_class(total=round(self.total_MB*1024**2, 3), unit='iB', unit_scale=True)
160 | self.bar_.set_description('TotalProgress')
161 | bf = 0
162 | while True:
163 | if self._t_handle_dl is None:
164 | break
165 | if not self._t_handle_dl.is_alive():
166 | break
167 | hf_cache_files = get_hf_cache_files(hf_cache_dir) if os.path.exists(hf_cache_dir) else []
168 | _, cached_mb1 = self._get_final_out_bytes()
169 | cached_mb2 = sum([p_getsize(f) for f in hf_cache_files])
170 | cached_mb = (cached_mb1 + cached_mb2)
171 |
172 | self.bar_.update(round(cached_mb - bf, 3))
173 | bf = cached_mb
174 | time.sleep(self.progress_sleep)
175 |
176 | # 数据统计可能不准确
177 | finished_rate = cached_mb / (self.total_MB + 1e-5) / 1024**2
178 | if self._t_handle_dl is None and finished_rate <= 0.99:
179 | left = self.total_MB * 1024**2 - bf
180 | self.bar_.update(round(left, 3))
181 |
182 | return
183 |
184 | def break_download(self):
185 | # 然后杀死该线程
186 | # 删除文件
187 | if self._t_handle_dl is not None:
188 | print('>>>>>>>>>>>>>>>>> break_download')
189 | stop_thread(self._t_handle_dl)
190 | self._t_handle_dl = None
191 | os.system(f'sh {CUR_DIR}/kill_hf.sh dataset')
192 | print('>>>>>>>>>>>>>>>>> stop_thread(self._t_handle_dl)')
193 |
194 | if self._t_handle_pg is not None:
195 | stop_thread(self._t_handle_pg)
196 | print('>>>>>>>>>>>>>>>>> stop_thread(self._t_handle_pg)')
197 | self._t_handle_pg = None
198 | self.remove_and_create()
199 | self._break_flag = True
200 | return "Done! Data-Download had interrupted!"
201 |
202 |
203 | if __name__ == '__main__':
204 | print(os.getcwd())
205 | download_ = xtunerDataDownload(
206 | 'shibing624/medical',
207 | out_path='/home/scc/sccWork/myGitHub/My_Learn/tmp/download')
208 | # out_path='/root/tmp/download'
209 | # )
210 | # download_.hf_download() # checked
211 | download_.auto_download() # checked-download & progress
212 | time.sleep(10)
213 | download_.break_download() # checked-download & progress & break
214 | print('Yes')
215 | # chech finished
216 | f_ = download_._finished_check()
217 | print(f'_finished_check={f_}')
218 |
219 |
--------------------------------------------------------------------------------
/xtuner_download/download_model.py:
--------------------------------------------------------------------------------
1 | # python3
2 | # Create Date: 2024-01-23
3 | # Author: Scc_hy
4 | # Func: 模型拉取到本地
5 | # ===========================================================================================
6 | import os
7 | import gradio as gr
8 | from tqdm.auto import tqdm
9 | from openxlab.model import download as ox_download
10 | from modelscope.hub.snapshot_download import snapshot_download
11 | from modelscope.hub.api import HubApi, ModelScopeConfig
12 | from os.path import getsize as p_getsize
13 | from os.path import join as p_join
14 | import threading
15 | import time
16 | from .download_utils import stop_thread, _split_repo, get_hf_cache_files, get_model_info, TOKEN
17 | os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
18 | CUR_DIR = os.path.dirname(__file__)
19 |
20 |
21 | class xtunerModelDownload():
22 | def __init__(self, model_name, out_path, tqdm_class=tqdm, progress_sleep=1) -> None:
23 | self.progress_sleep = progress_sleep
24 | self.tqdm_class = tqdm_class
25 | self.username, self.repository = _split_repo(model_name)
26 | self.model_name = model_name
27 | self.out_path = out_path
28 | self.final_out_path = p_join(out_path, f'{self.username}_{self.repository}')
29 | self.mid_download_dir = self.final_out_path
30 | self._t_handle_dl = None
31 | self._t_handle_pg = None
32 | self._break_flag = False
33 | self._get_info_flag = False
34 | self.remove_and_create()
35 | self.get_download_info()
36 |
37 | def reset_path(self, customer_dir):
38 | self.__remove_mid_files()
39 | self.__remove_final_files()
40 | self.out_path = f'{customer_dir}/model_download'
41 | print(f'xtunerModelDownload reset_path->{self.out_path}')
42 | self.final_out_path = p_join(self.out_path, f'{self.username}_{self.repository}')
43 | self.mid_download_dir = self.final_out_path
44 | self.__check_create_dir()
45 |
46 | def reset(self, model_name):
47 | self.remove_and_create()
48 | print(f'reset({model_name})')
49 | self.username, self.repository = _split_repo(model_name)
50 | self.model_name = model_name
51 | self.final_out_path = p_join(self.out_path, f'{self.username}_{self.repository}')
52 | self.mid_download_dir = self.final_out_path
53 | self._t_handle_dl = None
54 | self._t_handle_pg = None
55 | self._break_flag = False
56 | self._get_info_flag = False
57 | self.remove_and_create()
58 | self.get_download_info()
59 |
60 | def _username_map(self, tp):
61 | """username 映射
62 | """
63 | modelscope_map_dict = {
64 | 'internlm': 'Shanghai_AI_Laboratory',
65 | 'meta-llama': 'shakechen', # Llma-2
66 | 'huggyllama': 'skyline2006', # Llma
67 | 'THUDM': 'ZhipuAI',
68 | '01-ai': '01ai',
69 | 'Qwen': 'qwen'
70 | }
71 | hf_map_dict = {}
72 | openxlab_map_dict = {
73 | 'internlm': 'OpenLMLab',
74 | 'meta-llama': 'shakechen', # Llma-2
75 | 'huggyllama': 'skyline2006', # Llma
76 | 'THUDM': 'ZhipuAI',
77 | '01-ai': '01ai'
78 | }
79 | sp_model_name = '{u_name}/{rep}'.format(
80 | u_name=eval(f"{tp}_map_dict.get('{self.username}', '{self.username}')"),
81 | rep=self.repository
82 | )
83 | return sp_model_name
84 |
85 | def get_download_info(self):
86 | # 优先modelscope查看
87 | try:
88 | self.total_MB, self.total_file_nums = self._get_download_info()
89 | except Exception as e:
90 | self.total_MB, self.total_file_nums = get_model_info(self.model_name)
91 | self._get_info_flag = True
92 |
93 | def _get_download_info(self):
94 | _api = HubApi()
95 | headers = {'user-agent': ModelScopeConfig.get_user_agent(user_agent=None, )}
96 | snapshot_header = headers if 'CI_TEST' in os.environ else {
97 | **headers,
98 | **{
99 | 'Snapshot': 'True'
100 | }
101 | }
102 | model_id = self._username_map('modelscope')
103 | model_files = _api.get_model_files(
104 | model_id=model_id,
105 | revision=None,
106 | recursive=True,
107 | use_cookies=False,
108 | headers=snapshot_header,
109 | )
110 | total_MB = sum([i['Size']/1024**2 for i in model_files])
111 | total_file_nums = len(model_files)
112 | return total_MB, total_file_nums
113 |
114 | def __check_create_dir(self):
115 | if not os.path.exists(self.out_path):
116 | os.system(f'mkdir -p {self.out_path}')
117 | if not os.path.exists(self.final_out_path):
118 | os.system(f'mkdir -p {self.final_out_path}')
119 |
120 | def __remove_mid_files(self):
121 | """中断时删除所有文件"""
122 | os.system(f'rm -rf {self.mid_download_dir}')
123 | # cd rm
124 | rm_dir = './' + self.mid_download_dir.replace(self.out_path, '.')[2:].split('/')[0]
125 | os.system(f'cd {self.out_path} && rm -rf {rm_dir} && rm -rf temp')
126 | # 删除 hf 的cache
127 | os.system(f'rm -rf {self.final_out_path}/cache')
128 |
129 | def __remove_final_files(self):
130 | os.system(f'rm -rf {self.final_out_path}')
131 | os.system(f'cd {self.out_path} && rm -rf {self.username}_{self.repository}')
132 |
133 | def remove_and_create(self):
134 | self.__remove_mid_files()
135 | self.__remove_final_files()
136 | self.__check_create_dir()
137 | self.mid_download_dir = self.final_out_path
138 |
139 | def auto_download(self, progress=gr.Progress(track_tqdm=True), tp='speed'):
140 | cnt_ = 0
141 | while not self._get_info_flag:
142 | cnt_ += 1
143 | time.sleep(1)
144 | if cnt_ == 30:
145 | break
146 |
147 | self._break_flag = False
148 | self._t_download(self.loop_download, tp)
149 | # self._t_start(progress)
150 | # progress not use thread
151 | self.progress(progress=progress)
152 | if self._break_flag:
153 | return "Done! Model-Download had interrupted!"
154 | return self.final_out_path
155 |
156 | def loop_download(self, tp='speed'):
157 | # modelscope first
158 | # if 'internlm' in self.model_name.lower():
159 | # loop_list = [self.openxlab_download, self.modelscope_download, self.hf_download]
160 | if tp == 'speed':
161 | loop_list = [self.modelscope_download, self.hf_download, self.openxlab_download]
162 | else:
163 | loop_list = [self.hf_download, self.modelscope_download, self.openxlab_download]
164 |
165 | for download_func in loop_list:
166 | print("download_func=", download_func)
167 | print("_get_info_flag=", self._get_info_flag)
168 | try:
169 | download_func()
170 | time.sleep(1)
171 | except Exception as e:
172 | print("download_func=", download_func, '\n', '--'*25, f'\nerror={e}\n', '--'*25)
173 | pass
174 | # 执行完检验
175 | if self._finished_check():
176 | print('finished download all model files')
177 | break
178 |
179 | print('Failed download all model files & remove_and_create')
180 | self.remove_and_create()
181 | return
182 |
183 | def hf_download(self):
184 | print('>>>>>>> Start hf_download')
185 | # 1- mid download local dir
186 | self.mid_download_dir = self.final_out_path
187 | # 2- download
188 | os.system(f"""
189 | export HF_ENDPOINT=https://hf-mirror.com && \
190 | huggingface-cli download --resume-download {self.model_name} --local-dir-use-symlinks False \
191 | --repo-type model \
192 | --local-dir {self.final_out_path} \
193 | --cache-dir {self.final_out_path}/cache \
194 | --token {TOKEN}
195 | """)
196 | os.system(f'rm -rf {self.final_out_path}/cache')
197 | return self.final_out_path
198 |
199 | def modelscope_download(self):
200 | print('>>>>>>> Start modelscope_download')
201 | # 1- fix-name
202 | model_name = self._username_map('modelscope')
203 | # 2- mid download local dir
204 | self.mid_download_dir = mid_download_dir = p_join(self.out_path, model_name)
205 | # 3- download
206 | snapshot_download(model_id=model_name, cache_dir=self.out_path)
207 | # 保证目录一致 out_path/sccHyFuture/LLM_medQA_adapter --> final_out_path
208 | os.system(f'mv {mid_download_dir}/* {self.final_out_path}')
209 | self.__remove_mid_files()
210 | return self.final_out_path
211 |
212 | def openxlab_download(self):
213 | print('>>>>>>> Start openxlab_download')
214 | # 1- fix-name
215 | model_name = self._username_map('openxlab')
216 | # 2- mid download local dir
217 | self.mid_download_dir = self.final_out_path
218 | # 3- download
219 | ox_download(model_repo=model_name, output=self.final_out_path, cache=False)
220 | return self.final_out_path
221 |
222 | def _finished_check(self):
223 | """检查是否下载完整数据
224 | """
225 | no_flag = (self.total_file_nums is not None) or (self.total_file_nums <= 0.01)
226 | if no_flag and os.path.exists(self.final_out_path):
227 | final_nums = len([i for i in os.listdir(self.final_out_path) if not os.path.isdir(i) ])
228 | print('os.listdir(self.final_out_path)=', os.listdir(self.final_out_path))
229 | final_MB = sum([p_getsize(p_join(self.final_out_path, i))/ 1024**2 for i in os.listdir(self.final_out_path)])
230 | file_same = final_nums >= self.total_file_nums
231 | size_same = final_MB/(self.total_MB + 1e-5) >= 0.99
232 | print(f"self.total_file_nums={self.total_file_nums} final_nums={final_nums} >>>>> file_same={file_same}")
233 | print(f"self.total_MB={self.total_MB:.3f} final_MB={final_MB:.3f} >>>>> size_same={size_same}")
234 | return size_same & file_same
235 | return True
236 |
237 | def _t_start(self, pg=None):
238 | self._t_handle_pg = threading.Thread(target=self.progress, args=(pg,), name='X-model-progress', daemon=True)
239 | self._t_handle_pg.start()
240 |
241 | def _t_download(self, d_func, tp):
242 | self._t_handle_dl = threading.Thread(target=d_func, args=(tp,) ,name='X-model-download', daemon=True)
243 | self._t_handle_dl.start()
244 |
245 | def progress(self, progress=None):
246 | model_scope_cache_dir = p_join(self.out_path, 'temp')
247 | hf_cache_dir = p_join(self.final_out_path, 'cache')
248 | self.bar_ = self.tqdm_class(total=round(self.total_MB*1024**2, 3), unit='iB', unit_scale=True)
249 | self.bar_.set_description('TotalProgress')
250 | bf = 0
251 | while True:
252 | if self._t_handle_dl is None:
253 | break
254 | if not self._t_handle_dl.is_alive():
255 | break
256 | hf_cache_files = get_hf_cache_files(hf_cache_dir) if os.path.exists(hf_cache_dir) else []
257 | if self.mid_download_dir == self.final_out_path:
258 | cached_mb1 = sum([p_getsize(p_join(self.final_out_path, i)) for i in os.listdir(self.final_out_path)])
259 | cached_mb4 = sum([p_getsize(f) for f in hf_cache_files])
260 | cached_mb = cached_mb1 + cached_mb4
261 | else:
262 | # 获取最近创建的temp文件
263 | try:
264 | model_scope_cache_dir_tmp = sorted([
265 | [p_join(model_scope_cache_dir, i), os.stat(p_join(model_scope_cache_dir, i)).st_atime] for i in os.listdir(model_scope_cache_dir)
266 | ], key=lambda c:c[1])[-1][0]
267 | except Exception as e:
268 | model_scope_cache_dir_tmp = None
269 |
270 | cached_mb1 = 0
271 | if os.path.exists(self.mid_download_dir):
272 | cached_mb1 = sum([p_getsize(p_join(self.mid_download_dir, i)) for i in os.listdir(self.mid_download_dir)])
273 |
274 | cached_mb2 = sum([p_getsize(p_join(self.final_out_path, i)) for i in os.listdir(self.final_out_path)])
275 | cached_mb3 = 0
276 | if model_scope_cache_dir_tmp is not None:
277 | cached_mb3 = sum([p_getsize(p_join(model_scope_cache_dir_tmp, i)) for i in os.listdir(model_scope_cache_dir_tmp)])
278 |
279 | cached_mb4 = sum([p_getsize(f) for f in hf_cache_files])
280 | cached_mb = (cached_mb1 + cached_mb2 + cached_mb3 + cached_mb4)
281 |
282 | self.bar_.update(round(cached_mb - bf, 3))
283 | bf = cached_mb
284 | if cached_mb / (self.total_MB + 1e-5) / 1024**2 > 99.99:
285 | break
286 | time.sleep(self.progress_sleep)
287 | return
288 |
289 | def break_download(self):
290 | # 然后杀死该线程
291 | # 删除文件
292 | if self._t_handle_dl is not None:
293 | print('>>>>>>>>>>>>>>>>> break_download')
294 | stop_thread(self._t_handle_dl)
295 | self._t_handle_dl = None
296 | os.system(f'sh {CUR_DIR}/kill_hf.sh model')
297 | print('>>>>>>>>>>>>>>>>> stop_thread(self._t_handle_dl)')
298 |
299 | if self._t_handle_pg is not None:
300 | stop_thread(self._t_handle_pg)
301 | print('>>>>>>>>>>>>>>>>> stop_thread(self._t_handle_pg)')
302 | self._t_handle_pg = None
303 | self.remove_and_create()
304 | self._break_flag = True
305 | return "Done! Model-Download had interrupted!"
306 |
307 |
308 | if __name__ == '__main__':
309 | print(os.getcwd())
310 | download_ = xtunerModelDownload(
311 | 'internlm/InternLM-chat-7b',
312 | out_path='/home/scc/sccWork/myGitHub/My_Learn/tmp/download')
313 | #'/root/tmp/download')
314 | #'/home/scc/sccWork/myGitHub/My_Learn/tmp/download')
315 | # download_.hf_download() # checked-download & progress
316 | # download_.openxlab_download() # checked-download & progress
317 | # download_.modelscope_download() # checked-download & progress
318 | download_.auto_download() # checked-download & progress
319 | time.sleep(10)
320 | download_.break_download() # checked-download & progress & break
321 | print('Yes')
322 | # chech finished
323 | f_ = download_._finished_check()
324 | print(f'_finished_check={f_}')
325 |
--------------------------------------------------------------------------------
/xtuner_download/download_utils.py:
--------------------------------------------------------------------------------
1 | # python3
2 | # Create Date: 2024-01-24
3 | # Author: Scc_hy
4 | # Func: 模型拉取到本地
5 | # ===========================================================================================
6 | import os
7 | import re
8 | import inspect
9 | import ctypes
10 | import requests
11 | import inspect
12 | import ctypes
13 | import requests
14 | from tqdm.auto import tqdm
15 | from os.path import join as p_join
16 | from concurrent.futures import ThreadPoolExecutor
17 | from huggingface_hub.hf_api import HfApi
18 |
19 | os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
20 | TOKEN = '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'
21 |
22 | def get_hf_cache_files(folder_path):
23 | all_files = []
24 | for root, dirs, files in os.walk(folder_path):
25 | for file in files:
26 | file_tt = p_join(root, file)
27 | if os.path.isfile(file_tt) and '.incomplete' in file:
28 | all_files.append(file_tt)
29 | return all_files
30 |
31 |
32 | def get_final_out_files(folder_path):
33 | all_files = []
34 | for root, dirs, files in os.walk(folder_path):
35 | for file in files:
36 | file_tt = p_join(root, file)
37 | if not os.path.isdir(file_tt) and 'cache' not in file_tt:
38 | all_files.append(file_tt)
39 | return all_files
40 |
41 |
42 |
43 | def _split_repo(model_repo) -> (str, str):
44 | """
45 | Split a full repository name into two separate strings: the username and the repository name.
46 | """
47 | # username/repository format check
48 | pattern = r'.+/.+'
49 | if not re.match(pattern, model_repo):
50 | raise ValueError("The input string must be in the format 'username/model_repo'")
51 |
52 | values = model_repo.split('/')
53 | return values[0], values[1]
54 |
55 |
56 | def _async_raise(tid, exctype):
57 | """Raises an exception in the threads with id tid"""
58 | if not inspect.isclass(exctype):
59 | raise TypeError("Only types can be raised (not instances)")
60 | res = ctypes.pythonapi.PyThreadState_SetAsyncExc(ctypes.c_long(tid), ctypes.py_object(exctype))
61 | if res == 0:
62 | raise ValueError("invalid thread id")
63 | elif res != 1:
64 | # """if it returns a number greater than one, you're in trouble,
65 | # and you should call it again with exc=NULL to revert the effect"""
66 | ctypes.pythonapi.PyThreadState_SetAsyncExc(tid, None)
67 | raise SystemError("PyThreadState_SetAsyncExc failed")
68 |
69 |
70 | def stop_thread(thread):
71 | try:
72 | _async_raise(thread.ident, SystemExit)
73 | except Exception as e:
74 | print(e)
75 |
76 |
77 | def detect_data_file_bytes(data_name, data_file):
78 | try:
79 | txt = requests.get(f'https://hf-mirror.com/datasets/{data_name}/blob/main/{data_file}', timeout=3).text
80 | except Exception as e:
81 | print(e)
82 | return 0.0
83 | find_out = re.findall(r'Size of remote file:(.*?)B', txt, flags=re.S)
84 | find_info = find_out[0] if len(find_out) else '0 '
85 | info_num = float(re.findall(r'\d+.\d+|\d+', find_info)[0])
86 | info_cat = find_info[-1]+'B'
87 | # huggingface 直接1000
88 | info_map = {
89 | 'kB': 1000,
90 | 'KB': 1000,
91 | 'MB': 1000 ** 2,
92 | 'mB': 1000 ** 2,
93 | 'GB': 1000 ** 3,
94 | 'gB': 1000 ** 3,
95 | ' B': 1,
96 | }
97 | return info_num * info_map.get(info_cat, 1.0)
98 |
99 |
100 | def get_data_info(data_name):
101 | os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
102 | api_ = HfApi(token=TOKEN)
103 | try:
104 | df_info = api_.dataset_info(repo_id=data_name, token=TOKEN, timeout=3)
105 | except Exception as e:
106 | print(e)
107 | return 0, 0
108 | df_files = [i.rfilename for i in df_info.siblings]
109 | exec_ = ThreadPoolExecutor(max_workers=2)
110 | tasks = [exec_.submit(detect_data_file_bytes, data_name=data_name, data_file=i) for i in df_files]
111 | res = []
112 | for t in tqdm(tasks):
113 | res.append(t.result())
114 |
115 | total_MB = sum(res) / 1024 ** 2
116 | total_file_nums = len(df_files)
117 | return total_MB, total_file_nums
118 |
119 |
120 | def detect_model_file_bytes(model_name, data_file):
121 | try:
122 | txt = requests.get(f'https://hf-mirror.com/{model_name}/blob/main/{data_file}', timeout=3).text
123 | except Exception as e:
124 | print(e)
125 | return 0.0
126 | find_out = re.findall(r'Size of remote file:(.*?)B', txt, flags=re.S)
127 | find_info = find_out[0] if len(find_out) else '0 '
128 | info_num = float(re.findall(r'\d+.\d+|\d+', find_info)[0])
129 | info_cat = find_info[-1]+'B'
130 | # huggingface 直接1000
131 | info_map = {
132 | 'kB': 1000,
133 | 'KB': 1000,
134 | 'MB': 1000 ** 2,
135 | 'mB': 1000 ** 2,
136 | 'GB': 1000 ** 3,
137 | 'gB': 1000 ** 3,
138 | ' B': 1,
139 | }
140 | return info_num * info_map.get(info_cat, 1.0)
141 |
142 |
143 | def get_model_info(model_name):
144 | os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
145 | api_ = HfApi(token=TOKEN)
146 | try:
147 | df_info = api_.model_info(repo_id=model_name, token=TOKEN, timeout=3)
148 | except Exception as e:
149 | print(e)
150 | return 0, 0
151 | df_files = [i.rfilename for i in df_info.siblings]
152 | exec_ = ThreadPoolExecutor(max_workers=2)
153 | tasks = [exec_.submit(detect_model_file_bytes, model_name=model_name, data_file=i) for i in df_files]
154 | res = []
155 | for t in tqdm(tasks):
156 | res.append(t.result())
157 |
158 | total_MB = sum(res) / 1024 ** 2
159 | total_file_nums = len(df_files)
160 | return total_MB, total_file_nums
161 |
162 |
163 | def test_get_data_info():
164 | print('>>>>>>>> Start test')
165 | data_name = 'shibing624/medical'
166 | total_MB, total_file_nums = get_data_info(data_name)
167 | print(f'[ data_name={data_name} ] total_file_nums={total_file_nums} | total_MB={total_MB:.3f}MiB')
168 |
169 | def test_down_load_data():
170 | print('>>>>>>>> Start test')
171 | model_name = 'internlm/internlm-chat-7b'
172 | total_MB, total_file_nums = get_model_info(model_name)
173 | print(f'[ data_name={model_name} ] total_file_nums={total_file_nums} | total_MB={total_MB:.3f}MiB')
174 |
175 |
176 |
177 | if __name__ == '__main__':
178 | test_get_data_info()
179 | test_down_load_data()
180 |
--------------------------------------------------------------------------------
/xtuner_download/find_datalist.py:
--------------------------------------------------------------------------------
1 |
2 | import re
3 | import os
4 | from os.path import getsize as p_getsize
5 | from os.path import join as p_join
6 |
7 |
8 | def get_py_config_files(folder_path):
9 | all_files = []
10 | for root, dirs, files in os.walk(folder_path):
11 | for file in files:
12 | file_tt = p_join(root, file)
13 | if os.path.isfile(file_tt) and '.py' in file and '__init__' not in file:
14 | all_files.append(file_tt)
15 | return all_files
16 |
17 |
18 | def read_and_find_path(file):
19 | with open(file, 'r') as f:
20 | res = f.readlines()[:50]
21 | return re.findall(r"path\s+=\s+'(.*?)'\n", ''.join(res))
22 |
23 |
24 | father_p = '/home/scc/sccWork/openProject/xtuner019/xtuner/xtuner/configs'
25 | py_files = get_py_config_files(father_p)
26 | py_need_files = [i.rsplit('/', 1)[1] for i in py_files]
27 | files_need = set([i.split('lora_')[-1].replace('.py', '') for i in py_need_files])
28 | path_ = [read_and_find_path(i) for i in py_files]
29 | path_final = []
30 | for p in path_:
31 | path_final.extend(p)
32 |
33 | model_str = ['Llama', 'Qwen', 'Baichuan', 'chatglm', '01-ai', 'internlm', 'llama']
34 | path_final = [i for i in sorted(set(path_final)) if all(j not in i for j in model_str)]
35 |
36 |
37 |
38 |
39 |
--------------------------------------------------------------------------------
/xtuner_download/kill_hf.sh:
--------------------------------------------------------------------------------
1 |
2 |
3 | rp_tp=$1;
4 |
5 | pid=`ps -ef | grep -E "huggingface-cli.*?${rp_tp}" | grep -v "grep" | awk '{print $2}'`
6 | for id in $pid
7 | do
8 | kill -9 $id
9 | echo "killed $id"
10 | done;
11 |
--------------------------------------------------------------------------------
/xtuner_download/model_list.txt:
--------------------------------------------------------------------------------
1 | OpenLMLab/internlm-7b
2 | OpenLMLab/internlm-20b
3 | OpenLMLab/internlm-chat-7b
4 | OpenLMLab/internlm-chat-20b
5 |
6 | shakechen/Llama-2-7b-chat meta-llama/Llama-2-7b-chat
7 | --llama2_70b meta-llama/llama2-70b--
8 | shakechen/Llama-2-7b meta-llama/Llama-2-7b
9 | skyline2006/llama-7b huggyllama/llama-7b
10 |
11 | baichuan-inc/Baichuan-7B-Chat
12 | baichuan-inc/Baichuan-13B-Base
13 | baichuan-inc/baichuan-7B
14 | baichuan-inc/Baichuan2-13B-Chat
15 | baichuan-inc/Baichuan2-7B-Chat
16 | baichuan-inc/Baichuan2-13B-Chat
17 | baichuan-inc/Baichuan2-13B-Base
18 |
19 | ZhipuAI/chatglm3-6b THUDM/chatglm3-6b
20 | ZhipuAI/chatglm2-6b THUDM/chatglm2-6b
21 | ZhipuAI/chatglm3-6b-base THUDM/chatglm3-6b-base
22 |
23 | yi_34b
24 | 01ai/Yi-6B 01-ai/Yi-6B
25 | Qwen/Qwen-7B-Chat
26 | Qwen/Qwen-7B
27 |
28 |
29 |
--------------------------------------------------------------------------------
/xtuner_download/test_download.py:
--------------------------------------------------------------------------------
1 | from download_model import xtunerModelDownload
2 | from download_dataset import xtunerDataDownload
3 | from tqdm.auto import tqdm
4 | import time
5 |
6 | def main():
7 | print('>>>>>>>> Start xtunerModelDownload')
8 | model_name = 'internlm/internlm-chat-7b'
9 | d_model = xtunerModelDownload(
10 | model_name,
11 | out_path='/root/tmp/download_model',
12 | tqdm_class=tqdm
13 | )
14 | d_model.auto_download()
15 | print('>>>>>>>> Start xtunerDataDownload')
16 | data_name = 'shibing624/medical'
17 | d_data = xtunerDataDownload(
18 | data_name,
19 | out_path='/root/tmp/download_data',
20 | tqdm_class=tqdm,
21 | retry_times=0
22 | )
23 | d_data.auto_download()
24 | time.sleep(60)
25 | d_data.break_download()
26 | d_model.break_download()
27 | print('Yes')
28 |
29 |
30 | if __name__ == '__main__':
31 | main()
32 |
--------------------------------------------------------------------------------
/xtuner_download/todo_list.md:
--------------------------------------------------------------------------------
1 | - modelDownload
2 | - [X] download base
3 | - [X] check finished download
4 | - [X] 通过modelscope获取文件列表
5 | - [X] progress
6 | - [X] break down of downloading
7 | - [X] kill and delete restart files
8 |
9 | - dataDownload
10 | - [X] DataLists
11 | - [X] Download method
12 | - [X] only huggingface
13 |
--------------------------------------------------------------------------------
/xtuner_result/draw.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 | from datasets import load_dataset
3 | import altair as alt
4 | import pandas as pd
5 | import os
6 | import re
7 | import gradio as gr
8 |
9 |
10 | class resPlot:
11 | def __init__(self, work_dir):
12 | self.work_dir = work_dir
13 | self.log_file = None
14 | self.iter_dir_list = []
15 | self.get_log_path()
16 | self.find_iter_pth()
17 | print(f"resPlot(self.log_file={self.log_file})")
18 |
19 | def get_log_path(self):
20 | try:
21 | list_ = sorted([i for i in os.listdir(self.work_dir) if '.' not in i and re.match(r'\d+_\d+', i)])
22 | dir_name = list_[-1]
23 | self.log_file = os.path.join(self.work_dir, dir_name, 'vis_data' , f'{dir_name}.json')
24 | except Exception as e:
25 | print(e)
26 | pass
27 |
28 | def find_iter_pth(self):
29 | try:
30 | self.iter_dir_list = sorted([i for i in os.listdir(self.work_dir) if '.pth' in i])
31 | except Exception as e:
32 | print(e)
33 | pass
34 |
35 | def get_eval_test(self, ep_pth):
36 | ep_str = ep_pth.split('.')[0]
37 | ep = ep_str.split('_')[0] + '_' + str(int(ep_str.split('_')[1]) - 1)
38 | list_ = sorted([i for i in os.listdir(self.work_dir) if '.' not in i and re.match(r'\d+_\d+', i)])
39 | dir_name = list_[-1]
40 | eval_file = os.path.join(self.work_dir, dir_name, 'vis_data' , f'eval_outputs_{ep}.txt')
41 | try:
42 | return open(eval_file, 'r').read()
43 | except Exception as e:
44 | return f'eval_file={eval_file}\nERROR: {e} '
45 |
46 | def dynamic_eval_drop_down(self):
47 | list_ = sorted([i for i in os.listdir(self.work_dir) if '.' not in i and re.match(r'\d+_\d+', i)])
48 | dir_name = list_[-1]
49 | # /root/xtunerUITest/test/appPrepare/work_dir/20240204_204337/vis_data/eval_outputs_iter_49.txt
50 | eval_file = [i for i in os.listdir(os.path.join(self.work_dir, dir_name, 'vis_data')) if '.txt' in i]
51 | final_list = []
52 | if len(eval_file):
53 | final_list = ["iter_{}".format(int(i.split('_')[-1].split('.')[0])+1) for i in eval_file]
54 |
55 | return gr.Dropdown(choices=final_list, interactive=True)
56 |
57 | def dynamic_drop_down(self):
58 | self.iter_dir_list = sorted([i for i in os.listdir(self.work_dir) if '.pth' in i])
59 | return gr.Dropdown(choices=self.iter_dir_list, interactive=True)
60 |
61 | def reset_work_dir(self, root_dir):
62 | self.work_dir = f'{root_dir}/work_dir'
63 | self.get_log_path()
64 | self.find_iter_pth()
65 | print(f"resPlot -> self.work_dir={self.work_dir}\nself.log_file={self.log_file}\nself.iter_dir_list={self.iter_dir_list}")
66 |
67 | def lr_plot(self):
68 | self.get_log_path()
69 | y_axis_name = 'lr'
70 | return self.gr_line_plot(y_axis_name, self.log_file)
71 |
72 | def loss_plot(self):
73 | self.get_log_path()
74 | y_axis_name = 'loss'
75 | return self.gr_line_plot(y_axis_name, self.log_file)
76 |
77 | @staticmethod
78 | def make_plot(y_axis_name, log_path):
79 | ds = load_dataset('json', data_files=log_path)
80 | ds = ds['train'].to_pandas()
81 | ds = ds.rename(columns={'iter': 'iter_num'})
82 | # ['lr', 'data_time', 'loss', 'time', 'grad_norm', 'iter', 'memory', 'step']
83 | source = pd.DataFrame({
84 | 'iter_num': ds['iter_num'].map(int).tolist(),
85 | y_axis_name: ds[y_axis_name].map(float).tolist(),
86 | })
87 | base = alt.Chart(source).mark_line(
88 | point=alt.OverlayMarkDef(filled=False, fill="white")
89 | ).encode(x='iter_num',y=y_axis_name)
90 | return base
91 |
92 | @staticmethod
93 | def gr_line_plot(y_axis_name, log_path):
94 | ds = load_dataset('json', data_files=log_path)
95 | ds = ds['train'].to_pandas()
96 | ds = ds.rename(columns={'iter': 'iter_num'})
97 | source = pd.DataFrame({
98 | 'iter_num': ds['iter_num'].map(int).tolist(),
99 | y_axis_name: ds[y_axis_name].map(float).tolist(),
100 | })
101 | return gr.LinePlot(
102 | source,
103 | x="iter_num",
104 | x_title='iter_num',
105 | y=y_axis_name,
106 | y_title=y_axis_name,
107 | overlay_point=True,
108 | tooltip=["iter_num", y_axis_name],
109 | title=y_axis_name,
110 | height=300,
111 | width=500,
112 | )
113 |
114 | def draw(y_axis_name, log_path, save_path):
115 | """
116 | Args:
117 | y_axis_name: One of ('lr', 'loss')
118 | """
119 | ds = load_dataset('json', data_files=log_path)
120 | ds = ds['train']
121 | x = ds['iter']
122 | y = ds[y_axis_name]
123 | plt.figure(figsize=(10,5))
124 |
125 | plt.plot(x, y, marker='.')
126 |
127 | plt.title(f'Training Iterations vs {y_axis_name}')
128 | plt.xlabel('Training iterations')
129 | plt.ylabel(y_axis_name)
130 |
131 | plt.savefig(save_path)
132 |
133 |
134 | if __name__ == '__main__':
135 | draw('loss', './dummy_log.json', 'd1.png')
136 |
--------------------------------------------------------------------------------
/xtuner_run/example.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import torch
3 | from datasets import load_dataset
4 | from mmengine.dataset import DefaultSampler
5 | from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
6 | LoggerHook, ParamSchedulerHook)
7 | from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
8 | from peft import LoraConfig
9 | from torch.optim import AdamW
10 | from transformers import (AutoModelForCausalLM, AutoTokenizer,
11 | BitsAndBytesConfig)
12 |
13 | from xtuner.dataset import process_hf_dataset
14 | from xtuner.dataset.collate_fns import default_collate_fn
15 | from xtuner.dataset.map_fns import code_alpaca_map_fn, template_map_fn_factory
16 | from xtuner.engine.hooks import DatasetInfoHook, EvaluateChatHook
17 | # from xtuner.engine.runner import TrainLoop
18 | from xtuner.model import SupervisedFinetune
19 | from xtuner.utils import PROMPT_TEMPLATE, SYSTEM_TEMPLATE
20 |
21 | #######################################################################
22 | # PART 1 Settings #
23 | #######################################################################
24 | # Model
25 | pretrained_model_name_or_path = 'internlm/internlm2-chat-7b'
26 |
27 | # Data
28 | data_path = 'HuggingFaceH4/CodeAlpaca_20K'
29 | prompt_template = PROMPT_TEMPLATE.internlm2_chat
30 | max_length = 2048
31 | pack_to_max_length = True
32 |
33 | # Scheduler & Optimizer
34 | batch_size = 1 # per_device
35 | accumulative_counts = 16
36 | dataloader_num_workers = 0
37 | max_epochs = 3
38 | optim_type = AdamW
39 | lr = 2e-4
40 | betas = (0.9, 0.999)
41 | weight_decay = 0
42 | max_norm = 1 # grad clip
43 | warmup_ratio = 0.03
44 |
45 | # Save
46 | save_steps = 500
47 | save_total_limit = 2 # Maximum checkpoints to keep (-1 means unlimited)
48 |
49 | # Evaluate the generation performance during the training
50 | evaluation_freq = 100
51 | SYSTEM = SYSTEM_TEMPLATE.coder
52 | evaluation_inputs = [
53 | ('写一个Python函数,将十六进制颜色代码(如#0066ee)转换为对应的'
54 | '红、绿、蓝(RGB)三个颜色分量值,并以元组的形式返回。'),
55 | ('Write a Python function that takes a hexadecimal color code '
56 | '(e.g., #0066ee) as input and converts it into the corresponding '
57 | 'red, green, and blue (RGB) color component values.')
58 | ]
59 |
60 | #######################################################################
61 | # PART 2 Model & Tokenizer #
62 | #######################################################################
63 | tokenizer = dict(
64 | type=AutoTokenizer.from_pretrained,
65 | pretrained_model_name_or_path=pretrained_model_name_or_path,
66 | trust_remote_code=True,
67 | padding_side='right')
68 |
69 | model = dict(
70 | type=SupervisedFinetune,
71 | llm=dict(
72 | type=AutoModelForCausalLM.from_pretrained,
73 | pretrained_model_name_or_path=pretrained_model_name_or_path,
74 | trust_remote_code=True,
75 | torch_dtype=torch.float16,
76 | quantization_config=dict(
77 | type=BitsAndBytesConfig,
78 | load_in_4bit=True,
79 | load_in_8bit=False,
80 | llm_int8_threshold=6.0,
81 | llm_int8_has_fp16_weight=False,
82 | bnb_4bit_compute_dtype=torch.float16,
83 | bnb_4bit_use_double_quant=True,
84 | bnb_4bit_quant_type='nf4')),
85 | lora=dict(
86 | type=LoraConfig,
87 | r=64,
88 | lora_alpha=16,
89 | lora_dropout=0.1,
90 | bias='none',
91 | task_type='CAUSAL_LM'))
92 |
93 | #######################################################################
94 | # PART 3 Dataset & Dataloader #
95 | #######################################################################
96 | train_dataset = dict(
97 | type=process_hf_dataset,
98 | dataset=dict(type=load_dataset, path=data_path),
99 | tokenizer=tokenizer,
100 | max_length=max_length,
101 | dataset_map_fn=code_alpaca_map_fn,
102 | template_map_fn=dict(
103 | type=template_map_fn_factory, template=prompt_template),
104 | remove_unused_columns=True,
105 | shuffle_before_pack=True,
106 | pack_to_max_length=pack_to_max_length)
107 |
108 | train_dataloader = dict(
109 | batch_size=batch_size,
110 | num_workers=dataloader_num_workers,
111 | dataset=train_dataset,
112 | sampler=dict(type=DefaultSampler, shuffle=True),
113 | collate_fn=dict(type=default_collate_fn))
114 |
115 | #######################################################################
116 | # PART 4 Scheduler & Optimizer #
117 | #######################################################################
118 | # optimizer
119 | optim_wrapper = dict(
120 | type=AmpOptimWrapper,
121 | optimizer=dict(
122 | type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
123 | clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
124 | accumulative_counts=accumulative_counts,
125 | loss_scale='dynamic',
126 | dtype='float16')
127 |
128 | # learning policy
129 | # More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501
130 | param_scheduler = [
131 | dict(
132 | type=LinearLR,
133 | start_factor=1e-5,
134 | by_epoch=True,
135 | begin=0,
136 | end=warmup_ratio * max_epochs,
137 | convert_to_iter_based=True),
138 | dict(
139 | type=CosineAnnealingLR,
140 | eta_min=0.0,
141 | by_epoch=True,
142 | begin=warmup_ratio * max_epochs,
143 | T_max=max_epochs,
144 | convert_to_iter_based=True)
145 | ]
146 |
147 | # train, val, test setting
148 | # train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
149 |
150 | #######################################################################
151 | # PART 5 Runtime #
152 | #######################################################################
153 | # Log the dialogue periodically during the training process, optional
154 | custom_hooks = [
155 | dict(type=DatasetInfoHook, tokenizer=tokenizer),
156 | dict(
157 | type=EvaluateChatHook,
158 | tokenizer=tokenizer,
159 | every_n_iters=evaluation_freq,
160 | evaluation_inputs=evaluation_inputs,
161 | system=SYSTEM,
162 | prompt_template=prompt_template)
163 | ]
164 |
165 | # configure default hooks
166 | default_hooks = dict(
167 | # record the time of every iteration.
168 | timer=dict(type=IterTimerHook),
169 | # print log every 10 iterations.
170 | logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
171 | # enable the parameter scheduler.
172 | param_scheduler=dict(type=ParamSchedulerHook),
173 | # save checkpoint per `save_steps`.
174 | checkpoint=dict(
175 | type=CheckpointHook,
176 | by_epoch=False,
177 | interval=save_steps,
178 | max_keep_ckpts=save_total_limit),
179 | # set sampler seed in distributed evrionment.
180 | sampler_seed=dict(type=DistSamplerSeedHook),
181 | )
182 |
183 | # configure environment
184 | env_cfg = dict(
185 | # whether to enable cudnn benchmark
186 | cudnn_benchmark=False,
187 | # set multi process parameters
188 | mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
189 | # set distributed parameters
190 | dist_cfg=dict(backend='nccl'),
191 | )
192 |
193 | # set visualizer
194 | visualizer = None
195 |
196 | # set log level
197 | log_level = 'INFO'
198 |
199 | # load from which checkpoint
200 | load_from = None
201 |
202 | # whether to resume training from the loaded checkpoint
203 | resume = False
204 |
205 | # Defaults to use random seed and disable `deterministic`
206 | randomness = dict(seed=None, deterministic=False)
207 |
208 | # set log processor
209 | log_processor = dict(by_epoch=False)
--------------------------------------------------------------------------------
/xtuner_run/kill_xtuner.sh:
--------------------------------------------------------------------------------
1 |
2 | pid=`ps -ef | grep -E "xtuner train.*" | grep -v "grep" | awk '{print $2}'`
3 | for id in $pid
4 | do
5 | kill -9 $id
6 | echo "killed $id"
7 | done;
8 |
9 | pid=`ps -ef | grep -E "python.*?train.py.*" | grep -v "grep" | awk '{print $2}'`
10 | for id in $pid
11 | do
12 | kill -9 $id
13 | echo "killed $id"
14 | done;
--------------------------------------------------------------------------------
/xtuner_run/shell_train.py:
--------------------------------------------------------------------------------
1 | # python3
2 | # Create Date: 2024-01-30
3 | # Author: Scc_hy
4 | # Func: 用shell 启动xtuner
5 | # ===========================================================================================
6 |
7 | from .train_utils import prepareConfig, prepareUtil, stop_thread
8 | import threading
9 | import os
10 | import gradio as gr
11 | CUR_DIR = os.path.dirname(__file__)
12 |
13 |
14 | class quickTrain:
15 | def __init__(self,
16 | work_dir,
17 | config_py_path,
18 | deepspeed_seed=None,
19 | resume_from_checkpoint=None,
20 | run_type='mmengine'):
21 | self.work_dir = work_dir
22 | self.config_py_path = config_py_path
23 | self.resume_from_checkpoint = resume_from_checkpoint
24 | self.run_type = run_type
25 | self.deepspeed_seed = deepspeed_seed
26 | self._t_handle_tr = None
27 | self.log_file = os.path.join(self.work_dir, '__xtuner_tr.log')
28 | self.remove_log_file()
29 | print(f'config_py_path={config_py_path}')
30 |
31 | def reset_resume_from_checkpoint(self, ckpt):
32 | self.resume_from_checkpoint = f'{self.work_dir}/{ckpt}'
33 | print(f"reset_resume_from_checkpoint({self.resume_from_checkpoint})")
34 |
35 | def reset_deepspeed(self, deepspeed):
36 | print(f"reset_deepspeed({deepspeed})")
37 | self.deepspeed_seed = deepspeed
38 |
39 | def reset_work_dir(self, local_path):
40 | print(f"reset_work_dir({local_path})")
41 | self.work_dir = os.path.join(local_path, 'work_dir')
42 | if not os.path.exists(self.work_dir):
43 | os.system(f'mkdir -p {self.work_dir}')
44 | self.log_file = os.path.join(self.work_dir, '__xtuner_tr.log')
45 | self.remove_log_file()
46 |
47 | def reset_cfg_py(self, cfg_py):
48 | print(f"reset_cfg_py({cfg_py})")
49 | self.config_py_path = cfg_py
50 |
51 | def remove_log_file(self):
52 | if os.path.exists(self.log_file):
53 | os.system(f'rm -rf {self.log_file}')
54 |
55 | def _quick_train(self, resume, progress=gr.Progress(track_tqdm=True)):
56 | self.remove_log_file()
57 | add_ = resume_ = ''
58 | if str(self.deepspeed_seed).lower() not in ['none', 'dropdown']:
59 | add_ = f'--deepspeed deepspeed_{self.deepspeed_seed} '
60 |
61 | if self.resume_from_checkpoint is not None:
62 | resume_ = f'--resume {self.resume_from_checkpoint}'
63 |
64 | if resume:
65 | exec_ = f'xtuner train {self.config_py_path} --work-dir {self.work_dir} {add_} {resume_} > {self.log_file} 2>&1'
66 | else:
67 | exec_ = f'xtuner train {self.config_py_path} --work-dir {self.work_dir} {add_} > {self.log_file} 2>&1'
68 |
69 | print(f'exec={exec_}')
70 | os.system(exec_)
71 |
72 | def _t_start(self, resume=0):
73 | self._t_handle_tr = threading.Thread(target=self._quick_train, args=(resume,), name=f'X-train-{self.run_type}', daemon=True)
74 | self._t_handle_tr.start()
75 |
76 | def quick_train(self, progress=gr.Progress(track_tqdm=True)):
77 | self._break_flag = False
78 | self._t_start(0)
79 | self._t_handle_tr.join()
80 | if self._break_flag:
81 | return f"Done! Xtuner had interrupted!\nwork_dir={self.work_dir}", self.work_dir
82 | return "Success", self.work_dir
83 |
84 | def resume_train(self, progress=gr.Progress(track_tqdm=True)):
85 | self._break_flag = False
86 | self._t_start(1)
87 | self._t_handle_tr.join()
88 | if self._break_flag:
89 | return f"Done! Xtuner had interrupted!\nRESUME work_dir={self.work_dir}", self.work_dir
90 | return "Success", self.work_dir
91 |
92 | def _tail(self, n=100):
93 | line_list = []
94 | with open(self.log_file, "rb") as f:
95 | f.seek(0, 2)
96 | while 1:
97 | if f.read(1) == b"\n":
98 | now_index = f.tell()
99 | line_list.append(f.readline())
100 | f.seek(now_index, 0)
101 | if len(line_list) >= n:
102 | return line_list[::-1]
103 | if f.tell() <= 1:
104 | f.seek(0, 0)
105 | line_list.append(f.readline())
106 | return line_list[::-1]
107 | f.seek(-2, 1)
108 |
109 | def read_log(self):
110 | if self._t_handle_tr is None:
111 | return ""
112 | if os.path.exists(self.log_file):
113 | # with open(self.log_file, 'r') as f:
114 | # res_ = f.readlines()
115 | # return ''.join(res_)
116 | line_list = self._tail(5)
117 | return b"".join(line_list).decode()
118 |
119 | def break_train(self):
120 | # 然后杀死该线程
121 | # 删除文件
122 | if self._t_handle_tr is not None:
123 | print('>>>>>>>>>>>>>>>>> break_download')
124 | stop_thread(self._t_handle_tr)
125 | os.system(f'sh {CUR_DIR}/kill_xtuner.sh')
126 | self._t_handle_tr = None
127 |
128 | self._break_flag = True
129 | return f"Done! Xtuner had interrupted!\nwork_dir={self.work_dir}", self.work_dir
130 |
131 |
--------------------------------------------------------------------------------
/xtuner_run/todo_list.md:
--------------------------------------------------------------------------------
1 |
2 | - [X] config prepare
3 | - [X] test
4 | - [X] model sft
5 | - [X] code
6 | - [X] app add func
7 | - [ ] break & continue
8 | - [ ] log print
9 |
10 |
--------------------------------------------------------------------------------
/xtuner_run/train.py:
--------------------------------------------------------------------------------
1 | # python3
2 | # Create Date: 2024-01-29
3 | # Author: Scc_hy
4 | # Func: 基于指定参数进行模型训练
5 | # nohup xtuner train ./internlm_chat_7b_qlora_oasst1_e3_copy.py > __xtuner.log &
6 | # pip install -U xtuner
7 | # ===========================================================================================
8 |
9 | from .train_utils import prepareConfig, prepareUtil, stop_thread
10 | import threading
11 | import os
12 | import gradio as gr
13 | from transformers import Trainer
14 | from xtuner.dataset.collate_fns import default_collate_fn
15 | from functools import partial
16 | from xtuner.dataset import process_hf_dataset
17 | from datasets import load_dataset
18 | from xtuner.dataset.map_fns import template_map_fn_factory
19 | from transformers import TrainingArguments
20 | import torch
21 | from peft import LoraConfig
22 | from transformers import (AutoModelForCausalLM, AutoTokenizer,
23 | BitsAndBytesConfig)
24 | from xtuner.model import SupervisedFinetune
25 | from xtuner.apis.datasets import alpaca_data_collator, alpaca_dataset
26 | from mmengine.runner import Runner
27 | import time
28 |
29 | def safe_load_data(file):
30 | tp = file.split('.')[-1]
31 | if tp == 'csv':
32 | return load_dataset('csv', data_files=dict(train=file))
33 | if 'json' in tp:
34 | return load_dataset('json', data_files=dict(train=file))
35 |
36 | # py
37 | return load_dataset(file, split='train')
38 |
39 |
40 | def saft_build_model(model_name_or_path,
41 | quantization_config=None,
42 | lora_config=None,
43 | return_tokenizer=True,
44 | qlora_flag=True):
45 | if quantization_config is None:
46 | quantization_config = BitsAndBytesConfig(
47 | load_in_4bit=True,
48 | load_in_8bit=False,
49 | llm_int8_threshold=6.0,
50 | llm_int8_has_fp16_weight=False,
51 | bnb_4bit_compute_dtype=torch.float16,
52 | bnb_4bit_use_double_quant=True,
53 | bnb_4bit_quant_type='nf4')
54 | if lora_config is None:
55 | lora_config = LoraConfig(
56 | r=64,
57 | lora_alpha=16,
58 | lora_dropout=0.1,
59 | bias='none',
60 | task_type='CAUSAL_LM')
61 |
62 | llm = AutoModelForCausalLM.from_pretrained(
63 | model_name_or_path,
64 | torch_dtype=torch.float16,
65 | trust_remote_code=True,
66 | quantization_config=quantization_config if qlora_flag else None)
67 |
68 | try:
69 | model = SupervisedFinetune(llm, lora=lora_config)
70 | except Exception as e:
71 | model = SupervisedFinetune(llm, lora=lora_config, use_activation_checkpointing=False)
72 |
73 | if return_tokenizer:
74 | tokenizer = AutoTokenizer.from_pretrained(
75 | model_name_or_path,
76 | trust_remote_code=True,
77 | encode_special_tokens=True)
78 | return model.llm, tokenizer
79 | else:
80 | return model.llm
81 |
82 |
83 |
84 | def mm_run(
85 | model_name_or_path,
86 | dataset_name_or_path,
87 | work_dir,
88 | xtuner_type='qlora',
89 | resume_from_checkpoint=None,
90 | progress=gr.Progress(track_tqdm=True)
91 | ):
92 | cfg_org = prepareConfig(
93 | model_name_or_path=model_name_or_path,
94 | dataset_name_or_path=dataset_name_or_path
95 | )
96 | if resume_from_checkpoint is not None:
97 | cfg_org.resume_from_checkpoint = resume_from_checkpoint
98 | pp = prepareUtil(cfg_org, work_dir=work_dir, lora_type=xtuner_type)
99 | cfg = pp.auto_prepare()
100 | if resume_from_checkpoint is not None:
101 | cfg['load_from'] = resume_from_checkpoint
102 | cfg['resume'] = True
103 | try:
104 | runner = Runner.from_cfg(cfg)
105 | # runner = Runner.from_cfg(org_cfg)
106 | runner.train()
107 | # runner.test()
108 | except Exception as e:
109 | print(f"mm_run ERROR: \n{e}")
110 | return f"mm_run ERROR: \n{e}"
111 | return pp.work_dir
112 |
113 |
114 | def hf_run(
115 | model_name_or_path,
116 | dataset_name_or_path,
117 | work_dir,
118 | xtuner_type='qlora',
119 | resume_from_checkpoint=None,
120 | progress=gr.Progress(track_tqdm=True)
121 | ):
122 | cfg = prepareConfig(
123 | model_name_or_path=model_name_or_path,
124 | dataset_name_or_path=dataset_name_or_path
125 | )
126 | cfg.output_dir = work_dir
127 | if 'LOCAL_RANK' not in os.environ:
128 | os.environ['LOCAL_RANK'] = str(cfg.local_rank)
129 |
130 | cfg_dict = cfg.to_tr_dict()
131 | if resume_from_checkpoint is not None:
132 | cfg_dict['resume_from_checkpoint'] = resume_from_checkpoint
133 | tr_args = TrainingArguments(**cfg_dict)
134 | print('=='*35)
135 | print('tr_args=', tr_args)
136 | print('=='*35)
137 | model, tokenizer = saft_build_model(
138 | model_name_or_path=cfg.model_name_or_path,
139 | return_tokenizer=True,
140 | qlora_flag=xtuner_type == 'qlora'
141 | )
142 | train_dataset = process_hf_dataset(
143 | dataset=safe_load_data(cfg.dataset_name_or_path),
144 | tokenizer=tokenizer,
145 | max_length=cfg.max_length,
146 | dataset_map_fn=cfg.dataset_map_fn,
147 | template_map_fn=template_map_fn_factory(template=cfg.task_prompt_template),
148 | pack_to_max_length=cfg.pack_to_max_length
149 | )
150 | # build trainer
151 | trainer = Trainer(
152 | model=model,
153 | args=tr_args,
154 | train_dataset=train_dataset,
155 | data_collator=partial(default_collate_fn, return_hf_format=True)
156 | )
157 | # training
158 | trainer.train(resume_from_checkpoint=cfg.resume_from_checkpoint)
159 | trainer.save_state()
160 | trainer.save_model(output_dir=tr_args.output_dir)
161 | return tr_args.output_dir
162 |
163 |
164 |
165 | class quickTrain:
166 | def __init__(self,
167 | model_name_or_path,
168 | dataset_name_or_path,
169 | work_dir,
170 | xtuner_type='qlora',
171 | resume_from_checkpoint=None,
172 | run_type='mmengine'):
173 | self.model_name_or_path = model_name_or_path
174 | self.dataset_name_or_path = dataset_name_or_path
175 | self.xtuner_type = xtuner_type
176 | self.resume_from_checkpoint = resume_from_checkpoint
177 | self.run_type = run_type
178 | self.work_dir = work_dir
179 | self._break_flag = False
180 | self._t_handle_tr = None
181 | self.log_file = None
182 | self.mm_run_res_ = None
183 |
184 | def get_log_path(self):
185 | list_ = sorted([i for i in os.listdir(self.work_dir) if '.' not in i])
186 | dir_name = list_[-2] if 'last_' in list_[-1] else list_[-1]
187 | self.log_file = os.path.join(self.work_dir, dir_name, f'{dir_name}.log')
188 |
189 | def set_model_path(self, model_path):
190 | print(f'set_model_path({model_path})')
191 | self.model_name_or_path = model_path
192 |
193 | def set_data_path(self, data_path):
194 | print(f'set_data_path({data_path})')
195 | self.dataset_name_or_path = data_path
196 |
197 | def set_xtuner_type(self, xtuner_type):
198 | print(f'set_xtuner_type({xtuner_type})')
199 | self.xtuner_type = xtuner_type
200 |
201 | def set_work_dir(self, work_dir):
202 | print(f'set_work_dir({work_dir})')
203 | self.work_dir = f'{work_dir}/work_dir'
204 | if not os.path.exists(self.work_dir):
205 | os.system(f'mkdir -p {self.work_dir}')
206 |
207 | def set_resume_from_checkpoint(self, work_dir=None):
208 | self.resume_from_checkpoint = work_dir
209 |
210 | def _t_start(self):
211 | self._t_handle_tr = threading.Thread(target=self._quick_train, name=f'X-train-{self.run_type}', daemon=True)
212 | self._t_handle_tr.start()
213 |
214 | def _quick_train(self, progress=gr.Progress(track_tqdm=True)):
215 | print(
216 | f'self.model_name_or_path={self.model_name_or_path}\nself.dataset_name_or_path={self.dataset_name_or_path}\nself.work_dir={self.dataset_name_or_path}\nself.xtuner_type={self.xtuner_type}'
217 | )
218 | if self.run_type == 'mmengine':
219 | self.mm_run_res_ = None
220 | self.mm_run_res_ = mm_run(self.model_name_or_path, self.dataset_name_or_path, self.work_dir, self.xtuner_type, self.resume_from_checkpoint)
221 | return self.mm_run_res_
222 | return hf_run(self.model_name_or_path, self.dataset_name_or_path, self.work_dir, self.xtuner_type, self.resume_from_checkpoint)
223 |
224 | def read_log(self):
225 | if self.log_file is None:
226 | return ""
227 | if not os.path.exists(self.log_file):
228 | return ""
229 | with open(self.log_file , 'r') as f:
230 | read_res = f.readlines()
231 | read_res = ''.join(read_res) # [-20:]
232 | if self.mm_run_res_ is not None:
233 | return f'{read_res}\n{self.mm_run_res_}'
234 | return read_res
235 |
236 | def start_log(self):
237 | time.sleep(10)
238 | return "Start Training"
239 |
240 | def quick_train(self, progress=gr.Progress(track_tqdm=True)):
241 | self.log_file = None
242 | self._break_flag = False
243 | self._t_start()
244 | time.sleep(5)
245 | self.get_log_path()
246 | self._t_handle_tr.join()
247 | if self._break_flag:
248 | return "Done! Xtuner had interrupted!"
249 | return self.work_dir
250 |
251 | def break_train(self):
252 | # 然后杀死该线程
253 | # 删除文件
254 | if self._t_handle_tr is not None:
255 | print('>>>>>>>>>>>>>>>>> break_download')
256 | stop_thread(self._t_handle_tr)
257 | self._t_handle_tr = None
258 |
259 | self._break_flag = True
260 | return "Done! Xtuner had interrupted!"
261 |
262 |
263 |
264 | def main_test():
265 | model_ = '/root/share/model_repos/internlm-chat-7b'
266 | model_2 = '/root/share/model_repos/internlm2-chat-7b'
267 | TR_ = quickTrain(
268 | model_name_or_path=model_,
269 | dataset_name_or_path='/root/ft-medqa/MedQA2019-structured-train.jsonl',
270 | work_dir='./work_dir',
271 | xtuner_type='qlora',
272 | run_type='mmengine'
273 | )
274 | TR_.quick_train()
275 |
276 |
277 | if __name__ == '__main__':
278 | main_test()
279 |
280 |
281 |
282 |
283 |
--------------------------------------------------------------------------------
/xtuner_run/train_utils.py:
--------------------------------------------------------------------------------
1 | # python3
2 | # Create Date: 2024-01-29
3 | # Author: Scc_hy
4 | # Func: 参数准备
5 | # ===========================================================================================
6 |
7 | import os
8 | import transformers
9 | from dataclasses import dataclass, field
10 | from typing import List, Dict, ClassVar, Tuple, AnyStr, Callable
11 | import warnings
12 | import os
13 | import re
14 | import inspect
15 | import ctypes
16 | import inspect
17 | import ctypes
18 | warnings.filterwarnings(action='ignore')
19 |
20 |
21 | def _async_raise(tid, exctype):
22 | """Raises an exception in the threads with id tid"""
23 | if not inspect.isclass(exctype):
24 | raise TypeError("Only types can be raised (not instances)")
25 | res = ctypes.pythonapi.PyThreadState_SetAsyncExc(ctypes.c_long(tid), ctypes.py_object(exctype))
26 | if res == 0:
27 | raise ValueError("invalid thread id")
28 | elif res != 1:
29 | # """if it returns a number greater than one, you're in trouble,
30 | # and you should call it again with exc=NULL to revert the effect"""
31 | ctypes.pythonapi.PyThreadState_SetAsyncExc(tid, None)
32 | raise SystemError("PyThreadState_SetAsyncExc failed")
33 |
34 |
35 | def stop_thread(thread):
36 | try:
37 | _async_raise(thread.ident, SystemExit)
38 | except Exception as e:
39 | print(e)
40 |
41 |
42 | @dataclass
43 | class prepareConfig:
44 | model_name_or_path: AnyStr
45 | dataset_name_or_path: AnyStr
46 | do_train: bool = True
47 | save_strategy: AnyStr = 'epoch'
48 | lr_scheduler_type: AnyStr = 'cosine'
49 | logging_steps: int = 5
50 | framework: AnyStr = 'huggingface'
51 | output_dir: AnyStr = './work_dir'
52 | deepspeed: bool = None
53 | local_rank: int = -1
54 | seed: int = 42
55 | max_length: int = 2048
56 | pack_to_max_length: bool = True
57 | batch_size: int = 1
58 | per_device_train_batch_size: int = 1 # per_device
59 | per_device_eval_batch_size: int = 1 # per_device
60 | # accumulative_counts: int = 16
61 | gradient_accumulation_steps: int = 16
62 | dataloader_num_workers: int = 0
63 | max_epochs: int = 3
64 | num_train_epochs: float = 3.0 # TrainingArguments
65 | optim: AnyStr = "adamw_torch"
66 | # optim_args
67 | optim_type: AnyStr = 'bitsandbytes.optim.PagedAdamW32bit'
68 | learning_rate: float = 2e-4
69 | betas: Tuple[float, float] = (0.9, 0.999)
70 | adam_beta1: float = 0.9
71 | adam_beta2: float = 0.999
72 | weight_decay: float = 0
73 | max_norm: float = 1
74 | max_grad_norm: float = 1
75 | warmup_ratio: float = 0.03
76 | task_prompt_template: AnyStr = 'xtuner.utils.PROMPT_TEMPLATE.internlm_chat'
77 | # Save
78 | save_steps: int = 500
79 | save_total_limit: int = 2 # Maximum checkpoints to keep (-1 means unlimited)
80 | # Evaluate the generation performance during the training
81 | evaluation_freq: int = 500
82 | system: AnyStr = "" # SYSTEM_TEMPLATE.coder
83 | # prompt trick
84 | evaluation_inputs: List = field(default_factory=lambda: [
85 | '请给我介绍五个上海的景点', 'Please tell me five scenic spots in Shanghai'
86 | ])
87 | # set visualizer
88 | visualizer = None
89 | # set log level
90 | log_level: AnyStr = 'info'
91 | # load from which checkpoint
92 | resume_from_checkpoint: AnyStr = None
93 | # whether to resume training from the loaded checkpoint
94 | resume: bool = False
95 | # Defaults to use random seed and disable `deterministic`
96 | randomness: Dict = field(default_factory=lambda: dict(seed=None, deterministic=False))
97 | trust_remote_code: bool = True
98 | env_cfg: Dict = field(default_factory=lambda: dict(
99 | # whether to enable cudnn benchmark
100 | cudnn_benchmark=False,
101 | # set multi process parameters
102 | mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
103 | # set distributed parameters
104 | dist_cfg=dict(backend='nccl'),
105 | ))
106 | # Callable path
107 | dataset_map_fn:AnyStr = None
108 |
109 | def get_device_map(self):
110 | if self.deepspeed:
111 | self.device_map = None
112 | else:
113 | self.device_map = {
114 | '': int(os.environ.get('LOCAL_RANK', self.local_rank))
115 | }
116 |
117 | def to_tr_dict(self):
118 | self_dict = self.__dict__
119 | focus_key = [
120 | 'output_dir',
121 | 'deepspeed',
122 | 'local_rank',
123 | 'seed',
124 | 'per_device_train_batch_size',
125 | 'per_device_eval_batch_size',
126 | 'dataloader_num_workers',
127 | 'gradient_accumulation_steps',
128 | 'num_train_epochs',
129 | 'optim',
130 | 'weight_decay',
131 | 'adam_beta1',
132 | 'adam_beta2',
133 | 'max_grad_norm',
134 | 'warmup_ratio',
135 | 'save_steps',
136 | 'save_total_limit',
137 | 'log_level',
138 | 'resume_from_checkpoint',
139 | 'deepspeed',
140 | 'do_train',
141 | 'save_strategy' ,
142 | 'lr_scheduler_type',
143 | 'logging_steps',
144 | ]
145 | res = {}
146 | for k in focus_key:
147 | res[k] = self_dict[k]
148 | return res
149 |
150 |
151 | class prepareUtil:
152 | def __init__(self, cfg: prepareConfig, work_dir: AnyStr, lora_type: str='qlora'):
153 | self.cfg = cfg
154 | self.cfg.output_dir = work_dir
155 | self.work_dir = work_dir
156 | self.qlora_flag = lora_type == 'qlora'
157 | self._model, self._tokenizer = self.prepare_model_tokenizer()
158 |
159 | def auto_prepare(self):
160 | train_dataset, train_dataloader = self.prepare_data()
161 | optim_wrapper, param_scheduler = self.prepare_scheduler_optimizer()
162 | custom_hooks, default_hooks = self.prepare_hook()
163 | return dict(
164 | model=self._model,
165 | work_dir=self.cfg.output_dir,
166 | train_dataloader=train_dataloader,
167 | val_dataloader=None,
168 | test_dataloader=None,
169 | train_cfg=dict(by_epoch=True, max_epochs=self.cfg.max_epochs, val_interval=1),
170 | val_cfg=None,
171 | test_cfg=None,
172 | optim_wrapper=optim_wrapper,
173 | param_scheduler=param_scheduler,
174 | val_evaluator=None,
175 | test_evaluator=None,
176 | custom_hooks=custom_hooks,
177 | default_hooks=default_hooks,
178 | resume=self.cfg.resume_from_checkpoint is not None,
179 | env_cfg=self.cfg.env_cfg,
180 | visualizer=self.cfg.visualizer,
181 | log_level=self.cfg.log_level.upper(),
182 | randomness=self.cfg.randomness,
183 | launcher='none'
184 | )
185 |
186 | def prepare_scheduler_optimizer(self):
187 | optim_wrapper = dict(
188 | type='mmengine.optim.AmpOptimWrapper',
189 | optimizer=dict(
190 | type=self.cfg.optim_type, lr=self.cfg.learning_rate, betas=self.cfg.betas, weight_decay=self.cfg.weight_decay),
191 | clip_grad=dict(max_norm=self.cfg.max_norm, error_if_nonfinite=False),
192 | accumulative_counts=self.cfg.gradient_accumulation_steps,
193 | loss_scale='dynamic',
194 | dtype='float16'
195 | )
196 | param_scheduler = dict(
197 | type='mmengine.optim.CosineAnnealingLR',
198 | eta_min=0.0,
199 | by_epoch=True,
200 | begin=self.cfg.warmup_ratio * self.cfg.max_epochs,
201 | T_max=self.cfg.max_epochs,
202 | convert_to_iter_based=True
203 | )
204 | return optim_wrapper, param_scheduler
205 |
206 | def prepare_model_tokenizer(self):
207 | tokenizer = dict(
208 | type='transformers.AutoTokenizer.from_pretrained',
209 | pretrained_model_name_or_path=self.cfg.model_name_or_path,
210 | trust_remote_code=True,
211 | padding_side='right'
212 | )
213 | model = dict(
214 | type='xtuner.model.SupervisedFinetune',
215 | llm=dict(
216 | type='transformers.AutoModelForCausalLM.from_pretrained',
217 | pretrained_model_name_or_path=self.cfg.model_name_or_path,
218 | trust_remote_code=True,
219 | torch_dtype='torch.float16',
220 | quantization_config=dict(
221 | type='transformers.BitsAndBytesConfig',
222 | load_in_4bit=True,
223 | load_in_8bit=False,
224 | llm_int8_threshold=6.0,
225 | llm_int8_has_fp16_weight=False,
226 | bnb_4bit_compute_dtype='torch.float16',
227 | bnb_4bit_use_double_quant=True,
228 | bnb_4bit_quant_type='nf4') if self.qlora_flag else None
229 | ),
230 | lora=dict(
231 | type='peft.LoraConfig',
232 | r=64,
233 | lora_alpha=16,
234 | lora_dropout=0.1,
235 | bias='none',
236 | task_type='CAUSAL_LM'))
237 | return model, tokenizer
238 |
239 | def prepare_hook(self):
240 | custom_hooks = [
241 | dict(type='xtuner.engine.DatasetInfoHook', tokenizer=self._tokenizer),
242 | dict(
243 | type='xtuner.engine.EvaluateChatHook',
244 | tokenizer=self._tokenizer ,
245 | every_n_iters=self.cfg.evaluation_freq,
246 | evaluation_inputs=self.cfg.evaluation_inputs,
247 | system=self.cfg.system,
248 | prompt_template=self.cfg.task_prompt_template)
249 | ]
250 | # configure default hooks
251 | default_hooks = dict(
252 | checkpoint=dict(interval=1, type='mmengine.hooks.CheckpointHook'),
253 | logger=dict(interval=10, type='mmengine.hooks.LoggerHook'),
254 | param_scheduler=dict(type='mmengine.hooks.ParamSchedulerHook'),
255 | sampler_seed=dict(type='mmengine.hooks.DistSamplerSeedHook'),
256 | timer=dict(type='mmengine.hooks.IterTimerHook')
257 | )
258 | return custom_hooks, default_hooks
259 |
260 | def prepare_data(self):
261 | train_dataset = dict(
262 | type='xtuner.dataset.process_hf_dataset',
263 | dataset=self.safe_load_dataset(self.cfg.dataset_name_or_path),
264 | tokenizer=self._tokenizer,
265 | max_length=self.cfg.max_length,
266 | dataset_map_fn=self.cfg.dataset_map_fn,
267 | template_map_fn=dict(
268 | type='xtuner.dataset.map_fns.template_map_fn_factory',
269 | template=self.cfg.task_prompt_template),
270 | remove_unused_columns=True,
271 | shuffle_before_pack=True,
272 | pack_to_max_length=self.cfg.pack_to_max_length)
273 |
274 | train_dataloader = dict(
275 | batch_size=self.cfg.batch_size,
276 | num_workers=self.cfg.dataloader_num_workers,
277 | dataset=train_dataset,
278 | sampler=dict(type='mmengine.dataset.DefaultSampler', shuffle=True),
279 | collate_fn=dict(type='xtuner.dataset.collate_fns.default_collate_fn'))
280 | return train_dataset, train_dataloader
281 |
282 | def safe_load_dataset(self, file):
283 | load_tp = 'datasets.load_dataset'
284 | tp = file.split('.')[-1]
285 | if tp == 'csv':
286 | return dict(type=load_tp, path='csv', data_files=dict(train=file))
287 | if 'json' in tp:
288 | return dict(type=load_tp, path='json', data_files=dict(train=file))
289 | # py
290 | return dict(type=load_tp, path=file)
291 |
292 |
293 | def main_test():
294 | import transformers
295 | cfg = prepareConfig(
296 | model_name_or_path='/root/opencompass/InternLM/Shanghai_AI_Laboratory/internlm2-chat-7b',
297 | dataset_name_or_path='/root/ft-medqa/MedQA2019-structured-train.jsonl'
298 | )
299 | print(f'type(cfg.evaluation_inputs)={type(cfg.evaluation_inputs)}')
300 | print(cfg.evaluation_inputs)
301 | pp = prepareUtil(cfg)
302 | pp_res = pp.auto_prepare()
303 | # pp_res = cfg.to_tr_dict()
304 | print('--'*25)
305 | print(pp_res.custom_hooks)
306 | print(pp_res.custom_hooks[0]['type'])
307 | # print('=='*35)
308 |
309 |
310 | if __name__ == '__main__':
311 | main_test() # checked
312 |
313 |
314 |
--------------------------------------------------------------------------------