├── .deepspeed_env ├── .github ├── ISSUE_TEMPLATE │ ├── bug_report.yaml │ └── feature-request.yaml └── PULL_REQUEST_TEMPLATE │ └── pr_template.md ├── .gitignore ├── LICENSE ├── MODEL_LICENSE ├── README.md ├── README_zh.md ├── assets ├── WECHAT.md ├── chat-min.png ├── chat.png ├── cogagent_function.jpg ├── cogagent_function_cn.jpg ├── cogagent_main_demo.jpg ├── compare-min.png ├── compare.png ├── llava-comparison-min.png ├── method-min.png ├── method.png ├── metrics-min.png ├── metrics.png ├── pear_grounding.png ├── web_demo-min.png ├── web_demo.png └── wechat.jpg ├── basic_demo ├── cli_demo_hf.py ├── cli_demo_sat.py └── web_demo.py ├── composite_demo ├── client.py ├── conversation.py ├── demo_agent_cogagent.py ├── demo_chat_cogagent.py ├── demo_chat_cogvlm.py ├── main.py └── utils.py ├── dataset.md ├── dataset_zh.md ├── finetune_demo ├── evaluate_cogagent.sh ├── evaluate_cogagent_demo.py ├── evaluate_cogvlm.sh ├── evaluate_cogvlm_demo.py ├── finetune_cogagent_demo.py ├── finetune_cogagent_lora.sh ├── finetune_cogvlm_demo.py ├── finetune_cogvlm_lora.sh └── test_config_bf16.json ├── openai_demo ├── demo.jpg ├── openai_api.py └── openai_api_request.py ├── requirements.txt └── utils ├── __init__.py ├── merge_model.py ├── models ├── __init__.py ├── cogagent_model.py ├── cogvlm_model.py ├── eva_clip_L_hf.py ├── eva_clip_model.py └── mixin.py ├── split_dataset.py └── utils ├── __init__.py ├── chat.py ├── dataset.py ├── grounding_parser.py ├── language.py ├── template.py └── vision.py /.deepspeed_env: -------------------------------------------------------------------------------- 1 | SAT_HOME=~/.sat_models 2 | LOCAL_WORLD_SIZE=8 -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.yaml: -------------------------------------------------------------------------------- 1 | name: "\U0001F41B Bug Report" 2 | description: Submit a bug report to help us improve ChatGLM3 / 提交一个 Bug 问题报告来帮助我们改进 ChatGLM3 3 | body: 4 | - type: textarea 5 | id: system-info 6 | attributes: 7 | label: System Info / 系統信息 8 | description: Your operating environment / 您的运行环境信息 9 | placeholder: Includes Cuda version, Transformers version, Python version, operating system, hardware information (if you suspect a hardware problem)... / 包括Cuda版本,Transformers版本,Python版本,操作系统,硬件信息(如果您怀疑是硬件方面的问题)... 10 | validations: 11 | required: true 12 | 13 | - type: textarea 14 | id: who-can-help 15 | attributes: 16 | label: Who can help? / 谁可以帮助到您? 17 | description: | 18 | Your issue will be replied to more quickly if you can figure out the right person to tag with @ 19 | All issues are read by one of the maintainers, so if you don't know who to tag, just leave this blank and our maintainer will ping the right person. 20 | 21 | Please tag fewer than 3 people. 22 | 23 | 如果您能找到合适的标签 @,您的问题会更快得到回复。 24 | 所有问题都会由我们的维护者阅读,如果您不知道该标记谁,只需留空,我们的维护人员会找到合适的开发组成员来解决问题。 25 | 26 | 标记的人数应该不超过 1 个人。 27 | 28 | Related demo leader / 相关demo负责人 : 29 | - finetune_demo: @1049451037 30 | - composite_demo: @zR 31 | - openai_demo: @zR 32 | 33 | 34 | If it's not a bug in these three subsections, you may not specify the helper. Our maintainer will find the right person in the development group to solve the problem. 35 | 36 | 如果不是这三个子版块的bug,您可以不指明帮助者,我们的维护人员会找到合适的开发组成员来解决问题。 37 | 38 | placeholder: "@Username ..." 39 | 40 | - type: checkboxes 41 | id: information-scripts-examples 42 | attributes: 43 | label: Information / 问题信息 44 | description: 'The problem arises when using: / 问题出现在' 45 | options: 46 | - label: "The official example scripts / 官方的示例脚本" 47 | - label: "My own modified scripts / 我自己修改的脚本和任务" 48 | 49 | - type: textarea 50 | id: reproduction 51 | validations: 52 | required: true 53 | attributes: 54 | label: Reproduction / 复现过程 55 | description: | 56 | Please provide a code example that reproduces the problem you encountered, preferably with a minimal reproduction unit. 57 | If you have code snippets, error messages, stack traces, please provide them here as well. 58 | Please format your code correctly using code tags. See https://help.github.com/en/github/writing-on-github/creating-and-highlighting-code-blocks#syntax-highlighting 59 | Do not use screenshots, as they are difficult to read and (more importantly) do not allow others to copy and paste your code. 60 | 61 | 请提供能重现您遇到的问题的代码示例,最好是最小复现单元。 62 | 如果您有代码片段、错误信息、堆栈跟踪,也请在此提供。 63 | 请使用代码标签正确格式化您的代码。请参见 https://help.github.com/en/github/writing-on-github/creating-and-highlighting-code-blocks#syntax-highlighting 64 | 请勿使用截图,因为截图难以阅读,而且(更重要的是)不允许他人复制粘贴您的代码。 65 | placeholder: | 66 | Steps to reproduce the behavior/复现Bug的步骤: 67 | 68 | 1. 69 | 2. 70 | 3. 71 | 72 | - type: textarea 73 | id: expected-behavior 74 | validations: 75 | required: true 76 | attributes: 77 | label: Expected behavior / 期待表现 78 | description: "A clear and concise description of what you would expect to happen. /简单描述您期望发生的事情。" -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature-request.yaml: -------------------------------------------------------------------------------- 1 | name: "\U0001F680 Feature request" 2 | description: Submit a request for a new ChatGLM3 feature / 提交一个新的 ChatGLM3 的功能建议 3 | labels: [ "feature" ] 4 | body: 5 | - type: textarea 6 | id: feature-request 7 | validations: 8 | required: true 9 | attributes: 10 | label: Feature request / 功能建议 11 | description: | 12 | A brief description of the functional proposal. Links to corresponding papers and code are desirable. 13 | 对功能建议的简述。最好提供对应的论文和代码链接 14 | 15 | - type: textarea 16 | id: motivation 17 | validations: 18 | required: true 19 | attributes: 20 | label: Motivation / 动机 21 | description: | 22 | Your motivation for making the suggestion. If that motivation is related to another GitHub issue, link to it here. 23 | 您提出建议的动机。如果该动机与另一个 GitHub 问题有关,请在此处提供对应的链接。 24 | 25 | - type: textarea 26 | id: contribution 27 | validations: 28 | required: true 29 | attributes: 30 | label: Your contribution / 您的贡献 31 | description: | 32 | 33 | Your PR link or any other link you can help with. 34 | 您的PR链接或者其他您能提供帮助的链接。 -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE/pr_template.md: -------------------------------------------------------------------------------- 1 | # Raise valuable PR / 提出有价值的PR 2 | 3 | ## Caution/ 注意事项: 4 | Users should keep the following points in mind when submitting PRs: 5 | 6 | 1. The proposed PR should be about this project. 7 | 2. the proposed PR should be relevant, if there are multiple ideas and optimizations, they should be assigned to different PRs. 8 | 9 | 用户在提交PR时候应该注意以下几点: 10 | 11 | 1. 提出的PR应该是关于本项目的。 12 | 2. 提出的PR应该具有针对性,如果具有多个不同的想法和优化方案,应该分配到不同的PR中。 13 | 14 | ## 不应该提出的PR / PRs that should not be proposed 15 | 16 | If a developer proposes a PR about any of the following, it may be closed or Rejected. 17 | 18 | 1. those that don't describe improvement options. 19 | 2. multiple issues of different types combined in one PR. 20 | 3. The proposed PR is highly duplicative of already existing PRs. 21 | 22 | 如果开发者提出关于以下方面的PR,则可能会被直接关闭或拒绝通过。 23 | 24 | 1. 没有说明改进方案的。 25 | 2. 多个不同类型的问题合并在一个PR中的。 26 | 3. 提出的PR与已经存在的PR高度重复的。 27 | 28 | 29 | # 检查您的PR 30 | - [ ] Have you read the Contributor Guidelines, Pull Request section? / 您是否阅读了贡献者指南、Pull Request 部分? 31 | - [ ] Has this been discussed/approved via a Github issue or forum? If so, add a link. / 是否通过 Github 问题或论坛讨论/批准过?如果是,请添加链接。 32 | - [ ] Did you make sure you updated the documentation with your changes? Here are the Documentation Guidelines, and here are the Documentation Formatting Tips. /您是否确保根据您的更改更新了文档?这里是文档指南,这里是文档格式化技巧。 33 | - [ ] Did you write new required tests? / 您是否编写了新的必要测试? 34 | - [ ] Are your PRs for only one issue / 您的PR是否仅针对一个问题 -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .hypothesis/ 2 | __pycache__ 3 | output.png 4 | fewshot-data/ 5 | checkpoints/ 6 | records.db 7 | server.py 8 | examples/*grounding.png 9 | archive* 10 | hostfile 11 | runs/ 12 | *.idea/ 13 | .DS_Store -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2024 CogVLM team @ Zhipu AI 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. -------------------------------------------------------------------------------- /MODEL_LICENSE: -------------------------------------------------------------------------------- 1 | The CogVLM License 2 | 3 | 1. Definitions 4 | 5 | “Licensor” means the CogVLM Model Team that distributes its Software. 6 | 7 | “Software” means the CogVLM model parameters made available under this license. 8 | 9 | 2. License Grant 10 | 11 | Under the terms and conditions of this license, the Licensor hereby grants you a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable, royalty-free copyright license. 12 | This license permits you to use all open-source models in this repository for academic research free. Users who wish to use the models for commercial purposes must register [here](https://open.bigmodel.cn/mla/form). 13 | Registered users may use the models for commercial activities free of charge, but must comply with all terms and conditions of this license. 14 | The license notice shall be included in all copies or substantial portions of the Software. 15 | 16 | 3. Restriction 17 | 18 | You will not use, copy, modify, merge, publish, distribute, reproduce, or create derivative works of the Software, in whole or in part, for any military, or illegal purposes. 19 | 20 | You will not use the Software for any act that may undermine China's national security and national unity, harm the public interest of society, or infringe upon the rights and interests of human beings. 21 | 22 | 4. Disclaimer 23 | 24 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 25 | 26 | 5. Limitation of Liability 27 | 28 | EXCEPT TO THE EXTENT PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL THEORY, WHETHER BASED IN TORT, NEGLIGENCE, CONTRACT, LIABILITY, OR OTHERWISE WILL ANY LICENSOR BE LIABLE TO YOU FOR ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES, OR ANY OTHER COMMERCIAL LOSSES, EVEN IF THE LICENSOR HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES. 29 | 30 | 6. Dispute Resolution 31 | 32 | This license shall be governed and construed in accordance with the laws of People’s Republic of China. Any dispute arising from or in connection with this License shall be submitted to Haidian District People's Court in Beijing. 33 | 34 | Note that the license is subject to update to a more comprehensive version. For any questions related to the license and copyright, please contact us at license@zhipuai.cn. 35 | 36 | 7. Llama2 and EVA-CLIP2 License 37 | 38 | For CogVLM-17B version, Llama2 license conditions (https://ai.meta.com/llama/license/) and EVA license conditions (MIT, https://github.com/baaivision/EVA/blob/master/LICENSE) Also applies to model weights. 39 | 40 | 41 | 1. 定义 42 | 43 | “许可方”是指分发其软件的 CogVLM 模型团队。 44 | 45 | “软件”是指根据本许可提供的 CogVLM 模型参数。 46 | 47 | 2. 许可授予 48 | 49 | 根据本许可的条款和条件,许可方特此授予您非排他性、全球性、不可转让、不可再许可、可撤销、免版税的版权许可。 50 | 本许可允许您免费使用本仓库中的所有开源模型进行学术研究,对于希望将模型用于商业目的的用户,需在[这里](https://open.bigmodel.cn/mla/form)完成登记。 51 | 经过登记的用户可以免费使用本模型进行商业活动,但必须遵守本许可的所有条款和条件。 52 | 上述版权声明和本许可声明应包含在本软件的所有副本或重要部分中。 53 | 54 | 3.限制 55 | 56 | 您不得出于任何军事或非法目的使用、复制、修改、合并、发布、分发、复制或创建本软件的全部或部分衍生作品。 57 | 58 | 您不得利用本软件从事任何危害国家安全和国家统一、危害社会公共利益、侵犯人身权益的行为。 59 | 60 | 4.免责声明 61 | 62 | 本软件“按原样”提供,不提供任何明示或暗示的保证,包括但不限于对适销性、特定用途的适用性和非侵权性的保证。 在任何情况下,作者或版权持有人均不对任何索赔、损害或其他责任负责,无论是在合同诉讼、侵权行为还是其他方面,由软件或软件的使用或其他交易引起、由软件引起或与之相关 软件。 63 | 64 | 5. 责任限制 65 | 66 | 除适用法律禁止的范围外,在任何情况下且根据任何法律理论,无论是基于侵权行为、疏忽、合同、责任或其他原因,任何许可方均不对您承担任何直接、间接、特殊、偶然、示范性、 或间接损害,或任何其他商业损失,即使许可人已被告知此类损害的可能性。 67 | 68 | 6.争议解决 69 | 70 | 本许可受中华人民共和国法律管辖并按其解释。 因本许可引起的或与本许可有关的任何争议应提交北京市海淀区人民法院。 71 | 72 | 请注意,许可证可能会更新到更全面的版本。 有关许可和版权的任何问题,请通过 license@zhipuai.cn 与我们联系。 73 | 74 | 7. Llama2 和 EVA-CLIP2 许可 75 | 76 | 针对 CogVLM-17B 版本, Llama2 许可条件 (https://ai.meta.com/llama/license/) 和 EVA 许可条件 (MIT, https://github.com/baaivision/EVA/blob/master/LICENSE) 同时适用于模型权重。 -------------------------------------------------------------------------------- /assets/WECHAT.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | 4 |

扫码关注公众号,加入「ChatGLM交流群」

5 |

Scan the QR code to follow the official account and join the "ChatGLM Discussion Group"

6 |
7 | -------------------------------------------------------------------------------- /assets/chat-min.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THUDM/CogVLM/f7283b2c8d26cd7f932d9a5f7f5f9307f568195d/assets/chat-min.png -------------------------------------------------------------------------------- /assets/chat.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THUDM/CogVLM/f7283b2c8d26cd7f932d9a5f7f5f9307f568195d/assets/chat.png -------------------------------------------------------------------------------- /assets/cogagent_function.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THUDM/CogVLM/f7283b2c8d26cd7f932d9a5f7f5f9307f568195d/assets/cogagent_function.jpg -------------------------------------------------------------------------------- /assets/cogagent_function_cn.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THUDM/CogVLM/f7283b2c8d26cd7f932d9a5f7f5f9307f568195d/assets/cogagent_function_cn.jpg -------------------------------------------------------------------------------- /assets/cogagent_main_demo.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THUDM/CogVLM/f7283b2c8d26cd7f932d9a5f7f5f9307f568195d/assets/cogagent_main_demo.jpg -------------------------------------------------------------------------------- /assets/compare-min.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THUDM/CogVLM/f7283b2c8d26cd7f932d9a5f7f5f9307f568195d/assets/compare-min.png -------------------------------------------------------------------------------- /assets/compare.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THUDM/CogVLM/f7283b2c8d26cd7f932d9a5f7f5f9307f568195d/assets/compare.png -------------------------------------------------------------------------------- /assets/llava-comparison-min.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THUDM/CogVLM/f7283b2c8d26cd7f932d9a5f7f5f9307f568195d/assets/llava-comparison-min.png -------------------------------------------------------------------------------- /assets/method-min.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THUDM/CogVLM/f7283b2c8d26cd7f932d9a5f7f5f9307f568195d/assets/method-min.png -------------------------------------------------------------------------------- /assets/method.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THUDM/CogVLM/f7283b2c8d26cd7f932d9a5f7f5f9307f568195d/assets/method.png -------------------------------------------------------------------------------- /assets/metrics-min.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THUDM/CogVLM/f7283b2c8d26cd7f932d9a5f7f5f9307f568195d/assets/metrics-min.png -------------------------------------------------------------------------------- /assets/metrics.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THUDM/CogVLM/f7283b2c8d26cd7f932d9a5f7f5f9307f568195d/assets/metrics.png -------------------------------------------------------------------------------- /assets/pear_grounding.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THUDM/CogVLM/f7283b2c8d26cd7f932d9a5f7f5f9307f568195d/assets/pear_grounding.png -------------------------------------------------------------------------------- /assets/web_demo-min.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THUDM/CogVLM/f7283b2c8d26cd7f932d9a5f7f5f9307f568195d/assets/web_demo-min.png -------------------------------------------------------------------------------- /assets/web_demo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THUDM/CogVLM/f7283b2c8d26cd7f932d9a5f7f5f9307f568195d/assets/web_demo.png -------------------------------------------------------------------------------- /assets/wechat.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THUDM/CogVLM/f7283b2c8d26cd7f932d9a5f7f5f9307f568195d/assets/wechat.jpg -------------------------------------------------------------------------------- /basic_demo/cli_demo_hf.py: -------------------------------------------------------------------------------- 1 | """ 2 | This is a demo for using CogAgent and CogVLM in CLI 3 | Make sure you have installed vicuna-7b-v1.5 tokenizer model (https://huggingface.co/lmsys/vicuna-7b-v1.5), full checkpoint of vicuna-7b-v1.5 LLM is not required. 4 | In this demo, We us chat template, you can use others to replace such as 'vqa'. 5 | Strongly suggest to use GPU with bfloat16 support, otherwise, it will be slow. 6 | Mention that only one picture can be processed at one conversation, which means you can not replace or insert another picture during the conversation. 7 | """ 8 | 9 | import argparse 10 | import torch 11 | 12 | from PIL import Image 13 | from transformers import AutoModelForCausalLM, LlamaTokenizer 14 | 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument("--quant", choices=[4], type=int, default=None, help='quantization bits') 17 | parser.add_argument("--from_pretrained", type=str, default="THUDM/cogagent-chat-hf", help='pretrained ckpt') 18 | parser.add_argument("--local_tokenizer", type=str, default="lmsys/vicuna-7b-v1.5", help='tokenizer path') 19 | parser.add_argument("--fp16", action="store_true") 20 | parser.add_argument("--bf16", action="store_true") 21 | 22 | args = parser.parse_args() 23 | MODEL_PATH = args.from_pretrained 24 | TOKENIZER_PATH = args.local_tokenizer 25 | DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' 26 | 27 | tokenizer = LlamaTokenizer.from_pretrained(TOKENIZER_PATH) 28 | if args.bf16: 29 | torch_type = torch.bfloat16 30 | else: 31 | torch_type = torch.float16 32 | 33 | print("========Use torch type as:{} with device:{}========\n\n".format(torch_type, DEVICE)) 34 | 35 | if args.quant: 36 | model = AutoModelForCausalLM.from_pretrained( 37 | MODEL_PATH, 38 | torch_dtype=torch_type, 39 | low_cpu_mem_usage=True, 40 | load_in_4bit=True, 41 | trust_remote_code=True 42 | ).eval() 43 | else: 44 | model = AutoModelForCausalLM.from_pretrained( 45 | MODEL_PATH, 46 | torch_dtype=torch_type, 47 | low_cpu_mem_usage=True, 48 | load_in_4bit=args.quant is not None, 49 | trust_remote_code=True 50 | ).to(DEVICE).eval() 51 | 52 | text_only_template = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: {} ASSISTANT:" 53 | 54 | while True: 55 | image_path = input("image path >>>>> ") 56 | if image_path == '': 57 | print('You did not enter image path, the following will be a plain text conversation.') 58 | image = None 59 | text_only_first_query = True 60 | else: 61 | image = Image.open(image_path).convert('RGB') 62 | 63 | history = [] 64 | 65 | while True: 66 | query = input("Human:") 67 | if query == "clear": 68 | break 69 | 70 | if image is None: 71 | if text_only_first_query: 72 | query = text_only_template.format(query) 73 | text_only_first_query = False 74 | else: 75 | old_prompt = '' 76 | for _, (old_query, response) in enumerate(history): 77 | old_prompt += old_query + " " + response + "\n" 78 | query = old_prompt + "USER: {} ASSISTANT:".format(query) 79 | 80 | if image is None: 81 | input_by_model = model.build_conversation_input_ids(tokenizer, query=query, history=history, template_version='base') 82 | else: 83 | input_by_model = model.build_conversation_input_ids(tokenizer, query=query, history=history, images=[image]) 84 | 85 | inputs = { 86 | 'input_ids': input_by_model['input_ids'].unsqueeze(0).to(DEVICE), 87 | 'token_type_ids': input_by_model['token_type_ids'].unsqueeze(0).to(DEVICE), 88 | 'attention_mask': input_by_model['attention_mask'].unsqueeze(0).to(DEVICE), 89 | 'images': [[input_by_model['images'][0].to(DEVICE).to(torch_type)]] if image is not None else None, 90 | } 91 | if 'cross_images' in input_by_model and input_by_model['cross_images']: 92 | inputs['cross_images'] = [[input_by_model['cross_images'][0].to(DEVICE).to(torch_type)]] 93 | 94 | # add any transformers params here. 95 | gen_kwargs = {"max_length": 2048, 96 | "do_sample": False} # "temperature": 0.9 97 | with torch.no_grad(): 98 | outputs = model.generate(**inputs, **gen_kwargs) 99 | outputs = outputs[:, inputs['input_ids'].shape[1]:] 100 | response = tokenizer.decode(outputs[0]) 101 | response = response.split("")[0] 102 | print("\nCog:", response) 103 | history.append((query, response)) 104 | -------------------------------------------------------------------------------- /basic_demo/cli_demo_sat.py: -------------------------------------------------------------------------------- 1 | # -*- encoding: utf-8 -*- 2 | import os, sys 3 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 4 | 5 | import torch 6 | import argparse 7 | from sat.model.mixins import CachedAutoregressiveMixin 8 | from sat.quantization.kernels import quantize 9 | from sat.model import AutoModel 10 | 11 | 12 | from utils.utils import chat, llama2_tokenizer, llama2_text_processor_inference, get_image_processor 13 | from utils.models import CogAgentModel, CogVLMModel 14 | 15 | def main(): 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument("--max_length", type=int, default=2048, help='max length of the total sequence') 18 | parser.add_argument("--top_p", type=float, default=0.4, help='top p for nucleus sampling') 19 | parser.add_argument("--top_k", type=int, default=1, help='top k for top k sampling') 20 | parser.add_argument("--temperature", type=float, default=.8, help='temperature for sampling') 21 | parser.add_argument("--chinese", action='store_true', help='Chinese interface') 22 | parser.add_argument("--version", type=str, default="chat", choices=['chat', 'vqa', 'chat_old', 'base'], help='version of language process. if there is \"text_processor_version\" in model_config.json, this option will be overwritten') 23 | parser.add_argument("--quant", choices=[8, 4], type=int, default=None, help='quantization bits') 24 | 25 | parser.add_argument("--from_pretrained", type=str, default="cogagent-chat", help='pretrained ckpt') 26 | parser.add_argument("--local_tokenizer", type=str, default="lmsys/vicuna-7b-v1.5", help='tokenizer path') 27 | parser.add_argument("--fp16", action="store_true") 28 | parser.add_argument("--bf16", action="store_true") 29 | parser.add_argument("--stream_chat", action="store_true") 30 | args = parser.parse_args() 31 | rank = int(os.environ.get('RANK', 0)) 32 | world_size = int(os.environ.get('WORLD_SIZE', 1)) 33 | args = parser.parse_args() 34 | 35 | # load model 36 | model, model_args = AutoModel.from_pretrained( 37 | args.from_pretrained, 38 | args=argparse.Namespace( 39 | deepspeed=None, 40 | local_rank=rank, 41 | rank=rank, 42 | world_size=world_size, 43 | model_parallel_size=world_size, 44 | mode='inference', 45 | skip_init=True, 46 | use_gpu_initialization=True if (torch.cuda.is_available() and args.quant is None) else False, 47 | device='cpu' if args.quant else 'cuda', 48 | **vars(args) 49 | ), overwrite_args={'model_parallel_size': world_size} if world_size != 1 else {}) 50 | model = model.eval() 51 | from sat.mpu import get_model_parallel_world_size 52 | assert world_size == get_model_parallel_world_size(), "world size must equal to model parallel size for cli_demo!" 53 | 54 | language_processor_version = model_args.text_processor_version if 'text_processor_version' in model_args else args.version 55 | print("[Language processor version]:", language_processor_version) 56 | tokenizer = llama2_tokenizer(args.local_tokenizer, signal_type=language_processor_version) 57 | image_processor = get_image_processor(model_args.eva_args["image_size"][0]) 58 | cross_image_processor = get_image_processor(model_args.cross_image_pix) if "cross_image_pix" in model_args else None 59 | 60 | if args.quant: 61 | quantize(model, args.quant) 62 | if torch.cuda.is_available(): 63 | model = model.cuda() 64 | 65 | 66 | model.add_mixin('auto-regressive', CachedAutoregressiveMixin()) 67 | 68 | text_processor_infer = llama2_text_processor_inference(tokenizer, args.max_length, model.image_length) 69 | 70 | if args.chinese: 71 | if rank == 0: 72 | print('欢迎使用 CogAgent-CLI ,输入图像URL或本地路径读图,继续输入内容对话,clear 重新开始,stop 终止程序') 73 | else: 74 | if rank == 0: 75 | print('Welcome to CogAgent-CLI. Enter an image URL or local file path to load an image. Continue inputting text to engage in a conversation. Type "clear" to start over, or "stop" to end the program.') 76 | with torch.no_grad(): 77 | while True: 78 | history = None 79 | cache_image = None 80 | if args.chinese: 81 | if rank == 0: 82 | image_path = [input("请输入图像路径或URL: ")] 83 | else: 84 | image_path = [None] 85 | else: 86 | if rank == 0: 87 | image_path = [input("Please enter the image path or URL: ")] 88 | else: 89 | image_path = [None] 90 | if world_size > 1: 91 | torch.distributed.broadcast_object_list(image_path, 0) 92 | image_path = image_path[0] 93 | assert image_path is not None 94 | 95 | if image_path == 'stop': 96 | break 97 | 98 | if args.chinese: 99 | if rank == 0: 100 | query = [input("用户:")] 101 | else: 102 | query = [None] 103 | else: 104 | if rank == 0: 105 | query = [input("User: ")] 106 | else: 107 | query = [None] 108 | if world_size > 1: 109 | torch.distributed.broadcast_object_list(query, 0) 110 | query = query[0] 111 | assert query is not None 112 | 113 | while True: 114 | if query == "clear": 115 | break 116 | if query == "stop": 117 | sys.exit(0) 118 | try: 119 | response, history, cache_image = chat( 120 | image_path, 121 | model, 122 | text_processor_infer, 123 | image_processor, 124 | query, 125 | history=history, 126 | cross_img_processor=cross_image_processor, 127 | image=cache_image, 128 | max_length=args.max_length, 129 | top_p=args.top_p, 130 | temperature=args.temperature, 131 | top_k=args.top_k, 132 | invalid_slices=text_processor_infer.invalid_slices, 133 | args=args 134 | ) 135 | except Exception as e: 136 | print(e) 137 | break 138 | if rank == 0 and not args.stream_chat: 139 | if args.chinese: 140 | print("模型:"+response) 141 | else: 142 | print("Model: "+response) 143 | image_path = None 144 | if args.chinese: 145 | if rank == 0: 146 | query = [input("用户:")] 147 | else: 148 | query = [None] 149 | else: 150 | if rank == 0: 151 | query = [input("User: ")] 152 | else: 153 | query = [None] 154 | if world_size > 1: 155 | torch.distributed.broadcast_object_list(query, 0) 156 | query = query[0] 157 | assert query is not None 158 | 159 | 160 | if __name__ == "__main__": 161 | main() 162 | -------------------------------------------------------------------------------- /basic_demo/web_demo.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script is a simple web demo of the CogVLM and CogAgent models, designed for easy and quick demonstrations. 3 | For a more sophisticated user interface, users are encouraged to refer to the 'composite_demo', 4 | which is built with a more aesthetically pleasing Streamlit framework. 5 | 6 | Usage: 7 | - Use the interface to upload images and enter text prompts to interact with the models. 8 | 9 | Requirements: 10 | - Gradio (only 3.x,4.x is not support) and other necessary Python dependencies must be installed. 11 | - Proper model checkpoints should be accessible as specified in the script. 12 | 13 | Note: This demo is ideal for a quick showcase of the CogVLM and CogAgent models. For a more comprehensive and interactive 14 | experience, refer to the 'composite_demo'. 15 | """ 16 | import gradio as gr 17 | import os, sys 18 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 19 | 20 | from PIL import Image 21 | import torch 22 | import time 23 | from sat.model.mixins import CachedAutoregressiveMixin 24 | from sat.mpu import get_model_parallel_world_size 25 | from sat.model import AutoModel 26 | 27 | 28 | from utils.utils import chat, llama2_tokenizer, llama2_text_processor_inference, get_image_processor, parse_response 29 | from utils.models import CogAgentModel, CogVLMModel 30 | 31 | 32 | 33 | DESCRIPTION = '''

CogVLM / CogAgent

''' 34 | 35 | NOTES = '

This app is adapted from https://github.com/THUDM/CogVLM. It would be recommended to check out the repo if you want to see the detail of our model, CogVLM & CogAgent.

' 36 | 37 | MAINTENANCE_NOTICE1 = 'Hint 1: If the app report "Something went wrong, connection error out", please turn off your proxy and retry.
Hint 2: If you upload a large size of image like 10MB, it may take some time to upload and process. Please be patient and wait.' 38 | 39 | 40 | AGENT_NOTICE = 'Hint 1: To use Agent function, please use the prompts for agents.' 41 | 42 | GROUNDING_NOTICE = 'Hint 2: To use Grounding function, please use the prompts for grounding.' 43 | 44 | 45 | 46 | 47 | default_chatbox = [("", "Hi, What do you want to know about this image?")] 48 | 49 | 50 | model = image_processor = text_processor_infer = None 51 | 52 | is_grounding = False 53 | 54 | def process_image_without_resize(image_prompt): 55 | image = Image.open(image_prompt) 56 | # print(f"height:{image.height}, width:{image.width}") 57 | timestamp = int(time.time()) 58 | file_ext = os.path.splitext(image_prompt)[1] 59 | filename_grounding = f"examples/{timestamp}_grounding{file_ext}" 60 | return image, filename_grounding 61 | 62 | from sat.quantization.kernels import quantize 63 | 64 | def load_model(args): 65 | model, model_args = AutoModel.from_pretrained( 66 | args.from_pretrained, 67 | args=argparse.Namespace( 68 | deepspeed=None, 69 | local_rank=0, 70 | rank=0, 71 | world_size=world_size, 72 | model_parallel_size=world_size, 73 | mode='inference', 74 | fp16=args.fp16, 75 | bf16=args.bf16, 76 | skip_init=True, 77 | use_gpu_initialization=True if (torch.cuda.is_available() and args.quant is None) else False, 78 | device='cpu' if args.quant else 'cuda'), 79 | overwrite_args={'model_parallel_size': world_size} if world_size != 1 else {} 80 | ) 81 | model = model.eval() 82 | assert world_size == get_model_parallel_world_size(), "world size must equal to model parallel size for cli_demo!" 83 | 84 | language_processor_version = model_args.text_processor_version if 'text_processor_version' in model_args else args.version 85 | tokenizer = llama2_tokenizer(args.local_tokenizer, signal_type=language_processor_version) 86 | image_processor = get_image_processor(model_args.eva_args["image_size"][0]) 87 | cross_image_processor = get_image_processor(model_args.cross_image_pix) if "cross_image_pix" in model_args else None 88 | 89 | if args.quant: 90 | quantize(model, args.quant) 91 | if torch.cuda.is_available(): 92 | model = model.cuda() 93 | model.add_mixin('auto-regressive', CachedAutoregressiveMixin()) 94 | 95 | text_processor_infer = llama2_text_processor_inference(tokenizer, args.max_length, model.image_length) 96 | 97 | return model, image_processor, cross_image_processor, text_processor_infer 98 | 99 | 100 | def post( 101 | input_text, 102 | temperature, 103 | top_p, 104 | top_k, 105 | image_prompt, 106 | result_previous, 107 | hidden_image, 108 | state 109 | ): 110 | result_text = [(ele[0], ele[1]) for ele in result_previous] 111 | for i in range(len(result_text)-1, -1, -1): 112 | if result_text[i][0] == "" or result_text[i][0] == None: 113 | del result_text[i] 114 | print(f"history {result_text}") 115 | 116 | global model, image_processor, cross_image_processor, text_processor_infer, is_grounding 117 | 118 | try: 119 | with torch.no_grad(): 120 | pil_img, image_path_grounding = process_image_without_resize(image_prompt) 121 | response, _, cache_image = chat( 122 | image_path="", 123 | model=model, 124 | text_processor=text_processor_infer, 125 | img_processor=image_processor, 126 | query=input_text, 127 | history=result_text, 128 | cross_img_processor=cross_image_processor, 129 | image=pil_img, 130 | max_length=2048, 131 | top_p=top_p, 132 | temperature=temperature, 133 | top_k=top_k, 134 | invalid_slices=text_processor_infer.invalid_slices if hasattr(text_processor_infer, "invalid_slices") else [], 135 | no_prompt=False, 136 | args=state['args'] 137 | ) 138 | except Exception as e: 139 | print("error message", e) 140 | result_text.append((input_text, 'Timeout! Please wait a few minutes and retry.')) 141 | return "", result_text, hidden_image 142 | 143 | answer = response 144 | if is_grounding: 145 | parse_response(pil_img, answer, image_path_grounding) 146 | new_answer = answer.replace(input_text, "") 147 | result_text.append((input_text, new_answer)) 148 | result_text.append((None, (image_path_grounding,))) 149 | else: 150 | result_text.append((input_text, answer)) 151 | print(result_text) 152 | print('finished') 153 | return "", result_text, hidden_image 154 | 155 | 156 | def clear_fn(value): 157 | return "", default_chatbox, None 158 | 159 | def clear_fn2(value): 160 | return default_chatbox 161 | 162 | 163 | def main(args): 164 | global model, image_processor, cross_image_processor, text_processor_infer, is_grounding 165 | model, image_processor, cross_image_processor, text_processor_infer = load_model(args) 166 | is_grounding = 'grounding' in args.from_pretrained 167 | 168 | gr.close_all() 169 | 170 | with gr.Blocks(css='style.css') as demo: 171 | state = gr.State({'args': args}) 172 | 173 | gr.Markdown(DESCRIPTION) 174 | gr.Markdown(NOTES) 175 | 176 | 177 | with gr.Row(): 178 | with gr.Column(scale=5): 179 | with gr.Group(): 180 | gr.Markdown(AGENT_NOTICE) 181 | gr.Markdown(GROUNDING_NOTICE) 182 | input_text = gr.Textbox(label='Input Text', placeholder='Please enter text prompt below and press ENTER.') 183 | 184 | with gr.Row(): 185 | run_button = gr.Button('Generate') 186 | clear_button = gr.Button('Clear') 187 | 188 | image_prompt = gr.Image(type="filepath", label="Image Prompt", value=None) 189 | 190 | with gr.Row(): 191 | temperature = gr.Slider(maximum=1, value=0.8, minimum=0, label='Temperature') 192 | top_p = gr.Slider(maximum=1, value=0.4, minimum=0, label='Top P') 193 | top_k = gr.Slider(maximum=100, value=10, minimum=1, step=1, label='Top K') 194 | 195 | with gr.Column(scale=5): 196 | result_text = gr.components.Chatbot(label='Multi-round conversation History', value=[("", "Hi, What do you want to know about this image?")], height=600) 197 | hidden_image_hash = gr.Textbox(visible=False) 198 | 199 | 200 | gr.Markdown(MAINTENANCE_NOTICE1) 201 | 202 | print(gr.__version__) 203 | run_button.click(fn=post,inputs=[input_text, temperature, top_p, top_k, image_prompt, result_text, hidden_image_hash, state], 204 | outputs=[input_text, result_text, hidden_image_hash]) 205 | input_text.submit(fn=post,inputs=[input_text, temperature, top_p, top_k, image_prompt, result_text, hidden_image_hash, state], 206 | outputs=[input_text, result_text, hidden_image_hash]) 207 | clear_button.click(fn=clear_fn, inputs=clear_button, outputs=[input_text, result_text, image_prompt]) 208 | image_prompt.upload(fn=clear_fn2, inputs=clear_button, outputs=[result_text]) 209 | image_prompt.clear(fn=clear_fn2, inputs=clear_button, outputs=[result_text]) 210 | 211 | 212 | # demo.queue(concurrency_count=10) 213 | demo.launch() 214 | 215 | 216 | if __name__ == '__main__': 217 | import argparse 218 | parser = argparse.ArgumentParser() 219 | parser.add_argument("--max_length", type=int, default=2048, help='max length of the total sequence') 220 | parser.add_argument("--top_p", type=float, default=0.4, help='top p for nucleus sampling') 221 | parser.add_argument("--top_k", type=int, default=1, help='top k for top k sampling') 222 | parser.add_argument("--temperature", type=float, default=.8, help='temperature for sampling') 223 | parser.add_argument("--version", type=str, default="chat", choices=['chat', 'vqa', 'chat_old', 'base'], help='version of language process. if there is \"text_processor_version\" in model_config.json, this option will be overwritten') 224 | parser.add_argument("--quant", choices=[8, 4], type=int, default=None, help='quantization bits') 225 | parser.add_argument("--from_pretrained", type=str, default="cogagent-chat", help='pretrained ckpt') 226 | parser.add_argument("--local_tokenizer", type=str, default="lmsys/vicuna-7b-v1.5", help='tokenizer path') 227 | parser.add_argument("--fp16", action="store_true") 228 | parser.add_argument("--bf16", action="store_true") 229 | parser.add_argument("--stream_chat", action="store_true") 230 | args = parser.parse_args() 231 | rank = int(os.environ.get('RANK', 0)) 232 | world_size = int(os.environ.get('WORLD_SIZE', 1)) 233 | args = parser.parse_args() 234 | main(args) 235 | -------------------------------------------------------------------------------- /composite_demo/client.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | from threading import Thread 3 | 4 | import streamlit as st 5 | import torch 6 | import warnings 7 | import os 8 | 9 | from typing import Any, Protocol 10 | from collections.abc import Iterable 11 | from huggingface_hub.inference._text_generation import TextGenerationStreamResponse, Token 12 | from transformers import AutoTokenizer, TextIteratorStreamer, AutoModelForCausalLM 13 | from conversation import Conversation 14 | 15 | # Check if GPU supports bfloat16 16 | 17 | if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8: 18 | torch_type = torch.bfloat16 19 | else: 20 | torch_type = torch.float16 21 | warnings.warn("Your GPU does not support bfloat16 type, use fp16 instead") 22 | 23 | # if you use all of Our model, include cogagent-chat cogvlm-chat cogvlm-grounding and put it in different devices, you can do like this. 24 | models_info = { 25 | 'tokenizer': { 26 | 'path': os.environ.get('TOKENIZER_PATH', 'lmsys/vicuna-7b-v1.5'), 27 | }, 28 | 'agent_chat': { 29 | 'path': os.environ.get('MODEL_PATH_AGENT_CHAT', 'THUDM/cogagent-chat-hf'), 30 | 'device': ['cuda:0'] 31 | }, 32 | 'vlm_chat': { 33 | 'path': os.environ.get('MODEL_PATH_VLM_CHAT', 'THUDM/cogvlm-chat-hf'), 34 | 'device': ['cuda:3'] 35 | }, 36 | 'vlm_grounding': { 37 | 'path': os.environ.get('MODEL_PATH_VLM_GROUNDING','THUDM/cogvlm-grounding-generalist-hf'), 38 | 'device': ['cuda:6'] 39 | } 40 | } 41 | 42 | 43 | # if you just use one model, use like this 44 | # models_info = { 45 | # 'tokenizer': { 46 | # 'path': os.environ.get('TOKENIZER_PATH', 'lmsys/vicuna-7b-v1.5'), 47 | # }, 48 | # 'agent_chat': { 49 | # 'path': os.environ.get('MODEL_PATH_AGENT_CHAT', 'THUDM/cogagent-chat-hf'), 50 | # 'device': ['cuda:0'] 51 | # }, 52 | 53 | 54 | 55 | @st.cache_resource 56 | def get_client() -> Client: 57 | client = HFClient(models_info) 58 | return client 59 | 60 | 61 | def process_history(history: list[Conversation]): 62 | """ 63 | Process the input history to extract the query and the history pairs. 64 | Args: 65 | History(list[Conversation]): A list of Conversation objects representing all conversations. 66 | Returns: 67 | query(str): The current user input string. 68 | history_pairs(list[(str,str)]): A list of (user, assistant) pairs. 69 | last_user_image(Image): The last user image. Only the latest image. 70 | 71 | """ 72 | history_pairs = [] 73 | query = "" 74 | last_user_image = None 75 | 76 | user_text = None 77 | for i, conversation in enumerate(history): 78 | if conversation.role == conversation.role.USER: 79 | user_text = conversation.content 80 | if conversation.image: 81 | last_user_image = conversation.image 82 | 83 | if i == len(history) - 1: 84 | query = conversation.content 85 | 86 | else: 87 | if user_text is not None: 88 | history_pairs.append((user_text, conversation.content)) 89 | user_text = None 90 | return query, history_pairs, last_user_image 91 | 92 | 93 | class Client(Protocol): 94 | def generate_stream(self, 95 | history: list[Conversation], 96 | grounding: bool = False, 97 | model_use: str = 'agent_chat', 98 | **parameters: Any 99 | ) -> Iterable[TextGenerationStreamResponse]: 100 | ... 101 | 102 | 103 | class HFClient(Client): 104 | """ 105 | The HFClient class manages the interaction with various large language models 106 | for text generation tasks. It supports handling multiple models, each designated 107 | for a specific task like chatting or grounding. 108 | 109 | Args: 110 | models_info (dict): A dictionary containing the configuration for each model. 111 | The dictionary format is: 112 | - 'tokenizer': Path and settings for the tokenizer. 113 | - 'agent_chat': Path and settings for the CogAgent-chat-18B model. 114 | - 'vlm_chat': Path and settings for the CogVLM-chat-17B model. 115 | - 'vlm_grounding': Path and settings for the CogVLM-grounding-17B model. 116 | 117 | The class loads each model based on the provided information and assigns it to the 118 | specified CUDA device. It also handles the tokenizer used across all models. 119 | """ 120 | def __init__(self, models_info): 121 | self.models = {} 122 | self.tokenizer = AutoTokenizer.from_pretrained(models_info['tokenizer']['path'], trust_remote_code=True) 123 | for model_name, model_info in models_info.items(): 124 | if model_name != 'tokenizer': 125 | self.models[model_name] = [] 126 | for device in model_info['device']: 127 | model = AutoModelForCausalLM.from_pretrained( 128 | model_info['path'], 129 | torch_dtype=torch_type, 130 | low_cpu_mem_usage=True, 131 | trust_remote_code=True, 132 | ).to(device).eval() 133 | self.models[model_name].append(model) 134 | 135 | def select_best_gpu(self, model_name): 136 | min_memory_used = None 137 | selected_model = None 138 | 139 | for model in self.models[model_name]: 140 | device = next(model.parameters()).device 141 | mem_used = torch.cuda.memory_allocated(device=device) 142 | 143 | if min_memory_used is None or mem_used < min_memory_used: 144 | min_memory_used = mem_used 145 | selected_model = model 146 | 147 | return selected_model 148 | 149 | def generate_stream(self, 150 | history: list, 151 | grounding: bool = False, 152 | model_use: str = 'agent_chat', 153 | **parameters: Any 154 | ) -> Iterable[TextGenerationStreamResponse]: 155 | """ 156 | Generates a stream of text responses based on the input history and selected model. 157 | 158 | This method facilitates a chat-like interaction with the models. Depending on the 159 | model selected and whether grounding is enabled, it alters the behavior of the text 160 | generation process. 161 | 162 | Args: 163 | history (list[Conversation]): A list of Conversation objects representing the 164 | dialogue history. 165 | grounding (bool, optional): A flag to indicate whether grounding should be used 166 | in the generation process. Defaults to False. 167 | model_use (str, optional): The key name of the model to be used for the generation. 168 | Defaults to 'agent_chat'. 169 | **parameters (Any): Additional parameters that may be required for the generation 170 | process. 171 | 172 | Yields: 173 | Iterable[TextGenerationStreamResponse]: A stream of text generation responses, each 174 | encapsulating a generated piece of text. 175 | 176 | The method selects the appropriate model based on `model_use`, processes the input 177 | history, and feeds it into the model to generate text. It uses threading to handle 178 | the generation process efficiently. 179 | """ 180 | query, history, image = process_history(history) 181 | if grounding: 182 | query += "(with grounding)" 183 | 184 | model = self.select_best_gpu(model_use) 185 | device = next(model.parameters()).device 186 | 187 | # Print user input info 188 | 189 | print("\n== Input ==\n", query) 190 | print("\n==History==\n", history) 191 | print("\n== Model ==\n\n", model.config.name_or_path) 192 | print("\n== Device ==\n\n", device) 193 | 194 | input_by_model = model.build_conversation_input_ids( 195 | self.tokenizer, 196 | query=query, 197 | history=history, 198 | images=[image] 199 | ) 200 | inputs = { 201 | 'input_ids': input_by_model['input_ids'].unsqueeze(0).to(device), 202 | 'token_type_ids': input_by_model['token_type_ids'].unsqueeze(0).to(device), 203 | 'attention_mask': input_by_model['attention_mask'].unsqueeze(0).to(device), 204 | 'images': [[input_by_model['images'][0].to(device).to(torch_type)]], 205 | } 206 | 207 | # CogVLM model do not have param 'cross_images', Only CogAgent have. 208 | 209 | if 'cross_images' in input_by_model and input_by_model['cross_images']: 210 | inputs['cross_images'] = [[input_by_model['cross_images'][0].to(device).to(torch_type)]] 211 | 212 | # Use TextIteratorStreamer for streaming generation like huggingface. 213 | 214 | streamer = TextIteratorStreamer(self.tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True) 215 | parameters['streamer'] = streamer 216 | gen_kwargs = {**parameters, **inputs} 217 | with torch.no_grad(): 218 | thread = Thread(target=model.generate, kwargs=gen_kwargs) 219 | thread.start() 220 | for next_text in streamer: 221 | yield TextGenerationStreamResponse( 222 | token=Token( 223 | id=0, 224 | logprob=0, 225 | text=next_text, 226 | special=False, 227 | ) 228 | ) 229 | -------------------------------------------------------------------------------- /composite_demo/conversation.py: -------------------------------------------------------------------------------- 1 | import requests 2 | import re 3 | import streamlit as st 4 | 5 | from dataclasses import dataclass 6 | from enum import auto, Enum 7 | from PIL.Image import Image 8 | from PIL import ImageDraw 9 | from streamlit.delta_generator import DeltaGenerator 10 | 11 | 12 | class Role(Enum): 13 | """ 14 | CogVLM | CogAgent Only have 2 roles: USER, ASSISTANT 15 | 16 | Represents the roles in a conversation, specifically for CogVLM and CogAgent applications. 17 | 18 | There are two roles available: 19 | - USER: The user of the system, typically the one asking questions or initiating conversation. 20 | - ASSISTANT: The system or AI assistant responding to the user's queries. 21 | 22 | Methods: 23 | get_message(self): 24 | Retrieves a Streamlit chat message component based on the role. For the USER role, it 25 | returns a chat message with the name "user" and user avatar. For the ASSISTANT role, 26 | it returns a chat message with the name "assistant" and assistant avatar. 27 | """ 28 | 29 | USER = auto() 30 | ASSISTANT = auto() 31 | 32 | def get_message(self): 33 | 34 | match self.value: 35 | case Role.USER.value: 36 | return st.chat_message(name="user", avatar="user") 37 | case Role.ASSISTANT.value: 38 | return st.chat_message(name="assistant", avatar="assistant") 39 | case _: 40 | st.error(f'Unexpected role: {self}') 41 | 42 | 43 | @dataclass 44 | class Conversation: 45 | """ 46 | Represents a single conversation turn within a dialogue. 47 | Attributes: 48 | role (Role): The role of the speaker in the conversation (USER or ASSISTANT). 49 | content (str): The textual content of the conversation turn. 50 | image (Image, optional): An optional image associated with the conversation turn. 51 | content_show (str, optional): The content to be displayed in the WebUI. This may differ 52 | from `content` if translation or other processing is applied. 53 | translate (bool, optional): Whether to translate the content of the conversation turn. 54 | 55 | Methods: 56 | __str__(self) -> str: 57 | Returns a string representation of the conversation turn, including the role and content. 58 | 59 | show(self, placeholder: DeltaGenerator | None = None) -> str: 60 | Displays the conversation turn in the WebUI. If `placeholder` is provided, the content 61 | is shown in the specified Streamlit container. Otherwise, it uses the message style 62 | determined by the role. 63 | """ 64 | 65 | role: Role = Role.USER 66 | content: str = "" 67 | image: Image | None = None 68 | content_show: str | None = None 69 | translate: bool = False 70 | 71 | def __str__(self) -> str: 72 | print(self.role, self.content) 73 | match self.role: 74 | case Role.USER | Role.ASSISTANT: 75 | return f'{self.role}\n{self.content}' 76 | 77 | def show(self, placeholder: DeltaGenerator | None = None) -> str: 78 | """ 79 | show in markdown formate 80 | """ 81 | if placeholder: 82 | message = placeholder 83 | else: 84 | message = self.role.get_message() 85 | 86 | # for Chinese WebUI show 87 | if self.role == Role.USER: 88 | if self.translate: 89 | self.content = translate_baidu(self.content_show, source_lan="zh", target_lan="en") 90 | if self.content == "error": 91 | self.content_show = "Please Enter your Baidu Translation API Key in function translate_baidu()" 92 | else: 93 | self.content = self.content_show 94 | if self.role == Role.ASSISTANT: 95 | if self.translate: 96 | self.content_show = translate_baidu(self.content, source_lan="en", target_lan="zh") 97 | else: 98 | self.content_show = self.content 99 | 100 | self.content_show = self.content_show.replace('\n', ' \n') 101 | 102 | message.markdown(self.content_show) 103 | if self.image: 104 | message.image(self.image) 105 | 106 | 107 | def preprocess_text(history: list[Conversation], ) -> str: 108 | """ 109 | Prepares the conversation history for processing by concatenating the content of each turn. 110 | Args: 111 | history (list[Conversation]): The conversation history, a list of Conversation objects. 112 | 113 | Returns: 114 | str: A single string that concatenates the content of each conversation turn, followed by 115 | the ASSISTANT role indicator. This string is suitable for use as input to a text generation model. 116 | """ 117 | 118 | prompt = "" 119 | for conversation in history: 120 | prompt += f'{conversation}' 121 | prompt += f'{Role.ASSISTANT}\n' 122 | return prompt 123 | 124 | 125 | def postprocess_text(template: str, text: str) -> str: 126 | """ 127 | Post-processes the generated text by incorporating it into a given template. 128 | Args: 129 | template (str): A template string containing a placeholder for the generated text. 130 | text (str): The generated text to be incorporated into the template. 131 | 132 | Returns: 133 | str: The template with the generated text replacing the placeholder. 134 | """ 135 | quoted_text = f'"{text.strip()}"' 136 | return template.replace("", quoted_text).strip() if template != "" else text.strip() 137 | 138 | 139 | def postprocess_image(text: str, img: Image) -> (str, Image): 140 | """ 141 | Processes the given text to identify and draw bounding boxes on the provided image. 142 | This function searches for patterns in the text that represent coordinates for bounding 143 | boxes and draws rectangles on the image at these coordinates. Each box is drawn in a 144 | different color for distinction. 145 | Args: 146 | text (str): The text containing bounding box coordinates in a specific pattern. 147 | img (Image): The image on which to draw the bounding boxes. 148 | Returns: 149 | tuple[str, Image]: The processed text with additional annotations for each bounding 150 | box, and the image with the drawn bounding boxes. 151 | """ 152 | colors = ["red", "green", "blue", "yellow", "purple", "orange"] 153 | 154 | # Updated pattern to match single or multiple coordinate groups 155 | pattern = r"\[\[([\d,]+(?:;[\d,]+)*)\]\]" 156 | matches = re.findall(pattern, text) 157 | draw = ImageDraw.Draw(img) 158 | 159 | if not matches: 160 | return text, None 161 | 162 | for i, match in enumerate(matches): 163 | # Splitting the matched string into individual coordinate groups 164 | coords_groups = match.split(';') 165 | 166 | # Determining the color for the current match 167 | color = colors[i % len(colors)] 168 | 169 | for coords_str in coords_groups: 170 | coords = coords_str.split(',') 171 | 172 | if len(coords) == 4: # Rectangle 173 | scaled_coords = ( 174 | int(float(coords[0]) * 0.001 * img.width), 175 | int(float(coords[1]) * 0.001 * img.height), 176 | int(float(coords[2]) * 0.001 * img.width), 177 | int(float(coords[3]) * 0.001 * img.height) 178 | ) 179 | draw.rectangle(scaled_coords, outline=color, width=3) 180 | elif len(coords) == 2: # Point 181 | scaled_coords = ( 182 | int(float(coords[0]) * 0.001 * img.width), 183 | int(float(coords[1]) * 0.001 * img.height) 184 | ) 185 | radius = 5 186 | draw.ellipse([scaled_coords[0] - radius, scaled_coords[1] - radius, 187 | scaled_coords[0] + radius, scaled_coords[1] + radius], 188 | fill=color) 189 | 190 | return text, img 191 | 192 | def translate_baidu(translate_text, source_lan, target_lan): 193 | """ 194 | Translates text using Baidu's translation service. (if you are not use English) 195 | 196 | This function sends a request to the Baidu translation API to translate the provided text 197 | from the source language to the target language. 198 | 199 | Args: 200 | translate_text (str): The text to be translated. 201 | source_lan (str): The source language code (e.g., "en" for English). 202 | target_lan (str): The target language code (e.g., "zh" for Chinese). 203 | 204 | Returns: 205 | str: The translated text. Returns "error" in case of an exception. 206 | """ 207 | url = "https://aip.baidubce.com/rpc/2.0/mt/texttrans/v1?access_token=" 208 | headers = {'Content-Type': 'application/json'} 209 | payload = { 210 | 'q': translate_text, 211 | 'from': source_lan, 212 | 'to': target_lan 213 | } 214 | try: 215 | r = requests.post(url, json=payload, headers=headers) 216 | result = r.json() 217 | final_translation = '' 218 | 219 | for item in result['result']['trans_result']: 220 | final_translation += item['dst'] + '\n' 221 | except Exception as e: 222 | print(e) 223 | return "error" 224 | return final_translation 225 | -------------------------------------------------------------------------------- /composite_demo/demo_agent_cogagent.py: -------------------------------------------------------------------------------- 1 | from io import BytesIO 2 | import base64 3 | import streamlit as st 4 | import re 5 | 6 | from streamlit.delta_generator import DeltaGenerator 7 | from client import get_client 8 | from conversation import postprocess_text, Conversation, Role, postprocess_image 9 | from PIL import Image 10 | from utils import images_are_same 11 | 12 | client = get_client() 13 | 14 | 15 | def append_conversation( 16 | conversation: Conversation, 17 | history: list[Conversation], 18 | placeholder: DeltaGenerator | None = None, 19 | ) -> None: 20 | history.append(conversation) 21 | conversation.show(placeholder) 22 | 23 | 24 | def main( 25 | top_p: float = 0.8, 26 | temperature: float = 0.95, 27 | prompt_text: str = "", 28 | metadata: str = "", 29 | top_k: int = 2, 30 | max_new_tokens: int = 2048, 31 | grounding: bool = False, 32 | retry: bool = False, 33 | template: str = "" 34 | ): 35 | if 'chat_history' not in st.session_state: 36 | st.session_state.chat_history = [] 37 | 38 | if prompt_text == "" and retry == False: 39 | print("\n== Clean ==\n") 40 | st.session_state.chat_history = [] 41 | return 42 | 43 | history: list[Conversation] = st.session_state.chat_history 44 | for conversation in history: 45 | conversation.show() 46 | 47 | if retry: 48 | print("\n== Retry ==\n") 49 | last_user_conversation_idx = None 50 | for idx, conversation in enumerate(history): 51 | if conversation.role == Role.USER: 52 | last_user_conversation_idx = idx 53 | if last_user_conversation_idx is not None: 54 | prompt_text = history[last_user_conversation_idx].content_show 55 | del history[last_user_conversation_idx:] 56 | 57 | if prompt_text: 58 | image = Image.open(BytesIO(base64.b64decode(metadata))).convert('RGB') if metadata else None 59 | image.thumbnail((1120, 1120)) 60 | image_input = image 61 | if history and image: 62 | last_user_image = next( 63 | (conv.image for conv in reversed(history) if conv.role == Role.USER and conv.image), None) 64 | if last_user_image and images_are_same(image, last_user_image): 65 | image_input = None 66 | 67 | # Not necessary to clear history 68 | # else: 69 | # # new picture means new conversation 70 | # st.session_state.chat_history = [] 71 | # history = [] 72 | 73 | # Set conversation 74 | if re.search('[\u4e00-\u9fff]', prompt_text): 75 | translate = True 76 | else: 77 | translate = False 78 | 79 | user_conversation = Conversation( 80 | role=Role.USER, 81 | translate=translate, 82 | content_show=prompt_text.strip() if retry else postprocess_text(template=template, 83 | text=prompt_text.strip()), 84 | image=image_input 85 | ) 86 | append_conversation(user_conversation, history) 87 | placeholder = st.empty() 88 | assistant_conversation = placeholder.chat_message(name="assistant", avatar="assistant") 89 | assistant_conversation = assistant_conversation.empty() 90 | 91 | # steam Answer 92 | output_text = '' 93 | for response in client.generate_stream( 94 | model_use='agent_chat', 95 | grounding=grounding, 96 | history=history, 97 | do_sample=True, 98 | max_new_tokens=max_new_tokens, 99 | temperature=temperature, 100 | top_p=top_p, 101 | top_k=top_k, 102 | ): 103 | output_text += response.token.text 104 | assistant_conversation.markdown(output_text.strip() + '▌') 105 | 106 | ## Final Answer with image. 107 | print("\n==Output:==\n", output_text) 108 | content_output, image_output = postprocess_image(output_text, image) 109 | assistant_conversation = Conversation( 110 | role=Role.ASSISTANT, 111 | content=content_output, 112 | image=image_output, 113 | translate=translate, 114 | ) 115 | append_conversation( 116 | conversation=assistant_conversation, 117 | history=history, 118 | placeholder=placeholder.chat_message(name="assistant", avatar="assistant"), 119 | ) 120 | -------------------------------------------------------------------------------- /composite_demo/demo_chat_cogagent.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | import base64 3 | import re 4 | 5 | from PIL import Image 6 | from io import BytesIO 7 | from streamlit.delta_generator import DeltaGenerator 8 | from client import get_client 9 | from utils import images_are_same 10 | from conversation import Conversation, Role, postprocess_image, postprocess_text 11 | 12 | client = get_client() 13 | 14 | 15 | def append_conversation( 16 | conversation: Conversation, 17 | history: list[Conversation], 18 | placeholder: DeltaGenerator | None = None, 19 | ) -> None: 20 | history.append(conversation) 21 | conversation.show(placeholder) 22 | 23 | 24 | def main( 25 | top_p: float = 0.8, 26 | temperature: float = 0.95, 27 | prompt_text: str = "", 28 | metadata: str = "", 29 | top_k: int = 2, 30 | max_new_tokens: int = 2048, 31 | grounding: bool = False, 32 | retry: bool = False, 33 | template: str = "", 34 | ): 35 | if 'chat_history' not in st.session_state: 36 | st.session_state.chat_history = [] 37 | 38 | if prompt_text == "" and retry == False: 39 | print("\n== Clean ==\n") 40 | st.session_state.chat_history = [] 41 | return 42 | 43 | history: list[Conversation] = st.session_state.chat_history 44 | for conversation in history: 45 | conversation.show() 46 | if retry: 47 | last_user_conversation_idx = None 48 | for idx, conversation in enumerate(history): 49 | if conversation.role == Role.USER: 50 | last_user_conversation_idx = idx 51 | if last_user_conversation_idx is not None: 52 | prompt_text = history[last_user_conversation_idx].content_show 53 | del history[last_user_conversation_idx:] 54 | 55 | if prompt_text: 56 | image = Image.open(BytesIO(base64.b64decode(metadata))).convert('RGB') if metadata else None 57 | image.thumbnail((1120, 1120)) 58 | image_input = image 59 | if history and image: 60 | last_user_image = next( 61 | (conv.image for conv in reversed(history) if conv.role == Role.USER and conv.image), None) 62 | if last_user_image and images_are_same(image, last_user_image): 63 | image_input = None 64 | else: 65 | st.session_state.chat_history = [] 66 | history = [] 67 | 68 | # Set conversation 69 | if re.search('[\u4e00-\u9fff]', prompt_text): 70 | translate = True 71 | else: 72 | translate = False 73 | 74 | user_conversation = Conversation( 75 | role=Role.USER, 76 | translate=translate, 77 | content_show=prompt_text.strip() if retry else postprocess_text(template=template, 78 | text=prompt_text.strip()), 79 | image=image_input 80 | ) 81 | append_conversation(user_conversation, history) 82 | placeholder = st.empty() 83 | assistant_conversation = placeholder.chat_message(name="assistant", avatar="assistant") 84 | assistant_conversation = assistant_conversation.empty() 85 | 86 | # steam Answer 87 | output_text = '' 88 | for response in client.generate_stream( 89 | model_use='agent_chat', 90 | grounding=grounding, 91 | history=history, 92 | do_sample=True, 93 | max_new_tokens=max_new_tokens, 94 | temperature=temperature, 95 | top_p=top_p, 96 | top_k=top_k, 97 | ): 98 | output_text += response.token.text 99 | assistant_conversation.markdown(output_text.strip() + '▌') 100 | 101 | print("\n==Output:==\n", output_text) 102 | content_output, image_output = postprocess_image(output_text, image) 103 | assistant_conversation = Conversation( 104 | role=Role.ASSISTANT, 105 | content=content_output, 106 | image=image_output, 107 | translate=translate 108 | ) 109 | append_conversation( 110 | conversation=assistant_conversation, 111 | history=history, 112 | placeholder=placeholder.chat_message(name="assistant", avatar="assistant") 113 | ) 114 | -------------------------------------------------------------------------------- /composite_demo/demo_chat_cogvlm.py: -------------------------------------------------------------------------------- 1 | import streamlit as st 2 | import base64 3 | import re 4 | 5 | from PIL import Image 6 | from io import BytesIO 7 | from streamlit.delta_generator import DeltaGenerator 8 | from client import get_client 9 | from utils import images_are_same 10 | from conversation import Conversation, Role, postprocess_image, postprocess_text 11 | 12 | client = get_client() 13 | 14 | 15 | def append_conversation( 16 | conversation: Conversation, 17 | history: list[Conversation], 18 | placeholder: DeltaGenerator | None = None, 19 | ) -> None: 20 | history.append(conversation) 21 | conversation.show(placeholder) 22 | 23 | 24 | def main( 25 | top_p: float = 0.8, 26 | temperature: float = 0.95, 27 | prompt_text: str = "", 28 | metadata: str = "", 29 | top_k: int = 2, 30 | max_new_tokens: int = 2048, 31 | grounding: bool = False, 32 | retry: bool = False, 33 | template: str = "", 34 | ): 35 | if 'chat_history' not in st.session_state: 36 | st.session_state.chat_history = [] 37 | 38 | if prompt_text == "" and retry == False: 39 | print("\n== Clean ==\n") 40 | st.session_state.chat_history = [] 41 | return 42 | 43 | history: list[Conversation] = st.session_state.chat_history 44 | for conversation in history: 45 | conversation.show() 46 | if retry: 47 | last_user_conversation_idx = None 48 | for idx, conversation in enumerate(history): 49 | if conversation.role == Role.USER: 50 | last_user_conversation_idx = idx 51 | if last_user_conversation_idx is not None: 52 | prompt_text = history[last_user_conversation_idx].content_show 53 | del history[last_user_conversation_idx:] 54 | 55 | if prompt_text: 56 | image = Image.open(BytesIO(base64.b64decode(metadata))).convert('RGB') if metadata else None 57 | image.thumbnail((1120, 1120)) 58 | image_input = image 59 | if history and image: 60 | last_user_image = next( 61 | (conv.image for conv in reversed(history) if conv.role == Role.USER and conv.image), None) 62 | if last_user_image and images_are_same(image, last_user_image): 63 | image_input = None 64 | else: 65 | st.session_state.chat_history = [] 66 | history = [] 67 | 68 | # Set conversation 69 | if re.search('[\u4e00-\u9fff]', prompt_text): 70 | translate = True 71 | else: 72 | translate = False 73 | 74 | user_conversation = Conversation( 75 | role=Role.USER, 76 | translate=translate, 77 | content_show=prompt_text.strip() if retry else postprocess_text(template=template, 78 | text=prompt_text.strip()), 79 | image=image_input 80 | ) 81 | append_conversation(user_conversation, history) 82 | placeholder = st.empty() 83 | assistant_conversation = placeholder.chat_message(name="assistant", avatar="assistant") 84 | assistant_conversation = assistant_conversation.empty() 85 | 86 | # steam Answer 87 | output_text = '' 88 | for response in client.generate_stream( 89 | model_use='vlm_grounding' if grounding else 'vlm_chat', 90 | grounding=False, 91 | history=history, 92 | do_sample=True, 93 | max_new_tokens=max_new_tokens, 94 | temperature=temperature, 95 | top_p=top_p, 96 | top_k=top_k, 97 | ): 98 | output_text += response.token.text 99 | assistant_conversation.markdown(output_text.strip() + '▌') 100 | 101 | print("\n==Output:==\n", output_text) 102 | content_output, image_output = postprocess_image(output_text, image) 103 | assistant_conversation = Conversation( 104 | role=Role.ASSISTANT, 105 | content=content_output, 106 | image=image_output, 107 | translate=translate 108 | ) 109 | append_conversation( 110 | conversation=assistant_conversation, 111 | history=history, 112 | placeholder=placeholder.chat_message(name="assistant", avatar="assistant") 113 | ) 114 | -------------------------------------------------------------------------------- /composite_demo/main.py: -------------------------------------------------------------------------------- 1 | """ 2 | This is a demo using the chat version about CogAgent and CogVLM in WebDEMO 3 | 4 | Make sure you have installed the vicuna-7b-v1.5 tokenizer model (https://huggingface.co/lmsys/vicuna-7b-v1.5), 5 | and a full checkpoint of vicuna-7b-v1.5 LLM is not required. 6 | 7 | Mention that only one image can be processed in a conversation, which means you cannot replace or insert another image 8 | during the conversation. 9 | 10 | 11 | The models_info parameter is explained as follows 12 | tokenizer: tokenizer model using vicuna-7b-v1.5 model 13 | agent_chat: Use the CogAgent-chat-18B model to complete the conversation task 14 | vlm_chat: Use the CogVLM-chat-17B model to complete the conversation task 15 | vlm_grounding: Use CogVLM-grounding-17B model to complete the Grounding task 16 | 17 | Web Demo user operation logic is as follows: 18 | CogVLM-Chat -> grounding? - yes -> Choose a template -> CogVLM-grounding-17B 19 | - no -> CogVLM-chat-17B (without grounding) 20 | 21 | CogAgent-Chat -> CogAgent-chat-18B (Only QA,without Grounding) 22 | 23 | CogAgent-Agent -> CogAgent-chat-18B 24 | -> Choose a template -> grounding? - yes -> prompt + (with grounding) 25 | - no -> prompt 26 | 27 | CogAgent-vqa-hf are not included in this demo, but you can use it in the same way as CogAgent-chat-18B 28 | and used it in CogAgent-Chat 29 | """ 30 | 31 | import streamlit as st 32 | 33 | st.set_page_config( 34 | page_title="CogVLM & CogAgent Demo", 35 | page_icon=":robot:", 36 | layout='centered', 37 | initial_sidebar_state='expanded', 38 | ) 39 | 40 | from enum import Enum 41 | from utils import encode_file_to_base64, templates_agent_cogagent, template_grounding_cogvlm 42 | import demo_chat_cogvlm, demo_agent_cogagent, demo_chat_cogagent 43 | 44 | st.markdown("

CogAgent & CogVLM Chat Demo

", unsafe_allow_html=True) 45 | st.markdown( 46 | "更多使用方法请参考文档: https://lslfd0slxc.feishu.cn/wiki/WvQbwIJ9tiPAxGk8ywDck6yfnof \n\n 请根据文档的引导说明来尝试demo,以便理解demo的布局设计 \n", 47 | unsafe_allow_html=True) 48 | 49 | 50 | class Mode(str, Enum): 51 | CogVLM_Chat, CogAgent_Chat, CogAgent_Agent = '💬CogVLM-Chat', '🧑‍💻 CogAgent-Chat', '💡 CogAgent-Agent' 52 | 53 | 54 | with st.sidebar: 55 | top_p = st.slider( 56 | 'top_p', 0.0, 1.0, 0.8, step=0.01 57 | ) 58 | temperature = st.slider( 59 | 'temperature', 0.01, 1.0, 0.90, step=0.01 60 | ) 61 | top_k = st.slider( 62 | 'top_k', 1, 20, 5, step=1 63 | ) 64 | max_new_token = st.slider( 65 | 'Output length', 1, 2048, 2048, step=1 66 | ) 67 | 68 | uploaded_file = st.file_uploader("Choose an image...", type=['.jpg', '.png', '.jpeg'], accept_multiple_files=False) 69 | 70 | cols = st.columns(2) 71 | export_btn = cols[0] 72 | clear_history = cols[1].button("Clear History", use_container_width=True) 73 | retry = export_btn.button("Retry", use_container_width=True) 74 | 75 | prompt_text = st.chat_input( 76 | 'Chat with CogAgent | CogVLM', 77 | key='chat_input', 78 | ) 79 | 80 | tab = st.radio( 81 | 'Mode', 82 | [mode.value for mode in Mode], 83 | horizontal=True, 84 | label_visibility='hidden', 85 | ) 86 | 87 | selected_template_grounding_cogvlm = "" 88 | with st.sidebar: 89 | grounding = st.checkbox("Grounding") 90 | if tab == Mode.CogVLM_Chat or tab == Mode.CogAgent_Chat: 91 | if grounding: 92 | selected_template_grounding_cogvlm = st.selectbox("Template For Grounding", template_grounding_cogvlm) 93 | 94 | if tab == Mode.CogAgent_Agent: 95 | with st.sidebar: 96 | selected_template_agent_cogagent = st.selectbox("Template For Agent", templates_agent_cogagent) 97 | 98 | if clear_history or retry: 99 | prompt_text = "" 100 | 101 | match tab: 102 | case Mode.CogVLM_Chat: 103 | st.info("This option uses cogvlm-chat and cogvlm-grounding model.") 104 | if uploaded_file is not None: 105 | demo_chat_cogvlm.main( 106 | retry=retry, 107 | top_p=top_p, 108 | top_k=top_k, 109 | temperature=temperature, 110 | prompt_text=prompt_text, 111 | metadata=encode_file_to_base64(uploaded_file), 112 | max_new_tokens=max_new_token, 113 | grounding=grounding, 114 | template=selected_template_grounding_cogvlm 115 | ) 116 | else: 117 | st.error(f'Please upload an image to start') 118 | 119 | case Mode.CogAgent_Chat: 120 | st.info("This option uses cogagent-chat model.") 121 | if uploaded_file is not None: 122 | demo_chat_cogagent.main( 123 | retry=retry, 124 | top_p=top_p, 125 | top_k=top_k, 126 | temperature=temperature, 127 | prompt_text=prompt_text, 128 | metadata=encode_file_to_base64(uploaded_file), 129 | max_new_tokens=max_new_token, 130 | grounding=grounding, 131 | template=selected_template_grounding_cogvlm 132 | ) 133 | else: 134 | st.error(f'Please upload an image to start') 135 | 136 | case Mode.CogAgent_Agent: 137 | st.info("This option uses cogagent-chat model with agent template.") 138 | if uploaded_file is not None: 139 | demo_agent_cogagent.main( 140 | retry=retry, 141 | top_p=top_p, 142 | top_k=top_k, 143 | temperature=temperature, 144 | prompt_text=prompt_text, 145 | metadata=encode_file_to_base64(uploaded_file), 146 | max_new_tokens=max_new_token, 147 | grounding=grounding, 148 | template=selected_template_agent_cogagent 149 | ) 150 | else: 151 | st.error(f'Please upload an image to start') 152 | case _: 153 | st.error(f'Unexpected tab: {tab}') 154 | -------------------------------------------------------------------------------- /dataset.md: -------------------------------------------------------------------------------- 1 | # CogVLM-SFT-311K: Bilingual Visual Instruction Data in CogVLM SFT 2 | 3 | CogVLM-SFT-311K is the primary aligned corpus used in the initial training of CogVLM v1.0. The process of constructing this dataset is as follows: 4 | 1. Approximately 3500 high-quality data samples were selected from the open source [MiniGPT-4](https://huggingface.co/datasets/Vision-CAIR/cc_sbu_align), known as minigpt4-3500. 5 | 2. Minigpt4-3500 was integrated with [Llava-Instruct-150K](https://huggingface.co/datasets/liuhaotian/LLaVA-Instruct-150K) and translated into Chinese through a language model. 6 | 3. We discovered significant noise in the detailed description part of minigpt4-3500 and Llava-instruct. Thus, we corrected these Chinese corpora and retranslated them into English. 7 | 8 | ## License 9 | 10 | + Due to non-commercial agreements, we did not use these data in the bilingual version of CogVLM or any other models involving commercialization. 11 | + The dataset license adheres to:
Attribution-NonCommercial 4.0 International. It should abide by the policy of OpenAI: https://openai.com/policies/terms-of-use 12 | This will not allow you to use these data for any **commercial activitiesI**. 13 | 14 | ## Dataset Address 15 | 16 | + [CogVLM-SFT-311K](https://huggingface.co/datasets/THUDM/CogVLM-SFT-311K) 17 | 18 | ## Dataset Information 19 | 20 | The dataset contains three folders corresponding to the mixed part of minigpt4-3500 and llava, the llava solo conversation, and the multi-turn conversation datasets. Their layout is as follows: 21 | ``` 22 | .CogVLM-SFT-311K 23 | ├── llava_details-minigpt4_3500_formate 24 | ├── llava_instruction_multi_conversations_formate 25 | └── llava_instruction_single_conversation_formate 26 | ``` 27 | In our open-source data, the datasets are distributed as follows: 28 | ``` 29 | .llava_details-minigpt4_3500_formate 30 | ├── images 31 | │ └── 00000001.jpg 32 | └── labels 33 | └── 00000001.json 34 | ``` 35 | Images are stored in the images folder, while labels containing the corresponding image description or dialogue are stored in the labels folder. 36 | 37 | ## Dataset Quantity 38 | 39 | + llava_details-minigpt4_3500_formate: 22,464 images and descriptions 40 | + llava_instruction_multi_conversations_formate: 56,673 images and multi-turn conversations 41 | + llava_instruction_single_conversation_formate: 76,634 images and single-turn conversations 42 | 43 | Dataset Format 44 | 45 | Caption format for image description 46 | ``` 47 | { 48 | "captions": [ 49 | { 50 | "role": "caption", 51 | "content": "The photograph features a beach scene with a group of people surfing in the ocean. There are ten individuals spaced out in the water, sitting or lying on various sizes of body boards. The surfers line up in a row, close to the sandy shoreline. Some of the body boards are visible floating on the surface of the water, with the surfers riding on them or nearby." 52 | } 53 | ] 54 | } 55 | ``` 56 | Conversation format for image dialogue 57 | ``` 58 | { 59 | "conversations": [ 60 | { 61 | "role": "user", 62 | "content": "What can be inferred about the zebras' behavior and surroundings?" 63 | }, 64 | { 65 | "role": "assistant", 66 | "content": "Based on the image, we can infer that the two zebras are likely seeking relief from the sun's heat, as they are standing side by side under the branches of a thorny tree. This shade-providing tree offers some respite from the sun, possibly during the hottest part of the day. The zebras are in a green field with grass, providing them with an ideal environment to graze and eat while staying near their source of shelter. This shows that the zebras' behavior is influenced by the conditions and available resources in their surroundings. It also highlights that these animals adopt strategies to adapt to the fluctuating conditions of their environment, such as cooperation and seeking shelter, to survive and thrive in their natural habitat." 67 | } 68 | ] 69 | } 70 | ``` 71 | 72 | ## References 73 | This project utilizes data and concepts based on the following research papers: 74 | - Zhu, D., Chen, J., Shen, X., Li, X., & Elhoseiny, M. (2023). MiniGPT-4: Enhancing Vision-Language Understanding with Advanced Large Language Models. arXiv preprint arXiv:2304.10592. 75 | - Liu, H., Li, C., Wu, Q., & Lee, Y. J. (2023). Visual Instruction Tuning. arXiv:2304.08485. -------------------------------------------------------------------------------- /dataset_zh.md: -------------------------------------------------------------------------------- 1 | # CogVLM-SFT-311K:CogVLM SFT 中的双语视觉指令数据集 2 | 3 | CogVLM-SFT-311K 是我们在训练 **CogVLM v1.0** 最初版本时使用的主要对齐语料库。此数据集的构建过程如下: 4 | 1. 从开源的 [MiniGPT-4](https://huggingface.co/datasets/Vision-CAIR/cc_sbu_align) 中选取了大约3500个高质量数据样本,称为 minigpt4-3500。 5 | 2. 将 minigpt4-3500 与 [Llava-Instruct-150K](https://huggingface.co/datasets/liuhaotian/LLaVA-Instruct-150K) 整合,并通过语言模型翻译获得中文部分。 6 | 3. 我们发现在 minigpt4-3500 和 Llava-instruct 的详细描述部分存在许多噪声。因此,我们纠正了这两部分的中文语料,并将纠正后的语料重新翻译成英语。 7 | 8 | ## 许可证 9 | + 由于非商业协议限制,我们没有在 CogVLM的双语版本 和其他任何 涉及商业化的模型 中使用这些数据。 10 | + 数据集许可证遵守:
Attribution-NonCommercial 4.0 International It should abide by the policy of OpenAI: https://openai.com/policies/terms-of-use 11 | 这将不允许你使用这些数据进行任何 **商业化行为**。 12 | 13 | ## 数据集地址 14 | 15 | + [CogVLM-SFT-311K](https://huggingface.co/datasets/THUDM/CogVLM-SFT-311K) 16 | 17 | ## 数据集信息 18 | 数据集共有三个文件夹,分别对应混合 minigpt4-3500 与llava混合的一部分数据集,llava 单论对话和多轮对话数据集。其布局如下: 19 | ``` 20 | .CogVLM-SFT-311K 21 | ├── llava_details-minigpt4_3500_formate 22 | ├── llava_instruction_multi_conversations_formate 23 | └── llava_instruction_single_conversation_formate 24 | ``` 25 | 在我们开源的数据中,数据集按照以下格式分布 26 | ``` 27 | .llava_details-minigpt4_3500_formate 28 | ├── images 29 | │ └── 00000001.jpg 30 | └── labels 31 | └── 00000001.json 32 | ``` 33 | 其中,images存储图像数据,而labels存储这张图像对应的描述或对话。 34 | 35 | ## 数据集数量 36 | + llava_details-minigpt4_3500_formate 22464 张图片和描述 37 | + llava_instruction_muti_conversations_formate 56673 张图片和多轮连续对话 38 | + llava_instruction_single_conversation_formate 76634 张图片和单轮对话 39 | ## 数据集格式 40 | 41 | 图像描述 Caption 格式 42 | ``` 43 | { 44 | "captions": [ 45 | { 46 | "role": "caption", 47 | "content": "这张照片展示了一男一女,他们都穿着红嘿色衣服,坐在桌子旁,在参加活动时面带微笑。这对情侣正在一起摆姿势,背景中还有其他人。在整个场景中可以看到另外几个人,一些人坐在附近的桌子旁,另一些人坐得离这对夫妇更近。各种餐桌设置,如酒杯和杯子,被放置在桌子上。排在最前面的男子打着领带。" 48 | } 49 | ] 50 | } 51 | ``` 52 | 图像对话 Json 格式 53 | ``` 54 | { 55 | "conversations": [ 56 | { 57 | "role": "user", 58 | "content": "在尝试保持这个浴室有序时可能会遇到哪些挑战?", 59 | }, 60 | { 61 | "role": "assistant", 62 | "content": "在尝试保持这空间和各种物品的存在,例如杯子和手机占用了可用的空间。在图像中,浴室柜台上只有一个水槽可用,这导致多个用户可能必须共享空间放置自己的物品。这可能导致杂乱和混乱外的储物解决方案,如架子、橱柜或墙壁挂架,以保持浴室有序并最小化柜台上的杂乱。" 63 | }, 64 | ] 65 | } 66 | ``` 67 | 68 | ## References 69 | This project utilizes data and concepts based on the following research papers: 70 | - Zhu, D., Chen, J., Shen, X., Li, X., & Elhoseiny, M. (2023). MiniGPT-4: Enhancing Vision-Language Understanding with Advanced Large Language Models. arXiv preprint arXiv:2304.10592. 71 | - Liu, H., Li, C., Wu, Q., & Lee, Y. J. (2023). Visual Instruction Tuning. arXiv:2304.08485. -------------------------------------------------------------------------------- /finetune_demo/evaluate_cogagent.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | # export PATH=/usr/local/cuda/bin:$PATH 3 | # export LD_LIBRARY_PATH=/usr/local/cuda/lib64:$LD_LIBRARY_PATH 4 | 5 | NUM_GPUS_PER_WORKER=8 6 | MP_SIZE=1 7 | 8 | script_path=$(realpath $0) 9 | script_dir=$(dirname $script_path) 10 | main_dir=$(dirname $script_dir) 11 | MODEL_TYPE="cogagent-chat" 12 | VERSION="chat" 13 | # Tips: max_length should be longer than 256, to accomodate low-resolution image tokens 14 | MODEL_ARGS="--from_pretrained ./checkpoints/ft_cogagent_model \ 15 | --max_length 400 \ 16 | --local_tokenizer lmsys/vicuna-7b-v1.5 \ 17 | --version $VERSION" 18 | 19 | OPTIONS_SAT="SAT_HOME=~/.sat_models" 20 | OPTIONS_NCCL="NCCL_DEBUG=info NCCL_IB_DISABLE=0 NCCL_NET_GDR_LEVEL=2 LOCAL_WORLD_SIZE=$NUM_GPUS_PER_WORKER" 21 | HOST_FILE_PATH="hostfile" 22 | 23 | train_data="./archive_split/train" 24 | test_data="./archive_split/test" 25 | 26 | gpt_options=" \ 27 | --experiment-name finetune-$MODEL_TYPE \ 28 | --model-parallel-size ${MP_SIZE} \ 29 | --mode finetune \ 30 | --train-iters 0 \ 31 | --resume-dataloader \ 32 | $MODEL_ARGS \ 33 | --train-data ${train_data} \ 34 | --test-data ${test_data} \ 35 | --distributed-backend nccl \ 36 | --lr-decay-style cosine \ 37 | --warmup .02 \ 38 | --checkpoint-activations \ 39 | --save-interval 200 \ 40 | --eval-interval 200 \ 41 | --save "./checkpoints" \ 42 | --strict-eval \ 43 | --eval-batch-size 1 \ 44 | --split 1. \ 45 | --deepspeed_config test_config_bf16.json \ 46 | --skip-init \ 47 | --seed 2023 48 | " 49 | 50 | 51 | 52 | run_cmd="${OPTIONS_NCCL} ${OPTIONS_SAT} deepspeed --master_port 16666 --hostfile ${HOST_FILE_PATH} evaluate_cogagent_demo.py ${gpt_options}" 53 | echo ${run_cmd} 54 | eval ${run_cmd} 55 | 56 | set +x -------------------------------------------------------------------------------- /finetune_demo/evaluate_cogagent_demo.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import argparse 4 | import sys 5 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 6 | 7 | from sat import mpu, get_args, get_tokenizer 8 | from sat.training.deepspeed_training import training_main 9 | from sat.helpers import print_rank0 10 | from collections import defaultdict 11 | from functools import partial 12 | 13 | from utils.models import FineTuneTestCogAgentModel 14 | from utils.utils import llama2_text_processor, llama2_text_processor_inference, get_image_processor 15 | 16 | 17 | def data_collator(examples, cross_image_processor=None): 18 | def to_tensor(value): 19 | """Converts lists or numpy arrays to tensors.""" 20 | if isinstance(value, list): 21 | return torch.tensor(value) 22 | elif isinstance(value, np.ndarray): 23 | return torch.from_numpy(value) 24 | return value 25 | 26 | def concatenate_tensors(attribute, key): 27 | """Concatenates tensors for a specific attribute and key.""" 28 | if attribute is None: 29 | return torch.cat([ex[key] for ex in examples if isinstance(ex[key], torch.Tensor)]) 30 | else: 31 | return torch.cat([ex[attribute][key] for ex in examples if isinstance(ex[attribute][key], torch.Tensor)]) 32 | 33 | # Convert all lists and numpy arrays in examples to tensors 34 | for example in examples: 35 | for key, value in example.items(): 36 | example[key] = to_tensor(value) 37 | 38 | # Extract and concatenate attributes from examples 39 | img_args = {} 40 | for attribute in ['vision', 'cross']: 41 | if attribute == 'cross' and cross_image_processor is None: 42 | continue 43 | 44 | if attribute in examples[-1]: # Using the last example as reference 45 | for key in examples[-1][attribute]: 46 | tensor_key = f"{attribute}_{key}" 47 | tensors_to_concatenate = [ex[attribute][key] for ex in examples if isinstance(ex[attribute][key], torch.Tensor)] 48 | if tensors_to_concatenate: 49 | img_args[tensor_key] = concatenate_tensors(attribute, key) 50 | else: 51 | img_args[tensor_key] = examples[-1][attribute][key] 52 | 53 | # Remove 'vision' and 'cross' keys from examples 54 | for example in examples: 55 | example.pop('vision', None) 56 | example.pop('cross', None) 57 | 58 | # Create model_args by concatenating tensors and copying other attributes 59 | model_args = {key: concatenate_tensors(None, key) 60 | if isinstance(examples[-1][key], torch.Tensor) else examples[-1][key] 61 | for key in examples[-1] 62 | } 63 | 64 | # Merge img_args into model_args 65 | model_args.update(img_args) 66 | return model_args 67 | 68 | def broadcast_auto(data_dict): 69 | # Classify keys based on their data type 70 | tensor_keys_by_dtype = defaultdict(list) 71 | non_tensor_keys = [] 72 | 73 | for key, value in data_dict.items(): 74 | if isinstance(value, torch.Tensor): 75 | tensor_keys_by_dtype[value.dtype].append(key) 76 | else: 77 | non_tensor_keys.append(key) 78 | 79 | # Broadcast tensor data and collect in a new dictionary 80 | broadcasted_data = {} 81 | for dtype, keys in tensor_keys_by_dtype.items(): 82 | broadcasted_data.update(mpu.broadcast_data(keys, data_dict, dtype)) 83 | 84 | # Add non-tensor data to the new dictionary 85 | for key in non_tensor_keys: 86 | broadcasted_data[key] = data_dict[key] 87 | 88 | return broadcasted_data 89 | 90 | def get_batch(data_iterator, args, timers): 91 | # Broadcast data. 92 | timers('data loader').start() 93 | if data_iterator is not None: 94 | data = next(data_iterator) 95 | else: 96 | data = None 97 | timers('data loader').stop() 98 | data_b = broadcast_auto(data) 99 | for k in data_b: 100 | if type(data_b[k]) is torch.Tensor and data_b[k].dtype is not torch.int32 and data_b[k].dtype is not torch.long: 101 | if args.fp16: 102 | data_b[k] = data_b[k].half() 103 | elif args.bf16: 104 | data_b[k] = data_b[k].bfloat16() 105 | return data_b 106 | 107 | from torch.nn import CrossEntropyLoss 108 | import numpy as np 109 | 110 | from sat.model.mixins import CachedAutoregressiveMixin 111 | from sat.generation.autoregressive_sampling import filling_sequence 112 | from sat.generation.sampling_strategies import BaseStrategy, BeamSearchStrategy 113 | 114 | 115 | def chat(model, tokenizer, tokens, 116 | max_length: int = 1800, num_beams=5, top_p=0.95, top_k=0, temperature=0.8, **kwargs): 117 | inputs = tokens.to(model.parameters().__next__().device)[0] 118 | seq = torch.cat( 119 | [inputs, torch.tensor([-1] * (max_length - len(inputs)), device=inputs.device)], dim=0 120 | ) 121 | strategy = BaseStrategy(temperature=temperature, top_p=0.4, top_k=1, end_tokens=[tokenizer.eos_token_id]) 122 | # strategy = BeamSearchStrategy(temperature=temperature, top_p=top_p, top_k=top_k, end_tokens=[tokenizer.eos_token_id], 123 | # num_beams=num_beams, consider_end=True) 124 | get_func = llama2_text_processor_inference.get_func(None, None, image_rope_mask=kwargs['image_rope_mask']) 125 | output = filling_sequence( 126 | model, seq, 127 | batch_size=1, 128 | strategy=strategy, 129 | get_masks_and_position_ids=get_func, 130 | **kwargs 131 | )[0] # drop memory 132 | 133 | return output 134 | 135 | 136 | def forward_step_eval(data_iterator, model, args, timers): 137 | def compute_metrics(eval_preds): 138 | preds, labels, device = eval_preds 139 | preds = preds.unsqueeze(0) 140 | if isinstance(preds, tuple): 141 | preds = preds[0] 142 | decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True) 143 | if args.ignore_pad_token_for_loss: 144 | # Replace -100 in the labels as we can't decode them. 145 | labels = np.where(labels != -100, labels, tokenizer.pad_token_id) 146 | decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True) 147 | 148 | score_dict = { 149 | "acc": [], 150 | "acc_w/o_case": [], 151 | } 152 | for pred, label in zip(decoded_preds, decoded_labels): 153 | if args.rank == 0: 154 | print('pred', pred, 'label', label, flush=True) 155 | if pred == label: 156 | score_dict['acc'].append(1.) 157 | else: 158 | score_dict['acc'].append(0.) 159 | if pred.lower() == label.lower(): 160 | score_dict['acc_w/o_case'].append(1.) 161 | else: 162 | score_dict['acc_w/o_case'].append(0.) 163 | 164 | 165 | for k, v in score_dict.items(): 166 | score_dict[k] = float(np.mean(v)) 167 | return score_dict 168 | 169 | # Get the batch. 170 | timers('batch generator').start() 171 | data_b = get_batch( 172 | data_iterator, args, timers) 173 | timers('batch generator').stop() 174 | 175 | context_len = int(data_b['context_length'][0]) 176 | tokens = data_b['input_ids'][:, :context_len] 177 | data_b['vision_expert_mask'] = data_b['vision_expert_mask'][:, :context_len] 178 | data_b['image_embed_mask'] = data_b['image_embed_mask'][:, :context_len] 179 | data_b['image_rope_mask'] = data_b['image_rope_mask'][:, :context_len] 180 | 181 | data_b.pop('input_ids') 182 | data_b.pop('attention_mask') 183 | data_b.pop('position_ids') 184 | labels = data_b.pop('labels') 185 | qid = data_b.pop('question_id') 186 | 187 | model.add_mixin('auto-regressive', CachedAutoregressiveMixin()) 188 | outputs = chat(model, tokenizer, tokens, **data_b)[0][context_len:] 189 | # print(outputs) 190 | model.del_mixin('auto-regressive') 191 | 192 | return torch.tensor(0, device=outputs.device), {k: torch.tensor(v, device=outputs.device) for k, v in 193 | compute_metrics( 194 | (outputs.cpu(), labels.cpu(), outputs.device)).items()} 195 | 196 | 197 | from torch.nn import CrossEntropyLoss 198 | def forward_step(data_iterator, model, args, timers): 199 | """Forward step.""" 200 | 201 | # Get the batch. 202 | timers('batch generator').start() 203 | data_b = get_batch( 204 | data_iterator, args, timers) 205 | labels = data_b.pop('labels') 206 | timers('batch generator').stop() 207 | logits = model(**data_b)[0] 208 | lm_logits = logits.to(torch.float32) 209 | # Shift so that tokens < n predict n 210 | shift_labels = labels[..., 1:].contiguous() 211 | shift_logits = lm_logits[..., -1-shift_labels.size(-1):-1, :].contiguous() 212 | # Flatten the tokens 213 | loss_fct = CrossEntropyLoss(ignore_index=-100) 214 | loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) 215 | loss = loss.to(torch.float32) 216 | 217 | return loss, {'loss': loss} 218 | 219 | from utils.utils import ItemDataset 220 | def create_dataset_function(image_processor, text_processor, cross_image_processor, path, args): 221 | dataset = ItemDataset(image_processor, text_processor, args, path, cross_image_processor=cross_image_processor) 222 | return dataset 223 | 224 | if __name__ == '__main__': 225 | py_parser = argparse.ArgumentParser(add_help=False) 226 | py_parser.add_argument('--max_length', type=int) 227 | py_parser.add_argument('--ignore_pad_token_for_loss', action='store_false') 228 | py_parser.add_argument("--version", type=str, default="chat", help='version to interact with') 229 | py_parser.add_argument("--from_pretrained", type=str, default="cogagent-chat", help='pretrained ckpt') 230 | py_parser.add_argument("--local_tokenizer", type=str, default="lmsys/vicuna-7b-v1.5", help='tokenizer path') 231 | py_parser.add_argument("--vit_checkpoint_activations", action='store_true') 232 | py_parser = FineTuneTestCogAgentModel.add_model_specific_args(py_parser) 233 | known, args_list = py_parser.parse_known_args() 234 | args = get_args(args_list) 235 | args = argparse.Namespace(**vars(args), **vars(known)) 236 | if args.use_qlora: 237 | args.device = 'cpu' 238 | 239 | model, args = FineTuneTestCogAgentModel.from_pretrained(args.from_pretrained, args, overwrite_args={'model_parallel_size': args.model_parallel_size} if args.model_parallel_size != 1 else {}) 240 | if args.use_qlora and torch.cuda.is_available(): 241 | model = model.to('cuda') 242 | from utils.utils import llama2_tokenizer 243 | tokenizer = llama2_tokenizer(args.local_tokenizer, signal_type=args.version) 244 | image_processor = get_image_processor(args.eva_args["image_size"][0]) 245 | cross_image_processor = get_image_processor(args.cross_image_pix) 246 | text_processor = llama2_text_processor(tokenizer, args.max_length, args.image_length) 247 | 248 | training_main(args, model_cls=model, forward_step_function=forward_step, create_dataset_function=partial(create_dataset_function, image_processor, text_processor, cross_image_processor), collate_fn=partial(data_collator, cross_image_processor=cross_image_processor), forward_step_eval=forward_step_eval) -------------------------------------------------------------------------------- /finetune_demo/evaluate_cogvlm.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | # export PATH=/usr/local/cuda/bin:$PATH 3 | # export LD_LIBRARY_PATH=/usr/local/cuda/lib64:$LD_LIBRARY_PATH 4 | 5 | NUM_GPUS_PER_WORKER=8 6 | MP_SIZE=1 7 | 8 | script_path=$(realpath $0) 9 | script_dir=$(dirname $script_path) 10 | main_dir=$(dirname $script_dir) 11 | MODEL_TYPE="cogvlm-base-490" 12 | VERSION="base" 13 | MODEL_ARGS="--from_pretrained ./checkpoints/merged_lora_490 \ 14 | --max_length 1288 \ 15 | --lora_rank 10 \ 16 | --use_lora \ 17 | --local_tokenizer lmsys/vicuna-7b-v1.5 \ 18 | --version $VERSION" 19 | # Tips: If training models of resolution 244, you can set --max_length smaller 20 | 21 | 22 | OPTIONS_SAT="SAT_HOME=~/.sat_models" 23 | OPTIONS_NCCL="NCCL_DEBUG=info NCCL_IB_DISABLE=0 NCCL_NET_GDR_LEVEL=2 LOCAL_WORLD_SIZE=$NUM_GPUS_PER_WORKER" 24 | HOST_FILE_PATH="hostfile" 25 | 26 | train_data="./archive_split/train" 27 | test_data="./archive_split/test" 28 | 29 | gpt_options=" \ 30 | --experiment-name finetune-$MODEL_TYPE \ 31 | --model-parallel-size ${MP_SIZE} \ 32 | --mode finetune \ 33 | --train-iters 0 \ 34 | --resume-dataloader \ 35 | $MODEL_ARGS \ 36 | --train-data ${train_data} \ 37 | --test-data ${test_data} \ 38 | --distributed-backend nccl \ 39 | --lr-decay-style cosine \ 40 | --warmup .02 \ 41 | --checkpoint-activations \ 42 | --save-interval 200 \ 43 | --eval-interval 200 \ 44 | --save "./checkpoints" \ 45 | --strict-eval \ 46 | --eval-batch-size 1 \ 47 | --split 1. \ 48 | --deepspeed_config test_config_bf16.json \ 49 | --skip-init \ 50 | --seed 2023 51 | " 52 | 53 | 54 | 55 | run_cmd="${OPTIONS_NCCL} ${OPTIONS_SAT} deepspeed --master_port 16666 --hostfile ${HOST_FILE_PATH} evaluate_cogvlm_demo.py ${gpt_options}" 56 | echo ${run_cmd} 57 | eval ${run_cmd} 58 | 59 | set +x -------------------------------------------------------------------------------- /finetune_demo/evaluate_cogvlm_demo.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import argparse 4 | from functools import partial 5 | import sys 6 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 7 | 8 | from sat import mpu, get_args, get_tokenizer 9 | from sat.training.deepspeed_training import training_main 10 | from sat.helpers import print_rank0 11 | from utils.models import FineTuneTestCogVLMModel 12 | from utils.utils import llama2_text_processor, llama2_text_processor_inference, get_image_processor 13 | 14 | 15 | def data_collator(examples): 16 | examples = [ex for ex in examples if len(ex) > 0] # drop {} 17 | for example in examples: 18 | for k in example: 19 | if isinstance(example[k], list): 20 | example[k] = torch.tensor(example[k]) 21 | elif isinstance(example[k], np.ndarray): 22 | example[k] = torch.from_numpy(example[k]) 23 | img_args = {} 24 | tmp_example = examples[0] 25 | for k in tmp_example['vision']: 26 | if type(tmp_example['vision'][k]) is torch.Tensor: 27 | img_args['vision_'+k] = torch.cat([example['vision'][k] for example in examples]) 28 | else: 29 | img_args['vision_'+k] = example['vision'][k] 30 | for example in examples: 31 | example.pop('vision') 32 | if 'cross' in example: 33 | example.pop('cross') 34 | 35 | model_args = {} 36 | tmp_example = examples[0] 37 | for k in tmp_example: 38 | if type(tmp_example[k]) is torch.Tensor: 39 | model_args[k] = torch.cat([example[k] for example in examples]) 40 | else: 41 | model_args[k] = tmp_example[k] 42 | model_args.update(img_args) 43 | return model_args 44 | 45 | from collections import defaultdict 46 | 47 | def broadcast_auto(data_dict): 48 | type2list = defaultdict(list) 49 | other = [] 50 | for k in data_dict: 51 | if type(data_dict[k]) is torch.Tensor: 52 | type2list[data_dict[k].dtype].append(k) 53 | else: 54 | other.append(k) 55 | new_data = {} 56 | for k in type2list: 57 | new_data.update(mpu.broadcast_data(type2list[k], data_dict, k)) 58 | for k in other: 59 | new_data[k] = data_dict[k] 60 | return new_data 61 | 62 | def get_batch(data_iterator, args, timers): 63 | # Broadcast data. 64 | timers('data loader').start() 65 | if data_iterator is not None: 66 | data = next(data_iterator) 67 | else: 68 | data = None 69 | timers('data loader').stop() 70 | data_b = broadcast_auto(data) 71 | for k in data_b: 72 | if type(data_b[k]) is torch.Tensor and data_b[k].dtype is not torch.int32 and data_b[k].dtype is not torch.long: 73 | if args.fp16: 74 | data_b[k] = data_b[k].half() 75 | elif args.bf16: 76 | data_b[k] = data_b[k].bfloat16() 77 | return data_b 78 | 79 | from torch.nn import CrossEntropyLoss 80 | import numpy as np 81 | 82 | from sat.model.mixins import CachedAutoregressiveMixin 83 | from sat.generation.autoregressive_sampling import filling_sequence 84 | from sat.generation.sampling_strategies import BaseStrategy, BeamSearchStrategy 85 | 86 | 87 | def chat(model, tokenizer, tokens, 88 | max_length: int = 1800, num_beams=5, top_p=0.95, top_k=0, temperature=0.8, **kwargs): 89 | inputs = tokens.to(model.parameters().__next__().device)[0] 90 | seq = torch.cat( 91 | [inputs, torch.tensor([-1] * (max_length - len(inputs)), device=inputs.device)], dim=0 92 | ) 93 | strategy = BaseStrategy(temperature=temperature, top_p=0.4, top_k=1, end_tokens=[tokenizer.eos_token_id]) 94 | # strategy = BeamSearchStrategy(temperature=temperature, top_p=top_p, top_k=top_k, end_tokens=[tokenizer.eos_token_id], 95 | # num_beams=num_beams, consider_end=True) 96 | get_func = llama2_text_processor_inference.get_func(None, None, image_rope_mask=kwargs['image_rope_mask']) 97 | output = filling_sequence( 98 | model, seq, 99 | batch_size=1, 100 | strategy=strategy, 101 | get_masks_and_position_ids=get_func, 102 | **kwargs 103 | )[0] # drop memory 104 | 105 | return output 106 | 107 | 108 | def forward_step_eval(data_iterator, model, args, timers): 109 | def compute_metrics(eval_preds): 110 | preds, labels, device = eval_preds 111 | preds = preds.unsqueeze(0) 112 | if isinstance(preds, tuple): 113 | preds = preds[0] 114 | decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True) 115 | if args.ignore_pad_token_for_loss: 116 | # Replace -100 in the labels as we can't decode them. 117 | labels = np.where(labels != -100, labels, tokenizer.pad_token_id) 118 | decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True) 119 | 120 | score_dict = { 121 | "acc": [], 122 | "acc_w/o_case": [], 123 | } 124 | for pred, label in zip(decoded_preds, decoded_labels): 125 | if args.rank == 0: 126 | print('pred', pred, 'label', label, flush=True) 127 | if pred == label: 128 | score_dict['acc'].append(1.) 129 | else: 130 | score_dict['acc'].append(0.) 131 | if pred.lower() == label.lower(): 132 | score_dict['acc_w/o_case'].append(1.) 133 | else: 134 | score_dict['acc_w/o_case'].append(0.) 135 | 136 | 137 | for k, v in score_dict.items(): 138 | score_dict[k] = float(np.mean(v)) 139 | return score_dict 140 | 141 | # Get the batch. 142 | timers('batch generator').start() 143 | data_b = get_batch( 144 | data_iterator, args, timers) 145 | timers('batch generator').stop() 146 | 147 | context_len = int(data_b['context_length'][0]) 148 | tokens = data_b['input_ids'][:, :context_len] 149 | data_b['vision_expert_mask'] = data_b['vision_expert_mask'][:, :context_len] 150 | data_b['image_embed_mask'] = data_b['image_embed_mask'][:, :context_len] 151 | data_b['image_rope_mask'] = data_b['image_rope_mask'][:, :context_len] 152 | 153 | data_b.pop('input_ids') 154 | data_b.pop('attention_mask') 155 | data_b.pop('position_ids') 156 | labels = data_b.pop('labels') 157 | qid = data_b.pop('question_id') 158 | 159 | model.add_mixin('auto-regressive', CachedAutoregressiveMixin()) 160 | outputs = chat(model, tokenizer, tokens, **data_b)[0][context_len:] 161 | # print(outputs) 162 | model.del_mixin('auto-regressive') 163 | 164 | return torch.tensor(0, device=outputs.device), {k: torch.tensor(v, device=outputs.device) for k, v in 165 | compute_metrics( 166 | (outputs.cpu(), labels.cpu(), outputs.device)).items()} 167 | 168 | 169 | from torch.nn import CrossEntropyLoss 170 | def forward_step(data_iterator, model, args, timers): 171 | """Forward step.""" 172 | 173 | # Get the batch. 174 | timers('batch generator').start() 175 | data_b = get_batch( 176 | data_iterator, args, timers) 177 | labels = data_b.pop('labels') 178 | timers('batch generator').stop() 179 | logits = model(**data_b)[0] 180 | lm_logits = logits.to(torch.float32) 181 | # Shift so that tokens < n predict n 182 | shift_labels = labels[..., 1:].contiguous() 183 | shift_logits = lm_logits[..., -1-shift_labels.size(-1):-1, :].contiguous() 184 | # Flatten the tokens 185 | loss_fct = CrossEntropyLoss(ignore_index=-100) 186 | loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) 187 | loss = loss.to(torch.float32) 188 | 189 | return loss, {'loss': loss} 190 | 191 | from utils.utils import ItemDataset 192 | def create_dataset_function(image_processor, text_processor, path, args): 193 | dataset = ItemDataset(image_processor, text_processor, args, path) 194 | return dataset 195 | 196 | if __name__ == '__main__': 197 | py_parser = argparse.ArgumentParser(add_help=False) 198 | py_parser.add_argument('--max_length', type=int) 199 | py_parser.add_argument('--ignore_pad_token_for_loss', action='store_false') 200 | py_parser.add_argument("--version", type=str, default="chat", help='version to interact with') 201 | py_parser.add_argument("--from_pretrained", type=str, default="cogvlm-chat", help='pretrained ckpt') 202 | py_parser.add_argument("--local_tokenizer", type=str, default="lmsys/vicuna-7b-v1.5", help='tokenizer path') 203 | py_parser.add_argument("--vit_checkpoint_activations", action='store_true') 204 | py_parser = FineTuneTestCogVLMModel.add_model_specific_args(py_parser) 205 | known, args_list = py_parser.parse_known_args() 206 | args = get_args(args_list) 207 | args = argparse.Namespace(**vars(args), **vars(known)) 208 | if args.use_qlora: 209 | args.device = 'cpu' 210 | 211 | model, args = FineTuneTestCogVLMModel.from_pretrained(args.from_pretrained, args, overwrite_args={'model_parallel_size': args.model_parallel_size} if args.model_parallel_size != 1 else {}) 212 | if args.use_qlora and torch.cuda.is_available(): 213 | model = model.to('cuda') 214 | from utils.utils import llama2_tokenizer 215 | tokenizer = llama2_tokenizer(args.local_tokenizer, signal_type=args.version) 216 | image_processor = get_image_processor(args.eva_args["image_size"][0]) 217 | text_processor = llama2_text_processor(tokenizer, args.max_length, args.image_length) 218 | 219 | training_main(args, model_cls=model, forward_step_function=forward_step, create_dataset_function=partial(create_dataset_function, image_processor, text_processor), collate_fn=data_collator, forward_step_eval=forward_step_eval) -------------------------------------------------------------------------------- /finetune_demo/finetune_cogagent_demo.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import argparse 4 | from functools import partial 5 | import sys 6 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 7 | 8 | from sat import mpu, get_args, get_tokenizer 9 | from sat.training.deepspeed_training import training_main 10 | from sat.helpers import print_rank0 11 | from utils.models import FineTuneTrainCogAgentModel 12 | from utils.utils import llama2_text_processor, llama2_text_processor_inference, get_image_processor 13 | 14 | def disable_untrainable_params(self): 15 | total_trainable = 0 16 | # enable = ['vit'] 17 | enable = ["encoder", "cross_attention", "linear_proj", 'mlp.vision', 'rotary.vision', 'eoi', 'boi', 'vit'] 18 | if self.args.use_ptuning: 19 | enable.extend(['ptuning']) 20 | if self.args.use_lora or self.args.use_qlora: 21 | enable.extend(['matrix_A', 'matrix_B']) 22 | for n, p in self.named_parameters(): 23 | flag = False 24 | for e in enable: 25 | if type(e) is tuple: 26 | if e[0].lower() in n.lower() and e[1].lower() in n.lower() and 55 > int(n[:n.find('.mlp')].split('.')[-1]) > 45: 27 | flag = True 28 | break 29 | else: 30 | if e.lower() in n.lower(): 31 | flag = True 32 | break 33 | if not flag: 34 | p.requires_grad_(False) 35 | else: 36 | total_trainable += p.numel() 37 | if 'encoder' in n or 'vit' in n: 38 | p.lr_scale = 0.1 39 | print_rank0(n) 40 | print_rank0("***** Total trainable parameters: "+str(total_trainable)+" *****") 41 | 42 | FineTuneTrainCogAgentModel.disable_untrainable_params = disable_untrainable_params 43 | 44 | def data_collator(examples, cross_image_processor=None): 45 | def to_tensor(value): 46 | """Converts lists or numpy arrays to tensors.""" 47 | if isinstance(value, list): 48 | return torch.tensor(value) 49 | elif isinstance(value, np.ndarray): 50 | return torch.from_numpy(value) 51 | return value 52 | 53 | def concatenate_tensors(attribute, key): 54 | """Concatenates tensors for a specific attribute and key.""" 55 | if attribute is None: 56 | return torch.cat([ex[key] for ex in examples if isinstance(ex[key], torch.Tensor)]) 57 | else: 58 | return torch.cat([ex[attribute][key] for ex in examples if isinstance(ex[attribute][key], torch.Tensor)]) 59 | 60 | # Convert all lists and numpy arrays in examples to tensors 61 | for example in examples: 62 | for key, value in example.items(): 63 | example[key] = to_tensor(value) 64 | 65 | # Extract and concatenate attributes from examples 66 | img_args = {} 67 | for attribute in ['vision', 'cross']: 68 | if attribute == 'cross' and cross_image_processor is None: 69 | continue 70 | 71 | if attribute in examples[-1]: # Using the last example as reference 72 | for key in examples[-1][attribute]: 73 | tensor_key = f"{attribute}_{key}" 74 | tensors_to_concatenate = [ex[attribute][key] for ex in examples if isinstance(ex[attribute][key], torch.Tensor)] 75 | if tensors_to_concatenate: 76 | img_args[tensor_key] = concatenate_tensors(attribute, key) 77 | else: 78 | img_args[tensor_key] = examples[-1][attribute][key] 79 | 80 | # Remove 'vision' and 'cross' keys from examples 81 | for example in examples: 82 | example.pop('vision', None) 83 | example.pop('cross', None) 84 | 85 | # Create model_args by concatenating tensors and copying other attributes 86 | model_args = {key: concatenate_tensors(None, key) 87 | if isinstance(examples[-1][key], torch.Tensor) else examples[-1][key] 88 | for key in examples[-1] 89 | } 90 | 91 | # Merge img_args into model_args 92 | model_args.update(img_args) 93 | return model_args 94 | 95 | 96 | from collections import defaultdict 97 | 98 | def broadcast_auto(data_dict): 99 | type2list = defaultdict(list) 100 | other = [] 101 | for k in data_dict: 102 | if type(data_dict[k]) is torch.Tensor: 103 | type2list[data_dict[k].dtype].append(k) 104 | else: 105 | other.append(k) 106 | new_data = {} 107 | for k in type2list: 108 | new_data.update(mpu.broadcast_data(type2list[k], data_dict, k)) 109 | for k in other: 110 | new_data[k] = data_dict[k] 111 | return new_data 112 | 113 | def get_batch(data_iterator, args, timers): 114 | # Broadcast data. 115 | timers('data loader').start() 116 | if data_iterator is not None: 117 | data = next(data_iterator) 118 | else: 119 | data = None 120 | timers('data loader').stop() 121 | data_b = broadcast_auto(data) 122 | for k in data_b: 123 | if type(data_b[k]) is torch.Tensor and data_b[k].dtype is not torch.int32 and data_b[k].dtype is not torch.long: 124 | if args.fp16: 125 | data_b[k] = data_b[k].half() 126 | elif args.bf16: 127 | data_b[k] = data_b[k].bfloat16() 128 | return data_b 129 | 130 | from torch.nn import CrossEntropyLoss 131 | import numpy as np 132 | 133 | from sat.model.mixins import CachedAutoregressiveMixin 134 | from sat.generation.autoregressive_sampling import filling_sequence 135 | from sat.generation.sampling_strategies import BaseStrategy, BeamSearchStrategy 136 | 137 | 138 | def chat(model, tokenizer, tokens, 139 | max_length: int = 1800, num_beams=5, top_p=0.95, top_k=0, temperature=0.8, **kwargs): 140 | inputs = tokens.to(model.parameters().__next__().device)[0] 141 | seq = torch.cat( 142 | [inputs, torch.tensor([-1] * (max_length - len(inputs)), device=inputs.device)], dim=0 143 | ) 144 | strategy = BaseStrategy(temperature=temperature, top_p=0.4, top_k=1, end_tokens=[tokenizer.eos_token_id]) 145 | # strategy = BeamSearchStrategy(temperature=temperature, top_p=top_p, top_k=top_k, end_tokens=[tokenizer.eos_token_id], 146 | # num_beams=num_beams, consider_end=True) 147 | get_func = llama2_text_processor_inference.get_func(None, None, image_rope_mask=kwargs['image_rope_mask']) 148 | output = filling_sequence( 149 | model, seq, 150 | batch_size=1, 151 | strategy=strategy, 152 | get_masks_and_position_ids=get_func, 153 | **kwargs 154 | )[0] # drop memory 155 | 156 | return output 157 | 158 | 159 | def forward_step_eval(data_iterator, model, args, timers): 160 | def compute_metrics(eval_preds): 161 | preds, labels, device = eval_preds 162 | preds = preds.unsqueeze(0) 163 | if isinstance(preds, tuple): 164 | preds = preds[0] 165 | decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True) 166 | if args.ignore_pad_token_for_loss: 167 | # Replace -100 in the labels as we can't decode them. 168 | labels = np.where(labels != -100, labels, tokenizer.pad_token_id) 169 | decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True) 170 | 171 | score_dict = { 172 | "acc": [], 173 | "acc_w/o_case": [], 174 | } 175 | for pred, label in zip(decoded_preds, decoded_labels): 176 | if args.rank == 0: 177 | print('pred', pred, 'label', label, flush=True) 178 | if pred == label: 179 | score_dict['acc'].append(1.) 180 | else: 181 | score_dict['acc'].append(0.) 182 | if pred.lower() == label.lower(): 183 | score_dict['acc_w/o_case'].append(1.) 184 | else: 185 | score_dict['acc_w/o_case'].append(0.) 186 | 187 | 188 | for k, v in score_dict.items(): 189 | score_dict[k] = float(np.mean(v)) 190 | return score_dict 191 | 192 | # Get the batch. 193 | timers('batch generator').start() 194 | data_b = get_batch( 195 | data_iterator, args, timers) 196 | timers('batch generator').stop() 197 | 198 | context_len = int(data_b['context_length'][0]) 199 | tokens = data_b['input_ids'][:, :context_len] 200 | data_b['vision_expert_mask'] = data_b['vision_expert_mask'][:, :context_len] 201 | data_b['image_embed_mask'] = data_b['image_embed_mask'][:, :context_len] 202 | data_b['image_rope_mask'] = data_b['image_rope_mask'][:, :context_len] 203 | 204 | data_b.pop('input_ids') 205 | data_b.pop('attention_mask') 206 | data_b.pop('position_ids') 207 | labels = data_b.pop('labels') 208 | qid = data_b.pop('question_id') 209 | 210 | model.add_mixin('auto-regressive', CachedAutoregressiveMixin()) 211 | outputs = chat(model, tokenizer, tokens, **data_b)[0][context_len:] 212 | # print(outputs) 213 | model.del_mixin('auto-regressive') 214 | 215 | return torch.tensor(0, device=outputs.device), {k: torch.tensor(v, device=outputs.device) for k, v in 216 | compute_metrics( 217 | (outputs.cpu(), labels.cpu(), outputs.device)).items()} 218 | 219 | 220 | from torch.nn import CrossEntropyLoss 221 | def forward_step(data_iterator, model, args, timers): 222 | """Forward step.""" 223 | 224 | # Get the batch. 225 | timers('batch generator').start() 226 | data_b = get_batch( 227 | data_iterator, args, timers) 228 | labels = data_b.pop('labels') 229 | timers('batch generator').stop() 230 | logits = model(**data_b)[0] 231 | lm_logits = logits.to(torch.float32) 232 | # Shift so that tokens < n predict n 233 | shift_labels = labels[..., 1:].contiguous() 234 | shift_logits = lm_logits[..., -1-shift_labels.size(-1):-1, :].contiguous() 235 | # Flatten the tokens 236 | loss_fct = CrossEntropyLoss(ignore_index=-100) 237 | loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) 238 | loss = loss.to(torch.float32) 239 | 240 | return loss, {'loss': loss} 241 | 242 | from utils.utils import ItemDataset 243 | def create_dataset_function(image_processor, text_processor, cross_image_processor, path, args): 244 | dataset = ItemDataset(image_processor, text_processor, args, path, cross_image_processor=cross_image_processor) 245 | return dataset 246 | 247 | from sat.model.finetune.lora2 import LoraMixin 248 | from sat.model.finetune.prompt_tuning import PTuningV2Mixin 249 | 250 | if __name__ == '__main__': 251 | py_parser = argparse.ArgumentParser(add_help=False) 252 | py_parser.add_argument('--max_length', type=int) 253 | py_parser.add_argument('--ignore_pad_token_for_loss', action='store_false') 254 | py_parser.add_argument("--version", type=str, default="chat", choices=["chat", "vqa"], help='version to interact with') 255 | py_parser.add_argument("--from_pretrained", type=str, default="cogagent-chat", help='pretrained ckpt') 256 | py_parser.add_argument("--local_tokenizer", type=str, default="lmsys/vicuna-7b-v1.5", help='tokenizer path') 257 | py_parser.add_argument("--vit_checkpoint_activations", action='store_true') 258 | py_parser = FineTuneTrainCogAgentModel.add_model_specific_args(py_parser) 259 | known, args_list = py_parser.parse_known_args() 260 | args = get_args(args_list) 261 | args = argparse.Namespace(**vars(args), **vars(known)) 262 | if args.use_qlora: 263 | args.device = 'cpu' 264 | 265 | model, args = FineTuneTrainCogAgentModel.from_pretrained(args.from_pretrained, args, overwrite_args={'model_parallel_size': args.model_parallel_size} if args.model_parallel_size != 1 else {}) 266 | if args.use_ptuning: # TODO: wait for SAT updating 267 | model.add_mixin("ptuning", PTuningV2Mixin(args.num_layers, args.hidden_size // args.num_attention_heads, args.num_attention_heads, args.pre_seq_len)) 268 | 269 | if args.use_lora: 270 | model.add_mixin("lora", LoraMixin(args.num_layers, args.lora_rank, layer_range=args.layer_range), reinit=True) 271 | model.get_mixin("eva").vit_model.add_mixin("lora", LoraMixin(args.eva_args['num_layers'], args.lora_rank, layer_range=args.layer_range), reinit=True) 272 | elif args.use_qlora: 273 | model.add_mixin("lora", LoraMixin(args.num_layers, args.lora_rank, layer_range=args.layer_range, qlora=True), reinit=True) 274 | 275 | if args.use_qlora and torch.cuda.is_available(): 276 | model = model.to('cuda') 277 | from utils.utils import llama2_tokenizer 278 | tokenizer = llama2_tokenizer(args.local_tokenizer, signal_type=args.version) 279 | image_processor = get_image_processor(args.eva_args["image_size"][0]) 280 | cross_image_processor = get_image_processor(args.cross_image_pix) 281 | text_processor = llama2_text_processor(tokenizer, args.max_length, args.image_length) 282 | 283 | model = training_main(args, model_cls=model, forward_step_function=forward_step, create_dataset_function=partial(create_dataset_function, image_processor, text_processor, cross_image_processor), collate_fn=partial(data_collator, cross_image_processor=cross_image_processor), forward_step_eval=forward_step_eval) 284 | if args.use_lora: 285 | model.get_mixin("lora").merge_lora() 286 | model.get_mixin("eva").vit_model.get_mixin("lora").merge_lora() 287 | args.use_lora = False 288 | args.save = "checkpoints/merged_lora_cogagent" 289 | from sat.training.model_io import save_checkpoint 290 | save_checkpoint(1, model, None, None, args) -------------------------------------------------------------------------------- /finetune_demo/finetune_cogagent_lora.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | # export PATH=/usr/local/cuda/bin:$PATH 3 | # export LD_LIBRARY_PATH=/usr/local/cuda/lib64:$LD_LIBRARY_PATH 4 | 5 | NUM_GPUS_PER_WORKER=8 6 | MP_SIZE=1 7 | 8 | script_path=$(realpath $0) 9 | script_dir=$(dirname $script_path) 10 | main_dir=$(dirname $script_dir) 11 | MODEL_TYPE="cogagent-chat" 12 | VERSION="chat" 13 | MODEL_ARGS="--from_pretrained $MODEL_TYPE \ 14 | --max_length 400 \ 15 | --lora_rank 50 \ 16 | --use_lora \ 17 | --local_tokenizer lmsys/vicuna-7b-v1.5 \ 18 | --version $VERSION" 19 | # TIPS: max_length include low-resolution image sequence (which has 256 tokens) 20 | 21 | OPTIONS_SAT="SAT_HOME=~/.sat_models" 22 | OPTIONS_NCCL="NCCL_DEBUG=info NCCL_IB_DISABLE=0 NCCL_NET_GDR_LEVEL=2 LOCAL_WORLD_SIZE=$NUM_GPUS_PER_WORKER" 23 | HOST_FILE_PATH="hostfile" 24 | 25 | train_data="./archive_split/train" 26 | valid_data="./archive_split/valid" 27 | 28 | gpt_options=" \ 29 | --experiment-name finetune-$MODEL_TYPE \ 30 | --model-parallel-size ${MP_SIZE} \ 31 | --mode finetune \ 32 | --train-iters 2000 \ 33 | --resume-dataloader \ 34 | $MODEL_ARGS \ 35 | --train-data ${train_data} \ 36 | --valid-data ${valid_data} \ 37 | --distributed-backend nccl \ 38 | --lr-decay-style cosine \ 39 | --warmup .02 \ 40 | --checkpoint-activations \ 41 | --vit_checkpoint_activations \ 42 | --save-interval 200 \ 43 | --eval-interval 200 \ 44 | --save "./checkpoints" \ 45 | --eval-iters 10 \ 46 | --eval-batch-size 1 \ 47 | --split 1. \ 48 | --deepspeed_config test_config_bf16.json \ 49 | --skip-init \ 50 | --seed 2023 51 | " 52 | 53 | 54 | 55 | run_cmd="${OPTIONS_NCCL} ${OPTIONS_SAT} deepspeed --master_port 16666 --hostfile ${HOST_FILE_PATH} finetune_cogagent_demo.py ${gpt_options}" 56 | echo ${run_cmd} 57 | eval ${run_cmd} 58 | 59 | set +x -------------------------------------------------------------------------------- /finetune_demo/finetune_cogvlm_demo.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import argparse 4 | from functools import partial 5 | import sys 6 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 7 | 8 | from sat import mpu, get_args, get_tokenizer 9 | from sat.training.deepspeed_training import training_main 10 | from sat.helpers import print_rank0 11 | from utils.models import FineTuneTrainCogVLMModel 12 | from utils.utils import llama2_text_processor, llama2_text_processor_inference, get_image_processor 13 | 14 | def disable_untrainable_params(self): 15 | total_trainable = 0 16 | enable = [('mlp', 'vit')] 17 | if self.args.use_ptuning: 18 | enable.extend(['ptuning']) 19 | if self.args.use_lora or self.args.use_qlora: 20 | enable.extend(['matrix_A', 'matrix_B']) 21 | for n, p in self.named_parameters(): 22 | flag = False 23 | for e in enable: 24 | if type(e) is tuple: 25 | if e[0].lower() in n.lower() and e[1].lower() in n.lower() and 55 > int(n[:n.find('.mlp')].split('.')[-1]) > 45: 26 | flag = True 27 | break 28 | else: 29 | if e.lower() in n.lower(): 30 | flag = True 31 | break 32 | if not flag: 33 | p.requires_grad_(False) 34 | else: 35 | total_trainable += p.numel() 36 | print_rank0(n) 37 | print_rank0("***** Total trainable parameters: "+str(total_trainable)+" *****") 38 | 39 | FineTuneTrainCogVLMModel.disable_untrainable_params = disable_untrainable_params 40 | 41 | def data_collator(examples): 42 | examples = [ex for ex in examples if len(ex) > 0] # drop {} 43 | for example in examples: 44 | for k in example: 45 | if isinstance(example[k], list): 46 | example[k] = torch.tensor(example[k]) 47 | elif isinstance(example[k], np.ndarray): 48 | example[k] = torch.from_numpy(example[k]) 49 | img_args = {} 50 | tmp_example = examples[0] 51 | for k in tmp_example['vision']: 52 | if type(tmp_example['vision'][k]) is torch.Tensor: 53 | img_args['vision_'+k] = torch.cat([example['vision'][k] for example in examples]) 54 | else: 55 | img_args['vision_'+k] = example['vision'][k] 56 | for example in examples: 57 | example.pop('vision') 58 | if 'cross' in example: 59 | example.pop('cross') 60 | 61 | model_args = {} 62 | tmp_example = examples[0] 63 | for k in tmp_example: 64 | if type(tmp_example[k]) is torch.Tensor: 65 | model_args[k] = torch.cat([example[k] for example in examples]) 66 | else: 67 | model_args[k] = tmp_example[k] 68 | model_args.update(img_args) 69 | return model_args 70 | 71 | from collections import defaultdict 72 | 73 | def broadcast_auto(data_dict): 74 | type2list = defaultdict(list) 75 | other = [] 76 | for k in data_dict: 77 | if type(data_dict[k]) is torch.Tensor: 78 | type2list[data_dict[k].dtype].append(k) 79 | else: 80 | other.append(k) 81 | new_data = {} 82 | for k in type2list: 83 | new_data.update(mpu.broadcast_data(type2list[k], data_dict, k)) 84 | for k in other: 85 | new_data[k] = data_dict[k] 86 | return new_data 87 | 88 | def get_batch(data_iterator, args, timers): 89 | # Broadcast data. 90 | timers('data loader').start() 91 | if data_iterator is not None: 92 | data = next(data_iterator) 93 | else: 94 | data = None 95 | timers('data loader').stop() 96 | data_b = broadcast_auto(data) 97 | for k in data_b: 98 | if type(data_b[k]) is torch.Tensor and data_b[k].dtype is not torch.int32 and data_b[k].dtype is not torch.long: 99 | if args.fp16: 100 | data_b[k] = data_b[k].half() 101 | elif args.bf16: 102 | data_b[k] = data_b[k].bfloat16() 103 | return data_b 104 | 105 | from torch.nn import CrossEntropyLoss 106 | import numpy as np 107 | 108 | from sat.model.mixins import CachedAutoregressiveMixin 109 | from sat.generation.autoregressive_sampling import filling_sequence 110 | from sat.generation.sampling_strategies import BaseStrategy, BeamSearchStrategy 111 | 112 | 113 | def chat(model, tokenizer, tokens, 114 | max_length: int = 1800, num_beams=5, top_p=0.95, top_k=0, temperature=0.8, **kwargs): 115 | inputs = tokens.to(model.parameters().__next__().device)[0] 116 | seq = torch.cat( 117 | [inputs, torch.tensor([-1] * (max_length - len(inputs)), device=inputs.device)], dim=0 118 | ) 119 | strategy = BaseStrategy(temperature=temperature, top_p=0.4, top_k=1, end_tokens=[tokenizer.eos_token_id]) 120 | # strategy = BeamSearchStrategy(temperature=temperature, top_p=top_p, top_k=top_k, end_tokens=[tokenizer.eos_token_id], 121 | # num_beams=num_beams, consider_end=True) 122 | get_func = llama2_text_processor_inference.get_func(None, None, image_rope_mask=kwargs['image_rope_mask']) 123 | output = filling_sequence( 124 | model, seq, 125 | batch_size=1, 126 | strategy=strategy, 127 | get_masks_and_position_ids=get_func, 128 | **kwargs 129 | )[0] # drop memory 130 | 131 | return output 132 | 133 | 134 | def forward_step_eval(data_iterator, model, args, timers): 135 | def compute_metrics(eval_preds): 136 | preds, labels, device = eval_preds 137 | preds = preds.unsqueeze(0) 138 | if isinstance(preds, tuple): 139 | preds = preds[0] 140 | decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True) 141 | if args.ignore_pad_token_for_loss: 142 | # Replace -100 in the labels as we can't decode them. 143 | labels = np.where(labels != -100, labels, tokenizer.pad_token_id) 144 | decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True) 145 | 146 | score_dict = { 147 | "acc": [], 148 | "acc_w/o_case": [], 149 | } 150 | for pred, label in zip(decoded_preds, decoded_labels): 151 | if args.rank == 0: 152 | print('pred', pred, 'label', label, flush=True) 153 | if pred == label: 154 | score_dict['acc'].append(1.) 155 | else: 156 | score_dict['acc'].append(0.) 157 | if pred.lower() == label.lower(): 158 | score_dict['acc_w/o_case'].append(1.) 159 | else: 160 | score_dict['acc_w/o_case'].append(0.) 161 | 162 | 163 | for k, v in score_dict.items(): 164 | score_dict[k] = float(np.mean(v)) 165 | return score_dict 166 | 167 | # Get the batch. 168 | timers('batch generator').start() 169 | data_b = get_batch( 170 | data_iterator, args, timers) 171 | timers('batch generator').stop() 172 | 173 | context_len = int(data_b['context_length'][0]) 174 | tokens = data_b['input_ids'][:, :context_len] 175 | data_b['vision_expert_mask'] = data_b['vision_expert_mask'][:, :context_len] 176 | data_b['image_embed_mask'] = data_b['image_embed_mask'][:, :context_len] 177 | data_b['image_rope_mask'] = data_b['image_rope_mask'][:, :context_len] 178 | 179 | data_b.pop('input_ids') 180 | data_b.pop('attention_mask') 181 | data_b.pop('position_ids') 182 | labels = data_b.pop('labels') 183 | qid = data_b.pop('question_id') 184 | 185 | model.add_mixin('auto-regressive', CachedAutoregressiveMixin()) 186 | outputs = chat(model, tokenizer, tokens, **data_b)[0][context_len:] 187 | # print(outputs) 188 | model.del_mixin('auto-regressive') 189 | 190 | return torch.tensor(0, device=outputs.device), {k: torch.tensor(v, device=outputs.device) for k, v in 191 | compute_metrics( 192 | (outputs.cpu(), labels.cpu(), outputs.device)).items()} 193 | 194 | 195 | from torch.nn import CrossEntropyLoss 196 | def forward_step(data_iterator, model, args, timers): 197 | """Forward step.""" 198 | 199 | # Get the batch. 200 | timers('batch generator').start() 201 | data_b = get_batch( 202 | data_iterator, args, timers) 203 | labels = data_b.pop('labels') 204 | timers('batch generator').stop() 205 | logits = model(**data_b)[0] 206 | lm_logits = logits.to(torch.float32) 207 | # Shift so that tokens < n predict n 208 | shift_labels = labels[..., 1:].contiguous() 209 | shift_logits = lm_logits[..., -1-shift_labels.size(-1):-1, :].contiguous() 210 | # Flatten the tokens 211 | loss_fct = CrossEntropyLoss(ignore_index=-100) 212 | loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) 213 | loss = loss.to(torch.float32) 214 | 215 | return loss, {'loss': loss} 216 | 217 | from utils.utils import ItemDataset 218 | def create_dataset_function(image_processor, text_processor, path, args): 219 | dataset = ItemDataset(image_processor, text_processor, args, path) 220 | return dataset 221 | 222 | from sat.model.finetune.lora2 import LoraMixin 223 | from sat.model.finetune.prompt_tuning import PTuningV2Mixin 224 | 225 | if __name__ == '__main__': 226 | py_parser = argparse.ArgumentParser(add_help=False) 227 | py_parser.add_argument('--max_length', type=int) 228 | py_parser.add_argument('--ignore_pad_token_for_loss', action='store_false') 229 | py_parser.add_argument("--version", type=str, default="chat_old", help='version to interact with') 230 | py_parser.add_argument("--from_pretrained", type=str, default="cogvlm-chat", help='pretrained ckpt') 231 | py_parser.add_argument("--local_tokenizer", type=str, default="lmsys/vicuna-7b-v1.5", help='tokenizer path') 232 | py_parser.add_argument("--vit_checkpoint_activations", action='store_true') 233 | py_parser = FineTuneTrainCogVLMModel.add_model_specific_args(py_parser) 234 | known, args_list = py_parser.parse_known_args() 235 | args = get_args(args_list) 236 | args = argparse.Namespace(**vars(args), **vars(known)) 237 | if args.use_qlora: 238 | args.device = 'cpu' 239 | 240 | model, args = FineTuneTrainCogVLMModel.from_pretrained(args.from_pretrained, args, overwrite_args={'model_parallel_size': args.model_parallel_size} if args.model_parallel_size != 1 else {}) 241 | if args.use_ptuning: 242 | model.add_mixin("ptuning", PTuningV2Mixin(args.num_layers, args.hidden_size // args.num_attention_heads, args.num_attention_heads, args.pre_seq_len)) 243 | if args.use_lora: 244 | model.add_mixin("lora", LoraMixin(args.num_layers, args.lora_rank, layer_range=args.layer_range), reinit=True) 245 | model.get_mixin("eva").vit_model.add_mixin("lora", LoraMixin(args.eva_args['num_layers'], args.lora_rank, layer_range=args.layer_range), reinit=True) 246 | elif args.use_qlora: 247 | model.add_mixin("lora", LoraMixin(args.num_layers, args.lora_rank, layer_range=args.layer_range, qlora=True), reinit=True) 248 | 249 | if args.use_qlora and torch.cuda.is_available(): 250 | model = model.to('cuda') 251 | from utils.utils import llama2_tokenizer 252 | tokenizer = llama2_tokenizer(args.local_tokenizer, signal_type=args.version) 253 | image_processor = get_image_processor(args.eva_args["image_size"][0]) 254 | text_processor = llama2_text_processor(tokenizer, args.max_length, args.image_length) 255 | 256 | model = training_main(args, model_cls=model, forward_step_function=forward_step, create_dataset_function=partial(create_dataset_function, image_processor, text_processor), collate_fn=data_collator, forward_step_eval=forward_step_eval) 257 | if args.use_lora: 258 | model.get_mixin("lora").merge_lora() 259 | model.get_mixin("eva").vit_model.get_mixin("lora").merge_lora() 260 | args.use_lora = False 261 | args.save = "checkpoints/merged_lora_cogvlm{}".format(args.eva_args["image_size"][0]) 262 | from sat.training.model_io import save_checkpoint 263 | save_checkpoint(1, model, None, None, args) -------------------------------------------------------------------------------- /finetune_demo/finetune_cogvlm_lora.sh: -------------------------------------------------------------------------------- 1 | #! /bin/bash 2 | # export PATH=/usr/local/cuda/bin:$PATH 3 | # export LD_LIBRARY_PATH=/usr/local/cuda/lib64:$LD_LIBRARY_PATH 4 | 5 | NUM_GPUS_PER_WORKER=8 6 | MP_SIZE=1 7 | 8 | script_path=$(realpath $0) 9 | script_dir=$(dirname $script_path) 10 | main_dir=$(dirname $script_dir) 11 | MODEL_TYPE="cogvlm-base-490" 12 | VERSION="base" 13 | MODEL_ARGS="--from_pretrained $MODEL_TYPE \ 14 | --max_length 1288 \ 15 | --lora_rank 10 \ 16 | --use_lora \ 17 | --local_tokenizer lmsys/vicuna-7b-v1.5 \ 18 | --version $VERSION" 19 | # Tips: If training models of resolution 244, you can set --max_length smaller 20 | 21 | OPTIONS_SAT="SAT_HOME=~/.sat_models" 22 | OPTIONS_NCCL="NCCL_DEBUG=info NCCL_IB_DISABLE=0 NCCL_NET_GDR_LEVEL=2 LOCAL_WORLD_SIZE=$NUM_GPUS_PER_WORKER" 23 | HOST_FILE_PATH="hostfile" 24 | 25 | train_data="./archive_split/train" 26 | valid_data="./archive_split/valid" 27 | 28 | gpt_options=" \ 29 | --experiment-name finetune-$MODEL_TYPE \ 30 | --model-parallel-size ${MP_SIZE} \ 31 | --mode finetune \ 32 | --train-iters 800 \ 33 | --resume-dataloader \ 34 | $MODEL_ARGS \ 35 | --train-data ${train_data} \ 36 | --valid-data ${valid_data} \ 37 | --distributed-backend nccl \ 38 | --lr-decay-style cosine \ 39 | --warmup .02 \ 40 | --checkpoint-activations \ 41 | --vit_checkpoint_activations \ 42 | --save-interval 200 \ 43 | --eval-interval 200 \ 44 | --save "./checkpoints" \ 45 | --eval-iters 10 \ 46 | --eval-batch-size 1 \ 47 | --split 1. \ 48 | --deepspeed_config test_config_bf16.json \ 49 | --skip-init \ 50 | --seed 2023 51 | " 52 | 53 | 54 | 55 | run_cmd="${OPTIONS_NCCL} ${OPTIONS_SAT} deepspeed --master_port 16666 --hostfile ${HOST_FILE_PATH} finetune_cogvlm_demo.py ${gpt_options}" 56 | echo ${run_cmd} 57 | eval ${run_cmd} 58 | 59 | set +x -------------------------------------------------------------------------------- /finetune_demo/test_config_bf16.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_micro_batch_size_per_gpu": 4, 3 | "gradient_accumulation_steps": 1, 4 | "gradient_clipping": 0.1, 5 | "zero_optimization": { 6 | "stage": 2, 7 | "contiguous_gradients": false, 8 | "overlap_comm": true, 9 | "reduce_scatter": true, 10 | "reduce_bucket_size": 4e7, 11 | "allgather_bucket_size": 1e8, 12 | "load_from_fp32_weights": false 13 | }, 14 | "offload_optimizer": { 15 | "device": "cpu", 16 | "pin_memory": true 17 | }, 18 | "zero_allow_untested_optimizer": true, 19 | "bf16": { 20 | "enabled": true 21 | }, 22 | "optimizer": { 23 | "type": "Adam", 24 | "params": { 25 | "lr": 0.00001, 26 | "betas": [ 27 | 0.9, 28 | 0.95 29 | ], 30 | "eps": 1e-8, 31 | "weight_decay": 5e-2 32 | } 33 | }, 34 | "activation_checkpointing": { 35 | "partition_activations": false, 36 | "contiguous_memory_optimization": false, 37 | "cpu_checkpointing": false 38 | }, 39 | "wall_clock_breakdown": false 40 | } 41 | -------------------------------------------------------------------------------- /openai_demo/demo.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THUDM/CogVLM/f7283b2c8d26cd7f932d9a5f7f5f9307f568195d/openai_demo/demo.jpg -------------------------------------------------------------------------------- /openai_demo/openai_api_request.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script is designed to mimic the OpenAI API interface with CogVLM & CogAgent Chat 3 | It demonstrates how to integrate image and text-based input to generate a response. 4 | Currently, the model can only handle a single image. 5 | Therefore, do not use this script to process multiple images in one conversation. (includes images from history) 6 | And it only works on the chat model, not the base model. 7 | """ 8 | import requests 9 | import json 10 | import base64 11 | 12 | base_url = "http://127.0.0.1:8000" 13 | 14 | 15 | def create_chat_completion(model, messages, temperature=0.8, max_tokens=2048, top_p=0.8, use_stream=False): 16 | """ 17 | This function sends a request to the chat API to generate a response based on the given messages. 18 | 19 | Args: 20 | model (str): The name of the model to use for generating the response. 21 | messages (list): A list of message dictionaries representing the conversation history. 22 | temperature (float): Controls randomness in response generation. Higher values lead to more random responses. 23 | max_tokens (int): The maximum length of the generated response. 24 | top_p (float): Controls diversity of response by filtering less likely options. 25 | use_stream (bool): Determines whether to use a streaming response or a single response. 26 | 27 | The function constructs a JSON payload with the specified parameters and sends a POST request to the API. 28 | It then handles the response, either as a stream (for ongoing responses) or a single message. 29 | """ 30 | 31 | data = { 32 | "model": model, 33 | "messages": messages, 34 | "stream": use_stream, 35 | "max_tokens": max_tokens, 36 | "temperature": temperature, 37 | "top_p": top_p, 38 | } 39 | 40 | response = requests.post(f"{base_url}/v1/chat/completions", json=data, stream=use_stream) 41 | if response.status_code == 200: 42 | if use_stream: 43 | # 处理流式响应 44 | for line in response.iter_lines(): 45 | if line: 46 | decoded_line = line.decode('utf-8')[6:] 47 | try: 48 | response_json = json.loads(decoded_line) 49 | content = response_json.get("choices", [{}])[0].get("delta", {}).get("content", "") 50 | print(content) 51 | except: 52 | print("Special Token:", decoded_line) 53 | else: 54 | # 处理非流式响应 55 | decoded_line = response.json() 56 | content = decoded_line.get("choices", [{}])[0].get("message", "").get("content", "") 57 | print(content) 58 | else: 59 | print("Error:", response.status_code) 60 | return None 61 | 62 | 63 | def encode_image(image_path): 64 | """ 65 | Encodes an image file into a base64 string. 66 | Args: 67 | image_path (str): The path to the image file. 68 | 69 | This function opens the specified image file, reads its content, and encodes it into a base64 string. 70 | The base64 encoding is used to send images over HTTP as text. 71 | """ 72 | 73 | with open(image_path, "rb") as image_file: 74 | return base64.b64encode(image_file.read()).decode("utf-8") 75 | 76 | 77 | def simple_image_chat(use_stream=True, img_path=None): 78 | """ 79 | Facilitates a simple chat interaction involving an image. 80 | 81 | Args: 82 | use_stream (bool): Specifies whether to use streaming for chat responses. 83 | img_path (str): Path to the image file to be included in the chat. 84 | 85 | This function encodes the specified image and constructs a predefined conversation involving the image. 86 | It then calls `create_chat_completion` to generate a response from the model. 87 | The conversation includes asking about the content of the image and a follow-up question. 88 | """ 89 | 90 | img_url = f"data:image/jpeg;base64,{encode_image(img_path)}" 91 | messages = [ 92 | { 93 | "role": "user", 94 | "content": [ 95 | { 96 | "type": "text", 97 | "text": "What’s in this image?", 98 | }, 99 | { 100 | "type": "image_url", 101 | "image_url": { 102 | "url": img_url 103 | }, 104 | }, 105 | ], 106 | }, 107 | { 108 | "role": "assistant", 109 | "content": "The image displays a wooden boardwalk extending through a vibrant green grassy wetland. The sky is partly cloudy with soft, wispy clouds, indicating nice weather. Vegetation is seen on either side of the boardwalk, and trees are present in the background, suggesting that this area might be a natural reserve or park designed for ecological preservation and outdoor recreation. The boardwalk allows visitors to explore the area without disturbing the natural habitat.", 110 | }, 111 | { 112 | "role": "user", 113 | "content": "Do you think this is a spring or winter photo?" 114 | }, 115 | ] 116 | create_chat_completion("cogvlm-chat-17b", messages=messages, use_stream=use_stream) 117 | 118 | 119 | if __name__ == "__main__": 120 | simple_image_chat(use_stream=False, img_path="demo.jpg") 121 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | SwissArmyTransformer>=0.4.9 2 | transformers>=4.36.2 3 | xformers>=0.0.22 4 | torch>=2.1.0 5 | torchvision>=0.16.2 6 | spacy>=3.6.0 7 | pillow>=10.2.0 8 | deepspeed>=0.13.1 9 | seaborn>=0.13.2 10 | loguru~=0.7.2 11 | streamlit>=1.31.0 12 | timm>=0.9.12 13 | accelerate>=0.26.1 14 | pydantic>=2.6.0 15 | 16 | # for openai demo 17 | openai>=1.16.0 18 | sse-starlette>=1.8.2 19 | fastapi>=0.110.1 20 | httpx>=0.27.0 21 | uvicorn>=0.29.0 22 | jsonlines>=4.0.0 23 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THUDM/CogVLM/f7283b2c8d26cd7f932d9a5f7f5f9307f568195d/utils/__init__.py -------------------------------------------------------------------------------- /utils/merge_model.py: -------------------------------------------------------------------------------- 1 | # -*- encoding: utf-8 -*- 2 | import os, sys 3 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 4 | 5 | import torch 6 | import argparse 7 | from models.cogvlm_model import FineTuneTestCogVLMModel 8 | from sat.training.model_io import save_checkpoint 9 | 10 | def main(): 11 | parser = argparse.ArgumentParser() 12 | parser.add_argument("--version", type=str, default="base", help='version to interact with') 13 | parser.add_argument("--from_pretrained", type=str, default="checkpoints/merged_lora", help='pretrained ckpt') 14 | parser.add_argument("--fp16", action="store_true") 15 | parser.add_argument("--bf16", action="store_true") 16 | args = parser.parse_args() 17 | rank = int(os.environ.get('RANK', 0)) 18 | world_size = int(os.environ.get('WORLD_SIZE', 1)) 19 | parser = FineTuneTestCogVLMModel.add_model_specific_args(parser) 20 | args = parser.parse_args() 21 | 22 | # load model 23 | model, model_args = FineTuneTestCogVLMModel.from_pretrained( 24 | args.from_pretrained, 25 | args=argparse.Namespace( 26 | deepspeed=None, 27 | local_rank=rank, 28 | rank=rank, 29 | world_size=world_size, 30 | model_parallel_size=world_size, 31 | mode='inference', 32 | skip_init=True, 33 | use_gpu_initialization=True if torch.cuda.is_available() else False, 34 | device='cuda', 35 | **vars(args) 36 | ), url='local', overwrite_args={'model_parallel_size': 1}) 37 | model = model.eval() 38 | model_args.save = './checkpoints/merged_model_{}'.format(model_args.eva_args["image_size"][0]) 39 | save_checkpoint(1, model, None, None, model_args) 40 | 41 | if __name__ == "__main__": 42 | main() 43 | -------------------------------------------------------------------------------- /utils/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .cogagent_model import CogAgentModel, FineTuneTrainCogAgentModel, FineTuneTestCogAgentModel 2 | from .cogvlm_model import CogVLMModel, FineTuneTrainCogVLMModel, FineTuneTestCogVLMModel -------------------------------------------------------------------------------- /utils/models/cogagent_model.py: -------------------------------------------------------------------------------- 1 | from sat.model.official.llama_model import LLaMAModel 2 | import json 3 | import torch 4 | from functools import partial 5 | from sat.model.base_model import BaseMixin 6 | import torch.nn as nn 7 | import numpy as np 8 | from sat.resources.urls import MODEL_URLS 9 | 10 | from .eva_clip_L_hf import Eva2LargeEncoder 11 | from .mixin import LlamaVisionExpertFCMixin, LlamaVisionExpertAttnMixin 12 | 13 | 14 | MODEL_URLS["cogagent-chat"] = "r2://cogagent-chat.zip" 15 | MODEL_URLS["cogagent-vqa"] = "r2://cogagent-vqa.zip" 16 | 17 | 18 | class GLU(nn.Module): 19 | def __init__(self, args, in_features): 20 | super().__init__() 21 | self.linear_proj = nn.Linear(in_features, args.hidden_size, bias=False) 22 | self.norm1 = nn.LayerNorm(args.hidden_size) 23 | self.act1 = nn.GELU() 24 | self.act2 = nn.functional.silu 25 | self.dense_h_to_4h = nn.Linear(args.hidden_size, args.inner_hidden_size, bias=False) 26 | self.gate_proj = nn.Linear(args.hidden_size, args.inner_hidden_size, bias=False) 27 | self.dense_4h_to_h = nn.Linear(args.inner_hidden_size, args.hidden_size, bias=False) 28 | 29 | def forward(self, x): 30 | x = self.linear_proj(x) 31 | x = self.act1(self.norm1(x)) 32 | x = self.act2(self.gate_proj(x)) * self.dense_h_to_4h(x) 33 | x = self.dense_4h_to_h(x) 34 | return x 35 | 36 | from .eva_clip_model import EVA2CLIPModel 37 | import argparse 38 | from copy import deepcopy 39 | def override_dist_dtype_device_args(args, b={}): 40 | if args.mode == 'inference': 41 | minimal_args = argparse.Namespace( 42 | world_size=args.world_size, 43 | rank=args.rank, 44 | local_rank=args.local_rank, 45 | skip_init=args.skip_init, 46 | use_gpu_initialization=args.use_gpu_initialization, 47 | deepspeed=args.deepspeed, 48 | bf16=args.bf16, 49 | fp16=args.fp16, 50 | mode=args.mode, 51 | device=args.device 52 | ) 53 | else: 54 | minimal_args = argparse.Namespace( 55 | world_size=args.world_size, 56 | rank=args.rank, 57 | local_rank=args.local_rank, 58 | skip_init=args.skip_init, 59 | use_gpu_initialization=args.use_gpu_initialization, 60 | deepspeed=args.deepspeed, 61 | bf16=args.bf16, 62 | fp16=args.fp16, 63 | mode=args.mode, 64 | checkpoint_activations=args.checkpoint_activations if not hasattr(args, 'vit_checkpoint_activations') else args.vit_checkpoint_activations, 65 | checkpoint_num_layers=args.checkpoint_num_layers, 66 | device=args.device, 67 | hidden_dropout=0., 68 | attention_dropout=0., 69 | ) 70 | if hasattr(args, 'model_parallel_size'): 71 | b['model_parallel_size'] = args.model_parallel_size 72 | return argparse.Namespace(**deepcopy(b), **vars(minimal_args)) 73 | 74 | 75 | class ExternalVisionModel(BaseMixin): 76 | '''A combination of vit and a linear projection''' 77 | def __init__(self, args, vitclass): 78 | ''' 79 | args: the args to initialize the vit model 80 | vitclass: the class of VIT model, must be a subclass of BaseModel 81 | project_dim: the dimension of the projection layer 82 | default_load: the default load path for the vit model 83 | model_parallel_size: the model parallel size for the vit model 84 | ''' 85 | super().__init__() 86 | self.vit = vitclass() 87 | # self.ppx = nn.Embedding(80, 1024) 88 | # self.ppy = nn.Embedding(80, 1024) 89 | # nn.init.uniform_(self.ppx.weight.data) 90 | # nn.init.uniform_(self.ppy.weight.data) 91 | 92 | # self.pos_embed = nn.Parameter( 93 | # torch.from_numpy(get_2d_sincos_pos_embed(1024, 80)).float() 94 | # ) 95 | cross_image_length = (args.cross_image_pix//14)**2 96 | self.pos_embed = nn.Parameter( 97 | torch.zeros(cross_image_length, 1024) 98 | ) 99 | 100 | def forward(self, *args, **kw_args): 101 | enc = self.vit(*args, **kw_args) 102 | # i = torch.arange(80, device=enc.device) 103 | # j = torch.arange(80, device=enc.device) 104 | # posx = self.ppx(i).unsqueeze(0).repeat(80, 1, 1) 105 | # posy = self.ppy(j).unsqueeze(1).repeat(1, 80, 1) 106 | # pos = (posx + posy).view(-1, 1024).unsqueeze(0) 107 | 108 | # return enc + pos + self.pos_embed.unsqueeze(0) 109 | return enc + self.pos_embed.unsqueeze(0) 110 | 111 | class ImageMixin(BaseMixin): 112 | def __init__(self, args): 113 | super().__init__() 114 | vit_args = override_dist_dtype_device_args(args, args.eva_args) 115 | self.vit_model = EVA2CLIPModel(EVA2CLIPModel.get_args(**vars(vit_args))) 116 | self.in_features = 1792 117 | self.linear_proj = GLU(args, self.in_features) 118 | self.image_length = args.image_length 119 | self.boi = nn.Parameter(torch.zeros(1, 1, args.hidden_size)) 120 | self.eoi = nn.Parameter(torch.zeros(1, 1, args.hidden_size)) 121 | 122 | # self.ppx = nn.Embedding(16,1792) 123 | # self.ppy = nn.Embedding(16,1792) 124 | 125 | # self.pos_embed = nn.Parameter( 126 | # torch.from_numpy(get_2d_sincos_pos_embed(1792, 16)).float() 127 | # ) 128 | self.pos_embed = nn.Parameter( 129 | torch.zeros(self.image_length, 1792) 130 | ) 131 | 132 | def word_embedding_forward(self, input_ids, output_cross_layer, **kw_args): 133 | vision_inputs = {} 134 | for k in kw_args: 135 | if k.startswith('vision_') and k != 'vision_expert_mask': 136 | vision_inputs[k[7:]] = kw_args[k] 137 | if input_ids.shape[1] == 1 or not vision_inputs: 138 | return self.transformer.word_embeddings(input_ids) 139 | image_emb = self.vit_model(**vision_inputs)[0] 140 | 141 | # i = torch.arange(16, device=image_emb.device) 142 | # j = torch.arange(16, device=image_emb.device) 143 | # posx = self.ppx(i).unsqueeze(0).repeat(16, 1, 1) 144 | # posy = self.ppy(j).unsqueeze(1).repeat(1, 16, 1) 145 | # pos = (posx + posy).view(256, -1).unsqueeze(0) 146 | # image_emb = image_emb + pos + self.pos_embed.unsqueeze(0) 147 | image_emb = image_emb + self.pos_embed.unsqueeze(0) 148 | 149 | image_emb = self.linear_proj(image_emb) 150 | 151 | image_embed_mask = kw_args['image_embed_mask'] 152 | word_embedding = self.transformer.word_embeddings(input_ids).clone() 153 | word_embedding[image_embed_mask.bool()] = torch.cat([self.boi.repeat(len(image_emb), 1, 1), image_emb, self.eoi.repeat(len(image_emb), 1, 1)], dim=1).reshape(-1, image_emb.shape[-1]) 154 | 155 | return word_embedding.contiguous() 156 | 157 | class CogAgentModel(LLaMAModel): 158 | def __init__(self, args, transformer=None, **kwargs): 159 | super().__init__(args, transformer=transformer, **kwargs) 160 | self.image_length = args.image_length 161 | self.cross_image_pix = args.cross_image_pix 162 | self.add_mixin("eva", ImageMixin(args)) 163 | self.del_mixin("mlp") 164 | self.add_mixin("mlp", LlamaVisionExpertFCMixin(args.hidden_size, args.inner_hidden_size, args.num_layers, 32)) 165 | self.del_mixin("rotary") 166 | self.add_mixin("rotary", LlamaVisionExpertAttnMixin(args.hidden_size, args.num_attention_heads, args.num_layers, 32)) 167 | 168 | cross_model = ExternalVisionModel(args, vitclass=partial(Eva2LargeEncoder, image_size=self.cross_image_pix)) 169 | # if args.mode != 'inference': 170 | # cross_model.vit.model.set_grad_checkpointing(True) 171 | self.add_mixin("encoder", cross_model) 172 | 173 | @classmethod 174 | def add_model_specific_args(cls, parser): 175 | group = parser.add_argument_group('CogAgent', 'CogAgent Configurations') 176 | group.add_argument('--image_length', type=int, default=256) 177 | group.add_argument('--cross_image_pix', type=int, default=1120) # Standard CogAgent use 1120; if you want to adjust this param, finetune the model first. 178 | group.add_argument('--eva_args', type=json.loads, default={}) 179 | return super().add_model_specific_args(parser) 180 | 181 | def forward(self, input_ids, vision_expert_mask, image_embed_mask, **kwargs): 182 | 183 | cross_inputs = {} 184 | for k in kwargs: 185 | if k.startswith('cross_'): 186 | cross_inputs[k[6:]] = kwargs[k] 187 | if kwargs.get("mems_cross") is not None: 188 | kwargs['encoder_outputs'] = kwargs["mems_cross"][0] 189 | else: 190 | outputs = self.get_mixin('encoder')(**cross_inputs) 191 | kwargs['encoder_outputs'] = outputs 192 | kwargs['cross_attention_mask'] = cross_inputs['attention_mask'] 193 | 194 | if input_ids.shape[1] > 1: 195 | return super().forward(input_ids=input_ids, vision_expert_mask=vision_expert_mask, image_embed_mask=image_embed_mask, **kwargs) 196 | return super().forward(input_ids=input_ids, **kwargs) 197 | 198 | 199 | class FineTuneTrainCogAgentModel(CogAgentModel): 200 | def __init__(self, args, transformer=None, **kw_args): 201 | super().__init__(args, transformer=transformer, **kw_args) 202 | self.args = args 203 | # If you want to use model parallel with a mp_size=1 checkpoint, and meanwhile you also want to use lora, 204 | # you have to add_mixin after loading model checkpoint. 205 | 206 | @classmethod 207 | def add_model_specific_args(cls, parser): 208 | group = parser.add_argument_group('CogAgent-finetune', 'CogAgent finetune Configurations') 209 | group.add_argument('--pre_seq_len', type=int, default=8) 210 | group.add_argument('--lora_rank', type=int, default=10) 211 | group.add_argument('--use_ptuning', action="store_true") 212 | group.add_argument('--use_lora', action="store_true") 213 | group.add_argument('--use_qlora', action="store_true") 214 | group.add_argument('--layer_range', nargs='+', type=int, default=None) 215 | return super().add_model_specific_args(parser) 216 | 217 | 218 | from sat.model.finetune import PTuningV2Mixin 219 | from sat.model.finetune.lora2 import LoraMixin 220 | class FineTuneTestCogAgentModel(CogAgentModel): 221 | def __init__(self, args, transformer=None, **kw_args): 222 | super().__init__(args, transformer=transformer, **kw_args) 223 | if args.use_ptuning: 224 | self.add_mixin("ptuning", PTuningV2Mixin(args.num_layers, args.hidden_size // args.num_attention_heads, args.num_attention_heads, args.pre_seq_len)) 225 | if args.use_lora: 226 | self.add_mixin("lora", LoraMixin(args.num_layers, args.lora_rank, layer_range=args.layer_range), reinit=True) 227 | self.get_mixin("eva").vit_model.add_mixin("lora", LoraMixin(args.eva_args['num_layers'], args.lora_rank, layer_range=args.layer_range), reinit=True) 228 | elif args.use_qlora: 229 | self.add_mixin("lora", LoraMixin(args.num_layers, args.lora_rank, layer_range=args.layer_range, qlora=True), reinit=True) 230 | self.args = args 231 | 232 | @classmethod 233 | def add_model_specific_args(cls, parser): 234 | group = parser.add_argument_group('CogAgent-finetune', 'CogAgent finetune Configurations') 235 | group.add_argument('--pre_seq_len', type=int, default=8) 236 | group.add_argument('--lora_rank', type=int, default=10) 237 | group.add_argument('--use_ptuning', action="store_true") 238 | group.add_argument('--use_lora', action="store_true") 239 | group.add_argument('--use_qlora', action="store_true") 240 | group.add_argument('--layer_range', nargs='+', type=int, default=None) 241 | return super().add_model_specific_args(parser) 242 | -------------------------------------------------------------------------------- /utils/models/cogvlm_model.py: -------------------------------------------------------------------------------- 1 | from sat.model.official.llama_model import LLaMAModel 2 | import json 3 | import torch 4 | from sat.model.base_model import BaseMixin 5 | import torch.nn as nn 6 | from .mixin import LlamaVisionExpertFCMixin, LlamaVisionExpertAttnMixin 7 | 8 | from sat.resources.urls import MODEL_URLS 9 | 10 | MODEL_URLS["cogvlm-base-224"] = "r2://cogvlm-base-224.zip" 11 | MODEL_URLS["cogvlm-base-490"] = "r2://cogvlm-base-490.zip" 12 | MODEL_URLS["cogvlm-chat-v1.1"] = "r2://cogvlm-chat-v1.1.zip" 13 | MODEL_URLS["cogvlm-grounding-base"] = "r2://cogvlm-grounding-base.zip" 14 | MODEL_URLS["cogvlm-grounding-generalist-v1.1"] = "r2://cogvlm-grounding-generalist-v1.1.zip" 15 | 16 | 17 | class GLU(nn.Module): 18 | def __init__(self, args, in_features): 19 | super().__init__() 20 | self.linear_proj = nn.Linear(in_features, args.hidden_size, bias=False) 21 | self.norm1 = nn.LayerNorm(args.hidden_size) 22 | self.act1 = nn.GELU() 23 | self.act2 = nn.functional.silu 24 | self.dense_h_to_4h = nn.Linear(args.hidden_size, args.inner_hidden_size, bias=False) 25 | self.gate_proj = nn.Linear(args.hidden_size, args.inner_hidden_size, bias=False) 26 | self.dense_4h_to_h = nn.Linear(args.inner_hidden_size, args.hidden_size, bias=False) 27 | 28 | def forward(self, x): 29 | x = self.linear_proj(x) 30 | x = self.act1(self.norm1(x)) 31 | x = self.act2(self.gate_proj(x)) * self.dense_h_to_4h(x) 32 | x = self.dense_4h_to_h(x) 33 | return x 34 | 35 | from .eva_clip_model import EVA2CLIPModel 36 | import argparse 37 | from copy import deepcopy 38 | def override_dist_dtype_device_args(args, b={}): 39 | if args.mode == 'inference': 40 | minimal_args = argparse.Namespace( 41 | world_size=args.world_size, 42 | rank=args.rank, 43 | local_rank=args.local_rank, 44 | skip_init=args.skip_init, 45 | use_gpu_initialization=args.use_gpu_initialization, 46 | deepspeed=args.deepspeed, 47 | bf16=args.bf16, 48 | fp16=args.fp16, 49 | mode=args.mode, 50 | device=args.device 51 | ) 52 | else: 53 | minimal_args = argparse.Namespace( 54 | world_size=args.world_size, 55 | rank=args.rank, 56 | local_rank=args.local_rank, 57 | skip_init=args.skip_init, 58 | use_gpu_initialization=args.use_gpu_initialization, 59 | deepspeed=args.deepspeed, 60 | bf16=args.bf16, 61 | fp16=args.fp16, 62 | mode=args.mode, 63 | checkpoint_activations=args.checkpoint_activations if not hasattr(args, 'vit_checkpoint_activations') else args.vit_checkpoint_activations, 64 | checkpoint_num_layers=args.checkpoint_num_layers, 65 | device=args.device, 66 | hidden_dropout=0., 67 | attention_dropout=0., 68 | ) 69 | if hasattr(args, 'model_parallel_size'): 70 | b['model_parallel_size'] = args.model_parallel_size 71 | return argparse.Namespace(**deepcopy(b), **vars(minimal_args)) 72 | 73 | class ImageMixin(BaseMixin): 74 | def __init__(self, args): 75 | super().__init__() 76 | vit_args = override_dist_dtype_device_args(args, args.eva_args) 77 | self.vit_model = EVA2CLIPModel(EVA2CLIPModel.get_args(**vars(vit_args))) 78 | self.in_features = 1792 79 | self.linear_proj = GLU(args, self.in_features) 80 | self.image_length = args.image_length 81 | self.boi = nn.Parameter(torch.zeros(1, 1, args.hidden_size)) 82 | self.eoi = nn.Parameter(torch.zeros(1, 1, args.hidden_size)) 83 | 84 | def word_embedding_forward(self, input_ids, output_cross_layer, **kw_args): 85 | vision_inputs = {} 86 | for k in kw_args: 87 | if k.startswith('vision_') and k != 'vision_expert_mask': 88 | vision_inputs[k[7:]] = kw_args[k] 89 | if input_ids.shape[1] == 1 or not vision_inputs: 90 | return self.transformer.word_embeddings(input_ids) 91 | image_emb = self.vit_model(**vision_inputs)[0] 92 | image_emb = self.linear_proj(image_emb) 93 | 94 | image_embed_mask = kw_args['image_embed_mask'] 95 | word_embedding = self.transformer.word_embeddings(input_ids).clone() 96 | word_embedding[image_embed_mask.bool()] = torch.cat([self.boi.repeat(len(image_emb), 1, 1), image_emb, self.eoi.repeat(len(image_emb), 1, 1)], dim=1).reshape(-1, image_emb.shape[-1]) 97 | return word_embedding.contiguous() 98 | 99 | 100 | class CogVLMModel(LLaMAModel): 101 | def __init__(self, args, transformer=None, **kwargs): 102 | super().__init__(args, transformer=transformer, **kwargs) 103 | self.image_length = args.image_length 104 | self.add_mixin("eva", ImageMixin(args)) 105 | self.del_mixin("mlp") 106 | self.add_mixin("mlp", LlamaVisionExpertFCMixin(args.hidden_size, args.inner_hidden_size, args.num_layers, 32)) 107 | self.del_mixin("rotary") 108 | self.add_mixin("rotary", LlamaVisionExpertAttnMixin(args.hidden_size, args.num_attention_heads, args.num_layers, 32)) 109 | 110 | @classmethod 111 | def add_model_specific_args(cls, parser): 112 | group = parser.add_argument_group('CogVLM', 'CogVLM Configurations') 113 | group.add_argument('--image_length', type=int, default=256) 114 | group.add_argument('--eva_args', type=json.loads, default={}) 115 | return super().add_model_specific_args(parser) 116 | 117 | def forward(self, input_ids, vision_expert_mask, image_embed_mask, **kwargs): 118 | if input_ids.shape[1] > 1: 119 | return super().forward(input_ids=input_ids, vision_expert_mask=vision_expert_mask, image_embed_mask=image_embed_mask, **kwargs) 120 | return super().forward(input_ids=input_ids, **kwargs) 121 | 122 | 123 | class FineTuneTrainCogVLMModel(CogVLMModel): 124 | def __init__(self, args, transformer=None, **kw_args): 125 | super().__init__(args, transformer=transformer, **kw_args) 126 | self.args = args 127 | # If you want to use model parallel with a mp_size=1 checkpoint, and meanwhile you also want to use lora, 128 | # you have to add_mixin after loading model checkpoint. 129 | 130 | @classmethod 131 | def add_model_specific_args(cls, parser): 132 | group = parser.add_argument_group('CogVLM-finetune', 'CogVLM finetune Configurations') 133 | group.add_argument('--pre_seq_len', type=int, default=8) 134 | group.add_argument('--lora_rank', type=int, default=10) 135 | group.add_argument('--use_ptuning', action="store_true") 136 | group.add_argument('--use_lora', action="store_true") 137 | group.add_argument('--use_qlora', action="store_true") 138 | group.add_argument('--layer_range', nargs='+', type=int, default=None) 139 | return super().add_model_specific_args(parser) 140 | 141 | 142 | from sat.model.finetune import PTuningV2Mixin 143 | from sat.model.finetune.lora2 import LoraMixin 144 | class FineTuneTestCogVLMModel(CogVLMModel): 145 | def __init__(self, args, transformer=None, **kw_args): 146 | super().__init__(args, transformer=transformer, **kw_args) 147 | if args.use_ptuning: 148 | self.add_mixin("ptuning", PTuningV2Mixin(args.num_layers, args.hidden_size // args.num_attention_heads, args.num_attention_heads, args.pre_seq_len)) 149 | if args.use_lora: 150 | self.add_mixin("lora", LoraMixin(args.num_layers, args.lora_rank, layer_range=args.layer_range), reinit=True) 151 | self.get_mixin("eva").vit_model.add_mixin("lora", LoraMixin(args.eva_args['num_layers'], args.lora_rank, layer_range=args.layer_range), reinit=True) 152 | elif args.use_qlora: 153 | self.add_mixin("lora", LoraMixin(args.num_layers, args.lora_rank, layer_range=args.layer_range, qlora=True), reinit=True) 154 | self.args = args 155 | 156 | @classmethod 157 | def add_model_specific_args(cls, parser): 158 | group = parser.add_argument_group('CogVLM-finetune', 'CogVLM finetune Configurations') 159 | group.add_argument('--pre_seq_len', type=int, default=8) 160 | group.add_argument('--lora_rank', type=int, default=10) 161 | group.add_argument('--use_ptuning', action="store_true") 162 | group.add_argument('--use_lora', action="store_true") 163 | group.add_argument('--use_qlora', action="store_true") 164 | group.add_argument('--layer_range', nargs='+', type=int, default=None) 165 | return super().add_model_specific_args(parser) 166 | -------------------------------------------------------------------------------- /utils/models/eva_clip_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from sat.model.base_model import BaseModel 3 | from sat.model.mixins import BaseMixin 4 | from sat.model.official.vit_model import ViTProperty, ImagePatchEmbeddingMixin, InterpolatedPositionEmbeddingMixin, gelu 5 | from sat import mpu 6 | 7 | class IdentityMixin(BaseMixin): 8 | def __init__(self): 9 | super().__init__() 10 | 11 | def final_forward(self, logits, **kwargs): 12 | return logits[:, 1:] 13 | 14 | import xformers.ops as xops 15 | class XAttn(BaseMixin): 16 | def __init__(self, head_dim): 17 | super().__init__() 18 | self.scale = head_dim ** -0.5 19 | 20 | def attention_fn(self, query_layer, key_layer, value_layer, attention_mask, 21 | attention_dropout=None, log_attention_weights=None, scaling_attention_score=True, **kwargs): 22 | dropout_p = 0. # xformers does not support dropout for eva hidden size 23 | 24 | query_layer = query_layer.permute(0, 2, 1, 3) # B, num_heads, N, C -> B, N, num_heads, C 25 | key_layer = key_layer.permute(0, 2, 1, 3) 26 | value_layer = value_layer.permute(0, 2, 1, 3) 27 | 28 | out = xops.memory_efficient_attention( 29 | query_layer, key_layer, value_layer, 30 | p=dropout_p, 31 | scale=self.scale, 32 | ) 33 | return out 34 | 35 | def attention_forward(self, hidden_states, mask, **kw_args): 36 | self = self.transformer.layers[kw_args['layer_id']].attention 37 | attention_fn = self.hooks['attention_fn'] 38 | 39 | mixed_raw_layer = self.query_key_value(hidden_states) 40 | 41 | B, N, C = hidden_states.shape 42 | mixed_raw_layer = mixed_raw_layer.reshape(B, N, 3, self.num_attention_heads_per_partition, -1).permute(2, 0, 3, 1, 4) # 3, B, num_heads, N, C 43 | query_layer, key_layer, value_layer = mixed_raw_layer[0], mixed_raw_layer[1], mixed_raw_layer[2] 44 | 45 | dropout_fn = self.attention_dropout if self.training else None 46 | 47 | context_layer = attention_fn(query_layer, key_layer, value_layer, mask, dropout_fn, **kw_args) 48 | 49 | context_layer = context_layer.view(B, N, -1) 50 | output = self.dense(context_layer) 51 | 52 | if self.training: 53 | output = self.output_dropout(output) 54 | return output 55 | 56 | class NewLayerForward(BaseMixin): 57 | def __init__(self): 58 | super().__init__() 59 | 60 | def layer_forward(self, hidden_states, mask, *args, **kw_args): 61 | ''' 62 | hidden_states: [batch, seq_len, hidden_size] 63 | mask: [(1, 1), seq_len, seq_len] 64 | ''' 65 | self = self.transformer.layers[kw_args['layer_id']] 66 | 67 | attention_input = hidden_states 68 | 69 | # Self attention. 70 | attention_output = self.input_layernorm(self.attention(attention_input, mask, **kw_args)) 71 | 72 | # DropPath for attention 73 | if self.training and self.drop_path > 0.: 74 | if mpu.get_cuda_rng_tracker is not None: 75 | # drop_path must use model parallel rng tracker 76 | # the tracker is initialized as seed of `seed + model_parallel_rank` 77 | # deepspeed act-ckpt record the model parallel tracker states 78 | with mpu.get_cuda_rng_tracker().fork(): 79 | # drop_path percentage 0, others 1/(1-p) 80 | random_tensor = (1-self.drop_path 81 | + torch.rand((attention_output.shape[0],), dtype=attention_output.dtype, device=attention_output.device)).floor_() / (1-self.drop_path) 82 | attention_output = random_tensor.view(-1, 1, 1) * attention_output 83 | 84 | # Residual connection. 85 | hidden_states = attention_input + attention_output 86 | mlp_input = hidden_states 87 | 88 | # MLP. 89 | mlp_output = self.post_attention_layernorm(self.mlp(mlp_input, **kw_args)) 90 | 91 | # DropPath for mlp 92 | if self.training and self.drop_path > 0.: 93 | if mpu.get_cuda_rng_tracker is not None: 94 | with mpu.get_cuda_rng_tracker().fork(): 95 | random_tensor = (1-self.drop_path 96 | + torch.rand((mlp_output.shape[0],), dtype=mlp_output.dtype, device=mlp_output.device)).floor_() / (1-self.drop_path) 97 | mlp_output = random_tensor.view(-1, 1, 1) * mlp_output 98 | 99 | # Second residual connection. 100 | output = mlp_input + mlp_output 101 | 102 | return output 103 | 104 | class EVA2CLIPModel(BaseModel): 105 | def __init__(self, args, transformer=None, **kwargs): 106 | property = ViTProperty(args.image_size, args.patch_size, args.pre_len, args.post_len) 107 | args.max_sequence_length = property.pre_len + property.num_patches + property.post_len 108 | if 'activation_func' not in kwargs: 109 | kwargs['activation_func'] = gelu 110 | super().__init__(args, transformer=transformer, **kwargs) 111 | self.transformer.property = property 112 | self.add_mixin("patch_embedding", ImagePatchEmbeddingMixin(args.in_channels, args.hidden_size, property)) 113 | self.add_mixin("pos_embedding", InterpolatedPositionEmbeddingMixin()) 114 | self.add_mixin("final", IdentityMixin()) 115 | self.add_mixin("newpost", NewLayerForward()) 116 | self.add_mixin("xattn", XAttn(args.hidden_size // args.num_attention_heads)) 117 | 118 | @classmethod 119 | def add_model_specific_args(cls, parser): 120 | group = parser.add_argument_group('EVA2CLIP', 'EVA2CLIP Configurations') 121 | group.add_argument('--image-size', nargs='+', type=int, default=[224, 224]) 122 | group.add_argument('--pre-len', type=int, default=1) # [cls] by default 123 | group.add_argument('--post-len', type=int, default=0) # empty by default, but sometimes with special tokens, such as [det] in yolos. 124 | group.add_argument('--in-channels', type=int, default=3) 125 | group.add_argument('--patch-size', type=int, default=16) 126 | return parser 127 | 128 | -------------------------------------------------------------------------------- /utils/models/mixin.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from sat.transformer_defaults import attention_fn_default 5 | from sat.model.base_model import BaseMixin, non_conflict 6 | from sat.mpu.layers import ColumnParallelLinear, RowParallelLinear 7 | from sat.mpu.utils import split_tensor_along_last_dim 8 | from sat import mpu 9 | 10 | 11 | class LlamaVisionExpertFCMixin(BaseMixin): 12 | def __init__(self, in_features, hidden_features, num_layers=32, num_vision_layers=0, vision_layer_range=None, 13 | params_dtype=torch.float, device=torch.device('cpu')): 14 | super().__init__() 15 | 16 | self.num_layers = num_layers 17 | self.num_vision_layers = num_vision_layers 18 | if vision_layer_range is None: 19 | vision_layer_range = [i for i in range(min(num_vision_layers, num_layers))] 20 | self.vision_layer_range = vision_layer_range 21 | self.gate_proj = nn.ModuleList([ColumnParallelLinear( 22 | in_features, 23 | hidden_features, 24 | gather_output=False, 25 | init_method=None, 26 | bias=False, 27 | params_dtype=params_dtype, 28 | module=self, 29 | name="dense_h_to_4h_gate", 30 | skip_init=True, 31 | device=device 32 | ) for i in range(num_layers)]) 33 | # Trainable vision expert parameters 34 | vision_dense_h_to_4h_list = [] 35 | vision_dense_4h_to_h_list = [] 36 | gate_proj_list = [] 37 | 38 | 39 | for i in vision_layer_range: 40 | vision_dense_h_to_4h = ColumnParallelLinear( 41 | in_features, 42 | hidden_features, 43 | gather_output=False, 44 | init_method=None, 45 | bias=False, 46 | params_dtype=params_dtype, 47 | module=self, 48 | name="vision_dense_h_to_4h", 49 | skip_init=True, 50 | device=device 51 | ) 52 | 53 | # Project back to h. 54 | vision_dense_4h_to_h = RowParallelLinear( 55 | hidden_features, 56 | in_features, 57 | input_is_parallel=True, 58 | init_method=None, 59 | bias=False, 60 | params_dtype=params_dtype, 61 | module=self, 62 | name="vision_dense_4h_to_h", 63 | skip_init=True, 64 | device=device 65 | ) 66 | 67 | gate_proj = ColumnParallelLinear( 68 | in_features, 69 | hidden_features, 70 | gather_output=False, 71 | init_method=None, 72 | bias=False, 73 | params_dtype=params_dtype, 74 | module=self, 75 | name="vision_gate_proj", 76 | skip_init=True, 77 | device=device 78 | ) 79 | 80 | vision_dense_h_to_4h_list.append(vision_dense_h_to_4h) 81 | vision_dense_4h_to_h_list.append(vision_dense_4h_to_h) 82 | gate_proj_list.append(gate_proj) 83 | 84 | self.vision_dense_h_to_4h_list = nn.ModuleDict([ 85 | (str(layer_id), vision_dense_h_to_4h) 86 | for layer_id, vision_dense_h_to_4h in zip(vision_layer_range, vision_dense_h_to_4h_list) 87 | ]) 88 | self.vision_dense_4h_to_h_list = nn.ModuleDict([ 89 | (str(layer_id), vision_dense_4h_to_h) 90 | for layer_id, vision_dense_4h_to_h in zip(vision_layer_range, vision_dense_4h_to_h_list) 91 | ]) 92 | self.vision_gate_proj = nn.ModuleDict([ 93 | (str(layer_id), gate_proj) 94 | for layer_id, gate_proj in zip(vision_layer_range, gate_proj_list) 95 | ]) 96 | 97 | def mlp_forward(self, hidden_states, **kw_args): 98 | mixin_self = self 99 | self = self.transformer.layers[kw_args['layer_id']].mlp 100 | if "vision_expert_mask" in kw_args: 101 | vision_expert_mask = kw_args['vision_expert_mask'] 102 | else: 103 | vision_expert_mask = None 104 | 105 | layer_id_key = str(int(kw_args['layer_id'])) 106 | 107 | if kw_args['layer_id'] in mixin_self.vision_layer_range and (vision_expert_mask is not None) and vision_expert_mask.any(): 108 | vision_dense_h_to_4h = mixin_self.vision_dense_h_to_4h_list[layer_id_key] 109 | vision_dense_4h_to_h = mixin_self.vision_dense_4h_to_h_list[layer_id_key] 110 | vision_gate_proj = mixin_self.vision_gate_proj[layer_id_key] 111 | output = torch.empty(hidden_states.shape, dtype=hidden_states.dtype, device=hidden_states.device) 112 | 113 | language_hidden_state = hidden_states[~vision_expert_mask.bool()] 114 | language_intermediate_parallel = self.activation_func(mixin_self.gate_proj[kw_args['layer_id']](language_hidden_state)) * self.dense_h_to_4h(language_hidden_state) 115 | output[~vision_expert_mask.bool()] = self.dense_4h_to_h(language_intermediate_parallel) # language_output 116 | 117 | vision_hidden_state = hidden_states[vision_expert_mask.bool()] 118 | vision_intermediate_parallel = vision_dense_h_to_4h(vision_hidden_state) 119 | gate_output = vision_gate_proj(vision_hidden_state) 120 | 121 | vision_intermediate_parallel *= self.activation_func(gate_output) 122 | output[vision_expert_mask.bool()] = vision_dense_4h_to_h(vision_intermediate_parallel) # vision_output 123 | else: 124 | intermediate_parallel = self.activation_func(mixin_self.gate_proj[kw_args['layer_id']](hidden_states)) * self.dense_h_to_4h(hidden_states) 125 | output = self.dense_4h_to_h(intermediate_parallel) 126 | 127 | return output.contiguous() 128 | 129 | def copy_param(self): 130 | with torch.no_grad(): 131 | for i in self.vision_layer_range: 132 | self.vision_gate_proj[str(i)].weight.data.copy_(self.gate_proj[i].weight.data) 133 | self.vision_dense_4h_to_h_list[str(i)].weight.data.copy_(self.transformer.layers[i].mlp.dense_4h_to_h.weight.data) 134 | self.vision_dense_h_to_4h_list[str(i)].weight.data.copy_(self.transformer.layers[i].mlp.dense_h_to_4h.weight.data) 135 | 136 | from sat.mpu import get_model_parallel_world_size 137 | from sat.mpu.utils import divide 138 | from sat.model.position_embedding.triton_rotary_embeddings import FastRotaryEmbedding 139 | 140 | class LlamaVisionExpertAttnMixin(BaseMixin): 141 | def __init__(self, hidden_size, num_heads, num_layers=28, num_vision_layers=0, use_vision_expert=True, vision_layer_range=None, 142 | params_dtype=torch.float, device=torch.device('cpu')): 143 | super().__init__() 144 | 145 | world_size = get_model_parallel_world_size() 146 | self.hidden_size = hidden_size 147 | self.num_attention_heads = num_heads 148 | self.hidden_size_per_attention_head = divide(hidden_size, num_heads) 149 | self.num_attention_heads_per_partition = divide(num_heads, world_size) 150 | self.inner_hidden_size = num_heads * self.hidden_size_per_attention_head 151 | 152 | self.rotary_emb = FastRotaryEmbedding( 153 | hidden_size // num_heads, pos_idx_in_fp32=False 154 | ) 155 | 156 | self.num_vision_layers = num_vision_layers 157 | self.num_layers = num_layers 158 | if vision_layer_range is None: 159 | vision_layer_range = [i for i in range(min(num_vision_layers, num_layers))] 160 | self.vision_layer_range = vision_layer_range 161 | 162 | self.use_vision_expert = use_vision_expert 163 | # Trainable vision expert parameters 164 | 165 | if self.use_vision_expert: 166 | vision_query_key_value_list = [] 167 | vision_dense_list = [] 168 | for i in vision_layer_range: 169 | vision_query_key_value = ColumnParallelLinear( 170 | hidden_size, 171 | 3 * hidden_size, 172 | stride=3, 173 | gather_output=False, 174 | init_method=None, 175 | bias=False, 176 | params_dtype=params_dtype, 177 | module=self, 178 | name="vision_query_key_value", 179 | skip_init=True, 180 | device=device 181 | ) 182 | 183 | vision_dense = RowParallelLinear( 184 | self.inner_hidden_size, 185 | hidden_size, 186 | input_is_parallel=True, 187 | init_method=None, 188 | bias=False, 189 | params_dtype=params_dtype, 190 | module=self, 191 | name="vision_dense", 192 | skip_init=True, 193 | device=device, 194 | final_bias=False 195 | ) 196 | 197 | vision_query_key_value_list.append(vision_query_key_value) 198 | vision_dense_list.append(vision_dense) 199 | 200 | self.vision_query_key_value_list = nn.ModuleDict([ 201 | (str(layer_id), vision_query_key_value) 202 | for layer_id, vision_query_key_value in zip(vision_layer_range, vision_query_key_value_list) 203 | ]) 204 | self.vision_dense_list = nn.ModuleDict([ 205 | (str(layer_id), vision_dense) 206 | for layer_id, vision_dense in zip(vision_layer_range, vision_dense_list) 207 | ]) 208 | 209 | def attention_forward(self, hidden_states, mask, **kw_args): 210 | mixin_self = self 211 | self = self.transformer.layers[kw_args['layer_id']].attention 212 | attention_fn = attention_fn_default 213 | if 'attention_fn' in self.hooks: 214 | attention_fn = self.hooks['attention_fn'] 215 | if "vision_expert_mask" in kw_args: 216 | vision_expert_mask = kw_args['vision_expert_mask'] 217 | else: 218 | vision_expert_mask = None 219 | 220 | layer_id_key = str(int(kw_args['layer_id'])) 221 | if mixin_self.use_vision_expert and kw_args['layer_id'] in mixin_self.vision_layer_range and ( 222 | vision_expert_mask is not None) and vision_expert_mask.any(): 223 | shape = list(hidden_states.shape) 224 | parallel_size = mpu.get_model_parallel_world_size() 225 | shape[-1] = shape[-1] * 3 // parallel_size 226 | vision_query_key_value = mixin_self.vision_query_key_value_list[layer_id_key] 227 | mixed_raw_layer = torch.empty(shape, dtype=hidden_states.dtype, device=hidden_states.device) 228 | language_hidden_states = hidden_states[~vision_expert_mask.bool()] 229 | vision_hidden_states = hidden_states[vision_expert_mask.bool()] 230 | mixed_raw_layer[~vision_expert_mask.bool()] = self.query_key_value( 231 | language_hidden_states) # language_mixed_raw_layer 232 | mixed_raw_layer[vision_expert_mask.bool()] = vision_query_key_value( 233 | vision_hidden_states) # vision_mixed_raw_layer 234 | else: 235 | mixed_raw_layer = self.query_key_value(hidden_states) 236 | 237 | (mixed_query_layer, 238 | mixed_key_layer, 239 | mixed_value_layer) = split_tensor_along_last_dim(mixed_raw_layer, 3) 240 | 241 | dropout_fn = self.attention_dropout if self.training else None 242 | 243 | query_layer = self._transpose_for_scores(mixed_query_layer) 244 | key_layer = self._transpose_for_scores(mixed_key_layer) 245 | value_layer = self._transpose_for_scores(mixed_value_layer) 246 | 247 | query_layer, key_layer = mixin_self.rotary_emb(query_layer,key_layer, kw_args['position_ids'], max_seqlen=kw_args['position_ids'].max()+1, layer_id=kw_args['layer_id']) 248 | 249 | context_layer = attention_fn(query_layer, key_layer, value_layer, mask, dropout_fn, **kw_args) 250 | 251 | context_layer = context_layer.permute(0, 2, 1, 3).contiguous() 252 | new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,) 253 | context_layer = context_layer.view(*new_context_layer_shape) 254 | 255 | if mixin_self.use_vision_expert and kw_args['layer_id'] in mixin_self.vision_layer_range and ( 256 | vision_expert_mask is not None) and vision_expert_mask.any(): 257 | vision_dense = mixin_self.vision_dense_list[layer_id_key] 258 | parallel_size = mpu.get_model_parallel_world_size() 259 | target_shape = context_layer.shape[:-1] + (context_layer.shape[-1] * parallel_size,) 260 | output = torch.empty(target_shape, dtype=hidden_states.dtype, device=hidden_states.device) 261 | output[~vision_expert_mask.bool()] = self.dense(context_layer[~vision_expert_mask.bool()]) # language 262 | output[vision_expert_mask.bool()] = vision_dense(context_layer[vision_expert_mask.bool()]) # vision 263 | else: 264 | output = self.dense(context_layer) 265 | 266 | if self.training: 267 | output = self.output_dropout(output) 268 | return output.contiguous() 269 | 270 | def copy_param(self): 271 | with torch.no_grad(): 272 | for i in self.vision_layer_range: 273 | self.vision_query_key_value_list[str(i)].weight.data.copy_(self.transformer.layers[i].attention.query_key_value.weight.data) 274 | self.vision_dense_list[str(i)].weight.data.copy_(self.transformer.layers[i].attention.dense.weight.data) -------------------------------------------------------------------------------- /utils/split_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | 4 | def find_all_files(path, suffix=".jpg"): 5 | target_files = [] 6 | for cur_dir, _, files in os.walk(path, followlinks=True): 7 | for f in files: 8 | if f.endswith(suffix): 9 | target_files.append(os.path.join(cur_dir, f)) 10 | print(f'find {len(target_files)} files...') 11 | return target_files 12 | 13 | all_files = find_all_files('archive') 14 | os.makedirs("archive_split", exist_ok=True) 15 | os.makedirs("archive_split/train", exist_ok=True) 16 | os.makedirs("archive_split/valid", exist_ok=True) 17 | os.makedirs("archive_split/test", exist_ok=True) 18 | 19 | import random 20 | random.seed(2023) 21 | random.shuffle(all_files) 22 | train = all_files[:8000] 23 | valid = all_files[8000:8000+500] 24 | test = all_files[8000+500:8000+500+1500] 25 | 26 | print("building train") 27 | for file in train: 28 | shutil.move(file, os.path.join("archive_split/train", file.split("/")[-1])) 29 | print("building valid") 30 | for file in valid: 31 | shutil.move(file, os.path.join("archive_split/valid", file.split("/")[-1])) 32 | print("building test") 33 | for file in test: 34 | shutil.move(file, os.path.join("archive_split/test", file.split("/")[-1])) 35 | print("done") -------------------------------------------------------------------------------- /utils/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .chat import chat 2 | from .language import llama2_tokenizer, llama2_text_processor, llama2_text_processor_inference 3 | from .vision import get_image_processor 4 | from .grounding_parser import parse_response 5 | from .dataset import ItemDataset -------------------------------------------------------------------------------- /utils/utils/chat.py: -------------------------------------------------------------------------------- 1 | # -*- encoding: utf-8 -*- 2 | ''' 3 | @File : chat.py 4 | @Time : 2023/05/08 19:10:08 5 | @Author : Ming Ding 6 | @Contact : dm18@mails.tsinghua.edu.cn 7 | ''' 8 | 9 | from typing import Optional, Tuple, Union, List, Callable, Dict, Any 10 | import requests 11 | from PIL import Image 12 | from io import BytesIO 13 | 14 | import torch 15 | from sat.generation.autoregressive_sampling import filling_sequence, stream_filling_sequence, get_masks_and_position_ids_default 16 | from sat.generation.sampling_strategies import BaseStrategy, BeamSearchStrategy 17 | from sat.mpu import get_model_parallel_rank 18 | 19 | def process_image(image_path, img_processor, cross_img_processor, image): 20 | if image is None: 21 | if image_path.startswith("http"): 22 | response = requests.get(image_path, timeout=10) 23 | image = Image.open(BytesIO(response.content)) 24 | else: 25 | image = Image.open(image_path) 26 | 27 | if image is not None and isinstance(image, Image.Image): 28 | pil_img = image.convert('RGB') 29 | img_dict = img_processor(pil_img) 30 | cross_img_dict = cross_img_processor(pil_img) if cross_img_processor is not None else {} 31 | ret = (img_dict, pil_img, cross_img_dict) 32 | else: 33 | ret = image 34 | return ret 35 | 36 | def chat(image_path, model, text_processor, img_processor, 37 | query: str, history: List[Tuple[str, str]] = None, cross_img_processor=None, image: Image = None, 38 | max_length: int = 4096, top_p=0.95, top_k=5, temperature=0.95, repetition_penalty=1.0, 39 | invalid_slices=[], no_prompt=False, args=None 40 | ): 41 | if image is None: 42 | assert image_path is not None 43 | if not history: 44 | history = [] 45 | 46 | if no_prompt: 47 | query = '' 48 | prompt = text_processor.history_to_prompt(query, history) 49 | 50 | (torch_image, pil_img, cross_image) = process_image(image_path, img_processor, cross_img_processor, image) 51 | 52 | if torch_image is not None: 53 | for k in torch_image: 54 | if type(torch_image[k]) is torch.Tensor and torch_image[k].dtype is not torch.int and torch_image[k].dtype is not torch.long: 55 | torch_image[k] = torch_image[k].to(torch.bfloat16 if args.bf16 else torch.float16) 56 | if type(torch_image[k]) is torch.Tensor: 57 | torch_image[k] = torch_image[k].to(next(model.parameters()).device) 58 | 59 | if cross_image is not None: 60 | for k in cross_image: 61 | if type(cross_image[k]) is torch.Tensor and cross_image[k].dtype is not torch.int and cross_image[k].dtype is not torch.long: 62 | cross_image[k] = cross_image[k].to(torch.bfloat16 if args.bf16 else torch.float16) 63 | if type(cross_image[k]) is torch.Tensor: 64 | cross_image[k] = cross_image[k].to(next(model.parameters()).device) 65 | 66 | inputs_dic = text_processor(prompt) 67 | for k in inputs_dic: 68 | if type(inputs_dic[k]) is torch.Tensor and inputs_dic[k].dtype is not torch.int and inputs_dic[k].dtype is not torch.long: 69 | inputs_dic[k] = inputs_dic[k].to(torch.bfloat16 if args.bf16 else torch.float16) 70 | if type(inputs_dic[k]) is torch.Tensor: 71 | inputs_dic[k] = inputs_dic[k].to(next(model.parameters()).device) 72 | input_ids = inputs_dic['input_ids'].to(model.parameters().__next__().device)[0] 73 | 74 | if max_length-len(input_ids) <= 1: 75 | response = "The prompt exceeds the context length limit, please try again." 76 | return response, history, (torch_image, pil_img) 77 | 78 | seq = torch.cat( 79 | [input_ids, torch.tensor([-1]*(max_length-len(input_ids)), device=input_ids.device)], dim=0 80 | ) 81 | strategy = BaseStrategy(temperature=temperature, top_p=top_p, top_k=top_k, end_tokens=[text_processor.tokenizer.eos_token_id], 82 | invalid_slices=invalid_slices, repetition_penalty=repetition_penalty) 83 | # use beam search to get a better result 84 | # strategy = BeamSearchStrategy(temperature=temperature, top_p=top_p, top_k=top_k, end_tokens=[text_processor.tokenizer.eos_token_id], 85 | # num_beams=5, consider_end=True, repetition_penalty=repetition_penalty) 86 | get_func = text_processor.get_func(input_ids, **inputs_dic) if hasattr(text_processor, 'get_func') else get_masks_and_position_ids_default 87 | 88 | img_inputs = {'vision_'+k: v for k, v in torch_image.items()} 89 | if cross_image is not None: 90 | img_inputs = {**img_inputs, **{'cross_'+k:v for k,v in cross_image.items()}} 91 | inputs_dic.pop('input_ids') 92 | inputs = {**img_inputs, **inputs_dic} 93 | 94 | if args.stream_chat: 95 | filling_stream = stream_filling_sequence( 96 | model, seq, 97 | batch_size=1, 98 | get_masks_and_position_ids=get_func, 99 | strategy=strategy, 100 | **inputs 101 | ) 102 | if get_model_parallel_rank() == 0: 103 | if 'chinese' in args and not args.chinese: 104 | print("Model: ", end='') 105 | else: 106 | print("模型:", end='') 107 | offset = len(text_processor.tokenizer.decode(input_ids)) 108 | for tokens, mems in filling_stream: 109 | torch.cuda.empty_cache() 110 | tmp_response = text_processor.tokenizer.decode(tokens[0]) 111 | if tmp_response[-1] != "�": 112 | if get_model_parallel_rank() == 0: 113 | tmp_response_offseted = tmp_response[offset:] 114 | if hasattr(text_processor, 'process_response'): 115 | tmp_response_offseted = text_processor.process_response(tmp_response_offseted) 116 | print(tmp_response_offseted, end='', flush=True) 117 | offset = len(tmp_response) 118 | if get_model_parallel_rank() == 0: 119 | print() 120 | output = strategy.finalize(tokens, mems)[0] 121 | 122 | response = text_processor.tokenizer.decode(output[0]) 123 | else: 124 | output = filling_sequence( 125 | model, seq, 126 | batch_size=1, 127 | get_masks_and_position_ids=get_func, 128 | strategy=strategy, 129 | **inputs 130 | )[0] # drop memory 131 | 132 | # --------------- 133 | # port from inference_glm.py, more general than chat mode 134 | # clip -1s and fill back generated things into seq 135 | if type(output) is not list: 136 | output_list = output.tolist() 137 | else: 138 | output_list = output 139 | 140 | response = text_processor.tokenizer.decode(output_list[0]) 141 | # print('original:', response) 142 | if hasattr(text_processor, 'process_response'): 143 | response = text_processor.process_response(response) 144 | response = response.split(text_processor.sep)[-1].strip() 145 | if get_model_parallel_rank() == 0: 146 | from utils.utils.grounding_parser import parse_response 147 | parse_response(pil_img, response) 148 | history = history + [(query, response)] 149 | return response, history, (torch_image, pil_img, cross_image) 150 | -------------------------------------------------------------------------------- /utils/utils/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import logging 3 | import random 4 | import logging 5 | import jsonlines 6 | from io import BytesIO 7 | from PIL import Image 8 | from torch.utils.data import Dataset 9 | from sat.helpers import print_rank0 10 | 11 | def find_all_files(path, suffix=".jpg"): 12 | target_files = [] 13 | for cur_dir, _, files in os.walk(path, followlinks=True): 14 | for f in files: 15 | if f.endswith(suffix): 16 | target_files.append(os.path.join(cur_dir, f)) 17 | print_rank0(f'find {len(target_files)} files...') 18 | return target_files 19 | 20 | class ItemDataset(Dataset): 21 | def __init__(self, image_processor, text_processor, args, data_dirs, cross_image_processor=None, **kwargs): 22 | super().__init__() 23 | self.data = self.load_data(data_dirs) 24 | self.image_processor, self.text_processor, self.cross_image_processor = image_processor, text_processor, cross_image_processor 25 | 26 | def process_img(self, img): 27 | img_dict = {'vision': self.image_processor(img)} 28 | if self.cross_image_processor: 29 | img_dict.update({'cross': self.cross_image_processor(img)}) 30 | return img_dict 31 | 32 | def process_text(self, answer, prompt): 33 | return self.text_processor(answer, prompt) 34 | 35 | def load_data(self, data_dir): 36 | all_files = find_all_files(data_dir, suffix=".jpg") 37 | print_rank0(f"find {len(all_files)} samples in all...") 38 | return all_files 39 | 40 | def __len__(self): 41 | return len(self.data) 42 | 43 | def __getitem__(self, index): 44 | data = self.data[index] 45 | # img 46 | try: 47 | img = Image.open(data).convert('RGB') 48 | except Exception as e: 49 | print_rank0(e, level=logging.WARNING) 50 | return {} 51 | img_dict = self.process_img(img) 52 | # text 53 | label = data.split('/')[-1].split('.')[0] 54 | uni_key = label 55 | text_dict = self.process_text(label, "CAPTCHA:") 56 | if text_dict is None: 57 | print_rank0(f"Process text failed. Please check the max_target_length & max_source_length.\n The data is {data}", level=logging.WARNING) 58 | return {} 59 | # other attr 60 | ret = {**img_dict, **text_dict, "question_id": uni_key} 61 | return ret -------------------------------------------------------------------------------- /utils/utils/grounding_parser.py: -------------------------------------------------------------------------------- 1 | import seaborn as sns 2 | from PIL import Image, ImageDraw, ImageFont 3 | import matplotlib.font_manager 4 | import spacy 5 | import re 6 | 7 | nlp = spacy.load("en_core_web_sm") 8 | 9 | def draw_boxes(image, boxes, texts, output_fn='output.png'): 10 | box_width = 5 11 | color_palette = sns.color_palette("husl", len(boxes)) 12 | colors = [(int(r*255), int(g*255), int(b*255)) for r, g, b in color_palette] 13 | 14 | width, height = image.size 15 | absolute_boxes = [[(int(box[0] * width), int(box[1] * height), int(box[2] * width), int(box[3] * height)) for box in b] for b in boxes] 16 | 17 | overlay = Image.new('RGBA', image.size, (255, 255, 255, 0)) 18 | draw = ImageDraw.Draw(overlay) 19 | font_path = sorted(matplotlib.font_manager.findSystemFonts(fontpaths=None, fontext='ttf'))[0] 20 | font = ImageFont.truetype(font_path, size=26) 21 | 22 | for box, text, color in zip(absolute_boxes, texts, colors): 23 | for b in box: 24 | draw.rectangle(b, outline=color, width=box_width) 25 | if not text: 26 | continue 27 | splited_text = text.split('\n') 28 | num_lines = len(splited_text) 29 | text_width, text_height = font.getbbox(splited_text[0])[-2:] 30 | y_start = b[3] - text_height * num_lines - box_width 31 | if b[2] - b[0] < 100 or b[3] - b[1] < 100: 32 | y_start = b[3] 33 | for i, line in enumerate(splited_text): 34 | text_width, text_height = font.getbbox(line)[-2:] 35 | x = b[0] + box_width 36 | y = y_start + text_height * i 37 | draw.rectangle([x, y, x+text_width, y+text_height], fill=(128, 128, 128, 160)) 38 | draw.text((x, y), line, font=font, fill=(255, 255, 255)) 39 | img_with_overlay = Image.alpha_composite(image.convert('RGBA'), overlay).convert('RGB') 40 | img_with_overlay.save(output_fn) 41 | 42 | def boxstr_to_boxes(box_str): 43 | boxes = [[int(y)/1000 for y in x.split(',')] for x in box_str.split(';') if x.replace(',', '').isdigit()] 44 | return boxes 45 | 46 | def text_to_dict(text): 47 | doc = nlp(text) 48 | 49 | box_matches = list(re.finditer(r'\[\[([^\]]+)\]\]', text)) 50 | box_positions = [match.start() for match in box_matches] 51 | 52 | noun_phrases = [] 53 | boxes = [] 54 | 55 | for match, box_position in zip(box_matches, box_positions): 56 | nearest_np_start = max([0] + [chunk.start_char for chunk in doc.noun_chunks if chunk.end_char <= box_position]) 57 | noun_phrase = text[nearest_np_start:box_position].strip() 58 | if noun_phrase and noun_phrase[-1] == '?': 59 | noun_phrase = text[:box_position].strip() 60 | box_string = match.group(1) 61 | 62 | noun_phrases.append(noun_phrase) 63 | boxes.append(boxstr_to_boxes(box_string)) 64 | 65 | pairs = [] 66 | for noun_phrase, box_string in zip(noun_phrases, boxes): 67 | pairs.append((noun_phrase.lower(), box_string)) 68 | return dict(pairs) 69 | 70 | def parse_response(img, response, output_fn='output.png'): 71 | img = img.convert('RGB') 72 | width, height = img.size 73 | ratio = min(1920 / width, 1080 / height) 74 | new_width = int(width * ratio) 75 | new_height = int(height * ratio) 76 | new_img = img.resize((new_width, new_height), Image.LANCZOS) 77 | pattern = r"\[\[(.*?)\]\]" 78 | positions = re.findall(pattern, response) 79 | boxes = [[[int(y) for y in x.split(',')] for x in pos.split(';') if x.replace(',', '').isdigit()] for pos in positions] 80 | dic = text_to_dict(response) 81 | if not dic: 82 | texts = [] 83 | boxes = [] 84 | else: 85 | texts, boxes = zip(*dic.items()) 86 | draw_boxes(new_img, boxes, texts, output_fn=output_fn) -------------------------------------------------------------------------------- /utils/utils/language.py: -------------------------------------------------------------------------------- 1 | def base_history_to_prompt(self, query, history): 2 | prompt = '' + query 3 | return prompt 4 | 5 | def chat_history_to_prompt(self, query, history): 6 | prompt = " [INST] " 7 | for i, (old_query, response) in enumerate(history): 8 | prompt += old_query + " [/INST] " + response + " [INST] " 9 | prompt += query + " [/INST] " 10 | return prompt 11 | 12 | def vqa_history_to_prompt(self, query, history): 13 | # Only support single round chat in vqa mode 14 | prompt = "Question: " 15 | # for i, (old_query, response) in enumerate(history): 16 | # prompt += old_query + " Short answer: " + response + " Question: " 17 | prompt += query + " Short answer:" 18 | return prompt 19 | 20 | def chat_old_history_to_prompt(self, query, history): 21 | prompt = "Question: " 22 | for i, (old_query, response) in enumerate(history): 23 | prompt += old_query + " Answer: " + response + "\nQuestion: " 24 | prompt += query + " Answer:" 25 | return prompt 26 | 27 | _history_to_prompt = { 28 | "base": base_history_to_prompt, 29 | "chat": chat_history_to_prompt, 30 | "vqa": vqa_history_to_prompt, 31 | "chat_old": chat_old_history_to_prompt, # for cogvlm-v1.1 32 | } 33 | 34 | from transformers import LlamaTokenizer 35 | 36 | def llama2_tokenizer(tokenizer_path, signal_type="base"): 37 | tokenizer = LlamaTokenizer.from_pretrained(tokenizer_path) 38 | if tokenizer.pad_token_id is None: 39 | tokenizer.pad_token_id = 32000 40 | tokenizer.boi = "[IMG]" 41 | tokenizer.eoi = "[/IMG]" 42 | assert signal_type in ["base", "chat", "vqa", "chat_old"] 43 | tokenizer.signal_type = signal_type 44 | return tokenizer 45 | 46 | import re 47 | import numpy as np 48 | import torch 49 | 50 | class llama2_text_processor: 51 | def __init__(self, tokenizer, max_target_length=2048, image_length=257, model=None): 52 | self.tokenizer = tokenizer 53 | self.max_target_length = max_target_length 54 | self.image_length = image_length 55 | 56 | def __call__(self, caption, prompt=""): 57 | if '' not in prompt: 58 | prompt = self.replace_tags_with_empty(prompt) 59 | # caption = self.replace_tags_with_empty(caption) 60 | history = [] 61 | prompt = self.history_to_prompt(prompt, history) 62 | 63 | input_ids = [self.tokenizer.bos_token_id] 64 | 65 | prompt_splits = prompt.split('') 66 | caption_splits = caption.split('') 67 | if len(prompt_splits) > 0: 68 | input_ids.extend(self.tokenizer.encode(prompt_splits[0], add_special_tokens=False)) 69 | for tokens in prompt_splits[1:]: 70 | tokens_with_img = [-100] + self.tokenizer.encode(tokens, add_special_tokens=False) 71 | input_ids.extend(tokens_with_img) 72 | context_length = len(input_ids) + (len(prompt_splits)-1) * (self.image_length + 1) 73 | if context_length > self.max_target_length - 10: 74 | return None 75 | if len(caption_splits) > 0: 76 | input_ids.extend(self.tokenizer.encode(caption_splits[0], add_special_tokens=False)) 77 | for tokens in caption_splits[1:]: 78 | tokens_with_img = [-100] + self.tokenizer.encode(tokens, add_special_tokens=False) 79 | input_ids.extend(tokens_with_img) 80 | 81 | if len(input_ids) > self.max_target_length - self.image_length - 5: 82 | input_ids = input_ids[:self.max_target_length - self.image_length - 5] 83 | 84 | input_ids += [self.tokenizer.eos_token_id] 85 | 86 | while -100 in input_ids: 87 | img_idx = input_ids.index(-100) 88 | input_ids = input_ids[:img_idx] + [0] * (self.image_length + 1) + [-1] + input_ids[img_idx+1:] 89 | 90 | image_position = [] 91 | while -1 in input_ids: 92 | img_idx = input_ids.index(-1) 93 | input_ids[img_idx] = 0 94 | image_position.append(img_idx) 95 | 96 | image_embed_mask = [0] * len(input_ids) 97 | vision_expert_mask = [0] * len(input_ids) 98 | image_rope_mask = [0] * len(input_ids) 99 | for idx in image_position: 100 | image_embed_mask[idx-self.image_length-1: idx+1] = [1] * (self.image_length + 2) 101 | vision_expert_mask[idx-self.image_length-1: idx] = [1] * (self.image_length + 1) 102 | image_rope_mask[idx - self.image_length: idx] = [1] * self.image_length 103 | attention_mask = [1] * len(input_ids) 104 | labels = [-100] * context_length + input_ids[context_length:] 105 | 106 | pad_len = self.max_target_length - len(input_ids) 107 | input_ids = input_ids + [self.tokenizer.pad_token_id] * pad_len 108 | attention_mask = attention_mask + [1] * pad_len 109 | vision_expert_mask = vision_expert_mask + [0] * pad_len 110 | image_embed_mask = image_embed_mask + [0] * pad_len 111 | image_rope_mask = image_rope_mask + [0] * pad_len 112 | np_mask = np.tril(np.expand_dims(np.array(attention_mask), 0).repeat(len(attention_mask), 0)) 113 | labels = labels + [-100] * pad_len 114 | 115 | for idx in image_position: 116 | labels[idx-self.image_length-1: idx+1] = [-100] * (self.image_length + 2) 117 | 118 | position_ids = [] 119 | pid = -1 120 | for i in range(len(input_ids)): 121 | if image_rope_mask[i] == 0 or (i > 0 and image_rope_mask[i] != image_rope_mask[i - 1]): 122 | pid += 1 123 | position_ids.append(pid) 124 | 125 | input_ids = torch.tensor(input_ids).unsqueeze(0) 126 | labels = torch.tensor(labels).unsqueeze(0) 127 | attention_mask = torch.from_numpy(np_mask).unsqueeze(0).unsqueeze(0) 128 | image_embed_mask = torch.tensor(image_embed_mask).unsqueeze(0) 129 | vision_expert_mask = torch.tensor(vision_expert_mask).unsqueeze(0) 130 | image_rope_mask = torch.tensor(image_rope_mask).unsqueeze(0) 131 | position_ids = torch.tensor(position_ids).unsqueeze(0) 132 | context_length = torch.tensor(context_length).unsqueeze(0).long() 133 | return {'input_ids': input_ids, 'labels': labels, 'position_ids': position_ids, 'attention_mask': attention_mask, 'image_embed_mask': image_embed_mask, 134 | 'context_length': context_length, 'image_position': image_position, 'vision_expert_mask': vision_expert_mask, 'image_rope_mask': image_rope_mask 135 | } 136 | 137 | def history_to_prompt(self, query, history): 138 | return _history_to_prompt[self.tokenizer.signal_type](self, query, history) 139 | 140 | def replace_tags_with_empty(self, text): 141 | return re.sub('|||', '', text) 142 | 143 | from functools import partial 144 | def get_masks_and_position_ids(seq, image_logits_mask): 145 | tokens = seq.unsqueeze(0) 146 | 147 | attention_mask = torch.ones((1, len(seq), len(seq)), device=tokens.device) 148 | attention_mask.tril_() 149 | attention_mask.unsqueeze_(1) 150 | 151 | position_ids = [] 152 | pid = -1 153 | for i in range(len(image_logits_mask[0])): 154 | if image_logits_mask[0][i] == 0 or (i > 0 and image_logits_mask[0][i] != image_logits_mask[0][i - 1]): 155 | pid += 1 156 | position_ids.append(pid) 157 | for i in range(tokens.shape[1]-image_logits_mask.shape[1]): 158 | pid += 1 159 | position_ids.append(pid) 160 | position_ids = torch.tensor(position_ids, dtype=torch.long, device=tokens.device) 161 | position_ids = position_ids.unsqueeze(0) 162 | 163 | return tokens, attention_mask, position_ids 164 | 165 | class llama2_text_processor_inference: 166 | def __init__(self, tokenizer, max_target_length=1024, image_length=257, model=None, no_prompt=False, english=True): 167 | self.tokenizer = tokenizer 168 | self.max_target_length = max_target_length 169 | self.image_length = image_length 170 | if self.tokenizer.signal_type == "chat": 171 | self.sep = "[/INST]" 172 | elif self.tokenizer.signal_type == "vqa": 173 | self.sep = " Short answer:" 174 | elif self.tokenizer.signal_type == "chat_old": 175 | self.sep = " Answer:" 176 | else: 177 | self.sep = "" 178 | 179 | self.invalid_slices = [] 180 | self.no_eoi = True 181 | 182 | def __call__(self, prompt=""): 183 | if '' not in prompt: 184 | prompt = self.replace_tags_with_empty(prompt) 185 | # caption = self.replace_tags_with_empty(caption) 186 | history = [] 187 | prompt = self.history_to_prompt(prompt, history) 188 | 189 | input_ids = [self.tokenizer.bos_token_id] 190 | 191 | prompt_splits = prompt.split('') 192 | if len(prompt_splits) > 0: 193 | input_ids.extend(self.tokenizer.encode(prompt_splits[0], add_special_tokens=False)) 194 | for tokens in prompt_splits[1:]: 195 | tokens_with_img = [-100] + self.tokenizer.encode(tokens, add_special_tokens=False) 196 | input_ids.extend(tokens_with_img) 197 | 198 | while -100 in input_ids: 199 | img_idx = input_ids.index(-100) 200 | input_ids = input_ids[:img_idx] + [0] * (self.image_length + 1) + [-1] + input_ids[img_idx + 1:] 201 | 202 | image_position = [] 203 | while -1 in input_ids: 204 | img_idx = input_ids.index(-1) 205 | input_ids[img_idx] = 0 206 | image_position.append(img_idx) 207 | 208 | image_embed_mask = [0] * len(input_ids) 209 | vision_expert_mask = [0] * len(input_ids) 210 | image_rope_mask = [0] * len(input_ids) 211 | for idx in image_position: 212 | image_embed_mask[idx - self.image_length - 1: idx + 1] = [1] * (self.image_length + 2) 213 | vision_expert_mask[idx - self.image_length - 1: idx] = [1] * (self.image_length + 1) 214 | image_rope_mask[idx - self.image_length: idx] = [1] * self.image_length 215 | 216 | input_ids = torch.tensor(input_ids).unsqueeze(0) 217 | image_embed_mask = torch.tensor(image_embed_mask).unsqueeze(0) 218 | vision_expert_mask = torch.tensor(vision_expert_mask).unsqueeze(0) 219 | image_rope_mask = torch.tensor(image_rope_mask).unsqueeze(0) 220 | return {'input_ids': input_ids, 'image_embed_mask': image_embed_mask, 'vision_expert_mask': vision_expert_mask, 'image_rope_mask': image_rope_mask} 221 | 222 | def history_to_prompt(self, query, history): 223 | return _history_to_prompt[self.tokenizer.signal_type](self, query, history) 224 | 225 | def replace_tags_with_empty(self, text): 226 | return re.sub('|||', '', text) 227 | 228 | def process_response(self, response): 229 | return response.replace('', '') 230 | 231 | def get_func(self, inputs, **kwargs): 232 | get_func = partial(get_masks_and_position_ids, image_logits_mask=kwargs['image_rope_mask']) 233 | return get_func -------------------------------------------------------------------------------- /utils/utils/vision.py: -------------------------------------------------------------------------------- 1 | from torchvision import transforms 2 | from torchvision.transforms.functional import InterpolationMode 3 | import torch 4 | 5 | class BlipImageEvalProcessor: 6 | def __init__(self, image_size=384, mean=None, std=None): 7 | super().__init__() 8 | if mean is None: 9 | mean = (0.48145466, 0.4578275, 0.40821073) 10 | if std is None: 11 | std = (0.26862954, 0.26130258, 0.27577711) 12 | 13 | self.normalize = transforms.Normalize(mean, std) 14 | 15 | self.transform = transforms.Compose( 16 | [ 17 | transforms.Resize( 18 | (image_size, image_size), interpolation=InterpolationMode.BICUBIC 19 | ), 20 | transforms.ToTensor(), 21 | self.normalize, 22 | ] 23 | ) 24 | 25 | def __call__(self, item): 26 | return self.transform(item) 27 | 28 | from functools import partial 29 | 30 | def blip2_image_processor_func_with_inputs(image_processor, image): 31 | return {'image': image_processor(image).unsqueeze(0), 'input_ids': torch.zeros(1, 1, dtype=torch.long), 'position_ids': None, 'attention_mask': torch.ones(1, 1, dtype=torch.long)} 32 | 33 | def get_image_processor(image_size): 34 | return partial(blip2_image_processor_func_with_inputs, BlipImageEvalProcessor(image_size)) --------------------------------------------------------------------------------