├── .gitattributes ├── .github ├── ISSUE_TEMPLATE │ ├── ISSUE_TEMPLATE_EN.yml │ ├── ISSUE_TEMPLATE_ZH.yml │ └── config.yml └── workflows │ └── stale.yml ├── .gitignore ├── CITATION.cff ├── LICENSE ├── README.md ├── README_EN.md ├── examples ├── README.md ├── alpaca-2-13b.md └── alpaca-2-7b.md ├── notebooks └── gradio_web_demo.ipynb ├── pics ├── banner.png ├── models.png └── screencast.gif ├── prompts ├── README.md ├── alpaca-2-long.txt └── alpaca-2.txt ├── requirements.txt └── scripts ├── README.md ├── attn_and_long_ctx_patches.py ├── ceval ├── eval.py ├── evaluator.py ├── llama_evaluator.py └── subject_mapping.json ├── cmmlu ├── categories.py ├── eval.py ├── evaluator.py └── llama2_evaluator.py ├── inference ├── flash_attn_patch_for_inference.py ├── gradio_demo.py ├── inference_hf.py └── speculative_sample.py ├── langchain ├── doc.txt ├── langchain_qa.py └── langchain_sum.py ├── llama-cpp ├── README.md ├── chat.sh └── server_curl_example.sh ├── longbench ├── config │ ├── dataset2maxlen.json │ └── dataset2prompt.json ├── eval.py ├── metrics.py ├── pred_llama2.py └── requirements.txt ├── merge_llama2_with_chinese_lora_low_mem.py ├── openai_server_demo ├── README.md ├── README_vllm.md ├── openai_api_protocol.py ├── openai_api_protocol_vllm.py ├── openai_api_server.py └── openai_api_server_vllm.py ├── privategpt ├── README.md ├── privateGPT.py └── privateGPT_refine.py ├── tokenizer ├── special_tokens_map.json ├── tokenizer.model └── tokenizer_config.json └── training ├── build_dataset.py ├── ds_zero2_no_offload.json ├── peft ├── __init__.py ├── mapping.py ├── peft_model.py ├── tuners │ ├── __init__.py │ ├── lora.py │ ├── p_tuning.py │ ├── prefix_tuning.py │ └── prompt_tuning.py └── utils │ ├── __init__.py │ ├── adapters_utils.py │ ├── config.py │ ├── other.py │ └── save_and_load.py ├── run_clm_pt_with_peft.py ├── run_clm_sft_with_peft.py ├── run_pt.sh └── run_sft.sh /.gitattributes: -------------------------------------------------------------------------------- 1 | notebooks/** linguist-vendored 2 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/ISSUE_TEMPLATE_EN.yml: -------------------------------------------------------------------------------- 1 | name: English Issue Template 2 | description: For questions related to this project, we will prioritize issues with relatively complete content. 3 | 4 | body: 5 | - type: markdown 6 | attributes: 7 | value: 💡 For open discussions, please visit [Discussion Space](https://github.com/ymcui/Chinese-LLaMA-Alpaca-2/discussions). Please do not open a discussion in Issue section. Thank you. 8 | - type: checkboxes 9 | id: mustchecks 10 | attributes: 11 | label: Check before submitting issues 12 | description: Please check the following items before asking questions. Use the search function to find issues related to your problem. 13 | options: 14 | - label: Make sure to pull the latest code, as some issues and bugs have been fixed. 15 | required: true 16 | - label: I have read the [Wiki](https://github.com/ymcui/Chinese-LLaMA-Alpaca-2/wiki) and [FAQ section](https://github.com/ymcui/Chinese-LLaMA-Alpaca-2/wiki/FAQ) AND searched for similar issues and did not find a similar problem or solution 17 | required: true 18 | - label: Third-party plugin issues - e.g., [llama.cpp](https://github.com/ggerganov/llama.cpp), [LangChain](https://github.com/hwchase17/langchain), [text-generation-webui](https://github.com/oobabooga/text-generation-webui), we recommend checking the corresponding project for solutions 19 | required: true 20 | - type: dropdown 21 | id: question-type 22 | attributes: 23 | label: Type of Issue 24 | description: Please select the type of issue that best matches your problem 25 | options: 26 | - Download issue 27 | - Model conversion and merging 28 | - Model training and fine-tuning 29 | - Model inference 30 | - Model quantization and deployment 31 | - Performance issue 32 | - Other issues 33 | - type: dropdown 34 | id: model-type 35 | attributes: 36 | label: Base Model 37 | description: Please provide the type of base model. For issues related to multiple models, please select the most appropriate one and specify all models in the main text. 38 | options: 39 | - Chinese-LLaMA-2 (7B/13B) 40 | - Chinese-Alpaca-2 (7B/13B) 41 | - Chinese-LLaMA-2-16K (7B/13B) 42 | - Chinese-Alpaca-2-16K (7B/13B) 43 | - Others 44 | - type: dropdown 45 | id: operating-system 46 | attributes: 47 | label: Operating System 48 | description: Please provide your operating system 49 | options: 50 | - Windows 51 | - macOS 52 | - Linux 53 | - type: textarea 54 | id: question-detailed 55 | attributes: 56 | label: Describe your issue in detail 57 | description: Please describe your problem as detail as possible. **For code-related issues, please provide the complete command to reproduce the problem.** This will help us locate the issue quickly. 58 | value: | 59 | ``` 60 | # Please copy-and-paste your command here. 61 | ``` 62 | - type: textarea 63 | id: dependencies 64 | attributes: 65 | label: Dependencies (must be provided for code-related issues) 66 | description: Please provide the versions of common dependencies such as transformers, peft, torch, etc. Use `pip list | grep -E 'transformers|peft|torch|sentencepiece|bitsandbytes'` 67 | value: | 68 | ``` 69 | # Please copy-and-paste your dependencies here. 70 | ``` 71 | - type: textarea 72 | id: logs 73 | attributes: 74 | label: Execution logs or screenshots 75 | description: Please provide logs in text format (upload files if content is too long), or alternatively, screenshots of the execution record. 76 | value: | 77 | ``` 78 | # Please copy-and-paste your logs here. 79 | ``` -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/ISSUE_TEMPLATE_ZH.yml: -------------------------------------------------------------------------------- 1 | name: 中文提问模板 2 | description: 与本项目相关的问题提问,我们会优先查阅内容相对完整的issue。 3 | 4 | body: 5 | - type: markdown 6 | attributes: 7 | value: 💡 开放式讨论请移步[讨论区](https://github.com/ymcui/Chinese-LLaMA-Alpaca-2/discussions),请勿以issue形式提问,谢谢。 8 | - type: checkboxes 9 | id: mustchecks 10 | attributes: 11 | label: 提交前必须检查以下项目 12 | description: 请在提问前检查以下项目,善用搜索功能查找与自己问题相关的issue。 13 | options: 14 | - label: 请确保使用的是仓库最新代码(git pull),一些问题已被解决和修复。 15 | required: true 16 | - label: 我已阅读[项目文档](https://github.com/ymcui/Chinese-LLaMA-Alpaca-2/wiki)和[FAQ章节](https://github.com/ymcui/Chinese-LLaMA-Alpaca-2/wiki/常见问题)并且已在Issue中对问题进行了搜索,没有找到相似问题和解决方案。 17 | required: true 18 | - label: 第三方插件问题:例如[llama.cpp](https://github.com/ggerganov/llama.cpp)、[LangChain](https://github.com/hwchase17/langchain)、[text-generation-webui](https://github.com/oobabooga/text-generation-webui)等,同时建议到对应的项目中查找解决方案。 19 | required: true 20 | - type: dropdown 21 | id: question-type 22 | attributes: 23 | label: 问题类型 24 | description: 请选择最符合的问题类型 25 | options: 26 | - 下载问题 27 | - 模型转换和合并 28 | - 模型训练与精调 29 | - 模型推理 30 | - 模型量化和部署 31 | - 效果问题 32 | - 其他问题 33 | - type: dropdown 34 | id: model-type 35 | attributes: 36 | label: 基础模型 37 | description: 请提供问题涉及的具体模型。 38 | options: 39 | - Chinese-LLaMA-2 (7B/13B) 40 | - Chinese-Alpaca-2 (7B/13B) 41 | - Chinese-LLaMA-2-16K (7B/13B) 42 | - Chinese-Alpaca-2-16K (7B/13B) 43 | - Others 44 | - type: dropdown 45 | id: operating-system 46 | attributes: 47 | label: 操作系统 48 | description: 请提供操作系统类型 49 | options: 50 | - Windows 51 | - macOS 52 | - Linux 53 | - type: textarea 54 | id: question-detailed 55 | attributes: 56 | label: 详细描述问题 57 | description: 请尽量具体地描述遇到的问题,**代码程序类问题务必给出完整运行命令**,这将有助于快速定位问题所在。 58 | value: | 59 | ``` 60 | # 请在此处粘贴运行代码(请粘贴在本代码块里) 61 | ``` 62 | - type: textarea 63 | id: dependencies 64 | attributes: 65 | label: 依赖情况(代码类问题务必提供) 66 | description: 请提供transformers, peft, torch等常规依赖库的版本:`pip list | grep -E 'transformers|peft|torch|sentencepiece|bitsandbytes'` 67 | value: | 68 | ``` 69 | # 请在此处粘贴依赖情况(请粘贴在本代码块里) 70 | ``` 71 | - type: textarea 72 | id: logs 73 | attributes: 74 | label: 运行日志或截图 75 | description: 请优先提供文本形式的log(过长内容请上传文件),粘贴内容放在markdown代码块。或者提供截图形式的运行记录。 76 | value: | 77 | ``` 78 | # 请在此处粘贴运行日志(请粘贴在本代码块里) 79 | ``` -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/config.yml: -------------------------------------------------------------------------------- 1 | blank_issues_enabled: false -------------------------------------------------------------------------------- /.github/workflows/stale.yml: -------------------------------------------------------------------------------- 1 | # This workflow warns and then closes issues and PRs that have had no activity for a specified amount of time. 2 | # 3 | # You can adjust the behavior by modifying this file. 4 | # For more information, see: 5 | # https://github.com/actions/stale 6 | name: Mark stale issues and pull requests 7 | 8 | on: 9 | schedule: 10 | - cron: '0 22 * * *' 11 | 12 | jobs: 13 | stale: 14 | 15 | runs-on: ubuntu-latest 16 | permissions: 17 | issues: write 18 | pull-requests: read 19 | 20 | steps: 21 | - uses: actions/stale@v8 22 | with: 23 | repo-token: ${{ secrets.GITHUB_TOKEN }} 24 | stale-issue-message: 'This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your consideration.' 25 | stale-pr-message: 'This PR has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your consideration.' 26 | stale-issue-label: 'stale' 27 | stale-pr-label: 'stale' 28 | operations-per-run: 500 29 | close-issue-message: 'Closing the issue, since no updates observed. Feel free to re-open if you need any further assistance.' 30 | days-before-stale: 30 31 | days-before-close: 30 32 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | */.DS_Store 3 | -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: 1.2.0 2 | message: "If you find our resources useful, please cite our paper as below." 3 | authors: 4 | - family-names: "Cui" 5 | given-names: "Yiming" 6 | orcid: "https://orcid.org/0000-0002-2452-375X" 7 | - family-names: "Yang" 8 | given-names: "Ziqing" 9 | - family-names: "Yao" 10 | given-names: "Xin" 11 | title: "Chinese LLaMA and Alpaca 2" 12 | version: 1.0 13 | date-released: 2023-07-28 14 | url: "https://github.com/ymcui/Chinese-LLaMA-Alpaca-2" 15 | preferred-citation: 16 | type: article 17 | authors: 18 | - family-names: "Cui" 19 | given-names: "Yiming" 20 | orcid: "https://orcid.org/0000-0002-2452-375X" 21 | - family-names: "Yang" 22 | given-names: "Ziqing" 23 | - family-names: "Yao" 24 | given-names: "Xin" 25 | title: "Efficient and Effective Text Encoding for Chinese LLaMA and Alpaca" 26 | journal: "arXiv pre-print" 27 | year: 2023 28 | url: "https://arxiv.org/abs/2304.08177" -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright 2023 Yiming Cui, Ziqing Yang, Xin Yao 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. -------------------------------------------------------------------------------- /examples/README.md: -------------------------------------------------------------------------------- 1 | ## 输出示例 2 | 3 | 本目录针对Chinese-Alpaca-2模型给出参考输出样例,其目的是帮助用户快速了解模型输出情况,同时也有助于排查下载的模型是否和预期输出一致。输出样本来自于模型在线对战题库(共10个类别),每个类别选择3道题进行展示。 4 | 5 | - [Chinese-Alpaca-2-7B输出样例](./alpaca-2-7b.md) 6 | - [Chinese-Alpaca-2-13B输出样例](./alpaca-2-13b.md) 7 | 8 | **📊 模型在线对战**:[http://llm-arena.ymcui.com](http://llm-arena.ymcui.com/) 9 | -------------------------------------------------------------------------------- /pics/banner.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ymcui/Chinese-LLaMA-Alpaca-2/2a334d1634c857a7f02f885026d02ac4b469479d/pics/banner.png -------------------------------------------------------------------------------- /pics/models.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ymcui/Chinese-LLaMA-Alpaca-2/2a334d1634c857a7f02f885026d02ac4b469479d/pics/models.png -------------------------------------------------------------------------------- /pics/screencast.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ymcui/Chinese-LLaMA-Alpaca-2/2a334d1634c857a7f02f885026d02ac4b469479d/pics/screencast.gif -------------------------------------------------------------------------------- /prompts/README.md: -------------------------------------------------------------------------------- 1 | ## 系统指令 System Prompts 2 | 3 | ### alpaca-2.txt (default) 4 | 5 | 这个文件是训练时采用的默认系统指令,内容极简,因此回复长度上略短于一代Pro系列模型。 6 | 7 | This file is the default system prompt used in the SFT phase, which is simple. Thus, the length of the response may be shorter than 1st-gen Pro series models. 8 | 9 | ### alpaca-2-long.txt 10 | 11 | 这个文件是增加模型回复内容长度的系统指令示例,用户可根据实际情况自行参照修改。但建议保留最原始的`alpaca-2.txt`中的内容,在此基础上进行自定义系统指令的编写。 12 | 13 | This file is an improved system prompt sample to extend the response length. The users can modify this prompt if necessary. However, we suggest keep the original content in `alpaca-2.txt` and add your customized prompt based on this. 14 | -------------------------------------------------------------------------------- /prompts/alpaca-2-long.txt: -------------------------------------------------------------------------------- 1 | You are a helpful assistant. 你是一个乐于助人的助手。请你提供专业、有逻辑、内容真实、有价值的详细回复。 -------------------------------------------------------------------------------- /prompts/alpaca-2.txt: -------------------------------------------------------------------------------- 1 | You are a helpful assistant. 你是一个乐于助人的助手。 -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | peft==0.3.0 2 | torch==2.0.1 3 | transformers==4.35.0 4 | sentencepiece==0.1.99 5 | bitsandbytes==0.41.1 -------------------------------------------------------------------------------- /scripts/README.md: -------------------------------------------------------------------------------- 1 | # 代码与脚本 Code and Scripts 2 | 3 | ### training/ 4 | 5 | 预训练与指令精调代码,Wiki: 6 | 7 | - 预训练:[https://github.com/ymcui/Chinese-LLaMA-Alpaca-2/wiki/pt_scripts_zh](https://github.com/ymcui/Chinese-LLaMA-Alpaca-2/wiki/pt_scripts_zh) 8 | - 指令精调:[https://github.com/ymcui/Chinese-LLaMA-Alpaca-2/wiki/sft_scripts_zh](https://github.com/ymcui/Chinese-LLaMA-Alpaca-2/wiki/sft_scripts_zh) 9 | 10 | Pre-training and instruction finetuning code, Wiki: 11 | 12 | - Pre-training: https://github.com/ymcui/Chinese-LLaMA-Alpaca-2/wiki/pt_scripts_en 13 | - Instruction finetuning: https://github.com/ymcui/Chinese-LLaMA-Alpaca-2/wiki/sft_scripts_en 14 | 15 | ### inference/ 16 | 17 | 使用🤗transformers进行推理,Wiki:[https://github.com/ymcui/Chinese-LLaMA-Alpaca-2/wiki/inference_with_transformers_zh](https://github.com/ymcui/Chinese-LLaMA-Alpaca-2/wiki/inference_with_transformers_zh) 18 | 19 | Inference using 🤗transformers, Wiki: https://github.com/ymcui/Chinese-LLaMA-Alpaca-2/wiki/inference_with_transformers_en 20 | 21 | ### openai_server_demo/ 22 | 23 | 使用fastapi实现的仿OPENAI API风格的服务器,Wiki:[https://github.com/ymcui/Chinese-LLaMA-Alpaca-2/wiki/api_calls_zh](https://github.com/ymcui/Chinese-LLaMA-Alpaca-2/wiki/api_calls_zh) 24 | 25 | A server that implements OPENAI API using fastapi, Wiki: [https://github.com/ymcui/Chinese-LLaMA-Alpaca-2/wiki/api_calls_en](https://github.com/ymcui/Chinese-LLaMA-Alpaca-2/wiki/api_calls_en) 26 | 27 | ### ceval/ 28 | 29 | C-Eval评测脚本,Wiki:[https://github.com/ymcui/Chinese-LLaMA-Alpaca-2/wiki/ceval_zh](https://github.com/ymcui/Chinese-LLaMA-Alpaca-2/wiki/ceval_zh) 30 | 31 | Inference script for C-Eval, Wiki: https://github.com/ymcui/Chinese-LLaMA-Alpaca-2/wiki/ceval_en 32 | 33 | ### cmmlu/ 34 | 35 | CMMLU评测脚本,Wiki:[https://github.com/ymcui/Chinese-LLaMA-Alpaca-2/wiki/cmmlu_zh](https://github.com/ymcui/Chinese-LLaMA-Alpaca-2/wiki/cmmlu_zh) 36 | 37 | Inference script for CMMLU, Wiki: https://github.com/ymcui/Chinese-LLaMA-Alpaca-2/wiki/cmmlu_en 38 | 39 | ### longbench/ 40 | 41 | LongBench评测脚本,Wiki:[https://github.com/ymcui/Chinese-LLaMA-Alpaca-2/wiki/longbench_zh](https://github.com/ymcui/Chinese-LLaMA-Alpaca-2/wiki/longbench_zh) 42 | 43 | Inference script for LongBench, Wiki: https://github.com/ymcui/Chinese-LLaMA-Alpaca-2/wiki/longbench_en 44 | 45 | ### llama-cpp/ 46 | 47 | llama.cpp启动脚本、server脚本,Wiki:[https://github.com/ymcui/Chinese-LLaMA-Alpaca-2/wiki/llamacpp_zh](https://github.com/ymcui/Chinese-LLaMA-Alpaca-2/wiki/llamacpp_zh) 48 | 49 | launch script and server script for llama.cpp, Wiki: https://github.com/ymcui/Chinese-LLaMA-Alpaca-2/wiki/llamacpp_en 50 | 51 | 52 | ### attn_ang_long_ctx_patches.py 53 | 54 | Memory efficient attention补丁和NTK上下文拓展方法补丁。 55 | 56 | Patches for memory efficient attention and NTK context size scaling. 57 | 58 | ### merge_llama2_with_chinese_lora_low_mem.py 59 | 60 | 低资源版合并LLaMA-2/Alpaca-2 LoRA脚本,Wiki:[https://github.com/ymcui/Chinese-LLaMA-Alpaca-2/wiki/manual_conversion_zh](https://github.com/ymcui/Chinese-LLaMA-Alpaca-2/wiki/manual_conversion_zh) 61 | 62 | Script for merging LLaMA-2/Alpaca-2 LoRA (low-resource version). Wiki: https://github.com/ymcui/Chinese-LLaMA-Alpaca-2/wiki/manual_conversion_en 63 | 64 | ### tokenizer/ 65 | 66 | Chinese-LLaMA-2 & Chinese-Alpaca-2 tokenizer -------------------------------------------------------------------------------- /scripts/attn_and_long_ctx_patches.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from typing import Optional, Tuple, Union 4 | import transformers 5 | from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, rotate_half 6 | import math 7 | 8 | try: 9 | from xformers import ops as xops 10 | except ImportError: 11 | xops = None 12 | print( 13 | "Xformers is not installed correctly. If you want to use memory_efficient_attention use the following command to install Xformers\npip install xformers." 14 | ) 15 | 16 | 17 | STORE_KV_BEFORE_ROPE = False 18 | USE_MEM_EFF_ATTENTION = False 19 | ALPHA = 1.0 20 | AUTO_COEFF = 1.0 21 | SCALING_FACTOR = None 22 | 23 | 24 | def apply_rotary_pos_emb_single(q, cos, sin, position_ids): 25 | # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. 26 | cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] 27 | sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] 28 | cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] 29 | sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] 30 | q_embed = (q * cos) + (rotate_half(q) * sin) 31 | return q_embed 32 | 33 | 34 | def xformers_forward( 35 | self, 36 | hidden_states: torch.Tensor, 37 | attention_mask: Optional[torch.Tensor] = None, 38 | position_ids: Optional[torch.LongTensor] = None, 39 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 40 | output_attentions: bool = False, 41 | use_cache: bool = False, 42 | padding_mask=None, 43 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 44 | bsz, q_len, _ = hidden_states.size() 45 | 46 | query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) 47 | key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) 48 | value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) 49 | 50 | kv_seq_len = key_states.shape[-2] 51 | past_kv_len = 0 52 | if past_key_value is not None: 53 | past_kv_len = past_key_value[0].shape[-2] 54 | kv_seq_len += past_kv_len 55 | 56 | if STORE_KV_BEFORE_ROPE is False: 57 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) 58 | query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) 59 | # [bsz, nh, t, hd] 60 | 61 | if past_key_value is not None: 62 | # reuse k, v, self_attention 63 | key_states = torch.cat([past_key_value[0], key_states], dim=2) 64 | value_states = torch.cat([past_key_value[1], value_states], dim=2) 65 | 66 | past_key_value = (key_states, value_states) if use_cache else None 67 | else: 68 | if past_key_value is not None: 69 | # reuse k, v, self_attention 70 | key_states = torch.cat([past_key_value[0], key_states], dim=2) 71 | value_states = torch.cat([past_key_value[1], value_states], dim=2) 72 | past_key_value = (key_states, value_states) if use_cache else None 73 | 74 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) 75 | 76 | query_states = apply_rotary_pos_emb_single(query_states, cos, sin, position_ids) 77 | position_ids = torch.arange(kv_seq_len, dtype=torch.long, device=cos.device) 78 | position_ids = position_ids.unsqueeze(0).view(-1, kv_seq_len) 79 | key_states = apply_rotary_pos_emb_single(key_states, cos, sin, position_ids) 80 | 81 | pad_query = False 82 | if xops is not None and USE_MEM_EFF_ATTENTION: 83 | attn_weights = None 84 | query_states = query_states.transpose(1, 2) 85 | key_states = key_states.transpose(1, 2) 86 | value_states = value_states.transpose(1, 2) 87 | if query_states.size(1)==1 and key_states.size(1)>1: 88 | attn_bias = None 89 | elif query_states.size(1)1 and past_kv_len > 0: 90 | attn_bias = xops.LowerTriangularMask() 91 | query_states = torch.cat( 92 | ( 93 | torch.full( 94 | (bsz, past_kv_len, self.num_heads, self.head_dim), 95 | 0.0, 96 | dtype=query_states.dtype, 97 | device=query_states.device, 98 | ), 99 | query_states, 100 | ), 101 | dim=1, 102 | ) 103 | pad_query = True 104 | else: 105 | attn_bias = xops.LowerTriangularMask() 106 | attn_output = xops.memory_efficient_attention( 107 | query_states, key_states, value_states, attn_bias=attn_bias, p=0) 108 | else: 109 | attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) 110 | 111 | if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): 112 | raise ValueError( 113 | f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is" 114 | f" {attn_weights.size()}" 115 | ) 116 | 117 | if attention_mask is not None: 118 | if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): 119 | raise ValueError( 120 | f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" 121 | ) 122 | attn_weights = attn_weights + attention_mask 123 | attn_weights = torch.max( 124 | attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min, device=attn_weights.device) 125 | ) 126 | 127 | # upcast attention to fp32 128 | attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) 129 | attn_output = torch.matmul(attn_weights, value_states) 130 | 131 | if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): 132 | raise ValueError( 133 | f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" 134 | f" {attn_output.size()}" 135 | ) 136 | 137 | attn_output = attn_output.transpose(1, 2) 138 | if pad_query: 139 | attn_output = attn_output[:,past_kv_len:] 140 | attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) 141 | 142 | attn_output = self.o_proj(attn_output) 143 | 144 | if not output_attentions: 145 | attn_weights = None 146 | 147 | return attn_output, attn_weights, past_key_value 148 | 149 | 150 | old_init = transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.__init__ 151 | 152 | 153 | def _set_cos_sin_cache(self, seq_len, device, dtype): 154 | self.max_seq_len_cached = seq_len 155 | t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.float32) 156 | t = t / self.scaling_factor 157 | 158 | freqs = torch.einsum("i,j->ij", t, self.ntk_inv_freq.to(device)) 159 | # Different from paper, but it uses a different permutation in order to obtain the same calculation 160 | emb = torch.cat((freqs, freqs), dim=-1) 161 | self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) 162 | self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) 163 | 164 | 165 | def adaptive_ntk_init(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=None): 166 | self.alpha = ALPHA 167 | if SCALING_FACTOR is None: 168 | self.scaling_factor = scaling_factor or 1.0 169 | else: 170 | self.scaling_factor = SCALING_FACTOR 171 | if isinstance(ALPHA,(float,int)): 172 | base = base * ALPHA ** (dim / (dim-2)) 173 | self.base = base 174 | elif ALPHA=='auto': 175 | self.base = base 176 | else: 177 | raise ValueError(ALPHA) 178 | old_init(self, dim, max_position_embeddings, base, device) 179 | self.ntk_inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim)) 180 | 181 | self._set_cos_sin_cache = _set_cos_sin_cache 182 | self._set_cos_sin_cache( 183 | self, seq_len=max_position_embeddings, device=self.ntk_inv_freq.device, dtype=torch.get_default_dtype() 184 | ) 185 | 186 | 187 | def adaptive_ntk_forward(self, x, seq_len=None): 188 | if seq_len > self.max_seq_len_cached: 189 | if isinstance(self.alpha,(float,int)): 190 | self._set_cos_sin_cache(self, seq_len=seq_len, device=x.device, dtype=x.dtype) 191 | elif self.alpha=='auto': 192 | t = torch.arange(seq_len, device=x.device, dtype=torch.float32) 193 | t = t / self.scaling_factor 194 | dim = self.dim 195 | alpha = (seq_len / (self.max_position_embeddings/2) - 1) * AUTO_COEFF 196 | base = self.base * alpha ** (dim / (dim-2)) 197 | ntk_inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(x.device) / dim )) 198 | 199 | freqs = torch.einsum("i,j->ij", t, ntk_inv_freq) 200 | emb = torch.cat((freqs, freqs), dim=-1).to(x.device) 201 | cos_cached = emb.cos() 202 | sin_cached = emb.sin() 203 | return ( 204 | cos_cached[:seq_len].to(dtype=x.dtype), 205 | sin_cached[:seq_len].to(dtype=x.dtype) 206 | ) 207 | return ( 208 | self.cos_cached[:seq_len].to(dtype=x.dtype), 209 | self.sin_cached[:seq_len].to(dtype=x.dtype) 210 | ) 211 | 212 | 213 | def apply_attention_patch( 214 | use_memory_efficient_attention=False, 215 | store_kv_before_rope=False 216 | ): 217 | global USE_MEM_EFF_ATTENTION, STORE_KV_BEFORE_ROPE 218 | if use_memory_efficient_attention is True and xops is not None: 219 | USE_MEM_EFF_ATTENTION = use_memory_efficient_attention 220 | print("USE_XFORMERS_ATTENTION: ", USE_MEM_EFF_ATTENTION) 221 | STORE_KV_BEFORE_ROPE = store_kv_before_rope 222 | print("STORE_KV_BEFORE_ROPE:", STORE_KV_BEFORE_ROPE) 223 | transformers.models.llama.modeling_llama.LlamaAttention.forward = xformers_forward 224 | 225 | 226 | def apply_ntk_scaling_patch(alpha: Union[float,str], scaling_factor: Optional[float] = None): 227 | global ALPHA 228 | global SCALING_FACTOR 229 | ALPHA = alpha 230 | SCALING_FACTOR = scaling_factor 231 | try: 232 | ALPHA = float(ALPHA) 233 | except ValueError: 234 | if ALPHA!="auto": 235 | raise ValueError(f"Alpha can only be a float or 'auto', but given {ALPHA}") 236 | print(f"Apply NTK scaling with ALPHA={ALPHA}") 237 | if scaling_factor is None: 238 | print(f"The value of scaling factor will be read from model config file, or set to 1.") 239 | else: 240 | print(f"Warning: scaling factor is set to {SCALING_FACTOR}. \ 241 | If you set the value by hand, do not forget to update \ 242 | max_position_embeddings in the model config file.") 243 | 244 | transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.__init__ = adaptive_ntk_init 245 | if hasattr(transformers.models.llama.modeling_llama,'LlamaLinearScalingRotaryEmbedding'): 246 | transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding.__init__ = adaptive_ntk_init 247 | transformers.models.llama.modeling_llama.LlamaRotaryEmbedding.forward = adaptive_ntk_forward -------------------------------------------------------------------------------- /scripts/ceval/eval.py: -------------------------------------------------------------------------------- 1 | # This code is modified from C-Eval Project: https://github.com/SJTU-LIT/ceval 2 | 3 | import os 4 | import argparse 5 | import pandas as pd 6 | import torch 7 | import json 8 | from llama_evaluator import Llama_Evaluator 9 | 10 | import time 11 | choices = ["A", "B", "C", "D"] 12 | 13 | def main(args, evaluator,take): 14 | assert os.path.exists("subject_mapping.json"), "subject_mapping.json not found!" 15 | with open("subject_mapping.json") as f: 16 | subject_mapping = json.load(f) 17 | filenames = os.listdir("data/val") 18 | subject_list = [val_file.replace("_val.csv","") for val_file in filenames] 19 | accuracy, summary = {}, {} 20 | 21 | run_date=time.strftime('%Y-%m-%d_%H-%M-%S',time.localtime(time.time())) 22 | output_dir = args.output_dir 23 | save_result_dir=os.path.join(output_dir,f"take{take}") 24 | if not os.path.exists(save_result_dir): 25 | os.makedirs(save_result_dir,exist_ok=True) 26 | 27 | all_answers = {} 28 | for index,subject_name in enumerate(subject_list): 29 | print(f"{index/len(subject_list)} Inference starts at {run_date} on {args.model_path} with subject of {subject_name}!") 30 | val_file_path=os.path.join('data/val',f'{subject_name}_val.csv') 31 | dev_file_path=os.path.join('data/dev',f'{subject_name}_dev.csv') 32 | test_file_path=os.path.join('data/test',f'{subject_name}_test.csv') 33 | 34 | val_df=pd.read_csv(val_file_path) if args.do_test is False else pd.read_csv(test_file_path) 35 | dev_df=pd.read_csv(dev_file_path) if args.few_shot else None 36 | 37 | correct_ratio, answers = evaluator.eval_subject(subject_name, val_df, dev_df, 38 | save_result_dir=save_result_dir if args.do_save_csv else None, 39 | few_shot=args.few_shot, 40 | cot=args.cot, 41 | with_prompt=args.with_prompt, 42 | constrained_decoding=args.constrained_decoding, 43 | do_test=args.do_test) 44 | print(f"Subject: {subject_name}") 45 | print(f"Acc: {correct_ratio}") 46 | accuracy[subject_name] = correct_ratio 47 | summary[subject_name] = {"score":correct_ratio, 48 | "num":len(val_df), 49 | "correct":correct_ratio*len(val_df)/100} 50 | all_answers[subject_name] = answers 51 | 52 | json.dump(all_answers,open(save_result_dir+'/submission.json','w'),ensure_ascii=False,indent=4) 53 | print("Accuracy:") 54 | for k, v in accuracy.items(): 55 | print(k, ": ", v) 56 | 57 | 58 | total_num = 0 59 | total_correct = 0 60 | summary['grouped'] = { 61 | "STEM": {"correct": 0.0, "num": 0}, 62 | "Social Science": {"correct": 0.0, "num": 0}, 63 | "Humanities": {"correct": 0.0, "num": 0}, 64 | "Other": {"correct": 0.0, "num": 0} 65 | } 66 | for subj, info in subject_mapping.items(): 67 | group = info[2] 68 | summary['grouped'][group]["num"] += summary[subj]['num'] 69 | summary['grouped'][group]["correct"] += summary[subj]['correct'] 70 | for group, info in summary['grouped'].items(): 71 | info['score'] = info["correct"] / info["num"] 72 | total_num += info["num"] 73 | total_correct += info["correct"] 74 | summary['All'] = {"score": total_correct / total_num, "num": total_num, "correct": total_correct} 75 | 76 | json.dump(summary,open(save_result_dir+'/summary.json','w'),ensure_ascii=False,indent=2) 77 | 78 | 79 | if __name__ == "__main__": 80 | parser = argparse.ArgumentParser() 81 | parser.add_argument("--model_path", type=str) 82 | parser.add_argument("--cot",choices=["False","True"], default="False") 83 | parser.add_argument("--few_shot", choices=["False","True"], default="True") 84 | parser.add_argument("--ntrain", "-k", type=int, default=5) 85 | parser.add_argument("--with_prompt", choices=["False","True"], default="False") 86 | parser.add_argument("--constrained_decoding", choices=["False","True"], default="True") 87 | parser.add_argument("--temperature",type=float,default=0.2) 88 | parser.add_argument("--n_times", default=1,type=int) 89 | parser.add_argument("--do_save_csv", choices=["False","True"], default="False") 90 | parser.add_argument("--output_dir", type=str) 91 | parser.add_argument("--do_test", choices=["False","True"], default="False") 92 | parser.add_argument("--verbose", action="store_true", help="Print detailed information of each example.") 93 | 94 | args = parser.parse_args() 95 | 96 | args.cot = args.cot == "True" 97 | args.few_shot = args.few_shot == "True" 98 | args.with_prompt = args.with_prompt == "True" 99 | args.constrained_decoding = args.constrained_decoding == "True" 100 | args.do_test = args.do_test == "True" 101 | args.do_save_csv = args.do_save_csv == "True" 102 | if args.constrained_decoding is True: 103 | args.n_times=max(args.n_times,1) 104 | print(args) 105 | 106 | device = torch.device(0) 107 | print(device) 108 | evaluator=Llama_Evaluator( 109 | choices=choices, 110 | k=args.ntrain, 111 | model_path=args.model_path, 112 | device=device, 113 | temperature = args.temperature, 114 | verbose = args.verbose 115 | ) 116 | for i in range(args.n_times): 117 | main(args,evaluator=evaluator,take=i) 118 | -------------------------------------------------------------------------------- /scripts/ceval/evaluator.py: -------------------------------------------------------------------------------- 1 | # This code is modified from C-Eval Project: https://github.com/SJTU-LIT/ceval 2 | 3 | import string 4 | class Evaluator: 5 | def __init__(self, choices, model_name, k=-1): 6 | self.choices = choices 7 | self.model_name = model_name 8 | self.k = k 9 | self.puncs = list(string.punctuation) 10 | 11 | def format_example(self, line, include_answer=True): 12 | example = line['question'] 13 | for choice in self.choices: 14 | example += f'\n{choice}. {line[f"{choice}"]}' 15 | example += '\n答案:' 16 | if include_answer: 17 | example += f'{line["answer"]}\n\n' 18 | return example 19 | 20 | def generate_few_shot_prompt(self, subject, dev_df): 21 | prompt = f"以下是中国关于{subject}考试的单项选择题,请选出其中的正确答案。\n\n" 22 | k = self.k 23 | if self.k == -1: 24 | k = dev_df.shape[0] 25 | for i in range(k): 26 | prompt += self.format_example(dev_df.iloc[i, :]) 27 | return prompt 28 | 29 | def eval_subject(self, subject_name, test_df, dev_df=None, few_shot=False, save_result_dir=None): 30 | pass 31 | 32 | def normalize_answer(self,s): 33 | 34 | def white_space_fix(text): 35 | return ' '.join(text.split()) 36 | 37 | def remove_punc(text): 38 | exclude=set(self.puncs) 39 | return ''.join(ch for ch in text if ch not in exclude) 40 | 41 | def lower(text): 42 | return text.lower() 43 | 44 | return white_space_fix(remove_punc(lower(s))) 45 | 46 | def exact_match(self,pred, target): 47 | return self.normalize_answer(pred)==self.normalize_answer(target) 48 | -------------------------------------------------------------------------------- /scripts/ceval/llama_evaluator.py: -------------------------------------------------------------------------------- 1 | # This code is modified from C-Eval Project: https://github.com/SJTU-LIT/ceval 2 | 3 | import os 4 | import re 5 | from tqdm import tqdm 6 | import random 7 | import numpy as np 8 | import torch 9 | from transformers import AutoModelForCausalLM, LlamaTokenizer 10 | from transformers import GenerationConfig 11 | from evaluator import Evaluator 12 | 13 | DEFAULT_SYSTEM_PROMPT = """You are a helpful assistant. 你是一个乐于助人的助手。""" 14 | 15 | 16 | class Llama_Evaluator(Evaluator): 17 | def __init__(self, choices, k, model_path, device, temperature=0.2, verbose=False): 18 | super(Llama_Evaluator, self).__init__(choices, model_path, k) 19 | load_type = torch.float16 20 | self.model_path = model_path 21 | self.device = device 22 | self.verbose = verbose 23 | self.tokenizer = LlamaTokenizer.from_pretrained(model_path, legacy=True) 24 | self.model = AutoModelForCausalLM.from_pretrained( 25 | model_path, 26 | load_in_8bit=False, 27 | torch_dtype=load_type, 28 | low_cpu_mem_usage=True, 29 | device_map='auto', 30 | trust_remote_code=True) 31 | self.generation_config = GenerationConfig( 32 | temperature=temperature, 33 | top_k=40, 34 | top_p=0.9, 35 | do_sample=True, 36 | num_beams=1, 37 | repetition_penalty=1.1, 38 | max_new_tokens=20 39 | ) 40 | 41 | self.sA_id = self.tokenizer.encode("A", add_special_tokens=False)[0] 42 | self.sB_id = self.tokenizer.encode("B", add_special_tokens=False)[0] 43 | self.sC_id = self.tokenizer.encode("C", add_special_tokens=False)[0] 44 | self.sD_id = self.tokenizer.encode("D", add_special_tokens=False)[0] 45 | self.A_id = self.tokenizer.encode(":A")[-1] 46 | self.B_id = self.tokenizer.encode(":B")[-1] 47 | self.C_id = self.tokenizer.encode(":C")[-1] 48 | self.D_id = self.tokenizer.encode(":D")[-1] 49 | 50 | 51 | def eval_subject(self, subject_name, 52 | test_df, 53 | dev_df=None, 54 | few_shot=False, 55 | cot=False, 56 | save_result_dir=None, 57 | with_prompt=False, 58 | constrained_decoding=False, 59 | do_test=False): 60 | all_answers = {} 61 | if constrained_decoding is True: 62 | self.generation_config.output_scores = True 63 | self.generation_config.return_dict_in_generate = True 64 | self.generation_config.max_new_tokens = 1 65 | self.generation_config.top_p = 1.0 66 | self.generation_config.top_k = 0 67 | 68 | correct_num = 0 69 | if save_result_dir: 70 | result = [] 71 | score = [] 72 | if few_shot: 73 | if with_prompt: 74 | history = self.generate_alpaca2_few_shot_prompt(subject_name, dev_df, cot=cot) 75 | else: 76 | history = self.generate_llama2_few_shot_prompt(subject_name, dev_df, cot=cot) 77 | else: 78 | history = '' 79 | answers = ['NA'] * len(test_df) if do_test is True else list(test_df['answer']) 80 | for row_index, row in tqdm(test_df.iterrows(), total=len(test_df)): 81 | question = self.format_example(row, include_answer=False, cot=cot,with_prompt=with_prompt) 82 | instruction = question 83 | if with_prompt: 84 | prompt_template = ( 85 | "[INST] <>\n" 86 | "{system_prompt}\n" 87 | "<>\n\n" 88 | "{instruction} [/INST]" 89 | ) 90 | 91 | instruction = prompt_template.format_map({'instruction': instruction,'system_prompt':DEFAULT_SYSTEM_PROMPT}) 92 | instruction = history + instruction 93 | inputs = self.tokenizer(instruction, return_tensors="pt") 94 | generation_output = self.model.generate( 95 | input_ids = inputs["input_ids"].to(self.device), 96 | attention_mask = inputs['attention_mask'].to(self.device), 97 | eos_token_id=self.tokenizer.eos_token_id, 98 | pad_token_id=self.tokenizer.pad_token_id, 99 | generation_config = self.generation_config 100 | ) 101 | 102 | batch_size, length = inputs.input_ids.shape 103 | if constrained_decoding is True: 104 | logits = generation_output.scores[0][0] 105 | 106 | logits = logits.float().cpu().detach() 107 | choices1_logits = logits[[self.sA_id,self.sB_id,self.sC_id,self.sD_id]] 108 | choices2_logits = logits[[self.A_id,self.B_id,self.C_id,self.D_id]] 109 | choicesAll_logits = (choices1_logits + choices2_logits).numpy() 110 | assert not (np.any(np.isinf(choicesAll_logits)) or np.any(np.isnan(choicesAll_logits))) 111 | ans = {0: "A", 1: "B", 2: "C", 3: "D"}[np.argmax(choicesAll_logits)] 112 | response = self.tokenizer.decode([logits.argmax(-1).item()]) 113 | else: 114 | response = self.tokenizer.decode(generation_output[0, length:], skip_special_tokens=True) 115 | ans, direct_extract = self.extract_answer(row, response) 116 | if ans == answers[row_index]: 117 | correct_num += 1 118 | correct = 1 119 | else: 120 | correct = 0 121 | if self.verbose is True: 122 | print(f"\n======={str(row_index)}=======") 123 | print(f"question: {question}\n") 124 | print(f"response: {response}\n") 125 | print(f"extracted answer: {ans}") 126 | print(f"ground truth: {answers[row_index]} \n") 127 | if save_result_dir: 128 | result.append(response) 129 | score.append(correct) 130 | 131 | all_answers[str(row_index)] = ans 132 | 133 | correct_ratio = 100*correct_num/len(answers) 134 | 135 | if save_result_dir: 136 | test_df['model_output'] = result 137 | test_df['correctness'] = score 138 | test_df.to_csv(os.path.join(save_result_dir, f'{subject_name}_test.csv')) 139 | 140 | return correct_ratio, all_answers 141 | 142 | def format_example(self, line, include_answer=True, cot=False, with_prompt=False): 143 | example = line['question'] 144 | for choice in self.choices: 145 | example += f'\n{choice}. {line[f"{choice}"]}' 146 | if include_answer: 147 | if cot: 148 | example += "\n答案:让我们一步一步思考,\n" + \ 149 | line["explanation"] + f"\n所以答案是{line['answer']}。\n\n" 150 | else: 151 | example += '\n答案:' + line["answer"] + '\n\n' 152 | else: 153 | if with_prompt is False: 154 | if cot: 155 | example += "\n答案:让我们一步一步思考,\n1." 156 | else: 157 | example += '\n答案:' 158 | else: 159 | if cot: 160 | example += "\n答案是什么?让我们一步一步思考,\n1." 161 | else: 162 | example += '\n答案:' 163 | return example 164 | 165 | def generate_llama2_few_shot_prompt(self, subject, dev_df, cot=False): 166 | prompt = f"以下是中国关于{subject}考试的单项选择题,请选出其中的正确答案。\n\n" 167 | k = self.k 168 | if self.k == -1: 169 | k = dev_df.shape[0] 170 | for i in range(k): 171 | prompt += self.format_example( 172 | dev_df.iloc[i, :], 173 | include_answer=True, 174 | cot=cot 175 | ) 176 | return prompt 177 | 178 | def generate_alpaca2_few_shot_prompt(self, subject, dev_df, cot=False): 179 | prompt = f"以下是中国关于{subject}考试的单项选择题,请选出其中的正确答案。\n\n" 180 | prompt_template = ( 181 | "[INST] <>\n" 182 | "{system_prompt}\n" 183 | "<>\n\n" 184 | "{instruction} [/INST]好的,我会结合{subject}相关知识回答" 185 | ) 186 | 187 | prompt = prompt_template.format_map({'instruction':prompt,'system_prompt':DEFAULT_SYSTEM_PROMPT,'subject':subject}) 188 | k = self.k 189 | if self.k == -1: 190 | k = dev_df.shape[0] 191 | for i in range(k): 192 | line = dev_df.iloc[i, :] 193 | q=line['question'] 194 | for choice in self.choices: 195 | q += f'\n{choice}. {line[f"{choice}"]}' 196 | 197 | a = line['answer'] 198 | prompt += "[INST] "+q+"\n答案:[/INST]"+a+"\n" 199 | return prompt 200 | 201 | def extract_answer(self, line, gen_ans): 202 | m = re.findall(r'所以答案是(.+?)。', gen_ans, re.M) 203 | if len(m) > 0 and m[-1] in self.choices: 204 | return m[-1], True 205 | answer_patterns = [ 206 | r'([ABCD])是正确的', 207 | r'选项([ABCD])正确', 208 | r'答案为([ABCD])', 209 | r'答案是([ABCD])', 210 | r'答案([ABCD])', 211 | r'选择([ABCD])', 212 | r'答案:([ABCD])', 213 | r'选择答案([ABCD])' 214 | ] 215 | # RE extraction 216 | for answer_pattern in answer_patterns: 217 | m = re.search(answer_pattern, gen_ans, re.M) 218 | if m: 219 | answer = m.group(1) 220 | return answer, False 221 | # only containing one choice-character 222 | m = re.findall(r'[ABCD]', gen_ans, re.M) 223 | if len(m) >= 1: 224 | answer = m[0] 225 | return answer, False 226 | # only containing one choice-context 227 | choices_dict = {} 228 | pattern = "" 229 | for c in self.choices: 230 | choices_dict[str(line[f'{c}'])] = c 231 | pattern += re.escape(str(line[f'{c}']))+"|" 232 | pattern = pattern[:-1] 233 | m = re.findall(pattern, gen_ans, re.M) 234 | print("w/ escape:",repr(pattern),gen_ans,(len(m)>=1)) 235 | if len(m) >= 1: 236 | answer = choices_dict[m[0]] 237 | return answer, False 238 | return random.choice('ABCD'), False 239 | -------------------------------------------------------------------------------- /scripts/ceval/subject_mapping.json: -------------------------------------------------------------------------------- 1 | { 2 | "computer_network": [ 3 | "Computer Network", 4 | "\u8ba1\u7b97\u673a\u7f51\u7edc", 5 | "STEM" 6 | ], 7 | "operating_system": [ 8 | "Operating System", 9 | "\u64cd\u4f5c\u7cfb\u7edf", 10 | "STEM" 11 | ], 12 | "computer_architecture": [ 13 | "Computer Architecture", 14 | "\u8ba1\u7b97\u673a\u7ec4\u6210", 15 | "STEM" 16 | ], 17 | "college_programming": [ 18 | "College Programming", 19 | "\u5927\u5b66\u7f16\u7a0b", 20 | "STEM" 21 | ], 22 | "college_physics": [ 23 | "College Physics", 24 | "\u5927\u5b66\u7269\u7406", 25 | "STEM" 26 | ], 27 | "college_chemistry": [ 28 | "College Chemistry", 29 | "\u5927\u5b66\u5316\u5b66", 30 | "STEM" 31 | ], 32 | "advanced_mathematics": [ 33 | "Advanced Mathematics", 34 | "\u9ad8\u7b49\u6570\u5b66", 35 | "STEM" 36 | ], 37 | "probability_and_statistics": [ 38 | "Probability and Statistics", 39 | "\u6982\u7387\u7edf\u8ba1", 40 | "STEM" 41 | ], 42 | "discrete_mathematics": [ 43 | "Discrete Mathematics", 44 | "\u79bb\u6563\u6570\u5b66", 45 | "STEM" 46 | ], 47 | "electrical_engineer": [ 48 | "Electrical Engineer", 49 | "\u6ce8\u518c\u7535\u6c14\u5de5\u7a0b\u5e08", 50 | "STEM" 51 | ], 52 | "metrology_engineer": [ 53 | "Metrology Engineer", 54 | "\u6ce8\u518c\u8ba1\u91cf\u5e08", 55 | "STEM" 56 | ], 57 | "high_school_mathematics": [ 58 | "High School Mathematics", 59 | "\u9ad8\u4e2d\u6570\u5b66", 60 | "STEM" 61 | ], 62 | "high_school_physics": [ 63 | "High School Physics", 64 | "\u9ad8\u4e2d\u7269\u7406", 65 | "STEM" 66 | ], 67 | "high_school_chemistry": [ 68 | "High School Chemistry", 69 | "\u9ad8\u4e2d\u5316\u5b66", 70 | "STEM" 71 | ], 72 | "high_school_biology": [ 73 | "High School Biology", 74 | "\u9ad8\u4e2d\u751f\u7269", 75 | "STEM" 76 | ], 77 | "middle_school_mathematics": [ 78 | "Middle School Mathematics", 79 | "\u521d\u4e2d\u6570\u5b66", 80 | "STEM" 81 | ], 82 | "middle_school_biology": [ 83 | "Middle School Biology", 84 | "\u521d\u4e2d\u751f\u7269", 85 | "STEM" 86 | ], 87 | "middle_school_physics": [ 88 | "Middle School Physics", 89 | "\u521d\u4e2d\u7269\u7406", 90 | "STEM" 91 | ], 92 | "middle_school_chemistry": [ 93 | "Middle School Chemistry", 94 | "\u521d\u4e2d\u5316\u5b66", 95 | "STEM" 96 | ], 97 | "veterinary_medicine": [ 98 | "Veterinary Medicine", 99 | "\u517d\u533b\u5b66", 100 | "STEM" 101 | ], 102 | "college_economics": [ 103 | "College Economics", 104 | "\u5927\u5b66\u7ecf\u6d4e\u5b66", 105 | "Social Science" 106 | ], 107 | "business_administration": [ 108 | "Business Administration", 109 | "\u5de5\u5546\u7ba1\u7406", 110 | "Social Science" 111 | ], 112 | "marxism": [ 113 | "Marxism", 114 | "\u9a6c\u514b\u601d\u4e3b\u4e49\u57fa\u672c\u539f\u7406", 115 | "Social Science" 116 | ], 117 | "mao_zedong_thought": [ 118 | "Mao Zedong Thought", 119 | "\u6bdb\u6cfd\u4e1c\u601d\u60f3\u548c\u4e2d\u56fd\u7279\u8272\u793e\u4f1a\u4e3b\u4e49\u7406\u8bba\u4f53\u7cfb\u6982\u8bba", 120 | "Social Science" 121 | ], 122 | "education_science": [ 123 | "Education Science", 124 | "\u6559\u80b2\u5b66", 125 | "Social Science" 126 | ], 127 | "teacher_qualification": [ 128 | "Teacher Qualification", 129 | "\u6559\u5e08\u8d44\u683c", 130 | "Social Science" 131 | ], 132 | "high_school_politics": [ 133 | "High School Politics", 134 | "\u9ad8\u4e2d\u653f\u6cbb", 135 | "Social Science" 136 | ], 137 | "high_school_geography": [ 138 | "High School Geography", 139 | "\u9ad8\u4e2d\u5730\u7406", 140 | "Social Science" 141 | ], 142 | "middle_school_politics": [ 143 | "Middle School Politics", 144 | "\u521d\u4e2d\u653f\u6cbb", 145 | "Social Science" 146 | ], 147 | "middle_school_geography": [ 148 | "Middle School Geography", 149 | "\u521d\u4e2d\u5730\u7406", 150 | "Social Science" 151 | ], 152 | "modern_chinese_history": [ 153 | "Modern Chinese History", 154 | "\u8fd1\u4ee3\u53f2\u7eb2\u8981", 155 | "Humanities" 156 | ], 157 | "ideological_and_moral_cultivation": [ 158 | "Ideological and Moral Cultivation", 159 | "\u601d\u60f3\u9053\u5fb7\u4fee\u517b\u4e0e\u6cd5\u5f8b\u57fa\u7840", 160 | "Humanities" 161 | ], 162 | "logic": [ 163 | "Logic", 164 | "\u903b\u8f91\u5b66", 165 | "Humanities" 166 | ], 167 | "law": [ 168 | "Law", 169 | "\u6cd5\u5b66", 170 | "Humanities" 171 | ], 172 | "chinese_language_and_literature": [ 173 | "Chinese Language and Literature", 174 | "\u4e2d\u56fd\u8bed\u8a00\u6587\u5b66", 175 | "Humanities" 176 | ], 177 | "art_studies": [ 178 | "Art Studies", 179 | "\u827a\u672f\u5b66", 180 | "Humanities" 181 | ], 182 | "professional_tour_guide": [ 183 | "Professional Tour Guide", 184 | "\u5bfc\u6e38\u8d44\u683c", 185 | "Humanities" 186 | ], 187 | "legal_professional": [ 188 | "Legal Professional", 189 | "\u6cd5\u5f8b\u804c\u4e1a\u8d44\u683c", 190 | "Humanities" 191 | ], 192 | "high_school_chinese": [ 193 | "High School Chinese", 194 | "\u9ad8\u4e2d\u8bed\u6587", 195 | "Humanities" 196 | ], 197 | "high_school_history": [ 198 | "High School History", 199 | "\u9ad8\u4e2d\u5386\u53f2", 200 | "Humanities" 201 | ], 202 | "middle_school_history": [ 203 | "Middle School History", 204 | "\u521d\u4e2d\u5386\u53f2", 205 | "Humanities" 206 | ], 207 | "civil_servant": [ 208 | "Civil Servant", 209 | "\u516c\u52a1\u5458", 210 | "Other" 211 | ], 212 | "sports_science": [ 213 | "Sports Science", 214 | "\u4f53\u80b2\u5b66", 215 | "Other" 216 | ], 217 | "plant_protection": [ 218 | "Plant Protection", 219 | "\u690d\u7269\u4fdd\u62a4", 220 | "Other" 221 | ], 222 | "basic_medicine": [ 223 | "Basic Medicine", 224 | "\u57fa\u7840\u533b\u5b66", 225 | "Other" 226 | ], 227 | "clinical_medicine": [ 228 | "Clinical Medicine", 229 | "\u4e34\u5e8a\u533b\u5b66", 230 | "Other" 231 | ], 232 | "urban_and_rural_planner": [ 233 | "Urban and Rural Planner", 234 | "\u6ce8\u518c\u57ce\u4e61\u89c4\u5212\u5e08", 235 | "Other" 236 | ], 237 | "accountant": [ 238 | "Accountant", 239 | "\u6ce8\u518c\u4f1a\u8ba1\u5e08", 240 | "Other" 241 | ], 242 | "fire_engineer": [ 243 | "Fire Engineer", 244 | "\u6ce8\u518c\u6d88\u9632\u5de5\u7a0b\u5e08", 245 | "Other" 246 | ], 247 | "environmental_impact_assessment_engineer": [ 248 | "Environmental Impact Assessment Engineer", 249 | "\u73af\u5883\u5f71\u54cd\u8bc4\u4ef7\u5de5\u7a0b\u5e08", 250 | "Other" 251 | ], 252 | "tax_accountant": [ 253 | "Tax Accountant", 254 | "\u7a0e\u52a1\u5e08", 255 | "Other" 256 | ], 257 | "physician": [ 258 | "Physician", 259 | "\u533b\u5e08\u8d44\u683c", 260 | "Other" 261 | ] 262 | } -------------------------------------------------------------------------------- /scripts/cmmlu/categories.py: -------------------------------------------------------------------------------- 1 | # This code is modified from CMMLU Project: https://github.com/haonan-li/CMMLU 2 | name_en2zh = { 3 | "agronomy": "农学", 4 | "anatomy": "解剖学", 5 | "ancient_chinese": "古汉语", 6 | "arts": "艺术学", 7 | "astronomy": "天文学", 8 | "business_ethics": "商业伦理", 9 | "chinese_civil_service_exam": "中国公务员考试", 10 | "chinese_driving_rule": "中国驾驶规则", 11 | "chinese_food_culture": "中国饮食文化", 12 | "chinese_foreign_policy": "中国外交政策", 13 | "chinese_history":"中国历史", 14 | "chinese_literature": "中国文学", 15 | "chinese_teacher_qualification": "中国教师资格", 16 | "clinical_knowledge": "临床知识", 17 | "college_actuarial_science":"大学精算学", 18 | "college_education":"大学教育学", 19 | "college_engineering_hydrology": "大学工程水文学", 20 | "college_law": "大学法律", 21 | "college_mathematics": "大学数学", 22 | "college_medical_statistics":"大学医学统计", 23 | "college_medicine": "大学医学", 24 | "computer_science": "计算机科学", 25 | "computer_security": "计算机安全", 26 | "conceptual_physics": "概念物理学", 27 | "construction_project_management": "建设工程管理", 28 | "economics": "经济学", 29 | "education": "教育学", 30 | "electrical_engineering": "电气工程", 31 | "elementary_chinese":"小学语文", 32 | "elementary_commonsense":"小学常识", 33 | "elementary_information_and_technology": "小学信息技术", 34 | "elementary_mathematics": "初等数学", 35 | "ethnology": "民族学", 36 | "food_science": "食品科学", 37 | "genetics": "遗传学", 38 | "global_facts": "全球事实", 39 | "high_school_biology": "高中生物", 40 | "high_school_chemistry": "高中化学", 41 | "high_school_geography": "高中地理", 42 | "high_school_mathematics": "高中数学", 43 | "high_school_physics": "高中物理学", 44 | "high_school_politics": "高中政治", 45 | "human_sexuality": "人类性行为", 46 | "international_law": "国际法学", 47 | "journalism": "新闻学", 48 | "jurisprudence": "法理学", 49 | "legal_and_moral_basis": "法律与道德基础", 50 | "logical": "逻辑学", 51 | "machine_learning": "机器学习", 52 | "management": "管理学", 53 | "marketing": "市场营销", 54 | "marxist_theory": "马克思主义理论", 55 | "modern_chinese": "现代汉语", 56 | "nutrition": "营养学", 57 | "philosophy": "哲学", 58 | "professional_accounting": "专业会计", 59 | "professional_law": "专业法学", 60 | "professional_medicine": "专业医学", 61 | "professional_psychology": "专业心理学", 62 | "public_relations": "公共关系", 63 | "security_study":"安全研究", 64 | "sociology": "社会学", 65 | "sports_science": "体育学", 66 | "traditional_chinese_medicine": "中医中药", 67 | "virology": "病毒学", 68 | "world_history":"世界历史", 69 | "world_religions": "世界宗教", 70 | } 71 | 72 | subcategories = { 73 | "agronomy": ['other'], 74 | "anatomy": ['biology'], 75 | "ancient_chinese": ['linguistics','china specific'], 76 | "arts": ['arts'], 77 | "astronomy": ['physics'], 78 | "business_ethics": ['business'], 79 | "chinese_civil_service_exam": ['politics','china specific'], 80 | "chinese_driving_rule": ['other','china specific'], 81 | "chinese_food_culture": ['culture','china specific'], 82 | "chinese_foreign_policy": ['politics','china specific'], 83 | "chinese_history":['history','china specific'], 84 | "chinese_literature": ['literature','china specific'], 85 | "chinese_teacher_qualification": ['education','china specific'], 86 | "college_actuarial_science":['math'], 87 | "college_education":['education'], 88 | "college_engineering_hydrology": ['engineering'], 89 | "college_law": ['law'], 90 | "college_mathematics": ['math'], 91 | "college_medical_statistics":['statistics'], 92 | "clinical_knowledge": ['other'], 93 | "college_medicine": ['other'], 94 | "computer_science": ['computer science'], 95 | "computer_security": ['other'], 96 | "conceptual_physics": ['physics'], 97 | "construction_project_management": ['other','china specific'], 98 | "economics": ['economics'], 99 | "education": ['education'], 100 | "elementary_chinese":['linguistics','china specific'], 101 | "elementary_commonsense":['other','china specific'], 102 | "elementary_information_and_technology": ['other'], 103 | "electrical_engineering": ['engineering'], 104 | "elementary_mathematics": ['math'], 105 | "ethnology": ['culture','china specific'], 106 | "food_science": ['other'], 107 | "genetics": ['biology'], 108 | "global_facts": ['global'], 109 | "high_school_biology": ['biology'], 110 | "high_school_chemistry": ['chemistry'], 111 | "high_school_geography": ['geography'], 112 | "high_school_mathematics": ['math'], 113 | "high_school_physics": ['physics'], 114 | "high_school_politics": ['politics','china specific'], 115 | "human_sexuality": ['other'], 116 | "international_law": ['law'], 117 | "journalism": ['sociology'], 118 | "jurisprudence": ['law'], 119 | "legal_and_moral_basis": ['other'], 120 | "logical": ['philosophy'], 121 | "machine_learning": ['computer science'], 122 | "management": ['business'], 123 | "marketing": ['business'], 124 | "marxist_theory": ['philosophy'], 125 | "modern_chinese": ['linguistics','china specific'], 126 | "nutrition": ['other'], 127 | "philosophy": ['philosophy'], 128 | "professional_accounting": ['business'], 129 | "professional_law": ['law'], 130 | "professional_medicine": ['other'], 131 | "professional_psychology": ['psychology'], 132 | "public_relations": ['politics'], 133 | "security_study": ['politics'], 134 | "sociology": ['culture'], 135 | "sports_science": ['other'], 136 | "traditional_chinese_medicine": ['other','china specific'], 137 | "virology": ['biology'], 138 | "world_history":['history'], 139 | "world_religions": ['global'], 140 | } 141 | 142 | categories = { 143 | "STEM": ["physics", "chemistry", "biology", "computer science", "math", "engineering", "statistics"], 144 | "Humanities": ["history", "philosophy", "law", "arts", "literature", "global"], 145 | "Social Science": ['linguistics',"business", "politics", "culture", "economics", "geography", "psychology", "education", "sociology"], 146 | "Other":["other"], 147 | "China specific": ["china specific"], 148 | } 149 | -------------------------------------------------------------------------------- /scripts/cmmlu/eval.py: -------------------------------------------------------------------------------- 1 | # This code is modified from C-Eval Project: https://github.com/SJTU-LIT/ceval 2 | import os 3 | import argparse 4 | import pandas as pd 5 | import torch 6 | import json 7 | from llama2_evaluator import Llama_Evaluator 8 | from glob import glob 9 | import time 10 | from collections import defaultdict 11 | from categories import name_en2zh, subcategories, categories 12 | choices = ["A", "B", "C", "D"] 13 | 14 | category2subject = defaultdict(list) 15 | for k,v in categories.items(): 16 | for subject, subcat in subcategories.items(): 17 | for c in subcat: 18 | if c in v: 19 | category2subject[k].append(subject) 20 | category2subject_list = defaultdict(list) 21 | for key,value in category2subject.items(): 22 | for val in value: 23 | category2subject_list[val]=[val,name_en2zh[val],key] 24 | category2subject=category2subject_list 25 | choices = ["A", "B", "C", "D"] 26 | 27 | def main(args, evaluator,take): 28 | 29 | subject_mapping = category2subject #json.load(f) 30 | filenames = [s.split('/')[-1] for s in glob(args.input_dir+"/test/*csv")] 31 | subject_list = [val_file.replace(".csv","") for val_file in filenames] 32 | accuracy, summary = {}, {} 33 | 34 | run_date=time.strftime('%Y-%m-%d_%H-%M-%S',time.localtime(time.time())) 35 | output_dir = args.output_dir 36 | save_result_dir=os.path.join(output_dir,f"take{take}") 37 | if not os.path.exists(save_result_dir): 38 | os.makedirs(save_result_dir,exist_ok=True) 39 | 40 | all_answers = {} 41 | for index,subject_name in enumerate(subject_list): 42 | print(f"{index/len(subject_list)} Inference starts at {run_date} on {args.model_path} with subject of {subject_name}!") 43 | val_file_path=os.path.join(args.input_dir+'/test',f'{subject_name}.csv') 44 | dev_file_path=os.path.join(args.input_dir+'/dev',f'{subject_name}.csv') 45 | 46 | val_df=pd.read_csv(val_file_path) 47 | dev_df=pd.read_csv(dev_file_path) if args.few_shot else None 48 | 49 | correct_ratio, answers = evaluator.eval_subject(subject_name, val_df, dev_df, 50 | save_result_dir=save_result_dir if args.do_save_csv else None, 51 | few_shot=args.few_shot, 52 | cot=args.cot, 53 | with_prompt=args.with_prompt, 54 | constrained_decoding=args.constrained_decoding, 55 | do_test=False) 56 | print(f"Subject: {subject_name}") 57 | print(f"Acc: {correct_ratio}") 58 | accuracy[subject_name] = correct_ratio 59 | summary[subject_name] = {"score":correct_ratio, 60 | "num":len(val_df), 61 | "correct":correct_ratio*len(val_df)/100} 62 | all_answers[subject_name] = answers 63 | 64 | json.dump(all_answers,open(save_result_dir+'/submission.json','w'),ensure_ascii=False,indent=4) 65 | print("\n\nModel:",args.model_path) 66 | print("Accuracy:") 67 | for k, v in accuracy.items(): 68 | print(k, ": ", v) 69 | 70 | 71 | total_num = 0 72 | total_correct = 0 73 | summary['grouped'] = { 74 | "China specific": {"correct": 0.0, "num": 0}, 75 | "STEM": {"correct": 0.0, "num": 0}, 76 | "Social Science": {"correct": 0.0, "num": 0}, 77 | "Humanities": {"correct": 0.0, "num": 0}, 78 | "Other": {"correct": 0.0, "num": 0} 79 | } 80 | for subj, info in subject_mapping.items(): 81 | group = info[2] 82 | summary['grouped'][group]["num"] += summary[subj]['num'] 83 | summary['grouped'][group]["correct"] += summary[subj]['correct'] 84 | for group, info in summary['grouped'].items(): 85 | info['score'] = info["correct"] / info["num"] 86 | total_num += info["num"] 87 | total_correct += info["correct"] 88 | summary['All'] = {"score": total_correct / total_num, "num": total_num, "correct": total_correct} 89 | 90 | json.dump(summary,open(save_result_dir+'/summary.json','w'),ensure_ascii=False,indent=2) 91 | 92 | 93 | 94 | 95 | if __name__ == "__main__": 96 | parser = argparse.ArgumentParser() 97 | parser.add_argument("--ntrain", "-k", type=int, default=5) 98 | parser.add_argument("--model_path", type=str) 99 | parser.add_argument("--cot",choices=["False","True"], default="False") 100 | parser.add_argument("--few_shot", choices=["False","True"], default="True") 101 | parser.add_argument("--with_prompt", choices=["False","True"], default="False") 102 | parser.add_argument("--constrained_decoding", choices=["False","True"], default="False") 103 | parser.add_argument("--temperature",type=float,default=0.2) 104 | parser.add_argument("--n_times", default=1,type=int) 105 | parser.add_argument("--do_save_csv", choices=["False","True"], default="False") 106 | parser.add_argument("--output_dir", type=str) 107 | parser.add_argument("--input_dir", type=str) 108 | parser.add_argument("--verbose", action="store_true", help="Print detailed information of each example.") 109 | 110 | args = parser.parse_args() 111 | 112 | args.cot = args.cot == "True" 113 | args.few_shot = args.few_shot == "True" 114 | args.with_prompt = args.with_prompt == "True" 115 | args.do_save_csv = args.do_save_csv == "True" 116 | args.constrained_decoding = args.constrained_decoding == "True" 117 | if args.constrained_decoding is True: 118 | args.n_times=max(args.n_times,1) 119 | print(args) 120 | 121 | device = torch.device(0) 122 | print(device) 123 | evaluator=Llama_Evaluator( 124 | choices=choices, 125 | k=args.ntrain, 126 | model_path=args.model_path, 127 | device=device, 128 | temperature = args.temperature, 129 | verbose = args.verbose 130 | ) 131 | for i in range(args.n_times): 132 | main(args,evaluator=evaluator,take=i) 133 | -------------------------------------------------------------------------------- /scripts/cmmlu/evaluator.py: -------------------------------------------------------------------------------- 1 | # This code is modified from C-Eval Project: https://github.com/SJTU-LIT/ceval 2 | import string 3 | class Evaluator: 4 | def __init__(self, choices, model_path, k=-1): 5 | self.choices = choices 6 | self.model_path = model_path 7 | self.k = k 8 | self.puncs = list(string.punctuation) 9 | 10 | def format_example(self, line, include_answer=True): 11 | example = line['question'] 12 | # print(example) 13 | for choice in self.choices: 14 | example += f'\n{choice}. {line[f"{choice}"]}' 15 | example += '\n答案:' 16 | if include_answer: 17 | example += f'{line["answer"]}\n\n' 18 | return example 19 | 20 | def generate_few_shot_prompt(self, subject, dev_df): 21 | prompt = f"以下是中国关于{subject}考试的单项选择题,请选出其中的正确答案。\n\n" 22 | k = self.k 23 | if self.k == -1: 24 | k = dev_df.shape[0] 25 | for i in range(k): 26 | prompt += self.format_example(dev_df.iloc[i, :]) 27 | return prompt 28 | 29 | def eval_subject(self, subject_name, test_df, dev_df=None, few_shot=False, save_result_dir=None): 30 | pass 31 | 32 | def normalize_answer(self,s): 33 | 34 | def white_space_fix(text): 35 | return ' '.join(text.split()) 36 | 37 | def remove_punc(text): 38 | exclude=set(self.puncs) 39 | return ''.join(ch for ch in text if ch not in exclude) 40 | 41 | def lower(text): 42 | return text.lower() 43 | 44 | return white_space_fix(remove_punc(lower(s))) 45 | 46 | def exact_match(self,pred, target): 47 | return self.normalize_answer(pred)==self.normalize_answer(target) 48 | -------------------------------------------------------------------------------- /scripts/cmmlu/llama2_evaluator.py: -------------------------------------------------------------------------------- 1 | # This code is modified from C-Eval Project: https://github.com/SJTU-LIT/ceval 2 | 3 | import os 4 | import re 5 | from tqdm import tqdm 6 | import random 7 | import numpy as np 8 | import torch 9 | from transformers import AutoModelForCausalLM, LlamaTokenizer 10 | from transformers import GenerationConfig 11 | from evaluator import Evaluator 12 | 13 | class Llama_Evaluator(Evaluator): 14 | def __init__(self, choices, k, model_path, device, temperature=0.2, verbose=False): 15 | super(Llama_Evaluator, self).__init__(choices, model_path, k) 16 | load_type = torch.float16 17 | self.model_path = model_path 18 | self.device = device 19 | self.verbose = verbose 20 | self.tokenizer = LlamaTokenizer.from_pretrained(model_path, legacy=True) 21 | self.model = AutoModelForCausalLM.from_pretrained( 22 | model_path, 23 | load_in_8bit=False, 24 | torch_dtype=load_type, 25 | low_cpu_mem_usage=True, 26 | device_map='auto', 27 | trust_remote_code=True) 28 | self.generation_config = GenerationConfig( 29 | temperature=temperature, 30 | top_k=40, 31 | top_p=0.9, 32 | do_sample=True, 33 | num_beams=1, 34 | repetition_penalty=1.1, 35 | max_new_tokens=20 36 | ) 37 | 38 | self.sA_id = self.tokenizer.encode("A", add_special_tokens=False)[0] 39 | self.sB_id = self.tokenizer.encode("B", add_special_tokens=False)[0] 40 | self.sC_id = self.tokenizer.encode("C", add_special_tokens=False)[0] 41 | self.sD_id = self.tokenizer.encode("D", add_special_tokens=False)[0] 42 | self.A_id = self.tokenizer.encode(":A")[-1] 43 | self.B_id = self.tokenizer.encode(":B")[-1] 44 | self.C_id = self.tokenizer.encode(":C")[-1] 45 | self.D_id = self.tokenizer.encode(":D")[-1] 46 | 47 | 48 | def eval_subject(self, subject_name, 49 | test_df, 50 | dev_df=None, 51 | few_shot=False, 52 | cot=False, 53 | save_result_dir=None, 54 | with_prompt=False, 55 | constrained_decoding=False, 56 | do_test=False): 57 | all_answers = {} 58 | if constrained_decoding is True: 59 | self.generation_config.output_scores = True 60 | self.generation_config.return_dict_in_generate = True 61 | self.generation_config.max_new_tokens = 1 62 | self.generation_config.top_p = 1.0 63 | self.generation_config.top_k = 0 64 | 65 | correct_num = 0 66 | if save_result_dir: 67 | result = [] 68 | score = [] 69 | if few_shot: 70 | if with_prompt: 71 | history = self.generate_few_shot_prompt(subject_name, dev_df, cot=cot) 72 | else: 73 | history = self.generate_few_shot_noprompt(subject_name, dev_df, cot=cot) 74 | else: 75 | history = '' 76 | answers = ['NA'] * len(test_df) if do_test is True else list(test_df['Answer']) 77 | for row_index, row in tqdm(test_df.iterrows(), total=len(test_df)): 78 | question = self.format_example(row, include_answer=False, cot=cot,with_prompt=with_prompt) 79 | instruction = question 80 | if with_prompt: 81 | DEFAULT_SYSTEM_PROMPT = """你是一个乐于助人的助手。""" 82 | prompt_template = ( 83 | "[INST] <>\n" 84 | "{system_prompt}\n" 85 | "<>\n\n" 86 | "{instruction} [/INST]" 87 | ) 88 | 89 | instruction = prompt_template.format_map({'instruction': instruction,'system_prompt':DEFAULT_SYSTEM_PROMPT}) 90 | instruction=history+instruction 91 | 92 | inputs = self.tokenizer(instruction, return_tensors="pt") 93 | generation_output = self.model.generate( 94 | input_ids = inputs["input_ids"].to(self.device), 95 | attention_mask = inputs['attention_mask'].to(self.device), 96 | eos_token_id=self.tokenizer.eos_token_id, 97 | pad_token_id=self.tokenizer.pad_token_id, 98 | generation_config = self.generation_config 99 | ) 100 | 101 | _, length = inputs.input_ids.shape 102 | if constrained_decoding is True: 103 | logits = generation_output.scores[0][0] 104 | 105 | logits = logits.float().cpu().detach() 106 | choices1_logits = logits[[self.sA_id,self.sB_id,self.sC_id,self.sD_id]] 107 | choices2_logits = logits[[self.A_id,self.B_id,self.C_id,self.D_id]] 108 | choicesAll_logits = (choices1_logits + choices2_logits).numpy() 109 | assert not (np.any(np.isinf(choicesAll_logits)) or np.any(np.isnan(choicesAll_logits))) 110 | ans = {0: "A", 1: "B", 2: "C", 3: "D"}[np.argmax(choicesAll_logits)] 111 | response = self.tokenizer.decode([logits.argmax(-1).item()]) 112 | else: 113 | response = self.tokenizer.decode(generation_output[0, length:], skip_special_tokens=True) 114 | ans, _ = self.extract_answer(row, response) 115 | if ans == answers[row_index]: 116 | correct_num += 1 117 | correct = 1 118 | else: 119 | correct = 0 120 | if self.verbose is True: 121 | print(f"\n======={str(row_index)}=======") 122 | print(f"question: {question}\n") 123 | print(f"response: {response}\n") 124 | print(f"extracted answer: {ans}") 125 | print(f"ground truth: {answers[row_index]} \n") 126 | if save_result_dir: 127 | result.append(response) 128 | score.append(correct) 129 | 130 | all_answers[str(row_index)] = ans 131 | 132 | correct_ratio = 100*correct_num/len(answers) 133 | 134 | if save_result_dir: 135 | test_df['model_output'] = result 136 | test_df['correctness'] = score 137 | test_df.to_csv(os.path.join(save_result_dir, f'{subject_name}_test.csv')) 138 | 139 | return correct_ratio, all_answers 140 | 141 | def format_example(self, line, include_answer=True, cot=False, with_prompt=False): 142 | example = line['Question'] 143 | suffix = "" 144 | for choice in self.choices: 145 | example += f'\n{choice}. {line[f"{choice}"]}' 146 | if include_answer: 147 | if cot: 148 | example += "\n答案:让我们一步一步思考,\n" + \ 149 | line["explanation"] + f"\n所以答案是{line['Answer']}。\n\n" 150 | else: 151 | example += '\n答案:' + suffix + line["Answer"] + '\n\n' 152 | else: 153 | if with_prompt is False: 154 | if cot: 155 | example += "\n答案:让我们一步一步思考,\n1." 156 | else: 157 | example += '\n答案:' + suffix 158 | else: 159 | if cot: 160 | example += "\n答案是什么?让我们一步一步思考,\n1." 161 | else: 162 | example += '\n答案:' 163 | return example 164 | def generate_few_shot_noprompt(self, subject, dev_df, cot=False): 165 | prompt = f"以下是中国关于{subject}考试的单项选择题,请选出其中的正确答案。\n\n" 166 | k = self.k 167 | if self.k == -1: 168 | k = dev_df.shape[0] 169 | for i in range(k): 170 | prompt += self.format_example( 171 | dev_df.iloc[i, :], 172 | include_answer=True, 173 | cot=cot 174 | ) 175 | return prompt 176 | 177 | def generate_few_shot_prompt(self, subject, dev_df, cot=False): 178 | DEFAULT_SYSTEM_PROMPT = """你是一个乐于助人的助手。""" 179 | prompt = f"以下是中国关于{subject}考试的单项选择题,请选出其中的正确答案。\n\n" 180 | prompt_template = ( 181 | "[INST] <>\n" 182 | "{system_prompt}\n" 183 | "<>\n\n" 184 | "{instruction} [/INST]好的,我会结合{subject}相关知识回答" 185 | ) 186 | 187 | prompt = prompt_template.format_map({'instruction':prompt,'system_prompt':DEFAULT_SYSTEM_PROMPT,"subject":subject}) 188 | k = self.k 189 | if self.k == -1: 190 | k = dev_df.shape[0] 191 | for i in range(k): 192 | line=dev_df.iloc[i, :] 193 | q=line['Question'] 194 | for choice in self.choices: 195 | q += f'\n{choice}. {line[f"{choice}"]}' 196 | 197 | a=line['Answer'] 198 | prompt+="[INST] "+q+"\n答案:[/INST]"+a+"\n" 199 | 200 | return prompt 201 | 202 | def extract_answer(self, line, gen_ans): 203 | m = re.findall(r'所以答案是(.+?)。', gen_ans, re.M) 204 | if len(m) > 0 and m[-1] in self.choices: 205 | return m[-1], True 206 | answer_patterns = [ 207 | r'([ABCD])是正确的', 208 | r'选项([ABCD])正确', 209 | r'答案为([ABCD])', 210 | r'答案是([ABCD])', 211 | r'答案([ABCD])', 212 | r'选择([ABCD])', 213 | r'答案:([ABCD])', 214 | r'选择答案([ABCD])' 215 | ] 216 | # RE extraction 217 | for answer_pattern in answer_patterns: 218 | m = re.search(answer_pattern, gen_ans, re.M) 219 | if m: 220 | answer = m.group(1) 221 | return answer, False 222 | # only containing one choice-character 223 | m = re.findall(r'[ABCD]', gen_ans, re.M) 224 | if len(m) >= 1: 225 | answer = m[0] 226 | return answer, False 227 | choices_dict = {} 228 | pattern = "" 229 | for c in self.choices: 230 | choices_dict[str(line[f'{c}'])] = c 231 | pattern += re.escape(str(line[f'{c}']))+"|" 232 | pattern = pattern[:-1] 233 | m = re.findall(pattern, gen_ans, re.M) 234 | print("w/ escape:",repr(pattern),gen_ans,(len(m)>=1)) 235 | if len(m) >= 1: 236 | answer = choices_dict[m[0]] 237 | return answer, False 238 | return random.choice('ABCD'), False 239 | -------------------------------------------------------------------------------- /scripts/inference/flash_attn_patch_for_inference.py: -------------------------------------------------------------------------------- 1 | # Below code is based on https://github.com/lm-sys/FastChat/blob/main/fastchat/train/llama_flash_attn_monkey_patch.py. 2 | from typing import Optional, Tuple 3 | import torch 4 | 5 | import transformers 6 | 7 | from einops import rearrange 8 | try: 9 | from flash_attn.flash_attn_interface import flash_attn_with_kvcache 10 | except ImportError: 11 | flash_attn_with_kvcache = None 12 | print( 13 | "FlashAttention-2 is not installed correctly. If you want to use flash attention to inference, flash-attention >= 2.2 is needed. " 14 | "Please check the usage in https://github.com/Dao-AILab/flash-attention for more details." 15 | ) 16 | 17 | 18 | def forward( 19 | self, 20 | hidden_states: torch.Tensor, 21 | attention_mask: Optional[torch.Tensor] = None, 22 | position_ids: Optional[torch.Tensor] = None, 23 | past_key_value: Optional[Tuple[torch.Tensor]] = None, 24 | output_attentions: bool = False, 25 | use_cache: bool = False, 26 | padding_mask=None, 27 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 28 | """Input shape: Batch x Time x Channel 29 | 30 | attention_mask: [bsz, q_len] 31 | """ 32 | bsz, q_len, _ = hidden_states.size() 33 | 34 | query_states = ( 35 | self.q_proj(hidden_states) 36 | .view(bsz, q_len, self.num_heads, self.head_dim) 37 | ) 38 | key_states = ( 39 | self.k_proj(hidden_states) 40 | .view(bsz, q_len, self.num_heads, self.head_dim) 41 | ) 42 | value_states = ( 43 | self.v_proj(hidden_states) 44 | .view(bsz, q_len, self.num_heads, self.head_dim) 45 | ) 46 | 47 | kv_seq_len = key_states.shape[1] 48 | past_kv_len = 0 49 | if past_key_value is not None: 50 | past_kv_len = past_key_value[0].shape[-2] 51 | kv_seq_len += past_kv_len 52 | 53 | cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) 54 | rotary_dim = cos.shape[-1] 55 | cos, sin = cos.squeeze(0,1)[:,:rotary_dim//2].contiguous(), sin.squeeze(0,1)[:,:rotary_dim//2].contiguous() 56 | 57 | if past_key_value is not None: 58 | key_cache = torch.cat([past_key_value[0].transpose(1, 2), key_states], dim=1) 59 | value_cache = torch.cat([past_key_value[1].transpose(1, 2), value_states], dim=1) 60 | else: 61 | key_cache = key_states 62 | value_cache = value_states 63 | 64 | assert not output_attentions, "output_attentions is not supported" 65 | 66 | q = query_states # [bsz, q_len, nh, hd] 67 | k, v = key_states, value_states # [bsz, q_len, nh, hd] 68 | 69 | output = flash_attn_with_kvcache( 70 | q, key_cache, value_cache, k, v, rotary_cos=cos, rotary_sin=sin, cache_seqlens=past_kv_len, softmax_scale=None, causal=True, rotary_interleaved=False 71 | ) 72 | output = rearrange(output, "b s h d -> b s (h d)", b=bsz) 73 | 74 | past_key_value = (key_cache[:,:kv_seq_len].transpose(1,2), value_cache[:,:kv_seq_len].transpose(1,2)) if use_cache else None 75 | 76 | output = self.o_proj(output) 77 | 78 | return output, None, past_key_value 79 | 80 | 81 | # Disable the transformation of the attention mask in LlamaModel as the flash attention 82 | # requires the attention mask to be the same as the key_padding_mask 83 | def _prepare_decoder_attention_mask( 84 | self, attention_mask, input_shape, inputs_embeds, past_key_values_length 85 | ): 86 | return attention_mask 87 | 88 | 89 | def replace_llama_attn_with_flash_attn(): 90 | if flash_attn_with_kvcache != None: 91 | print("USE_FLASH_ATTENTION: ", True) 92 | transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = _prepare_decoder_attention_mask 93 | transformers.models.llama.modeling_llama.LlamaAttention.forward = forward 94 | else: 95 | print("USE_FLASH_ATTENTION: ", False) 96 | -------------------------------------------------------------------------------- /scripts/langchain/doc.txt: -------------------------------------------------------------------------------- 1 | 李白[注 1](701年5月19日—762年11月30日),字太白,号青莲居士,中国唐朝诗人。李白自言祖籍陇西成纪(今甘肃静宁西南),汉飞将军李广后裔,西凉武昭王李暠之后,与李唐皇室同宗。 2 | 一说其幼时内迁,寄籍剑南道绵州昌隆(今四川省江油市青莲镇)。一说先人隋末被窜于碎叶,出生于碎叶,属唐安西都护府(今吉尔吉斯斯坦共和国楚河州托克马克市)。有“诗仙”、“诗侠”、“酒仙”、“谪仙人”等称呼,活跃于盛唐[1],为杰出的浪漫主义诗人。与杜甫合称“李杜”[注 2]。被贺知章呼为“天上谪仙”、“李谪仙”。 3 | 李白的诗歌在唐朝已被选进殷璠编选的《河岳英灵集》、于敦煌石室发现的《唐写本唐人选唐诗》、韦庄编选的《又玄集》和韦縠编选的《才调集》。唐文宗御封李白的诗歌、裴旻的剑舞、张旭的草书称为“三绝”[2]。其作品想像奇特丰富,风格雄奇浪漫,意境独特,清新俊逸;善于利用夸饰与譬喻等手法、自然优美的词句,表现出奔放的情感。诗句行云流水,浑然天成。李白诗篇传诵千年,众多诗句已成经典,清赵翼称:“李杜诗篇万口传”(例如“抽刀断水水更流,举杯消愁愁更愁”等,更被谱入曲)。李白在诗歌的艺术成就被认为是中国浪漫主义诗歌的巅峰。诗作在全唐诗收录于卷161至卷185。有《李太白集》传世。杜甫曾经这样评价过李白的文章:“笔落惊风雨,诗成泣鬼神”、“白也诗无敌,飘然思不群”。 4 | 生平 5 | 早年 6 | 据《新唐书》记载李白为兴圣皇帝(凉武昭王李暠)九世孙[3],如果按照这个说法李白与李唐诸王实际上同宗,应是唐太宗李世民的同辈族弟。亦有野史说其祖是李建成或李元吉,因为被李世民族灭而逃往西域;但此说缺乏佐证,且李建成、李元吉诸子尚在幼年即在玄武门之变后全数被害,留有亲生后嗣的可能性很小。据《旧唐书》记载,李白之父李客为任城尉。更为了学习而隐居。 7 | 李白于武则天大足元年(701年)[4]出生,关于其出生地有多种说法,现在主要有剑南道绵州昌隆县(今四川省江油市)[5]青莲乡(今青莲镇)和西域的碎叶(Suyab,位于今吉尔吉斯托克马克附近)[6]这两种说法,其中后一种说法认为李白直到四岁时(705年)才跟随他的父亲李客迁居蜀地,入籍绵州。李白自四岁(705年)接受启蒙教育,从景云元年(710年)开始,李白开始读诸子史籍[7],开元三年时十四岁(715年)——喜好作赋、剑术、奇书、神仙:“十五观奇书,做赋凌相如”。在青年时期开始在中国各地游历。开元五年左右,李白曾拜撰写《长短经》的赵蕤为师,学习一年有余,这段时期的学习对李白产生了深远的影响。开元六年,在戴天山(约在四川省昌隆县北五十里处)大明寺读书。二十五岁时只身出四川,开始了广泛漫游,南到洞庭湘江,东至吴、越,寓居在安陆(今湖北省安陆市)、应山(今湖北省广水市)。 8 | 中年 9 | 李白曾经在唐玄宗天宝元年(742年)供奉翰林。有一次皇帝因酒酣问李白说:“我朝与天后(武后)之朝何如?”白曰:“天后朝政出多门,国由奸幸,任人之道,如小儿市瓜,不择香味,惟拣肥大者;我朝任人如淘沙取金,剖石采用,皆得其精粹者。”玄宗听后大笑不止[8][9]。但是由于他桀骜不驯的性格,所以仅仅不到两年他就离开了长安。据说是因为他作的《清平调》得罪了当时宠冠后宫的杨贵妃(因李白命“力士脱靴”,高力士引以为大耻,因而以言语诱使杨贵妃认为“可怜飞燕倚新妆”几句是讽刺她)而不容于宫中[注 3]。天宝三年(745年)“恳求还山,帝赐金放还”,离开长安。 10 | 后在洛阳与另两位著名诗人杜甫、高适相识,并结为好友。 11 | 晚年 12 | 天宝十一年(752年)李白年届五十二岁,北上途中游广平郡邯郸、临洺、清漳等地。十月,抵幽州。初有立功边疆思想,在边地习骑射。后发现安禄山野心,登黄金台痛哭。不久即离幽州南下。 13 | 安史之乱爆发时,李白游华山,南下回宣城,后上庐山。756年12月,李白被三次邀请,下山赴寻阳入永王李璘幕僚[10]。永王触怒唐肃宗被杀后,李白也获罪入狱。幸得郭子仪力保,方得免死,改为流徙夜郎(今贵州关岭县一带),在途经巫山时遇赦,此时他已经59岁。(参见李璘之乱) 14 | 李白晚年在江南一带漂泊。在他61岁时,听到太尉李光弼率领大军讨伐安史叛军,于是他北上准备追随李光弼从军杀敌,但是中途因病折回。第二年,李白投奔他的族叔、当时在当涂(今属安徽省马鞍山)当县令的李阳冰。同年11月,李白病逝于寓所,终年61岁,葬当涂龙山。唐宪宗元和十二年(817年),宣歙观察使范传正根据李白生前“志在青山”的遗愿,将其墓迁至当涂青山。 15 | 去世 16 | 《新唐书》记载,唐代宗继位后以左拾遗召李白,但李白当时已去世。 17 | 李阳冰在《草堂集序》中说李白是病死的[11];皮日休在诗作中记载,李白是患“腐胁疾”而死的[12]。 18 | 《旧唐书》则记载,李白流放虽然遇赦,但因途中饮酒过度,醉死于宣城。中国民间有“太白捞月”的传说:李白在舟中赏月,饮酒大醉,想要跳下船至水里捞月而溺死[13][14][15];在民间的求签活动中亦有“太白捞月”一签文,乃是下下签[16]。 19 | 作品 20 | 李白一生创作大量的诗歌,绝大多数已散佚[17],流传至今的只有九百多首。他的诗歌创作涉及的中国古典诗歌的题材非常广泛,而且在许多题材都有名作出现,而且因为际遇的不同,每个时期的诗风都有所不同。 -------------------------------------------------------------------------------- /scripts/langchain/langchain_qa.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | parser = argparse.ArgumentParser() 4 | parser.add_argument('--file_path', required=True, type=str) 5 | parser.add_argument('--embedding_path', required=True, type=str) 6 | parser.add_argument('--model_path', required=True, type=str) 7 | parser.add_argument('--gpu_id', default="0", type=str) 8 | parser.add_argument('--chain_type', default="refine", type=str) 9 | args = parser.parse_args() 10 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id 11 | file_path = args.file_path 12 | embedding_path = args.embedding_path 13 | model_path = args.model_path 14 | 15 | import torch 16 | from langchain.llms.huggingface_pipeline import HuggingFacePipeline 17 | from langchain.text_splitter import RecursiveCharacterTextSplitter 18 | from langchain.vectorstores import FAISS 19 | from langchain.document_loaders import TextLoader 20 | from langchain.prompts import PromptTemplate 21 | from langchain.chains import RetrievalQA 22 | from langchain.embeddings.huggingface import HuggingFaceEmbeddings 23 | 24 | prompt_template = ( 25 | "[INST] <>\n" 26 | "You are a helpful assistant. 你是一个乐于助人的助手。\n" 27 | "<>\n\n" 28 | "{context}\n{question} [/INST]" 29 | ) 30 | 31 | refine_prompt_template = ( 32 | "[INST] <>\n" 33 | "You are a helpful assistant. 你是一个乐于助人的助手。\n" 34 | "<>\n\n" 35 | "这是原始问题: {question}\n" 36 | "已有的回答: {existing_answer}\n" 37 | "现在还有一些文字,(如果有需要)你可以根据它们完善现有的回答。" 38 | "\n\n" 39 | "{context_str}\n" 40 | "\n\n" 41 | "请根据新的文段,进一步完善你的回答。" 42 | " [/INST]" 43 | ) 44 | 45 | initial_qa_template = ( 46 | "[INST] <>\n" 47 | "You are a helpful assistant. 你是一个乐于助人的助手。\n" 48 | "<>\n\n" 49 | "以下为背景知识:\n" 50 | "{context_str}" 51 | "\n" 52 | "请根据以上背景知识, 回答这个问题:{question}。" 53 | " [/INST]" 54 | ) 55 | 56 | 57 | if __name__ == '__main__': 58 | load_type = torch.float16 59 | if not torch.cuda.is_available(): 60 | raise RuntimeError("No CUDA GPUs are available.") 61 | 62 | loader = TextLoader(file_path) 63 | documents = loader.load() 64 | text_splitter = RecursiveCharacterTextSplitter( 65 | chunk_size=600, chunk_overlap=100) 66 | texts = text_splitter.split_documents(documents) 67 | 68 | print("Loading the embedding model...") 69 | embeddings = HuggingFaceEmbeddings(model_name=embedding_path) 70 | docsearch = FAISS.from_documents(texts, embeddings) 71 | 72 | print("loading LLM...") 73 | model = HuggingFacePipeline.from_model_id(model_id=model_path, 74 | task="text-generation", 75 | device=0, 76 | pipeline_kwargs={ 77 | "max_new_tokens": 400, 78 | "do_sample": True, 79 | "temperature": 0.2, 80 | "top_k": 40, 81 | "top_p": 0.9, 82 | "repetition_penalty": 1.1}, 83 | model_kwargs={ 84 | "torch_dtype": load_type, 85 | "low_cpu_mem_usage": True, 86 | "trust_remote_code": True} 87 | ) 88 | 89 | if args.chain_type == "stuff": 90 | PROMPT = PromptTemplate( 91 | template=prompt_template, input_variables=["context", "question"] 92 | ) 93 | chain_type_kwargs = {"prompt": PROMPT} 94 | qa = RetrievalQA.from_chain_type( 95 | llm=model, 96 | chain_type="stuff", 97 | retriever=docsearch.as_retriever(search_kwargs={"k": 1}), 98 | chain_type_kwargs=chain_type_kwargs) 99 | 100 | elif args.chain_type == "refine": 101 | refine_prompt = PromptTemplate( 102 | input_variables=["question", "existing_answer", "context_str"], 103 | template=refine_prompt_template, 104 | ) 105 | initial_qa_prompt = PromptTemplate( 106 | input_variables=["context_str", "question"], 107 | template=initial_qa_template, 108 | ) 109 | chain_type_kwargs = {"question_prompt": initial_qa_prompt, "refine_prompt": refine_prompt} 110 | qa = RetrievalQA.from_chain_type( 111 | llm=model, chain_type="refine", 112 | retriever=docsearch.as_retriever(search_kwargs={"k": 1}), 113 | chain_type_kwargs=chain_type_kwargs) 114 | 115 | while True: 116 | query = input("请输入问题:") 117 | if len(query.strip())==0: 118 | break 119 | print(qa.run(query)) 120 | -------------------------------------------------------------------------------- /scripts/langchain/langchain_sum.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | parser = argparse.ArgumentParser() 4 | parser.add_argument('--file_path', required=True, type=str) 5 | parser.add_argument('--model_path', required=True, type=str) 6 | parser.add_argument('--gpu_id', default="0", type=str) 7 | parser.add_argument('--chain_type', default="refine", type=str) 8 | args = parser.parse_args() 9 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu_id 10 | file_path = args.file_path 11 | model_path = args.model_path 12 | 13 | import torch 14 | from langchain.llms.huggingface_pipeline import HuggingFacePipeline 15 | from langchain.text_splitter import RecursiveCharacterTextSplitter 16 | from langchain.prompts import PromptTemplate 17 | from langchain.chains.summarize import load_summarize_chain 18 | 19 | prompt_template = ( 20 | "[INST] <>\n" 21 | "You are a helpful assistant. 你是一个乐于助人的助手。\n" 22 | "<>\n\n" 23 | "请为以下文字写一段摘要:\n{text} [/INST]" 24 | ) 25 | refine_template = ( 26 | "[INST] <>\n" 27 | "You are a helpful assistant. 你是一个乐于助人的助手。\n" 28 | "<>\n\n" 29 | "已有一段摘要:{existing_answer}\n" 30 | "现在还有一些文字,(如果有需要)你可以根据它们完善现有的摘要。" 31 | "\n" 32 | "{text}\n" 33 | "\n" 34 | "如果这段文字没有用,返回原来的摘要即可。请你生成一个最终的摘要。" 35 | " [/INST]" 36 | ) 37 | 38 | 39 | if __name__ == '__main__': 40 | load_type = torch.float16 41 | if not torch.cuda.is_available(): 42 | raise RuntimeError("No CUDA GPUs are available.") 43 | 44 | text_splitter = RecursiveCharacterTextSplitter(chunk_size=600, chunk_overlap=100, length_function=len) 45 | with open(file_path) as f: 46 | text = f.read() 47 | docs = text_splitter.create_documents([text]) 48 | 49 | print("loading LLM...") 50 | model = HuggingFacePipeline.from_model_id(model_id=model_path, 51 | task="text-generation", 52 | device=0, 53 | pipeline_kwargs={ 54 | "max_new_tokens": 400, 55 | "do_sample": True, 56 | "temperature": 0.2, 57 | "top_k": 40, 58 | "top_p": 0.9, 59 | "repetition_penalty": 1.1}, 60 | model_kwargs={ 61 | "torch_dtype" : load_type, 62 | "low_cpu_mem_usage" : True, 63 | "trust_remote_code": True} 64 | ) 65 | 66 | PROMPT = PromptTemplate(template=prompt_template, input_variables=["text"]) 67 | REFINE_PROMPT = PromptTemplate( 68 | template=refine_template,input_variables=["existing_answer", "text"], 69 | ) 70 | 71 | if args.chain_type == "stuff": 72 | chain = load_summarize_chain(model, chain_type="stuff", prompt=PROMPT) 73 | elif args.chain_type == "refine": 74 | chain = load_summarize_chain(model, chain_type="refine", question_prompt=PROMPT, refine_prompt=REFINE_PROMPT) 75 | print(chain.run(docs)) 76 | -------------------------------------------------------------------------------- /scripts/llama-cpp/README.md: -------------------------------------------------------------------------------- 1 | ## llama.cpp相关示例脚本 2 | 3 | 具体使用方法参考:https://github.com/ymcui/Chinese-LLaMA-Alpaca-2/wiki/llamacpp_zh 4 | 5 | Detailed usage: https://github.com/ymcui/Chinese-LLaMA-Alpaca-2/wiki/llamacpp_en 6 | 7 | ### chat.sh 8 | 9 | 用于与Alpaca-2系列模型进行对话交流。 10 | 11 | Chat with Alpaca-2 models. 12 | 13 | ### server_curl_example.sh 14 | 15 | 架设server后使用curl调用示例。 16 | 17 | An example to use curl for API calls after setting up server. 18 | -------------------------------------------------------------------------------- /scripts/llama-cpp/chat.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # temporary script to chat with Chinese Alpaca-2 model 4 | # usage: ./chat.sh alpaca2-ggml-model-path your-first-instruction 5 | 6 | SYSTEM_PROMPT='You are a helpful assistant. 你是一个乐于助人的助手。' 7 | # SYSTEM_PROMPT='You are a helpful assistant. 你是一个乐于助人的助手。请你提供专业、有逻辑、内容真实、有价值的详细回复。' # Try this one, if you prefer longer response. 8 | MODEL_PATH=$1 9 | FIRST_INSTRUCTION=$2 10 | 11 | ./main -m "$MODEL_PATH" \ 12 | --color -i -c 4096 -t 8 --temp 0.5 --top_k 40 --top_p 0.9 --repeat_penalty 1.1 \ 13 | --in-prefix-bos --in-prefix ' [INST] ' --in-suffix ' [/INST]' -p \ 14 | "[INST] <> 15 | $SYSTEM_PROMPT 16 | <> 17 | 18 | $FIRST_INSTRUCTION [/INST]" 19 | -------------------------------------------------------------------------------- /scripts/llama-cpp/server_curl_example.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # NOTE: start the server first before running this script. 4 | # usage: ./server_curl_example.sh your-instruction 5 | 6 | SYSTEM_PROMPT='You are a helpful assistant. 你是一个乐于助人的助手。' 7 | # SYSTEM_PROMPT='You are a helpful assistant. 你是一个乐于助人的助手。请你提供专业、有逻辑、内容真实、有价值的详细回复。' # Try this one, if you prefer longer response. 8 | INSTRUCTION=$1 9 | ALL_PROMPT="[INST] <>\n$SYSTEM_PROMPT\n<>\n\n$INSTRUCTION [/INST]" 10 | CURL_DATA="{\"prompt\": \"$ALL_PROMPT\",\"n_predict\": 128}" 11 | 12 | curl --request POST \ 13 | --url http://localhost:8080/completion \ 14 | --header "Content-Type: application/json" \ 15 | --data "$CURL_DATA" 16 | -------------------------------------------------------------------------------- /scripts/longbench/config/dataset2maxlen.json: -------------------------------------------------------------------------------- 1 | { 2 | "narrativeqa": 128, 3 | "qasper": 128, 4 | "multifieldqa_en": 64, 5 | "multifieldqa_zh": 64, 6 | "hotpotqa": 32, 7 | "2wikimqa": 32, 8 | "musique": 32, 9 | "dureader": 128, 10 | "gov_report": 512, 11 | "qmsum": 512, 12 | "multi_news": 512, 13 | "vcsum": 512, 14 | "trec": 64, 15 | "triviaqa": 32, 16 | "samsum": 128, 17 | "lsht": 64, 18 | "passage_count": 32, 19 | "passage_retrieval_en": 32, 20 | "passage_retrieval_zh": 32, 21 | "lcc": 64, 22 | "repobench-p": 64 23 | } -------------------------------------------------------------------------------- /scripts/longbench/config/dataset2prompt.json: -------------------------------------------------------------------------------- 1 | { 2 | "narrativeqa": "You are given a story, which can be either a novel or a movie script, and a question. Answer the question asconcisely as you can, using a single phrase if possible. Do not provide any explanation.\n\nStory: {context}\n\nNow, answer the question based on the story asconcisely as you can, using a single phrase if possible. Do not provide any explanation.\n\nQuestion: {input}\n\nAnswer:", 3 | "qasper": "You are given a scientific article and a question. Answer the question as concisely as you can, using a single phrase or sentence if possible. If the question cannot be answered based on the information in the article, write \"unanswerable\". If the question is a yes/no question, answer \"yes\", \"no\", or \"unanswerable\". Do not provide any explanation.\n\nArticle: {context}\n\n Answer the question based on the above article as concisely as you can, using a single phrase or sentence if possible. If the question cannot be answered based on the information in the article, write \"unanswerable\". If the question is a yes/no question, answer \"yes\", \"no\", or \"unanswerable\". Do not provide any explanation.\n\nQuestion: {input}\n\nAnswer:", 4 | "multifieldqa_en": "Read the following text and answer briefly.\n\n{context}\n\nNow, answer the following question based on the above text, only give me the answer and do not output any other words.\n\nQuestion: {input}\nAnswer:", 5 | "multifieldqa_zh": "阅读以下文字并用中文简短回答:\n\n{context}\n\n现在请基于上面的文章回答下面的问题,只告诉我答案,不要输出任何其他字词。\n\n问题:{input}\n回答:", 6 | "hotpotqa": "Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nThe following are given passages.\n{context}\n\nAnswer the question based on the given passages. Only give me the answer and do not output any other words.\n\nQuestion: {input}\nAnswer:", 7 | "2wikimqa": "Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nThe following are given passages.\n{context}\n\nAnswer the question based on the given passages. Only give me the answer and do not output any other words.\n\nQuestion: {input}\nAnswer:", 8 | "musique": "Answer the question based on the given passages. Only give me the answer and do not output any other words.\n\nThe following are given passages.\n{context}\n\nAnswer the question based on the given passages. Only give me the answer and do not output any other words.\n\nQuestion: {input}\nAnswer:", 9 | "dureader": "请基于给定的文章回答下述问题。\n\n文章:{context}\n\n请基于上述文章回答下面的问题。\n\n问题:{input}\n回答:", 10 | "gov_report": "You are given a report by a government agency. Write a one-page summary of the report.\n\nReport:\n{context}\n\nNow, write a one-page summary of the report.\n\nSummary:", 11 | "qmsum": "You are given a meeting transcript and a query containing a question or instruction. Answer the query in one or more sentences.\n\nTranscript:\n{context}\n\nNow, answer the query based on the above meeting transcript in one or more sentences.\n\nQuery: {input}\nAnswer:", 12 | "multi_news": "You are given several news passages. Write a one-page summary of all news. \n\nNews:\n{context}\n\nNow, write a one-page summary of all the news.\n\nSummary:", 13 | "vcsum": "下面有一段会议记录,请你阅读后,写一段总结,总结会议的内容。\n会议记录:\n{context}\n\n会议总结:", 14 | "trec": "Please determine the type of the question below. Here are some examples of questions.\n\n{context}\n{input}", 15 | "triviaqa": "Answer the question based on the given passage. Only give me the answer and do not output any other words. The following are some examples.\n\n{context}\n\n{input}", 16 | "samsum": "Summarize the dialogue into a few short sentences. The following are some examples.\n\n{context}\n\n{input}", 17 | "lsht": "请判断给定新闻的类别,下面是一些例子。\n\n{context}\n{input}", 18 | "passage_count": "There are some paragraphs below sourced from Wikipedia. Some of them may be duplicates. Please carefully read these paragraphs and determine how many unique paragraphs there are after removing duplicates. In other words, how many non-repeating paragraphs are there in total?\n\n{context}\n\nPlease enter the final count of unique paragraphs after removing duplicates. The output format should only contain the number, such as 1, 2, 3, and so on.\n\nThe final answer is: ", 19 | "passage_retrieval_en": "Here are 30 paragraphs from Wikipedia, along with an abstract. Please determine which paragraph the abstract is from.\n\n{context}\n\nThe following is an abstract.\n\n{input}\n\nPlease enter the number of the paragraph that the abstract is from. The answer format must be like \"Paragraph 1\", \"Paragraph 2\", etc.\n\nThe answer is: ", 20 | "passage_retrieval_zh": "以下是若干段落文字,以及其中一个段落的摘要。请确定给定的摘要出自哪一段。\n\n{context}\n\n下面是一个摘要\n\n{input}\n\n请输入摘要所属段落的编号。答案格式必须是\"段落1\",\"段落2\"等格式\n\n答案是:", 21 | "lcc": "Please complete the code given below. \n{context}Next line of code:\n", 22 | "repobench-p": "Please complete the code given below. \n{context}{input}Next line of code:\n" 23 | } -------------------------------------------------------------------------------- /scripts/longbench/eval.py: -------------------------------------------------------------------------------- 1 | # The script is from https://github.com/THUDM/LongBench 2 | import os 3 | import json 4 | import argparse 5 | import numpy as np 6 | 7 | from metrics import ( 8 | qa_f1_score, 9 | rouge_zh_score, 10 | qa_f1_zh_score, 11 | rouge_score, 12 | classification_score, 13 | retrieval_score, 14 | retrieval_zh_score, 15 | count_score, 16 | code_sim_score, 17 | ) 18 | 19 | dataset2metric = { 20 | "narrativeqa": qa_f1_score, 21 | "qasper": qa_f1_score, 22 | "multifieldqa_en": qa_f1_score, 23 | "multifieldqa_zh": qa_f1_zh_score, 24 | "hotpotqa": qa_f1_score, 25 | "2wikimqa": qa_f1_score, 26 | "musique": qa_f1_score, 27 | "dureader": rouge_zh_score, 28 | "gov_report": rouge_score, 29 | "qmsum": rouge_score, 30 | "multi_news": rouge_score, 31 | "vcsum": rouge_zh_score, 32 | "trec": classification_score, 33 | "triviaqa": qa_f1_score, 34 | "samsum": rouge_score, 35 | "lsht": classification_score, 36 | "passage_retrieval_en": retrieval_score, 37 | "passage_count": count_score, 38 | "passage_retrieval_zh": retrieval_zh_score, 39 | "lcc": code_sim_score, 40 | "repobench-p": code_sim_score, 41 | } 42 | 43 | def parse_args(args=None): 44 | parser = argparse.ArgumentParser() 45 | parser.add_argument('--output_dir') 46 | parser.add_argument('--e', action='store_true', help="Evaluate on LongBench-E") 47 | return parser.parse_args(args) 48 | 49 | def scorer_e(dataset, predictions, answers, lengths, all_classes): 50 | scores = {"0-4k": [], "4-8k": [], "8k+": []} 51 | for (prediction, ground_truths, length) in zip(predictions, answers, lengths): 52 | score = 0. 53 | if dataset in ["trec", "triviaqa", "samsum", "lsht"]: 54 | prediction = prediction.lstrip('\n').split('\n')[0] 55 | for ground_truth in ground_truths: 56 | score = max(score, dataset2metric[dataset](prediction, ground_truth, all_classes=all_classes)) 57 | if length < 4000: 58 | scores["0-4k"].append(score) 59 | elif length < 8000: 60 | scores["4-8k"].append(score) 61 | else: 62 | scores["8k+"].append(score) 63 | for key in scores.keys(): 64 | scores[key] = round(100 * np.mean(scores[key]), 2) 65 | return scores 66 | 67 | def scorer(dataset, predictions, answers, all_classes): 68 | total_score = 0. 69 | for (prediction, ground_truths) in zip(predictions, answers): 70 | score = 0. 71 | if dataset in ["trec", "triviaqa", "samsum", "lsht"]: 72 | prediction = prediction.lstrip('\n').split('\n')[0] 73 | for ground_truth in ground_truths: 74 | score = max(score, dataset2metric[dataset](prediction, ground_truth, all_classes=all_classes)) 75 | total_score += score 76 | return round(100 * total_score / len(predictions), 2) 77 | 78 | if __name__ == '__main__': 79 | args = parse_args() 80 | scores = dict() 81 | if args.e: 82 | path = f"{args.output_dir}/pred_e/" 83 | else: 84 | path = f"{args.output_dir}/pred/" 85 | all_files = os.listdir(path) 86 | print("Evaluating on:", all_files) 87 | for filename in all_files: 88 | if not filename.endswith("jsonl"): 89 | continue 90 | predictions, answers, lengths = [], [], [] 91 | dataset = filename.split('.')[0] 92 | with open(f"{path}{filename}", "r", encoding="utf-8") as f: 93 | print(filename) 94 | for line in f: 95 | data = json.loads(line) 96 | predictions.append(data["pred"]) 97 | answers.append(data["answers"]) 98 | all_classes = data["all_classes"] 99 | if "length" in data: 100 | lengths.append(data["length"]) 101 | if args.e: 102 | score = scorer_e(dataset, predictions, answers, lengths, all_classes) 103 | else: 104 | score = scorer(dataset, predictions, answers, all_classes) 105 | scores[dataset] = score 106 | if args.e: 107 | out_path = f"{args.output_dir}/pred_e/result.json" 108 | else: 109 | out_path = f"{args.output_dir}/pred/result.json" 110 | with open(out_path, "w") as f: 111 | json.dump(scores, f, ensure_ascii=False, indent=4) 112 | -------------------------------------------------------------------------------- /scripts/longbench/metrics.py: -------------------------------------------------------------------------------- 1 | # The script is from https://github.com/THUDM/LongBench 2 | import re 3 | import string 4 | 5 | import jieba 6 | from fuzzywuzzy import fuzz 7 | import difflib 8 | 9 | from collections import Counter 10 | from rouge import Rouge 11 | 12 | def normalize_answer(s): 13 | """Lower text and remove punctuation, articles and extra whitespace.""" 14 | 15 | def remove_articles(text): 16 | return re.sub(r"\b(a|an|the)\b", " ", text) 17 | 18 | def white_space_fix(text): 19 | return " ".join(text.split()) 20 | 21 | def remove_punc(text): 22 | exclude = set(string.punctuation) 23 | return "".join(ch for ch in text if ch not in exclude) 24 | 25 | def lower(text): 26 | return text.lower() 27 | 28 | return white_space_fix(remove_articles(remove_punc(lower(s)))) 29 | 30 | 31 | def normalize_zh_answer(s): 32 | """Lower text and remove punctuation, extra whitespace.""" 33 | 34 | def white_space_fix(text): 35 | return "".join(text.split()) 36 | 37 | def remove_punc(text): 38 | cn_punctuation = "!?。。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏." 39 | all_punctuation = set(string.punctuation + cn_punctuation) 40 | return "".join(ch for ch in text if ch not in all_punctuation) 41 | 42 | def lower(text): 43 | return text.lower() 44 | 45 | return white_space_fix(remove_punc(lower(s))) 46 | 47 | def count_score(prediction, ground_truth, **kwargs): 48 | numbers = re.findall(r"\d+", prediction) 49 | right_num = 0 50 | for number in numbers: 51 | if str(number) == str(ground_truth): 52 | right_num += 1 53 | final_score = 0.0 if len(numbers) == 0 else right_num / len(numbers) 54 | return float(final_score) 55 | 56 | def retrieval_score(prediction, ground_truth, **kwargs): 57 | pattern = r'Paragraph (\d+)' 58 | matches = re.findall(pattern, ground_truth) 59 | ground_truth_id = matches[0] 60 | numbers = re.findall(r"\d+", prediction) 61 | right_num = 0 62 | for number in numbers: 63 | if str(number) == str(ground_truth_id): 64 | right_num += 1 65 | final_score = 0.0 if len(numbers) == 0 else right_num / len(numbers) 66 | return float(final_score) 67 | 68 | def retrieval_zh_score(prediction, ground_truth, **kwargs): 69 | pattern = r'段落(\d+)' 70 | matches = re.findall(pattern, ground_truth) 71 | ground_truth_id = matches[0] 72 | numbers = re.findall(r"\d+", prediction) 73 | right_num = 0 74 | for number in numbers: 75 | if str(number) == str(ground_truth_id): 76 | right_num += 1 77 | final_score = 0.0 if len(numbers) == 0 else right_num / len(numbers) 78 | return float(final_score) 79 | 80 | def code_sim_score(prediction, ground_truth, **kwargs): 81 | all_lines = prediction.lstrip('\n').split('\n') 82 | prediction = "" 83 | for line in all_lines: 84 | if ('`' not in line) and ('#' not in line) and ('//' not in line): 85 | prediction = line 86 | break 87 | return (fuzz.ratio(prediction, ground_truth) / 100) 88 | 89 | def classification_score(prediction, ground_truth, **kwargs): 90 | em_match_list = [] 91 | all_classes = kwargs["all_classes"] 92 | for class_name in all_classes: 93 | if class_name in prediction: 94 | em_match_list.append(class_name) 95 | for match_term in em_match_list: 96 | if match_term in ground_truth and match_term != ground_truth: 97 | em_match_list.remove(match_term) 98 | if em_match_list != 0: 99 | if ground_truth in em_match_list: 100 | score = (1.0 / len(em_match_list)) 101 | else: 102 | score = 0.0 103 | else: 104 | best_match = None 105 | highest_similarity = 0 106 | for string in all_classes: 107 | similarity = difflib.SequenceMatcher(None, string, prediction).ratio() 108 | if similarity > highest_similarity: 109 | highest_similarity = similarity 110 | best_match = string 111 | score = float(best_match == ground_truth) 112 | return score 113 | 114 | def rouge_score(prediction, ground_truth, **kwargs): 115 | rouge = Rouge() 116 | try: 117 | scores = rouge.get_scores([prediction], [ground_truth], avg=True) 118 | except Exception: 119 | return 0.0 120 | return scores["rouge-l"]["f"] 121 | 122 | def rouge_zh_score(prediction, ground_truth, **kwargs): 123 | prediction = " ".join(list(jieba.cut(prediction, cut_all=False))) 124 | ground_truth = " ".join(list(jieba.cut(ground_truth, cut_all=False))) 125 | score = rouge_score(prediction, ground_truth) 126 | return score 127 | 128 | def f1_score(prediction, ground_truth, **kwargs): 129 | common = Counter(prediction) & Counter(ground_truth) 130 | num_same = sum(common.values()) 131 | if num_same == 0: 132 | return 0 133 | precision = 1.0 * num_same / len(prediction) 134 | recall = 1.0 * num_same / len(ground_truth) 135 | f1 = (2 * precision * recall) / (precision + recall) 136 | return f1 137 | 138 | def qa_f1_score(prediction, ground_truth, **kwargs): 139 | normalized_prediction = normalize_answer(prediction) 140 | normalized_ground_truth = normalize_answer(ground_truth) 141 | 142 | prediction_tokens = normalized_prediction.split() 143 | ground_truth_tokens = normalized_ground_truth.split() 144 | return f1_score(prediction_tokens, ground_truth_tokens) 145 | 146 | 147 | def qa_f1_zh_score(prediction, ground_truth, **kwargs): 148 | prediction_tokens = list(jieba.cut(prediction, cut_all=False)) 149 | ground_truth_tokens = list(jieba.cut(ground_truth, cut_all=False)) 150 | prediction_tokens = [normalize_zh_answer(token) for token in prediction_tokens] 151 | ground_truth_tokens = [normalize_zh_answer(token) for token in ground_truth_tokens] 152 | prediction_tokens = [token for token in prediction_tokens if len(token) > 0] 153 | ground_truth_tokens = [token for token in ground_truth_tokens if len(token) > 0] 154 | return f1_score(prediction_tokens, ground_truth_tokens) 155 | -------------------------------------------------------------------------------- /scripts/longbench/pred_llama2.py: -------------------------------------------------------------------------------- 1 | # The script is modified from https://github.com/THUDM/LongBench/blob/main/pred.py 2 | from datasets import load_dataset 3 | import torch 4 | import random 5 | import numpy as np 6 | import json 7 | from transformers import LlamaTokenizer, AutoModelForCausalLM 8 | from transformers import BitsAndBytesConfig 9 | from tqdm import tqdm 10 | import os 11 | import argparse 12 | import sys 13 | parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 14 | sys.path.append(parent_dir) 15 | from attn_and_long_ctx_patches import apply_attention_patch, apply_ntk_scaling_patch 16 | 17 | dir_path = os.path.dirname(os.path.realpath(__file__)) 18 | 19 | DEFAULT_SYSTEM_PROMPT = """You are a helpful assistant. 你是一个乐于助人的助手。""" 20 | 21 | TEMPLATE = ( 22 | "[INST] <>\n" 23 | "{system_prompt}\n" 24 | "<>\n\n" 25 | "{instruction} [/INST]" 26 | ) 27 | 28 | parser = argparse.ArgumentParser() 29 | parser.add_argument('--model_path', type=str) 30 | parser.add_argument('--load_in_4bit',action='store_true') 31 | parser.add_argument('--load_in_8bit',action='store_true') 32 | parser.add_argument('--predict_on',type=str, default='zh') 33 | parser.add_argument('--output_dir',type=str, default='pred') 34 | parser.add_argument('--gpus',type=str, default=None) 35 | parser.add_argument('--max_length',type=int, default=4096-512) 36 | parser.add_argument('--alpha', type=str, default="auto", help="The scaling factor of NTK method, can be a float or 'auto'. ") 37 | parser.add_argument('--with_inst', choices=['true','false','auto'], default = 'false', 38 | help="Whether use the system prompt and template of Chinese-Alpaca-2 when constructing the instructions.") 39 | parser.add_argument('--e', action='store_true', help="Evaluate on LongBench-E") 40 | parser.add_argument('--use_flash_attention_2', action='store_true', help="Use flash attention to replace the LLaMA attention") 41 | parser.add_argument('--use_ntk', action='store_true', help="Use dynamic-ntk to extend context window") 42 | 43 | 44 | args = parser.parse_args() 45 | 46 | model_path = args.model_path 47 | load_in_4bit = args.load_in_4bit 48 | load_in_8bit = args.load_in_8bit 49 | predict_on = args.predict_on 50 | output_dir = args.output_dir 51 | gpus=args.gpus 52 | max_length = args.max_length 53 | alpha = args.alpha 54 | 55 | DO_SAMPLE =True 56 | TEMPERATURE = 0.2 57 | REPETITION_PENALTY = 1.1 58 | TOP_P = 0.95 59 | TOP_K = 40 60 | 61 | if gpus is not None: 62 | os.environ["CUDA_VISIBLE_DEVICES"] = gpus 63 | apply_attention_patch(use_memory_efficient_attention=True) 64 | if args.use_ntk: 65 | apply_ntk_scaling_patch(args.alpha) 66 | 67 | 68 | def fill_llama2_prompt_template(instruction, with_inst = True, with_system_prompt = True, system_prompt = DEFAULT_SYSTEM_PROMPT): 69 | if with_inst is False: 70 | return instruction 71 | if with_system_prompt is True: 72 | return TEMPLATE.format_map({'instruction': instruction,'system_prompt': system_prompt}) 73 | else: 74 | return "[INST] {instruction} [/INST]" 75 | 76 | 77 | def get_pred(model, tokenizer, data, max_length, max_gen, prompt_format, dataset, device): 78 | preds = [] 79 | for json_obj in tqdm(data): 80 | prompt = prompt_format.format(**json_obj) 81 | # truncate to fit max_length (we suggest truncate in the middle, since the left and right side may contain crucial instructions) 82 | tokenized_prompt = tokenizer(prompt, truncation=False, return_tensors="pt").input_ids[0] 83 | if len(tokenized_prompt) > max_length: 84 | half = int(max_length/2) 85 | prompt = tokenizer.decode(tokenized_prompt[:half], skip_special_tokens=True)+tokenizer.decode(tokenized_prompt[-half:], skip_special_tokens=True) 86 | if args.with_inst == 'auto': 87 | if dataset not in ["trec", "triviaqa", "samsum", "lsht", "lcc", "repobench-p"]: # chat models are better off without build prompts on these tasks 88 | prompt = fill_llama2_prompt_template(instruction=prompt) 89 | elif args.with_inst == 'true': 90 | prompt = fill_llama2_prompt_template(instruction=prompt, with_inst = True) 91 | else: 92 | prompt = fill_llama2_prompt_template(instruction=prompt, with_inst = False) 93 | 94 | input_data = tokenizer(prompt, truncation=False, return_tensors="pt").to(device) 95 | context_length = input_data.input_ids.shape[-1] 96 | if dataset == "samsum": # prevent illegal output on samsum (model endlessly repeat "\nDialogue"), might be a prompting issue 97 | output = model.generate( 98 | **input_data, 99 | max_new_tokens=max_gen, 100 | num_beams=1, 101 | do_sample=DO_SAMPLE, 102 | repetition_penalty = REPETITION_PENALTY, 103 | top_p = TOP_P, 104 | top_k = TOP_K, 105 | temperature=TEMPERATURE, 106 | min_length=context_length+1, 107 | eos_token_id=[tokenizer.eos_token_id, tokenizer.encode("\n", add_special_tokens=False)[-1]], 108 | )[0] 109 | else: 110 | output = model.generate( 111 | **input_data, 112 | max_new_tokens=max_gen, 113 | num_beams=1, 114 | do_sample=DO_SAMPLE, 115 | repetition_penalty = REPETITION_PENALTY, 116 | top_p = TOP_P, 117 | top_k = TOP_K, 118 | temperature=TEMPERATURE 119 | )[0] 120 | pred = tokenizer.decode(output[context_length:], skip_special_tokens=True) 121 | #print(pred) 122 | preds.append({"pred": pred, "answers": json_obj["answers"], "all_classes": json_obj["all_classes"], "length": json_obj["length"]}) 123 | return preds 124 | 125 | def seed_everything(seed): 126 | torch.manual_seed(seed) 127 | torch.cuda.manual_seed(seed) 128 | np.random.seed(seed) 129 | random.seed(seed) 130 | torch.backends.cudnn.benchmark = False 131 | torch.backends.cudnn.deterministic = True 132 | torch.cuda.manual_seed_all(seed) 133 | 134 | if __name__ == '__main__': 135 | seed_everything(42) 136 | load_type = torch.float16 137 | if torch.cuda.is_available(): 138 | device = torch.device(0) 139 | else: 140 | device = torch.device('cpu') 141 | 142 | if args.e: 143 | en_datasets = [ "hotpotqa","2wikimqa", 144 | "qasper", "multifieldqa_en", "gov_report", 145 | "trec", "samsum", "triviaqa", 146 | "passage_count", "passage_retrieval_en", "multi_news"] 147 | zh_datasets = [] 148 | code_datasets = [ "lcc", "repobench-p" ] 149 | if not os.path.exists(f"{output_dir}/pred_e"): 150 | os.makedirs(f"{output_dir}/pred_e") 151 | else: 152 | en_datasets = [ "hotpotqa","2wikimqa", "musique", "narrativeqa", 153 | "qasper", "multifieldqa_en", "gov_report", 154 | "qmsum", "trec", "samsum", "triviaqa", 155 | "passage_count", "passage_retrieval_en", "multi_news"] 156 | zh_datasets = [ "dureader", "multifieldqa_zh", 157 | "vcsum","lsht", "passage_retrieval_zh"] 158 | code_datasets = [ "lcc", "repobench-p" ] 159 | 160 | if not os.path.exists(f"{output_dir}/pred"): 161 | os.makedirs(f"{output_dir}/pred") 162 | 163 | datasets = [] 164 | for data_type in predict_on.split(','): 165 | if data_type == 'zh': 166 | datasets += zh_datasets 167 | elif data_type == 'en': 168 | datasets += en_datasets 169 | elif data_type == 'code': 170 | datasets += code_datasets 171 | print(datasets) 172 | 173 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 174 | 175 | tokenizer = LlamaTokenizer.from_pretrained(model_path, legacy=True) 176 | model = None 177 | if args.load_in_4bit or args.load_in_8bit: 178 | quantization_config = BitsAndBytesConfig( 179 | load_in_4bit=args.load_in_4bit, 180 | load_in_8bit=args.load_in_8bit, 181 | bnb_4bit_compute_dtype=load_type, 182 | ) 183 | model = AutoModelForCausalLM.from_pretrained( 184 | model_path, 185 | torch_dtype=load_type, 186 | low_cpu_mem_usage=True, 187 | device_map='auto', 188 | quantization_config=quantization_config if (args.load_in_4bit or args.load_in_8bit) else None, 189 | use_flash_attention_2=args.use_flash_attention_2, 190 | trust_remote_code=True 191 | ) 192 | model = model.eval() 193 | model_vocab_size = model.get_input_embeddings().weight.size(0) 194 | print(f"Vocab of the base model: {model_vocab_size}") 195 | tokenizer_vocab_size = len(tokenizer) 196 | print(f"Vocab of the tokenizer: {tokenizer_vocab_size}") 197 | 198 | # we design specific prompt format and max generation length for each task, feel free to modify them to optimize model output 199 | dataset2prompt = json.load(open(dir_path + "/config/dataset2prompt.json", "r")) 200 | dataset2maxlen = json.load(open(dir_path + "/config/dataset2maxlen.json", "r")) 201 | # predict on each dataset 202 | for dataset in datasets: 203 | print(f"Loading dataset {dataset}") 204 | if args.e: 205 | data = load_dataset('THUDM/LongBench', dataset+'_e', split='test') 206 | output_path = f"{output_dir}/pred_e/{dataset}.jsonl" 207 | else: 208 | data = load_dataset('THUDM/LongBench', dataset, split='test') 209 | output_path = f"{output_dir}/pred/{dataset}.jsonl" 210 | prompt_format = dataset2prompt[dataset] 211 | max_gen = dataset2maxlen[dataset] 212 | preds = get_pred(model, tokenizer, data, max_length, max_gen, prompt_format, dataset, device) 213 | with open(output_path, "w", encoding="utf-8") as f: 214 | for pred in preds: 215 | json.dump(pred, f, ensure_ascii=False) 216 | f.write('\n') 217 | -------------------------------------------------------------------------------- /scripts/longbench/requirements.txt: -------------------------------------------------------------------------------- 1 | datasets 2 | tqdm 3 | rouge 4 | jieba 5 | fuzzywuzzy 6 | torch 7 | transformers==4.35.0 8 | einops -------------------------------------------------------------------------------- /scripts/openai_server_demo/README.md: -------------------------------------------------------------------------------- 1 | # OPENAI API DEMO 2 | 3 | > 更加详细的OPENAI API信息: 4 | 5 | 这是一个使用fastapi实现的简易的仿OPENAI API风格的服务器DEMO,您可以使用这个API DEMO来快速搭建基于中文大模型的个人网站以及其他有趣的WEB DEMO。 6 | 7 | ## 部署方式 8 | 9 | 安装依赖 10 | ``` shell 11 | pip install fastapi uvicorn shortuuid sse_starlette 12 | ``` 13 | 14 | 启动脚本 15 | ``` shell 16 | python scripts/openai_server_demo/openai_api_server.py --base_model /path/to/base_model --lora_model /path/to/lora_model --gpus 0,1 17 | ``` 18 | 19 | ### 参数说明 20 | 21 | `--base_model {base_model}`:存放HF格式的LLaMA-2模型权重和配置文件的目录,可以是合并后的中文Alpaca-2模型(此时无需提供`--lora_model`),也可以是转后HF格式后的原版LLaMA-2模型(需要提供`--lora_model`) 22 | 23 | `--lora_model {lora_model}`:中文Alpaca-2 LoRA解压后文件所在目录,也可使用🤗Model Hub模型调用名称。若不提供此参数,则只加载--base_model指定的模型 24 | 25 | `--tokenizer_path {tokenizer_path}`:存放对应tokenizer的目录。若不提供此参数,则其默认值与`--lora_model`相同;若也未提供`--lora_model`参数,则其默认值与--base_model相同 26 | 27 | `--only_cpu`:仅使用CPU进行推理 28 | 29 | `--gpus {gpu_ids}`:指定使用的GPU设备编号,默认为0。如使用多张GPU,以逗号分隔,如0,1,2 30 | 31 | `--load_in_8bit`:使用8bit模型进行推理,可节省显存,但可能影响模型效果 32 | 33 | `--alpha {alpha}`:使用NTK方法拓展上下文长度的系数,可以提升可处理的输入长度。默认为1。如果不知道怎么设置,可以保持默认值,或设为`"auto"` 34 | 35 | `--use_ntk`:使用NTK方法拓展上下文长度,只对基础版和16K版有效,64K版无需设置该参数。 36 | 37 | `--use_flash_attention_2`:使用flash-attention2加速推理。 38 | 39 | ## API文档 40 | 41 | ### 文字接龙(completion) 42 | 43 | > 有关completion的中文翻译,李宏毅教授将其翻译为文字接龙 44 | 45 | 最基础的API接口,输入prompt,输出语言大模型的文字接龙(completion)结果。 46 | 47 | API DEMO内置有prompt模板,prompt将被套入instruction模板中,这里输入的prompt应更像指令而非对话。 48 | 49 | #### 快速体验completion接口 50 | 51 | 请求command: 52 | 53 | ``` shell 54 | curl http://localhost:19327/v1/completions \ 55 | -H "Content-Type: application/json" \ 56 | -d '{ 57 | "prompt": "告诉我中国的首都在哪里" 58 | }' 59 | ``` 60 | 61 | json返回体: 62 | 63 | ``` json 64 | { 65 | "id": "cmpl-3watqWsbmYgbWXupsSik7s", 66 | "object": "text_completion", 67 | "created": 1686067311, 68 | "model": "chinese-llama-alpaca-2", 69 | "choices": [ 70 | { 71 | "index": 0, 72 | "text": "中国的首都是北京。" 73 | } 74 | ] 75 | } 76 | ``` 77 | 78 | #### completion接口高级参数 79 | 80 | 请求command: 81 | 82 | ``` shell 83 | curl http://localhost:19327/v1/completions \ 84 | -H "Content-Type: application/json" \ 85 | -d '{ 86 | "prompt": "告诉我中国和美国分别各有哪些优点缺点", 87 | "max_tokens": 90, 88 | "temperature": 0.7, 89 | "num_beams": 4, 90 | "top_k": 40 91 | }' 92 | ``` 93 | 94 | json返回体: 95 | 96 | ``` json 97 | { 98 | "id": "cmpl-PvVwfMq2MVWHCBKiyYJfKM", 99 | "object": "text_completion", 100 | "created": 1686149471, 101 | "model": "chinese-llama-alpaca-2", 102 | "choices": [ 103 | { 104 | "index": 0, 105 | "text": "中国的优点是拥有丰富的文化和历史,而美国的优点是拥有先进的科技和经济体系。" 106 | } 107 | ] 108 | } 109 | ``` 110 | 111 | #### completion接口高级参数说明 112 | 113 | > 有关Decoding策略,更加详细的细节可以参考 该文章详细讲述了三种LLaMA会用到的Decoding策略:Greedy Decoding、Random Sampling 和 Beam Search,Decoding策略是top_k、top_p、temperature、num_beam等高级参数的基础。 114 | 115 | `prompt`: 生成文字接龙(completion)的提示。 116 | 117 | `max_tokens`: 新生成的句子的token长度。 118 | 119 | `temperature`: 在0和2之间选择的采样温度。较高的值如0.8会使输出更加随机,而较低的值如0.2则会使其输出更具有确定性。temperature越高,使用随机采样最为decoding的概率越大。 120 | 121 | `num_beams`: 当搜索策略为束搜索(beam search)时,该参数为在束搜索(beam search)中所使用的束个数,当num_beams=1时,实际上就是贪心搜索(greedy decoding)。 122 | 123 | `top_k`: 在随机采样(random sampling)时,前top_k高概率的token将作为候选token被随机采样。 124 | 125 | `top_p`: 在随机采样(random sampling)时,累积概率超过top_p的token将作为候选token被随机采样,越低随机性越大,举个例子,当top_p设定为0.6时,概率前5的token概率分别为{0.23, 0.20, 0.18, 0.11, 0.10}时,前三个token的累积概率为0.61,那么第4个token将被过滤掉,只有前三的token将作为候选token被随机采样。 126 | 127 | `repetition_penalty`: 重复惩罚,具体细节可以参考这篇文章: 。 128 | 129 | `do_sample`: 启用随机采样策略。默认为true。 130 | 131 | ### 聊天(chat completion) 132 | 133 | 聊天接口支持多轮对话 134 | 135 | #### 快速体验聊天接口 136 | 137 | 请求command: 138 | 139 | ``` shell 140 | curl http://localhost:19327/v1/chat/completions \ 141 | -H "Content-Type: application/json" \ 142 | -d '{ 143 | "messages": [ 144 | {"role": "user","content": "给我讲一些有关杭州的故事吧"} 145 | ], 146 | "repetition_penalty": 1.0 147 | }' 148 | ``` 149 | 150 | json返回体: 151 | 152 | ``` json 153 | { 154 | "id": "chatcmpl-5L99pYoW2ov5ra44Ghwupt", 155 | "object": "chat.completion", 156 | "created": 1686143170, 157 | "model": "chinese-llama-alpaca-2", 158 | "choices": [ 159 | { 160 | "index": 0, 161 | "message": { 162 | "role": "user", 163 | "content": "给我讲一些有关杭州的故事吧" 164 | } 165 | }, 166 | { 167 | "index": 1, 168 | "message": { 169 | "role": "assistant", 170 | "content": "好的,请问您对杭州有什么特别的偏好吗?" 171 | } 172 | } 173 | ] 174 | } 175 | ``` 176 | 177 | #### 多轮对话 178 | 179 | 请求command: 180 | 181 | ``` shell 182 | curl http://localhost:19327/v1/chat/completions \ 183 | -H "Content-Type: application/json" \ 184 | -d '{ 185 | "messages": [ 186 | {"role": "user","content": "给我讲一些有关杭州的故事吧"}, 187 | {"role": "assistant","content": "好的,请问您对杭州有什么特别的偏好吗?"}, 188 | {"role": "user","content": "我比较喜欢和西湖,可以给我讲一下西湖吗"} 189 | ], 190 | "repetition_penalty": 1.0 191 | }' 192 | ``` 193 | 194 | json返回体: 195 | 196 | ``` json 197 | { 198 | "id": "chatcmpl-hmvrQNPGYTcLtmYruPJbv6", 199 | "object": "chat.completion", 200 | "created": 1686143439, 201 | "model": "chinese-llama-alpaca-2", 202 | "choices": [ 203 | { 204 | "index": 0, 205 | "message": { 206 | "role": "user", 207 | "content": "给我讲一些有关杭州的故事吧" 208 | } 209 | }, 210 | { 211 | "index": 1, 212 | "message": { 213 | "role": "assistant", 214 | "content": "好的,请问您对杭州有什么特别的偏好吗?" 215 | } 216 | }, 217 | { 218 | "index": 2, 219 | "message": { 220 | "role": "user", 221 | "content": "我比较喜欢和西湖,可以给我讲一下西湖吗" 222 | } 223 | }, 224 | { 225 | "index": 3, 226 | "message": { 227 | "role": "assistant", 228 | "content": "是的,西湖是杭州最著名的景点之一,它被誉为“人间天堂”。 <\\s>" 229 | } 230 | } 231 | ] 232 | } 233 | ``` 234 | 235 | #### 聊天接口高级参数说明 236 | 237 | `prompt`: 生成文字接龙(completion)的提示。 238 | 239 | `max_tokens`: 新生成的句子的token长度。 240 | 241 | `temperature`: 在0和2之间选择的采样温度。较高的值如0.8会使输出更加随机,而较低的值如0.2则会使其输出更具有确定性。temperature越高,使用随机采样最为decoding的概率越大。 242 | 243 | `num_beams`: 当搜索策略为束搜索(beam search)时,该参数为在束搜索(beam search)中所使用的束个数,当num_beams=1时,实际上就是贪心搜索(greedy decoding)。 244 | 245 | `top_k`: 在随机采样(random sampling)时,前top_k高概率的token将作为候选token被随机采样。 246 | 247 | `top_p`: 在随机采样(random sampling)时,累积概率超过top_p的token将作为候选token被随机采样,越低随机性越大,举个例子,当top_p设定为0.6时,概率前5的token概率分别为[0.23, 0.20, 0.18, 0.11, 0.10]时,前三个token的累积概率为0.61,那么第4个token将被过滤掉,只有前三的token将作为候选token被随机采样。 248 | 249 | `repetition_penalty`: 重复惩罚,具体细节可以参考这篇文章: 。 250 | 251 | `do_sample`: 启用随机采样策略。默认为true。 252 | 253 | `stream`: OpenAI格式的流式返回。默认为false,设置为true时,会按照OpenAI的格式流式返回数据,可以作为任意基于ChatGPT的应用的后端。 254 | 255 | ### 文本嵌入向量(text embedding) 256 | 257 | 文本嵌入向量有很多作用,包括但不限于基于大型文档问答、总结一本书中的内容、为大语言模型找到与当前用户输入最相近的记忆等等。 258 | 259 | 请求command: 260 | 261 | ``` shell 262 | curl http://localhost:19327/v1/embeddings \ 263 | -H "Content-Type: application/json" \ 264 | -d '{ 265 | "input": "今天天气真不错" 266 | }' 267 | ``` 268 | 269 | json返回体: 270 | 271 | ``` json 272 | { 273 | "object": "list", 274 | "data": [ 275 | { 276 | "object": "embedding", 277 | "embedding": [ 278 | 0.003643923671916127, 279 | -0.0072653163224458694, 280 | 0.0075545101426541805, 281 | ...., 282 | 0.0045851171016693115 283 | ], 284 | "index": 0 285 | } 286 | ], 287 | "model": "chinese-llama-alpaca-2" 288 | } 289 | ``` 290 | 291 | embedding向量的长度与所使用模型hidden size相同。比如当使用7B模型时,embedding的长度为4096。 292 | -------------------------------------------------------------------------------- /scripts/openai_server_demo/README_vllm.md: -------------------------------------------------------------------------------- 1 | # OPENAI API DEMO 2 | 3 | > 更加详细的OPENAI API信息: 4 | 5 | 这是一个使用fastapi实现的简易的仿OPENAI API风格的服务器DEMO,您可以使用这个API DEMO来快速搭建基于中文大模型的个人网站以及其他有趣的WEB DEMO。 6 | 7 | 本实现基于vLLM部署LLM后端服务,暂不支持加载LoRA模型、仅CPU部署和使用8bit推理。 8 | 9 | ## 部署方式 10 | 11 | 安装依赖 12 | ``` shell 13 | pip install fastapi uvicorn shortuuid vllm fschat 14 | ``` 15 | 16 | 启动脚本 17 | ``` shell 18 | python scripts/openai_server_demo/openai_api_server_vllm.py --model /path/to/base_model --tokenizer-mode slow --served-model-name chinese-llama-alpaca-2 19 | ``` 20 | 21 | ### 参数说明 22 | 23 | `--model {base_model}`: 存放HF格式的LLaMA-2模型权重和配置文件的目录,可以是合并后的中文Alpaca-2模型 24 | 25 | `--tokenizer {tokenizer_path}`: 存放对应tokenizer的目录。若不提供此参数,则其默认值与`--base_model`相同 26 | 27 | `--tokenizer-mode {tokenizer-mode}`: tokenizer的模式。使用基于LLaMA/LLaMa-2的模型时,固定为`slow` 28 | 29 | `--tensor-parallel-size {tensor_parallel_size}`: 使用的GPU数量。默认为1 30 | 31 | `--served-model-name {served-model-name}`: API中使用的模型名。若使用中文Alpaca-2系列模型,模型名中务必包含`chinese-llama-alpaca-2` 32 | 33 | `--host {host_name}`: 部署服务的host name。默认值是`localhost` 34 | 35 | `--port {port}`: 部署服务的端口号。默认值是`8000` 36 | 37 | ## API文档 38 | 39 | ### 文字接龙(completion) 40 | 41 | > 有关completion的中文翻译,李宏毅教授将其翻译为文字接龙 42 | 43 | 最基础的API接口,输入prompt,输出语言大模型的文字接龙(completion)结果。 44 | 45 | API DEMO内置有prompt模板,prompt将被套入instruction模板中,这里输入的prompt应更像指令而非对话。 46 | 47 | #### 快速体验completion接口 48 | 49 | 请求command: 50 | 51 | ``` shell 52 | curl http://localhost:8000/v1/completions \ 53 | -H "Content-Type: application/json" \ 54 | -d '{ 55 | "model": "chinese-llama-alpaca-2", 56 | "prompt": "告诉我中国的首都在哪里" 57 | }' 58 | ``` 59 | 60 | json返回体: 61 | 62 | ``` json 63 | { 64 | "id": "cmpl-41234d71fa034ec3ae90bbf6b5be7", 65 | "object": "text_completion", 66 | "created": 1690870733, 67 | "model": "chinese-llama-alpaca-2", 68 | "choices": [ 69 | { 70 | "index": 0, 71 | "text": "中国的首都是北京。" 72 | } 73 | ] 74 | } 75 | ``` 76 | 77 | #### completion接口高级参数 78 | 79 | 请求command: 80 | 81 | ``` shell 82 | curl http://localhost:8000/v1/completions \ 83 | -H "Content-Type: application/json" \ 84 | -d '{ 85 | "model": "chinese-llama-alpaca-2", 86 | "prompt": "告诉我中国和美国分别各有哪些优点缺点", 87 | "max_tokens": 90, 88 | "temperature": 0.7, 89 | "num_beams": 4, 90 | "top_k": 40 91 | }' 92 | ``` 93 | 94 | json返回体: 95 | 96 | ``` json 97 | { 98 | "id": "cmpl-ceca9906bf0a429989e850368cc3f893", 99 | "object": "text_completion", 100 | "created": 1690870952, 101 | "model": "chinese-llama-alpaca-2", 102 | "choices": [ 103 | { 104 | "index": 0, 105 | "text": "中国的优点是拥有丰富的文化和历史,而美国的优点是拥有先进的科技和经济体系。" 106 | } 107 | ] 108 | } 109 | ``` 110 | 111 | #### completion接口高级参数说明 112 | 113 | > 有关Decoding策略,更加详细的细节可以参考 该文章详细讲述了三种LLaMA会用到的Decoding策略:Greedy Decoding、Random Sampling 和 Beam Search,Decoding策略是top_k、top_p、temperature等高级参数的基础。 114 | 115 | `prompt`: 生成文字接龙(completion)的提示。 116 | 117 | `max_tokens`: 新生成的句子的token长度。 118 | 119 | `temperature`: 在0和2之间选择的采样温度。较高的值如0.8会使输出更加随机,而较低的值如0.2则会使其输出更具有确定性。temperature越高,使用随机采样最为decoding的概率越大。 120 | 121 | `use_beam_search`: 使用束搜索(beam search)。默认为`false`,即启用随机采样策略(random sampling) 122 | 123 | `n`: 输出序列的数量,默认为1 124 | 125 | `best_of`: 当搜索策略为束搜索(beam search)时,该参数为在束搜索(beam search)中所使用的束个数。默认和`n`相同 126 | 127 | `top_k`: 在随机采样(random sampling)时,前top_k高概率的token将作为候选token被随机采样。 128 | 129 | `top_p`: 在随机采样(random sampling)时,累积概率超过top_p的token将作为候选token被随机采样,越低随机性越大,举个例子,当top_p设定为0.6时,概率前5的token概率分别为{0.23, 0.20, 0.18, 0.11, 0.10}时,前三个token的累积概率为0.61,那么第4个token将被过滤掉,只有前三的token将作为候选token被随机采样。 130 | 131 | `presence_penalty`: 重复惩罚,取值范围-2 ~ 2,默认值为0。值大于0表示鼓励模型使用新的token,反之鼓励重复。 132 | 133 | `stream`: 设置为`true`时,按流式输出的形式返回。默认为`false`。 134 | 135 | 136 | ### 聊天(chat completion) 137 | 138 | 聊天接口支持多轮对话 139 | 140 | #### 快速体验聊天接口 141 | 142 | 请求command: 143 | 144 | ``` shell 145 | curl http://localhost:8000/v1/chat/completions \ 146 | -H "Content-Type: application/json" \ 147 | -d '{ 148 | "model": "chinese-llama-alpaca-2", 149 | "messages": [ 150 | {"role": "user","content": "给我讲一些有关杭州的故事吧"} 151 | ] 152 | }' 153 | ``` 154 | 155 | json返回体: 156 | 157 | ``` json 158 | { 159 | "id": "cmpl-8fc1b6356cf64681a41a8739445a8cf8", 160 | "object": "chat.completion", 161 | "created": 1690872695, 162 | "model": "chinese-llama-alpaca-2", 163 | "choices": [ 164 | { 165 | "index": 0, 166 | "message": { 167 | "role": "assistant", 168 | "content": "好的,请问您对杭州有什么特别的偏好吗?" 169 | } 170 | } 171 | ] 172 | } 173 | ``` 174 | 175 | #### 多轮对话 176 | 177 | 请求command: 178 | 179 | ``` shell 180 | curl http://localhost:8000/v1/chat/completions \ 181 | -H "Content-Type: application/json" \ 182 | -d '{ 183 | "model": "chinese-llama-alpaca-2", 184 | "messages": [ 185 | {"role": "user","content": "给我讲一些有关杭州的故事吧"}, 186 | {"role": "assistant","content": "好的,请问您对杭州有什么特别的偏好吗?"}, 187 | {"role": "user","content": "我比较喜欢和西湖,可以给我讲一下西湖吗"} 188 | ], 189 | "repetition_penalty": 1.0 190 | }' 191 | ``` 192 | 193 | json返回体: 194 | 195 | ``` json 196 | { 197 | "id": "cmpl-02bf36497d3543c980ca2ae8cc4feb63", 198 | "object": "chat.completion", 199 | "created": 1690872676, 200 | "model": "chinese-llama-alpaca-2", 201 | "choices": [ 202 | { 203 | "index": 0, 204 | "message": { 205 | "role": "assistant", 206 | "content": "是的,西湖是杭州最著名的景点之一,它被誉为“人间天堂”。 <\\s>" 207 | } 208 | } 209 | ] 210 | } 211 | ``` 212 | 213 | #### 聊天接口高级参数说明 214 | 215 | `prompt`: 生成文字接龙(completion)的提示。 216 | 217 | `max_tokens`: 新生成的句子的token长度。 218 | 219 | `temperature`: 在0和2之间选择的采样温度。较高的值如0.8会使输出更加随机,而较低的值如0.2则会使其输出更具有确定性。temperature越高,使用随机采样最为decoding的概率越大。 220 | 221 | `use_beam_search`: 使用束搜索(beam search)。默认为`false`,即启用随机采样策略(random sampling) 222 | 223 | `n`: 输出序列的数量,默认为1 224 | 225 | `best_of`: 当搜索策略为束搜索(beam search)时,该参数为在束搜索(beam search)中所使用的束个数。默认和`n`相同 226 | 227 | `top_k`: 在随机采样(random sampling)时,前top_k高概率的token将作为候选token被随机采样。 228 | 229 | `top_p`: 在随机采样(random sampling)时,累积概率超过top_p的token将作为候选token被随机采样,越低随机性越大,举个例子,当top_p设定为0.6时,概率前5的token概率分别为{0.23, 0.20, 0.18, 0.11, 0.10}时,前三个token的累积概率为0.61,那么第4个token将被过滤掉,只有前三的token将作为候选token被随机采样。 230 | 231 | `presence_penalty`: 重复惩罚,取值范围-2 ~ 2,默认值为0。值大于0表示鼓励模型使用新的token,反之鼓励重复。 232 | 233 | `stream`: 设置为`true`时,按流式输出的形式返回。默认为`false`。 234 | -------------------------------------------------------------------------------- /scripts/openai_server_demo/openai_api_protocol.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, List, Dict, Any, Union, Literal 2 | 3 | import time 4 | 5 | import shortuuid 6 | from pydantic import BaseModel, Field 7 | 8 | 9 | class ChatCompletionRequest(BaseModel): 10 | model: str = "chinese-llama-alpaca-2" 11 | messages: Union[str, List[Dict[str, str]]] 12 | temperature: Optional[float] = 0.2 13 | top_p: Optional[float] = 0.9 14 | top_k: Optional[int] = 40 15 | n: Optional[int] = 1 16 | max_tokens: Optional[int] = 512 17 | num_beams: Optional[int] = 1 18 | stop: Optional[Union[str, List[str]]] = None 19 | stream: Optional[bool] = False 20 | repetition_penalty: Optional[float] = 1.1 21 | user: Optional[str] = None 22 | do_sample: Optional[bool] = True 23 | 24 | 25 | class ChatMessage(BaseModel): 26 | role: str 27 | content: str 28 | 29 | 30 | class DeltaMessage(BaseModel): 31 | role: Optional[Literal["user", "assistant", "system"]] = None 32 | content: Optional[str] = None 33 | 34 | 35 | class ChatCompletionResponseChoice(BaseModel): 36 | index: int 37 | message: ChatMessage 38 | 39 | 40 | class ChatCompletionResponseStreamChoice(BaseModel): 41 | index: int 42 | delta: DeltaMessage 43 | finish_reason: Optional[Literal["stop", "length"]] 44 | 45 | 46 | class ChatCompletionResponse(BaseModel): 47 | id: str = Field(default_factory=lambda: f"chatcmpl-{shortuuid.random()}") 48 | object: str = "chat.completion" 49 | created: int = Field(default_factory=lambda: int(time.time())) 50 | model: str = "chinese-llama-alpaca-2" 51 | choices: List[ 52 | Union[ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice] 53 | ] 54 | 55 | 56 | class EmbeddingsRequest(BaseModel): 57 | input: Union[str, List[Any]] 58 | user: Optional[str] = None 59 | 60 | 61 | class EmbeddingsResponse(BaseModel): 62 | object: str = "list" 63 | data: List[Dict[str, Any]] 64 | model: str = "chinese-llama-alpaca-2" 65 | 66 | 67 | class CompletionRequest(BaseModel): 68 | prompt: Union[str, List[Any]] 69 | temperature: Optional[float] = 0.2 70 | n: Optional[int] = 1 71 | max_tokens: Optional[int] = 512 72 | stop: Optional[Union[str, List[str]]] = None 73 | stream: Optional[bool] = False 74 | top_p: Optional[float] = 0.9 75 | top_k: Optional[int] = 40 76 | num_beams: Optional[int] = 1 77 | logprobs: Optional[int] = None 78 | echo: Optional[bool] = False 79 | repetition_penalty: Optional[float] = 1.1 80 | user: Optional[str] = None 81 | do_sample: Optional[bool] = True 82 | 83 | 84 | class CompletionResponseChoice(BaseModel): 85 | index: int 86 | text: str 87 | 88 | 89 | class CompletionResponse(BaseModel): 90 | id: Optional[str] = Field(default_factory=lambda: f"cmpl-{shortuuid.random()}") 91 | object: Optional[str] = "text_completion" 92 | created: Optional[int] = Field(default_factory=lambda: int(time.time())) 93 | model: Optional[str] = "chinese-llama-alpaca-2" 94 | choices: List[CompletionResponseChoice] 95 | -------------------------------------------------------------------------------- /scripts/openai_server_demo/openai_api_protocol_vllm.py: -------------------------------------------------------------------------------- 1 | import time 2 | from typing import Dict, List, Literal, Optional, Union 3 | 4 | from pydantic import BaseModel, Field 5 | 6 | from vllm.utils import random_uuid 7 | 8 | 9 | class ErrorResponse(BaseModel): 10 | object: str = "error" 11 | message: str 12 | type: str 13 | param: Optional[str] = None 14 | code: Optional[str] = None 15 | 16 | 17 | class ModelPermission(BaseModel): 18 | id: str = Field(default_factory=lambda: f"modelperm-{random_uuid()}") 19 | object: str = "model_permission" 20 | created: int = Field(default_factory=lambda: int(time.time())) 21 | allow_create_engine: bool = False 22 | allow_sampling: bool = True 23 | allow_logprobs: bool = True 24 | allow_search_indices: bool = False 25 | allow_view: bool = True 26 | allow_fine_tuning: bool = False 27 | organization: str = "*" 28 | group: Optional[str] = None 29 | is_blocking: str = False 30 | 31 | 32 | class ModelCard(BaseModel): 33 | id: str 34 | object: str = "model" 35 | created: int = Field(default_factory=lambda: int(time.time())) 36 | owned_by: str = "vllm" 37 | root: Optional[str] = None 38 | parent: Optional[str] = None 39 | permission: List[ModelPermission] = Field(default_factory=list) 40 | 41 | 42 | class ModelList(BaseModel): 43 | object: str = "list" 44 | data: List[ModelCard] = Field(default_factory=list) 45 | 46 | 47 | class UsageInfo(BaseModel): 48 | prompt_tokens: int = 0 49 | total_tokens: int = 0 50 | completion_tokens: Optional[int] = 0 51 | 52 | 53 | class ChatCompletionRequest(BaseModel): 54 | model: str 55 | messages: Union[str, List[Dict[str, str]]] 56 | temperature: Optional[float] = 0.2 57 | top_p: Optional[float] = 0.9 58 | n: Optional[int] = 1 59 | max_tokens: Optional[int] = 512 60 | stop: Optional[Union[str, List[str]]] = Field(default_factory=list) 61 | stream: Optional[bool] = False 62 | presence_penalty: Optional[float] = 1.0 63 | frequency_penalty: Optional[float] = 0.0 64 | logit_bias: Optional[Dict[str, float]] = None 65 | user: Optional[str] = None 66 | # Additional parameters supported by vLLM 67 | best_of: Optional[int] = None 68 | top_k: Optional[int] = 40 69 | ignore_eos: Optional[bool] = False 70 | use_beam_search: Optional[bool] = False 71 | 72 | 73 | class CompletionRequest(BaseModel): 74 | model: str 75 | prompt: Union[str, List[str]] 76 | suffix: Optional[str] = None 77 | max_tokens: Optional[int] = 512 78 | temperature: Optional[float] = 0.2 79 | top_p: Optional[float] = 0.9 80 | n: Optional[int] = 1 81 | stream: Optional[bool] = False 82 | logprobs: Optional[int] = None 83 | echo: Optional[bool] = False 84 | stop: Optional[Union[str, List[str]]] = Field(default_factory=list) 85 | presence_penalty: Optional[float] = 1.0 86 | frequency_penalty: Optional[float] = 0.0 87 | best_of: Optional[int] = None 88 | logit_bias: Optional[Dict[str, float]] = None 89 | user: Optional[str] = None 90 | # Additional parameters supported by vLLM 91 | top_k: Optional[int] = 40 92 | ignore_eos: Optional[bool] = False 93 | use_beam_search: Optional[bool] = False 94 | 95 | 96 | class LogProbs(BaseModel): 97 | text_offset: List[int] = Field(default_factory=list) 98 | token_logprobs: List[Optional[float]] = Field(default_factory=list) 99 | tokens: List[str] = Field(default_factory=list) 100 | top_logprobs: List[Optional[Dict[str, 101 | float]]] = Field(default_factory=list) 102 | 103 | 104 | class CompletionResponseChoice(BaseModel): 105 | index: int 106 | text: str 107 | logprobs: Optional[LogProbs] = None 108 | finish_reason: Optional[Literal["stop", "length"]] = None 109 | 110 | 111 | class CompletionResponse(BaseModel): 112 | id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}") 113 | object: str = "text_completion" 114 | created: int = Field(default_factory=lambda: int(time.time())) 115 | model: str 116 | choices: List[CompletionResponseChoice] 117 | usage: UsageInfo 118 | 119 | 120 | class CompletionResponseStreamChoice(BaseModel): 121 | index: int 122 | text: str 123 | logprobs: Optional[LogProbs] = None 124 | finish_reason: Optional[Literal["stop", "length"]] = None 125 | 126 | 127 | class CompletionStreamResponse(BaseModel): 128 | id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}") 129 | object: str = "text_completion" 130 | created: int = Field(default_factory=lambda: int(time.time())) 131 | model: str 132 | choices: List[CompletionResponseStreamChoice] 133 | 134 | 135 | class ChatMessage(BaseModel): 136 | role: str 137 | content: str 138 | 139 | 140 | class ChatCompletionResponseChoice(BaseModel): 141 | index: int 142 | message: ChatMessage 143 | finish_reason: Optional[Literal["stop", "length"]] = None 144 | 145 | 146 | class ChatCompletionResponse(BaseModel): 147 | id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}") 148 | object: str = "chat.completion" 149 | created: int = Field(default_factory=lambda: int(time.time())) 150 | model: str 151 | choices: List[ChatCompletionResponseChoice] 152 | usage: UsageInfo 153 | 154 | 155 | class DeltaMessage(BaseModel): 156 | role: Optional[str] = None 157 | content: Optional[str] = None 158 | 159 | 160 | class ChatCompletionResponseStreamChoice(BaseModel): 161 | index: int 162 | delta: DeltaMessage 163 | finish_reason: Optional[Literal["stop", "length"]] = None 164 | 165 | 166 | class ChatCompletionStreamResponse(BaseModel): 167 | id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}") 168 | object: str = "chat.completion.chunk" 169 | created: int = Field(default_factory=lambda: int(time.time())) 170 | model: str 171 | choices: List[ChatCompletionResponseStreamChoice] 172 | -------------------------------------------------------------------------------- /scripts/privategpt/README.md: -------------------------------------------------------------------------------- 1 | ## privateGPT相关示例脚本 2 | 3 | 具体使用方法参考:https://github.com/ymcui/Chinese-LLaMA-Alpaca-2/wiki/privategpt_zh 4 | 5 | Detailed usage: https://github.com/ymcui/Chinese-LLaMA-Alpaca-2/wiki/privategpt_en 6 | 7 | The following codes are adapted from https://github.com/imartinez/privateGPT/blob/main/privateGPT.py 8 | 9 | ### privateGPT.py 10 | 11 | 嵌套Alpaca-2指令模板的主程序入口示例代码。由于第三方库更新频繁,请勿直接使用。建议对照教程自行修改。 12 | 13 | Example with Alpaca-2 template. Please do not use this script directly, as third-party library may change over time. Please follow our wiki to adapt to new code. 14 | 15 | ### privateGPT_refine.py 16 | 17 | 使用`refine`策略的主程序入口示例代码。 18 | 19 | Example that uses `refine` strategy. 20 | -------------------------------------------------------------------------------- /scripts/privategpt/privateGPT.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | from dotenv import load_dotenv 3 | from langchain.chains import RetrievalQA 4 | from langchain.embeddings import HuggingFaceEmbeddings 5 | from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler 6 | from langchain.vectorstores import Chroma 7 | from langchain.llms import GPT4All, LlamaCpp 8 | import os 9 | import argparse 10 | import time 11 | 12 | load_dotenv() 13 | 14 | embeddings_model_name = os.environ.get("EMBEDDINGS_MODEL_NAME") 15 | persist_directory = os.environ.get('PERSIST_DIRECTORY') 16 | 17 | model_type = os.environ.get('MODEL_TYPE') 18 | model_path = os.environ.get('MODEL_PATH') 19 | model_n_ctx = os.environ.get('MODEL_N_CTX') 20 | model_n_batch = int(os.environ.get('MODEL_N_BATCH', 8)) 21 | target_source_chunks = int(os.environ.get('TARGET_SOURCE_CHUNKS', 4)) 22 | 23 | from constants import CHROMA_SETTINGS 24 | 25 | def main(): 26 | # Parse the command line arguments 27 | args = parse_arguments() 28 | embeddings = HuggingFaceEmbeddings(model_name=embeddings_model_name) 29 | db = Chroma(persist_directory=persist_directory, embedding_function=embeddings, client_settings=CHROMA_SETTINGS) 30 | retriever = db.as_retriever(search_kwargs={"k": target_source_chunks}) 31 | # activate/deactivate the streaming StdOut callback for LLMs 32 | callbacks = [] if args.mute_stream else [StreamingStdOutCallbackHandler()] 33 | # Prepare the LLM 34 | match model_type: 35 | case "LlamaCpp": 36 | llm = LlamaCpp(model_path=model_path, max_tokens=model_n_ctx, n_ctx=model_n_ctx, 37 | n_gpu_layers=1, n_batch=model_n_batch, callbacks=callbacks, n_threads=8, verbose=False) 38 | case "GPT4All": 39 | llm = GPT4All(model=model_path, max_tokens=model_n_ctx, backend='gptj', n_batch=model_n_batch, callbacks=callbacks, verbose=False) 40 | case _default: 41 | # raise exception if model_type is not supported 42 | raise Exception(f"Model type {model_type} is not supported. Please choose one of the following: LlamaCpp, GPT4All") 43 | 44 | # The followings are specifically designed for Chinese-Alpaca-2 45 | # For detailed usage: https://github.com/ymcui/Chinese-LLaMA-Alpaca-2/wiki/privategpt_en 46 | alpaca2_prompt_template = ( 47 | "[INST] <>\n" 48 | "You are a helpful assistant. 你是一个乐于助人的助手。\n" 49 | "<>\n\n" 50 | "{context}\n\n{question} [/INST]" 51 | ) 52 | from langchain import PromptTemplate 53 | input_with_prompt = PromptTemplate(template=alpaca2_prompt_template, input_variables=["context", "question"]) 54 | 55 | qa = RetrievalQA.from_chain_type( 56 | llm=llm, chain_type="stuff", retriever=retriever, 57 | return_source_documents= not args.hide_source, 58 | chain_type_kwargs={"prompt": input_with_prompt}) 59 | 60 | # Interactive questions and answers 61 | while True: 62 | query = input("\nEnter a query: ") 63 | if query == "exit": 64 | break 65 | if query.strip() == "": 66 | continue 67 | 68 | # Get the answer from the chain 69 | start = time.time() 70 | res = qa(query) 71 | answer, docs = res['result'], [] if args.hide_source else res['source_documents'] 72 | end = time.time() 73 | 74 | # Print the result 75 | print("\n\n> Question:") 76 | print(query) 77 | print(f"\n> Answer (took {round(end - start, 2)} s.):") 78 | print(answer) 79 | 80 | # Print the relevant sources used for the answer 81 | for document in docs: 82 | print("\n> " + document.metadata["source"] + ":") 83 | print(document.page_content) 84 | 85 | def parse_arguments(): 86 | parser = argparse.ArgumentParser(description='privateGPT: Ask questions to your documents without an internet connection, ' 87 | 'using the power of LLMs.') 88 | parser.add_argument("--hide-source", "-S", action='store_true', 89 | help='Use this flag to disable printing of source documents used for answers.') 90 | 91 | parser.add_argument("--mute-stream", "-M", 92 | action='store_true', 93 | help='Use this flag to disable the streaming StdOut callback for LLMs.') 94 | 95 | return parser.parse_args() 96 | 97 | 98 | if __name__ == "__main__": 99 | main() 100 | -------------------------------------------------------------------------------- /scripts/privategpt/privateGPT_refine.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | from dotenv import load_dotenv 3 | from langchain.chains import RetrievalQA 4 | from langchain.embeddings import HuggingFaceEmbeddings 5 | from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler 6 | from langchain.vectorstores import Chroma 7 | from langchain.llms import GPT4All, LlamaCpp 8 | import os 9 | import argparse 10 | import time 11 | 12 | load_dotenv() 13 | 14 | embeddings_model_name = os.environ.get("EMBEDDINGS_MODEL_NAME") 15 | persist_directory = os.environ.get('PERSIST_DIRECTORY') 16 | 17 | model_type = os.environ.get('MODEL_TYPE') 18 | model_path = os.environ.get('MODEL_PATH') 19 | model_n_ctx = os.environ.get('MODEL_N_CTX') 20 | model_n_batch = int(os.environ.get('MODEL_N_BATCH', 8)) 21 | target_source_chunks = int(os.environ.get('TARGET_SOURCE_CHUNKS', 4)) 22 | 23 | from constants import CHROMA_SETTINGS 24 | 25 | def main(): 26 | # Parse the command line arguments 27 | args = parse_arguments() 28 | embeddings = HuggingFaceEmbeddings(model_name=embeddings_model_name) 29 | db = Chroma(persist_directory=persist_directory, embedding_function=embeddings, client_settings=CHROMA_SETTINGS) 30 | retriever = db.as_retriever(search_kwargs={"k": target_source_chunks}) 31 | # activate/deactivate the streaming StdOut callback for LLMs 32 | callbacks = [] if args.mute_stream else [StreamingStdOutCallbackHandler()] 33 | # Prepare the LLM 34 | match model_type: 35 | case "LlamaCpp": 36 | llm = LlamaCpp(model_path=model_path, max_tokens=model_n_ctx, n_ctx=model_n_ctx, 37 | n_gpu_layers=1, n_batch=model_n_batch, callbacks=callbacks, n_threads=8, verbose=False) 38 | case "GPT4All": 39 | llm = GPT4All(model=model_path, max_tokens=model_n_ctx, backend='gptj', n_batch=model_n_batch, callbacks=callbacks, verbose=False) 40 | case _default: 41 | # raise exception if model_type is not supported 42 | raise Exception(f"Model type {model_type} is not supported. Please choose one of the following: LlamaCpp, GPT4All") 43 | 44 | # The followings are specifically designed for Chinese-Alpaca-2 45 | # For detailed usage: https://github.com/ymcui/Chinese-LLaMA-Alpaca-2/wiki/privategpt_en 46 | alpaca2_refine_prompt_template = ( 47 | "[INST] <>\n" 48 | "You are a helpful assistant. 你是一个乐于助人的助手。\n" 49 | "<>\n\n" 50 | "这是原始问题:{question}\n" 51 | "已有的回答: {existing_answer}\n" 52 | "现在还有一些文字,(如果有需要)你可以根据它们完善现有的回答。" 53 | "\n\n{context_str}\n\n" 54 | "请根据新的文段,进一步完善你的回答。 [/INST]" 55 | ) 56 | 57 | alpaca2_initial_prompt_template = ( 58 | "[INST] <>\n" 59 | "You are a helpful assistant. 你是一个乐于助人的助手。\n" 60 | "<>\n\n" 61 | "以下为背景知识:\n{context_str}\n" 62 | "请根据以上背景知识,回答这个问题:{question} [/INST]" 63 | ) 64 | 65 | from langchain import PromptTemplate 66 | refine_prompt = PromptTemplate( 67 | input_variables=["question", "existing_answer", "context_str"], 68 | template=alpaca2_refine_prompt_template, 69 | ) 70 | initial_qa_prompt = PromptTemplate( 71 | input_variables=["context_str", "question"], 72 | template=alpaca2_initial_prompt_template, 73 | ) 74 | chain_type_kwargs = {"question_prompt": initial_qa_prompt, "refine_prompt": refine_prompt} 75 | qa = RetrievalQA.from_chain_type( 76 | llm=llm, chain_type="refine", 77 | retriever=retriever, return_source_documents= not args.hide_source, 78 | chain_type_kwargs=chain_type_kwargs) 79 | 80 | # Interactive questions and answers 81 | while True: 82 | query = input("\nEnter a query: ") 83 | if query == "exit": 84 | break 85 | if query.strip() == "": 86 | continue 87 | 88 | # Get the answer from the chain 89 | start = time.time() 90 | res = qa(query) 91 | answer, docs = res['result'], [] if args.hide_source else res['source_documents'] 92 | end = time.time() 93 | 94 | # Print the result 95 | print("\n\n> Question:") 96 | print(query) 97 | print(f"\n> Answer (took {round(end - start, 2)} s.):") 98 | print(answer) 99 | 100 | # Print the relevant sources used for the answer 101 | for document in docs: 102 | print("\n> " + document.metadata["source"] + ":") 103 | print(document.page_content) 104 | 105 | def parse_arguments(): 106 | parser = argparse.ArgumentParser(description='privateGPT: Ask questions to your documents without an internet connection, ' 107 | 'using the power of LLMs.') 108 | parser.add_argument("--hide-source", "-S", action='store_true', 109 | help='Use this flag to disable printing of source documents used for answers.') 110 | 111 | parser.add_argument("--mute-stream", "-M", 112 | action='store_true', 113 | help='Use this flag to disable the streaming StdOut callback for LLMs.') 114 | 115 | return parser.parse_args() 116 | 117 | 118 | if __name__ == "__main__": 119 | main() 120 | -------------------------------------------------------------------------------- /scripts/tokenizer/special_tokens_map.json: -------------------------------------------------------------------------------- 1 | { 2 | "bos_token": { 3 | "content": "", 4 | "lstrip": false, 5 | "normalized": true, 6 | "rstrip": false, 7 | "single_word": false 8 | }, 9 | "eos_token": { 10 | "content": "", 11 | "lstrip": false, 12 | "normalized": true, 13 | "rstrip": false, 14 | "single_word": false 15 | }, 16 | "pad_token": "", 17 | "unk_token": { 18 | "content": "", 19 | "lstrip": false, 20 | "normalized": true, 21 | "rstrip": false, 22 | "single_word": false 23 | } 24 | } 25 | -------------------------------------------------------------------------------- /scripts/tokenizer/tokenizer.model: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ymcui/Chinese-LLaMA-Alpaca-2/2a334d1634c857a7f02f885026d02ac4b469479d/scripts/tokenizer/tokenizer.model -------------------------------------------------------------------------------- /scripts/tokenizer/tokenizer_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "add_bos_token": true, 3 | "add_eos_token": false, 4 | "bos_token": { 5 | "__type": "AddedToken", 6 | "content": "", 7 | "lstrip": false, 8 | "normalized": true, 9 | "rstrip": false, 10 | "single_word": false 11 | }, 12 | "clean_up_tokenization_spaces": false, 13 | "eos_token": { 14 | "__type": "AddedToken", 15 | "content": "", 16 | "lstrip": false, 17 | "normalized": true, 18 | "rstrip": false, 19 | "single_word": false 20 | }, 21 | "model_max_length": 1000000000000000019884624838656, 22 | "pad_token": null, 23 | "sp_model_kwargs": {}, 24 | "tokenizer_class": "LlamaTokenizer", 25 | "unk_token": { 26 | "__type": "AddedToken", 27 | "content": "", 28 | "lstrip": false, 29 | "normalized": true, 30 | "rstrip": false, 31 | "single_word": false 32 | }, 33 | "use_fast": false 34 | } 35 | -------------------------------------------------------------------------------- /scripts/training/build_dataset.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | from dataclasses import dataclass 4 | from typing import Dict, Sequence, Union, List 5 | import datasets 6 | import torch 7 | from datasets import load_dataset, concatenate_datasets 8 | import transformers 9 | 10 | 11 | IGNORE_INDEX = -100 12 | 13 | logger = logging.getLogger('__name__') 14 | 15 | PROMPT_TEMPLATE = ( 16 | "[INST] <>\n" 17 | "You are a helpful assistant. 你是一个乐于助人的助手。\n" 18 | "<>\n\n{instruction} [/INST]" 19 | ) 20 | 21 | def build_instruction_dataset(data_path: Union[List[str],str], 22 | tokenizer: transformers.PreTrainedTokenizer, 23 | max_seq_length: int, data_cache_dir = None, 24 | preprocessing_num_workers = None, 25 | ): 26 | 27 | def tokenization(examples): 28 | sources = [] 29 | targets = [] 30 | prompt = PROMPT_TEMPLATE 31 | for instruction, input, output in zip(examples['instruction'],examples['input'],examples['output']): 32 | if input is not None and input !="": 33 | instruction = instruction+'\n'+input 34 | source = prompt.format_map({'instruction':instruction}) 35 | target = f"{output}{tokenizer.eos_token}" 36 | 37 | sources.append(source) 38 | targets.append(target) 39 | 40 | tokenized_sources = tokenizer(sources,return_attention_mask=False) 41 | tokenized_targets = tokenizer(targets,return_attention_mask=False,add_special_tokens=False) 42 | 43 | all_input_ids = [] 44 | all_labels = [] 45 | for s,t in zip(tokenized_sources['input_ids'],tokenized_targets['input_ids']): 46 | input_ids = torch.LongTensor(s + t)[:max_seq_length] 47 | labels = torch.LongTensor([IGNORE_INDEX] * len(s) + t)[:max_seq_length] 48 | assert len(input_ids) == len(labels) 49 | all_input_ids.append(input_ids) 50 | all_labels.append(labels) 51 | 52 | results = {'input_ids':all_input_ids, 'labels': all_labels} 53 | return results 54 | 55 | 56 | logging.warning("building dataset...") 57 | all_datasets = [] 58 | 59 | if not isinstance(data_path,(list,tuple)): 60 | data_path = [data_path] 61 | for file in data_path: 62 | 63 | if data_cache_dir is None: 64 | data_cache_dir = str(os.path.dirname(file)) 65 | cache_path = os.path.join(data_cache_dir,os.path.basename(file).split('.')[0]+f"_{max_seq_length}") 66 | os.makedirs(cache_path, exist_ok=True) 67 | try: 68 | processed_dataset = datasets.load_from_disk(cache_path) 69 | logger.info(f'training datasets-{file} has been loaded from disk') 70 | except Exception: 71 | raw_dataset = load_dataset("json", data_files=file, cache_dir=cache_path) 72 | tokenization_func = tokenization 73 | tokenized_dataset = raw_dataset.map( 74 | tokenization_func, 75 | batched=True, 76 | num_proc=preprocessing_num_workers, 77 | remove_columns=["instruction","input","output"], 78 | keep_in_memory=False, 79 | desc="preprocessing on dataset", 80 | ) 81 | processed_dataset = tokenized_dataset 82 | processed_dataset.save_to_disk(cache_path) 83 | processed_dataset.set_format('torch') 84 | all_datasets.append(processed_dataset['train']) 85 | all_datasets = concatenate_datasets(all_datasets) 86 | return all_datasets 87 | 88 | @dataclass 89 | class DataCollatorForSupervisedDataset(object): 90 | """Collate examples for supervised fine-tuning.""" 91 | 92 | tokenizer: transformers.PreTrainedTokenizer 93 | 94 | def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: 95 | input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels")) 96 | input_ids = torch.nn.utils.rnn.pad_sequence( 97 | input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id 98 | ) 99 | labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=-100) 100 | return dict( 101 | input_ids=input_ids, 102 | labels=labels, 103 | attention_mask=input_ids.ne(self.tokenizer.pad_token_id), 104 | ) 105 | -------------------------------------------------------------------------------- /scripts/training/ds_zero2_no_offload.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 100, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1e-10 9 | }, 10 | 11 | "zero_optimization": { 12 | "stage": 2, 13 | "allgather_partitions": true, 14 | "allgather_bucket_size": 1e8, 15 | "overlap_comm": true, 16 | "reduce_scatter": true, 17 | "reduce_bucket_size": 1e8, 18 | "contiguous_gradients": true 19 | }, 20 | 21 | "gradient_accumulation_steps": "auto", 22 | "gradient_clipping": "auto", 23 | "steps_per_print": 2000, 24 | "train_batch_size": "auto", 25 | "train_micro_batch_size_per_gpu": "auto", 26 | "wall_clock_breakdown": false 27 | } 28 | -------------------------------------------------------------------------------- /scripts/training/peft/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | # There's no way to ignore "F401 '...' imported but unused" warnings in this 3 | # module, but to preserve other warnings. So, don't check this module at all. 4 | 5 | # coding=utf-8 6 | # Copyright 2023-present the HuggingFace Inc. team. 7 | # 8 | # Licensed under the Apache License, Version 2.0 (the "License"); 9 | # you may not use this file except in compliance with the License. 10 | # You may obtain a copy of the License at 11 | # 12 | # http://www.apache.org/licenses/LICENSE-2.0 13 | # 14 | # Unless required by applicable law or agreed to in writing, software 15 | # distributed under the License is distributed on an "AS IS" BASIS, 16 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 17 | # See the License for the specific language governing permissions and 18 | # limitations under the License. 19 | 20 | __version__ = "0.3.0.dev0" 21 | 22 | from .mapping import MODEL_TYPE_TO_PEFT_MODEL_MAPPING, PEFT_TYPE_TO_CONFIG_MAPPING, get_peft_config, get_peft_model 23 | from .peft_model import ( 24 | PeftModel, 25 | PeftModelForCausalLM, 26 | PeftModelForSeq2SeqLM, 27 | PeftModelForSequenceClassification, 28 | PeftModelForTokenClassification, 29 | ) 30 | from .tuners import ( 31 | LoraConfig, 32 | LoraModel, 33 | PrefixEncoder, 34 | PrefixTuningConfig, 35 | PromptEmbedding, 36 | PromptEncoder, 37 | PromptEncoderConfig, 38 | PromptEncoderReparameterizationType, 39 | PromptTuningConfig, 40 | PromptTuningInit, 41 | ) 42 | from .utils import ( 43 | TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING, 44 | PeftConfig, 45 | PeftType, 46 | PromptLearningConfig, 47 | TaskType, 48 | bloom_model_postprocess_past_key_value, 49 | get_peft_model_state_dict, 50 | # prepare_model_for_int8_training, 51 | set_peft_model_state_dict, 52 | shift_tokens_right, 53 | ) 54 | -------------------------------------------------------------------------------- /scripts/training/peft/mapping.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023-present the HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from .peft_model import ( 17 | PeftModel, 18 | PeftModelForCausalLM, 19 | PeftModelForSeq2SeqLM, 20 | PeftModelForSequenceClassification, 21 | PeftModelForTokenClassification, 22 | ) 23 | from .tuners import LoraConfig, PrefixTuningConfig, PromptEncoderConfig, PromptTuningConfig 24 | from .utils import PromptLearningConfig 25 | 26 | 27 | MODEL_TYPE_TO_PEFT_MODEL_MAPPING = { 28 | "SEQ_CLS": PeftModelForSequenceClassification, 29 | "SEQ_2_SEQ_LM": PeftModelForSeq2SeqLM, 30 | "CAUSAL_LM": PeftModelForCausalLM, 31 | "TOKEN_CLS": PeftModelForTokenClassification, 32 | } 33 | 34 | PEFT_TYPE_TO_CONFIG_MAPPING = { 35 | "PROMPT_TUNING": PromptTuningConfig, 36 | "PREFIX_TUNING": PrefixTuningConfig, 37 | "P_TUNING": PromptEncoderConfig, 38 | "LORA": LoraConfig, 39 | } 40 | 41 | TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING = { 42 | "t5": ["q", "v"], 43 | "mt5": ["q", "v"], 44 | "bart": ["q_proj", "v_proj"], 45 | "gpt2": ["c_attn"], 46 | "bloom": ["query_key_value"], 47 | "opt": ["q_proj", "v_proj"], 48 | "gptj": ["q_proj", "v_proj"], 49 | "gpt_neox": ["query_key_value"], 50 | "gpt_neo": ["q_proj", "v_proj"], 51 | "bert": ["query", "value"], 52 | "roberta": ["query", "value"], 53 | "xlm-roberta": ["query", "value"], 54 | "electra": ["query", "value"], 55 | "deberta-v2": ["query_proj", "value_proj"], 56 | "deberta": ["in_proj"], 57 | "layoutlm": ["query", "value"], 58 | "llama": ["q_proj", "v_proj"], 59 | "chatglm": ["query_key_value"], 60 | } 61 | 62 | 63 | def get_peft_config(config_dict): 64 | """ 65 | Returns a Peft config object from a dictionary. 66 | 67 | Args: 68 | config_dict (`Dict[str, Any]`): Dictionary containing the configuration parameters. 69 | """ 70 | 71 | return PEFT_TYPE_TO_CONFIG_MAPPING[config_dict["peft_type"]](**config_dict) 72 | 73 | 74 | def _prepare_prompt_learning_config(peft_config, model_config): 75 | if peft_config.num_layers is None: 76 | if "num_hidden_layers" in model_config: 77 | num_layers = model_config["num_hidden_layers"] 78 | elif "num_layers" in model_config: 79 | num_layers = model_config["num_layers"] 80 | elif "n_layer" in model_config: 81 | num_layers = model_config["n_layer"] 82 | else: 83 | raise ValueError("Please specify `num_layers` in `peft_config`") 84 | peft_config.num_layers = num_layers 85 | 86 | if peft_config.token_dim is None: 87 | if "hidden_size" in model_config: 88 | token_dim = model_config["hidden_size"] 89 | elif "n_embd" in model_config: 90 | token_dim = model_config["n_embd"] 91 | elif "d_model" in model_config: 92 | token_dim = model_config["d_model"] 93 | else: 94 | raise ValueError("Please specify `token_dim` in `peft_config`") 95 | peft_config.token_dim = token_dim 96 | 97 | if peft_config.num_attention_heads is None: 98 | if "num_attention_heads" in model_config: 99 | num_attention_heads = model_config["num_attention_heads"] 100 | elif "n_head" in model_config: 101 | num_attention_heads = model_config["n_head"] 102 | elif "num_heads" in model_config: 103 | num_attention_heads = model_config["num_heads"] 104 | elif "encoder_attention_heads" in model_config: 105 | num_attention_heads = model_config["encoder_attention_heads"] 106 | else: 107 | raise ValueError("Please specify `num_attention_heads` in `peft_config`") 108 | peft_config.num_attention_heads = num_attention_heads 109 | 110 | if getattr(peft_config, "encoder_hidden_size", None) is None: 111 | setattr(peft_config, "encoder_hidden_size", token_dim) 112 | 113 | return peft_config 114 | 115 | 116 | def _prepare_lora_config(peft_config, model_config): 117 | if peft_config.target_modules is None: 118 | if model_config["model_type"] not in TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING: 119 | raise ValueError("Please specify `target_modules` in `peft_config`") 120 | peft_config.target_modules = TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING[model_config["model_type"]] 121 | if len(peft_config.target_modules) == 1: 122 | peft_config.fan_in_fan_out = True 123 | peft_config.enable_lora = [True, False, True] 124 | if peft_config.inference_mode: 125 | peft_config.merge_weights = True 126 | return peft_config 127 | 128 | 129 | def get_peft_model(model, peft_config): 130 | """ 131 | Returns a Peft model object from a model and a config. 132 | 133 | Args: 134 | model ([`transformers.PreTrainedModel`]): Model to be wrapped. 135 | peft_config ([`PeftConfig`]): Configuration object containing the parameters of the Peft model. 136 | """ 137 | 138 | model_config = model.config.to_dict() 139 | peft_config.base_model_name_or_path = model.__dict__.get("name_or_path", None) 140 | if peft_config.task_type not in MODEL_TYPE_TO_PEFT_MODEL_MAPPING.keys(): 141 | peft_config = _prepare_lora_config(peft_config, model_config) 142 | return PeftModel(model, peft_config) 143 | if not isinstance(peft_config, PromptLearningConfig): 144 | peft_config = _prepare_lora_config(peft_config, model_config) 145 | else: 146 | peft_config = _prepare_prompt_learning_config(peft_config, model_config) 147 | return MODEL_TYPE_TO_PEFT_MODEL_MAPPING[peft_config.task_type](model, peft_config) 148 | -------------------------------------------------------------------------------- /scripts/training/peft/tuners/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | # There's no way to ignore "F401 '...' imported but unused" warnings in this 3 | # module, but to preserve other warnings. So, don't check this module at all 4 | 5 | # coding=utf-8 6 | # Copyright 2023-present the HuggingFace Inc. team. 7 | # 8 | # Licensed under the Apache License, Version 2.0 (the "License"); 9 | # you may not use this file except in compliance with the License. 10 | # You may obtain a copy of the License at 11 | # 12 | # http://www.apache.org/licenses/LICENSE-2.0 13 | # 14 | # Unless required by applicable law or agreed to in writing, software 15 | # distributed under the License is distributed on an "AS IS" BASIS, 16 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 17 | # See the License for the specific language governing permissions and 18 | # limitations under the License. 19 | 20 | from .lora import LoraConfig, LoraModel 21 | from .p_tuning import PromptEncoder, PromptEncoderConfig, PromptEncoderReparameterizationType 22 | from .prefix_tuning import PrefixEncoder, PrefixTuningConfig 23 | from .prompt_tuning import PromptEmbedding, PromptTuningConfig, PromptTuningInit 24 | -------------------------------------------------------------------------------- /scripts/training/peft/tuners/p_tuning.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023-present the HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import enum 17 | import warnings 18 | from dataclasses import dataclass, field 19 | from typing import Union 20 | 21 | import torch 22 | 23 | from ..utils import PeftType, PromptLearningConfig 24 | 25 | 26 | class PromptEncoderReparameterizationType(str, enum.Enum): 27 | MLP = "MLP" 28 | LSTM = "LSTM" 29 | 30 | 31 | @dataclass 32 | class PromptEncoderConfig(PromptLearningConfig): 33 | """ 34 | This is the configuration class to store the configuration of a [`~peft.PromptEncoder`]. 35 | 36 | Args: 37 | encoder_reparameterization_type 38 | (Union[[`PromptEncoderReparameterizationType`], `str`]): The type of reparameterization to use. 39 | encoder_hidden_size (`int`): The hidden size of the prompt encoder. 40 | encoder_num_layers (`int`): The number of layers of the prompt encoder. 41 | encoder_dropout (`float`): The dropout probability of the prompt encoder. 42 | """ 43 | 44 | encoder_reparameterization_type: Union[str, PromptEncoderReparameterizationType] = field( 45 | default=PromptEncoderReparameterizationType.MLP, 46 | metadata={"help": "How to reparameterize the prompt encoder"}, 47 | ) 48 | encoder_hidden_size: int = field( 49 | default=None, 50 | metadata={"help": "The hidden size of the prompt encoder"}, 51 | ) 52 | encoder_num_layers: int = field( 53 | default=2, 54 | metadata={"help": "The number of layers of the prompt encoder"}, 55 | ) 56 | encoder_dropout: float = field( 57 | default=0.0, 58 | metadata={"help": "The dropout of the prompt encoder"}, 59 | ) 60 | 61 | def __post_init__(self): 62 | self.peft_type = PeftType.P_TUNING 63 | 64 | 65 | # Based on https://github.com/NVIDIA/NeMo/blob/main/nemo/collections/nlp/modules/common/prompt_encoder.py 66 | # with some refactor 67 | class PromptEncoder(torch.nn.Module): 68 | """ 69 | The prompt encoder network that is used to generate the virtual token embeddings for p-tuning. 70 | 71 | Args: 72 | config ([`PromptEncoderConfig`]): The configuration of the prompt encoder. 73 | 74 | Example:: 75 | 76 | >>> from peft import PromptEncoder, PromptEncoderConfig >>> config = PromptEncoderConfig( 77 | peft_type="P_TUNING", task_type="SEQ_2_SEQ_LM", num_virtual_tokens=20, token_dim=768, 78 | num_transformer_submodules=1, num_attention_heads=12, num_layers=12, 79 | encoder_reparameterization_type="MLP", encoder_hidden_size=768 80 | ) 81 | >>> prompt_encoder = PromptEncoder(config) 82 | 83 | **Attributes**: 84 | - **embedding** ([`~torch.nn.Embedding`]) -- The embedding layer of the prompt encoder. 85 | - **mlp_head** ([`~torch.nn.Sequential`]) -- The MLP head of the prompt encoder if `inference_mode=False`. 86 | - **lstm_head** ([`~torch.nn.LSTM`]) -- The LSTM head of the prompt encoder if `inference_mode=False` and 87 | `encoder_reparameterization_type="LSTM"`. 88 | - **token_dim** (`int`) -- The hidden embedding dimension of the base transformer model. 89 | - **input_size** (`int`) -- The input size of the prompt encoder. 90 | - **output_size** (`int`) -- The output size of the prompt encoder. 91 | - **hidden_size** (`int`) -- The hidden size of the prompt encoder. 92 | - **total_virtual_tokens** (`int`): The total number of virtual tokens of the 93 | prompt encoder. 94 | - **encoder_type** (Union[[`PromptEncoderReparameterizationType`], `str`]): 95 | The encoder type of the prompt encoder. 96 | 97 | 98 | Input shape: (batch_size, total_virtual_tokens) 99 | 100 | Output shape: (batch_size, total_virtual_tokens, token_dim) 101 | """ 102 | 103 | def __init__(self, config): 104 | super().__init__() 105 | self.token_dim = config.token_dim 106 | self.input_size = self.token_dim 107 | self.output_size = self.token_dim 108 | self.hidden_size = config.encoder_hidden_size 109 | self.total_virtual_tokens = config.num_virtual_tokens * config.num_transformer_submodules 110 | self.encoder_type = config.encoder_reparameterization_type 111 | 112 | # embedding 113 | self.embedding = torch.nn.Embedding(self.total_virtual_tokens, self.token_dim) 114 | if not config.inference_mode: 115 | if self.encoder_type == PromptEncoderReparameterizationType.LSTM: 116 | lstm_dropout = config.encoder_dropout 117 | num_layers = config.encoder_num_layers 118 | # LSTM 119 | self.lstm_head = torch.nn.LSTM( 120 | input_size=self.input_size, 121 | hidden_size=self.hidden_size, 122 | num_layers=num_layers, 123 | dropout=lstm_dropout, 124 | bidirectional=True, 125 | batch_first=True, 126 | ) 127 | 128 | self.mlp_head = torch.nn.Sequential( 129 | torch.nn.Linear(self.hidden_size * 2, self.hidden_size * 2), 130 | torch.nn.ReLU(), 131 | torch.nn.Linear(self.hidden_size * 2, self.output_size), 132 | ) 133 | 134 | elif self.encoder_type == PromptEncoderReparameterizationType.MLP: 135 | warnings.warn( 136 | f"for {self.encoder_type}, the `encoder_num_layers` is ignored. Exactly 2 MLP layers are used." 137 | ) 138 | layers = [ 139 | torch.nn.Linear(self.input_size, self.hidden_size), 140 | torch.nn.ReLU(), 141 | torch.nn.Linear(self.hidden_size, self.hidden_size), 142 | torch.nn.ReLU(), 143 | torch.nn.Linear(self.hidden_size, self.output_size), 144 | ] 145 | self.mlp_head = torch.nn.Sequential(*layers) 146 | 147 | else: 148 | raise ValueError("Prompt encoder type not recognized. Please use one of MLP (recommended) or LSTM.") 149 | 150 | def forward(self, indices): 151 | input_embeds = self.embedding(indices) 152 | if self.encoder_type == PromptEncoderReparameterizationType.LSTM: 153 | output_embeds = self.mlp_head(self.lstm_head(input_embeds)[0]) 154 | elif self.encoder_type == PromptEncoderReparameterizationType.MLP: 155 | output_embeds = self.mlp_head(input_embeds) 156 | else: 157 | raise ValueError("Prompt encoder type not recognized. Please use one of MLP (recommended) or LSTM.") 158 | 159 | return output_embeds 160 | -------------------------------------------------------------------------------- /scripts/training/peft/tuners/prefix_tuning.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023-present the HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | 17 | from dataclasses import dataclass, field 18 | 19 | import torch 20 | 21 | from ..utils import PeftType, PromptLearningConfig 22 | 23 | 24 | @dataclass 25 | class PrefixTuningConfig(PromptLearningConfig): 26 | """ 27 | This is the configuration class to store the configuration of a [`~peft.PrefixEncoder`]. 28 | 29 | Args: 30 | encoder_hidden_size (`int`): The hidden size of the prompt encoder. 31 | prefix_projection (`bool`): Whether to project the prefix embeddings. 32 | """ 33 | 34 | encoder_hidden_size: int = field( 35 | default=None, 36 | metadata={"help": "The hidden size of the encoder"}, 37 | ) 38 | prefix_projection: bool = field( 39 | default=False, 40 | metadata={"help": "Whether to project the prefix tokens"}, 41 | ) 42 | 43 | def __post_init__(self): 44 | self.peft_type = PeftType.PREFIX_TUNING 45 | 46 | 47 | # Based on https://github.com/THUDM/P-tuning-v2/blob/main/model/prefix_encoder.py 48 | # with some refactor 49 | class PrefixEncoder(torch.nn.Module): 50 | r""" 51 | The torch.nn model to encode the prefix 52 | 53 | Args: 54 | config ([`PrefixTuningConfig`]): The configuration of the prefix encoder. 55 | 56 | Example:: 57 | 58 | >>> from peft import PrefixEncoder, PrefixTuningConfig >>> config = PrefixTuningConfig( 59 | peft_type="PREFIX_TUNING", task_type="SEQ_2_SEQ_LM", num_virtual_tokens=20, token_dim=768, 60 | num_transformer_submodules=1, num_attention_heads=12, num_layers=12, encoder_hidden_size=768 61 | ) 62 | >>> prefix_encoder = PrefixEncoder(config) 63 | 64 | 65 | **Attributes**: 66 | - **embedding** (`torch.nn.Embedding`) -- 67 | The embedding layer of the prefix encoder. 68 | - **transform** (`torch.nn.Sequential`) -- The 69 | two-layer MLP to transform the prefix embeddings if `prefix_projection` is `True`. 70 | - **prefix_projection** (`bool`) -- Whether to project the prefix embeddings. 71 | 72 | Input shape: (batch_size, num_virtual_tokens) 73 | 74 | Output shape: (batch_size, num_virtual_tokens, 2*layers*hidden) 75 | """ 76 | 77 | def __init__(self, config): 78 | super().__init__() 79 | self.prefix_projection = config.prefix_projection 80 | token_dim = config.token_dim 81 | num_layers = config.num_layers 82 | encoder_hidden_size = config.encoder_hidden_size 83 | num_virtual_tokens = config.num_virtual_tokens 84 | if self.prefix_projection and not config.inference_mode: 85 | # Use a two-layer MLP to encode the prefix 86 | self.embedding = torch.nn.Embedding(num_virtual_tokens, token_dim) 87 | self.transform = torch.nn.Sequential( 88 | torch.nn.Linear(token_dim, encoder_hidden_size), 89 | torch.nn.Tanh(), 90 | torch.nn.Linear(encoder_hidden_size, num_layers * 2 * token_dim), 91 | ) 92 | else: 93 | self.embedding = torch.nn.Embedding(num_virtual_tokens, num_layers * 2 * token_dim) 94 | 95 | def forward(self, prefix: torch.Tensor): 96 | if self.prefix_projection: 97 | prefix_tokens = self.embedding(prefix) 98 | past_key_values = self.transform(prefix_tokens) 99 | else: 100 | past_key_values = self.embedding(prefix) 101 | return past_key_values 102 | -------------------------------------------------------------------------------- /scripts/training/peft/tuners/prompt_tuning.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023-present the HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import enum 17 | import math 18 | from dataclasses import dataclass, field 19 | from typing import Optional, Union 20 | 21 | import torch 22 | 23 | from ..utils import PeftType, PromptLearningConfig 24 | 25 | 26 | class PromptTuningInit(str, enum.Enum): 27 | TEXT = "TEXT" 28 | RANDOM = "RANDOM" 29 | 30 | 31 | @dataclass 32 | class PromptTuningConfig(PromptLearningConfig): 33 | """ 34 | This is the configuration class to store the configuration of a [`~peft.PromptEmbedding`]. 35 | 36 | Args: 37 | prompt_tuning_init (Union[[`PromptTuningInit`], `str`]): The initialization of the prompt embedding. 38 | prompt_tuning_init_text ( Optional[`str`]): The text to initialize the prompt embedding. 39 | Only used if `prompt_tuning_init` is `TEXT` 40 | tokenizer_name_or_path ( Optional[`str`]): The name or path of the tokenizer. 41 | Only used if `prompt_tuning_init` is `TEXT` 42 | """ 43 | 44 | prompt_tuning_init: Union[PromptTuningInit, str] = field( 45 | default=PromptTuningInit.RANDOM, 46 | metadata={"help": "How to initialize the prompt tuning parameters"}, 47 | ) 48 | prompt_tuning_init_text: Optional[str] = field( 49 | default=None, 50 | metadata={ 51 | "help": "The text to use for prompt tuning initialization. Only used if prompt_tuning_init is `TEXT`" 52 | }, 53 | ) 54 | tokenizer_name_or_path: Optional[str] = field( 55 | default=None, 56 | metadata={ 57 | "help": "The tokenizer to use for prompt tuning initialization. Only used if prompt_tuning_init is `TEXT`" 58 | }, 59 | ) 60 | 61 | def __post_init__(self): 62 | self.peft_type = PeftType.PROMPT_TUNING 63 | 64 | 65 | class PromptEmbedding(torch.nn.Module): 66 | """ 67 | The model to encode virtual tokens into prompt embeddings. 68 | 69 | Args: 70 | config ([`PromptTuningConfig`]): The configuration of the prompt embedding. 71 | word_embeddings (`torch.nn.Module`): The word embeddings of the base transformer model. 72 | 73 | **Attributes**: 74 | **embedding** (`torch.nn.Embedding`) -- The embedding layer of the prompt embedding. 75 | 76 | Example:: 77 | 78 | >>> from peft import PromptEmbedding, PromptTuningConfig >>> config = PromptTuningConfig( 79 | peft_type="PROMPT_TUNING", task_type="SEQ_2_SEQ_LM", num_virtual_tokens=20, token_dim=768, 80 | num_transformer_submodules=1, num_attention_heads=12, num_layers=12, prompt_tuning_init="TEXT", 81 | prompt_tuning_init_text="Predict if sentiment of this review is positive, negative or neutral", 82 | tokenizer_name_or_path="t5-base", 83 | ) 84 | >>> # t5_model.shared is the word embeddings of the base model >>> prompt_embedding = PromptEmbedding(config, 85 | t5_model.shared) 86 | 87 | 88 | Input Shape: (batch_size, total_virtual_tokens) 89 | 90 | Output Shape: (batch_size, total_virtual_tokens, token_dim) 91 | """ 92 | 93 | def __init__(self, config, word_embeddings): 94 | super().__init__() 95 | 96 | total_virtual_tokens = config.num_virtual_tokens * config.num_transformer_submodules 97 | self.embedding = torch.nn.Embedding(total_virtual_tokens, config.token_dim) 98 | if config.prompt_tuning_init == PromptTuningInit.TEXT: 99 | from transformers import AutoTokenizer 100 | 101 | tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_name_or_path) 102 | init_text = config.prompt_tuning_init_text 103 | init_token_ids = tokenizer(init_text)["input_ids"] 104 | # Trim or iterate until num_text_tokens matches total_virtual_tokens 105 | num_text_tokens = len(init_token_ids) 106 | if num_text_tokens > total_virtual_tokens: 107 | init_token_ids = init_token_ids[:total_virtual_tokens] 108 | elif num_text_tokens < total_virtual_tokens: 109 | num_reps = math.ceil(total_virtual_tokens / num_text_tokens) 110 | init_token_ids = init_token_ids * num_reps 111 | init_token_ids = init_token_ids[:total_virtual_tokens] 112 | 113 | word_embedding_weights = word_embeddings(torch.LongTensor(init_token_ids)).detach().clone() 114 | word_embedding_weights = word_embedding_weights.to(torch.float32) 115 | self.embedding.weight = torch.nn.Parameter(word_embedding_weights) 116 | 117 | def forward(self, indices): 118 | # Just get embeddings 119 | prompt_embeddings = self.embedding(indices) 120 | return prompt_embeddings 121 | -------------------------------------------------------------------------------- /scripts/training/peft/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | # There's no way to ignore "F401 '...' imported but unused" warnings in this 3 | # module, but to preserve other warnings. So, don't check this module at all 4 | 5 | # coding=utf-8 6 | # Copyright 2023-present the HuggingFace Inc. team. 7 | # 8 | # Licensed under the Apache License, Version 2.0 (the "License"); 9 | # you may not use this file except in compliance with the License. 10 | # You may obtain a copy of the License at 11 | # 12 | # http://www.apache.org/licenses/LICENSE-2.0 13 | # 14 | # Unless required by applicable law or agreed to in writing, software 15 | # distributed under the License is distributed on an "AS IS" BASIS, 16 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 17 | # See the License for the specific language governing permissions and 18 | # limitations under the License. 19 | 20 | from .adapters_utils import CONFIG_NAME, WEIGHTS_NAME 21 | from .config import PeftConfig, PeftType, PromptLearningConfig, TaskType 22 | from .other import ( 23 | TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING, 24 | _set_trainable, 25 | bloom_model_postprocess_past_key_value, 26 | # prepare_model_for_int8_training, 27 | shift_tokens_right, 28 | transpose, 29 | ) 30 | from .save_and_load import get_peft_model_state_dict, set_peft_model_state_dict 31 | -------------------------------------------------------------------------------- /scripts/training/peft/utils/adapters_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023-present the HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | WEIGHTS_NAME = "adapter_model.bin" 16 | CONFIG_NAME = "adapter_config.json" 17 | 18 | # TODO: add automapping and superclass here? 19 | -------------------------------------------------------------------------------- /scripts/training/peft/utils/config.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023-present the HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | import enum 16 | import json 17 | import os 18 | from dataclasses import asdict, dataclass, field 19 | from typing import Optional, Union 20 | 21 | from huggingface_hub import hf_hub_download 22 | from transformers.utils import PushToHubMixin 23 | 24 | from .adapters_utils import CONFIG_NAME 25 | 26 | 27 | class PeftType(str, enum.Enum): 28 | PROMPT_TUNING = "PROMPT_TUNING" 29 | P_TUNING = "P_TUNING" 30 | PREFIX_TUNING = "PREFIX_TUNING" 31 | LORA = "LORA" 32 | 33 | 34 | class TaskType(str, enum.Enum): 35 | SEQ_CLS = "SEQ_CLS" 36 | SEQ_2_SEQ_LM = "SEQ_2_SEQ_LM" 37 | CAUSAL_LM = "CAUSAL_LM" 38 | 39 | 40 | @dataclass 41 | class PeftConfigMixin(PushToHubMixin): 42 | r""" 43 | This is the base configuration class for PEFT adapter models. It contains all the methods that are common to all 44 | PEFT adapter models. This class inherits from `transformers.utils.PushToHubMixin` which contains the methods to 45 | push your model to the Hub. The method `save_pretrained` will save the configuration of your adapter model in a 46 | directory. The method `from_pretrained` will load the configuration of your adapter model from a directory. 47 | 48 | Args: 49 | peft_type (Union[[`~peft.utils.config.PeftType`], `str`]): The type of Peft method to use. 50 | """ 51 | peft_type: Optional[PeftType] = field(default=None, metadata={"help": "The type of PEFT model."}) 52 | 53 | @property 54 | def __dict__(self): 55 | return asdict(self) 56 | 57 | def to_dict(self): 58 | return self.__dict__ 59 | 60 | def save_pretrained(self, save_directory, **kwargs): 61 | r""" 62 | This method saves the configuration of your adapter model in a directory. 63 | 64 | Args: 65 | save_directory (`str`): 66 | The directory where the configuration will be saved. 67 | **kwargs: 68 | Additional keyword arguments passed along to the `transformers.utils.PushToHubMixin.push_to_hub` 69 | method. 70 | """ 71 | if os.path.isfile(save_directory): 72 | raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file") 73 | 74 | os.makedirs(save_directory, exist_ok=True) 75 | 76 | output_dict = self.__dict__ 77 | output_path = os.path.join(save_directory, CONFIG_NAME) 78 | 79 | # save it 80 | with open(output_path, "w") as writer: 81 | writer.write(json.dumps(output_dict, indent=2, sort_keys=True)) 82 | 83 | @classmethod 84 | def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): 85 | r""" 86 | This method loads the configuration of your adapter model from a directory. 87 | 88 | Args: 89 | pretrained_model_name_or_path (`str`): 90 | The directory or the hub-id where the configuration is saved. 91 | **kwargs: 92 | Additional keyword arguments passed along to the child class initialization. 93 | """ 94 | if os.path.isfile(os.path.join(pretrained_model_name_or_path, CONFIG_NAME)): 95 | config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME) 96 | else: 97 | try: 98 | config_file = hf_hub_download(pretrained_model_name_or_path, CONFIG_NAME) 99 | except Exception: 100 | raise ValueError(f"Can't find config.json at '{pretrained_model_name_or_path}'") 101 | 102 | loaded_attributes = cls.from_json_file(config_file) 103 | 104 | config = cls(**kwargs) 105 | 106 | for key, value in loaded_attributes.items(): 107 | if hasattr(config, key): 108 | setattr(config, key, value) 109 | 110 | return config 111 | 112 | @classmethod 113 | def from_json_file(cls, path_json_file, **kwargs): 114 | r""" 115 | Loads a configuration file from a json file. 116 | 117 | Args: 118 | path_json_file (`str`): 119 | The path to the json file. 120 | """ 121 | with open(path_json_file, "r") as file: 122 | json_object = json.load(file) 123 | 124 | return json_object 125 | 126 | 127 | @dataclass 128 | class PeftConfig(PeftConfigMixin): 129 | """ 130 | This is the base configuration class to store the configuration of a :class:`~peft.PeftModel`. 131 | 132 | Args: 133 | peft_type (Union[[`~peft.utils.config.PeftType`], `str`]): The type of Peft method to use. 134 | task_type (Union[[`~peft.utils.config.TaskType`], `str`]): The type of task to perform. 135 | inference_mode (`bool`, defaults to `False`): Whether to use the Peft model in inference mode. 136 | """ 137 | 138 | base_model_name_or_path: str = field(default=None, metadata={"help": "The name of the base model to use."}) 139 | peft_type: Union[str, PeftType] = field(default=None, metadata={"help": "Peft type"}) 140 | task_type: Union[str, TaskType] = field(default=None, metadata={"help": "Task type"}) 141 | inference_mode: bool = field(default=False, metadata={"help": "Whether to use inference mode"}) 142 | 143 | 144 | @dataclass 145 | class PromptLearningConfig(PeftConfig): 146 | """ 147 | This is the base configuration class to store the configuration of a Union[[`~peft.PrefixTuning`], 148 | [`~peft.PromptEncoder`], [`~peft.PromptTuning`]]. 149 | 150 | Args: 151 | num_virtual_tokens (`int`): The number of virtual tokens to use. 152 | token_dim (`int`): The hidden embedding dimension of the base transformer model. 153 | num_transformer_submodules (`int`): The number of transformer submodules in the base transformer model. 154 | num_attention_heads (`int`): The number of attention heads in the base transformer model. 155 | num_layers (`int`): The number of layers in the base transformer model. 156 | """ 157 | 158 | num_virtual_tokens: int = field(default=None, metadata={"help": "Number of virtual tokens"}) 159 | token_dim: int = field( 160 | default=None, metadata={"help": "The hidden embedding dimension of the base transformer model"} 161 | ) 162 | num_transformer_submodules: Optional[int] = field( 163 | default=None, metadata={"help": "Number of transformer submodules"} 164 | ) 165 | num_attention_heads: Optional[int] = field(default=None, metadata={"help": "Number of attention heads"}) 166 | num_layers: Optional[int] = field(default=None, metadata={"help": "Number of transformer layers"}) 167 | -------------------------------------------------------------------------------- /scripts/training/peft/utils/other.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023-present the HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import torch 17 | 18 | 19 | # needed for prefix-tuning of bloom model 20 | def bloom_model_postprocess_past_key_value(past_key_values): 21 | past_key_values = torch.cat(past_key_values) 22 | total_layers, batch_size, num_attention_heads, num_virtual_tokens, head_dim = past_key_values.shape 23 | keys = past_key_values[: total_layers // 2] 24 | keys = keys.transpose(2, 3).reshape( 25 | total_layers // 2, batch_size * num_attention_heads, head_dim, num_virtual_tokens 26 | ) 27 | values = past_key_values[total_layers // 2 :] 28 | values = values.reshape(total_layers // 2, batch_size * num_attention_heads, num_virtual_tokens, head_dim) 29 | 30 | return tuple(zip(keys, values)) 31 | 32 | 33 | TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING = { 34 | "bloom": bloom_model_postprocess_past_key_value, 35 | } 36 | 37 | 38 | # copied from transformers.models.bart.modeling_bart 39 | def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int): 40 | """ 41 | Shift input ids one token to the right. 42 | 43 | Args: 44 | input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): input ids 45 | pad_token_id (`int`): The id of the `padding` token. 46 | decoder_start_token_id (`int`): The id of the `start` token. 47 | """ 48 | shifted_input_ids = input_ids.new_zeros(input_ids.shape) 49 | shifted_input_ids[:, 1:] = input_ids[:, :-1].clone() 50 | shifted_input_ids[:, 0] = decoder_start_token_id 51 | 52 | if pad_token_id is None: 53 | raise ValueError("self.model.config.pad_token_id has to be defined.") 54 | # replace possible -100 values in labels by `pad_token_id` 55 | shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) 56 | 57 | return shifted_input_ids 58 | 59 | 60 | def _set_trainable(model): 61 | if model.modules_to_save is not None: 62 | for name, param in model.named_parameters(): 63 | if any(module_name in name for module_name in model.modules_to_save): 64 | param.requires_grad = True 65 | 66 | 67 | def fsdp_auto_wrap_policy(model): 68 | import functools 69 | import os 70 | 71 | from accelerate import FullyShardedDataParallelPlugin 72 | from torch.distributed.fsdp.wrap import _or_policy, lambda_auto_wrap_policy, transformer_auto_wrap_policy 73 | 74 | from ..tuners import PrefixEncoder, PromptEmbedding, PromptEncoder 75 | 76 | def lambda_policy_fn(module): 77 | if ( 78 | len(list(module.named_children())) == 0 79 | and getattr(module, "weight", None) is not None 80 | and module.weight.requires_grad 81 | ): 82 | return True 83 | return False 84 | 85 | lambda_policy = functools.partial(lambda_auto_wrap_policy, lambda_fn=lambda_policy_fn) 86 | transformer_wrap_policy = functools.partial( 87 | transformer_auto_wrap_policy, 88 | transformer_layer_cls=( 89 | PrefixEncoder, 90 | PromptEncoder, 91 | PromptEmbedding, 92 | FullyShardedDataParallelPlugin.get_module_class_from_name( 93 | model, os.environ.get("FSDP_TRANSFORMER_CLS_TO_WRAP", "") 94 | ), 95 | ), 96 | ) 97 | 98 | auto_wrap_policy = functools.partial(_or_policy, policies=[lambda_policy, transformer_wrap_policy]) 99 | return auto_wrap_policy 100 | 101 | 102 | def transpose(weight, fan_in_fan_out): 103 | return weight.T if fan_in_fan_out else weight 104 | -------------------------------------------------------------------------------- /scripts/training/peft/utils/save_and_load.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023-present the HuggingFace Inc. team. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from .config import PeftType 17 | 18 | 19 | def get_peft_model_state_dict(model, state_dict=None): 20 | """ 21 | Get the state dict of the Peft model. 22 | 23 | Args: 24 | model ([`PeftModel`]): The Peft model. When using torch.nn.DistributedDataParallel, DeepSpeed or FSDP, 25 | the model should be the underlying model/unwrapped model (i.e. model.module). 26 | state_dict (`dict`, *optional*, defaults to `None`): 27 | The state dict of the model. If not provided, the state dict of the model 28 | will be used. 29 | """ 30 | if state_dict is None: 31 | state_dict = model.state_dict() 32 | if model.peft_config.peft_type == PeftType.LORA: 33 | # to_return = lora_state_dict(model, bias=model.peft_config.bias) 34 | # adapted from `https://github.com/microsoft/LoRA/blob/main/loralib/utils.py` 35 | # to directly with the state dict which is necessary when using DeepSpeed or FSDP 36 | bias = model.peft_config.bias 37 | if bias == "none": 38 | to_return = {k: state_dict[k] for k in state_dict if "lora_" in k} 39 | elif bias == "all": 40 | to_return = {k: state_dict[k] for k in state_dict if "lora_" in k or "bias" in k} 41 | elif bias == "lora_only": 42 | to_return = {} 43 | for k in state_dict: 44 | if "lora_" in k: 45 | to_return[k] = state_dict[k] 46 | bias_name = k.split("lora_")[0] + "bias" 47 | if bias_name in state_dict: 48 | to_return[bias_name] = state_dict[bias_name] 49 | else: 50 | raise NotImplementedError 51 | else: 52 | to_return = {} 53 | if model.peft_config.inference_mode: 54 | prompt_embeddings = model.prompt_encoder.embedding.weight 55 | else: 56 | prompt_embeddings = model.get_prompt_embedding_to_save() 57 | to_return["prompt_embeddings"] = prompt_embeddings 58 | if model.modules_to_save is not None: 59 | for key, value in state_dict.items(): 60 | if any(module_name in key for module_name in model.modules_to_save): 61 | to_return[key] = value 62 | return to_return 63 | 64 | 65 | def set_peft_model_state_dict(model, peft_model_state_dict): 66 | """ 67 | Set the state dict of the Peft model. 68 | 69 | Args: 70 | model ([`PeftModel`]): The Peft model. 71 | peft_model_state_dict (`dict`): The state dict of the Peft model. 72 | """ 73 | 74 | model.load_state_dict(peft_model_state_dict, strict=False) 75 | if model.peft_config.peft_type != PeftType.LORA: 76 | model.prompt_encoder.embedding.load_state_dict( 77 | {"weight": peft_model_state_dict["prompt_embeddings"]}, strict=True 78 | ) 79 | return model 80 | -------------------------------------------------------------------------------- /scripts/training/run_pt.sh: -------------------------------------------------------------------------------- 1 | # 运行脚本前请仔细阅读wiki(https://github.com/ymcui/Chinese-LLaMA-Alpaca-2/wiki/pt_scripts_zh) 2 | # Read the wiki(https://github.com/ymcui/Chinese-LLaMA-Alpaca-2/wiki/pt_scripts_zh) carefully before running the script 3 | lr=2e-4 4 | lora_rank=64 5 | lora_alpha=128 6 | lora_trainable="q_proj,v_proj,k_proj,o_proj,gate_proj,down_proj,up_proj" 7 | modules_to_save="embed_tokens,lm_head" 8 | lora_dropout=0.05 9 | 10 | pretrained_model=path/to/hf/llama-2/dir 11 | chinese_tokenizer_path=path/to/chinese-llama-2/tokenizer/dir 12 | dataset_dir=path/to/pt/data/dir 13 | data_cache=temp_data_cache_dir 14 | per_device_train_batch_size=1 15 | gradient_accumulation_steps=8 16 | block_size=512 17 | output_dir=output_dir 18 | 19 | deepspeed_config_file=ds_zero2_no_offload.json 20 | 21 | torchrun --nnodes 1 --nproc_per_node 1 run_clm_pt_with_peft.py \ 22 | --deepspeed ${deepspeed_config_file} \ 23 | --model_name_or_path ${pretrained_model} \ 24 | --tokenizer_name_or_path ${chinese_tokenizer_path} \ 25 | --dataset_dir ${dataset_dir} \ 26 | --data_cache_dir ${data_cache} \ 27 | --validation_split_percentage 0.001 \ 28 | --per_device_train_batch_size ${per_device_train_batch_size} \ 29 | --do_train \ 30 | --seed $RANDOM \ 31 | --fp16 \ 32 | --num_train_epochs 1 \ 33 | --lr_scheduler_type cosine \ 34 | --learning_rate ${lr} \ 35 | --warmup_ratio 0.05 \ 36 | --weight_decay 0.01 \ 37 | --logging_strategy steps \ 38 | --logging_steps 10 \ 39 | --save_strategy steps \ 40 | --save_total_limit 3 \ 41 | --save_steps 200 \ 42 | --gradient_accumulation_steps ${gradient_accumulation_steps} \ 43 | --preprocessing_num_workers 8 \ 44 | --block_size ${block_size} \ 45 | --output_dir ${output_dir} \ 46 | --overwrite_output_dir \ 47 | --ddp_timeout 30000 \ 48 | --logging_first_step True \ 49 | --lora_rank ${lora_rank} \ 50 | --lora_alpha ${lora_alpha} \ 51 | --trainable ${lora_trainable} \ 52 | --lora_dropout ${lora_dropout} \ 53 | --modules_to_save ${modules_to_save} \ 54 | --torch_dtype float16 \ 55 | --load_in_kbits 16 \ 56 | --save_safetensors False \ 57 | --gradient_checkpointing \ 58 | --ddp_find_unused_parameters False 59 | -------------------------------------------------------------------------------- /scripts/training/run_sft.sh: -------------------------------------------------------------------------------- 1 | # 运行脚本前请仔细阅读wiki(https://github.com/ymcui/Chinese-LLaMA-Alpaca-2/wiki/sft_scripts_zh) 2 | # Read the wiki(https://github.com/ymcui/Chinese-LLaMA-Alpaca-2/wiki/sft_scripts_zh) carefully before running the script 3 | lr=1e-4 4 | lora_rank=64 5 | lora_alpha=128 6 | lora_trainable="q_proj,v_proj,k_proj,o_proj,gate_proj,down_proj,up_proj" 7 | modules_to_save="embed_tokens,lm_head" 8 | lora_dropout=0.05 9 | 10 | pretrained_model=path/to/hf/llama-2/or/chinese-llama-2/dir/or/model_id 11 | chinese_tokenizer_path=path/to/chinese-llama-2/tokenizer/dir 12 | dataset_dir=path/to/sft/data/dir 13 | per_device_train_batch_size=1 14 | per_device_eval_batch_size=1 15 | gradient_accumulation_steps=8 16 | max_seq_length=512 17 | output_dir=output_dir 18 | validation_file=validation_file_name 19 | 20 | deepspeed_config_file=ds_zero2_no_offload.json 21 | 22 | torchrun --nnodes 1 --nproc_per_node 1 run_clm_sft_with_peft.py \ 23 | --deepspeed ${deepspeed_config_file} \ 24 | --model_name_or_path ${pretrained_model} \ 25 | --tokenizer_name_or_path ${chinese_tokenizer_path} \ 26 | --dataset_dir ${dataset_dir} \ 27 | --per_device_train_batch_size ${per_device_train_batch_size} \ 28 | --per_device_eval_batch_size ${per_device_eval_batch_size} \ 29 | --do_train \ 30 | --do_eval \ 31 | --seed $RANDOM \ 32 | --fp16 \ 33 | --num_train_epochs 1 \ 34 | --lr_scheduler_type cosine \ 35 | --learning_rate ${lr} \ 36 | --warmup_ratio 0.03 \ 37 | --weight_decay 0 \ 38 | --logging_strategy steps \ 39 | --logging_steps 10 \ 40 | --save_strategy steps \ 41 | --save_total_limit 3 \ 42 | --evaluation_strategy steps \ 43 | --eval_steps 100 \ 44 | --save_steps 200 \ 45 | --gradient_accumulation_steps ${gradient_accumulation_steps} \ 46 | --preprocessing_num_workers 8 \ 47 | --max_seq_length ${max_seq_length} \ 48 | --output_dir ${output_dir} \ 49 | --overwrite_output_dir \ 50 | --ddp_timeout 30000 \ 51 | --logging_first_step True \ 52 | --lora_rank ${lora_rank} \ 53 | --lora_alpha ${lora_alpha} \ 54 | --trainable ${lora_trainable} \ 55 | --lora_dropout ${lora_dropout} \ 56 | --modules_to_save ${modules_to_save} \ 57 | --torch_dtype float16 \ 58 | --validation_file ${validation_file} \ 59 | --load_in_kbits 16 \ 60 | --save_safetensors False \ 61 | --gradient_checkpointing \ 62 | --ddp_find_unused_parameters False 63 | --------------------------------------------------------------------------------