├── README.md
└── Ziwei-Chatglm3-6B
├── .github
├── ISSUE_TEMPLATE
│ ├── bug_report.yaml
│ └── feature-request.yaml
└── PULL_REQUEST_TEMPLATE
│ └── pr_template.md
├── .gitignore
├── MODEL_LICENSE
├── composite_demo
├── .streamlit
│ └── config.toml
├── assets
│ ├── demo.png
│ ├── emojis.png
│ ├── heart.png
│ └── tool.png
├── client.py
├── conversation.py
├── demo_chat.py
├── main.py
├── requirements.txt
└── tool_registry.py
├── data
└── output.json
├── finetune_demo
├── =1.3.0
├── README.md
├── arguments.py
├── configs
│ └── deepspeed.json
├── finetune.py
├── inference.py
├── preprocess_utils.py
├── scripts
│ ├── finetune_ds.sh
│ ├── finetune_ds_multiturn.sh
│ ├── finetune_pt.sh
│ ├── finetune_pt_multiturn.sh
│ ├── format_advertise_gen.py
│ └── format_tool_alpaca.py
└── trainer.py
├── openai_api
├── openai_api.py
├── openai_api_request.py
├── requirements.txt
└── utils.py
├── requirements.txt
├── useChroma
└── build_chroma.py
└── with_prompt
├── prompt_web.py
└── utils.py
/README.md:
--------------------------------------------------------------------------------
1 | 紫薇算命大模型
2 | ====
3 | 紫薇系列模型
4 | -------
5 | * Ziwei-llama-7B:此版本基于Chines-llama-alpacha-7B通过lora微调和小批量文本pre-train训练而来,此基座模型相较于llama,扩充了中文词汇表,增加了在中文上表现能力,训练后算命领域的表现得到显著提升
6 | 。
7 |
8 | * Ziwei-chatglm3-6B:此版本基于Chatglm3-6B通过p-tuning训练而成,此基座模型拥有目前10B以下最好的中文表现能力
9 | ,微调后算命领域也有显著提升。
10 |
11 | 简介
12 | -------
13 | 本项目旨在开发一个算命大模型,利用中文AI模型进行答题,提供给用户准确、有趣的算命体验。
14 |
15 |
16 | 我们将使用两种不同的模型分别进行开发:基于Chinese-llama-alpacha模型和基于chatGLM3模型。同时,我们将对它们进行比较,以了解它们在算命服务中的表现差异。数据收集方面,我们将使用爬虫技术收集网络上的相关问答对,同时利用书籍、博客等信息源获取更多的算命内容,并结合ChatGPT3.5API生成更多的问答对,以拓展数据集。
17 | 我们将使用lora(Language Representation Augmentation)技术对Chinese-llama模型进行微调,以提高其在算命服务中的表现。lora技术能够通过引入多样化的语言表征,增强模型的语言理解能力,使其在特定领域表现更好。
18 | 对chatGLM3模型则将采用P-tuningV2技术进行微调,P-tuningV2是一种自动深度参数调整技术,能够以较低的计算成本实现对模型的微调,并提升其在指定任务上的性能。
19 |
20 | 通过思维链、longchain和向量数据库的应用,我们将构建丰富的prompt,以激发模型更具创意、生动的回答,并提供更具吸引力的算命体验。通过向量数据库的构建,我们也将提供更加准确的查询匹配以及信息检索能力。
21 | 在对两种模型进行微调后,我们将进行性能比较与评估。通过评估指标包括准确率、回答的充实度、回答的个性化程度等,来比较两种模型在算命服务中的表现差异,以及它们的适用场景和优劣势。
22 | 我们希望通过这个项目,为用户提供一个全新的算命体验,以及为中文AI模型的应用领域提供更多的实践和拓展。
23 |
24 | 数据
25 | -------
26 | 我们的对话数据由chatgpt-3.5生成,共十万条指令微调数据,随后经过清洗、数据增强,以下是数据样例:
27 |
28 | | 类型 | 数据 |
29 | | ------ | ------ |
30 | | llama数据 | "instruction": "请为我进行星座预测, 告诉我未来的事业发展如何。", "input": "我是射手座的,想知道我的事业发展如何,可以为我进行星座预测吗?",
"output": "根据你的星座射手座,预测显示你的事业将取得积极的发展。你将迎来许多机会和挑战,需要运用你的乐观和自信来面对。随着时间的推移,你将在工作中得到认可和升职的机会。记得保持专注和努力,同时也留出时间来追求个人兴趣和发展。" |
31 | | chatglm3数据| "prompt": "我抽到了一个卦象,是震卦。请帮我解读一下这个卦象。",
"response": "震卦代表着行动和挑战。这个卦象提醒你在面对困难和障碍时要保持积极的态度,并勇于采取行动。不要害怕失败,因为失败是成功的一部分。这个卦象还提醒你关注自己的思维方式和决策过程,确保它们是明智和有利的。相信自己的能力,勇往直前。"|
32 |
33 |
34 | 文本数据来自有关算命的电子书籍,以及对部分算命网站的爬虫。
35 |
36 | 部分数据在项目data文件夹下。
37 |
38 | 实验结果
39 | -------
40 | 举例:对于“我的出生日期是2002年10月5日22时15分我什么时候会有孩子?”这个问题,我们给出微调前与微调后和加入
41 | 提示词工程后的回答对比。
42 |
43 | | 类型 | 微调前回答 | 微调后回答 |
44 | | ------ | ------ | ------ |
45 | | Ziwei-Chatglm3-6B | 作为一个人工智能助手,我无法预测个人的命运,包括生育的时间。生育是一个自然的过程,受到许多因素的影响,包括个人的生育能力、健康状况、家庭环境等。我的建议是,要积极面对生活,努力实现自己的目标,不断学习和成长。这样,你就能塑造自己的命运。|作为一位占卜师,我会根据你的出生日期和时间来进行一次星盘解读,以了解你目前的生育状况和未来可能的生育机会。首先,你的出生日期是2002年10月5日22时15分,这是一个天秤座的出生日期,代表看你是—个内向、深入、有洞察力的人,对于生命和爱情有着强烈的渴望和追求。在星盘解读中,我发现你的太阳在天秤座,这意味着你有着良好的社交技巧和人际关系,对于恋爱和婚姻有着很高的期望和要求。同时,你的月亮在射手座,这代表着你有着开放、自由、冒险和探索的精神,对于爱情和家庭也有着开放的态度。根据你的星盘解读,目前你的生育状况可能存在一些不确定性,可能需要进一步的检查和观察。未来可能的生育机会主要取决于你和你伴侣的关系状况、健康状况等因素。因此,我建议你保持积极乐观的心态,与你的伴侣共同关注生育问题,并在必要时寻求专业医生的帮助和建议。最后,我想要提醒你,作为一位占卜师,我的解读只是一种参考,并不能完全决定你的命运和未来。你的生育状况和机会,还取决于你自己的努力和选择。希望你能够保持积极、乐观和自信的心态,面对未来的挑战和机遇。|
46 | | Ziwei-llama-7B |抱歉,作为占卜师,我不能提供关于个人的预测或建议。但是,你可以通过自我探索、冥想或咨询专业心理医生来了解自己内心深处的想法和感受。|作为占卜师,我可以告诉你一些相关的信息,但是不能确定你的具体命运。根据占星学的理论,一个人的出生日期、时间和地点等信息会影响他们的命盘和星座,从而影响他们的生辰八字。然而,这些因素并不总是确定性的,因为每个人的生命历程都是独特的,并且受到许多其他因素的影响。关于你的问题,如果你想知道何时会有孩子的话,那么你可以通过占星学来预测。首先,你需要了解你的命盘和星座,然后使用占星学的方法来分析它们。在占星学中,每个星座都有不同的象征和含义,而每个命盘也具有不同的特征。因此,通过分析你的命盘和星座,你可以获得有关何时会有孩子的信息。然而,请注意,占星学只是一种预测方法,它并不能保证100%准确。它只能提供一些可能的情况和趋势,但最终的结果取决于你的个人选择和行动。此外,由于每个人的生命历程都是独特的,所以即使是相同的星座和命盘,也会导致不同的结果。总之,如果你想了解何时会有孩子的问题,那么可以通过占星学进行预测。但是,请记住,这种预测方法并不是百分之百准确的,而且最终的结果取决于你自己的选择。根据你提供的信息,你的命盘显示你将在未来十年内有机会拥有孩子。然而,具体的时间还要综合考虑其他因素,如个人意愿、健康状况等。建议你保持身体健康和平衡的生活,并注意时机,如果合适,可以试着怀孕。|
47 |
48 | 使用
49 | -------
50 | >>模型下载
51 |
52 | * Ziwei-Chatglm3-6B:
53 |
54 | 下载微调后的模型权重:下载网址: [百度网盘](https://pan.baidu.com/s/1KZzW8x8gBFM7I_PnbJBMcA?pwd=5qqd "模型权重")。
55 | 提取码:5qqd
56 |
57 | >>模型使用提示
58 |
59 | 1.下载网盘内微调后的权重以及chatglm3-6B的原始权重,一同使用方为完整模型。
60 |
61 | 2.可以选择选择微调后或者加入prompt的模型进行推理。
62 |
63 | 3.可以继续收集数据利用脚本进行微调。
64 |
65 |
66 |
67 |
68 |
69 |
70 | 项目参与者
71 | -------
72 | 本项目由哈尔滨工业大学(威海)网络与信息安全研究中心完成,指导教师为徐永东
73 |
74 | 以下为相关贡献者,排名不分先后
75 |
76 | 研发:佟梓赫,秦晨阳,刘宗鑫,沙昆
77 |
--------------------------------------------------------------------------------
/Ziwei-Chatglm3-6B/.github/ISSUE_TEMPLATE/bug_report.yaml:
--------------------------------------------------------------------------------
1 | name: "\U0001F41B Bug Report"
2 | description: Submit a bug report to help us improve ChatGLM3 / 提交一个 Bug 问题报告来帮助我们改进 ChatGLM3
3 | body:
4 | - type: textarea
5 | id: system-info
6 | attributes:
7 | label: System Info / 系統信息
8 | description: Your operating environment / 您的运行环境信息
9 | placeholder: Includes Cuda version, Transformers version, Python version, operating system, hardware information (if you suspect a hardware problem)... / 包括Cuda版本,Transformers版本,Python版本,操作系统,硬件信息(如果您怀疑是硬件方面的问题)...
10 | validations:
11 | required: true
12 |
13 | - type: textarea
14 | id: who-can-help
15 | attributes:
16 | label: Who can help? / 谁可以帮助到您?
17 | description: |
18 | Your issue will be replied to more quickly if you can figure out the right person to tag with @
19 | All issues are read by one of the maintainers, so if you don't know who to tag, just leave this blank and our maintainer will ping the right person.
20 |
21 | Please tag fewer than 3 people.
22 |
23 | 如果您能找到合适的标签 @,您的问题会更快得到回复。
24 | 所有问题都会由我们的维护者阅读,如果您不知道该标记谁,只需留空,我们的维护人员会找到合适的开发组成员来解决问题。
25 |
26 | 标记的人数应该不超过 3 个人。
27 |
28 | Related demo leader / 相关demo负责人 :
29 | - finetune_demo: @Btlmd
30 | - langchain_demo: @yincf
31 | - composite_demo: @abmfy
32 |
33 | If it's not a bug in these three subsections, you may not specify the helper. Our maintainer will find the right person in the development group to solve the problem.
34 |
35 | 如果不是这三个子版块的bug,您可以不指明帮助者,我们的维护人员会找到合适的开发组成员来解决问题。
36 |
37 | placeholder: "@Username ..."
38 |
39 | - type: checkboxes
40 | id: information-scripts-examples
41 | attributes:
42 | label: Information / 问题信息
43 | description: 'The problem arises when using: / 问题出现在'
44 | options:
45 | - label: "The official example scripts / 官方的示例脚本"
46 | - label: "My own modified scripts / 我自己修改的脚本和任务"
47 |
48 | - type: textarea
49 | id: reproduction
50 | validations:
51 | required: true
52 | attributes:
53 | label: Reproduction / 复现过程
54 | description: |
55 | Please provide a code example that reproduces the problem you encountered, preferably with a minimal reproduction unit.
56 | If you have code snippets, error messages, stack traces, please provide them here as well.
57 | Please format your code correctly using code tags. See https://help.github.com/en/github/writing-on-github/creating-and-highlighting-code-blocks#syntax-highlighting
58 | Do not use screenshots, as they are difficult to read and (more importantly) do not allow others to copy and paste your code.
59 |
60 | 请提供能重现您遇到的问题的代码示例,最好是最小复现单元。
61 | 如果您有代码片段、错误信息、堆栈跟踪,也请在此提供。
62 | 请使用代码标签正确格式化您的代码。请参见 https://help.github.com/en/github/writing-on-github/creating-and-highlighting-code-blocks#syntax-highlighting
63 | 请勿使用截图,因为截图难以阅读,而且(更重要的是)不允许他人复制粘贴您的代码。
64 | placeholder: |
65 | Steps to reproduce the behavior/复现Bug的步骤:
66 |
67 | 1.
68 | 2.
69 | 3.
70 |
71 | - type: textarea
72 | id: expected-behavior
73 | validations:
74 | required: true
75 | attributes:
76 | label: Expected behavior / 期待表现
77 | description: "A clear and concise description of what you would expect to happen. /简单描述您期望发生的事情。"
--------------------------------------------------------------------------------
/Ziwei-Chatglm3-6B/.github/ISSUE_TEMPLATE/feature-request.yaml:
--------------------------------------------------------------------------------
1 | name: "\U0001F680 Feature request"
2 | description: Submit a request for a new ChatGLM3 feature / 提交一个新的 ChatGLM3 的功能建议
3 | labels: [ "feature" ]
4 | body:
5 | - type: textarea
6 | id: feature-request
7 | validations:
8 | required: true
9 | attributes:
10 | label: Feature request / 功能建议
11 | description: |
12 | A brief description of the functional proposal. Links to corresponding papers and code are desirable.
13 | 对功能建议的简述。最好提供对应的论文和代码链接
14 |
15 | - type: textarea
16 | id: motivation
17 | validations:
18 | required: true
19 | attributes:
20 | label: Motivation / 动机
21 | description: |
22 | Your motivation for making the suggestion. If that motivation is related to another GitHub issue, link to it here.
23 | 您提出建议的动机。如果该动机与另一个 GitHub 问题有关,请在此处提供对应的链接。
24 |
25 | - type: textarea
26 | id: contribution
27 | validations:
28 | required: true
29 | attributes:
30 | label: Your contribution / 您的贡献
31 | description: |
32 |
33 | Your PR link or any other link you can help with.
34 | 您的PR链接或者其他您能提供帮助的链接。
--------------------------------------------------------------------------------
/Ziwei-Chatglm3-6B/.github/PULL_REQUEST_TEMPLATE/pr_template.md:
--------------------------------------------------------------------------------
1 | # Raise valuable PR / 提出有价值的PR
2 |
3 | ## Caution/ 注意事项:
4 | Users should keep the following points in mind when submitting PRs:
5 |
6 | 1. The proposed PR should be about this project.
7 | 2. the proposed PR should be relevant, if there are multiple ideas and optimizations, they should be assigned to different PRs.
8 |
9 | 用户在提交PR时候应该注意以下几点:
10 |
11 | 1. 提出的PR应该是关于本项目的。
12 | 2. 提出的PR应该具有针对性,如果具有多个不同的想法和优化方案,应该分配到不同的PR中。
13 |
14 | ## 不应该提出的PR / PRs that should not be proposed
15 |
16 | If a developer proposes a PR about any of the following, it may be closed or Rejected.
17 |
18 | 1. those that don't describe improvement options.
19 | 2. multiple issues of different types combined in one PR.
20 | 3. The proposed PR is highly duplicative of already existing PRs.
21 |
22 | 如果开发者提出关于以下方面的PR,则可能会被直接关闭或拒绝通过。
23 |
24 | 1. 没有说明改进方案的。
25 | 2. 多个不同类型的问题合并在一个PR中的。
26 | 3. 提出的PR与已经存在的PR高度重复的。
27 |
28 |
29 | # 检查您的PR
30 | - [ ] Have you read the Contributor Guidelines, Pull Request section? / 您是否阅读了贡献者指南、Pull Request 部分?
31 | - [ ] Has this been discussed/approved via a Github issue or forum? If so, add a link. / 是否通过 Github 问题或论坛讨论/批准过?如果是,请添加链接。
32 | - [ ] Did you make sure you updated the documentation with your changes? Here are the Documentation Guidelines, and here are the Documentation Formatting Tips. /您是否确保根据您的更改更新了文档?这里是文档指南,这里是文档格式化技巧。
33 | - [ ] Did you write new required tests? / 您是否编写了新的必要测试?
34 | - [ ] Are your PRs for only one issue / 您的PR是否仅针对一个问题
--------------------------------------------------------------------------------
/Ziwei-Chatglm3-6B/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__
2 |
3 | # finetune_demo: generated & downloaded files
4 | finetune_demo/output
5 | finetune_demo/data
6 | finetune_demo/formatted_data
7 | ToolAlpaca/
8 | AdvertiseGen/
9 | *.gz
10 | *.idea
--------------------------------------------------------------------------------
/Ziwei-Chatglm3-6B/MODEL_LICENSE:
--------------------------------------------------------------------------------
1 | The ChatGLM3-6B License
2 |
3 | 1. 定义
4 |
5 | “许可方”是指分发其软件的 ChatGLM3-6B 模型团队。
6 |
7 | “软件”是指根据本许可提供的 ChatGLM3-6B 模型参数。
8 |
9 | 2. 许可授予
10 |
11 | 根据本许可的条款和条件,许可方特此授予您非排他性、全球性、不可转让、不可再许可、可撤销、免版税的版权许可。
12 |
13 | 上述版权声明和本许可声明应包含在本软件的所有副本或重要部分中。
14 |
15 | 3.限制
16 |
17 | 您不得出于任何军事或非法目的使用、复制、修改、合并、发布、分发、复制或创建本软件的全部或部分衍生作品。
18 |
19 | 您不得利用本软件从事任何危害国家安全和国家统一、危害社会公共利益、侵犯人身权益的行为。
20 |
21 | 4.免责声明
22 |
23 | 本软件“按原样”提供,不提供任何明示或暗示的保证,包括但不限于对适销性、特定用途的适用性和非侵权性的保证。 在任何情况下,作者或版权持有人均不对任何索赔、损害或其他责任负责,无论是在合同诉讼、侵权行为还是其他方面,由软件或软件的使用或其他交易引起、由软件引起或与之相关 软件。
24 |
25 | 5. 责任限制
26 |
27 | 除适用法律禁止的范围外,在任何情况下且根据任何法律理论,无论是基于侵权行为、疏忽、合同、责任或其他原因,任何许可方均不对您承担任何直接、间接、特殊、偶然、示范性、 或间接损害,或任何其他商业损失,即使许可人已被告知此类损害的可能性。
28 |
29 | 6.争议解决
30 |
31 | 本许可受中华人民共和国法律管辖并按其解释。 因本许可引起的或与本许可有关的任何争议应提交北京市海淀区人民法院。
32 |
33 | 请注意,许可证可能会更新到更全面的版本。 有关许可和版权的任何问题,请通过 license@zhipuai.cn 与我们联系。
34 |
35 | 1. Definitions
36 |
37 | “Licensor” means the ChatGLM3-6B Model Team that distributes its Software.
38 |
39 | “Software” means the ChatGLM3-6B model parameters made available under this license.
40 |
41 | 2. License Grant
42 |
43 | Subject to the terms and conditions of this License, the Licensor hereby grants to you a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable, royalty-free copyright license to use the Software.
44 |
45 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
46 |
47 | 3. Restriction
48 |
49 | You will not use, copy, modify, merge, publish, distribute, reproduce, or create derivative works of the Software, in whole or in part, for any military, or illegal purposes.
50 |
51 | You will not use the Software for any act that may undermine China's national security and national unity, harm the public interest of society, or infringe upon the rights and interests of human beings.
52 |
53 | 4. Disclaimer
54 |
55 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
56 |
57 | 5. Limitation of Liability
58 |
59 | EXCEPT TO THE EXTENT PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL THEORY, WHETHER BASED IN TORT, NEGLIGENCE, CONTRACT, LIABILITY, OR OTHERWISE WILL ANY LICENSOR BE LIABLE TO YOU FOR ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES, OR ANY OTHER COMMERCIAL LOSSES, EVEN IF THE LICENSOR HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES.
60 |
61 | 6. Dispute Resolution
62 |
63 | This license shall be governed and construed in accordance with the laws of People’s Republic of China. Any dispute arising from or in connection with this License shall be submitted to Haidian District People's Court in Beijing.
64 |
65 | Note that the license is subject to update to a more comprehensive version. For any questions related to the license and copyright, please contact us at license@zhipuai.cn.
66 |
--------------------------------------------------------------------------------
/Ziwei-Chatglm3-6B/composite_demo/.streamlit/config.toml:
--------------------------------------------------------------------------------
1 | [theme]
2 | font = "monospace"
--------------------------------------------------------------------------------
/Ziwei-Chatglm3-6B/composite_demo/assets/demo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/myCSAI/Ziwei/a8fcbeabdce8a4417e4517b07bb799131b2cf5e0/Ziwei-Chatglm3-6B/composite_demo/assets/demo.png
--------------------------------------------------------------------------------
/Ziwei-Chatglm3-6B/composite_demo/assets/emojis.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/myCSAI/Ziwei/a8fcbeabdce8a4417e4517b07bb799131b2cf5e0/Ziwei-Chatglm3-6B/composite_demo/assets/emojis.png
--------------------------------------------------------------------------------
/Ziwei-Chatglm3-6B/composite_demo/assets/heart.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/myCSAI/Ziwei/a8fcbeabdce8a4417e4517b07bb799131b2cf5e0/Ziwei-Chatglm3-6B/composite_demo/assets/heart.png
--------------------------------------------------------------------------------
/Ziwei-Chatglm3-6B/composite_demo/assets/tool.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/myCSAI/Ziwei/a8fcbeabdce8a4417e4517b07bb799131b2cf5e0/Ziwei-Chatglm3-6B/composite_demo/assets/tool.png
--------------------------------------------------------------------------------
/Ziwei-Chatglm3-6B/composite_demo/client.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import os
4 | import streamlit as st
5 | import torch
6 |
7 | from collections.abc import Iterable
8 | from typing import Any, Protocol
9 | from huggingface_hub.inference._text_generation import TextGenerationStreamResponse, Token
10 | from transformers import AutoModel, AutoTokenizer, AutoConfig
11 | from transformers.generation.logits_process import LogitsProcessor
12 | from transformers.generation.utils import LogitsProcessorList
13 |
14 | from conversation import Conversation
15 |
16 | TOOL_PROMPT = 'Answer the following questions as best as you can. You have access to the following tools:'
17 |
18 | MODEL_PATH = os.environ.get('MODEL_PATH', '/home/admin1/桌面/chatglm3-6b')
19 | PT_PATH = os.environ.get('PT_PATH', None)
20 | PRE_SEQ_LEN = int(os.environ.get("PRE_SEQ_LEN", 128))
21 | TOKENIZER_PATH = os.environ.get("TOKENIZER_PATH", MODEL_PATH)
22 |
23 |
24 | @st.cache_resource
25 | def get_client() -> Client:
26 | client = HFClient(MODEL_PATH, TOKENIZER_PATH, PT_PATH)
27 | return client
28 |
29 |
30 | class Client(Protocol):
31 | def generate_stream(self,
32 | system: str | None,
33 | tools: list[dict] | None,
34 | history: list[Conversation],
35 | **parameters: Any
36 | ) -> Iterable[TextGenerationStreamResponse]:
37 | ...
38 |
39 |
40 | def stream_chat(
41 | self, tokenizer, query: str,
42 | history: list[tuple[str, str]] = None,
43 | role: str = "user",
44 | past_key_values=None,
45 | max_new_tokens: int = 256,
46 | do_sample=True, top_p=0.8,
47 | temperature=0.8,
48 | repetition_penalty=1.0,
49 | length_penalty=1.0, num_beams=1,
50 | logits_processor=None,
51 | return_past_key_values=False,
52 | **kwargs
53 | ):
54 | class InvalidScoreLogitsProcessor(LogitsProcessor):
55 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
56 | if torch.isnan(scores).any() or torch.isinf(scores).any():
57 | scores.zero_()
58 | scores[..., 5] = 5e4
59 | return scores
60 |
61 | if history is None:
62 | history = []
63 |
64 | print("\n== Input ==\n", query)
65 | print("\n==History==\n", history)
66 |
67 | if logits_processor is None:
68 | logits_processor = LogitsProcessorList()
69 | logits_processor.append(InvalidScoreLogitsProcessor())
70 | eos_token_id = [tokenizer.eos_token_id, tokenizer.get_command("<|user|>"),
71 | tokenizer.get_command("<|observation|>")]
72 | gen_kwargs = {"max_new_tokens": max_new_tokens,
73 | "do_sample": do_sample,
74 | "top_p": top_p,
75 | "temperature": temperature,
76 | "logits_processor": logits_processor,
77 | "repetition_penalty": repetition_penalty,
78 | "length_penalty": length_penalty,
79 | "num_beams": num_beams,
80 | **kwargs
81 | }
82 |
83 | if past_key_values is None:
84 | inputs = tokenizer.build_chat_input(query, history=history, role=role)
85 | else:
86 | inputs = tokenizer.build_chat_input(query, role=role)
87 | inputs = inputs.to(self.device)
88 | if past_key_values is not None:
89 | past_length = past_key_values[0][0].shape[0]
90 | if self.transformer.pre_seq_len is not None:
91 | past_length -= self.transformer.pre_seq_len
92 | inputs.position_ids += past_length
93 | attention_mask = inputs.attention_mask
94 | attention_mask = torch.cat((attention_mask.new_ones(1, past_length), attention_mask), dim=1)
95 | inputs['attention_mask'] = attention_mask
96 | history.append({"role": role, "content": query})
97 | input_sequence_length = inputs['input_ids'].shape[1]
98 | if input_sequence_length + max_new_tokens >= self.config.seq_length:
99 | yield "Current input sequence length {} plus max_new_tokens {} is too long. The maximum model sequence length is {}. You may adjust the generation parameter to enable longer chat history.".format(
100 | input_sequence_length, max_new_tokens, self.config.seq_length
101 | ), history
102 | return
103 |
104 | if input_sequence_length > self.config.seq_length:
105 | yield "Current input sequence length {} exceeds maximum model sequence length {}. Unable to generate tokens.".format(
106 | input_sequence_length, self.config.seq_length
107 | ), history
108 | return
109 |
110 | for outputs in self.stream_generate(**inputs, past_key_values=past_key_values,
111 | eos_token_id=eos_token_id, return_past_key_values=return_past_key_values,
112 | **gen_kwargs):
113 | if return_past_key_values:
114 | outputs, past_key_values = outputs
115 | outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):]
116 | response = tokenizer.decode(outputs)
117 | if response and response[-1] != "�":
118 | new_history = history
119 | if return_past_key_values:
120 | yield response, new_history, past_key_values
121 | else:
122 | yield response, new_history
123 |
124 |
125 | class HFClient(Client):
126 | def __init__(self, model_path: str, tokenizer_path: str, pt_checkpoint: str = None):
127 | self.model_path = model_path
128 | self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, trust_remote_code=True)
129 |
130 | if pt_checkpoint is not None and os.path.exists(pt_checkpoint):
131 | config = AutoConfig.from_pretrained(
132 | model_path,
133 | trust_remote_code=True,
134 | pre_seq_len=PRE_SEQ_LEN
135 | )
136 | self.model = AutoModel.from_pretrained(
137 | model_path,
138 | trust_remote_code=True,
139 | config=config,
140 | device_map="auto"
141 | ).eval()
142 | prefix_state_dict = torch.load(os.path.join(pt_checkpoint, "pytorch_model.bin"))
143 | new_prefix_state_dict = {}
144 | for k, v in prefix_state_dict.items():
145 | if k.startswith("transformer.prefix_encoder."):
146 | new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
147 | print("Loaded from pt checkpoints", new_prefix_state_dict.keys())
148 | self.model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
149 | else:
150 | self.model = (
151 | AutoModel.from_pretrained(
152 | MODEL_PATH,
153 | trust_remote_code=True,
154 | device_map="auto"
155 | ).eval())
156 | # plus .quantized() if you want to use quantized model
157 |
158 | def generate_stream(
159 | self,
160 | system: str | None,
161 | tools: list[dict] | None,
162 | history: list[Conversation],
163 | **parameters: Any
164 | ) -> Iterable[TextGenerationStreamResponse]:
165 | chat_history = [{
166 | 'role': 'system',
167 | 'content': system if not tools else TOOL_PROMPT,
168 | }]
169 |
170 | if tools:
171 | chat_history[0]['tools'] = tools
172 |
173 | for conversation in history[:-1]:
174 | chat_history.append({
175 | 'role': str(conversation.role).removeprefix('<|').removesuffix('|>'),
176 | 'content': conversation.content,
177 | })
178 |
179 | query = history[-1].content
180 | role = str(history[-1].role).removeprefix('<|').removesuffix('|>')
181 | text = ''
182 | for new_text, _ in stream_chat(
183 | self.model,
184 | self.tokenizer,
185 | query,
186 | chat_history,
187 | role,
188 | **parameters,
189 | ):
190 | word = new_text.removeprefix(text)
191 | word_stripped = word.strip()
192 | text = new_text
193 | yield TextGenerationStreamResponse(
194 | generated_text=text,
195 | token=Token(
196 | id=0,
197 | logprob=0,
198 | text=word,
199 | special=word_stripped.startswith('<|') and word_stripped.endswith('|>'),
200 | )
201 | )
202 |
--------------------------------------------------------------------------------
/Ziwei-Chatglm3-6B/composite_demo/conversation.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | from enum import auto, Enum
3 | import json
4 |
5 | from PIL.Image import Image
6 | import streamlit as st
7 | from streamlit.delta_generator import DeltaGenerator
8 |
9 | TOOL_PROMPT = 'Answer the following questions as best as you can. You have access to the following tools:\n'
10 |
11 | class Role(Enum):
12 | SYSTEM = auto()
13 | USER = auto()
14 | ASSISTANT = auto()
15 | TOOL = auto()
16 | INTERPRETER = auto()
17 | OBSERVATION = auto()
18 |
19 | def __str__(self):
20 | match self:
21 | case Role.SYSTEM:
22 | return "<|system|>"
23 | case Role.USER:
24 | return "<|user|>"
25 | case Role.ASSISTANT | Role.TOOL | Role.INTERPRETER:
26 | return "<|assistant|>"
27 | case Role.OBSERVATION:
28 | return "<|observation|>"
29 |
30 | # Get the message block for the given role
31 | def get_message(self):
32 | # Compare by value here, because the enum object in the session state
33 | # is not the same as the enum cases here, due to streamlit's rerunning
34 | # behavior.
35 | match self.value:
36 | case Role.SYSTEM.value:
37 | return
38 | case Role.USER.value:
39 | return st.chat_message(name="user", avatar="user")
40 | case Role.ASSISTANT.value:
41 | return st.chat_message(name="assistant", avatar="assistant")
42 | case Role.TOOL.value:
43 | return st.chat_message(name="tool", avatar="assistant")
44 | case Role.INTERPRETER.value:
45 | return st.chat_message(name="interpreter", avatar="assistant")
46 | case Role.OBSERVATION.value:
47 | return st.chat_message(name="observation", avatar="user")
48 | case _:
49 | st.error(f'Unexpected role: {self}')
50 |
51 | @dataclass
52 | class Conversation:
53 | role: Role
54 | content: str
55 | tool: str | None = None
56 | image: Image | None = None
57 |
58 | def __str__(self) -> str:
59 | print(self.role, self.content, self.tool)
60 | match self.role:
61 | case Role.SYSTEM | Role.USER | Role.ASSISTANT | Role.OBSERVATION:
62 | return f'{self.role}\n{self.content}'
63 | case Role.TOOL:
64 | return f'{self.role}{self.tool}\n{self.content}'
65 | case Role.INTERPRETER:
66 | return f'{self.role}interpreter\n{self.content}'
67 |
68 | # Human readable format
69 | def get_text(self) -> str:
70 | text = postprocess_text(self.content)
71 | match self.role.value:
72 | case Role.TOOL.value:
73 | text = f'Calling tool `{self.tool}`:\n\n{text}'
74 | case Role.INTERPRETER.value:
75 | text = f'{text}'
76 | case Role.OBSERVATION.value:
77 | text = f'Observation:\n```\n{text}\n```'
78 | return text
79 |
80 | # Display as a markdown block
81 | def show(self, placeholder: DeltaGenerator | None=None) -> str:
82 | if placeholder:
83 | message = placeholder
84 | else:
85 | message = self.role.get_message()
86 | if self.image:
87 | message.image(self.image)
88 | else:
89 | text = self.get_text()
90 | message.markdown(text)
91 |
92 | def preprocess_text(
93 | system: str | None,
94 | tools: list[dict] | None,
95 | history: list[Conversation],
96 | ) -> str:
97 | if tools:
98 | tools = json.dumps(tools, indent=4, ensure_ascii=False)
99 |
100 | prompt = f"{Role.SYSTEM}\n"
101 | prompt += system if not tools else TOOL_PROMPT
102 | if tools:
103 | tools = json.loads(tools)
104 | prompt += json.dumps(tools, ensure_ascii=False)
105 | for conversation in history:
106 | prompt += f'{conversation}'
107 | prompt += f'{Role.ASSISTANT}\n'
108 | return prompt
109 |
110 | def postprocess_text(text: str) -> str:
111 | text = text.replace("\(", "$")
112 | text = text.replace("\)", "$")
113 | text = text.replace("\[", "$$")
114 | text = text.replace("\]", "$$")
115 | text = text.replace("<|assistant|>", "")
116 | text = text.replace("<|observation|>", "")
117 | text = text.replace("<|system|>", "")
118 | text = text.replace("<|user|>", "")
119 | return text.strip()
--------------------------------------------------------------------------------
/Ziwei-Chatglm3-6B/composite_demo/demo_chat.py:
--------------------------------------------------------------------------------
1 | import streamlit as st
2 | from streamlit.delta_generator import DeltaGenerator
3 |
4 | from client import get_client
5 | from conversation import postprocess_text, preprocess_text, Conversation, Role
6 |
7 | client = get_client()
8 |
9 |
10 | # Append a conversation into history, while show it in a new markdown block
11 | def append_conversation(
12 | conversation: Conversation,
13 | history: list[Conversation],
14 | placeholder: DeltaGenerator | None = None,
15 | ) -> None:
16 | history.append(conversation)
17 | conversation.show(placeholder)
18 |
19 |
20 | def main(
21 | prompt_text: str,
22 | system_prompt: str,
23 | top_p: float = 0.8,
24 | temperature: float = 0.95,
25 | repetition_penalty: float = 1.0,
26 | max_new_tokens: int = 1024,
27 | retry: bool = False
28 | ):
29 | placeholder = st.empty()
30 | with placeholder.container():
31 | if 'chat_history' not in st.session_state:
32 | st.session_state.chat_history = []
33 |
34 | if prompt_text == "" and retry == False:
35 | print("\n== Clean ==\n")
36 | st.session_state.chat_history = []
37 | return
38 |
39 | history: list[Conversation] = st.session_state.chat_history
40 | for conversation in history:
41 | conversation.show()
42 |
43 | if retry:
44 | print("\n== Retry ==\n")
45 | last_user_conversation_idx = None
46 | for idx, conversation in enumerate(history):
47 | if conversation.role == Role.USER:
48 | last_user_conversation_idx = idx
49 | if last_user_conversation_idx is not None:
50 | prompt_text = history[last_user_conversation_idx].content
51 | del history[last_user_conversation_idx:]
52 |
53 |
54 | if prompt_text:
55 | prompt_text = prompt_text.strip()
56 | append_conversation(Conversation(Role.USER, prompt_text), history)
57 | placeholder = st.empty()
58 | message_placeholder = placeholder.chat_message(name="assistant", avatar="assistant")
59 | markdown_placeholder = message_placeholder.empty()
60 |
61 | output_text = ''
62 | for response in client.generate_stream(
63 | system_prompt,
64 | tools=None,
65 | history=history,
66 | do_sample=True,
67 | max_new_tokens=max_new_tokens,
68 | temperature=temperature,
69 | top_p=top_p,
70 | stop_sequences=[str(Role.USER)],
71 | repetition_penalty=repetition_penalty,
72 | ):
73 | token = response.token
74 | if response.token.special:
75 | print("\n==Output:==\n", output_text)
76 | match token.text.strip():
77 | case '<|user|>':
78 | break
79 | case _:
80 | st.error(f'Unexpected special token: {token.text.strip()}')
81 | break
82 | output_text += response.token.text
83 | markdown_placeholder.markdown(postprocess_text(output_text + '▌'))
84 |
85 | append_conversation(Conversation(
86 | Role.ASSISTANT,
87 | postprocess_text(output_text),
88 | ), history, markdown_placeholder)
--------------------------------------------------------------------------------
/Ziwei-Chatglm3-6B/composite_demo/main.py:
--------------------------------------------------------------------------------
1 | import streamlit as st
2 | st.set_page_config(
3 | page_title="ChatGLM3-Ziwei Demo",
4 | page_icon=":robot:",
5 | layout='centered',
6 | initial_sidebar_state='expanded',
7 | )
8 |
9 |
10 | import demo_chat
11 | from enum import Enum
12 |
13 | DEFAULT_SYSTEM_PROMPT = '''
14 | You are ChatGLM3-Ziwei, a large language model trained by HITWH. Follow the user's instructions carefully. Respond using markdown.
15 | '''.strip()
16 |
17 | # Set the title of the demo
18 | st.title("ChatGLM3-Ziwei Demo")
19 |
20 | # Add your custom text here, with smaller font size
21 |
22 |
23 | class Mode(str, Enum):
24 | CHAT= '💬 Chat'
25 |
26 |
27 | with st.sidebar:
28 | top_p = st.slider(
29 | 'top_p', 0.0, 1.0, 0.8, step=0.01
30 | )
31 | temperature = st.slider(
32 | 'temperature', 0.0, 1.5, 0.95, step=0.01
33 | )
34 | repetition_penalty = st.slider(
35 | 'repetition_penalty', 0.0, 2.0, 1.1, step=0.01
36 | )
37 | max_new_token = st.slider(
38 | 'Output length', 5, 32000, 256, step=1
39 | )
40 |
41 | cols = st.columns(2)
42 | export_btn = cols[0]
43 | clear_history = cols[1].button("Clear History", use_container_width=True)
44 | retry = export_btn.button("Retry", use_container_width=True)
45 |
46 | system_prompt = st.text_area(
47 | label="System Prompt (Only for chat mode)",
48 | height=300,
49 | value=DEFAULT_SYSTEM_PROMPT,
50 | )
51 |
52 | prompt_text = st.chat_input(
53 | 'Chat with ChatGLM3!',
54 | key='chat_input',
55 | )
56 |
57 | tab = st.radio(
58 | 'Mode',
59 | [mode.value for mode in Mode],
60 | horizontal=True,
61 | label_visibility='hidden',
62 | )
63 |
64 | if clear_history or retry:
65 | prompt_text = ""
66 |
67 | match tab:
68 | case Mode.CHAT:
69 | demo_chat.main(
70 | retry=retry,
71 | top_p=top_p,
72 | temperature=temperature,
73 | prompt_text=prompt_text,
74 | system_prompt=system_prompt,
75 | repetition_penalty=repetition_penalty,
76 | max_new_tokens=max_new_token
77 | )
78 |
79 | case _:
80 | st.error(f'Unexpected tab: {tab}')
81 |
--------------------------------------------------------------------------------
/Ziwei-Chatglm3-6B/composite_demo/requirements.txt:
--------------------------------------------------------------------------------
1 | huggingface_hub>=0.19.4
2 | pillow>=10.1.0
3 | pyyaml>=6.0.1
4 | requests>=2.31.0
5 | ipykernel>=6.26.0
6 | ipython>=8.18.1
7 | jupyter_client>=8.6.0
8 |
--------------------------------------------------------------------------------
/Ziwei-Chatglm3-6B/composite_demo/tool_registry.py:
--------------------------------------------------------------------------------
1 | """
2 | This code is the tool registration part. By registering the tool, the model can call the tool.
3 | This code provides extended functionality to the model, enabling it to call and interact with a variety of utilities
4 | through defined interfaces.
5 | """
6 |
7 | import copy
8 | import inspect
9 | from pprint import pformat
10 | import traceback
11 | from types import GenericAlias
12 | from typing import get_origin, Annotated
13 | import subprocess
14 |
15 | _TOOL_HOOKS = {}
16 | _TOOL_DESCRIPTIONS = {}
17 |
18 |
19 | def register_tool(func: callable):
20 | tool_name = func.__name__
21 | tool_description = inspect.getdoc(func).strip()
22 | python_params = inspect.signature(func).parameters
23 | tool_params = []
24 | for name, param in python_params.items():
25 | annotation = param.annotation
26 | if annotation is inspect.Parameter.empty:
27 | raise TypeError(f"Parameter `{name}` missing type annotation")
28 | if get_origin(annotation) != Annotated:
29 | raise TypeError(f"Annotation type for `{name}` must be typing.Annotated")
30 |
31 | typ, (description, required) = annotation.__origin__, annotation.__metadata__
32 | typ: str = str(typ) if isinstance(typ, GenericAlias) else typ.__name__
33 | if not isinstance(description, str):
34 | raise TypeError(f"Description for `{name}` must be a string")
35 | if not isinstance(required, bool):
36 | raise TypeError(f"Required for `{name}` must be a bool")
37 |
38 | tool_params.append({
39 | "name": name,
40 | "description": description,
41 | "type": typ,
42 | "required": required
43 | })
44 | tool_def = {
45 | "name": tool_name,
46 | "description": tool_description,
47 | "params": tool_params
48 | }
49 | print("[registered tool] " + pformat(tool_def))
50 | _TOOL_HOOKS[tool_name] = func
51 | _TOOL_DESCRIPTIONS[tool_name] = tool_def
52 |
53 | return func
54 |
55 |
56 | def dispatch_tool(tool_name: str, tool_params: dict) -> str:
57 | if tool_name not in _TOOL_HOOKS:
58 | return f"Tool `{tool_name}` not found. Please use a provided tool."
59 | tool_call = _TOOL_HOOKS[tool_name]
60 | try:
61 | ret = tool_call(**tool_params)
62 | except:
63 | ret = traceback.format_exc()
64 | return str(ret)
65 |
66 |
67 | def get_tools() -> dict:
68 | return copy.deepcopy(_TOOL_DESCRIPTIONS)
69 |
70 |
71 | # Tool Definitions
72 |
73 | @register_tool
74 | def random_number_generator(
75 | seed: Annotated[int, 'The random seed used by the generator', True],
76 | range: Annotated[tuple[int, int], 'The range of the generated numbers', True],
77 | ) -> int:
78 | """
79 | Generates a random number x, s.t. range[0] <= x < range[1]
80 | """
81 | if not isinstance(seed, int):
82 | raise TypeError("Seed must be an integer")
83 | if not isinstance(range, tuple):
84 | raise TypeError("Range must be a tuple")
85 | if not isinstance(range[0], int) or not isinstance(range[1], int):
86 | raise TypeError("Range must be a tuple of integers")
87 |
88 | import random
89 | return random.Random(seed).randint(*range)
90 |
91 |
92 | @register_tool
93 | def get_weather(
94 | city_name: Annotated[str, 'The name of the city to be queried', True],
95 | ) -> str:
96 | """
97 | Get the current weather for `city_name`
98 | """
99 |
100 | if not isinstance(city_name, str):
101 | raise TypeError("City name must be a string")
102 |
103 | key_selection = {
104 | "current_condition": ["temp_C", "FeelsLikeC", "humidity", "weatherDesc", "observation_time"],
105 | }
106 | import requests
107 | try:
108 | resp = requests.get(f"https://wttr.in/{city_name}?format=j1")
109 | resp.raise_for_status()
110 | resp = resp.json()
111 | ret = {k: {_v: resp[k][0][_v] for _v in v} for k, v in key_selection.items()}
112 | except:
113 | import traceback
114 | ret = "Error encountered while fetching weather data!\n" + traceback.format_exc()
115 |
116 | return str(ret)
117 |
118 |
119 | @register_tool
120 | def get_shell(
121 | query: Annotated[str, 'The command should run in Linux shell', True],
122 | ) -> str:
123 | """
124 | Use shell to run command
125 | """
126 | if not isinstance(query, str):
127 | raise TypeError("Command must be a string")
128 | try:
129 | result = subprocess.run(query, shell=True, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE,
130 | text=True)
131 | return result.stdout
132 | except subprocess.CalledProcessError as e:
133 | return e.stderr
134 |
135 |
136 | if __name__ == "__main__":
137 | # print(dispatch_tool("get_shell", {"query": "pwd"}))
138 | print(get_tools())
--------------------------------------------------------------------------------
/Ziwei-Chatglm3-6B/finetune_demo/=1.3.0:
--------------------------------------------------------------------------------
1 | Looking in indexes: http://mirrors.aliyun.com/pypi/simple
2 | Collecting openai
3 | Downloading http://mirrors.aliyun.com/pypi/packages/8a/6c/f345662c586464cbd6185239ddea4d281b623db302115ff7a2bb27db1eea/openai-1.3.4-py3-none-any.whl (220 kB)
4 | ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 220.5/220.5 kB 86.3 kB/s eta 0:00:00
5 | Requirement already satisfied: anyio<4,>=3.5.0 in /home/admin1/桌面/ChatGLM3-main/venv/lib/python3.10/site-packages (from openai) (3.7.1)
6 | Collecting distro<2,>=1.7.0 (from openai)
7 | Downloading http://mirrors.aliyun.com/pypi/packages/f4/2c/c90a3adaf0ddb70afe193f5ebfb539612af57cffe677c3126be533df3098/distro-1.8.0-py3-none-any.whl (20 kB)
8 | Requirement already satisfied: httpx<1,>=0.23.0 in /home/admin1/桌面/ChatGLM3-main/venv/lib/python3.10/site-packages (from openai) (0.25.1)
9 | Requirement already satisfied: pydantic<3,>=1.9.0 in /home/admin1/桌面/ChatGLM3-main/venv/lib/python3.10/site-packages (from openai) (2.5.1)
10 | Requirement already satisfied: tqdm>4 in /home/admin1/桌面/ChatGLM3-main/venv/lib/python3.10/site-packages (from openai) (4.66.1)
11 | Requirement already satisfied: typing-extensions<5,>=4.5 in /home/admin1/桌面/ChatGLM3-main/venv/lib/python3.10/site-packages (from openai) (4.8.0)
12 | Requirement already satisfied: idna>=2.8 in /home/admin1/桌面/ChatGLM3-main/venv/lib/python3.10/site-packages (from anyio<4,>=3.5.0->openai) (3.4)
13 | Requirement already satisfied: sniffio>=1.1 in /home/admin1/桌面/ChatGLM3-main/venv/lib/python3.10/site-packages (from anyio<4,>=3.5.0->openai) (1.3.0)
14 | Requirement already satisfied: exceptiongroup in /home/admin1/桌面/ChatGLM3-main/venv/lib/python3.10/site-packages (from anyio<4,>=3.5.0->openai) (1.1.3)
15 | Requirement already satisfied: certifi in /home/admin1/桌面/ChatGLM3-main/venv/lib/python3.10/site-packages (from httpx<1,>=0.23.0->openai) (2023.11.17)
16 | Requirement already satisfied: httpcore in /home/admin1/桌面/ChatGLM3-main/venv/lib/python3.10/site-packages (from httpx<1,>=0.23.0->openai) (1.0.2)
17 | Requirement already satisfied: annotated-types>=0.4.0 in /home/admin1/桌面/ChatGLM3-main/venv/lib/python3.10/site-packages (from pydantic<3,>=1.9.0->openai) (0.6.0)
18 | Requirement already satisfied: pydantic-core==2.14.3 in /home/admin1/桌面/ChatGLM3-main/venv/lib/python3.10/site-packages (from pydantic<3,>=1.9.0->openai) (2.14.3)
19 | Requirement already satisfied: h11<0.15,>=0.13 in /home/admin1/桌面/ChatGLM3-main/venv/lib/python3.10/site-packages (from httpcore->httpx<1,>=0.23.0->openai) (0.14.0)
20 | Installing collected packages: distro, openai
21 | Successfully installed distro-1.8.0 openai-1.3.4
22 |
--------------------------------------------------------------------------------
/Ziwei-Chatglm3-6B/finetune_demo/README.md:
--------------------------------------------------------------------------------
1 | # ChatGLM3-6B 微调示例
2 |
3 | 本目录提供 ChatGLM3-6B 模型的微调示例,包括全量微调和 P-Tuning v2。格式上,提供多轮对话微调样例和输入输出格式微调样例。
4 |
5 | 如果将模型下载到了本地,本文和代码中的 `THUDM/chatglm3-6b` 字段均应替换为相应地址以从本地加载模型。
6 |
7 | 运行示例需要 `python>=3.9`,除基础的 `torch` 依赖外,示例代码运行还需要依赖
8 |
9 | ```bash
10 | pip install transformers==4.30.2 accelerate sentencepiece astunparse deepspeed
11 | ```
12 |
13 | ## 多轮对话格式
14 |
15 | 多轮对话微调示例采用 ChatGLM3 对话格式约定,对不同角色添加不同 `loss_mask` 从而在一遍计算中为多轮回复计算 `loss`。
16 |
17 | ### 数据格式和预处理
18 |
19 | 对于数据文件,样例采用如下格式
20 |
21 | 如果您仅希望微调模型的对话能力,而非工具能力,您应该按照以下格式整理数据。
22 |
23 | ```json
24 | [
25 | {
26 | "conversations": [
27 | {
28 | "role": "system",
29 | "content": ""
30 | },
31 | {
32 | "role": "user",
33 | "content": ""
34 | },
35 | {
36 | "role": "assistant",
37 | "content": ""
38 | },
39 | // ... Muti Turn
40 | {
41 | "role": "user",
42 | "content": ""
43 | },
44 | {
45 | "role": "assistant",
46 | "content": ""
47 | }
48 | ]
49 | }
50 | // ...
51 | ]
52 | ```
53 |
54 | **请注意,这种方法在微调的step较多的情况下会影响到模型的工具调用功能**
55 |
56 | 如果您希望微调模型的对话和工具能力,您应该按照以下格式整理数据。
57 |
58 | ```json
59 | [
60 | {
61 | "tools": [
62 | // available tools, format is not restricted
63 | ],
64 | "conversations": [
65 | {
66 | "role": "system",
67 | "content": ""
68 | },
69 | {
70 | "role": "user",
71 | "content": ""
72 | },
73 | {
74 | "role": "assistant",
75 | "content": ""
76 | },
77 | {
78 | "role": "tool",
79 | "name": "": ""
82 | },
83 | "observation": ""
84 | // don't have to be string
85 | },
86 | {
87 | "role": "assistant",
88 | "content": ""
89 | },
90 | // ... Muti Turn
91 | {
92 | "role": "user",
93 | "content": ""
94 | },
95 | {
96 | "role": "assistant",
97 | "content": ""
98 | }
99 | ]
100 | }
101 | // ...
102 | ]
103 | ```
104 |
105 | - 关于工具描述的 system prompt 无需手动插入,预处理时会将 `tools` 字段使用 `json.dumps(..., ensure_ascii=False)` 格式化后插入为首条 system prompt。
106 |
107 | - 每种角色可以附带一个 `bool` 类型的 `loss` 字段,表示该字段所预测的内容是否参与 `loss` 计算。若没有该字段,样例实现中默认对 `system`, `user` 不计算 `loss`,其余角色则计算 `loss`。
108 |
109 | - `tool` 并不是 ChatGLM3 中的原生角色,这里的 `tool` 在预处理阶段将被自动转化为一个具有工具调用 `metadata` 的 `assistant` 角色(默认计算 `loss`)和一个表示工具返回值的 `observation` 角色(不计算 `loss`)。
110 |
111 | - 目前暂未实现 `Code interpreter`的微调任务。
112 |
113 | - `system` 角色为可选角色,但若存在 `system` 角色,其必须出现在 `user` 角色之前,且一个完整的对话数据(无论单轮或者多轮对话)只能出现一次 `system` 角色。
114 |
115 | 作为示例,我们使用 ToolAlpaca 数据集来进行微调。首先,克隆 [ToolAlpaca 数据集](https://github.com/tangqiaoyu/ToolAlpaca),并使用
116 |
117 | ```bash
118 | ./scripts/format_tool_alpaca.py --path "ToolAlpaca/data/train_data.json"
119 | ```
120 |
121 | 将数据集处理成上述格式。在这里,我们有意将工具处理成了了 `list[str]` 这样的自然语言形式,以观察模型在微调前后对工具定义的理解能力。
122 |
123 | ### 微调模型
124 |
125 | 以下脚本提供了微调模型的参考方式。
126 |
127 | ```bash
128 | ./scripts/finetune_ds_multiturn.sh # 全量微调
129 | ./scripts/finetune_pt_multiturn.sh # P-Tuning v2 微调
130 | ```
131 |
132 | ### 部署
133 |
134 | 我们更新了 ChatGLM3 的综合 Demo,使其可以部署微调后的模型 checkpoint。
135 |
136 | 对于全量微调,可以使用以下方式进行部署
137 |
138 | ```bash
139 | cd ../composite_demo
140 | MODEL_PATH="path to finetuned model checkpoint" TOKENIZER_PATH="THUDM/chatglm3-6b" streamlit run main.py
141 | ```
142 |
143 | 对于 P-Tuning v2 微调,可以使用以下方式进行部署
144 |
145 | ```bash
146 | cd ../composite_demo
147 | MODEL_PATH="THUDM/chatglm3-6b" PT_PATH="path to p-tuning checkpoint" streamlit run main.py
148 | ```
149 |
150 | ## 输入输出格式
151 |
152 | 对于输入-输出格式,样例采用如下输入格式
153 |
154 | ```json
155 | [
156 | {
157 | "prompt": "",
158 | "response": ""
159 | }
160 | // ...
161 | ]
162 | ```
163 |
164 | 预处理时,不会拼接任何角色标识符。
165 |
166 | 作为示例,我们使用 AdvertiseGen 数据集来进行微调。从 [Google Drive](https://drive.google.com/file/d/13_vf0xRTQsyneRKdD1bZIr93vBGOczrk/view?usp=sharing) 或者 [Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/b3f119a008264b1cabd1/?dl=1) 下载处理好的 AdvertiseGen 数据集,将解压后的 `AdvertiseGen` 目录放到本目录下。
167 |
168 | ```bash
169 | ./scripts/format_advertise_gen.py --path "AdvertiseGen/train.json"
170 | ```
171 |
172 | 来下载和将数据集处理成上述格式。
173 |
174 | ### 微调模型
175 |
176 | 以下脚本提供了微调模型的参考方式。
177 |
178 | ```bash
179 | ./scripts/finetune_ds.sh # 全量微调
180 | ./scripts/finetune_pt.sh # P-Tuning v2 微调
181 | ```
182 |
183 | ### 推理验证
184 |
185 | 对于输入输出格式的微调,可使用 `inference.py` 进行基本的推理验证。
186 |
187 | ```bash
188 | python inference.py \
189 | --pt-checkpoint "path to p-tuning checkpoint" \
190 | --model THUDM/chatglm3-6b
191 | ```
192 |
193 | ```bash
194 | python inference.py \
195 | --tokenizer THUDM/chatglm3-6b \
196 | --model "path to finetuned model checkpoint"
197 | ```
198 |
199 | ### 提示
200 |
201 | 1. 微调代码在开始训练前,会先打印首条训练数据的预处理信息,显示为
202 |
203 | ```log
204 | Sanity Check >>>>>>>>>>>>>
205 | '[gMASK]': 64790 -> -100
206 | 'sop': 64792 -> -100
207 | '<|system|>': 64794 -> -100
208 | '': 30910 -> -100
209 | '\n': 13 -> -100
210 | 'Answer': 20115 -> -100
211 | 'the': 267 -> -100
212 | 'following': 1762 -> -100
213 | ...
214 | 'know': 683 -> -100
215 | 'the': 267 -> -100
216 | 'response': 3010 -> -100
217 | 'details': 3296 -> -100
218 | '.': 30930 -> -100
219 | '<|assistant|>': 64796 -> -100
220 | '': 30910 -> 30910
221 | '\n': 13 -> 13
222 | 'I': 307 -> 307
223 | 'need': 720 -> 720
224 | 'to': 289 -> 289
225 | 'use': 792 -> 792
226 | ...
227 | <<<<<<<<<<<<< Sanity Check
228 | ```
229 |
230 | 字样,每行依次表示一个 detokenized string, token_id 和 target_id。可在日志中查看这部分的 `loss_mask` 是否符合预期。若不符合,可能需要调整代码或数据。
231 |
232 | 2. 参考显存用量
233 |
234 | - P-Tuning V2 `PRE_SEQ_LEN=128`, `DEV_BATCH_SIZE=1`, `GRAD_ACCUMULARION_STEPS=16`, `MAX_SEQ_LEN=2048` 配置下约需要 21GB 显存。
235 | - 全量微调时,`./scripts/finetune_ds_multiturn.sh` 中的配置(`MAX_SEQ_LEN=2048`, `DEV_BATCH_SIZE=16`, `GRAD_ACCUMULARION_STEPS=1`)恰好用满 4 * 80GB 显存。
236 |
237 | 3. 若尝试后发现显存不足,可以考虑
238 | - 尝试降低 `DEV_BATCH_SIZE` 并提升 `GRAD_ACCUMULARION_STEPS`
239 | - 尝试添加 `--quantization_bit 8` 或 `--quantization_bit 4`。
240 | - `PRE_SEQ_LEN=128`, `DEV_BATCH_SIZE=1`, `GRAD_ACCUMULARION_STEPS=16`, `MAX_SEQ_LEN=1024` 配置下,`--quantization_bit 8` 约需 12GB 显存,`--quantization_bit 4` 约需 7.6GB 显存。
241 |
242 | ## 参考文献
243 |
244 | ```
245 | @inproceedings{liu2022p,
246 | title={P-tuning: Prompt tuning can be comparable to fine-tuning across scales and tasks},
247 | author={Liu, Xiao and Ji, Kaixuan and Fu, Yicheng and Tam, Weng and Du, Zhengxiao and Yang, Zhilin and Tang, Jie},
248 | booktitle={Proceedings of the 60th Annual Meeting of the Association for Computational Linguistics (Volume 2: Short Papers)},
249 | pages={61--68},
250 | year={2022}
251 | }
252 |
253 | @misc{tang2023toolalpaca,
254 | title={ToolAlpaca: Generalized Tool Learning for Language Models with 3000 Simulated Cases},
255 | author={Qiaoyu Tang and Ziliang Deng and Hongyu Lin and Xianpei Han and Qiao Liang and Le Sun},
256 | year={2023},
257 | eprint={2306.05301},
258 | archivePrefix={arXiv},
259 | primaryClass={cs.CL}
260 | }
261 | ```
262 |
--------------------------------------------------------------------------------
/Ziwei-Chatglm3-6B/finetune_demo/arguments.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass, field
2 | from typing import Optional
3 |
4 |
5 | @dataclass
6 | class ModelArguments:
7 | """
8 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
9 | """
10 |
11 | model_name_or_path: str = field(
12 | metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
13 | )
14 | ptuning_checkpoint: str = field(
15 | default=None, metadata={"help": "Path to p-tuning v2 checkpoints"}
16 | )
17 | config_name: Optional[str] = field(
18 | default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
19 | )
20 | tokenizer_name: Optional[str] = field(
21 | default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
22 | )
23 | cache_dir: Optional[str] = field(
24 | default=None,
25 | metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"},
26 | )
27 | use_fast_tokenizer: bool = field(
28 | default=True,
29 | metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."},
30 | )
31 | model_revision: str = field(
32 | default="main",
33 | metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."},
34 | )
35 | use_auth_token: bool = field(
36 | default=False,
37 | metadata={
38 | "help": (
39 | "Will use the token generated when running `huggingface-cli login` (necessary to use this script "
40 | "with private models)."
41 | )
42 | },
43 | )
44 | resize_position_embeddings: Optional[bool] = field(
45 | default=None,
46 | metadata={
47 | "help": (
48 | "Whether to automatically resize the position embeddings if `max_source_length` exceeds "
49 | "the model's position embeddings."
50 | )
51 | },
52 | )
53 | quantization_bit: Optional[int] = field(
54 | default=None
55 | )
56 | pre_seq_len: Optional[int] = field(
57 | default=None
58 | )
59 | prefix_projection: bool = field(
60 | default=False
61 | )
62 |
63 |
64 | @dataclass
65 | class DataTrainingArguments:
66 | """
67 | Arguments pertaining to what data we are going to input our model for training and eval.
68 | """
69 | train_file: Optional[str] = field(
70 | default=None, metadata={"help": "The input training data file (a jsonlines or csv file)."}
71 | )
72 |
73 | max_seq_length: Optional[int] = field(
74 | default=2048,
75 | metadata={
76 | "help": (
77 | "The maximum total input sequence length after tokenization. Sequences longer "
78 | "than this will be truncated."
79 | )
80 | },
81 | )
82 |
83 | max_source_length: Optional[int] = field(
84 | default=1024,
85 | metadata={
86 | "help": (
87 | "The maximum total input sequence length after tokenization. Sequences longer "
88 | "than this will be truncated, sequences shorter will be padded."
89 | )
90 | },
91 | )
92 | max_target_length: Optional[int] = field(
93 | default=128,
94 | metadata={
95 | "help": (
96 | "The maximum total sequence length for target text after tokenization. Sequences longer "
97 | "than this will be truncated, sequences shorter will be padded."
98 | )
99 | },
100 | )
101 |
102 | train_format: str = field(
103 | default=None, metadata={"help": "The format of the training data file (mulit-turn or input-output)"},
104 | )
105 |
106 | overwrite_cache: bool = field(
107 | default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
108 | )
109 |
110 | preprocessing_num_workers: Optional[int] = field(
111 | default=None,
112 | metadata={"help": "The number of processes to use for the preprocessing."},
113 | )
114 |
115 | max_seq_length: Optional[int] = field(
116 | default=1024,
117 | metadata={
118 | "help": (
119 | "The maximum total input sequence length after tokenization. Sequences longer "
120 | "than this will be truncated, sequences shorter will be padded."
121 | )
122 | },
123 | )
124 |
125 | pad_to_max_length: bool = field(
126 | default=False,
127 | metadata={
128 | "help": (
129 | "Whether to pad all samples to model maximum sentence length. "
130 | "If False, will pad the samples dynamically when batching to the maximum length in the batch. More "
131 | "efficient on GPU but very bad for TPU."
132 | )
133 | },
134 | )
135 |
136 | max_train_samples: Optional[int] = field(
137 | default=None,
138 | metadata={
139 | "help": (
140 | "For debugging purposes or quicker training, truncate the number of training examples to this "
141 | "value if set."
142 | )
143 | },
144 | )
145 |
146 | def __post_init__(self):
147 | extension = self.train_file.split(".")[-1]
148 | assert extension in {"jsonl", "json"}, "`train_file` should be a jsonl or a json file."
149 |
150 | assert self.train_format in {"multi-turn", "input-output"}
--------------------------------------------------------------------------------
/Ziwei-Chatglm3-6B/finetune_demo/configs/deepspeed.json:
--------------------------------------------------------------------------------
1 | {
2 | "train_micro_batch_size_per_gpu": "auto",
3 | "zero_allow_untested_optimizer": true,
4 | "fp16": {
5 | "enabled": "auto",
6 | "loss_scale": 0,
7 | "initial_scale_power": 16,
8 | "loss_scale_window": 1000,
9 | "hysteresis": 2,
10 | "min_loss_scale": 1
11 | },
12 | "zero_optimization": {
13 | "stage": 2,
14 | "allgather_partitions": true,
15 | "allgather_bucket_size": 5e8,
16 | "overlap_comm": false,
17 | "reduce_scatter": true,
18 | "reduce_bucket_size": 5e8,
19 | "contiguous_gradients" : true
20 | }
21 | }
--------------------------------------------------------------------------------
/Ziwei-Chatglm3-6B/finetune_demo/finetune.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding=utf-8
3 | # Copyright 2021 The HuggingFace Team. All rights reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 | """
17 | Fine-tuning the library models for sequence to sequence.
18 | """
19 | # You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments.
20 | # Adapted from
21 |
22 |
23 | import logging
24 | import os
25 | import sys
26 | import torch
27 | import json
28 | import transformers
29 | from transformers import (
30 | AutoConfig,
31 | AutoModel,
32 | AutoTokenizer,
33 | DataCollatorForSeq2Seq,
34 | HfArgumentParser,
35 | Seq2SeqTrainingArguments,
36 | set_seed,
37 | )
38 | from trainer import PrefixTrainer
39 |
40 | from arguments import ModelArguments, DataTrainingArguments
41 |
42 | from preprocess_utils import sanity_check, MultiTurnDataset, InputOutputDataset
43 |
44 | logger = logging.getLogger(__name__)
45 |
46 | def main():
47 | parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments))
48 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
49 | # If we pass only one argument to the script and it's the path to a json file,
50 | # let's parse it to get our arguments.
51 | model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
52 | else:
53 | model_args, data_args, training_args = parser.parse_args_into_dataclasses()
54 |
55 | # Setup logging
56 | logging.basicConfig(
57 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
58 | datefmt="%m/%d/%Y %H:%M:%S",
59 | handlers=[logging.StreamHandler(sys.stdout)],
60 | )
61 |
62 | if training_args.should_log:
63 | # The default of training_args.log_level is passive, so we set log level at info here to have that default.
64 | transformers.utils.logging.set_verbosity_info()
65 |
66 | log_level = training_args.get_process_log_level()
67 | logger.setLevel(log_level)
68 | # datasets.utils.logging.set_verbosity(log_level)
69 | transformers.utils.logging.set_verbosity(log_level)
70 | transformers.utils.logging.enable_default_handler()
71 | transformers.utils.logging.enable_explicit_format()
72 |
73 | # Log on each process the small summary:
74 | logger.warning(
75 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
76 | + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
77 | )
78 | logger.info(f"Training/evaluation parameters {training_args}")
79 |
80 | # Set seed before initializing model.
81 | set_seed(training_args.seed)
82 |
83 | # Load pretrained model and tokenizer
84 | config = AutoConfig.from_pretrained(model_args.model_name_or_path, trust_remote_code=True)
85 | config.pre_seq_len = model_args.pre_seq_len
86 | config.prefix_projection = model_args.prefix_projection
87 |
88 | tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, trust_remote_code=True)
89 |
90 | if model_args.ptuning_checkpoint is not None:
91 | model = AutoModel.from_pretrained(model_args.model_name_or_path, config=config, trust_remote_code=True)
92 | prefix_state_dict = torch.load(os.path.join(model_args.ptuning_checkpoint, "pytorch_model.bin"))
93 | new_prefix_state_dict = {}
94 | for k, v in prefix_state_dict.items():
95 | if k.startswith("transformer.prefix_encoder."):
96 | new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
97 | model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
98 | else:
99 | model = AutoModel.from_pretrained(model_args.model_name_or_path, config=config, trust_remote_code=True)
100 |
101 | if model_args.quantization_bit is not None:
102 | print(f"Quantized to {model_args.quantization_bit} bit")
103 | model = model.quantize(model_args.quantization_bit)
104 | if model_args.pre_seq_len is not None:
105 | # P-tuning v2
106 | model = model.half()
107 | model.transformer.prefix_encoder.float()
108 | else:
109 | # Finetune
110 | model = model.float()
111 |
112 | with open(data_args.train_file, "r", encoding="utf-8") as f:
113 | if data_args.train_file.endswith(".json"):
114 | train_data = json.load(f)
115 | elif data_args.train_file.endswith(".jsonl"):
116 | train_data = [json.loads(line) for line in f]
117 |
118 | if data_args.train_format == "multi-turn":
119 | train_dataset = MultiTurnDataset(
120 | train_data,
121 | tokenizer,
122 | data_args.max_seq_length,
123 | )
124 | elif data_args.train_format == "input-output":
125 | train_dataset = InputOutputDataset(
126 | train_data,
127 | tokenizer,
128 | data_args.max_source_length,
129 | data_args.max_target_length,
130 | )
131 | else:
132 | raise ValueError(f"Unknown train format: {data_args.train_format}")
133 | if training_args.local_rank < 1:
134 | sanity_check(train_dataset[0]['input_ids'], train_dataset[0]['labels'], tokenizer)
135 |
136 | # Data collator
137 | data_collator = DataCollatorForSeq2Seq(
138 | tokenizer,
139 | model=model,
140 | label_pad_token_id=-100,
141 | pad_to_multiple_of=None,
142 | padding=False
143 | )
144 |
145 | # Initialize our Trainer
146 | trainer = PrefixTrainer(
147 | model=model,
148 | args=training_args,
149 | train_dataset=train_dataset,
150 | tokenizer=tokenizer,
151 | data_collator=data_collator,
152 | save_changed=model_args.pre_seq_len is not None
153 | )
154 |
155 | checkpoint = None
156 | if training_args.resume_from_checkpoint is not None:
157 | checkpoint = training_args.resume_from_checkpoint
158 | model.gradient_checkpointing_enable()
159 | model.enable_input_require_grads()
160 | trainer.train(resume_from_checkpoint=checkpoint)
161 | trainer.save_model() # Saves the tokenizer too for easy upload
162 | trainer.save_state()
163 |
164 | if __name__ == "__main__":
165 | main()
166 |
--------------------------------------------------------------------------------
/Ziwei-Chatglm3-6B/finetune_demo/inference.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from transformers import AutoConfig, AutoModel, AutoTokenizer
3 | import torch
4 | import os
5 |
6 | parser = argparse.ArgumentParser()
7 | parser.add_argument("--pt-checkpoint", type=str, default=None, help="The checkpoint path")
8 | parser.add_argument("--model", type=str, default=None, help="main model weights")
9 | parser.add_argument("--tokenizer", type=str, default=None, help="main model weights")
10 | parser.add_argument("--pt-pre-seq-len", type=int, default=128, help="The pre-seq-len used in p-tuning")
11 | parser.add_argument("--device", type=str, default="cuda")
12 | parser.add_argument("--max-new-tokens", type=int, default=128)
13 |
14 | args = parser.parse_args()
15 |
16 | if args.tokenizer is None:
17 | args.tokenizer = args.model
18 |
19 | if args.pt_checkpoint:
20 | tokenizer = AutoTokenizer.from_pretrained(args.tokenizer, trust_remote_code=True)
21 | config = AutoConfig.from_pretrained(args.model, trust_remote_code=True, pre_seq_len=128)
22 | model = AutoModel.from_pretrained(args.model, config=config, trust_remote_code=True)
23 | prefix_state_dict = torch.load(os.path.join(args.pt_checkpoint, "pytorch_model.bin"))
24 | new_prefix_state_dict = {}
25 | for k, v in prefix_state_dict.items():
26 | if k.startswith("transformer.prefix_encoder."):
27 | new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v
28 | model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict)
29 | else:
30 | tokenizer = AutoTokenizer.from_pretrained(args.tokenizer, trust_remote_code=True)
31 | model = AutoModel.from_pretrained(args.model, trust_remote_code=True)
32 |
33 | model = model.to(args.device)
34 |
35 | while True:
36 | prompt = input("Prompt:")
37 | inputs = tokenizer(prompt, return_tensors="pt")
38 | inputs = inputs.to(args.device)
39 | response = model.generate(input_ids=inputs["input_ids"], max_length=inputs["input_ids"].shape[-1] + args.max_new_tokens)
40 | response = response[0, inputs["input_ids"].shape[-1]:]
41 | print("Response:", tokenizer.decode(response, skip_special_tokens=True))
42 |
43 | #我的出生日期是2002年10月5日22时15分,生辰八字为壬寅壬子己未甲子,请预测一下我今年的运势
44 | #python inference.py --pt-checkpoint /media/admin1/BackupPlus/chatglm3-6b-pt/output_pt-20231224-165313-128-2e-2/checkpoint-1000 --model /home/admin1/桌面/chatglm3-6b
45 | # 1. 我的出生日期是2002年10月5日22时15分,我的命运如何?
46 | # 2. 我的出生日期是2002年10月5日22时15分我什么时候会结婚?
47 | # 3. 我的出生日期是2002年10月5日22时15分我什么时候会有孩子?
48 | # 4. 我的出生日期是2002年10月5日22时15分我的事业会有什么发展?
49 | # 5. 我的出生日期是2002年10月5日22时15分我会赚多少钱?
50 | # 6. 我的出生日期是2002年10月5日22时15分我什么时候会退休?
51 | # 7. 我的出生日期是2002年10月5日22时15分我会有多少财富?
52 | # 8. 我的出生日期是2002年10月5日22时15分我的健康状况如何?
53 | # 9. 我的出生日期是2002年10月5日22时15分我什么时候会生病?
54 | # 10. 我的出生日期是2002年10月5日22时15分我的寿命是多长?
--------------------------------------------------------------------------------
/Ziwei-Chatglm3-6B/finetune_demo/preprocess_utils.py:
--------------------------------------------------------------------------------
1 | import json
2 | import ast
3 | import astunparse
4 | from transformers import PreTrainedTokenizer
5 | from torch.utils.data import Dataset
6 | from copy import deepcopy
7 | from typing import Dict, List
8 |
9 | # text constants
10 | FUNCTION_CALL_NAME = 'tool_call'
11 | FUNCTION_CALL_PREFIX = '```python\n'
12 | FUNCTION_CALL_POSTFIX = '\n```'
13 | TOOL_DEFINITION_PREFIX = 'Answer the following questions as best as you can. You have access to the following tools:\n'
14 | CONVERSATOIN_KEY = 'conversations'
15 | TOOL_DESC_KEY = 'tools'
16 |
17 | def format_function_call(function_name: str, parameters: Dict[str, str]):
18 | function_name = ast.Name(id=function_name)
19 | keywords = [
20 | ast.keyword(arg=arg_name, value=ast.Constant(arg_value))
21 | for arg_name, arg_value in parameters.items()
22 | ]
23 | func_call = ast.Call(func=function_name, args=[], keywords=keywords)
24 | return astunparse.unparse(func_call).strip()
25 |
26 | def format_conversation(item, tokenizer, conversation_key: str, tool_key: str):
27 | conversations = deepcopy(item[conversation_key])
28 |
29 | # Note: `loss_mask` here means whether *the prediction* of the token should take loss
30 | tokens, loss_masks = [tokenizer.get_command("[gMASK]"), tokenizer.get_command("sop")], [0, 0]
31 |
32 | def _update(_tokens: List[int], value: int = 1):
33 | value = int(value)
34 | tokens.extend(_tokens)
35 | loss_masks.extend([value] * len(_tokens))
36 |
37 | # insert system prompt for tools
38 | if tool_key in item:
39 | conversations.insert(0,
40 | {
41 | "role": "system",
42 | "content": TOOL_DEFINITION_PREFIX + json.dumps(item[tool_key], indent=4, ensure_ascii=False)
43 | }
44 | )
45 |
46 | for idx, conv in enumerate(conversations):
47 | loss = conv.get("loss", True)
48 | if conv['role'] in {'system', 'user'}:
49 | loss = False
50 | if conv['role'] == 'tool':
51 | # function call python code
52 | value = FUNCTION_CALL_PREFIX + format_function_call(FUNCTION_CALL_NAME, conv["parameters"]) + FUNCTION_CALL_POSTFIX
53 | text = tokenizer.build_single_message("assistant", conv["name"], value)
54 | _update(text, loss)
55 |
56 | # function call result
57 | value = conv.get('observation', None)
58 | if not isinstance(value, str):
59 | value = json.dumps(value, ensure_ascii=False)
60 | text = tokenizer.build_single_message("observation", "", value)
61 | _update(text, False)
62 | else:
63 | text = tokenizer.build_single_message(conv['role'], "", conv["content"])
64 | _update(text, loss)
65 |
66 | _update([tokenizer.eos_token_id], False)
67 |
68 | assert len(tokens) == len(loss_masks), f"length mismatch: {len(tokens)} vs {len(loss_masks)}"
69 | return tokens, loss_masks
70 |
71 | def sanity_check(tokens: List[int], target: List[int], tokenizer: PreTrainedTokenizer):
72 | print("Sanity Check >>>>>>>>>>>>>")
73 | for t, m in zip(tokens, target):
74 | decoded = tokenizer.tokenizer.index_special_tokens[t] \
75 | if t in tokenizer.tokenizer.index_special_tokens \
76 | else tokenizer.decode([t])
77 | print("%20s: %6d -> %6d" % (repr(decoded), t, m))
78 | print("<<<<<<<<<<<<< Sanity Check")
79 |
80 | assert len(tokens) == len(target), f"length mismatch: {len(tokens)} vs {len(target)}"
81 |
82 | class MultiTurnDataset(Dataset):
83 | def __init__(self, data: List[dict], tokenizer: PreTrainedTokenizer, max_seq_length: int):
84 | super(MultiTurnDataset, self).__init__()
85 | self.tokenizer = tokenizer
86 | self.max_seq_length = max_seq_length
87 | self.data = data
88 |
89 | def __len__(self):
90 | return len(self.data)
91 |
92 | def __getitem__(self, i) -> dict:
93 | data_item = self.data[i]
94 | tokens, loss_masks = format_conversation(data_item, self.tokenizer, CONVERSATOIN_KEY, TOOL_DESC_KEY)
95 |
96 | # labels are used inside the model
97 | target_based_loss_mask = [False] + loss_masks[:-1]
98 | labels = [(t if m else -100) for t, m in zip(tokens, target_based_loss_mask)]
99 |
100 | tokens = tokens[:self.max_seq_length]
101 | labels = labels[:self.max_seq_length]
102 | tokens += [self.tokenizer.pad_token_id] * (self.max_seq_length - len(tokens))
103 | labels += [-100] * (self.max_seq_length - len(labels))
104 |
105 | assert len(tokens) == len(labels), f"length mismatch: {len(tokens)} vs {len(labels)}"
106 |
107 | return {
108 | "input_ids": tokens,
109 | "labels": labels
110 | }
111 |
112 | class InputOutputDataset(Dataset):
113 | def __init__(self, data: List[dict], tokenizer: PreTrainedTokenizer, max_source_length: int, max_target_length: int):
114 | super(InputOutputDataset, self).__init__()
115 | self.tokenizer = tokenizer
116 | self.max_source_length = max_source_length
117 | self.max_target_length = max_target_length
118 | self.max_seq_length = max_source_length + max_target_length + 1
119 | self.data = data
120 |
121 | def __len__(self):
122 | return len(self.data)
123 |
124 | def __getitem__(self, i) -> dict:
125 | data_item = self.data[i]
126 |
127 | a_ids = self.tokenizer.encode(text=data_item['prompt'], add_special_tokens=True, truncation=True,
128 | max_length=self.max_source_length)
129 | b_ids = self.tokenizer.encode(text=data_item['response'], add_special_tokens=False, truncation=True,
130 | max_length=self.max_target_length)
131 |
132 | context_length = len(a_ids)
133 | input_ids = a_ids + b_ids + [self.tokenizer.eos_token_id]
134 | labels = [self.tokenizer.pad_token_id] * context_length + b_ids + [self.tokenizer.eos_token_id]
135 |
136 | pad_len = self.max_seq_length - len(input_ids)
137 | input_ids = input_ids + [self.tokenizer.pad_token_id] * pad_len
138 | labels = labels + [self.tokenizer.pad_token_id] * pad_len
139 | labels = [(l if l != self.tokenizer.pad_token_id else -100) for l in labels]
140 |
141 | assert len(input_ids) == len(labels), f"length mismatch: {len(input_ids)} vs {len(labels)}"
142 |
143 | return {
144 | "input_ids": input_ids,
145 | "labels": labels
146 | }
147 |
--------------------------------------------------------------------------------
/Ziwei-Chatglm3-6B/finetune_demo/scripts/finetune_ds.sh:
--------------------------------------------------------------------------------
1 | #! /usr/bin/env bash
2 |
3 | set -ex
4 |
5 | LR=1e-4
6 | NUM_GPUS=4
7 | MAX_SOURCE_LEN=1024
8 | MAX_TARGET_LEN=128
9 | DEV_BATCH_SIZE=4
10 | GRAD_ACCUMULARION_STEPS=1
11 | MAX_STEP=500
12 | SAVE_INTERVAL=500
13 |
14 | RUN_NAME=advertise_gen_ft
15 | BASE_MODEL_PATH=THUDM/chatglm3-6b
16 | DATASET_PATH=formatted_data/advertise_gen.jsonl
17 |
18 | DATESTR=`date +%Y%m%d-%H%M%S`
19 | OUTPUT_DIR=output/${RUN_NAME}-${DATESTR}-${LR}
20 | MASTER_PORT=$(shuf -n 1 -i 10000-65535)
21 |
22 | mkdir -p $OUTPUT_DIR
23 |
24 | torchrun --standalone --nnodes=1 --nproc_per_node=$NUM_GPUS finetune.py \
25 | --train_format input-output \
26 | --train_file $DATASET_PATH \
27 | --preprocessing_num_workers 1 \
28 | --model_name_or_path $BASE_MODEL_PATH \
29 | --output_dir $OUTPUT_DIR \
30 | --max_source_length $MAX_SOURCE_LEN \
31 | --max_target_length $MAX_TARGET_LEN \
32 | --per_device_train_batch_size $DEV_BATCH_SIZE \
33 | --gradient_accumulation_steps $GRAD_ACCUMULARION_STEPS \
34 | --max_steps $MAX_STEP \
35 | --logging_steps 1 \
36 | --save_steps $SAVE_INTERVAL \
37 | --learning_rate $LR \
38 | --fp16 \
39 | --deepspeed configs/deepspeed.json 2>&1 | tee ${OUTPUT_DIR}/train.log
40 |
--------------------------------------------------------------------------------
/Ziwei-Chatglm3-6B/finetune_demo/scripts/finetune_ds_multiturn.sh:
--------------------------------------------------------------------------------
1 | #! /usr/bin/env bash
2 |
3 | set -ex
4 |
5 | LR=1e-4
6 | NUM_GPUS=4
7 | MAX_SEQ_LEN=2048
8 | DEV_BATCH_SIZE=16
9 | GRAD_ACCUMULARION_STEPS=1
10 | MAX_STEP=200
11 | SAVE_INTERVAL=50
12 |
13 | DATESTR=`date +%Y%m%d-%H%M%S`
14 | RUN_NAME=tool_alpaca_ft
15 | DATASET_PATH=formatted_data/tool_alpaca.jsonl
16 |
17 | BASE_MODEL_PATH=THUDM/chatglm3-6b
18 | OUTPUT_DIR=output/${RUN_NAME}-${DATESTR}-${LR}
19 |
20 | mkdir -p $OUTPUT_DIR
21 |
22 | torchrun --standalone --nnodes=1 --nproc_per_node=$NUM_GPUS finetune.py \
23 | --train_format multi-turn \
24 | --train_file $DATASET_PATH \
25 | --max_seq_length $MAX_SEQ_LEN \
26 | --preprocessing_num_workers 1 \
27 | --model_name_or_path $BASE_MODEL_PATH \
28 | --output_dir $OUTPUT_DIR \
29 | --per_device_train_batch_size $DEV_BATCH_SIZE \
30 | --gradient_accumulation_steps $GRAD_ACCUMULARION_STEPS \
31 | --max_steps $MAX_STEP \
32 | --logging_steps 1 \
33 | --save_steps $SAVE_INTERVAL \
34 | --fp16 \
35 | --deepspeed configs/deepspeed.json 2>&1 | tee ${OUTPUT_DIR}/train.log
36 |
--------------------------------------------------------------------------------
/Ziwei-Chatglm3-6B/finetune_demo/scripts/finetune_pt.sh:
--------------------------------------------------------------------------------
1 | #! /usr/bin/env bash
2 |
3 | set -ex
4 |
5 | PRE_SEQ_LEN=128
6 | LR=2e-2
7 | NUM_GPUS=1
8 | MAX_SOURCE_LEN=1024
9 | MAX_TARGET_LEN=128
10 | DEV_BATCH_SIZE=1
11 | GRAD_ACCUMULARION_STEPS=32
12 | MAX_STEP=1000
13 | SAVE_INTERVAL=500
14 |
15 | DATESTR=`date +%Y%m%d-%H%M%S`
16 | RUN_NAME=output_pt
17 |
18 | BASE_MODEL_PATH=/home/admin1/桌面/chatglm3-6b
19 | DATASET_PATH=formatted_data/output.jsonl
20 | OUTPUT_DIR=/media/admin1/BackupPlus/chatglm3-6b-pt/${RUN_NAME}-${DATESTR}-${PRE_SEQ_LEN}-${LR}
21 |
22 | mkdir -p $OUTPUT_DIR
23 |
24 | torchrun --standalone --nnodes=1 --nproc_per_node=$NUM_GPUS finetune.py \
25 | --train_format input-output \
26 | --train_file $DATASET_PATH \
27 | --preprocessing_num_workers 1 \
28 | --model_name_or_path $BASE_MODEL_PATH \
29 | --output_dir $OUTPUT_DIR \
30 | --max_source_length $MAX_SOURCE_LEN \
31 | --max_target_length $MAX_TARGET_LEN \
32 | --per_device_train_batch_size $DEV_BATCH_SIZE \
33 | --gradient_accumulation_steps $GRAD_ACCUMULARION_STEPS \
34 | --max_steps $MAX_STEP \
35 | --logging_steps 1 \
36 | --save_steps $SAVE_INTERVAL \
37 | --learning_rate $LR \
38 | --pre_seq_len $PRE_SEQ_LEN 2>&1 | tee ${OUTPUT_DIR}/train.log
39 |
--------------------------------------------------------------------------------
/Ziwei-Chatglm3-6B/finetune_demo/scripts/finetune_pt_multiturn.sh:
--------------------------------------------------------------------------------
1 | #! /usr/bin/env bash
2 |
3 | set -ex
4 |
5 | PRE_SEQ_LEN=128
6 | LR=2e-2
7 | NUM_GPUS=1
8 | MAX_SEQ_LEN=2048
9 | DEV_BATCH_SIZE=1
10 | GRAD_ACCUMULARION_STEPS=16
11 | MAX_STEP=1000
12 | SAVE_INTERVAL=500
13 |
14 | DATESTR=`date +%Y%m%d-%H%M%S`
15 | RUN_NAME=tool_alpaca_pt
16 |
17 | BASE_MODEL_PATH=THUDM/chatglm3-6b
18 | DATASET_PATH=formatted_data/tool_alpaca.jsonl
19 | OUTPUT_DIR=output/${RUN_NAME}-${DATESTR}-${PRE_SEQ_LEN}-${LR}
20 |
21 | mkdir -p $OUTPUT_DIR
22 |
23 | torchrun --standalone --nnodes=1 --nproc_per_node=$NUM_GPUS finetune.py \
24 | --train_format multi-turn \
25 | --train_file $DATASET_PATH \
26 | --max_seq_length $MAX_SEQ_LEN \
27 | --preprocessing_num_workers 1 \
28 | --model_name_or_path $BASE_MODEL_PATH \
29 | --output_dir $OUTPUT_DIR \
30 | --per_device_train_batch_size $DEV_BATCH_SIZE \
31 | --gradient_accumulation_steps $GRAD_ACCUMULARION_STEPS \
32 | --max_steps $MAX_STEP \
33 | --logging_steps 1 \
34 | --save_steps $SAVE_INTERVAL \
35 | --learning_rate $LR \
36 | --pre_seq_len $PRE_SEQ_LEN 2>&1 | tee ${OUTPUT_DIR}/train.log
37 |
--------------------------------------------------------------------------------
/Ziwei-Chatglm3-6B/finetune_demo/scripts/format_advertise_gen.py:
--------------------------------------------------------------------------------
1 | #! /usr/bin/env python
2 |
3 | import json
4 | from collections import Counter
5 | from argparse import ArgumentParser
6 | import os
7 |
8 | # parser = ArgumentParser()
9 | # parser.add_argument("--path", type=str, required=True)
10 | #
11 | # args = parser.parse_args()
12 |
13 | with open("//data/output.json") as f:
14 | # data = [json.loads(line) for line in f]
15 | data = []
16 | line = f.read(1)
17 | print(line)
18 | while line:
19 | if line == "{":
20 | j = ""
21 | while line != "}":
22 | if line != '\n':
23 | j += line
24 | line = f.read(1)
25 | j += "}"
26 | print(j)
27 | data += [json.loads(j)]
28 | line = f.read(1)
29 | for x in data:
30 | print(x['output'])
31 | train_examples = [{
32 | "prompt": x['instruction']+x['input'],
33 | "response": x['output']
34 | } for x in data]
35 |
36 | os.makedirs("../formatted_data", exist_ok=True)
37 |
38 | with open("../formatted_data/output.jsonl", "w") as f:
39 | for e in train_examples:
40 | f.write(json.dumps(e, ensure_ascii=False) + "\n")
41 |
--------------------------------------------------------------------------------
/Ziwei-Chatglm3-6B/finetune_demo/scripts/format_tool_alpaca.py:
--------------------------------------------------------------------------------
1 | #! /usr/bin/env python
2 |
3 | import json
4 | from collections import Counter
5 | from argparse import ArgumentParser
6 | import os
7 |
8 | parser = ArgumentParser()
9 | parser.add_argument("--path", type=str, required=True)
10 |
11 | args = parser.parse_args()
12 |
13 | with open(args.path) as f:
14 | data = json.load(f)
15 |
16 | train_examples = []
17 | err_count = 0
18 | for setting in data:
19 | api_desc = [setting["NLDocumentation"]]
20 | for instance in setting["Instances"]:
21 | try:
22 | conv = [{
23 | "role": "user",
24 | "content": instance['input'],
25 | }]
26 | for step in instance['intermediate_steps']:
27 | tool_name, params, react = step[0]
28 | step_thought = react.split("Action:")[0].strip()
29 | observation = step[1]
30 | conv.append({
31 | "role": "assistant",
32 | "content": step_thought,
33 | })
34 | conv.append({
35 | "role": "tool",
36 | "name": tool_name,
37 | "parameters": json.loads(params),
38 | "observation": observation,
39 | })
40 | conv.append({
41 | "role": "assistant",
42 | "content": instance['Final Thought'] + "\n" + instance['output'],
43 | })
44 | except:
45 | err_count += 1
46 | else:
47 | train_examples.append({
48 | "tools": api_desc,
49 | "conversations": conv
50 | })
51 |
52 | print("err_count:", err_count)
53 | print("train_examples:", len(train_examples))
54 | print("conversation distribution:", Counter([len(e["conversations"]) for e in train_examples]))
55 |
56 | os.makedirs("../formatted_data", exist_ok=True)
57 |
58 | with open("../formatted_data/tool_alpaca.jsonl", "w") as f:
59 | for e in train_examples:
60 | f.write(json.dumps(e, ensure_ascii=False) + "\n")
--------------------------------------------------------------------------------
/Ziwei-Chatglm3-6B/finetune_demo/trainer.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2020-present the HuggingFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """
16 | The Trainer class, to easily train a 🤗 Transformers from scratch or finetune it on a new task.
17 | """
18 | import os
19 | from typing import Optional
20 | from transformers import Trainer
21 |
22 | import torch
23 | from transformers.modeling_utils import PreTrainedModel, unwrap_model
24 | from transformers.utils import logging
25 |
26 | logger = logging.get_logger(__name__)
27 |
28 | WEIGHTS_NAME = "pytorch_model.bin"
29 | TRAINING_ARGS_NAME = "training_args.bin"
30 |
31 |
32 | class PrefixTrainer(Trainer):
33 | def __init__(self, *args, save_changed=False, **kwargs):
34 | self.save_changed = save_changed
35 | super().__init__(*args, **kwargs)
36 |
37 | def _save(self, output_dir: Optional[str] = None, state_dict=None):
38 | # If we are executing this function, we are the process zero, so we don't check for that.
39 | output_dir = output_dir if output_dir is not None else self.args.output_dir
40 | os.makedirs(output_dir, exist_ok=True)
41 | logger.info(f"Saving model checkpoint to {output_dir}")
42 | # Save a trained model and configuration using `save_pretrained()`.
43 | # They can then be reloaded using `from_pretrained()`
44 | if not isinstance(self.model, PreTrainedModel):
45 | if isinstance(unwrap_model(self.model), PreTrainedModel):
46 | if state_dict is None:
47 | state_dict = self.model.state_dict()
48 | unwrap_model(self.model).save_pretrained(output_dir, state_dict=state_dict)
49 | else:
50 | logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
51 | if state_dict is None:
52 | state_dict = self.model.state_dict()
53 | torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
54 | else:
55 | if self.save_changed:
56 | print("Saving PrefixEncoder")
57 | state_dict = self.model.state_dict()
58 | filtered_state_dict = {}
59 | for k, v in self.model.named_parameters():
60 | if v.requires_grad:
61 | filtered_state_dict[k] = state_dict[k]
62 | self.model.save_pretrained(output_dir, state_dict=filtered_state_dict)
63 | else:
64 | print("Saving the whole model")
65 | self.model.save_pretrained(output_dir, state_dict=state_dict)
66 | if self.tokenizer is not None:
67 | self.tokenizer.save_pretrained(output_dir)
68 |
69 | # Good practice: save your training arguments together with the trained model
70 | torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))
71 |
--------------------------------------------------------------------------------
/Ziwei-Chatglm3-6B/openai_api/openai_api.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Implements API for ChatGLM3-6B in OpenAI's format. (https://platform.openai.com/docs/api-reference/chat)
3 | # Usage: python openai_api.py
4 | # Visit http://localhost:8000/docs for documents.
5 |
6 | # 在OpenAI的API中,max_tokens 等价于 HuggingFace 的 max_new_tokens 而不是 max_length,。
7 | # 例如,对于6b模型,设置max_tokens = 8192,则会报错,因为扣除历史记录和提示词后,模型不能输出那么多的tokens。
8 |
9 | import time
10 | from contextlib import asynccontextmanager
11 | from typing import List, Literal, Optional, Union
12 |
13 | import torch
14 | import uvicorn
15 | from fastapi import FastAPI, HTTPException
16 | from fastapi.middleware.cors import CORSMiddleware
17 | from loguru import logger
18 | from pydantic import BaseModel, Field
19 | from sse_starlette.sse import EventSourceResponse
20 | from transformers import AutoTokenizer, AutoModel
21 |
22 | from utils import process_response, generate_chatglm3, generate_stream_chatglm3
23 |
24 |
25 | @asynccontextmanager
26 | async def lifespan(app: FastAPI): # collects GPU memory
27 | yield
28 | if torch.cuda.is_available():
29 | torch.cuda.empty_cache()
30 | torch.cuda.ipc_collect()
31 |
32 |
33 | app = FastAPI(lifespan=lifespan)
34 |
35 | app.add_middleware(
36 | CORSMiddleware,
37 | allow_origins=["*"],
38 | allow_credentials=True,
39 | allow_methods=["*"],
40 | allow_headers=["*"],
41 | )
42 |
43 |
44 | class ModelCard(BaseModel):
45 | id: str
46 | object: str = "model"
47 | created: int = Field(default_factory=lambda: int(time.time()))
48 | owned_by: str = "owner"
49 | root: Optional[str] = None
50 | parent: Optional[str] = None
51 | permission: Optional[list] = None
52 |
53 |
54 | class ModelList(BaseModel):
55 | object: str = "list"
56 | data: List[ModelCard] = []
57 |
58 |
59 | class FunctionCallResponse(BaseModel):
60 | name: Optional[str] = None
61 | arguments: Optional[str] = None
62 |
63 |
64 | class ChatMessage(BaseModel):
65 | role: Literal["user", "assistant", "system", "function"]
66 | content: str = None
67 | name: Optional[str] = None
68 | function_call: Optional[FunctionCallResponse] = None
69 |
70 |
71 | class DeltaMessage(BaseModel):
72 | role: Optional[Literal["user", "assistant", "system"]] = None
73 | content: Optional[str] = None
74 | function_call: Optional[FunctionCallResponse] = None
75 |
76 | class ChatCompletionRequest(BaseModel):
77 | model: str
78 | messages: List[ChatMessage]
79 | temperature: Optional[float] = 0.8
80 | top_p: Optional[float] = 0.8
81 | max_tokens: Optional[int] = None
82 | stream: Optional[bool] = False
83 | functions: Optional[Union[dict, List[dict]]] = None
84 | # Additional parameters
85 | repetition_penalty: Optional[float] = 1.1
86 |
87 |
88 | class ChatCompletionResponseChoice(BaseModel):
89 | index: int
90 | message: ChatMessage
91 | finish_reason: Literal["stop", "length", "function_call"]
92 |
93 |
94 | class ChatCompletionResponseStreamChoice(BaseModel):
95 | index: int
96 | delta: DeltaMessage
97 | finish_reason: Optional[Literal["stop", "length", "function_call"]]
98 |
99 |
100 | class UsageInfo(BaseModel):
101 | prompt_tokens: int = 0
102 | total_tokens: int = 0
103 | completion_tokens: Optional[int] = 0
104 |
105 |
106 | class ChatCompletionResponse(BaseModel):
107 | model: str
108 | object: Literal["chat.completion", "chat.completion.chunk"]
109 | choices: List[Union[ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice]]
110 | created: Optional[int] = Field(default_factory=lambda: int(time.time()))
111 | usage: Optional[UsageInfo] = None
112 |
113 |
114 | @app.get("/v1/models", response_model=ModelList)
115 | async def list_models():
116 | model_card = ModelCard(id="chatglm3-6b")
117 | return ModelList(data=[model_card])
118 |
119 |
120 | @app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
121 | async def create_chat_completion(request: ChatCompletionRequest):
122 | global model, tokenizer
123 |
124 | if len(request.messages) < 1 or request.messages[-1].role == "assistant":
125 | raise HTTPException(status_code=400, detail="Invalid request")
126 |
127 | gen_params = dict(
128 | messages=request.messages,
129 | temperature=request.temperature,
130 | top_p=request.top_p,
131 | max_tokens=request.max_tokens or 1024,
132 | echo=False,
133 | stream=request.stream,
134 | repetition_penalty=request.repetition_penalty,
135 | functions=request.functions,
136 | )
137 |
138 | logger.debug(f"==== request ====\n{gen_params}")
139 |
140 | if request.stream:
141 | generate = predict(request.model, gen_params)
142 | return EventSourceResponse(generate, media_type="text/event-stream")
143 |
144 | response = generate_chatglm3(model, tokenizer, gen_params)
145 | usage = UsageInfo()
146 |
147 | function_call, finish_reason = None, "stop"
148 | if request.functions:
149 | try:
150 | function_call = process_response(response["text"], use_tool=True)
151 | except:
152 | logger.warning("Failed to parse tool call")
153 |
154 | if isinstance(function_call, dict):
155 | finish_reason = "function_call"
156 | function_call = FunctionCallResponse(**function_call)
157 |
158 | message = ChatMessage(
159 | role="assistant",
160 | content=response["text"],
161 | function_call=function_call if isinstance(function_call, FunctionCallResponse) else None,
162 | )
163 |
164 | choice_data = ChatCompletionResponseChoice(
165 | index=0,
166 | message=message,
167 | finish_reason=finish_reason,
168 | )
169 |
170 | task_usage = UsageInfo.model_validate(response["usage"])
171 | for usage_key, usage_value in task_usage.model_dump().items():
172 | setattr(usage, usage_key, getattr(usage, usage_key) + usage_value)
173 |
174 | return ChatCompletionResponse(model=request.model, choices=[choice_data], object="chat.completion", usage=usage)
175 |
176 |
177 | async def predict(model_id: str, params: dict):
178 | global model, tokenizer
179 |
180 | choice_data = ChatCompletionResponseStreamChoice(
181 | index=0,
182 | delta=DeltaMessage(role="assistant"),
183 | finish_reason=None
184 | )
185 | chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
186 | yield "{}".format(chunk.model_dump_json(exclude_unset=True))
187 |
188 | previous_text = ""
189 | for new_response in generate_stream_chatglm3(model, tokenizer, params):
190 | decoded_unicode = new_response["text"]
191 | delta_text = decoded_unicode[len(previous_text):]
192 | previous_text = decoded_unicode
193 |
194 | finish_reason = new_response["finish_reason"]
195 | if len(delta_text) == 0 and finish_reason != "function_call":
196 | continue
197 |
198 | function_call = None
199 | if finish_reason == "function_call":
200 | try:
201 | function_call = process_response(decoded_unicode, use_tool=True)
202 | except:
203 | print("Failed to parse tool call")
204 |
205 | if isinstance(function_call, dict):
206 | function_call = FunctionCallResponse(**function_call)
207 |
208 | delta = DeltaMessage(
209 | content=delta_text,
210 | role="assistant",
211 | function_call=function_call if isinstance(function_call, FunctionCallResponse) else None,
212 | )
213 |
214 | choice_data = ChatCompletionResponseStreamChoice(
215 | index=0,
216 | delta=delta,
217 | finish_reason=finish_reason
218 | )
219 | chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
220 | yield "{}".format(chunk.model_dump_json(exclude_unset=True))
221 |
222 | choice_data = ChatCompletionResponseStreamChoice(
223 | index=0,
224 | delta=DeltaMessage(),
225 | finish_reason="stop"
226 | )
227 | chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
228 | yield "{}".format(chunk.model_dump_json(exclude_unset=True))
229 | yield '[DONE]'
230 |
231 |
232 | if __name__ == "__main__":
233 |
234 | model_path = "/home/admin1/桌面/chatglm3-6b"
235 | tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
236 | model = AutoModel.from_pretrained(model_path, trust_remote_code=True).cuda()
237 |
238 | # 多显卡支持,使用下面两行代替上面一行,将num_gpus改为你实际的显卡数量
239 | # from utils import load_model_on_gpus
240 | # model = load_model_on_gpus("THUDM/chatglm3-6b", num_gpus=2)
241 | model = model.eval()
242 |
243 | uvicorn.run(app, host='0.0.0.0', port=8000, workers=1)
244 |
--------------------------------------------------------------------------------
/Ziwei-Chatglm3-6B/openai_api/openai_api_request.py:
--------------------------------------------------------------------------------
1 | # 使用curl命令测试返回
2 | # curl -X POST "http://127.0.0.1:8000/v1/chat/completions" \
3 | # -H "Content-Type: application/json" \
4 | # -d "{\"model\": \"chatglm3-6b\", \"messages\": [{\"role\": \"system\", \"content\": \"You are ChatGLM3, a large language model trained by Zhipu.AI. Follow the user's instructions carefully. Respond using markdown.\"}, {\"role\": \"user\", \"content\": \"你好,给我讲一个故事,大概100字\"}], \"stream\": false, \"max_tokens\": 100, \"temperature\": 0.8, \"top_p\": 0.8}"
5 |
6 | # 使用Python代码测返回
7 | import requests
8 | import json
9 |
10 | base_url = "http://127.0.0.1:8000" # 本地部署的地址,或者使用你访问模型的API地址
11 |
12 | def create_chat_completion(model, messages, use_stream=False):
13 | data = {
14 | "model": model, # 模型名称
15 | "messages": messages, # 会话历史
16 | "stream": use_stream, # 是否流式响应
17 | "max_tokens": 100, # 最多生成字数
18 | "temperature": 0.8, # 温度
19 | "top_p": 0.8, # 采样概率
20 | }
21 |
22 | response = requests.post(f"{base_url}/v1/chat/completions", json=data, stream=use_stream)
23 | if response.status_code == 200:
24 | if use_stream:
25 | # 处理流式响应
26 | for line in response.iter_lines():
27 | if line:
28 | decoded_line = line.decode('utf-8')[6:]
29 | try:
30 | response_json = json.loads(decoded_line)
31 | content = response_json.get("choices", [{}])[0].get("delta", {}).get("content", "")
32 | print(content)
33 | except:
34 | print("Special Token:", decoded_line)
35 | else:
36 | # 处理非流式响应
37 | decoded_line = response.json()
38 | print(decoded_line)
39 | content = decoded_line.get("choices", [{}])[0].get("message", "").get("content", "")
40 | print(content)
41 | else:
42 | print("Error:", response.status_code)
43 | return None
44 |
45 |
46 | if __name__ == "__main__":
47 | chat_messages = [
48 | {
49 | "role": "system",
50 | "content": "You are ChatGLM3, a large language model trained by Zhipu.AI. Follow the user's instructions carefully. Respond using markdown.",
51 | },
52 | {
53 | "role": "user",
54 | "content": "你好,给我讲一个故事,大概100字"
55 | },
56 | {
57 | "role":"assistant",
58 | "content":"从前,有一个美丽的村庄,村子里的居民过着和谐的生活。有一天,村子里来了一只可爱的小狗,它一跃成为村民们的好朋友。小狗每天都会陪伴着大家,带给他们无尽的欢乐。无论是儿童还是老人,都为小狗的来到感到高兴。渐渐地,小狗成了村子的象征,它象征着友谊、忠诚和美好。"
59 | },
60 | {
61 | "role": "user",
62 | "content": "能把这个故事继续写下去吗"
63 | }
64 | ]
65 | create_chat_completion("chatglm3-6b", chat_messages, use_stream=False)
66 |
67 |
68 |
--------------------------------------------------------------------------------
/Ziwei-Chatglm3-6B/openai_api/requirements.txt:
--------------------------------------------------------------------------------
1 | openai>=1.3.0
2 | pydantic>=2.5.1
--------------------------------------------------------------------------------
/Ziwei-Chatglm3-6B/openai_api/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import gc
3 | import json
4 | import torch
5 | from torch.nn import Module
6 | from transformers import PreTrainedModel, PreTrainedTokenizer
7 | from transformers import AutoModel
8 | from transformers.generation.logits_process import LogitsProcessor
9 | from typing import Dict, Union, Optional, Tuple
10 |
11 |
12 | def auto_configure_device_map(num_gpus: int) -> Dict[str, int]:
13 | # transformer.word_embeddings 占用1层
14 | # transformer.final_layernorm 和 lm_head 占用1层
15 | # transformer.layers 占用 28 层
16 | # 总共30层分配到num_gpus张卡上
17 | num_trans_layers = 28
18 | per_gpu_layers = 30 / num_gpus
19 |
20 | # bugfix: 在linux中调用torch.embedding传入的weight,input不在同一device上,导致RuntimeError
21 | # windows下 model.device 会被设置成 transformer.word_embeddings.device
22 | # linux下 model.device 会被设置成 lm_head.device
23 | # 在调用chat或者stream_chat时,input_ids会被放到model.device上
24 | # 如果transformer.word_embeddings.device和model.device不同,则会导致RuntimeError
25 | # 因此这里将transformer.word_embeddings,transformer.final_layernorm,lm_head都放到第一张卡上
26 | # 本文件来源于https://github.com/THUDM/ChatGLM-6B/blob/main/utils.py
27 | # 仅此处做少许修改以支持ChatGLM3
28 | device_map = {
29 | 'transformer.embedding.word_embeddings': 0,
30 | 'transformer.encoder.final_layernorm': 0,
31 | 'transformer.output_layer': 0,
32 | 'transformer.rotary_pos_emb': 0,
33 | 'lm_head': 0
34 | }
35 |
36 | used = 2
37 | gpu_target = 0
38 | for i in range(num_trans_layers):
39 | if used >= per_gpu_layers:
40 | gpu_target += 1
41 | used = 0
42 | assert gpu_target < num_gpus
43 | device_map[f'transformer.encoder.layers.{i}'] = gpu_target
44 | used += 1
45 |
46 | return device_map
47 |
48 |
49 | def load_model_on_gpus(checkpoint_path: Union[str, os.PathLike], num_gpus: int = 2,
50 | device_map: Optional[Dict[str, int]] = None, **kwargs) -> Module:
51 | if num_gpus < 2 and device_map is None:
52 | model = AutoModel.from_pretrained(checkpoint_path, trust_remote_code=True, **kwargs).half().cuda()
53 | else:
54 | from accelerate import dispatch_model
55 |
56 | model = AutoModel.from_pretrained(checkpoint_path, trust_remote_code=True, **kwargs).half()
57 |
58 | if device_map is None:
59 | device_map = auto_configure_device_map(num_gpus)
60 |
61 | model = dispatch_model(model, device_map=device_map)
62 |
63 | return model
64 |
65 |
66 | class InvalidScoreLogitsProcessor(LogitsProcessor):
67 | def __call__(
68 | self, input_ids: torch.LongTensor, scores: torch.FloatTensor
69 | ) -> torch.FloatTensor:
70 | if torch.isnan(scores).any() or torch.isinf(scores).any():
71 | scores.zero_()
72 | scores[..., 5] = 5e4
73 | return scores
74 |
75 |
76 | def process_response(output: str, use_tool: bool = False) -> Union[str, dict]:
77 | content = ""
78 | for response in output.split("<|assistant|>"):
79 | metadata, content = response.split("\n", maxsplit=1)
80 | if not metadata.strip():
81 | content = content.strip()
82 | content = content.replace("[[训练时间]]", "2023年")
83 | else:
84 | if use_tool:
85 | content = "\n".join(content.split("\n")[1:-1])
86 |
87 | def tool_call(**kwargs):
88 | return kwargs
89 |
90 | parameters = eval(content)
91 | content = {
92 | "name": metadata.strip(),
93 | "arguments": json.dumps(parameters, ensure_ascii=False)
94 | }
95 | else:
96 | content = {
97 | "name": metadata.strip(),
98 | "content": content
99 | }
100 | return content
101 |
102 |
103 | @torch.inference_mode()
104 | def generate_stream_chatglm3(model: PreTrainedModel, tokenizer: PreTrainedTokenizer, params: dict):
105 | messages = params["messages"]
106 | functions = params["functions"]
107 | temperature = float(params.get("temperature", 1.0))
108 | repetition_penalty = float(params.get("repetition_penalty", 1.0))
109 | top_p = float(params.get("top_p", 1.0))
110 | max_new_tokens = int(params.get("max_tokens", 256))
111 | echo = params.get("echo", True)
112 | messages = process_chatglm_messages(messages, functions=functions)
113 | query, role = messages[-1]["content"], messages[-1]["role"]
114 |
115 | inputs = tokenizer.build_chat_input(query, history=messages[:-1], role=role)
116 | inputs = inputs.to(model.device)
117 | input_echo_len = len(inputs["input_ids"][0])
118 |
119 | if input_echo_len >= model.config.seq_length:
120 | print(f"Input length larger than {model.config.seq_length}")
121 |
122 | eos_token_id = [
123 | tokenizer.eos_token_id,
124 | tokenizer.get_command("<|user|>"),
125 | ]
126 |
127 | gen_kwargs = {
128 | "max_new_tokens": max_new_tokens,
129 | "do_sample": True if temperature > 1e-5 else False,
130 | "top_p": top_p,
131 | "repetition_penalty": repetition_penalty,
132 | "logits_processor": [InvalidScoreLogitsProcessor()],
133 | }
134 | if temperature > 1e-5:
135 | gen_kwargs["temperature"] = temperature
136 |
137 | total_len = 0
138 | for total_ids in model.stream_generate(**inputs, eos_token_id=eos_token_id, **gen_kwargs):
139 | total_ids = total_ids.tolist()[0]
140 | total_len = len(total_ids)
141 | if echo:
142 | output_ids = total_ids[:-1]
143 | else:
144 | output_ids = total_ids[input_echo_len:-1]
145 |
146 | response = tokenizer.decode(output_ids)
147 | if response and response[-1] != "�":
148 | response, stop_found = apply_stopping_strings(response, ["<|observation|>"])
149 |
150 | yield {
151 | "text": response,
152 | "usage": {
153 | "prompt_tokens": input_echo_len,
154 | "completion_tokens": total_len - input_echo_len,
155 | "total_tokens": total_len,
156 | },
157 | "finish_reason": "function_call" if stop_found else None,
158 | }
159 |
160 | if stop_found:
161 | break
162 |
163 | # Only last stream result contains finish_reason, we set finish_reason as stop
164 | ret = {
165 | "text": response,
166 | "usage": {
167 | "prompt_tokens": input_echo_len,
168 | "completion_tokens": total_len - input_echo_len,
169 | "total_tokens": total_len,
170 | },
171 | "finish_reason": "stop",
172 | }
173 | yield ret
174 |
175 | gc.collect()
176 | torch.cuda.empty_cache()
177 |
178 |
179 | def process_chatglm_messages(messages, functions=None):
180 | _messages = messages
181 | messages = []
182 |
183 | if functions:
184 | messages.append(
185 | {
186 | "role": "system",
187 | "content": "Answer the following questions as best as you can. You have access to the following tools:",
188 | "tools": functions
189 | }
190 | )
191 |
192 | for m in _messages:
193 | role, content, func_call = m.role, m.content, m.function_call
194 | if role == "function":
195 | messages.append(
196 | {
197 | "role": "observation",
198 | "content": content
199 | }
200 | )
201 |
202 | elif role == "assistant" and func_call is not None:
203 | for response in content.split("<|assistant|>"):
204 | metadata, sub_content = response.split("\n", maxsplit=1)
205 | messages.append(
206 | {
207 | "role": role,
208 | "metadata": metadata,
209 | "content": sub_content.strip()
210 | }
211 | )
212 | else:
213 | messages.append({"role": role, "content": content})
214 | return messages
215 |
216 |
217 | def generate_chatglm3(model: PreTrainedModel, tokenizer: PreTrainedTokenizer, params: dict):
218 | for response in generate_stream_chatglm3(model, tokenizer, params):
219 | pass
220 | return response
221 |
222 |
223 | def apply_stopping_strings(reply, stop_strings) -> Tuple[str, bool]:
224 | stop_found = False
225 | for string in stop_strings:
226 | idx = reply.find(string)
227 | if idx != -1:
228 | reply = reply[:idx]
229 | stop_found = True
230 | break
231 |
232 | if not stop_found:
233 | # If something like "\nYo" is generated just before "\nYou: is completed, trim it
234 | for string in stop_strings:
235 | for j in range(len(string) - 1, 0, -1):
236 | if reply[-j:] == string[:j]:
237 | reply = reply[:-j]
238 | break
239 | else:
240 | continue
241 |
242 | break
243 |
244 | return reply, stop_found
245 |
--------------------------------------------------------------------------------
/Ziwei-Chatglm3-6B/requirements.txt:
--------------------------------------------------------------------------------
1 | protobuf
2 | transformers>=4.30.2
3 | cpm_kernels
4 | torch>=2.0
5 | gradio~=3.39
6 | sentencepiece
7 | accelerate
8 | sse-starlette
9 | streamlit>=1.24.0
10 | fastapi>=0.95.1
11 | uvicorn~=0.24.0
12 | sse_starlette
13 | loguru~=0.7.2
14 | chroma
15 | openai~=1.3.4
16 | langchain~=0.0.352
17 | astunparse~=1.6.3
18 | pydantic~=2.5.1
19 | requests~=2.31.0
--------------------------------------------------------------------------------
/Ziwei-Chatglm3-6B/useChroma/build_chroma.py:
--------------------------------------------------------------------------------
1 | import sys
2 | from langchain.document_loaders import DirectoryLoader
3 | from langchain.embeddings.sentence_transformer import SentenceTransformerEmbeddings
4 | from langchain.vectorstores import Chroma
5 | from langchain.text_splitter import RecursiveCharacterTextSplitter
6 | import openai
7 | from langchain.document_loaders import PyPDFLoader
8 |
9 | sys.path.append('../..')
10 | openai.api_key = "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"
11 | #需要导入的书籍的pdf版本
12 | path = "/home/admin1/桌面/data_pdf"
13 | embedding_function = SentenceTransformerEmbeddings(model_name="shibing624/text2vec-base-chinese")
14 |
15 |
16 | def build_chromadb():
17 | if path:
18 | loader = DirectoryLoader(path, glob="**/*.pdf", loader_cls=PyPDFLoader)
19 | pages = loader.load()
20 | # print(len(pages))
21 | # print(page.page_content[0:500])
22 | text_splitter = RecursiveCharacterTextSplitter(
23 | chunk_size=200,
24 | chunk_overlap=10
25 | )
26 | splits = text_splitter.split_documents(pages)
27 | chroma = Chroma.from_documents(splits, embedding_function, persist_directory="/home/admin1/桌面/chromaDb")
28 | chroma.persist()
29 |
30 |
31 | if __name__ == '__main__':
32 | build_chromadb()
33 | db = Chroma(persist_directory="/home/admin1/桌面/chromaDb", embedding_function=embedding_function)
34 | docs = db.similarity_search("火命", k=1)
35 | print(docs)
36 |
--------------------------------------------------------------------------------
/Ziwei-Chatglm3-6B/with_prompt/prompt_web.py:
--------------------------------------------------------------------------------
1 | import streamlit as st
2 | import torch
3 | import copy
4 | from transformers import AutoModel, AutoTokenizer
5 |
6 | in_response=[] #多个内部回答
7 |
8 | # 设置页面标题、图标和布局
9 | st.set_page_config(
10 | page_title="ChatGLM3-Ziwei 演示",
11 | page_icon=":robot:",
12 | layout="wide"
13 | )
14 |
15 | # 设置为模型ID或本地文件夹路径
16 | model_path = "/home/admin1/桌面/chatglm3-6b"
17 |
18 | @st.cache_resource
19 | def get_model():
20 | tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
21 | model = AutoModel.from_pretrained(model_path, trust_remote_code=True).cuda()
22 | # 多显卡支持,使用下面两行代替上面一行,将num_gpus改为你实际的显卡数量
23 | # from utils import load_model_on_gpus
24 | # model = load_model_on_gpus("THUDM/chatglm3-6b", num_gpus=2)
25 | model = model.eval()
26 | return tokenizer, model
27 |
28 | # 加载Chatglm3的model和tokenizer
29 | tokenizer, model = get_model()
30 |
31 | # 初始化历史记录和past key values
32 | if "history" not in st.session_state:
33 | st.session_state.history = []
34 | if "past_key_values" not in st.session_state:
35 | st.session_state.past_key_values = None
36 |
37 | # 设置max_length、top_p和temperature
38 | max_length = st.sidebar.slider("max_length", 0, 32768, 8192, step=1)
39 | top_p = st.sidebar.slider("top_p", 0.0, 1.0, 0.8, step=0.01)
40 | temperature = st.sidebar.slider("temperature", 0.0, 1.0, 0.6, step=0.01)
41 |
42 | # 清理会话历史
43 | buttonClean = st.sidebar.button("清理会话历史", key="clean")
44 | if buttonClean:
45 | st.session_state.history = []
46 | st.session_state.past_key_values = None
47 | if torch.cuda.is_available():
48 | torch.cuda.empty_cache()
49 | st.rerun()
50 |
51 | # 渲染聊天历史记录
52 | for i, message in enumerate(st.session_state.history):
53 | if message["role"] == "user":
54 | with st.chat_message(name="user", avatar="user"):
55 | st.markdown(message["content"])
56 | else:
57 | with st.chat_message(name="assistant", avatar="assistant"):
58 | st.markdown(message["content"])
59 |
60 | # 输入框和输出框
61 | with st.chat_message(name="user", avatar="user"):
62 | input_placeholder = st.empty()
63 | with st.chat_message(name="assistant", avatar="assistant"):
64 | message_placeholder = st.empty()
65 |
66 | #prompt函数
67 | def template_change(template,a,b,i):
68 | template1=template.replace(a,b,i)
69 | return template1
70 |
71 | def get_classify(question,history,past_key_values):
72 | template_string = """
73 | 请判断下列问题属于占卜的哪一种分类或主题
74 | 注意:你只需要输出你判断的主题分类,即输出一个或几个词语,而不是一段话。
75 | ###问题:{question}
76 | """
77 | # 填充变量
78 | prompt = template_change(template_string,'{question}',question,1)
79 | #print('prompt 1: '+prompt)
80 | i=0
81 | for classify, history, past_key_values in model.stream_chat(
82 | tokenizer,
83 | prompt,
84 | history,
85 | past_key_values=past_key_values,
86 | max_length=max_length,
87 | top_p=top_p,
88 | temperature=temperature,
89 | return_past_key_values=True,
90 | ):
91 | i+=1
92 | print("1: " + classify)
93 |
94 | return classify
95 |
96 |
97 | # 多问题回答函数
98 | def prompt_main(question, history,past_key_values, theme,num):
99 | global in_response
100 | # 定义模板字符串
101 | ####主题:{theme}
102 | template_string = """
103 | 你现在是一位占卜师,你需要根据下面的问题与我对话,回答需要解释答案。
104 | ###问题: {question}
105 | 对话主题:{theme}
106 | """
107 | # 如果问题的答案需要询问者提供信息,那么不要捏造信息,询问相关的信息。
108 | # 使用模板字符串创建一个提示模板
109 | # 填充变量 ,theme=new_question
110 | prompt0 = template_change(template_string,'{question}', question,1)
111 | prompt=template_change(prompt0,'{theme}', theme,1)
112 |
113 | i=0
114 | for inresponse, history, past_key_values in model.stream_chat(
115 | tokenizer,
116 | prompt,
117 | history,
118 | past_key_values=past_key_values,
119 | max_length=max_length,
120 | top_p=top_p,
121 | temperature=temperature,
122 | return_past_key_values=True,
123 | ):
124 | i+=1
125 | in_response.append(inresponse)
126 | print("2:" + in_response[num])
127 |
128 | # 多回答合并函数
129 | def prompt_merge(num, question,history, past_key_values):
130 | global in_response
131 | reply = ''
132 | for i in range(num):
133 | reply += '\n第' + str(i + 1) + '段文字: '
134 | reply += in_response[i]
135 | #print('reply : ' + reply)
136 | # 定义模板字符串
137 | template_string = """
138 | 请把下面{num}段文字改写合并为一段流畅的文字
139 | 回答时开头不要出现【改写后的文字如下:】
140 | 整合后的文字不能重复出现相似的内容
141 | 整合后的文字应该尽量包含{num}段文字里不同的内容
142 | ###{reply}
143 | """
144 | # 填充变量
145 | prompt0 = template_change(template_string,'{num}', str(num),2)
146 | prompt=template_change(prompt0,'{reply}',reply,1)
147 | #print('prompt 3: ' + prompt)
148 | i=0
149 | for out_response ,history,past_key_values in model.stream_chat(
150 | tokenizer,
151 | prompt,
152 | history,
153 | past_key_values=past_key_values,
154 | max_length=max_length,
155 | top_p=top_p,
156 | temperature=temperature,
157 | return_past_key_values=True,
158 | ):
159 | message_placeholder.markdown(out_response)
160 | print("3:" + out_response)
161 | # for i in range(len(history)):
162 | # print("history ",history[i])
163 | return history[-1]
164 |
165 | # 获取用户输入
166 | prompt_text = st.chat_input("请输入您的问题")
167 |
168 | # 如果用户输入了内容,则生成回复
169 | if prompt_text:
170 | input_placeholder.markdown(prompt_text)
171 | history = st.session_state.history
172 | past_key_values = st.session_state.past_key_values
173 | history1=copy.deepcopy(st.session_state.history)
174 | past_key_values1 = st.session_state.past_key_values
175 |
176 | num=4
177 | theme = get_classify(prompt_text,history,past_key_values)
178 | history = copy.deepcopy(history1)
179 | for i in range(num):
180 | prompt_main(prompt_text,history,past_key_values, theme, i)
181 | history = copy.deepcopy(history1)
182 | history.append(prompt_merge(num, prompt_text,history,past_key_values))
183 | # for i in range(len(history)):
184 | # print("history1 ",history[i])
185 | h=[]
186 | i=0
187 | for response, h, past_key_values in model.stream_chat(
188 | tokenizer,
189 | prompt_text,
190 | h,
191 | past_key_values=past_key_values,
192 | max_length=max_length,
193 | top_p=top_p,
194 | temperature=temperature,
195 | return_past_key_values=True,
196 | ):
197 | i+=1
198 | history[-2]=h[-2]
199 | # for i in range(len(history)):
200 | # print('history: ' , history[i])
201 |
202 |
203 | # for response, history, past_key_values in model.stream_chat(
204 | # tokenizer,
205 | # prompt_text,
206 | # history,
207 | # past_key_values=past_key_values,
208 | # max_length=max_length,
209 | # top_p=top_p,
210 | # temperature=temperature,
211 | # return_past_key_values=True,
212 | # ):
213 | # message_placeholder.markdown(response)
214 |
215 | # 更新历史记录和past key values
216 | st.session_state.history = history
217 | st.session_state.past_key_values = past_key_values
218 | # streamlit run prompt_web.py
219 | # 请你给我算算今年运势
220 |
--------------------------------------------------------------------------------
/Ziwei-Chatglm3-6B/with_prompt/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | from typing import Dict, Union, Optional
3 | from torch.nn import Module
4 | from transformers import AutoModel
5 |
6 |
7 | def auto_configure_device_map(num_gpus: int) -> Dict[str, int]:
8 | # transformer.word_embeddings 占用1层
9 | # transformer.final_layernorm 和 lm_head 占用1层
10 | # transformer.layers 占用 28 层
11 | # 总共30层分配到num_gpus张卡上
12 | num_trans_layers = 28
13 | per_gpu_layers = 30 / num_gpus
14 |
15 | # bugfix: 在linux中调用torch.embedding传入的weight,input不在同一device上,导致RuntimeError
16 | # windows下 model.device 会被设置成 transformer.word_embeddings.device
17 | # linux下 model.device 会被设置成 lm_head.device
18 | # 在调用chat或者stream_chat时,input_ids会被放到model.device上
19 | # 如果transformer.word_embeddings.device和model.device不同,则会导致RuntimeError
20 | # 因此这里将transformer.word_embeddings,transformer.final_layernorm,lm_head都放到第一张卡上
21 | # 本文件来源于https://github.com/THUDM/ChatGLM-6B/blob/main/utils.py
22 | # 仅此处做少许修改以支持ChatGLM3
23 | device_map = {
24 | 'transformer.embedding.word_embeddings': 0,
25 | 'transformer.encoder.final_layernorm': 0,
26 | 'transformer.output_layer': 0,
27 | 'transformer.rotary_pos_emb': 0,
28 | 'lm_head': 0
29 | }
30 |
31 | used = 2
32 | gpu_target = 0
33 | for i in range(num_trans_layers):
34 | if used >= per_gpu_layers:
35 | gpu_target += 1
36 | used = 0
37 | assert gpu_target < num_gpus
38 | device_map[f'transformer.encoder.layers.{i}'] = gpu_target
39 | used += 1
40 |
41 | return device_map
42 |
43 |
44 | def load_model_on_gpus(checkpoint_path: Union[str, os.PathLike], num_gpus: int = 2,
45 | device_map: Optional[Dict[str, int]] = None, **kwargs) -> Module:
46 | if num_gpus < 2 and device_map is None:
47 | model = AutoModel.from_pretrained(checkpoint_path, trust_remote_code=True, **kwargs).half().cuda()
48 | else:
49 | from accelerate import dispatch_model
50 |
51 | model = AutoModel.from_pretrained(checkpoint_path, trust_remote_code=True, **kwargs).half()
52 |
53 | if device_map is None:
54 | device_map = auto_configure_device_map(num_gpus)
55 |
56 | model = dispatch_model(model, device_map=device_map)
57 |
58 | return model
--------------------------------------------------------------------------------