├── README.md └── Ziwei-Chatglm3-6B ├── .github ├── ISSUE_TEMPLATE │ ├── bug_report.yaml │ └── feature-request.yaml └── PULL_REQUEST_TEMPLATE │ └── pr_template.md ├── .gitignore ├── MODEL_LICENSE ├── composite_demo ├── .streamlit │ └── config.toml ├── assets │ ├── demo.png │ ├── emojis.png │ ├── heart.png │ └── tool.png ├── client.py ├── conversation.py ├── demo_chat.py ├── main.py ├── requirements.txt └── tool_registry.py ├── data └── output.json ├── finetune_demo ├── =1.3.0 ├── README.md ├── arguments.py ├── configs │ └── deepspeed.json ├── finetune.py ├── inference.py ├── preprocess_utils.py ├── scripts │ ├── finetune_ds.sh │ ├── finetune_ds_multiturn.sh │ ├── finetune_pt.sh │ ├── finetune_pt_multiturn.sh │ ├── format_advertise_gen.py │ └── format_tool_alpaca.py └── trainer.py ├── openai_api ├── openai_api.py ├── openai_api_request.py ├── requirements.txt └── utils.py ├── requirements.txt ├── useChroma └── build_chroma.py └── with_prompt ├── prompt_web.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | 紫薇算命大模型 2 | ==== 3 | 紫薇系列模型 4 | ------- 5 | * Ziwei-llama-7B:此版本基于Chines-llama-alpacha-7B通过lora微调和小批量文本pre-train训练而来,此基座模型相较于llama,扩充了中文词汇表,增加了在中文上表现能力,训练后算命领域的表现得到显著提升 6 | 。 7 | 8 | * Ziwei-chatglm3-6B:此版本基于Chatglm3-6B通过p-tuning训练而成,此基座模型拥有目前10B以下最好的中文表现能力 9 | ,微调后算命领域也有显著提升。 10 | 11 | 简介 12 | ------- 13 | 本项目旨在开发一个算命大模型,利用中文AI模型进行答题,提供给用户准确、有趣的算命体验。 14 | 15 | 16 | 我们将使用两种不同的模型分别进行开发:基于Chinese-llama-alpacha模型和基于chatGLM3模型。同时,我们将对它们进行比较,以了解它们在算命服务中的表现差异。数据收集方面,我们将使用爬虫技术收集网络上的相关问答对,同时利用书籍、博客等信息源获取更多的算命内容,并结合ChatGPT3.5API生成更多的问答对,以拓展数据集。 17 | 我们将使用lora(Language Representation Augmentation)技术对Chinese-llama模型进行微调,以提高其在算命服务中的表现。lora技术能够通过引入多样化的语言表征,增强模型的语言理解能力,使其在特定领域表现更好。 18 | 对chatGLM3模型则将采用P-tuningV2技术进行微调,P-tuningV2是一种自动深度参数调整技术,能够以较低的计算成本实现对模型的微调,并提升其在指定任务上的性能。 19 | 20 | 通过思维链、longchain和向量数据库的应用,我们将构建丰富的prompt,以激发模型更具创意、生动的回答,并提供更具吸引力的算命体验。通过向量数据库的构建,我们也将提供更加准确的查询匹配以及信息检索能力。 21 | 在对两种模型进行微调后,我们将进行性能比较与评估。通过评估指标包括准确率、回答的充实度、回答的个性化程度等,来比较两种模型在算命服务中的表现差异,以及它们的适用场景和优劣势。 22 | 我们希望通过这个项目,为用户提供一个全新的算命体验,以及为中文AI模型的应用领域提供更多的实践和拓展。 23 | 24 | 数据 25 | ------- 26 | 我们的对话数据由chatgpt-3.5生成,共十万条指令微调数据,随后经过清洗、数据增强,以下是数据样例: 27 | 28 | | 类型 | 数据 | 29 | | ------ | ------ | 30 | | llama数据 | "instruction": "请为我进行星座预测, 告诉我未来的事业发展如何。", "input": "我是射手座的,想知道我的事业发展如何,可以为我进行星座预测吗?",
"output": "根据你的星座射手座,预测显示你的事业将取得积极的发展。你将迎来许多机会和挑战,需要运用你的乐观和自信来面对。随着时间的推移,你将在工作中得到认可和升职的机会。记得保持专注和努力,同时也留出时间来追求个人兴趣和发展。" | 31 | | chatglm3数据| "prompt": "我抽到了一个卦象,是震卦。请帮我解读一下这个卦象。",
"response": "震卦代表着行动和挑战。这个卦象提醒你在面对困难和障碍时要保持积极的态度,并勇于采取行动。不要害怕失败,因为失败是成功的一部分。这个卦象还提醒你关注自己的思维方式和决策过程,确保它们是明智和有利的。相信自己的能力,勇往直前。"| 32 | 33 | 34 | 文本数据来自有关算命的电子书籍,以及对部分算命网站的爬虫。 35 | 36 | 部分数据在项目data文件夹下。 37 | 38 | 实验结果 39 | ------- 40 | 举例:对于“我的出生日期是2002年10月5日22时15分我什么时候会有孩子?”这个问题,我们给出微调前与微调后和加入 41 | 提示词工程后的回答对比。 42 | 43 | | 类型 | 微调前回答 | 微调后回答 | 44 | | ------ | ------ | ------ | 45 | | Ziwei-Chatglm3-6B | 作为一个人工智能助手,我无法预测个人的命运,包括生育的时间。生育是一个自然的过程,受到许多因素的影响,包括个人的生育能力、健康状况、家庭环境等。我的建议是,要积极面对生活,努力实现自己的目标,不断学习和成长。这样,你就能塑造自己的命运。|作为一位占卜师,我会根据你的出生日期和时间来进行一次星盘解读,以了解你目前的生育状况和未来可能的生育机会。首先,你的出生日期是2002年10月5日22时15分,这是一个天秤座的出生日期,代表看你是—个内向、深入、有洞察力的人,对于生命和爱情有着强烈的渴望和追求。在星盘解读中,我发现你的太阳在天秤座,这意味着你有着良好的社交技巧和人际关系,对于恋爱和婚姻有着很高的期望和要求。同时,你的月亮在射手座,这代表着你有着开放、自由、冒险和探索的精神,对于爱情和家庭也有着开放的态度。根据你的星盘解读,目前你的生育状况可能存在一些不确定性,可能需要进一步的检查和观察。未来可能的生育机会主要取决于你和你伴侣的关系状况、健康状况等因素。因此,我建议你保持积极乐观的心态,与你的伴侣共同关注生育问题,并在必要时寻求专业医生的帮助和建议。最后,我想要提醒你,作为一位占卜师,我的解读只是一种参考,并不能完全决定你的命运和未来。你的生育状况和机会,还取决于你自己的努力和选择。希望你能够保持积极、乐观和自信的心态,面对未来的挑战和机遇。| 46 | | Ziwei-llama-7B |抱歉,作为占卜师,我不能提供关于个人的预测或建议。但是,你可以通过自我探索、冥想或咨询专业心理医生来了解自己内心深处的想法和感受。|作为占卜师,我可以告诉你一些相关的信息,但是不能确定你的具体命运。根据占星学的理论,一个人的出生日期、时间和地点等信息会影响他们的命盘和星座,从而影响他们的生辰八字。然而,这些因素并不总是确定性的,因为每个人的生命历程都是独特的,并且受到许多其他因素的影响。关于你的问题,如果你想知道何时会有孩子的话,那么你可以通过占星学来预测。首先,你需要了解你的命盘和星座,然后使用占星学的方法来分析它们。在占星学中,每个星座都有不同的象征和含义,而每个命盘也具有不同的特征。因此,通过分析你的命盘和星座,你可以获得有关何时会有孩子的信息。然而,请注意,占星学只是一种预测方法,它并不能保证100%准确。它只能提供一些可能的情况和趋势,但最终的结果取决于你的个人选择和行动。此外,由于每个人的生命历程都是独特的,所以即使是相同的星座和命盘,也会导致不同的结果。总之,如果你想了解何时会有孩子的问题,那么可以通过占星学进行预测。但是,请记住,这种预测方法并不是百分之百准确的,而且最终的结果取决于你自己的选择。根据你提供的信息,你的命盘显示你将在未来十年内有机会拥有孩子。然而,具体的时间还要综合考虑其他因素,如个人意愿、健康状况等。建议你保持身体健康和平衡的生活,并注意时机,如果合适,可以试着怀孕。| 47 | 48 | 使用 49 | ------- 50 | >>模型下载 51 | 52 | * Ziwei-Chatglm3-6B: 53 | 54 | 下载微调后的模型权重:下载网址: [百度网盘](https://pan.baidu.com/s/1KZzW8x8gBFM7I_PnbJBMcA?pwd=5qqd "模型权重")。 55 | 提取码:5qqd 56 | 57 | >>模型使用提示 58 | 59 | 1.下载网盘内微调后的权重以及chatglm3-6B的原始权重,一同使用方为完整模型。 60 | 61 | 2.可以选择选择微调后或者加入prompt的模型进行推理。 62 | 63 | 3.可以继续收集数据利用脚本进行微调。 64 | 65 | 66 | 67 | 68 | 69 | 70 | 项目参与者 71 | ------- 72 | 本项目由哈尔滨工业大学(威海)网络与信息安全研究中心完成,指导教师为徐永东 73 | 74 | 以下为相关贡献者,排名不分先后 75 | 76 | 研发:佟梓赫,秦晨阳,刘宗鑫,沙昆 77 | -------------------------------------------------------------------------------- /Ziwei-Chatglm3-6B/.github/ISSUE_TEMPLATE/bug_report.yaml: -------------------------------------------------------------------------------- 1 | name: "\U0001F41B Bug Report" 2 | description: Submit a bug report to help us improve ChatGLM3 / 提交一个 Bug 问题报告来帮助我们改进 ChatGLM3 3 | body: 4 | - type: textarea 5 | id: system-info 6 | attributes: 7 | label: System Info / 系統信息 8 | description: Your operating environment / 您的运行环境信息 9 | placeholder: Includes Cuda version, Transformers version, Python version, operating system, hardware information (if you suspect a hardware problem)... / 包括Cuda版本,Transformers版本,Python版本,操作系统,硬件信息(如果您怀疑是硬件方面的问题)... 10 | validations: 11 | required: true 12 | 13 | - type: textarea 14 | id: who-can-help 15 | attributes: 16 | label: Who can help? / 谁可以帮助到您? 17 | description: | 18 | Your issue will be replied to more quickly if you can figure out the right person to tag with @ 19 | All issues are read by one of the maintainers, so if you don't know who to tag, just leave this blank and our maintainer will ping the right person. 20 | 21 | Please tag fewer than 3 people. 22 | 23 | 如果您能找到合适的标签 @,您的问题会更快得到回复。 24 | 所有问题都会由我们的维护者阅读,如果您不知道该标记谁,只需留空,我们的维护人员会找到合适的开发组成员来解决问题。 25 | 26 | 标记的人数应该不超过 3 个人。 27 | 28 | Related demo leader / 相关demo负责人 : 29 | - finetune_demo: @Btlmd 30 | - langchain_demo: @yincf 31 | - composite_demo: @abmfy 32 | 33 | If it's not a bug in these three subsections, you may not specify the helper. Our maintainer will find the right person in the development group to solve the problem. 34 | 35 | 如果不是这三个子版块的bug,您可以不指明帮助者,我们的维护人员会找到合适的开发组成员来解决问题。 36 | 37 | placeholder: "@Username ..." 38 | 39 | - type: checkboxes 40 | id: information-scripts-examples 41 | attributes: 42 | label: Information / 问题信息 43 | description: 'The problem arises when using: / 问题出现在' 44 | options: 45 | - label: "The official example scripts / 官方的示例脚本" 46 | - label: "My own modified scripts / 我自己修改的脚本和任务" 47 | 48 | - type: textarea 49 | id: reproduction 50 | validations: 51 | required: true 52 | attributes: 53 | label: Reproduction / 复现过程 54 | description: | 55 | Please provide a code example that reproduces the problem you encountered, preferably with a minimal reproduction unit. 56 | If you have code snippets, error messages, stack traces, please provide them here as well. 57 | Please format your code correctly using code tags. See https://help.github.com/en/github/writing-on-github/creating-and-highlighting-code-blocks#syntax-highlighting 58 | Do not use screenshots, as they are difficult to read and (more importantly) do not allow others to copy and paste your code. 59 | 60 | 请提供能重现您遇到的问题的代码示例,最好是最小复现单元。 61 | 如果您有代码片段、错误信息、堆栈跟踪,也请在此提供。 62 | 请使用代码标签正确格式化您的代码。请参见 https://help.github.com/en/github/writing-on-github/creating-and-highlighting-code-blocks#syntax-highlighting 63 | 请勿使用截图,因为截图难以阅读,而且(更重要的是)不允许他人复制粘贴您的代码。 64 | placeholder: | 65 | Steps to reproduce the behavior/复现Bug的步骤: 66 | 67 | 1. 68 | 2. 69 | 3. 70 | 71 | - type: textarea 72 | id: expected-behavior 73 | validations: 74 | required: true 75 | attributes: 76 | label: Expected behavior / 期待表现 77 | description: "A clear and concise description of what you would expect to happen. /简单描述您期望发生的事情。" -------------------------------------------------------------------------------- /Ziwei-Chatglm3-6B/.github/ISSUE_TEMPLATE/feature-request.yaml: -------------------------------------------------------------------------------- 1 | name: "\U0001F680 Feature request" 2 | description: Submit a request for a new ChatGLM3 feature / 提交一个新的 ChatGLM3 的功能建议 3 | labels: [ "feature" ] 4 | body: 5 | - type: textarea 6 | id: feature-request 7 | validations: 8 | required: true 9 | attributes: 10 | label: Feature request / 功能建议 11 | description: | 12 | A brief description of the functional proposal. Links to corresponding papers and code are desirable. 13 | 对功能建议的简述。最好提供对应的论文和代码链接 14 | 15 | - type: textarea 16 | id: motivation 17 | validations: 18 | required: true 19 | attributes: 20 | label: Motivation / 动机 21 | description: | 22 | Your motivation for making the suggestion. If that motivation is related to another GitHub issue, link to it here. 23 | 您提出建议的动机。如果该动机与另一个 GitHub 问题有关,请在此处提供对应的链接。 24 | 25 | - type: textarea 26 | id: contribution 27 | validations: 28 | required: true 29 | attributes: 30 | label: Your contribution / 您的贡献 31 | description: | 32 | 33 | Your PR link or any other link you can help with. 34 | 您的PR链接或者其他您能提供帮助的链接。 -------------------------------------------------------------------------------- /Ziwei-Chatglm3-6B/.github/PULL_REQUEST_TEMPLATE/pr_template.md: -------------------------------------------------------------------------------- 1 | # Raise valuable PR / 提出有价值的PR 2 | 3 | ## Caution/ 注意事项: 4 | Users should keep the following points in mind when submitting PRs: 5 | 6 | 1. The proposed PR should be about this project. 7 | 2. the proposed PR should be relevant, if there are multiple ideas and optimizations, they should be assigned to different PRs. 8 | 9 | 用户在提交PR时候应该注意以下几点: 10 | 11 | 1. 提出的PR应该是关于本项目的。 12 | 2. 提出的PR应该具有针对性,如果具有多个不同的想法和优化方案,应该分配到不同的PR中。 13 | 14 | ## 不应该提出的PR / PRs that should not be proposed 15 | 16 | If a developer proposes a PR about any of the following, it may be closed or Rejected. 17 | 18 | 1. those that don't describe improvement options. 19 | 2. multiple issues of different types combined in one PR. 20 | 3. The proposed PR is highly duplicative of already existing PRs. 21 | 22 | 如果开发者提出关于以下方面的PR,则可能会被直接关闭或拒绝通过。 23 | 24 | 1. 没有说明改进方案的。 25 | 2. 多个不同类型的问题合并在一个PR中的。 26 | 3. 提出的PR与已经存在的PR高度重复的。 27 | 28 | 29 | # 检查您的PR 30 | - [ ] Have you read the Contributor Guidelines, Pull Request section? / 您是否阅读了贡献者指南、Pull Request 部分? 31 | - [ ] Has this been discussed/approved via a Github issue or forum? If so, add a link. / 是否通过 Github 问题或论坛讨论/批准过?如果是,请添加链接。 32 | - [ ] Did you make sure you updated the documentation with your changes? Here are the Documentation Guidelines, and here are the Documentation Formatting Tips. /您是否确保根据您的更改更新了文档?这里是文档指南,这里是文档格式化技巧。 33 | - [ ] Did you write new required tests? / 您是否编写了新的必要测试? 34 | - [ ] Are your PRs for only one issue / 您的PR是否仅针对一个问题 -------------------------------------------------------------------------------- /Ziwei-Chatglm3-6B/.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | 3 | # finetune_demo: generated & downloaded files 4 | finetune_demo/output 5 | finetune_demo/data 6 | finetune_demo/formatted_data 7 | ToolAlpaca/ 8 | AdvertiseGen/ 9 | *.gz 10 | *.idea -------------------------------------------------------------------------------- /Ziwei-Chatglm3-6B/MODEL_LICENSE: -------------------------------------------------------------------------------- 1 | The ChatGLM3-6B License 2 | 3 | 1. 定义 4 | 5 | “许可方”是指分发其软件的 ChatGLM3-6B 模型团队。 6 | 7 | “软件”是指根据本许可提供的 ChatGLM3-6B 模型参数。 8 | 9 | 2. 许可授予 10 | 11 | 根据本许可的条款和条件,许可方特此授予您非排他性、全球性、不可转让、不可再许可、可撤销、免版税的版权许可。 12 | 13 | 上述版权声明和本许可声明应包含在本软件的所有副本或重要部分中。 14 | 15 | 3.限制 16 | 17 | 您不得出于任何军事或非法目的使用、复制、修改、合并、发布、分发、复制或创建本软件的全部或部分衍生作品。 18 | 19 | 您不得利用本软件从事任何危害国家安全和国家统一、危害社会公共利益、侵犯人身权益的行为。 20 | 21 | 4.免责声明 22 | 23 | 本软件“按原样”提供,不提供任何明示或暗示的保证,包括但不限于对适销性、特定用途的适用性和非侵权性的保证。 在任何情况下,作者或版权持有人均不对任何索赔、损害或其他责任负责,无论是在合同诉讼、侵权行为还是其他方面,由软件或软件的使用或其他交易引起、由软件引起或与之相关 软件。 24 | 25 | 5. 责任限制 26 | 27 | 除适用法律禁止的范围外,在任何情况下且根据任何法律理论,无论是基于侵权行为、疏忽、合同、责任或其他原因,任何许可方均不对您承担任何直接、间接、特殊、偶然、示范性、 或间接损害,或任何其他商业损失,即使许可人已被告知此类损害的可能性。 28 | 29 | 6.争议解决 30 | 31 | 本许可受中华人民共和国法律管辖并按其解释。 因本许可引起的或与本许可有关的任何争议应提交北京市海淀区人民法院。 32 | 33 | 请注意,许可证可能会更新到更全面的版本。 有关许可和版权的任何问题,请通过 license@zhipuai.cn 与我们联系。 34 | 35 | 1. Definitions 36 | 37 | “Licensor” means the ChatGLM3-6B Model Team that distributes its Software. 38 | 39 | “Software” means the ChatGLM3-6B model parameters made available under this license. 40 | 41 | 2. License Grant 42 | 43 | Subject to the terms and conditions of this License, the Licensor hereby grants to you a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable, royalty-free copyright license to use the Software. 44 | 45 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 46 | 47 | 3. Restriction 48 | 49 | You will not use, copy, modify, merge, publish, distribute, reproduce, or create derivative works of the Software, in whole or in part, for any military, or illegal purposes. 50 | 51 | You will not use the Software for any act that may undermine China's national security and national unity, harm the public interest of society, or infringe upon the rights and interests of human beings. 52 | 53 | 4. Disclaimer 54 | 55 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 56 | 57 | 5. Limitation of Liability 58 | 59 | EXCEPT TO THE EXTENT PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL THEORY, WHETHER BASED IN TORT, NEGLIGENCE, CONTRACT, LIABILITY, OR OTHERWISE WILL ANY LICENSOR BE LIABLE TO YOU FOR ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES, OR ANY OTHER COMMERCIAL LOSSES, EVEN IF THE LICENSOR HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES. 60 | 61 | 6. Dispute Resolution 62 | 63 | This license shall be governed and construed in accordance with the laws of People’s Republic of China. Any dispute arising from or in connection with this License shall be submitted to Haidian District People's Court in Beijing. 64 | 65 | Note that the license is subject to update to a more comprehensive version. For any questions related to the license and copyright, please contact us at license@zhipuai.cn. 66 | -------------------------------------------------------------------------------- /Ziwei-Chatglm3-6B/composite_demo/.streamlit/config.toml: -------------------------------------------------------------------------------- 1 | [theme] 2 | font = "monospace" -------------------------------------------------------------------------------- /Ziwei-Chatglm3-6B/composite_demo/assets/demo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/myCSAI/Ziwei/a8fcbeabdce8a4417e4517b07bb799131b2cf5e0/Ziwei-Chatglm3-6B/composite_demo/assets/demo.png -------------------------------------------------------------------------------- /Ziwei-Chatglm3-6B/composite_demo/assets/emojis.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/myCSAI/Ziwei/a8fcbeabdce8a4417e4517b07bb799131b2cf5e0/Ziwei-Chatglm3-6B/composite_demo/assets/emojis.png -------------------------------------------------------------------------------- /Ziwei-Chatglm3-6B/composite_demo/assets/heart.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/myCSAI/Ziwei/a8fcbeabdce8a4417e4517b07bb799131b2cf5e0/Ziwei-Chatglm3-6B/composite_demo/assets/heart.png -------------------------------------------------------------------------------- /Ziwei-Chatglm3-6B/composite_demo/assets/tool.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/myCSAI/Ziwei/a8fcbeabdce8a4417e4517b07bb799131b2cf5e0/Ziwei-Chatglm3-6B/composite_demo/assets/tool.png -------------------------------------------------------------------------------- /Ziwei-Chatglm3-6B/composite_demo/client.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import os 4 | import streamlit as st 5 | import torch 6 | 7 | from collections.abc import Iterable 8 | from typing import Any, Protocol 9 | from huggingface_hub.inference._text_generation import TextGenerationStreamResponse, Token 10 | from transformers import AutoModel, AutoTokenizer, AutoConfig 11 | from transformers.generation.logits_process import LogitsProcessor 12 | from transformers.generation.utils import LogitsProcessorList 13 | 14 | from conversation import Conversation 15 | 16 | TOOL_PROMPT = 'Answer the following questions as best as you can. You have access to the following tools:' 17 | 18 | MODEL_PATH = os.environ.get('MODEL_PATH', '/home/admin1/桌面/chatglm3-6b') 19 | PT_PATH = os.environ.get('PT_PATH', None) 20 | PRE_SEQ_LEN = int(os.environ.get("PRE_SEQ_LEN", 128)) 21 | TOKENIZER_PATH = os.environ.get("TOKENIZER_PATH", MODEL_PATH) 22 | 23 | 24 | @st.cache_resource 25 | def get_client() -> Client: 26 | client = HFClient(MODEL_PATH, TOKENIZER_PATH, PT_PATH) 27 | return client 28 | 29 | 30 | class Client(Protocol): 31 | def generate_stream(self, 32 | system: str | None, 33 | tools: list[dict] | None, 34 | history: list[Conversation], 35 | **parameters: Any 36 | ) -> Iterable[TextGenerationStreamResponse]: 37 | ... 38 | 39 | 40 | def stream_chat( 41 | self, tokenizer, query: str, 42 | history: list[tuple[str, str]] = None, 43 | role: str = "user", 44 | past_key_values=None, 45 | max_new_tokens: int = 256, 46 | do_sample=True, top_p=0.8, 47 | temperature=0.8, 48 | repetition_penalty=1.0, 49 | length_penalty=1.0, num_beams=1, 50 | logits_processor=None, 51 | return_past_key_values=False, 52 | **kwargs 53 | ): 54 | class InvalidScoreLogitsProcessor(LogitsProcessor): 55 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: 56 | if torch.isnan(scores).any() or torch.isinf(scores).any(): 57 | scores.zero_() 58 | scores[..., 5] = 5e4 59 | return scores 60 | 61 | if history is None: 62 | history = [] 63 | 64 | print("\n== Input ==\n", query) 65 | print("\n==History==\n", history) 66 | 67 | if logits_processor is None: 68 | logits_processor = LogitsProcessorList() 69 | logits_processor.append(InvalidScoreLogitsProcessor()) 70 | eos_token_id = [tokenizer.eos_token_id, tokenizer.get_command("<|user|>"), 71 | tokenizer.get_command("<|observation|>")] 72 | gen_kwargs = {"max_new_tokens": max_new_tokens, 73 | "do_sample": do_sample, 74 | "top_p": top_p, 75 | "temperature": temperature, 76 | "logits_processor": logits_processor, 77 | "repetition_penalty": repetition_penalty, 78 | "length_penalty": length_penalty, 79 | "num_beams": num_beams, 80 | **kwargs 81 | } 82 | 83 | if past_key_values is None: 84 | inputs = tokenizer.build_chat_input(query, history=history, role=role) 85 | else: 86 | inputs = tokenizer.build_chat_input(query, role=role) 87 | inputs = inputs.to(self.device) 88 | if past_key_values is not None: 89 | past_length = past_key_values[0][0].shape[0] 90 | if self.transformer.pre_seq_len is not None: 91 | past_length -= self.transformer.pre_seq_len 92 | inputs.position_ids += past_length 93 | attention_mask = inputs.attention_mask 94 | attention_mask = torch.cat((attention_mask.new_ones(1, past_length), attention_mask), dim=1) 95 | inputs['attention_mask'] = attention_mask 96 | history.append({"role": role, "content": query}) 97 | input_sequence_length = inputs['input_ids'].shape[1] 98 | if input_sequence_length + max_new_tokens >= self.config.seq_length: 99 | yield "Current input sequence length {} plus max_new_tokens {} is too long. The maximum model sequence length is {}. You may adjust the generation parameter to enable longer chat history.".format( 100 | input_sequence_length, max_new_tokens, self.config.seq_length 101 | ), history 102 | return 103 | 104 | if input_sequence_length > self.config.seq_length: 105 | yield "Current input sequence length {} exceeds maximum model sequence length {}. Unable to generate tokens.".format( 106 | input_sequence_length, self.config.seq_length 107 | ), history 108 | return 109 | 110 | for outputs in self.stream_generate(**inputs, past_key_values=past_key_values, 111 | eos_token_id=eos_token_id, return_past_key_values=return_past_key_values, 112 | **gen_kwargs): 113 | if return_past_key_values: 114 | outputs, past_key_values = outputs 115 | outputs = outputs.tolist()[0][len(inputs["input_ids"][0]):] 116 | response = tokenizer.decode(outputs) 117 | if response and response[-1] != "�": 118 | new_history = history 119 | if return_past_key_values: 120 | yield response, new_history, past_key_values 121 | else: 122 | yield response, new_history 123 | 124 | 125 | class HFClient(Client): 126 | def __init__(self, model_path: str, tokenizer_path: str, pt_checkpoint: str = None): 127 | self.model_path = model_path 128 | self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, trust_remote_code=True) 129 | 130 | if pt_checkpoint is not None and os.path.exists(pt_checkpoint): 131 | config = AutoConfig.from_pretrained( 132 | model_path, 133 | trust_remote_code=True, 134 | pre_seq_len=PRE_SEQ_LEN 135 | ) 136 | self.model = AutoModel.from_pretrained( 137 | model_path, 138 | trust_remote_code=True, 139 | config=config, 140 | device_map="auto" 141 | ).eval() 142 | prefix_state_dict = torch.load(os.path.join(pt_checkpoint, "pytorch_model.bin")) 143 | new_prefix_state_dict = {} 144 | for k, v in prefix_state_dict.items(): 145 | if k.startswith("transformer.prefix_encoder."): 146 | new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v 147 | print("Loaded from pt checkpoints", new_prefix_state_dict.keys()) 148 | self.model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict) 149 | else: 150 | self.model = ( 151 | AutoModel.from_pretrained( 152 | MODEL_PATH, 153 | trust_remote_code=True, 154 | device_map="auto" 155 | ).eval()) 156 | # plus .quantized() if you want to use quantized model 157 | 158 | def generate_stream( 159 | self, 160 | system: str | None, 161 | tools: list[dict] | None, 162 | history: list[Conversation], 163 | **parameters: Any 164 | ) -> Iterable[TextGenerationStreamResponse]: 165 | chat_history = [{ 166 | 'role': 'system', 167 | 'content': system if not tools else TOOL_PROMPT, 168 | }] 169 | 170 | if tools: 171 | chat_history[0]['tools'] = tools 172 | 173 | for conversation in history[:-1]: 174 | chat_history.append({ 175 | 'role': str(conversation.role).removeprefix('<|').removesuffix('|>'), 176 | 'content': conversation.content, 177 | }) 178 | 179 | query = history[-1].content 180 | role = str(history[-1].role).removeprefix('<|').removesuffix('|>') 181 | text = '' 182 | for new_text, _ in stream_chat( 183 | self.model, 184 | self.tokenizer, 185 | query, 186 | chat_history, 187 | role, 188 | **parameters, 189 | ): 190 | word = new_text.removeprefix(text) 191 | word_stripped = word.strip() 192 | text = new_text 193 | yield TextGenerationStreamResponse( 194 | generated_text=text, 195 | token=Token( 196 | id=0, 197 | logprob=0, 198 | text=word, 199 | special=word_stripped.startswith('<|') and word_stripped.endswith('|>'), 200 | ) 201 | ) 202 | -------------------------------------------------------------------------------- /Ziwei-Chatglm3-6B/composite_demo/conversation.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from enum import auto, Enum 3 | import json 4 | 5 | from PIL.Image import Image 6 | import streamlit as st 7 | from streamlit.delta_generator import DeltaGenerator 8 | 9 | TOOL_PROMPT = 'Answer the following questions as best as you can. You have access to the following tools:\n' 10 | 11 | class Role(Enum): 12 | SYSTEM = auto() 13 | USER = auto() 14 | ASSISTANT = auto() 15 | TOOL = auto() 16 | INTERPRETER = auto() 17 | OBSERVATION = auto() 18 | 19 | def __str__(self): 20 | match self: 21 | case Role.SYSTEM: 22 | return "<|system|>" 23 | case Role.USER: 24 | return "<|user|>" 25 | case Role.ASSISTANT | Role.TOOL | Role.INTERPRETER: 26 | return "<|assistant|>" 27 | case Role.OBSERVATION: 28 | return "<|observation|>" 29 | 30 | # Get the message block for the given role 31 | def get_message(self): 32 | # Compare by value here, because the enum object in the session state 33 | # is not the same as the enum cases here, due to streamlit's rerunning 34 | # behavior. 35 | match self.value: 36 | case Role.SYSTEM.value: 37 | return 38 | case Role.USER.value: 39 | return st.chat_message(name="user", avatar="user") 40 | case Role.ASSISTANT.value: 41 | return st.chat_message(name="assistant", avatar="assistant") 42 | case Role.TOOL.value: 43 | return st.chat_message(name="tool", avatar="assistant") 44 | case Role.INTERPRETER.value: 45 | return st.chat_message(name="interpreter", avatar="assistant") 46 | case Role.OBSERVATION.value: 47 | return st.chat_message(name="observation", avatar="user") 48 | case _: 49 | st.error(f'Unexpected role: {self}') 50 | 51 | @dataclass 52 | class Conversation: 53 | role: Role 54 | content: str 55 | tool: str | None = None 56 | image: Image | None = None 57 | 58 | def __str__(self) -> str: 59 | print(self.role, self.content, self.tool) 60 | match self.role: 61 | case Role.SYSTEM | Role.USER | Role.ASSISTANT | Role.OBSERVATION: 62 | return f'{self.role}\n{self.content}' 63 | case Role.TOOL: 64 | return f'{self.role}{self.tool}\n{self.content}' 65 | case Role.INTERPRETER: 66 | return f'{self.role}interpreter\n{self.content}' 67 | 68 | # Human readable format 69 | def get_text(self) -> str: 70 | text = postprocess_text(self.content) 71 | match self.role.value: 72 | case Role.TOOL.value: 73 | text = f'Calling tool `{self.tool}`:\n\n{text}' 74 | case Role.INTERPRETER.value: 75 | text = f'{text}' 76 | case Role.OBSERVATION.value: 77 | text = f'Observation:\n```\n{text}\n```' 78 | return text 79 | 80 | # Display as a markdown block 81 | def show(self, placeholder: DeltaGenerator | None=None) -> str: 82 | if placeholder: 83 | message = placeholder 84 | else: 85 | message = self.role.get_message() 86 | if self.image: 87 | message.image(self.image) 88 | else: 89 | text = self.get_text() 90 | message.markdown(text) 91 | 92 | def preprocess_text( 93 | system: str | None, 94 | tools: list[dict] | None, 95 | history: list[Conversation], 96 | ) -> str: 97 | if tools: 98 | tools = json.dumps(tools, indent=4, ensure_ascii=False) 99 | 100 | prompt = f"{Role.SYSTEM}\n" 101 | prompt += system if not tools else TOOL_PROMPT 102 | if tools: 103 | tools = json.loads(tools) 104 | prompt += json.dumps(tools, ensure_ascii=False) 105 | for conversation in history: 106 | prompt += f'{conversation}' 107 | prompt += f'{Role.ASSISTANT}\n' 108 | return prompt 109 | 110 | def postprocess_text(text: str) -> str: 111 | text = text.replace("\(", "$") 112 | text = text.replace("\)", "$") 113 | text = text.replace("\[", "$$") 114 | text = text.replace("\]", "$$") 115 | text = text.replace("<|assistant|>", "") 116 | text = text.replace("<|observation|>", "") 117 | text = text.replace("<|system|>", "") 118 | text = text.replace("<|user|>", "") 119 | return text.strip() -------------------------------------------------------------------------------- /Ziwei-Chatglm3-6B/composite_demo/demo_chat.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | from streamlit.delta_generator import DeltaGenerator 3 | 4 | from client import get_client 5 | from conversation import postprocess_text, preprocess_text, Conversation, Role 6 | 7 | client = get_client() 8 | 9 | 10 | # Append a conversation into history, while show it in a new markdown block 11 | def append_conversation( 12 | conversation: Conversation, 13 | history: list[Conversation], 14 | placeholder: DeltaGenerator | None = None, 15 | ) -> None: 16 | history.append(conversation) 17 | conversation.show(placeholder) 18 | 19 | 20 | def main( 21 | prompt_text: str, 22 | system_prompt: str, 23 | top_p: float = 0.8, 24 | temperature: float = 0.95, 25 | repetition_penalty: float = 1.0, 26 | max_new_tokens: int = 1024, 27 | retry: bool = False 28 | ): 29 | placeholder = st.empty() 30 | with placeholder.container(): 31 | if 'chat_history' not in st.session_state: 32 | st.session_state.chat_history = [] 33 | 34 | if prompt_text == "" and retry == False: 35 | print("\n== Clean ==\n") 36 | st.session_state.chat_history = [] 37 | return 38 | 39 | history: list[Conversation] = st.session_state.chat_history 40 | for conversation in history: 41 | conversation.show() 42 | 43 | if retry: 44 | print("\n== Retry ==\n") 45 | last_user_conversation_idx = None 46 | for idx, conversation in enumerate(history): 47 | if conversation.role == Role.USER: 48 | last_user_conversation_idx = idx 49 | if last_user_conversation_idx is not None: 50 | prompt_text = history[last_user_conversation_idx].content 51 | del history[last_user_conversation_idx:] 52 | 53 | 54 | if prompt_text: 55 | prompt_text = prompt_text.strip() 56 | append_conversation(Conversation(Role.USER, prompt_text), history) 57 | placeholder = st.empty() 58 | message_placeholder = placeholder.chat_message(name="assistant", avatar="assistant") 59 | markdown_placeholder = message_placeholder.empty() 60 | 61 | output_text = '' 62 | for response in client.generate_stream( 63 | system_prompt, 64 | tools=None, 65 | history=history, 66 | do_sample=True, 67 | max_new_tokens=max_new_tokens, 68 | temperature=temperature, 69 | top_p=top_p, 70 | stop_sequences=[str(Role.USER)], 71 | repetition_penalty=repetition_penalty, 72 | ): 73 | token = response.token 74 | if response.token.special: 75 | print("\n==Output:==\n", output_text) 76 | match token.text.strip(): 77 | case '<|user|>': 78 | break 79 | case _: 80 | st.error(f'Unexpected special token: {token.text.strip()}') 81 | break 82 | output_text += response.token.text 83 | markdown_placeholder.markdown(postprocess_text(output_text + '▌')) 84 | 85 | append_conversation(Conversation( 86 | Role.ASSISTANT, 87 | postprocess_text(output_text), 88 | ), history, markdown_placeholder) -------------------------------------------------------------------------------- /Ziwei-Chatglm3-6B/composite_demo/main.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | st.set_page_config( 3 | page_title="ChatGLM3-Ziwei Demo", 4 | page_icon=":robot:", 5 | layout='centered', 6 | initial_sidebar_state='expanded', 7 | ) 8 | 9 | 10 | import demo_chat 11 | from enum import Enum 12 | 13 | DEFAULT_SYSTEM_PROMPT = ''' 14 | You are ChatGLM3-Ziwei, a large language model trained by HITWH. Follow the user's instructions carefully. Respond using markdown. 15 | '''.strip() 16 | 17 | # Set the title of the demo 18 | st.title("ChatGLM3-Ziwei Demo") 19 | 20 | # Add your custom text here, with smaller font size 21 | 22 | 23 | class Mode(str, Enum): 24 | CHAT= '💬 Chat' 25 | 26 | 27 | with st.sidebar: 28 | top_p = st.slider( 29 | 'top_p', 0.0, 1.0, 0.8, step=0.01 30 | ) 31 | temperature = st.slider( 32 | 'temperature', 0.0, 1.5, 0.95, step=0.01 33 | ) 34 | repetition_penalty = st.slider( 35 | 'repetition_penalty', 0.0, 2.0, 1.1, step=0.01 36 | ) 37 | max_new_token = st.slider( 38 | 'Output length', 5, 32000, 256, step=1 39 | ) 40 | 41 | cols = st.columns(2) 42 | export_btn = cols[0] 43 | clear_history = cols[1].button("Clear History", use_container_width=True) 44 | retry = export_btn.button("Retry", use_container_width=True) 45 | 46 | system_prompt = st.text_area( 47 | label="System Prompt (Only for chat mode)", 48 | height=300, 49 | value=DEFAULT_SYSTEM_PROMPT, 50 | ) 51 | 52 | prompt_text = st.chat_input( 53 | 'Chat with ChatGLM3!', 54 | key='chat_input', 55 | ) 56 | 57 | tab = st.radio( 58 | 'Mode', 59 | [mode.value for mode in Mode], 60 | horizontal=True, 61 | label_visibility='hidden', 62 | ) 63 | 64 | if clear_history or retry: 65 | prompt_text = "" 66 | 67 | match tab: 68 | case Mode.CHAT: 69 | demo_chat.main( 70 | retry=retry, 71 | top_p=top_p, 72 | temperature=temperature, 73 | prompt_text=prompt_text, 74 | system_prompt=system_prompt, 75 | repetition_penalty=repetition_penalty, 76 | max_new_tokens=max_new_token 77 | ) 78 | 79 | case _: 80 | st.error(f'Unexpected tab: {tab}') 81 | -------------------------------------------------------------------------------- /Ziwei-Chatglm3-6B/composite_demo/requirements.txt: -------------------------------------------------------------------------------- 1 | huggingface_hub>=0.19.4 2 | pillow>=10.1.0 3 | pyyaml>=6.0.1 4 | requests>=2.31.0 5 | ipykernel>=6.26.0 6 | ipython>=8.18.1 7 | jupyter_client>=8.6.0 8 | -------------------------------------------------------------------------------- /Ziwei-Chatglm3-6B/composite_demo/tool_registry.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code is the tool registration part. By registering the tool, the model can call the tool. 3 | This code provides extended functionality to the model, enabling it to call and interact with a variety of utilities 4 | through defined interfaces. 5 | """ 6 | 7 | import copy 8 | import inspect 9 | from pprint import pformat 10 | import traceback 11 | from types import GenericAlias 12 | from typing import get_origin, Annotated 13 | import subprocess 14 | 15 | _TOOL_HOOKS = {} 16 | _TOOL_DESCRIPTIONS = {} 17 | 18 | 19 | def register_tool(func: callable): 20 | tool_name = func.__name__ 21 | tool_description = inspect.getdoc(func).strip() 22 | python_params = inspect.signature(func).parameters 23 | tool_params = [] 24 | for name, param in python_params.items(): 25 | annotation = param.annotation 26 | if annotation is inspect.Parameter.empty: 27 | raise TypeError(f"Parameter `{name}` missing type annotation") 28 | if get_origin(annotation) != Annotated: 29 | raise TypeError(f"Annotation type for `{name}` must be typing.Annotated") 30 | 31 | typ, (description, required) = annotation.__origin__, annotation.__metadata__ 32 | typ: str = str(typ) if isinstance(typ, GenericAlias) else typ.__name__ 33 | if not isinstance(description, str): 34 | raise TypeError(f"Description for `{name}` must be a string") 35 | if not isinstance(required, bool): 36 | raise TypeError(f"Required for `{name}` must be a bool") 37 | 38 | tool_params.append({ 39 | "name": name, 40 | "description": description, 41 | "type": typ, 42 | "required": required 43 | }) 44 | tool_def = { 45 | "name": tool_name, 46 | "description": tool_description, 47 | "params": tool_params 48 | } 49 | print("[registered tool] " + pformat(tool_def)) 50 | _TOOL_HOOKS[tool_name] = func 51 | _TOOL_DESCRIPTIONS[tool_name] = tool_def 52 | 53 | return func 54 | 55 | 56 | def dispatch_tool(tool_name: str, tool_params: dict) -> str: 57 | if tool_name not in _TOOL_HOOKS: 58 | return f"Tool `{tool_name}` not found. Please use a provided tool." 59 | tool_call = _TOOL_HOOKS[tool_name] 60 | try: 61 | ret = tool_call(**tool_params) 62 | except: 63 | ret = traceback.format_exc() 64 | return str(ret) 65 | 66 | 67 | def get_tools() -> dict: 68 | return copy.deepcopy(_TOOL_DESCRIPTIONS) 69 | 70 | 71 | # Tool Definitions 72 | 73 | @register_tool 74 | def random_number_generator( 75 | seed: Annotated[int, 'The random seed used by the generator', True], 76 | range: Annotated[tuple[int, int], 'The range of the generated numbers', True], 77 | ) -> int: 78 | """ 79 | Generates a random number x, s.t. range[0] <= x < range[1] 80 | """ 81 | if not isinstance(seed, int): 82 | raise TypeError("Seed must be an integer") 83 | if not isinstance(range, tuple): 84 | raise TypeError("Range must be a tuple") 85 | if not isinstance(range[0], int) or not isinstance(range[1], int): 86 | raise TypeError("Range must be a tuple of integers") 87 | 88 | import random 89 | return random.Random(seed).randint(*range) 90 | 91 | 92 | @register_tool 93 | def get_weather( 94 | city_name: Annotated[str, 'The name of the city to be queried', True], 95 | ) -> str: 96 | """ 97 | Get the current weather for `city_name` 98 | """ 99 | 100 | if not isinstance(city_name, str): 101 | raise TypeError("City name must be a string") 102 | 103 | key_selection = { 104 | "current_condition": ["temp_C", "FeelsLikeC", "humidity", "weatherDesc", "observation_time"], 105 | } 106 | import requests 107 | try: 108 | resp = requests.get(f"https://wttr.in/{city_name}?format=j1") 109 | resp.raise_for_status() 110 | resp = resp.json() 111 | ret = {k: {_v: resp[k][0][_v] for _v in v} for k, v in key_selection.items()} 112 | except: 113 | import traceback 114 | ret = "Error encountered while fetching weather data!\n" + traceback.format_exc() 115 | 116 | return str(ret) 117 | 118 | 119 | @register_tool 120 | def get_shell( 121 | query: Annotated[str, 'The command should run in Linux shell', True], 122 | ) -> str: 123 | """ 124 | Use shell to run command 125 | """ 126 | if not isinstance(query, str): 127 | raise TypeError("Command must be a string") 128 | try: 129 | result = subprocess.run(query, shell=True, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, 130 | text=True) 131 | return result.stdout 132 | except subprocess.CalledProcessError as e: 133 | return e.stderr 134 | 135 | 136 | if __name__ == "__main__": 137 | # print(dispatch_tool("get_shell", {"query": "pwd"})) 138 | print(get_tools()) -------------------------------------------------------------------------------- /Ziwei-Chatglm3-6B/finetune_demo/=1.3.0: -------------------------------------------------------------------------------- 1 | Looking in indexes: http://mirrors.aliyun.com/pypi/simple 2 | Collecting openai 3 | Downloading http://mirrors.aliyun.com/pypi/packages/8a/6c/f345662c586464cbd6185239ddea4d281b623db302115ff7a2bb27db1eea/openai-1.3.4-py3-none-any.whl (220 kB) 4 | ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 220.5/220.5 kB 86.3 kB/s eta 0:00:00 5 | Requirement already satisfied: anyio<4,>=3.5.0 in /home/admin1/桌面/ChatGLM3-main/venv/lib/python3.10/site-packages (from openai) (3.7.1) 6 | Collecting distro<2,>=1.7.0 (from openai) 7 | Downloading http://mirrors.aliyun.com/pypi/packages/f4/2c/c90a3adaf0ddb70afe193f5ebfb539612af57cffe677c3126be533df3098/distro-1.8.0-py3-none-any.whl (20 kB) 8 | Requirement already satisfied: httpx<1,>=0.23.0 in /home/admin1/桌面/ChatGLM3-main/venv/lib/python3.10/site-packages (from openai) (0.25.1) 9 | Requirement already satisfied: pydantic<3,>=1.9.0 in /home/admin1/桌面/ChatGLM3-main/venv/lib/python3.10/site-packages (from openai) (2.5.1) 10 | Requirement already satisfied: tqdm>4 in /home/admin1/桌面/ChatGLM3-main/venv/lib/python3.10/site-packages (from openai) (4.66.1) 11 | Requirement already satisfied: typing-extensions<5,>=4.5 in /home/admin1/桌面/ChatGLM3-main/venv/lib/python3.10/site-packages (from openai) (4.8.0) 12 | Requirement already satisfied: idna>=2.8 in /home/admin1/桌面/ChatGLM3-main/venv/lib/python3.10/site-packages (from anyio<4,>=3.5.0->openai) (3.4) 13 | Requirement already satisfied: sniffio>=1.1 in /home/admin1/桌面/ChatGLM3-main/venv/lib/python3.10/site-packages (from anyio<4,>=3.5.0->openai) (1.3.0) 14 | Requirement already satisfied: exceptiongroup in /home/admin1/桌面/ChatGLM3-main/venv/lib/python3.10/site-packages (from anyio<4,>=3.5.0->openai) (1.1.3) 15 | Requirement already satisfied: certifi in /home/admin1/桌面/ChatGLM3-main/venv/lib/python3.10/site-packages (from httpx<1,>=0.23.0->openai) (2023.11.17) 16 | Requirement already satisfied: httpcore in /home/admin1/桌面/ChatGLM3-main/venv/lib/python3.10/site-packages (from httpx<1,>=0.23.0->openai) (1.0.2) 17 | Requirement already satisfied: annotated-types>=0.4.0 in /home/admin1/桌面/ChatGLM3-main/venv/lib/python3.10/site-packages (from pydantic<3,>=1.9.0->openai) (0.6.0) 18 | Requirement already satisfied: pydantic-core==2.14.3 in /home/admin1/桌面/ChatGLM3-main/venv/lib/python3.10/site-packages (from pydantic<3,>=1.9.0->openai) (2.14.3) 19 | Requirement already satisfied: h11<0.15,>=0.13 in /home/admin1/桌面/ChatGLM3-main/venv/lib/python3.10/site-packages (from httpcore->httpx<1,>=0.23.0->openai) (0.14.0) 20 | Installing collected packages: distro, openai 21 | Successfully installed distro-1.8.0 openai-1.3.4 22 | -------------------------------------------------------------------------------- /Ziwei-Chatglm3-6B/finetune_demo/README.md: -------------------------------------------------------------------------------- 1 | # ChatGLM3-6B 微调示例 2 | 3 | 本目录提供 ChatGLM3-6B 模型的微调示例,包括全量微调和 P-Tuning v2。格式上,提供多轮对话微调样例和输入输出格式微调样例。 4 | 5 | 如果将模型下载到了本地,本文和代码中的 `THUDM/chatglm3-6b` 字段均应替换为相应地址以从本地加载模型。 6 | 7 | 运行示例需要 `python>=3.9`,除基础的 `torch` 依赖外,示例代码运行还需要依赖 8 | 9 | ```bash 10 | pip install transformers==4.30.2 accelerate sentencepiece astunparse deepspeed 11 | ``` 12 | 13 | ## 多轮对话格式 14 | 15 | 多轮对话微调示例采用 ChatGLM3 对话格式约定,对不同角色添加不同 `loss_mask` 从而在一遍计算中为多轮回复计算 `loss`。 16 | 17 | ### 数据格式和预处理 18 | 19 | 对于数据文件,样例采用如下格式 20 | 21 | 如果您仅希望微调模型的对话能力,而非工具能力,您应该按照以下格式整理数据。 22 | 23 | ```json 24 | [ 25 | { 26 | "conversations": [ 27 | { 28 | "role": "system", 29 | "content": "" 30 | }, 31 | { 32 | "role": "user", 33 | "content": "" 34 | }, 35 | { 36 | "role": "assistant", 37 | "content": "" 38 | }, 39 | // ... Muti Turn 40 | { 41 | "role": "user", 42 | "content": "" 43 | }, 44 | { 45 | "role": "assistant", 46 | "content": "" 47 | } 48 | ] 49 | } 50 | // ... 51 | ] 52 | ``` 53 | 54 | **请注意,这种方法在微调的step较多的情况下会影响到模型的工具调用功能** 55 | 56 | 如果您希望微调模型的对话和工具能力,您应该按照以下格式整理数据。 57 | 58 | ```json 59 | [ 60 | { 61 | "tools": [ 62 | // available tools, format is not restricted 63 | ], 64 | "conversations": [ 65 | { 66 | "role": "system", 67 | "content": "" 68 | }, 69 | { 70 | "role": "user", 71 | "content": "" 72 | }, 73 | { 74 | "role": "assistant", 75 | "content": "" 76 | }, 77 | { 78 | "role": "tool", 79 | "name": "": "" 82 | }, 83 | "observation": "" 84 | // don't have to be string 85 | }, 86 | { 87 | "role": "assistant", 88 | "content": "" 89 | }, 90 | // ... Muti Turn 91 | { 92 | "role": "user", 93 | "content": "" 94 | }, 95 | { 96 | "role": "assistant", 97 | "content": "" 98 | } 99 | ] 100 | } 101 | // ... 102 | ] 103 | ``` 104 | 105 | - 关于工具描述的 system prompt 无需手动插入,预处理时会将 `tools` 字段使用 `json.dumps(..., ensure_ascii=False)` 格式化后插入为首条 system prompt。 106 | 107 | - 每种角色可以附带一个 `bool` 类型的 `loss` 字段,表示该字段所预测的内容是否参与 `loss` 计算。若没有该字段,样例实现中默认对 `system`, `user` 不计算 `loss`,其余角色则计算 `loss`。 108 | 109 | - `tool` 并不是 ChatGLM3 中的原生角色,这里的 `tool` 在预处理阶段将被自动转化为一个具有工具调用 `metadata` 的 `assistant` 角色(默认计算 `loss`)和一个表示工具返回值的 `observation` 角色(不计算 `loss`)。 110 | 111 | - 目前暂未实现 `Code interpreter`的微调任务。 112 | 113 | - `system` 角色为可选角色,但若存在 `system` 角色,其必须出现在 `user` 角色之前,且一个完整的对话数据(无论单轮或者多轮对话)只能出现一次 `system` 角色。 114 | 115 | 作为示例,我们使用 ToolAlpaca 数据集来进行微调。首先,克隆 [ToolAlpaca 数据集](https://github.com/tangqiaoyu/ToolAlpaca),并使用 116 | 117 | ```bash 118 | ./scripts/format_tool_alpaca.py --path "ToolAlpaca/data/train_data.json" 119 | ``` 120 | 121 | 将数据集处理成上述格式。在这里,我们有意将工具处理成了了 `list[str]` 这样的自然语言形式,以观察模型在微调前后对工具定义的理解能力。 122 | 123 | ### 微调模型 124 | 125 | 以下脚本提供了微调模型的参考方式。 126 | 127 | ```bash 128 | ./scripts/finetune_ds_multiturn.sh # 全量微调 129 | ./scripts/finetune_pt_multiturn.sh # P-Tuning v2 微调 130 | ``` 131 | 132 | ### 部署 133 | 134 | 我们更新了 ChatGLM3 的综合 Demo,使其可以部署微调后的模型 checkpoint。 135 | 136 | 对于全量微调,可以使用以下方式进行部署 137 | 138 | ```bash 139 | cd ../composite_demo 140 | MODEL_PATH="path to finetuned model checkpoint" TOKENIZER_PATH="THUDM/chatglm3-6b" streamlit run main.py 141 | ``` 142 | 143 | 对于 P-Tuning v2 微调,可以使用以下方式进行部署 144 | 145 | ```bash 146 | cd ../composite_demo 147 | MODEL_PATH="THUDM/chatglm3-6b" PT_PATH="path to p-tuning checkpoint" streamlit run main.py 148 | ``` 149 | 150 | ## 输入输出格式 151 | 152 | 对于输入-输出格式,样例采用如下输入格式 153 | 154 | ```json 155 | [ 156 | { 157 | "prompt": "", 158 | "response": "" 159 | } 160 | // ... 161 | ] 162 | ``` 163 | 164 | 预处理时,不会拼接任何角色标识符。 165 | 166 | 作为示例,我们使用 AdvertiseGen 数据集来进行微调。从 [Google Drive](https://drive.google.com/file/d/13_vf0xRTQsyneRKdD1bZIr93vBGOczrk/view?usp=sharing) 或者 [Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/b3f119a008264b1cabd1/?dl=1) 下载处理好的 AdvertiseGen 数据集,将解压后的 `AdvertiseGen` 目录放到本目录下。 167 | 168 | ```bash 169 | ./scripts/format_advertise_gen.py --path "AdvertiseGen/train.json" 170 | ``` 171 | 172 | 来下载和将数据集处理成上述格式。 173 | 174 | ### 微调模型 175 | 176 | 以下脚本提供了微调模型的参考方式。 177 | 178 | ```bash 179 | ./scripts/finetune_ds.sh # 全量微调 180 | ./scripts/finetune_pt.sh # P-Tuning v2 微调 181 | ``` 182 | 183 | ### 推理验证 184 | 185 | 对于输入输出格式的微调,可使用 `inference.py` 进行基本的推理验证。 186 | 187 | ```bash 188 | python inference.py \ 189 | --pt-checkpoint "path to p-tuning checkpoint" \ 190 | --model THUDM/chatglm3-6b 191 | ``` 192 | 193 | ```bash 194 | python inference.py \ 195 | --tokenizer THUDM/chatglm3-6b \ 196 | --model "path to finetuned model checkpoint" 197 | ``` 198 | 199 | ### 提示 200 | 201 | 1. 微调代码在开始训练前,会先打印首条训练数据的预处理信息,显示为 202 | 203 | ```log 204 | Sanity Check >>>>>>>>>>>>> 205 | '[gMASK]': 64790 -> -100 206 | 'sop': 64792 -> -100 207 | '<|system|>': 64794 -> -100 208 | '': 30910 -> -100 209 | '\n': 13 -> -100 210 | 'Answer': 20115 -> -100 211 | 'the': 267 -> -100 212 | 'following': 1762 -> -100 213 | ... 214 | 'know': 683 -> -100 215 | 'the': 267 -> -100 216 | 'response': 3010 -> -100 217 | 'details': 3296 -> -100 218 | '.': 30930 -> -100 219 | '<|assistant|>': 64796 -> -100 220 | '': 30910 -> 30910 221 | '\n': 13 -> 13 222 | 'I': 307 -> 307 223 | 'need': 720 -> 720 224 | 'to': 289 -> 289 225 | 'use': 792 -> 792 226 | ... 227 | <<<<<<<<<<<<< Sanity Check 228 | ``` 229 | 230 | 字样,每行依次表示一个 detokenized string, token_id 和 target_id。可在日志中查看这部分的 `loss_mask` 是否符合预期。若不符合,可能需要调整代码或数据。 231 | 232 | 2. 参考显存用量 233 | 234 | - P-Tuning V2 `PRE_SEQ_LEN=128`, `DEV_BATCH_SIZE=1`, `GRAD_ACCUMULARION_STEPS=16`, `MAX_SEQ_LEN=2048` 配置下约需要 21GB 显存。 235 | - 全量微调时,`./scripts/finetune_ds_multiturn.sh` 中的配置(`MAX_SEQ_LEN=2048`, `DEV_BATCH_SIZE=16`, `GRAD_ACCUMULARION_STEPS=1`)恰好用满 4 * 80GB 显存。 236 | 237 | 3. 若尝试后发现显存不足,可以考虑 238 | - 尝试降低 `DEV_BATCH_SIZE` 并提升 `GRAD_ACCUMULARION_STEPS` 239 | - 尝试添加 `--quantization_bit 8` 或 `--quantization_bit 4`。 240 | - `PRE_SEQ_LEN=128`, `DEV_BATCH_SIZE=1`, `GRAD_ACCUMULARION_STEPS=16`, `MAX_SEQ_LEN=1024` 配置下,`--quantization_bit 8` 约需 12GB 显存,`--quantization_bit 4` 约需 7.6GB 显存。 241 | 242 | ## 参考文献 243 | 244 | ``` 245 | @inproceedings{liu2022p, 246 | title={P-tuning: Prompt tuning can be comparable to fine-tuning across scales and tasks}, 247 | author={Liu, Xiao and Ji, Kaixuan and Fu, Yicheng and Tam, Weng and Du, Zhengxiao and Yang, Zhilin and Tang, Jie}, 248 | booktitle={Proceedings of the 60th Annual Meeting of the Association for Computational Linguistics (Volume 2: Short Papers)}, 249 | pages={61--68}, 250 | year={2022} 251 | } 252 | 253 | @misc{tang2023toolalpaca, 254 | title={ToolAlpaca: Generalized Tool Learning for Language Models with 3000 Simulated Cases}, 255 | author={Qiaoyu Tang and Ziliang Deng and Hongyu Lin and Xianpei Han and Qiao Liang and Le Sun}, 256 | year={2023}, 257 | eprint={2306.05301}, 258 | archivePrefix={arXiv}, 259 | primaryClass={cs.CL} 260 | } 261 | ``` 262 | -------------------------------------------------------------------------------- /Ziwei-Chatglm3-6B/finetune_demo/arguments.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | from typing import Optional 3 | 4 | 5 | @dataclass 6 | class ModelArguments: 7 | """ 8 | Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. 9 | """ 10 | 11 | model_name_or_path: str = field( 12 | metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"} 13 | ) 14 | ptuning_checkpoint: str = field( 15 | default=None, metadata={"help": "Path to p-tuning v2 checkpoints"} 16 | ) 17 | config_name: Optional[str] = field( 18 | default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} 19 | ) 20 | tokenizer_name: Optional[str] = field( 21 | default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} 22 | ) 23 | cache_dir: Optional[str] = field( 24 | default=None, 25 | metadata={"help": "Where to store the pretrained models downloaded from huggingface.co"}, 26 | ) 27 | use_fast_tokenizer: bool = field( 28 | default=True, 29 | metadata={"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."}, 30 | ) 31 | model_revision: str = field( 32 | default="main", 33 | metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, 34 | ) 35 | use_auth_token: bool = field( 36 | default=False, 37 | metadata={ 38 | "help": ( 39 | "Will use the token generated when running `huggingface-cli login` (necessary to use this script " 40 | "with private models)." 41 | ) 42 | }, 43 | ) 44 | resize_position_embeddings: Optional[bool] = field( 45 | default=None, 46 | metadata={ 47 | "help": ( 48 | "Whether to automatically resize the position embeddings if `max_source_length` exceeds " 49 | "the model's position embeddings." 50 | ) 51 | }, 52 | ) 53 | quantization_bit: Optional[int] = field( 54 | default=None 55 | ) 56 | pre_seq_len: Optional[int] = field( 57 | default=None 58 | ) 59 | prefix_projection: bool = field( 60 | default=False 61 | ) 62 | 63 | 64 | @dataclass 65 | class DataTrainingArguments: 66 | """ 67 | Arguments pertaining to what data we are going to input our model for training and eval. 68 | """ 69 | train_file: Optional[str] = field( 70 | default=None, metadata={"help": "The input training data file (a jsonlines or csv file)."} 71 | ) 72 | 73 | max_seq_length: Optional[int] = field( 74 | default=2048, 75 | metadata={ 76 | "help": ( 77 | "The maximum total input sequence length after tokenization. Sequences longer " 78 | "than this will be truncated." 79 | ) 80 | }, 81 | ) 82 | 83 | max_source_length: Optional[int] = field( 84 | default=1024, 85 | metadata={ 86 | "help": ( 87 | "The maximum total input sequence length after tokenization. Sequences longer " 88 | "than this will be truncated, sequences shorter will be padded." 89 | ) 90 | }, 91 | ) 92 | max_target_length: Optional[int] = field( 93 | default=128, 94 | metadata={ 95 | "help": ( 96 | "The maximum total sequence length for target text after tokenization. Sequences longer " 97 | "than this will be truncated, sequences shorter will be padded." 98 | ) 99 | }, 100 | ) 101 | 102 | train_format: str = field( 103 | default=None, metadata={"help": "The format of the training data file (mulit-turn or input-output)"}, 104 | ) 105 | 106 | overwrite_cache: bool = field( 107 | default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} 108 | ) 109 | 110 | preprocessing_num_workers: Optional[int] = field( 111 | default=None, 112 | metadata={"help": "The number of processes to use for the preprocessing."}, 113 | ) 114 | 115 | max_seq_length: Optional[int] = field( 116 | default=1024, 117 | metadata={ 118 | "help": ( 119 | "The maximum total input sequence length after tokenization. Sequences longer " 120 | "than this will be truncated, sequences shorter will be padded." 121 | ) 122 | }, 123 | ) 124 | 125 | pad_to_max_length: bool = field( 126 | default=False, 127 | metadata={ 128 | "help": ( 129 | "Whether to pad all samples to model maximum sentence length. " 130 | "If False, will pad the samples dynamically when batching to the maximum length in the batch. More " 131 | "efficient on GPU but very bad for TPU." 132 | ) 133 | }, 134 | ) 135 | 136 | max_train_samples: Optional[int] = field( 137 | default=None, 138 | metadata={ 139 | "help": ( 140 | "For debugging purposes or quicker training, truncate the number of training examples to this " 141 | "value if set." 142 | ) 143 | }, 144 | ) 145 | 146 | def __post_init__(self): 147 | extension = self.train_file.split(".")[-1] 148 | assert extension in {"jsonl", "json"}, "`train_file` should be a jsonl or a json file." 149 | 150 | assert self.train_format in {"multi-turn", "input-output"} -------------------------------------------------------------------------------- /Ziwei-Chatglm3-6B/finetune_demo/configs/deepspeed.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_micro_batch_size_per_gpu": "auto", 3 | "zero_allow_untested_optimizer": true, 4 | "fp16": { 5 | "enabled": "auto", 6 | "loss_scale": 0, 7 | "initial_scale_power": 16, 8 | "loss_scale_window": 1000, 9 | "hysteresis": 2, 10 | "min_loss_scale": 1 11 | }, 12 | "zero_optimization": { 13 | "stage": 2, 14 | "allgather_partitions": true, 15 | "allgather_bucket_size": 5e8, 16 | "overlap_comm": false, 17 | "reduce_scatter": true, 18 | "reduce_bucket_size": 5e8, 19 | "contiguous_gradients" : true 20 | } 21 | } -------------------------------------------------------------------------------- /Ziwei-Chatglm3-6B/finetune_demo/finetune.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | # Copyright 2021 The HuggingFace Team. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | """ 17 | Fine-tuning the library models for sequence to sequence. 18 | """ 19 | # You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments. 20 | # Adapted from 21 | 22 | 23 | import logging 24 | import os 25 | import sys 26 | import torch 27 | import json 28 | import transformers 29 | from transformers import ( 30 | AutoConfig, 31 | AutoModel, 32 | AutoTokenizer, 33 | DataCollatorForSeq2Seq, 34 | HfArgumentParser, 35 | Seq2SeqTrainingArguments, 36 | set_seed, 37 | ) 38 | from trainer import PrefixTrainer 39 | 40 | from arguments import ModelArguments, DataTrainingArguments 41 | 42 | from preprocess_utils import sanity_check, MultiTurnDataset, InputOutputDataset 43 | 44 | logger = logging.getLogger(__name__) 45 | 46 | def main(): 47 | parser = HfArgumentParser((ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments)) 48 | if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): 49 | # If we pass only one argument to the script and it's the path to a json file, 50 | # let's parse it to get our arguments. 51 | model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1])) 52 | else: 53 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 54 | 55 | # Setup logging 56 | logging.basicConfig( 57 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 58 | datefmt="%m/%d/%Y %H:%M:%S", 59 | handlers=[logging.StreamHandler(sys.stdout)], 60 | ) 61 | 62 | if training_args.should_log: 63 | # The default of training_args.log_level is passive, so we set log level at info here to have that default. 64 | transformers.utils.logging.set_verbosity_info() 65 | 66 | log_level = training_args.get_process_log_level() 67 | logger.setLevel(log_level) 68 | # datasets.utils.logging.set_verbosity(log_level) 69 | transformers.utils.logging.set_verbosity(log_level) 70 | transformers.utils.logging.enable_default_handler() 71 | transformers.utils.logging.enable_explicit_format() 72 | 73 | # Log on each process the small summary: 74 | logger.warning( 75 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" 76 | + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" 77 | ) 78 | logger.info(f"Training/evaluation parameters {training_args}") 79 | 80 | # Set seed before initializing model. 81 | set_seed(training_args.seed) 82 | 83 | # Load pretrained model and tokenizer 84 | config = AutoConfig.from_pretrained(model_args.model_name_or_path, trust_remote_code=True) 85 | config.pre_seq_len = model_args.pre_seq_len 86 | config.prefix_projection = model_args.prefix_projection 87 | 88 | tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path, trust_remote_code=True) 89 | 90 | if model_args.ptuning_checkpoint is not None: 91 | model = AutoModel.from_pretrained(model_args.model_name_or_path, config=config, trust_remote_code=True) 92 | prefix_state_dict = torch.load(os.path.join(model_args.ptuning_checkpoint, "pytorch_model.bin")) 93 | new_prefix_state_dict = {} 94 | for k, v in prefix_state_dict.items(): 95 | if k.startswith("transformer.prefix_encoder."): 96 | new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v 97 | model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict) 98 | else: 99 | model = AutoModel.from_pretrained(model_args.model_name_or_path, config=config, trust_remote_code=True) 100 | 101 | if model_args.quantization_bit is not None: 102 | print(f"Quantized to {model_args.quantization_bit} bit") 103 | model = model.quantize(model_args.quantization_bit) 104 | if model_args.pre_seq_len is not None: 105 | # P-tuning v2 106 | model = model.half() 107 | model.transformer.prefix_encoder.float() 108 | else: 109 | # Finetune 110 | model = model.float() 111 | 112 | with open(data_args.train_file, "r", encoding="utf-8") as f: 113 | if data_args.train_file.endswith(".json"): 114 | train_data = json.load(f) 115 | elif data_args.train_file.endswith(".jsonl"): 116 | train_data = [json.loads(line) for line in f] 117 | 118 | if data_args.train_format == "multi-turn": 119 | train_dataset = MultiTurnDataset( 120 | train_data, 121 | tokenizer, 122 | data_args.max_seq_length, 123 | ) 124 | elif data_args.train_format == "input-output": 125 | train_dataset = InputOutputDataset( 126 | train_data, 127 | tokenizer, 128 | data_args.max_source_length, 129 | data_args.max_target_length, 130 | ) 131 | else: 132 | raise ValueError(f"Unknown train format: {data_args.train_format}") 133 | if training_args.local_rank < 1: 134 | sanity_check(train_dataset[0]['input_ids'], train_dataset[0]['labels'], tokenizer) 135 | 136 | # Data collator 137 | data_collator = DataCollatorForSeq2Seq( 138 | tokenizer, 139 | model=model, 140 | label_pad_token_id=-100, 141 | pad_to_multiple_of=None, 142 | padding=False 143 | ) 144 | 145 | # Initialize our Trainer 146 | trainer = PrefixTrainer( 147 | model=model, 148 | args=training_args, 149 | train_dataset=train_dataset, 150 | tokenizer=tokenizer, 151 | data_collator=data_collator, 152 | save_changed=model_args.pre_seq_len is not None 153 | ) 154 | 155 | checkpoint = None 156 | if training_args.resume_from_checkpoint is not None: 157 | checkpoint = training_args.resume_from_checkpoint 158 | model.gradient_checkpointing_enable() 159 | model.enable_input_require_grads() 160 | trainer.train(resume_from_checkpoint=checkpoint) 161 | trainer.save_model() # Saves the tokenizer too for easy upload 162 | trainer.save_state() 163 | 164 | if __name__ == "__main__": 165 | main() 166 | -------------------------------------------------------------------------------- /Ziwei-Chatglm3-6B/finetune_demo/inference.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from transformers import AutoConfig, AutoModel, AutoTokenizer 3 | import torch 4 | import os 5 | 6 | parser = argparse.ArgumentParser() 7 | parser.add_argument("--pt-checkpoint", type=str, default=None, help="The checkpoint path") 8 | parser.add_argument("--model", type=str, default=None, help="main model weights") 9 | parser.add_argument("--tokenizer", type=str, default=None, help="main model weights") 10 | parser.add_argument("--pt-pre-seq-len", type=int, default=128, help="The pre-seq-len used in p-tuning") 11 | parser.add_argument("--device", type=str, default="cuda") 12 | parser.add_argument("--max-new-tokens", type=int, default=128) 13 | 14 | args = parser.parse_args() 15 | 16 | if args.tokenizer is None: 17 | args.tokenizer = args.model 18 | 19 | if args.pt_checkpoint: 20 | tokenizer = AutoTokenizer.from_pretrained(args.tokenizer, trust_remote_code=True) 21 | config = AutoConfig.from_pretrained(args.model, trust_remote_code=True, pre_seq_len=128) 22 | model = AutoModel.from_pretrained(args.model, config=config, trust_remote_code=True) 23 | prefix_state_dict = torch.load(os.path.join(args.pt_checkpoint, "pytorch_model.bin")) 24 | new_prefix_state_dict = {} 25 | for k, v in prefix_state_dict.items(): 26 | if k.startswith("transformer.prefix_encoder."): 27 | new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v 28 | model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict) 29 | else: 30 | tokenizer = AutoTokenizer.from_pretrained(args.tokenizer, trust_remote_code=True) 31 | model = AutoModel.from_pretrained(args.model, trust_remote_code=True) 32 | 33 | model = model.to(args.device) 34 | 35 | while True: 36 | prompt = input("Prompt:") 37 | inputs = tokenizer(prompt, return_tensors="pt") 38 | inputs = inputs.to(args.device) 39 | response = model.generate(input_ids=inputs["input_ids"], max_length=inputs["input_ids"].shape[-1] + args.max_new_tokens) 40 | response = response[0, inputs["input_ids"].shape[-1]:] 41 | print("Response:", tokenizer.decode(response, skip_special_tokens=True)) 42 | 43 | #我的出生日期是2002年10月5日22时15分,生辰八字为壬寅壬子己未甲子,请预测一下我今年的运势 44 | #python inference.py --pt-checkpoint /media/admin1/BackupPlus/chatglm3-6b-pt/output_pt-20231224-165313-128-2e-2/checkpoint-1000 --model /home/admin1/桌面/chatglm3-6b 45 | # 1. 我的出生日期是2002年10月5日22时15分,我的命运如何? 46 | # 2. 我的出生日期是2002年10月5日22时15分我什么时候会结婚? 47 | # 3. 我的出生日期是2002年10月5日22时15分我什么时候会有孩子? 48 | # 4. 我的出生日期是2002年10月5日22时15分我的事业会有什么发展? 49 | # 5. 我的出生日期是2002年10月5日22时15分我会赚多少钱? 50 | # 6. 我的出生日期是2002年10月5日22时15分我什么时候会退休? 51 | # 7. 我的出生日期是2002年10月5日22时15分我会有多少财富? 52 | # 8. 我的出生日期是2002年10月5日22时15分我的健康状况如何? 53 | # 9. 我的出生日期是2002年10月5日22时15分我什么时候会生病? 54 | # 10. 我的出生日期是2002年10月5日22时15分我的寿命是多长? -------------------------------------------------------------------------------- /Ziwei-Chatglm3-6B/finetune_demo/preprocess_utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import ast 3 | import astunparse 4 | from transformers import PreTrainedTokenizer 5 | from torch.utils.data import Dataset 6 | from copy import deepcopy 7 | from typing import Dict, List 8 | 9 | # text constants 10 | FUNCTION_CALL_NAME = 'tool_call' 11 | FUNCTION_CALL_PREFIX = '```python\n' 12 | FUNCTION_CALL_POSTFIX = '\n```' 13 | TOOL_DEFINITION_PREFIX = 'Answer the following questions as best as you can. You have access to the following tools:\n' 14 | CONVERSATOIN_KEY = 'conversations' 15 | TOOL_DESC_KEY = 'tools' 16 | 17 | def format_function_call(function_name: str, parameters: Dict[str, str]): 18 | function_name = ast.Name(id=function_name) 19 | keywords = [ 20 | ast.keyword(arg=arg_name, value=ast.Constant(arg_value)) 21 | for arg_name, arg_value in parameters.items() 22 | ] 23 | func_call = ast.Call(func=function_name, args=[], keywords=keywords) 24 | return astunparse.unparse(func_call).strip() 25 | 26 | def format_conversation(item, tokenizer, conversation_key: str, tool_key: str): 27 | conversations = deepcopy(item[conversation_key]) 28 | 29 | # Note: `loss_mask` here means whether *the prediction* of the token should take loss 30 | tokens, loss_masks = [tokenizer.get_command("[gMASK]"), tokenizer.get_command("sop")], [0, 0] 31 | 32 | def _update(_tokens: List[int], value: int = 1): 33 | value = int(value) 34 | tokens.extend(_tokens) 35 | loss_masks.extend([value] * len(_tokens)) 36 | 37 | # insert system prompt for tools 38 | if tool_key in item: 39 | conversations.insert(0, 40 | { 41 | "role": "system", 42 | "content": TOOL_DEFINITION_PREFIX + json.dumps(item[tool_key], indent=4, ensure_ascii=False) 43 | } 44 | ) 45 | 46 | for idx, conv in enumerate(conversations): 47 | loss = conv.get("loss", True) 48 | if conv['role'] in {'system', 'user'}: 49 | loss = False 50 | if conv['role'] == 'tool': 51 | # function call python code 52 | value = FUNCTION_CALL_PREFIX + format_function_call(FUNCTION_CALL_NAME, conv["parameters"]) + FUNCTION_CALL_POSTFIX 53 | text = tokenizer.build_single_message("assistant", conv["name"], value) 54 | _update(text, loss) 55 | 56 | # function call result 57 | value = conv.get('observation', None) 58 | if not isinstance(value, str): 59 | value = json.dumps(value, ensure_ascii=False) 60 | text = tokenizer.build_single_message("observation", "", value) 61 | _update(text, False) 62 | else: 63 | text = tokenizer.build_single_message(conv['role'], "", conv["content"]) 64 | _update(text, loss) 65 | 66 | _update([tokenizer.eos_token_id], False) 67 | 68 | assert len(tokens) == len(loss_masks), f"length mismatch: {len(tokens)} vs {len(loss_masks)}" 69 | return tokens, loss_masks 70 | 71 | def sanity_check(tokens: List[int], target: List[int], tokenizer: PreTrainedTokenizer): 72 | print("Sanity Check >>>>>>>>>>>>>") 73 | for t, m in zip(tokens, target): 74 | decoded = tokenizer.tokenizer.index_special_tokens[t] \ 75 | if t in tokenizer.tokenizer.index_special_tokens \ 76 | else tokenizer.decode([t]) 77 | print("%20s: %6d -> %6d" % (repr(decoded), t, m)) 78 | print("<<<<<<<<<<<<< Sanity Check") 79 | 80 | assert len(tokens) == len(target), f"length mismatch: {len(tokens)} vs {len(target)}" 81 | 82 | class MultiTurnDataset(Dataset): 83 | def __init__(self, data: List[dict], tokenizer: PreTrainedTokenizer, max_seq_length: int): 84 | super(MultiTurnDataset, self).__init__() 85 | self.tokenizer = tokenizer 86 | self.max_seq_length = max_seq_length 87 | self.data = data 88 | 89 | def __len__(self): 90 | return len(self.data) 91 | 92 | def __getitem__(self, i) -> dict: 93 | data_item = self.data[i] 94 | tokens, loss_masks = format_conversation(data_item, self.tokenizer, CONVERSATOIN_KEY, TOOL_DESC_KEY) 95 | 96 | # labels are used inside the model 97 | target_based_loss_mask = [False] + loss_masks[:-1] 98 | labels = [(t if m else -100) for t, m in zip(tokens, target_based_loss_mask)] 99 | 100 | tokens = tokens[:self.max_seq_length] 101 | labels = labels[:self.max_seq_length] 102 | tokens += [self.tokenizer.pad_token_id] * (self.max_seq_length - len(tokens)) 103 | labels += [-100] * (self.max_seq_length - len(labels)) 104 | 105 | assert len(tokens) == len(labels), f"length mismatch: {len(tokens)} vs {len(labels)}" 106 | 107 | return { 108 | "input_ids": tokens, 109 | "labels": labels 110 | } 111 | 112 | class InputOutputDataset(Dataset): 113 | def __init__(self, data: List[dict], tokenizer: PreTrainedTokenizer, max_source_length: int, max_target_length: int): 114 | super(InputOutputDataset, self).__init__() 115 | self.tokenizer = tokenizer 116 | self.max_source_length = max_source_length 117 | self.max_target_length = max_target_length 118 | self.max_seq_length = max_source_length + max_target_length + 1 119 | self.data = data 120 | 121 | def __len__(self): 122 | return len(self.data) 123 | 124 | def __getitem__(self, i) -> dict: 125 | data_item = self.data[i] 126 | 127 | a_ids = self.tokenizer.encode(text=data_item['prompt'], add_special_tokens=True, truncation=True, 128 | max_length=self.max_source_length) 129 | b_ids = self.tokenizer.encode(text=data_item['response'], add_special_tokens=False, truncation=True, 130 | max_length=self.max_target_length) 131 | 132 | context_length = len(a_ids) 133 | input_ids = a_ids + b_ids + [self.tokenizer.eos_token_id] 134 | labels = [self.tokenizer.pad_token_id] * context_length + b_ids + [self.tokenizer.eos_token_id] 135 | 136 | pad_len = self.max_seq_length - len(input_ids) 137 | input_ids = input_ids + [self.tokenizer.pad_token_id] * pad_len 138 | labels = labels + [self.tokenizer.pad_token_id] * pad_len 139 | labels = [(l if l != self.tokenizer.pad_token_id else -100) for l in labels] 140 | 141 | assert len(input_ids) == len(labels), f"length mismatch: {len(input_ids)} vs {len(labels)}" 142 | 143 | return { 144 | "input_ids": input_ids, 145 | "labels": labels 146 | } 147 | -------------------------------------------------------------------------------- /Ziwei-Chatglm3-6B/finetune_demo/scripts/finetune_ds.sh: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env bash 2 | 3 | set -ex 4 | 5 | LR=1e-4 6 | NUM_GPUS=4 7 | MAX_SOURCE_LEN=1024 8 | MAX_TARGET_LEN=128 9 | DEV_BATCH_SIZE=4 10 | GRAD_ACCUMULARION_STEPS=1 11 | MAX_STEP=500 12 | SAVE_INTERVAL=500 13 | 14 | RUN_NAME=advertise_gen_ft 15 | BASE_MODEL_PATH=THUDM/chatglm3-6b 16 | DATASET_PATH=formatted_data/advertise_gen.jsonl 17 | 18 | DATESTR=`date +%Y%m%d-%H%M%S` 19 | OUTPUT_DIR=output/${RUN_NAME}-${DATESTR}-${LR} 20 | MASTER_PORT=$(shuf -n 1 -i 10000-65535) 21 | 22 | mkdir -p $OUTPUT_DIR 23 | 24 | torchrun --standalone --nnodes=1 --nproc_per_node=$NUM_GPUS finetune.py \ 25 | --train_format input-output \ 26 | --train_file $DATASET_PATH \ 27 | --preprocessing_num_workers 1 \ 28 | --model_name_or_path $BASE_MODEL_PATH \ 29 | --output_dir $OUTPUT_DIR \ 30 | --max_source_length $MAX_SOURCE_LEN \ 31 | --max_target_length $MAX_TARGET_LEN \ 32 | --per_device_train_batch_size $DEV_BATCH_SIZE \ 33 | --gradient_accumulation_steps $GRAD_ACCUMULARION_STEPS \ 34 | --max_steps $MAX_STEP \ 35 | --logging_steps 1 \ 36 | --save_steps $SAVE_INTERVAL \ 37 | --learning_rate $LR \ 38 | --fp16 \ 39 | --deepspeed configs/deepspeed.json 2>&1 | tee ${OUTPUT_DIR}/train.log 40 | -------------------------------------------------------------------------------- /Ziwei-Chatglm3-6B/finetune_demo/scripts/finetune_ds_multiturn.sh: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env bash 2 | 3 | set -ex 4 | 5 | LR=1e-4 6 | NUM_GPUS=4 7 | MAX_SEQ_LEN=2048 8 | DEV_BATCH_SIZE=16 9 | GRAD_ACCUMULARION_STEPS=1 10 | MAX_STEP=200 11 | SAVE_INTERVAL=50 12 | 13 | DATESTR=`date +%Y%m%d-%H%M%S` 14 | RUN_NAME=tool_alpaca_ft 15 | DATASET_PATH=formatted_data/tool_alpaca.jsonl 16 | 17 | BASE_MODEL_PATH=THUDM/chatglm3-6b 18 | OUTPUT_DIR=output/${RUN_NAME}-${DATESTR}-${LR} 19 | 20 | mkdir -p $OUTPUT_DIR 21 | 22 | torchrun --standalone --nnodes=1 --nproc_per_node=$NUM_GPUS finetune.py \ 23 | --train_format multi-turn \ 24 | --train_file $DATASET_PATH \ 25 | --max_seq_length $MAX_SEQ_LEN \ 26 | --preprocessing_num_workers 1 \ 27 | --model_name_or_path $BASE_MODEL_PATH \ 28 | --output_dir $OUTPUT_DIR \ 29 | --per_device_train_batch_size $DEV_BATCH_SIZE \ 30 | --gradient_accumulation_steps $GRAD_ACCUMULARION_STEPS \ 31 | --max_steps $MAX_STEP \ 32 | --logging_steps 1 \ 33 | --save_steps $SAVE_INTERVAL \ 34 | --fp16 \ 35 | --deepspeed configs/deepspeed.json 2>&1 | tee ${OUTPUT_DIR}/train.log 36 | -------------------------------------------------------------------------------- /Ziwei-Chatglm3-6B/finetune_demo/scripts/finetune_pt.sh: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env bash 2 | 3 | set -ex 4 | 5 | PRE_SEQ_LEN=128 6 | LR=2e-2 7 | NUM_GPUS=1 8 | MAX_SOURCE_LEN=1024 9 | MAX_TARGET_LEN=128 10 | DEV_BATCH_SIZE=1 11 | GRAD_ACCUMULARION_STEPS=32 12 | MAX_STEP=1000 13 | SAVE_INTERVAL=500 14 | 15 | DATESTR=`date +%Y%m%d-%H%M%S` 16 | RUN_NAME=output_pt 17 | 18 | BASE_MODEL_PATH=/home/admin1/桌面/chatglm3-6b 19 | DATASET_PATH=formatted_data/output.jsonl 20 | OUTPUT_DIR=/media/admin1/BackupPlus/chatglm3-6b-pt/${RUN_NAME}-${DATESTR}-${PRE_SEQ_LEN}-${LR} 21 | 22 | mkdir -p $OUTPUT_DIR 23 | 24 | torchrun --standalone --nnodes=1 --nproc_per_node=$NUM_GPUS finetune.py \ 25 | --train_format input-output \ 26 | --train_file $DATASET_PATH \ 27 | --preprocessing_num_workers 1 \ 28 | --model_name_or_path $BASE_MODEL_PATH \ 29 | --output_dir $OUTPUT_DIR \ 30 | --max_source_length $MAX_SOURCE_LEN \ 31 | --max_target_length $MAX_TARGET_LEN \ 32 | --per_device_train_batch_size $DEV_BATCH_SIZE \ 33 | --gradient_accumulation_steps $GRAD_ACCUMULARION_STEPS \ 34 | --max_steps $MAX_STEP \ 35 | --logging_steps 1 \ 36 | --save_steps $SAVE_INTERVAL \ 37 | --learning_rate $LR \ 38 | --pre_seq_len $PRE_SEQ_LEN 2>&1 | tee ${OUTPUT_DIR}/train.log 39 | -------------------------------------------------------------------------------- /Ziwei-Chatglm3-6B/finetune_demo/scripts/finetune_pt_multiturn.sh: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env bash 2 | 3 | set -ex 4 | 5 | PRE_SEQ_LEN=128 6 | LR=2e-2 7 | NUM_GPUS=1 8 | MAX_SEQ_LEN=2048 9 | DEV_BATCH_SIZE=1 10 | GRAD_ACCUMULARION_STEPS=16 11 | MAX_STEP=1000 12 | SAVE_INTERVAL=500 13 | 14 | DATESTR=`date +%Y%m%d-%H%M%S` 15 | RUN_NAME=tool_alpaca_pt 16 | 17 | BASE_MODEL_PATH=THUDM/chatglm3-6b 18 | DATASET_PATH=formatted_data/tool_alpaca.jsonl 19 | OUTPUT_DIR=output/${RUN_NAME}-${DATESTR}-${PRE_SEQ_LEN}-${LR} 20 | 21 | mkdir -p $OUTPUT_DIR 22 | 23 | torchrun --standalone --nnodes=1 --nproc_per_node=$NUM_GPUS finetune.py \ 24 | --train_format multi-turn \ 25 | --train_file $DATASET_PATH \ 26 | --max_seq_length $MAX_SEQ_LEN \ 27 | --preprocessing_num_workers 1 \ 28 | --model_name_or_path $BASE_MODEL_PATH \ 29 | --output_dir $OUTPUT_DIR \ 30 | --per_device_train_batch_size $DEV_BATCH_SIZE \ 31 | --gradient_accumulation_steps $GRAD_ACCUMULARION_STEPS \ 32 | --max_steps $MAX_STEP \ 33 | --logging_steps 1 \ 34 | --save_steps $SAVE_INTERVAL \ 35 | --learning_rate $LR \ 36 | --pre_seq_len $PRE_SEQ_LEN 2>&1 | tee ${OUTPUT_DIR}/train.log 37 | -------------------------------------------------------------------------------- /Ziwei-Chatglm3-6B/finetune_demo/scripts/format_advertise_gen.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | 3 | import json 4 | from collections import Counter 5 | from argparse import ArgumentParser 6 | import os 7 | 8 | # parser = ArgumentParser() 9 | # parser.add_argument("--path", type=str, required=True) 10 | # 11 | # args = parser.parse_args() 12 | 13 | with open("//data/output.json") as f: 14 | # data = [json.loads(line) for line in f] 15 | data = [] 16 | line = f.read(1) 17 | print(line) 18 | while line: 19 | if line == "{": 20 | j = "" 21 | while line != "}": 22 | if line != '\n': 23 | j += line 24 | line = f.read(1) 25 | j += "}" 26 | print(j) 27 | data += [json.loads(j)] 28 | line = f.read(1) 29 | for x in data: 30 | print(x['output']) 31 | train_examples = [{ 32 | "prompt": x['instruction']+x['input'], 33 | "response": x['output'] 34 | } for x in data] 35 | 36 | os.makedirs("../formatted_data", exist_ok=True) 37 | 38 | with open("../formatted_data/output.jsonl", "w") as f: 39 | for e in train_examples: 40 | f.write(json.dumps(e, ensure_ascii=False) + "\n") 41 | -------------------------------------------------------------------------------- /Ziwei-Chatglm3-6B/finetune_demo/scripts/format_tool_alpaca.py: -------------------------------------------------------------------------------- 1 | #! /usr/bin/env python 2 | 3 | import json 4 | from collections import Counter 5 | from argparse import ArgumentParser 6 | import os 7 | 8 | parser = ArgumentParser() 9 | parser.add_argument("--path", type=str, required=True) 10 | 11 | args = parser.parse_args() 12 | 13 | with open(args.path) as f: 14 | data = json.load(f) 15 | 16 | train_examples = [] 17 | err_count = 0 18 | for setting in data: 19 | api_desc = [setting["NLDocumentation"]] 20 | for instance in setting["Instances"]: 21 | try: 22 | conv = [{ 23 | "role": "user", 24 | "content": instance['input'], 25 | }] 26 | for step in instance['intermediate_steps']: 27 | tool_name, params, react = step[0] 28 | step_thought = react.split("Action:")[0].strip() 29 | observation = step[1] 30 | conv.append({ 31 | "role": "assistant", 32 | "content": step_thought, 33 | }) 34 | conv.append({ 35 | "role": "tool", 36 | "name": tool_name, 37 | "parameters": json.loads(params), 38 | "observation": observation, 39 | }) 40 | conv.append({ 41 | "role": "assistant", 42 | "content": instance['Final Thought'] + "\n" + instance['output'], 43 | }) 44 | except: 45 | err_count += 1 46 | else: 47 | train_examples.append({ 48 | "tools": api_desc, 49 | "conversations": conv 50 | }) 51 | 52 | print("err_count:", err_count) 53 | print("train_examples:", len(train_examples)) 54 | print("conversation distribution:", Counter([len(e["conversations"]) for e in train_examples])) 55 | 56 | os.makedirs("../formatted_data", exist_ok=True) 57 | 58 | with open("../formatted_data/tool_alpaca.jsonl", "w") as f: 59 | for e in train_examples: 60 | f.write(json.dumps(e, ensure_ascii=False) + "\n") -------------------------------------------------------------------------------- /Ziwei-Chatglm3-6B/finetune_demo/trainer.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2020-present the HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | """ 16 | The Trainer class, to easily train a 🤗 Transformers from scratch or finetune it on a new task. 17 | """ 18 | import os 19 | from typing import Optional 20 | from transformers import Trainer 21 | 22 | import torch 23 | from transformers.modeling_utils import PreTrainedModel, unwrap_model 24 | from transformers.utils import logging 25 | 26 | logger = logging.get_logger(__name__) 27 | 28 | WEIGHTS_NAME = "pytorch_model.bin" 29 | TRAINING_ARGS_NAME = "training_args.bin" 30 | 31 | 32 | class PrefixTrainer(Trainer): 33 | def __init__(self, *args, save_changed=False, **kwargs): 34 | self.save_changed = save_changed 35 | super().__init__(*args, **kwargs) 36 | 37 | def _save(self, output_dir: Optional[str] = None, state_dict=None): 38 | # If we are executing this function, we are the process zero, so we don't check for that. 39 | output_dir = output_dir if output_dir is not None else self.args.output_dir 40 | os.makedirs(output_dir, exist_ok=True) 41 | logger.info(f"Saving model checkpoint to {output_dir}") 42 | # Save a trained model and configuration using `save_pretrained()`. 43 | # They can then be reloaded using `from_pretrained()` 44 | if not isinstance(self.model, PreTrainedModel): 45 | if isinstance(unwrap_model(self.model), PreTrainedModel): 46 | if state_dict is None: 47 | state_dict = self.model.state_dict() 48 | unwrap_model(self.model).save_pretrained(output_dir, state_dict=state_dict) 49 | else: 50 | logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.") 51 | if state_dict is None: 52 | state_dict = self.model.state_dict() 53 | torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME)) 54 | else: 55 | if self.save_changed: 56 | print("Saving PrefixEncoder") 57 | state_dict = self.model.state_dict() 58 | filtered_state_dict = {} 59 | for k, v in self.model.named_parameters(): 60 | if v.requires_grad: 61 | filtered_state_dict[k] = state_dict[k] 62 | self.model.save_pretrained(output_dir, state_dict=filtered_state_dict) 63 | else: 64 | print("Saving the whole model") 65 | self.model.save_pretrained(output_dir, state_dict=state_dict) 66 | if self.tokenizer is not None: 67 | self.tokenizer.save_pretrained(output_dir) 68 | 69 | # Good practice: save your training arguments together with the trained model 70 | torch.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME)) 71 | -------------------------------------------------------------------------------- /Ziwei-Chatglm3-6B/openai_api/openai_api.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Implements API for ChatGLM3-6B in OpenAI's format. (https://platform.openai.com/docs/api-reference/chat) 3 | # Usage: python openai_api.py 4 | # Visit http://localhost:8000/docs for documents. 5 | 6 | # 在OpenAI的API中,max_tokens 等价于 HuggingFace 的 max_new_tokens 而不是 max_length,。 7 | # 例如,对于6b模型,设置max_tokens = 8192,则会报错,因为扣除历史记录和提示词后,模型不能输出那么多的tokens。 8 | 9 | import time 10 | from contextlib import asynccontextmanager 11 | from typing import List, Literal, Optional, Union 12 | 13 | import torch 14 | import uvicorn 15 | from fastapi import FastAPI, HTTPException 16 | from fastapi.middleware.cors import CORSMiddleware 17 | from loguru import logger 18 | from pydantic import BaseModel, Field 19 | from sse_starlette.sse import EventSourceResponse 20 | from transformers import AutoTokenizer, AutoModel 21 | 22 | from utils import process_response, generate_chatglm3, generate_stream_chatglm3 23 | 24 | 25 | @asynccontextmanager 26 | async def lifespan(app: FastAPI): # collects GPU memory 27 | yield 28 | if torch.cuda.is_available(): 29 | torch.cuda.empty_cache() 30 | torch.cuda.ipc_collect() 31 | 32 | 33 | app = FastAPI(lifespan=lifespan) 34 | 35 | app.add_middleware( 36 | CORSMiddleware, 37 | allow_origins=["*"], 38 | allow_credentials=True, 39 | allow_methods=["*"], 40 | allow_headers=["*"], 41 | ) 42 | 43 | 44 | class ModelCard(BaseModel): 45 | id: str 46 | object: str = "model" 47 | created: int = Field(default_factory=lambda: int(time.time())) 48 | owned_by: str = "owner" 49 | root: Optional[str] = None 50 | parent: Optional[str] = None 51 | permission: Optional[list] = None 52 | 53 | 54 | class ModelList(BaseModel): 55 | object: str = "list" 56 | data: List[ModelCard] = [] 57 | 58 | 59 | class FunctionCallResponse(BaseModel): 60 | name: Optional[str] = None 61 | arguments: Optional[str] = None 62 | 63 | 64 | class ChatMessage(BaseModel): 65 | role: Literal["user", "assistant", "system", "function"] 66 | content: str = None 67 | name: Optional[str] = None 68 | function_call: Optional[FunctionCallResponse] = None 69 | 70 | 71 | class DeltaMessage(BaseModel): 72 | role: Optional[Literal["user", "assistant", "system"]] = None 73 | content: Optional[str] = None 74 | function_call: Optional[FunctionCallResponse] = None 75 | 76 | class ChatCompletionRequest(BaseModel): 77 | model: str 78 | messages: List[ChatMessage] 79 | temperature: Optional[float] = 0.8 80 | top_p: Optional[float] = 0.8 81 | max_tokens: Optional[int] = None 82 | stream: Optional[bool] = False 83 | functions: Optional[Union[dict, List[dict]]] = None 84 | # Additional parameters 85 | repetition_penalty: Optional[float] = 1.1 86 | 87 | 88 | class ChatCompletionResponseChoice(BaseModel): 89 | index: int 90 | message: ChatMessage 91 | finish_reason: Literal["stop", "length", "function_call"] 92 | 93 | 94 | class ChatCompletionResponseStreamChoice(BaseModel): 95 | index: int 96 | delta: DeltaMessage 97 | finish_reason: Optional[Literal["stop", "length", "function_call"]] 98 | 99 | 100 | class UsageInfo(BaseModel): 101 | prompt_tokens: int = 0 102 | total_tokens: int = 0 103 | completion_tokens: Optional[int] = 0 104 | 105 | 106 | class ChatCompletionResponse(BaseModel): 107 | model: str 108 | object: Literal["chat.completion", "chat.completion.chunk"] 109 | choices: List[Union[ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice]] 110 | created: Optional[int] = Field(default_factory=lambda: int(time.time())) 111 | usage: Optional[UsageInfo] = None 112 | 113 | 114 | @app.get("/v1/models", response_model=ModelList) 115 | async def list_models(): 116 | model_card = ModelCard(id="chatglm3-6b") 117 | return ModelList(data=[model_card]) 118 | 119 | 120 | @app.post("/v1/chat/completions", response_model=ChatCompletionResponse) 121 | async def create_chat_completion(request: ChatCompletionRequest): 122 | global model, tokenizer 123 | 124 | if len(request.messages) < 1 or request.messages[-1].role == "assistant": 125 | raise HTTPException(status_code=400, detail="Invalid request") 126 | 127 | gen_params = dict( 128 | messages=request.messages, 129 | temperature=request.temperature, 130 | top_p=request.top_p, 131 | max_tokens=request.max_tokens or 1024, 132 | echo=False, 133 | stream=request.stream, 134 | repetition_penalty=request.repetition_penalty, 135 | functions=request.functions, 136 | ) 137 | 138 | logger.debug(f"==== request ====\n{gen_params}") 139 | 140 | if request.stream: 141 | generate = predict(request.model, gen_params) 142 | return EventSourceResponse(generate, media_type="text/event-stream") 143 | 144 | response = generate_chatglm3(model, tokenizer, gen_params) 145 | usage = UsageInfo() 146 | 147 | function_call, finish_reason = None, "stop" 148 | if request.functions: 149 | try: 150 | function_call = process_response(response["text"], use_tool=True) 151 | except: 152 | logger.warning("Failed to parse tool call") 153 | 154 | if isinstance(function_call, dict): 155 | finish_reason = "function_call" 156 | function_call = FunctionCallResponse(**function_call) 157 | 158 | message = ChatMessage( 159 | role="assistant", 160 | content=response["text"], 161 | function_call=function_call if isinstance(function_call, FunctionCallResponse) else None, 162 | ) 163 | 164 | choice_data = ChatCompletionResponseChoice( 165 | index=0, 166 | message=message, 167 | finish_reason=finish_reason, 168 | ) 169 | 170 | task_usage = UsageInfo.model_validate(response["usage"]) 171 | for usage_key, usage_value in task_usage.model_dump().items(): 172 | setattr(usage, usage_key, getattr(usage, usage_key) + usage_value) 173 | 174 | return ChatCompletionResponse(model=request.model, choices=[choice_data], object="chat.completion", usage=usage) 175 | 176 | 177 | async def predict(model_id: str, params: dict): 178 | global model, tokenizer 179 | 180 | choice_data = ChatCompletionResponseStreamChoice( 181 | index=0, 182 | delta=DeltaMessage(role="assistant"), 183 | finish_reason=None 184 | ) 185 | chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk") 186 | yield "{}".format(chunk.model_dump_json(exclude_unset=True)) 187 | 188 | previous_text = "" 189 | for new_response in generate_stream_chatglm3(model, tokenizer, params): 190 | decoded_unicode = new_response["text"] 191 | delta_text = decoded_unicode[len(previous_text):] 192 | previous_text = decoded_unicode 193 | 194 | finish_reason = new_response["finish_reason"] 195 | if len(delta_text) == 0 and finish_reason != "function_call": 196 | continue 197 | 198 | function_call = None 199 | if finish_reason == "function_call": 200 | try: 201 | function_call = process_response(decoded_unicode, use_tool=True) 202 | except: 203 | print("Failed to parse tool call") 204 | 205 | if isinstance(function_call, dict): 206 | function_call = FunctionCallResponse(**function_call) 207 | 208 | delta = DeltaMessage( 209 | content=delta_text, 210 | role="assistant", 211 | function_call=function_call if isinstance(function_call, FunctionCallResponse) else None, 212 | ) 213 | 214 | choice_data = ChatCompletionResponseStreamChoice( 215 | index=0, 216 | delta=delta, 217 | finish_reason=finish_reason 218 | ) 219 | chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk") 220 | yield "{}".format(chunk.model_dump_json(exclude_unset=True)) 221 | 222 | choice_data = ChatCompletionResponseStreamChoice( 223 | index=0, 224 | delta=DeltaMessage(), 225 | finish_reason="stop" 226 | ) 227 | chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk") 228 | yield "{}".format(chunk.model_dump_json(exclude_unset=True)) 229 | yield '[DONE]' 230 | 231 | 232 | if __name__ == "__main__": 233 | 234 | model_path = "/home/admin1/桌面/chatglm3-6b" 235 | tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) 236 | model = AutoModel.from_pretrained(model_path, trust_remote_code=True).cuda() 237 | 238 | # 多显卡支持,使用下面两行代替上面一行,将num_gpus改为你实际的显卡数量 239 | # from utils import load_model_on_gpus 240 | # model = load_model_on_gpus("THUDM/chatglm3-6b", num_gpus=2) 241 | model = model.eval() 242 | 243 | uvicorn.run(app, host='0.0.0.0', port=8000, workers=1) 244 | -------------------------------------------------------------------------------- /Ziwei-Chatglm3-6B/openai_api/openai_api_request.py: -------------------------------------------------------------------------------- 1 | # 使用curl命令测试返回 2 | # curl -X POST "http://127.0.0.1:8000/v1/chat/completions" \ 3 | # -H "Content-Type: application/json" \ 4 | # -d "{\"model\": \"chatglm3-6b\", \"messages\": [{\"role\": \"system\", \"content\": \"You are ChatGLM3, a large language model trained by Zhipu.AI. Follow the user's instructions carefully. Respond using markdown.\"}, {\"role\": \"user\", \"content\": \"你好,给我讲一个故事,大概100字\"}], \"stream\": false, \"max_tokens\": 100, \"temperature\": 0.8, \"top_p\": 0.8}" 5 | 6 | # 使用Python代码测返回 7 | import requests 8 | import json 9 | 10 | base_url = "http://127.0.0.1:8000" # 本地部署的地址,或者使用你访问模型的API地址 11 | 12 | def create_chat_completion(model, messages, use_stream=False): 13 | data = { 14 | "model": model, # 模型名称 15 | "messages": messages, # 会话历史 16 | "stream": use_stream, # 是否流式响应 17 | "max_tokens": 100, # 最多生成字数 18 | "temperature": 0.8, # 温度 19 | "top_p": 0.8, # 采样概率 20 | } 21 | 22 | response = requests.post(f"{base_url}/v1/chat/completions", json=data, stream=use_stream) 23 | if response.status_code == 200: 24 | if use_stream: 25 | # 处理流式响应 26 | for line in response.iter_lines(): 27 | if line: 28 | decoded_line = line.decode('utf-8')[6:] 29 | try: 30 | response_json = json.loads(decoded_line) 31 | content = response_json.get("choices", [{}])[0].get("delta", {}).get("content", "") 32 | print(content) 33 | except: 34 | print("Special Token:", decoded_line) 35 | else: 36 | # 处理非流式响应 37 | decoded_line = response.json() 38 | print(decoded_line) 39 | content = decoded_line.get("choices", [{}])[0].get("message", "").get("content", "") 40 | print(content) 41 | else: 42 | print("Error:", response.status_code) 43 | return None 44 | 45 | 46 | if __name__ == "__main__": 47 | chat_messages = [ 48 | { 49 | "role": "system", 50 | "content": "You are ChatGLM3, a large language model trained by Zhipu.AI. Follow the user's instructions carefully. Respond using markdown.", 51 | }, 52 | { 53 | "role": "user", 54 | "content": "你好,给我讲一个故事,大概100字" 55 | }, 56 | { 57 | "role":"assistant", 58 | "content":"从前,有一个美丽的村庄,村子里的居民过着和谐的生活。有一天,村子里来了一只可爱的小狗,它一跃成为村民们的好朋友。小狗每天都会陪伴着大家,带给他们无尽的欢乐。无论是儿童还是老人,都为小狗的来到感到高兴。渐渐地,小狗成了村子的象征,它象征着友谊、忠诚和美好。" 59 | }, 60 | { 61 | "role": "user", 62 | "content": "能把这个故事继续写下去吗" 63 | } 64 | ] 65 | create_chat_completion("chatglm3-6b", chat_messages, use_stream=False) 66 | 67 | 68 | -------------------------------------------------------------------------------- /Ziwei-Chatglm3-6B/openai_api/requirements.txt: -------------------------------------------------------------------------------- 1 | openai>=1.3.0 2 | pydantic>=2.5.1 -------------------------------------------------------------------------------- /Ziwei-Chatglm3-6B/openai_api/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import gc 3 | import json 4 | import torch 5 | from torch.nn import Module 6 | from transformers import PreTrainedModel, PreTrainedTokenizer 7 | from transformers import AutoModel 8 | from transformers.generation.logits_process import LogitsProcessor 9 | from typing import Dict, Union, Optional, Tuple 10 | 11 | 12 | def auto_configure_device_map(num_gpus: int) -> Dict[str, int]: 13 | # transformer.word_embeddings 占用1层 14 | # transformer.final_layernorm 和 lm_head 占用1层 15 | # transformer.layers 占用 28 层 16 | # 总共30层分配到num_gpus张卡上 17 | num_trans_layers = 28 18 | per_gpu_layers = 30 / num_gpus 19 | 20 | # bugfix: 在linux中调用torch.embedding传入的weight,input不在同一device上,导致RuntimeError 21 | # windows下 model.device 会被设置成 transformer.word_embeddings.device 22 | # linux下 model.device 会被设置成 lm_head.device 23 | # 在调用chat或者stream_chat时,input_ids会被放到model.device上 24 | # 如果transformer.word_embeddings.device和model.device不同,则会导致RuntimeError 25 | # 因此这里将transformer.word_embeddings,transformer.final_layernorm,lm_head都放到第一张卡上 26 | # 本文件来源于https://github.com/THUDM/ChatGLM-6B/blob/main/utils.py 27 | # 仅此处做少许修改以支持ChatGLM3 28 | device_map = { 29 | 'transformer.embedding.word_embeddings': 0, 30 | 'transformer.encoder.final_layernorm': 0, 31 | 'transformer.output_layer': 0, 32 | 'transformer.rotary_pos_emb': 0, 33 | 'lm_head': 0 34 | } 35 | 36 | used = 2 37 | gpu_target = 0 38 | for i in range(num_trans_layers): 39 | if used >= per_gpu_layers: 40 | gpu_target += 1 41 | used = 0 42 | assert gpu_target < num_gpus 43 | device_map[f'transformer.encoder.layers.{i}'] = gpu_target 44 | used += 1 45 | 46 | return device_map 47 | 48 | 49 | def load_model_on_gpus(checkpoint_path: Union[str, os.PathLike], num_gpus: int = 2, 50 | device_map: Optional[Dict[str, int]] = None, **kwargs) -> Module: 51 | if num_gpus < 2 and device_map is None: 52 | model = AutoModel.from_pretrained(checkpoint_path, trust_remote_code=True, **kwargs).half().cuda() 53 | else: 54 | from accelerate import dispatch_model 55 | 56 | model = AutoModel.from_pretrained(checkpoint_path, trust_remote_code=True, **kwargs).half() 57 | 58 | if device_map is None: 59 | device_map = auto_configure_device_map(num_gpus) 60 | 61 | model = dispatch_model(model, device_map=device_map) 62 | 63 | return model 64 | 65 | 66 | class InvalidScoreLogitsProcessor(LogitsProcessor): 67 | def __call__( 68 | self, input_ids: torch.LongTensor, scores: torch.FloatTensor 69 | ) -> torch.FloatTensor: 70 | if torch.isnan(scores).any() or torch.isinf(scores).any(): 71 | scores.zero_() 72 | scores[..., 5] = 5e4 73 | return scores 74 | 75 | 76 | def process_response(output: str, use_tool: bool = False) -> Union[str, dict]: 77 | content = "" 78 | for response in output.split("<|assistant|>"): 79 | metadata, content = response.split("\n", maxsplit=1) 80 | if not metadata.strip(): 81 | content = content.strip() 82 | content = content.replace("[[训练时间]]", "2023年") 83 | else: 84 | if use_tool: 85 | content = "\n".join(content.split("\n")[1:-1]) 86 | 87 | def tool_call(**kwargs): 88 | return kwargs 89 | 90 | parameters = eval(content) 91 | content = { 92 | "name": metadata.strip(), 93 | "arguments": json.dumps(parameters, ensure_ascii=False) 94 | } 95 | else: 96 | content = { 97 | "name": metadata.strip(), 98 | "content": content 99 | } 100 | return content 101 | 102 | 103 | @torch.inference_mode() 104 | def generate_stream_chatglm3(model: PreTrainedModel, tokenizer: PreTrainedTokenizer, params: dict): 105 | messages = params["messages"] 106 | functions = params["functions"] 107 | temperature = float(params.get("temperature", 1.0)) 108 | repetition_penalty = float(params.get("repetition_penalty", 1.0)) 109 | top_p = float(params.get("top_p", 1.0)) 110 | max_new_tokens = int(params.get("max_tokens", 256)) 111 | echo = params.get("echo", True) 112 | messages = process_chatglm_messages(messages, functions=functions) 113 | query, role = messages[-1]["content"], messages[-1]["role"] 114 | 115 | inputs = tokenizer.build_chat_input(query, history=messages[:-1], role=role) 116 | inputs = inputs.to(model.device) 117 | input_echo_len = len(inputs["input_ids"][0]) 118 | 119 | if input_echo_len >= model.config.seq_length: 120 | print(f"Input length larger than {model.config.seq_length}") 121 | 122 | eos_token_id = [ 123 | tokenizer.eos_token_id, 124 | tokenizer.get_command("<|user|>"), 125 | ] 126 | 127 | gen_kwargs = { 128 | "max_new_tokens": max_new_tokens, 129 | "do_sample": True if temperature > 1e-5 else False, 130 | "top_p": top_p, 131 | "repetition_penalty": repetition_penalty, 132 | "logits_processor": [InvalidScoreLogitsProcessor()], 133 | } 134 | if temperature > 1e-5: 135 | gen_kwargs["temperature"] = temperature 136 | 137 | total_len = 0 138 | for total_ids in model.stream_generate(**inputs, eos_token_id=eos_token_id, **gen_kwargs): 139 | total_ids = total_ids.tolist()[0] 140 | total_len = len(total_ids) 141 | if echo: 142 | output_ids = total_ids[:-1] 143 | else: 144 | output_ids = total_ids[input_echo_len:-1] 145 | 146 | response = tokenizer.decode(output_ids) 147 | if response and response[-1] != "�": 148 | response, stop_found = apply_stopping_strings(response, ["<|observation|>"]) 149 | 150 | yield { 151 | "text": response, 152 | "usage": { 153 | "prompt_tokens": input_echo_len, 154 | "completion_tokens": total_len - input_echo_len, 155 | "total_tokens": total_len, 156 | }, 157 | "finish_reason": "function_call" if stop_found else None, 158 | } 159 | 160 | if stop_found: 161 | break 162 | 163 | # Only last stream result contains finish_reason, we set finish_reason as stop 164 | ret = { 165 | "text": response, 166 | "usage": { 167 | "prompt_tokens": input_echo_len, 168 | "completion_tokens": total_len - input_echo_len, 169 | "total_tokens": total_len, 170 | }, 171 | "finish_reason": "stop", 172 | } 173 | yield ret 174 | 175 | gc.collect() 176 | torch.cuda.empty_cache() 177 | 178 | 179 | def process_chatglm_messages(messages, functions=None): 180 | _messages = messages 181 | messages = [] 182 | 183 | if functions: 184 | messages.append( 185 | { 186 | "role": "system", 187 | "content": "Answer the following questions as best as you can. You have access to the following tools:", 188 | "tools": functions 189 | } 190 | ) 191 | 192 | for m in _messages: 193 | role, content, func_call = m.role, m.content, m.function_call 194 | if role == "function": 195 | messages.append( 196 | { 197 | "role": "observation", 198 | "content": content 199 | } 200 | ) 201 | 202 | elif role == "assistant" and func_call is not None: 203 | for response in content.split("<|assistant|>"): 204 | metadata, sub_content = response.split("\n", maxsplit=1) 205 | messages.append( 206 | { 207 | "role": role, 208 | "metadata": metadata, 209 | "content": sub_content.strip() 210 | } 211 | ) 212 | else: 213 | messages.append({"role": role, "content": content}) 214 | return messages 215 | 216 | 217 | def generate_chatglm3(model: PreTrainedModel, tokenizer: PreTrainedTokenizer, params: dict): 218 | for response in generate_stream_chatglm3(model, tokenizer, params): 219 | pass 220 | return response 221 | 222 | 223 | def apply_stopping_strings(reply, stop_strings) -> Tuple[str, bool]: 224 | stop_found = False 225 | for string in stop_strings: 226 | idx = reply.find(string) 227 | if idx != -1: 228 | reply = reply[:idx] 229 | stop_found = True 230 | break 231 | 232 | if not stop_found: 233 | # If something like "\nYo" is generated just before "\nYou: is completed, trim it 234 | for string in stop_strings: 235 | for j in range(len(string) - 1, 0, -1): 236 | if reply[-j:] == string[:j]: 237 | reply = reply[:-j] 238 | break 239 | else: 240 | continue 241 | 242 | break 243 | 244 | return reply, stop_found 245 | -------------------------------------------------------------------------------- /Ziwei-Chatglm3-6B/requirements.txt: -------------------------------------------------------------------------------- 1 | protobuf 2 | transformers>=4.30.2 3 | cpm_kernels 4 | torch>=2.0 5 | gradio~=3.39 6 | sentencepiece 7 | accelerate 8 | sse-starlette 9 | streamlit>=1.24.0 10 | fastapi>=0.95.1 11 | uvicorn~=0.24.0 12 | sse_starlette 13 | loguru~=0.7.2 14 | chroma 15 | openai~=1.3.4 16 | langchain~=0.0.352 17 | astunparse~=1.6.3 18 | pydantic~=2.5.1 19 | requests~=2.31.0 -------------------------------------------------------------------------------- /Ziwei-Chatglm3-6B/useChroma/build_chroma.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from langchain.document_loaders import DirectoryLoader 3 | from langchain.embeddings.sentence_transformer import SentenceTransformerEmbeddings 4 | from langchain.vectorstores import Chroma 5 | from langchain.text_splitter import RecursiveCharacterTextSplitter 6 | import openai 7 | from langchain.document_loaders import PyPDFLoader 8 | 9 | sys.path.append('../..') 10 | openai.api_key = "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" 11 | #需要导入的书籍的pdf版本 12 | path = "/home/admin1/桌面/data_pdf" 13 | embedding_function = SentenceTransformerEmbeddings(model_name="shibing624/text2vec-base-chinese") 14 | 15 | 16 | def build_chromadb(): 17 | if path: 18 | loader = DirectoryLoader(path, glob="**/*.pdf", loader_cls=PyPDFLoader) 19 | pages = loader.load() 20 | # print(len(pages)) 21 | # print(page.page_content[0:500]) 22 | text_splitter = RecursiveCharacterTextSplitter( 23 | chunk_size=200, 24 | chunk_overlap=10 25 | ) 26 | splits = text_splitter.split_documents(pages) 27 | chroma = Chroma.from_documents(splits, embedding_function, persist_directory="/home/admin1/桌面/chromaDb") 28 | chroma.persist() 29 | 30 | 31 | if __name__ == '__main__': 32 | build_chromadb() 33 | db = Chroma(persist_directory="/home/admin1/桌面/chromaDb", embedding_function=embedding_function) 34 | docs = db.similarity_search("火命", k=1) 35 | print(docs) 36 | -------------------------------------------------------------------------------- /Ziwei-Chatglm3-6B/with_prompt/prompt_web.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | import torch 3 | import copy 4 | from transformers import AutoModel, AutoTokenizer 5 | 6 | in_response=[] #多个内部回答 7 | 8 | # 设置页面标题、图标和布局 9 | st.set_page_config( 10 | page_title="ChatGLM3-Ziwei 演示", 11 | page_icon=":robot:", 12 | layout="wide" 13 | ) 14 | 15 | # 设置为模型ID或本地文件夹路径 16 | model_path = "/home/admin1/桌面/chatglm3-6b" 17 | 18 | @st.cache_resource 19 | def get_model(): 20 | tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) 21 | model = AutoModel.from_pretrained(model_path, trust_remote_code=True).cuda() 22 | # 多显卡支持,使用下面两行代替上面一行,将num_gpus改为你实际的显卡数量 23 | # from utils import load_model_on_gpus 24 | # model = load_model_on_gpus("THUDM/chatglm3-6b", num_gpus=2) 25 | model = model.eval() 26 | return tokenizer, model 27 | 28 | # 加载Chatglm3的model和tokenizer 29 | tokenizer, model = get_model() 30 | 31 | # 初始化历史记录和past key values 32 | if "history" not in st.session_state: 33 | st.session_state.history = [] 34 | if "past_key_values" not in st.session_state: 35 | st.session_state.past_key_values = None 36 | 37 | # 设置max_length、top_p和temperature 38 | max_length = st.sidebar.slider("max_length", 0, 32768, 8192, step=1) 39 | top_p = st.sidebar.slider("top_p", 0.0, 1.0, 0.8, step=0.01) 40 | temperature = st.sidebar.slider("temperature", 0.0, 1.0, 0.6, step=0.01) 41 | 42 | # 清理会话历史 43 | buttonClean = st.sidebar.button("清理会话历史", key="clean") 44 | if buttonClean: 45 | st.session_state.history = [] 46 | st.session_state.past_key_values = None 47 | if torch.cuda.is_available(): 48 | torch.cuda.empty_cache() 49 | st.rerun() 50 | 51 | # 渲染聊天历史记录 52 | for i, message in enumerate(st.session_state.history): 53 | if message["role"] == "user": 54 | with st.chat_message(name="user", avatar="user"): 55 | st.markdown(message["content"]) 56 | else: 57 | with st.chat_message(name="assistant", avatar="assistant"): 58 | st.markdown(message["content"]) 59 | 60 | # 输入框和输出框 61 | with st.chat_message(name="user", avatar="user"): 62 | input_placeholder = st.empty() 63 | with st.chat_message(name="assistant", avatar="assistant"): 64 | message_placeholder = st.empty() 65 | 66 | #prompt函数 67 | def template_change(template,a,b,i): 68 | template1=template.replace(a,b,i) 69 | return template1 70 | 71 | def get_classify(question,history,past_key_values): 72 | template_string = """ 73 | 请判断下列问题属于占卜的哪一种分类或主题 74 | 注意:你只需要输出你判断的主题分类,即输出一个或几个词语,而不是一段话。 75 | ###问题:{question} 76 | """ 77 | # 填充变量 78 | prompt = template_change(template_string,'{question}',question,1) 79 | #print('prompt 1: '+prompt) 80 | i=0 81 | for classify, history, past_key_values in model.stream_chat( 82 | tokenizer, 83 | prompt, 84 | history, 85 | past_key_values=past_key_values, 86 | max_length=max_length, 87 | top_p=top_p, 88 | temperature=temperature, 89 | return_past_key_values=True, 90 | ): 91 | i+=1 92 | print("1: " + classify) 93 | 94 | return classify 95 | 96 | 97 | # 多问题回答函数 98 | def prompt_main(question, history,past_key_values, theme,num): 99 | global in_response 100 | # 定义模板字符串 101 | ####主题:{theme} 102 | template_string = """ 103 | 你现在是一位占卜师,你需要根据下面的问题与我对话,回答需要解释答案。 104 | ###问题: {question} 105 | 对话主题:{theme} 106 | """ 107 | # 如果问题的答案需要询问者提供信息,那么不要捏造信息,询问相关的信息。 108 | # 使用模板字符串创建一个提示模板 109 | # 填充变量 ,theme=new_question 110 | prompt0 = template_change(template_string,'{question}', question,1) 111 | prompt=template_change(prompt0,'{theme}', theme,1) 112 | 113 | i=0 114 | for inresponse, history, past_key_values in model.stream_chat( 115 | tokenizer, 116 | prompt, 117 | history, 118 | past_key_values=past_key_values, 119 | max_length=max_length, 120 | top_p=top_p, 121 | temperature=temperature, 122 | return_past_key_values=True, 123 | ): 124 | i+=1 125 | in_response.append(inresponse) 126 | print("2:" + in_response[num]) 127 | 128 | # 多回答合并函数 129 | def prompt_merge(num, question,history, past_key_values): 130 | global in_response 131 | reply = '' 132 | for i in range(num): 133 | reply += '\n第' + str(i + 1) + '段文字: ' 134 | reply += in_response[i] 135 | #print('reply : ' + reply) 136 | # 定义模板字符串 137 | template_string = """ 138 | 请把下面{num}段文字改写合并为一段流畅的文字 139 | 回答时开头不要出现【改写后的文字如下:】 140 | 整合后的文字不能重复出现相似的内容 141 | 整合后的文字应该尽量包含{num}段文字里不同的内容 142 | ###{reply} 143 | """ 144 | # 填充变量 145 | prompt0 = template_change(template_string,'{num}', str(num),2) 146 | prompt=template_change(prompt0,'{reply}',reply,1) 147 | #print('prompt 3: ' + prompt) 148 | i=0 149 | for out_response ,history,past_key_values in model.stream_chat( 150 | tokenizer, 151 | prompt, 152 | history, 153 | past_key_values=past_key_values, 154 | max_length=max_length, 155 | top_p=top_p, 156 | temperature=temperature, 157 | return_past_key_values=True, 158 | ): 159 | message_placeholder.markdown(out_response) 160 | print("3:" + out_response) 161 | # for i in range(len(history)): 162 | # print("history ",history[i]) 163 | return history[-1] 164 | 165 | # 获取用户输入 166 | prompt_text = st.chat_input("请输入您的问题") 167 | 168 | # 如果用户输入了内容,则生成回复 169 | if prompt_text: 170 | input_placeholder.markdown(prompt_text) 171 | history = st.session_state.history 172 | past_key_values = st.session_state.past_key_values 173 | history1=copy.deepcopy(st.session_state.history) 174 | past_key_values1 = st.session_state.past_key_values 175 | 176 | num=4 177 | theme = get_classify(prompt_text,history,past_key_values) 178 | history = copy.deepcopy(history1) 179 | for i in range(num): 180 | prompt_main(prompt_text,history,past_key_values, theme, i) 181 | history = copy.deepcopy(history1) 182 | history.append(prompt_merge(num, prompt_text,history,past_key_values)) 183 | # for i in range(len(history)): 184 | # print("history1 ",history[i]) 185 | h=[] 186 | i=0 187 | for response, h, past_key_values in model.stream_chat( 188 | tokenizer, 189 | prompt_text, 190 | h, 191 | past_key_values=past_key_values, 192 | max_length=max_length, 193 | top_p=top_p, 194 | temperature=temperature, 195 | return_past_key_values=True, 196 | ): 197 | i+=1 198 | history[-2]=h[-2] 199 | # for i in range(len(history)): 200 | # print('history: ' , history[i]) 201 | 202 | 203 | # for response, history, past_key_values in model.stream_chat( 204 | # tokenizer, 205 | # prompt_text, 206 | # history, 207 | # past_key_values=past_key_values, 208 | # max_length=max_length, 209 | # top_p=top_p, 210 | # temperature=temperature, 211 | # return_past_key_values=True, 212 | # ): 213 | # message_placeholder.markdown(response) 214 | 215 | # 更新历史记录和past key values 216 | st.session_state.history = history 217 | st.session_state.past_key_values = past_key_values 218 | # streamlit run prompt_web.py 219 | # 请你给我算算今年运势 220 | -------------------------------------------------------------------------------- /Ziwei-Chatglm3-6B/with_prompt/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Dict, Union, Optional 3 | from torch.nn import Module 4 | from transformers import AutoModel 5 | 6 | 7 | def auto_configure_device_map(num_gpus: int) -> Dict[str, int]: 8 | # transformer.word_embeddings 占用1层 9 | # transformer.final_layernorm 和 lm_head 占用1层 10 | # transformer.layers 占用 28 层 11 | # 总共30层分配到num_gpus张卡上 12 | num_trans_layers = 28 13 | per_gpu_layers = 30 / num_gpus 14 | 15 | # bugfix: 在linux中调用torch.embedding传入的weight,input不在同一device上,导致RuntimeError 16 | # windows下 model.device 会被设置成 transformer.word_embeddings.device 17 | # linux下 model.device 会被设置成 lm_head.device 18 | # 在调用chat或者stream_chat时,input_ids会被放到model.device上 19 | # 如果transformer.word_embeddings.device和model.device不同,则会导致RuntimeError 20 | # 因此这里将transformer.word_embeddings,transformer.final_layernorm,lm_head都放到第一张卡上 21 | # 本文件来源于https://github.com/THUDM/ChatGLM-6B/blob/main/utils.py 22 | # 仅此处做少许修改以支持ChatGLM3 23 | device_map = { 24 | 'transformer.embedding.word_embeddings': 0, 25 | 'transformer.encoder.final_layernorm': 0, 26 | 'transformer.output_layer': 0, 27 | 'transformer.rotary_pos_emb': 0, 28 | 'lm_head': 0 29 | } 30 | 31 | used = 2 32 | gpu_target = 0 33 | for i in range(num_trans_layers): 34 | if used >= per_gpu_layers: 35 | gpu_target += 1 36 | used = 0 37 | assert gpu_target < num_gpus 38 | device_map[f'transformer.encoder.layers.{i}'] = gpu_target 39 | used += 1 40 | 41 | return device_map 42 | 43 | 44 | def load_model_on_gpus(checkpoint_path: Union[str, os.PathLike], num_gpus: int = 2, 45 | device_map: Optional[Dict[str, int]] = None, **kwargs) -> Module: 46 | if num_gpus < 2 and device_map is None: 47 | model = AutoModel.from_pretrained(checkpoint_path, trust_remote_code=True, **kwargs).half().cuda() 48 | else: 49 | from accelerate import dispatch_model 50 | 51 | model = AutoModel.from_pretrained(checkpoint_path, trust_remote_code=True, **kwargs).half() 52 | 53 | if device_map is None: 54 | device_map = auto_configure_device_map(num_gpus) 55 | 56 | model = dispatch_model(model, device_map=device_map) 57 | 58 | return model --------------------------------------------------------------------------------