├── .deepspeed_env
├── .github
├── ISSUE_TEMPLATE
│ ├── bug_report.yaml
│ └── feature-request.yaml
└── PULL_REQUEST_TEMPLATE
│ └── pr_template.md
├── .gitignore
├── LICENSE
├── MODEL_LICENSE
├── README.md
├── README_zh.md
├── assets
├── WECHAT.md
├── chat-min.png
├── chat.png
├── cogagent_function.jpg
├── cogagent_function_cn.jpg
├── cogagent_main_demo.jpg
├── compare-min.png
├── compare.png
├── llava-comparison-min.png
├── method-min.png
├── method.png
├── metrics-min.png
├── metrics.png
├── pear_grounding.png
├── web_demo-min.png
├── web_demo.png
└── wechat.jpg
├── basic_demo
├── cli_demo_hf.py
├── cli_demo_sat.py
└── web_demo.py
├── composite_demo
├── client.py
├── conversation.py
├── demo_agent_cogagent.py
├── demo_chat_cogagent.py
├── demo_chat_cogvlm.py
├── main.py
└── utils.py
├── dataset.md
├── dataset_zh.md
├── finetune_demo
├── evaluate_cogagent.sh
├── evaluate_cogagent_demo.py
├── evaluate_cogvlm.sh
├── evaluate_cogvlm_demo.py
├── finetune_cogagent_demo.py
├── finetune_cogagent_lora.sh
├── finetune_cogvlm_demo.py
├── finetune_cogvlm_lora.sh
└── test_config_bf16.json
├── openai_demo
├── demo.jpg
├── openai_api.py
└── openai_api_request.py
├── requirements.txt
└── utils
├── __init__.py
├── merge_model.py
├── models
├── __init__.py
├── cogagent_model.py
├── cogvlm_model.py
├── eva_clip_L_hf.py
├── eva_clip_model.py
└── mixin.py
├── split_dataset.py
└── utils
├── __init__.py
├── chat.py
├── dataset.py
├── grounding_parser.py
├── language.py
├── template.py
└── vision.py
/.deepspeed_env:
--------------------------------------------------------------------------------
1 | SAT_HOME=~/.sat_models
2 | LOCAL_WORLD_SIZE=8
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/bug_report.yaml:
--------------------------------------------------------------------------------
1 | name: "\U0001F41B Bug Report"
2 | description: Submit a bug report to help us improve ChatGLM3 / 提交一个 Bug 问题报告来帮助我们改进 ChatGLM3
3 | body:
4 | - type: textarea
5 | id: system-info
6 | attributes:
7 | label: System Info / 系統信息
8 | description: Your operating environment / 您的运行环境信息
9 | placeholder: Includes Cuda version, Transformers version, Python version, operating system, hardware information (if you suspect a hardware problem)... / 包括Cuda版本,Transformers版本,Python版本,操作系统,硬件信息(如果您怀疑是硬件方面的问题)...
10 | validations:
11 | required: true
12 |
13 | - type: textarea
14 | id: who-can-help
15 | attributes:
16 | label: Who can help? / 谁可以帮助到您?
17 | description: |
18 | Your issue will be replied to more quickly if you can figure out the right person to tag with @
19 | All issues are read by one of the maintainers, so if you don't know who to tag, just leave this blank and our maintainer will ping the right person.
20 |
21 | Please tag fewer than 3 people.
22 |
23 | 如果您能找到合适的标签 @,您的问题会更快得到回复。
24 | 所有问题都会由我们的维护者阅读,如果您不知道该标记谁,只需留空,我们的维护人员会找到合适的开发组成员来解决问题。
25 |
26 | 标记的人数应该不超过 1 个人。
27 |
28 | Related demo leader / 相关demo负责人 :
29 | - finetune_demo: @1049451037
30 | - composite_demo: @zR
31 | - openai_demo: @zR
32 |
33 |
34 | If it's not a bug in these three subsections, you may not specify the helper. Our maintainer will find the right person in the development group to solve the problem.
35 |
36 | 如果不是这三个子版块的bug,您可以不指明帮助者,我们的维护人员会找到合适的开发组成员来解决问题。
37 |
38 | placeholder: "@Username ..."
39 |
40 | - type: checkboxes
41 | id: information-scripts-examples
42 | attributes:
43 | label: Information / 问题信息
44 | description: 'The problem arises when using: / 问题出现在'
45 | options:
46 | - label: "The official example scripts / 官方的示例脚本"
47 | - label: "My own modified scripts / 我自己修改的脚本和任务"
48 |
49 | - type: textarea
50 | id: reproduction
51 | validations:
52 | required: true
53 | attributes:
54 | label: Reproduction / 复现过程
55 | description: |
56 | Please provide a code example that reproduces the problem you encountered, preferably with a minimal reproduction unit.
57 | If you have code snippets, error messages, stack traces, please provide them here as well.
58 | Please format your code correctly using code tags. See https://help.github.com/en/github/writing-on-github/creating-and-highlighting-code-blocks#syntax-highlighting
59 | Do not use screenshots, as they are difficult to read and (more importantly) do not allow others to copy and paste your code.
60 |
61 | 请提供能重现您遇到的问题的代码示例,最好是最小复现单元。
62 | 如果您有代码片段、错误信息、堆栈跟踪,也请在此提供。
63 | 请使用代码标签正确格式化您的代码。请参见 https://help.github.com/en/github/writing-on-github/creating-and-highlighting-code-blocks#syntax-highlighting
64 | 请勿使用截图,因为截图难以阅读,而且(更重要的是)不允许他人复制粘贴您的代码。
65 | placeholder: |
66 | Steps to reproduce the behavior/复现Bug的步骤:
67 |
68 | 1.
69 | 2.
70 | 3.
71 |
72 | - type: textarea
73 | id: expected-behavior
74 | validations:
75 | required: true
76 | attributes:
77 | label: Expected behavior / 期待表现
78 | description: "A clear and concise description of what you would expect to happen. /简单描述您期望发生的事情。"
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/feature-request.yaml:
--------------------------------------------------------------------------------
1 | name: "\U0001F680 Feature request"
2 | description: Submit a request for a new ChatGLM3 feature / 提交一个新的 ChatGLM3 的功能建议
3 | labels: [ "feature" ]
4 | body:
5 | - type: textarea
6 | id: feature-request
7 | validations:
8 | required: true
9 | attributes:
10 | label: Feature request / 功能建议
11 | description: |
12 | A brief description of the functional proposal. Links to corresponding papers and code are desirable.
13 | 对功能建议的简述。最好提供对应的论文和代码链接
14 |
15 | - type: textarea
16 | id: motivation
17 | validations:
18 | required: true
19 | attributes:
20 | label: Motivation / 动机
21 | description: |
22 | Your motivation for making the suggestion. If that motivation is related to another GitHub issue, link to it here.
23 | 您提出建议的动机。如果该动机与另一个 GitHub 问题有关,请在此处提供对应的链接。
24 |
25 | - type: textarea
26 | id: contribution
27 | validations:
28 | required: true
29 | attributes:
30 | label: Your contribution / 您的贡献
31 | description: |
32 |
33 | Your PR link or any other link you can help with.
34 | 您的PR链接或者其他您能提供帮助的链接。
--------------------------------------------------------------------------------
/.github/PULL_REQUEST_TEMPLATE/pr_template.md:
--------------------------------------------------------------------------------
1 | # Raise valuable PR / 提出有价值的PR
2 |
3 | ## Caution/ 注意事项:
4 | Users should keep the following points in mind when submitting PRs:
5 |
6 | 1. The proposed PR should be about this project.
7 | 2. the proposed PR should be relevant, if there are multiple ideas and optimizations, they should be assigned to different PRs.
8 |
9 | 用户在提交PR时候应该注意以下几点:
10 |
11 | 1. 提出的PR应该是关于本项目的。
12 | 2. 提出的PR应该具有针对性,如果具有多个不同的想法和优化方案,应该分配到不同的PR中。
13 |
14 | ## 不应该提出的PR / PRs that should not be proposed
15 |
16 | If a developer proposes a PR about any of the following, it may be closed or Rejected.
17 |
18 | 1. those that don't describe improvement options.
19 | 2. multiple issues of different types combined in one PR.
20 | 3. The proposed PR is highly duplicative of already existing PRs.
21 |
22 | 如果开发者提出关于以下方面的PR,则可能会被直接关闭或拒绝通过。
23 |
24 | 1. 没有说明改进方案的。
25 | 2. 多个不同类型的问题合并在一个PR中的。
26 | 3. 提出的PR与已经存在的PR高度重复的。
27 |
28 |
29 | # 检查您的PR
30 | - [ ] Have you read the Contributor Guidelines, Pull Request section? / 您是否阅读了贡献者指南、Pull Request 部分?
31 | - [ ] Has this been discussed/approved via a Github issue or forum? If so, add a link. / 是否通过 Github 问题或论坛讨论/批准过?如果是,请添加链接。
32 | - [ ] Did you make sure you updated the documentation with your changes? Here are the Documentation Guidelines, and here are the Documentation Formatting Tips. /您是否确保根据您的更改更新了文档?这里是文档指南,这里是文档格式化技巧。
33 | - [ ] Did you write new required tests? / 您是否编写了新的必要测试?
34 | - [ ] Are your PRs for only one issue / 您的PR是否仅针对一个问题
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | .hypothesis/
2 | __pycache__
3 | output.png
4 | fewshot-data/
5 | checkpoints/
6 | records.db
7 | server.py
8 | examples/*grounding.png
9 | archive*
10 | hostfile
11 | runs/
12 | *.idea/
13 | .DS_Store
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright 2024 CogVLM team @ Zhipu AI
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
--------------------------------------------------------------------------------
/MODEL_LICENSE:
--------------------------------------------------------------------------------
1 | The CogVLM License
2 |
3 | 1. Definitions
4 |
5 | “Licensor” means the CogVLM Model Team that distributes its Software.
6 |
7 | “Software” means the CogVLM model parameters made available under this license.
8 |
9 | 2. License Grant
10 |
11 | Under the terms and conditions of this license, the Licensor hereby grants you a non-exclusive, worldwide, non-transferable, non-sublicensable, revocable, royalty-free copyright license.
12 | This license permits you to use all open-source models in this repository for academic research free. Users who wish to use the models for commercial purposes must register [here](https://open.bigmodel.cn/mla/form).
13 | Registered users may use the models for commercial activities free of charge, but must comply with all terms and conditions of this license.
14 | The license notice shall be included in all copies or substantial portions of the Software.
15 |
16 | 3. Restriction
17 |
18 | You will not use, copy, modify, merge, publish, distribute, reproduce, or create derivative works of the Software, in whole or in part, for any military, or illegal purposes.
19 |
20 | You will not use the Software for any act that may undermine China's national security and national unity, harm the public interest of society, or infringe upon the rights and interests of human beings.
21 |
22 | 4. Disclaimer
23 |
24 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
25 |
26 | 5. Limitation of Liability
27 |
28 | EXCEPT TO THE EXTENT PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL THEORY, WHETHER BASED IN TORT, NEGLIGENCE, CONTRACT, LIABILITY, OR OTHERWISE WILL ANY LICENSOR BE LIABLE TO YOU FOR ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES, OR ANY OTHER COMMERCIAL LOSSES, EVEN IF THE LICENSOR HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES.
29 |
30 | 6. Dispute Resolution
31 |
32 | This license shall be governed and construed in accordance with the laws of People’s Republic of China. Any dispute arising from or in connection with this License shall be submitted to Haidian District People's Court in Beijing.
33 |
34 | Note that the license is subject to update to a more comprehensive version. For any questions related to the license and copyright, please contact us at license@zhipuai.cn.
35 |
36 | 7. Llama2 and EVA-CLIP2 License
37 |
38 | For CogVLM-17B version, Llama2 license conditions (https://ai.meta.com/llama/license/) and EVA license conditions (MIT, https://github.com/baaivision/EVA/blob/master/LICENSE) Also applies to model weights.
39 |
40 |
41 | 1. 定义
42 |
43 | “许可方”是指分发其软件的 CogVLM 模型团队。
44 |
45 | “软件”是指根据本许可提供的 CogVLM 模型参数。
46 |
47 | 2. 许可授予
48 |
49 | 根据本许可的条款和条件,许可方特此授予您非排他性、全球性、不可转让、不可再许可、可撤销、免版税的版权许可。
50 | 本许可允许您免费使用本仓库中的所有开源模型进行学术研究,对于希望将模型用于商业目的的用户,需在[这里](https://open.bigmodel.cn/mla/form)完成登记。
51 | 经过登记的用户可以免费使用本模型进行商业活动,但必须遵守本许可的所有条款和条件。
52 | 上述版权声明和本许可声明应包含在本软件的所有副本或重要部分中。
53 |
54 | 3.限制
55 |
56 | 您不得出于任何军事或非法目的使用、复制、修改、合并、发布、分发、复制或创建本软件的全部或部分衍生作品。
57 |
58 | 您不得利用本软件从事任何危害国家安全和国家统一、危害社会公共利益、侵犯人身权益的行为。
59 |
60 | 4.免责声明
61 |
62 | 本软件“按原样”提供,不提供任何明示或暗示的保证,包括但不限于对适销性、特定用途的适用性和非侵权性的保证。 在任何情况下,作者或版权持有人均不对任何索赔、损害或其他责任负责,无论是在合同诉讼、侵权行为还是其他方面,由软件或软件的使用或其他交易引起、由软件引起或与之相关 软件。
63 |
64 | 5. 责任限制
65 |
66 | 除适用法律禁止的范围外,在任何情况下且根据任何法律理论,无论是基于侵权行为、疏忽、合同、责任或其他原因,任何许可方均不对您承担任何直接、间接、特殊、偶然、示范性、 或间接损害,或任何其他商业损失,即使许可人已被告知此类损害的可能性。
67 |
68 | 6.争议解决
69 |
70 | 本许可受中华人民共和国法律管辖并按其解释。 因本许可引起的或与本许可有关的任何争议应提交北京市海淀区人民法院。
71 |
72 | 请注意,许可证可能会更新到更全面的版本。 有关许可和版权的任何问题,请通过 license@zhipuai.cn 与我们联系。
73 |
74 | 7. Llama2 和 EVA-CLIP2 许可
75 |
76 | 针对 CogVLM-17B 版本, Llama2 许可条件 (https://ai.meta.com/llama/license/) 和 EVA 许可条件 (MIT, https://github.com/baaivision/EVA/blob/master/LICENSE) 同时适用于模型权重。
--------------------------------------------------------------------------------
/assets/WECHAT.md:
--------------------------------------------------------------------------------
1 |
2 |

3 |
4 |
扫码关注公众号,加入「ChatGLM交流群」
5 |
Scan the QR code to follow the official account and join the "ChatGLM Discussion Group"
6 |
7 |
--------------------------------------------------------------------------------
/assets/chat-min.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THUDM/CogVLM/f7283b2c8d26cd7f932d9a5f7f5f9307f568195d/assets/chat-min.png
--------------------------------------------------------------------------------
/assets/chat.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THUDM/CogVLM/f7283b2c8d26cd7f932d9a5f7f5f9307f568195d/assets/chat.png
--------------------------------------------------------------------------------
/assets/cogagent_function.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THUDM/CogVLM/f7283b2c8d26cd7f932d9a5f7f5f9307f568195d/assets/cogagent_function.jpg
--------------------------------------------------------------------------------
/assets/cogagent_function_cn.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THUDM/CogVLM/f7283b2c8d26cd7f932d9a5f7f5f9307f568195d/assets/cogagent_function_cn.jpg
--------------------------------------------------------------------------------
/assets/cogagent_main_demo.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THUDM/CogVLM/f7283b2c8d26cd7f932d9a5f7f5f9307f568195d/assets/cogagent_main_demo.jpg
--------------------------------------------------------------------------------
/assets/compare-min.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THUDM/CogVLM/f7283b2c8d26cd7f932d9a5f7f5f9307f568195d/assets/compare-min.png
--------------------------------------------------------------------------------
/assets/compare.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THUDM/CogVLM/f7283b2c8d26cd7f932d9a5f7f5f9307f568195d/assets/compare.png
--------------------------------------------------------------------------------
/assets/llava-comparison-min.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THUDM/CogVLM/f7283b2c8d26cd7f932d9a5f7f5f9307f568195d/assets/llava-comparison-min.png
--------------------------------------------------------------------------------
/assets/method-min.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THUDM/CogVLM/f7283b2c8d26cd7f932d9a5f7f5f9307f568195d/assets/method-min.png
--------------------------------------------------------------------------------
/assets/method.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THUDM/CogVLM/f7283b2c8d26cd7f932d9a5f7f5f9307f568195d/assets/method.png
--------------------------------------------------------------------------------
/assets/metrics-min.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THUDM/CogVLM/f7283b2c8d26cd7f932d9a5f7f5f9307f568195d/assets/metrics-min.png
--------------------------------------------------------------------------------
/assets/metrics.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THUDM/CogVLM/f7283b2c8d26cd7f932d9a5f7f5f9307f568195d/assets/metrics.png
--------------------------------------------------------------------------------
/assets/pear_grounding.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THUDM/CogVLM/f7283b2c8d26cd7f932d9a5f7f5f9307f568195d/assets/pear_grounding.png
--------------------------------------------------------------------------------
/assets/web_demo-min.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THUDM/CogVLM/f7283b2c8d26cd7f932d9a5f7f5f9307f568195d/assets/web_demo-min.png
--------------------------------------------------------------------------------
/assets/web_demo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THUDM/CogVLM/f7283b2c8d26cd7f932d9a5f7f5f9307f568195d/assets/web_demo.png
--------------------------------------------------------------------------------
/assets/wechat.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THUDM/CogVLM/f7283b2c8d26cd7f932d9a5f7f5f9307f568195d/assets/wechat.jpg
--------------------------------------------------------------------------------
/basic_demo/cli_demo_hf.py:
--------------------------------------------------------------------------------
1 | """
2 | This is a demo for using CogAgent and CogVLM in CLI
3 | Make sure you have installed vicuna-7b-v1.5 tokenizer model (https://huggingface.co/lmsys/vicuna-7b-v1.5), full checkpoint of vicuna-7b-v1.5 LLM is not required.
4 | In this demo, We us chat template, you can use others to replace such as 'vqa'.
5 | Strongly suggest to use GPU with bfloat16 support, otherwise, it will be slow.
6 | Mention that only one picture can be processed at one conversation, which means you can not replace or insert another picture during the conversation.
7 | """
8 |
9 | import argparse
10 | import torch
11 |
12 | from PIL import Image
13 | from transformers import AutoModelForCausalLM, LlamaTokenizer
14 |
15 | parser = argparse.ArgumentParser()
16 | parser.add_argument("--quant", choices=[4], type=int, default=None, help='quantization bits')
17 | parser.add_argument("--from_pretrained", type=str, default="THUDM/cogagent-chat-hf", help='pretrained ckpt')
18 | parser.add_argument("--local_tokenizer", type=str, default="lmsys/vicuna-7b-v1.5", help='tokenizer path')
19 | parser.add_argument("--fp16", action="store_true")
20 | parser.add_argument("--bf16", action="store_true")
21 |
22 | args = parser.parse_args()
23 | MODEL_PATH = args.from_pretrained
24 | TOKENIZER_PATH = args.local_tokenizer
25 | DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
26 |
27 | tokenizer = LlamaTokenizer.from_pretrained(TOKENIZER_PATH)
28 | if args.bf16:
29 | torch_type = torch.bfloat16
30 | else:
31 | torch_type = torch.float16
32 |
33 | print("========Use torch type as:{} with device:{}========\n\n".format(torch_type, DEVICE))
34 |
35 | if args.quant:
36 | model = AutoModelForCausalLM.from_pretrained(
37 | MODEL_PATH,
38 | torch_dtype=torch_type,
39 | low_cpu_mem_usage=True,
40 | load_in_4bit=True,
41 | trust_remote_code=True
42 | ).eval()
43 | else:
44 | model = AutoModelForCausalLM.from_pretrained(
45 | MODEL_PATH,
46 | torch_dtype=torch_type,
47 | low_cpu_mem_usage=True,
48 | load_in_4bit=args.quant is not None,
49 | trust_remote_code=True
50 | ).to(DEVICE).eval()
51 |
52 | text_only_template = "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: {} ASSISTANT:"
53 |
54 | while True:
55 | image_path = input("image path >>>>> ")
56 | if image_path == '':
57 | print('You did not enter image path, the following will be a plain text conversation.')
58 | image = None
59 | text_only_first_query = True
60 | else:
61 | image = Image.open(image_path).convert('RGB')
62 |
63 | history = []
64 |
65 | while True:
66 | query = input("Human:")
67 | if query == "clear":
68 | break
69 |
70 | if image is None:
71 | if text_only_first_query:
72 | query = text_only_template.format(query)
73 | text_only_first_query = False
74 | else:
75 | old_prompt = ''
76 | for _, (old_query, response) in enumerate(history):
77 | old_prompt += old_query + " " + response + "\n"
78 | query = old_prompt + "USER: {} ASSISTANT:".format(query)
79 |
80 | if image is None:
81 | input_by_model = model.build_conversation_input_ids(tokenizer, query=query, history=history, template_version='base')
82 | else:
83 | input_by_model = model.build_conversation_input_ids(tokenizer, query=query, history=history, images=[image])
84 |
85 | inputs = {
86 | 'input_ids': input_by_model['input_ids'].unsqueeze(0).to(DEVICE),
87 | 'token_type_ids': input_by_model['token_type_ids'].unsqueeze(0).to(DEVICE),
88 | 'attention_mask': input_by_model['attention_mask'].unsqueeze(0).to(DEVICE),
89 | 'images': [[input_by_model['images'][0].to(DEVICE).to(torch_type)]] if image is not None else None,
90 | }
91 | if 'cross_images' in input_by_model and input_by_model['cross_images']:
92 | inputs['cross_images'] = [[input_by_model['cross_images'][0].to(DEVICE).to(torch_type)]]
93 |
94 | # add any transformers params here.
95 | gen_kwargs = {"max_length": 2048,
96 | "do_sample": False} # "temperature": 0.9
97 | with torch.no_grad():
98 | outputs = model.generate(**inputs, **gen_kwargs)
99 | outputs = outputs[:, inputs['input_ids'].shape[1]:]
100 | response = tokenizer.decode(outputs[0])
101 | response = response.split("")[0]
102 | print("\nCog:", response)
103 | history.append((query, response))
104 |
--------------------------------------------------------------------------------
/basic_demo/cli_demo_sat.py:
--------------------------------------------------------------------------------
1 | # -*- encoding: utf-8 -*-
2 | import os, sys
3 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
4 |
5 | import torch
6 | import argparse
7 | from sat.model.mixins import CachedAutoregressiveMixin
8 | from sat.quantization.kernels import quantize
9 | from sat.model import AutoModel
10 |
11 |
12 | from utils.utils import chat, llama2_tokenizer, llama2_text_processor_inference, get_image_processor
13 | from utils.models import CogAgentModel, CogVLMModel
14 |
15 | def main():
16 | parser = argparse.ArgumentParser()
17 | parser.add_argument("--max_length", type=int, default=2048, help='max length of the total sequence')
18 | parser.add_argument("--top_p", type=float, default=0.4, help='top p for nucleus sampling')
19 | parser.add_argument("--top_k", type=int, default=1, help='top k for top k sampling')
20 | parser.add_argument("--temperature", type=float, default=.8, help='temperature for sampling')
21 | parser.add_argument("--chinese", action='store_true', help='Chinese interface')
22 | parser.add_argument("--version", type=str, default="chat", choices=['chat', 'vqa', 'chat_old', 'base'], help='version of language process. if there is \"text_processor_version\" in model_config.json, this option will be overwritten')
23 | parser.add_argument("--quant", choices=[8, 4], type=int, default=None, help='quantization bits')
24 |
25 | parser.add_argument("--from_pretrained", type=str, default="cogagent-chat", help='pretrained ckpt')
26 | parser.add_argument("--local_tokenizer", type=str, default="lmsys/vicuna-7b-v1.5", help='tokenizer path')
27 | parser.add_argument("--fp16", action="store_true")
28 | parser.add_argument("--bf16", action="store_true")
29 | parser.add_argument("--stream_chat", action="store_true")
30 | args = parser.parse_args()
31 | rank = int(os.environ.get('RANK', 0))
32 | world_size = int(os.environ.get('WORLD_SIZE', 1))
33 | args = parser.parse_args()
34 |
35 | # load model
36 | model, model_args = AutoModel.from_pretrained(
37 | args.from_pretrained,
38 | args=argparse.Namespace(
39 | deepspeed=None,
40 | local_rank=rank,
41 | rank=rank,
42 | world_size=world_size,
43 | model_parallel_size=world_size,
44 | mode='inference',
45 | skip_init=True,
46 | use_gpu_initialization=True if (torch.cuda.is_available() and args.quant is None) else False,
47 | device='cpu' if args.quant else 'cuda',
48 | **vars(args)
49 | ), overwrite_args={'model_parallel_size': world_size} if world_size != 1 else {})
50 | model = model.eval()
51 | from sat.mpu import get_model_parallel_world_size
52 | assert world_size == get_model_parallel_world_size(), "world size must equal to model parallel size for cli_demo!"
53 |
54 | language_processor_version = model_args.text_processor_version if 'text_processor_version' in model_args else args.version
55 | print("[Language processor version]:", language_processor_version)
56 | tokenizer = llama2_tokenizer(args.local_tokenizer, signal_type=language_processor_version)
57 | image_processor = get_image_processor(model_args.eva_args["image_size"][0])
58 | cross_image_processor = get_image_processor(model_args.cross_image_pix) if "cross_image_pix" in model_args else None
59 |
60 | if args.quant:
61 | quantize(model, args.quant)
62 | if torch.cuda.is_available():
63 | model = model.cuda()
64 |
65 |
66 | model.add_mixin('auto-regressive', CachedAutoregressiveMixin())
67 |
68 | text_processor_infer = llama2_text_processor_inference(tokenizer, args.max_length, model.image_length)
69 |
70 | if args.chinese:
71 | if rank == 0:
72 | print('欢迎使用 CogAgent-CLI ,输入图像URL或本地路径读图,继续输入内容对话,clear 重新开始,stop 终止程序')
73 | else:
74 | if rank == 0:
75 | print('Welcome to CogAgent-CLI. Enter an image URL or local file path to load an image. Continue inputting text to engage in a conversation. Type "clear" to start over, or "stop" to end the program.')
76 | with torch.no_grad():
77 | while True:
78 | history = None
79 | cache_image = None
80 | if args.chinese:
81 | if rank == 0:
82 | image_path = [input("请输入图像路径或URL: ")]
83 | else:
84 | image_path = [None]
85 | else:
86 | if rank == 0:
87 | image_path = [input("Please enter the image path or URL: ")]
88 | else:
89 | image_path = [None]
90 | if world_size > 1:
91 | torch.distributed.broadcast_object_list(image_path, 0)
92 | image_path = image_path[0]
93 | assert image_path is not None
94 |
95 | if image_path == 'stop':
96 | break
97 |
98 | if args.chinese:
99 | if rank == 0:
100 | query = [input("用户:")]
101 | else:
102 | query = [None]
103 | else:
104 | if rank == 0:
105 | query = [input("User: ")]
106 | else:
107 | query = [None]
108 | if world_size > 1:
109 | torch.distributed.broadcast_object_list(query, 0)
110 | query = query[0]
111 | assert query is not None
112 |
113 | while True:
114 | if query == "clear":
115 | break
116 | if query == "stop":
117 | sys.exit(0)
118 | try:
119 | response, history, cache_image = chat(
120 | image_path,
121 | model,
122 | text_processor_infer,
123 | image_processor,
124 | query,
125 | history=history,
126 | cross_img_processor=cross_image_processor,
127 | image=cache_image,
128 | max_length=args.max_length,
129 | top_p=args.top_p,
130 | temperature=args.temperature,
131 | top_k=args.top_k,
132 | invalid_slices=text_processor_infer.invalid_slices,
133 | args=args
134 | )
135 | except Exception as e:
136 | print(e)
137 | break
138 | if rank == 0 and not args.stream_chat:
139 | if args.chinese:
140 | print("模型:"+response)
141 | else:
142 | print("Model: "+response)
143 | image_path = None
144 | if args.chinese:
145 | if rank == 0:
146 | query = [input("用户:")]
147 | else:
148 | query = [None]
149 | else:
150 | if rank == 0:
151 | query = [input("User: ")]
152 | else:
153 | query = [None]
154 | if world_size > 1:
155 | torch.distributed.broadcast_object_list(query, 0)
156 | query = query[0]
157 | assert query is not None
158 |
159 |
160 | if __name__ == "__main__":
161 | main()
162 |
--------------------------------------------------------------------------------
/basic_demo/web_demo.py:
--------------------------------------------------------------------------------
1 | """
2 | This script is a simple web demo of the CogVLM and CogAgent models, designed for easy and quick demonstrations.
3 | For a more sophisticated user interface, users are encouraged to refer to the 'composite_demo',
4 | which is built with a more aesthetically pleasing Streamlit framework.
5 |
6 | Usage:
7 | - Use the interface to upload images and enter text prompts to interact with the models.
8 |
9 | Requirements:
10 | - Gradio (only 3.x,4.x is not support) and other necessary Python dependencies must be installed.
11 | - Proper model checkpoints should be accessible as specified in the script.
12 |
13 | Note: This demo is ideal for a quick showcase of the CogVLM and CogAgent models. For a more comprehensive and interactive
14 | experience, refer to the 'composite_demo'.
15 | """
16 | import gradio as gr
17 | import os, sys
18 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
19 |
20 | from PIL import Image
21 | import torch
22 | import time
23 | from sat.model.mixins import CachedAutoregressiveMixin
24 | from sat.mpu import get_model_parallel_world_size
25 | from sat.model import AutoModel
26 |
27 |
28 | from utils.utils import chat, llama2_tokenizer, llama2_text_processor_inference, get_image_processor, parse_response
29 | from utils.models import CogAgentModel, CogVLMModel
30 |
31 |
32 |
33 | DESCRIPTION = ''''''
34 |
35 | NOTES = ' This app is adapted from https://github.com/THUDM/CogVLM. It would be recommended to check out the repo if you want to see the detail of our model, CogVLM & CogAgent.
'
36 |
37 | MAINTENANCE_NOTICE1 = 'Hint 1: If the app report "Something went wrong, connection error out", please turn off your proxy and retry.
Hint 2: If you upload a large size of image like 10MB, it may take some time to upload and process. Please be patient and wait.'
38 |
39 |
40 | AGENT_NOTICE = 'Hint 1: To use Agent function, please use the prompts for agents.'
41 |
42 | GROUNDING_NOTICE = 'Hint 2: To use Grounding function, please use the prompts for grounding.'
43 |
44 |
45 |
46 |
47 | default_chatbox = [("", "Hi, What do you want to know about this image?")]
48 |
49 |
50 | model = image_processor = text_processor_infer = None
51 |
52 | is_grounding = False
53 |
54 | def process_image_without_resize(image_prompt):
55 | image = Image.open(image_prompt)
56 | # print(f"height:{image.height}, width:{image.width}")
57 | timestamp = int(time.time())
58 | file_ext = os.path.splitext(image_prompt)[1]
59 | filename_grounding = f"examples/{timestamp}_grounding{file_ext}"
60 | return image, filename_grounding
61 |
62 | from sat.quantization.kernels import quantize
63 |
64 | def load_model(args):
65 | model, model_args = AutoModel.from_pretrained(
66 | args.from_pretrained,
67 | args=argparse.Namespace(
68 | deepspeed=None,
69 | local_rank=0,
70 | rank=0,
71 | world_size=world_size,
72 | model_parallel_size=world_size,
73 | mode='inference',
74 | fp16=args.fp16,
75 | bf16=args.bf16,
76 | skip_init=True,
77 | use_gpu_initialization=True if (torch.cuda.is_available() and args.quant is None) else False,
78 | device='cpu' if args.quant else 'cuda'),
79 | overwrite_args={'model_parallel_size': world_size} if world_size != 1 else {}
80 | )
81 | model = model.eval()
82 | assert world_size == get_model_parallel_world_size(), "world size must equal to model parallel size for cli_demo!"
83 |
84 | language_processor_version = model_args.text_processor_version if 'text_processor_version' in model_args else args.version
85 | tokenizer = llama2_tokenizer(args.local_tokenizer, signal_type=language_processor_version)
86 | image_processor = get_image_processor(model_args.eva_args["image_size"][0])
87 | cross_image_processor = get_image_processor(model_args.cross_image_pix) if "cross_image_pix" in model_args else None
88 |
89 | if args.quant:
90 | quantize(model, args.quant)
91 | if torch.cuda.is_available():
92 | model = model.cuda()
93 | model.add_mixin('auto-regressive', CachedAutoregressiveMixin())
94 |
95 | text_processor_infer = llama2_text_processor_inference(tokenizer, args.max_length, model.image_length)
96 |
97 | return model, image_processor, cross_image_processor, text_processor_infer
98 |
99 |
100 | def post(
101 | input_text,
102 | temperature,
103 | top_p,
104 | top_k,
105 | image_prompt,
106 | result_previous,
107 | hidden_image,
108 | state
109 | ):
110 | result_text = [(ele[0], ele[1]) for ele in result_previous]
111 | for i in range(len(result_text)-1, -1, -1):
112 | if result_text[i][0] == "" or result_text[i][0] == None:
113 | del result_text[i]
114 | print(f"history {result_text}")
115 |
116 | global model, image_processor, cross_image_processor, text_processor_infer, is_grounding
117 |
118 | try:
119 | with torch.no_grad():
120 | pil_img, image_path_grounding = process_image_without_resize(image_prompt)
121 | response, _, cache_image = chat(
122 | image_path="",
123 | model=model,
124 | text_processor=text_processor_infer,
125 | img_processor=image_processor,
126 | query=input_text,
127 | history=result_text,
128 | cross_img_processor=cross_image_processor,
129 | image=pil_img,
130 | max_length=2048,
131 | top_p=top_p,
132 | temperature=temperature,
133 | top_k=top_k,
134 | invalid_slices=text_processor_infer.invalid_slices if hasattr(text_processor_infer, "invalid_slices") else [],
135 | no_prompt=False,
136 | args=state['args']
137 | )
138 | except Exception as e:
139 | print("error message", e)
140 | result_text.append((input_text, 'Timeout! Please wait a few minutes and retry.'))
141 | return "", result_text, hidden_image
142 |
143 | answer = response
144 | if is_grounding:
145 | parse_response(pil_img, answer, image_path_grounding)
146 | new_answer = answer.replace(input_text, "")
147 | result_text.append((input_text, new_answer))
148 | result_text.append((None, (image_path_grounding,)))
149 | else:
150 | result_text.append((input_text, answer))
151 | print(result_text)
152 | print('finished')
153 | return "", result_text, hidden_image
154 |
155 |
156 | def clear_fn(value):
157 | return "", default_chatbox, None
158 |
159 | def clear_fn2(value):
160 | return default_chatbox
161 |
162 |
163 | def main(args):
164 | global model, image_processor, cross_image_processor, text_processor_infer, is_grounding
165 | model, image_processor, cross_image_processor, text_processor_infer = load_model(args)
166 | is_grounding = 'grounding' in args.from_pretrained
167 |
168 | gr.close_all()
169 |
170 | with gr.Blocks(css='style.css') as demo:
171 | state = gr.State({'args': args})
172 |
173 | gr.Markdown(DESCRIPTION)
174 | gr.Markdown(NOTES)
175 |
176 |
177 | with gr.Row():
178 | with gr.Column(scale=5):
179 | with gr.Group():
180 | gr.Markdown(AGENT_NOTICE)
181 | gr.Markdown(GROUNDING_NOTICE)
182 | input_text = gr.Textbox(label='Input Text', placeholder='Please enter text prompt below and press ENTER.')
183 |
184 | with gr.Row():
185 | run_button = gr.Button('Generate')
186 | clear_button = gr.Button('Clear')
187 |
188 | image_prompt = gr.Image(type="filepath", label="Image Prompt", value=None)
189 |
190 | with gr.Row():
191 | temperature = gr.Slider(maximum=1, value=0.8, minimum=0, label='Temperature')
192 | top_p = gr.Slider(maximum=1, value=0.4, minimum=0, label='Top P')
193 | top_k = gr.Slider(maximum=100, value=10, minimum=1, step=1, label='Top K')
194 |
195 | with gr.Column(scale=5):
196 | result_text = gr.components.Chatbot(label='Multi-round conversation History', value=[("", "Hi, What do you want to know about this image?")], height=600)
197 | hidden_image_hash = gr.Textbox(visible=False)
198 |
199 |
200 | gr.Markdown(MAINTENANCE_NOTICE1)
201 |
202 | print(gr.__version__)
203 | run_button.click(fn=post,inputs=[input_text, temperature, top_p, top_k, image_prompt, result_text, hidden_image_hash, state],
204 | outputs=[input_text, result_text, hidden_image_hash])
205 | input_text.submit(fn=post,inputs=[input_text, temperature, top_p, top_k, image_prompt, result_text, hidden_image_hash, state],
206 | outputs=[input_text, result_text, hidden_image_hash])
207 | clear_button.click(fn=clear_fn, inputs=clear_button, outputs=[input_text, result_text, image_prompt])
208 | image_prompt.upload(fn=clear_fn2, inputs=clear_button, outputs=[result_text])
209 | image_prompt.clear(fn=clear_fn2, inputs=clear_button, outputs=[result_text])
210 |
211 |
212 | # demo.queue(concurrency_count=10)
213 | demo.launch()
214 |
215 |
216 | if __name__ == '__main__':
217 | import argparse
218 | parser = argparse.ArgumentParser()
219 | parser.add_argument("--max_length", type=int, default=2048, help='max length of the total sequence')
220 | parser.add_argument("--top_p", type=float, default=0.4, help='top p for nucleus sampling')
221 | parser.add_argument("--top_k", type=int, default=1, help='top k for top k sampling')
222 | parser.add_argument("--temperature", type=float, default=.8, help='temperature for sampling')
223 | parser.add_argument("--version", type=str, default="chat", choices=['chat', 'vqa', 'chat_old', 'base'], help='version of language process. if there is \"text_processor_version\" in model_config.json, this option will be overwritten')
224 | parser.add_argument("--quant", choices=[8, 4], type=int, default=None, help='quantization bits')
225 | parser.add_argument("--from_pretrained", type=str, default="cogagent-chat", help='pretrained ckpt')
226 | parser.add_argument("--local_tokenizer", type=str, default="lmsys/vicuna-7b-v1.5", help='tokenizer path')
227 | parser.add_argument("--fp16", action="store_true")
228 | parser.add_argument("--bf16", action="store_true")
229 | parser.add_argument("--stream_chat", action="store_true")
230 | args = parser.parse_args()
231 | rank = int(os.environ.get('RANK', 0))
232 | world_size = int(os.environ.get('WORLD_SIZE', 1))
233 | args = parser.parse_args()
234 | main(args)
235 |
--------------------------------------------------------------------------------
/composite_demo/client.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 | from threading import Thread
3 |
4 | import streamlit as st
5 | import torch
6 | import warnings
7 | import os
8 |
9 | from typing import Any, Protocol
10 | from collections.abc import Iterable
11 | from huggingface_hub.inference._text_generation import TextGenerationStreamResponse, Token
12 | from transformers import AutoTokenizer, TextIteratorStreamer, AutoModelForCausalLM
13 | from conversation import Conversation
14 |
15 | # Check if GPU supports bfloat16
16 |
17 | if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8:
18 | torch_type = torch.bfloat16
19 | else:
20 | torch_type = torch.float16
21 | warnings.warn("Your GPU does not support bfloat16 type, use fp16 instead")
22 |
23 | # if you use all of Our model, include cogagent-chat cogvlm-chat cogvlm-grounding and put it in different devices, you can do like this.
24 | models_info = {
25 | 'tokenizer': {
26 | 'path': os.environ.get('TOKENIZER_PATH', 'lmsys/vicuna-7b-v1.5'),
27 | },
28 | 'agent_chat': {
29 | 'path': os.environ.get('MODEL_PATH_AGENT_CHAT', 'THUDM/cogagent-chat-hf'),
30 | 'device': ['cuda:0']
31 | },
32 | 'vlm_chat': {
33 | 'path': os.environ.get('MODEL_PATH_VLM_CHAT', 'THUDM/cogvlm-chat-hf'),
34 | 'device': ['cuda:3']
35 | },
36 | 'vlm_grounding': {
37 | 'path': os.environ.get('MODEL_PATH_VLM_GROUNDING','THUDM/cogvlm-grounding-generalist-hf'),
38 | 'device': ['cuda:6']
39 | }
40 | }
41 |
42 |
43 | # if you just use one model, use like this
44 | # models_info = {
45 | # 'tokenizer': {
46 | # 'path': os.environ.get('TOKENIZER_PATH', 'lmsys/vicuna-7b-v1.5'),
47 | # },
48 | # 'agent_chat': {
49 | # 'path': os.environ.get('MODEL_PATH_AGENT_CHAT', 'THUDM/cogagent-chat-hf'),
50 | # 'device': ['cuda:0']
51 | # },
52 |
53 |
54 |
55 | @st.cache_resource
56 | def get_client() -> Client:
57 | client = HFClient(models_info)
58 | return client
59 |
60 |
61 | def process_history(history: list[Conversation]):
62 | """
63 | Process the input history to extract the query and the history pairs.
64 | Args:
65 | History(list[Conversation]): A list of Conversation objects representing all conversations.
66 | Returns:
67 | query(str): The current user input string.
68 | history_pairs(list[(str,str)]): A list of (user, assistant) pairs.
69 | last_user_image(Image): The last user image. Only the latest image.
70 |
71 | """
72 | history_pairs = []
73 | query = ""
74 | last_user_image = None
75 |
76 | user_text = None
77 | for i, conversation in enumerate(history):
78 | if conversation.role == conversation.role.USER:
79 | user_text = conversation.content
80 | if conversation.image:
81 | last_user_image = conversation.image
82 |
83 | if i == len(history) - 1:
84 | query = conversation.content
85 |
86 | else:
87 | if user_text is not None:
88 | history_pairs.append((user_text, conversation.content))
89 | user_text = None
90 | return query, history_pairs, last_user_image
91 |
92 |
93 | class Client(Protocol):
94 | def generate_stream(self,
95 | history: list[Conversation],
96 | grounding: bool = False,
97 | model_use: str = 'agent_chat',
98 | **parameters: Any
99 | ) -> Iterable[TextGenerationStreamResponse]:
100 | ...
101 |
102 |
103 | class HFClient(Client):
104 | """
105 | The HFClient class manages the interaction with various large language models
106 | for text generation tasks. It supports handling multiple models, each designated
107 | for a specific task like chatting or grounding.
108 |
109 | Args:
110 | models_info (dict): A dictionary containing the configuration for each model.
111 | The dictionary format is:
112 | - 'tokenizer': Path and settings for the tokenizer.
113 | - 'agent_chat': Path and settings for the CogAgent-chat-18B model.
114 | - 'vlm_chat': Path and settings for the CogVLM-chat-17B model.
115 | - 'vlm_grounding': Path and settings for the CogVLM-grounding-17B model.
116 |
117 | The class loads each model based on the provided information and assigns it to the
118 | specified CUDA device. It also handles the tokenizer used across all models.
119 | """
120 | def __init__(self, models_info):
121 | self.models = {}
122 | self.tokenizer = AutoTokenizer.from_pretrained(models_info['tokenizer']['path'], trust_remote_code=True)
123 | for model_name, model_info in models_info.items():
124 | if model_name != 'tokenizer':
125 | self.models[model_name] = []
126 | for device in model_info['device']:
127 | model = AutoModelForCausalLM.from_pretrained(
128 | model_info['path'],
129 | torch_dtype=torch_type,
130 | low_cpu_mem_usage=True,
131 | trust_remote_code=True,
132 | ).to(device).eval()
133 | self.models[model_name].append(model)
134 |
135 | def select_best_gpu(self, model_name):
136 | min_memory_used = None
137 | selected_model = None
138 |
139 | for model in self.models[model_name]:
140 | device = next(model.parameters()).device
141 | mem_used = torch.cuda.memory_allocated(device=device)
142 |
143 | if min_memory_used is None or mem_used < min_memory_used:
144 | min_memory_used = mem_used
145 | selected_model = model
146 |
147 | return selected_model
148 |
149 | def generate_stream(self,
150 | history: list,
151 | grounding: bool = False,
152 | model_use: str = 'agent_chat',
153 | **parameters: Any
154 | ) -> Iterable[TextGenerationStreamResponse]:
155 | """
156 | Generates a stream of text responses based on the input history and selected model.
157 |
158 | This method facilitates a chat-like interaction with the models. Depending on the
159 | model selected and whether grounding is enabled, it alters the behavior of the text
160 | generation process.
161 |
162 | Args:
163 | history (list[Conversation]): A list of Conversation objects representing the
164 | dialogue history.
165 | grounding (bool, optional): A flag to indicate whether grounding should be used
166 | in the generation process. Defaults to False.
167 | model_use (str, optional): The key name of the model to be used for the generation.
168 | Defaults to 'agent_chat'.
169 | **parameters (Any): Additional parameters that may be required for the generation
170 | process.
171 |
172 | Yields:
173 | Iterable[TextGenerationStreamResponse]: A stream of text generation responses, each
174 | encapsulating a generated piece of text.
175 |
176 | The method selects the appropriate model based on `model_use`, processes the input
177 | history, and feeds it into the model to generate text. It uses threading to handle
178 | the generation process efficiently.
179 | """
180 | query, history, image = process_history(history)
181 | if grounding:
182 | query += "(with grounding)"
183 |
184 | model = self.select_best_gpu(model_use)
185 | device = next(model.parameters()).device
186 |
187 | # Print user input info
188 |
189 | print("\n== Input ==\n", query)
190 | print("\n==History==\n", history)
191 | print("\n== Model ==\n\n", model.config.name_or_path)
192 | print("\n== Device ==\n\n", device)
193 |
194 | input_by_model = model.build_conversation_input_ids(
195 | self.tokenizer,
196 | query=query,
197 | history=history,
198 | images=[image]
199 | )
200 | inputs = {
201 | 'input_ids': input_by_model['input_ids'].unsqueeze(0).to(device),
202 | 'token_type_ids': input_by_model['token_type_ids'].unsqueeze(0).to(device),
203 | 'attention_mask': input_by_model['attention_mask'].unsqueeze(0).to(device),
204 | 'images': [[input_by_model['images'][0].to(device).to(torch_type)]],
205 | }
206 |
207 | # CogVLM model do not have param 'cross_images', Only CogAgent have.
208 |
209 | if 'cross_images' in input_by_model and input_by_model['cross_images']:
210 | inputs['cross_images'] = [[input_by_model['cross_images'][0].to(device).to(torch_type)]]
211 |
212 | # Use TextIteratorStreamer for streaming generation like huggingface.
213 |
214 | streamer = TextIteratorStreamer(self.tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True)
215 | parameters['streamer'] = streamer
216 | gen_kwargs = {**parameters, **inputs}
217 | with torch.no_grad():
218 | thread = Thread(target=model.generate, kwargs=gen_kwargs)
219 | thread.start()
220 | for next_text in streamer:
221 | yield TextGenerationStreamResponse(
222 | token=Token(
223 | id=0,
224 | logprob=0,
225 | text=next_text,
226 | special=False,
227 | )
228 | )
229 |
--------------------------------------------------------------------------------
/composite_demo/conversation.py:
--------------------------------------------------------------------------------
1 | import requests
2 | import re
3 | import streamlit as st
4 |
5 | from dataclasses import dataclass
6 | from enum import auto, Enum
7 | from PIL.Image import Image
8 | from PIL import ImageDraw
9 | from streamlit.delta_generator import DeltaGenerator
10 |
11 |
12 | class Role(Enum):
13 | """
14 | CogVLM | CogAgent Only have 2 roles: USER, ASSISTANT
15 |
16 | Represents the roles in a conversation, specifically for CogVLM and CogAgent applications.
17 |
18 | There are two roles available:
19 | - USER: The user of the system, typically the one asking questions or initiating conversation.
20 | - ASSISTANT: The system or AI assistant responding to the user's queries.
21 |
22 | Methods:
23 | get_message(self):
24 | Retrieves a Streamlit chat message component based on the role. For the USER role, it
25 | returns a chat message with the name "user" and user avatar. For the ASSISTANT role,
26 | it returns a chat message with the name "assistant" and assistant avatar.
27 | """
28 |
29 | USER = auto()
30 | ASSISTANT = auto()
31 |
32 | def get_message(self):
33 |
34 | match self.value:
35 | case Role.USER.value:
36 | return st.chat_message(name="user", avatar="user")
37 | case Role.ASSISTANT.value:
38 | return st.chat_message(name="assistant", avatar="assistant")
39 | case _:
40 | st.error(f'Unexpected role: {self}')
41 |
42 |
43 | @dataclass
44 | class Conversation:
45 | """
46 | Represents a single conversation turn within a dialogue.
47 | Attributes:
48 | role (Role): The role of the speaker in the conversation (USER or ASSISTANT).
49 | content (str): The textual content of the conversation turn.
50 | image (Image, optional): An optional image associated with the conversation turn.
51 | content_show (str, optional): The content to be displayed in the WebUI. This may differ
52 | from `content` if translation or other processing is applied.
53 | translate (bool, optional): Whether to translate the content of the conversation turn.
54 |
55 | Methods:
56 | __str__(self) -> str:
57 | Returns a string representation of the conversation turn, including the role and content.
58 |
59 | show(self, placeholder: DeltaGenerator | None = None) -> str:
60 | Displays the conversation turn in the WebUI. If `placeholder` is provided, the content
61 | is shown in the specified Streamlit container. Otherwise, it uses the message style
62 | determined by the role.
63 | """
64 |
65 | role: Role = Role.USER
66 | content: str = ""
67 | image: Image | None = None
68 | content_show: str | None = None
69 | translate: bool = False
70 |
71 | def __str__(self) -> str:
72 | print(self.role, self.content)
73 | match self.role:
74 | case Role.USER | Role.ASSISTANT:
75 | return f'{self.role}\n{self.content}'
76 |
77 | def show(self, placeholder: DeltaGenerator | None = None) -> str:
78 | """
79 | show in markdown formate
80 | """
81 | if placeholder:
82 | message = placeholder
83 | else:
84 | message = self.role.get_message()
85 |
86 | # for Chinese WebUI show
87 | if self.role == Role.USER:
88 | if self.translate:
89 | self.content = translate_baidu(self.content_show, source_lan="zh", target_lan="en")
90 | if self.content == "error":
91 | self.content_show = "Please Enter your Baidu Translation API Key in function translate_baidu()"
92 | else:
93 | self.content = self.content_show
94 | if self.role == Role.ASSISTANT:
95 | if self.translate:
96 | self.content_show = translate_baidu(self.content, source_lan="en", target_lan="zh")
97 | else:
98 | self.content_show = self.content
99 |
100 | self.content_show = self.content_show.replace('\n', ' \n')
101 |
102 | message.markdown(self.content_show)
103 | if self.image:
104 | message.image(self.image)
105 |
106 |
107 | def preprocess_text(history: list[Conversation], ) -> str:
108 | """
109 | Prepares the conversation history for processing by concatenating the content of each turn.
110 | Args:
111 | history (list[Conversation]): The conversation history, a list of Conversation objects.
112 |
113 | Returns:
114 | str: A single string that concatenates the content of each conversation turn, followed by
115 | the ASSISTANT role indicator. This string is suitable for use as input to a text generation model.
116 | """
117 |
118 | prompt = ""
119 | for conversation in history:
120 | prompt += f'{conversation}'
121 | prompt += f'{Role.ASSISTANT}\n'
122 | return prompt
123 |
124 |
125 | def postprocess_text(template: str, text: str) -> str:
126 | """
127 | Post-processes the generated text by incorporating it into a given template.
128 | Args:
129 | template (str): A template string containing a placeholder for the generated text.
130 | text (str): The generated text to be incorporated into the template.
131 |
132 | Returns:
133 | str: The template with the generated text replacing the placeholder.
134 | """
135 | quoted_text = f'"{text.strip()}"'
136 | return template.replace("", quoted_text).strip() if template != "" else text.strip()
137 |
138 |
139 | def postprocess_image(text: str, img: Image) -> (str, Image):
140 | """
141 | Processes the given text to identify and draw bounding boxes on the provided image.
142 | This function searches for patterns in the text that represent coordinates for bounding
143 | boxes and draws rectangles on the image at these coordinates. Each box is drawn in a
144 | different color for distinction.
145 | Args:
146 | text (str): The text containing bounding box coordinates in a specific pattern.
147 | img (Image): The image on which to draw the bounding boxes.
148 | Returns:
149 | tuple[str, Image]: The processed text with additional annotations for each bounding
150 | box, and the image with the drawn bounding boxes.
151 | """
152 | colors = ["red", "green", "blue", "yellow", "purple", "orange"]
153 |
154 | # Updated pattern to match single or multiple coordinate groups
155 | pattern = r"\[\[([\d,]+(?:;[\d,]+)*)\]\]"
156 | matches = re.findall(pattern, text)
157 | draw = ImageDraw.Draw(img)
158 |
159 | if not matches:
160 | return text, None
161 |
162 | for i, match in enumerate(matches):
163 | # Splitting the matched string into individual coordinate groups
164 | coords_groups = match.split(';')
165 |
166 | # Determining the color for the current match
167 | color = colors[i % len(colors)]
168 |
169 | for coords_str in coords_groups:
170 | coords = coords_str.split(',')
171 |
172 | if len(coords) == 4: # Rectangle
173 | scaled_coords = (
174 | int(float(coords[0]) * 0.001 * img.width),
175 | int(float(coords[1]) * 0.001 * img.height),
176 | int(float(coords[2]) * 0.001 * img.width),
177 | int(float(coords[3]) * 0.001 * img.height)
178 | )
179 | draw.rectangle(scaled_coords, outline=color, width=3)
180 | elif len(coords) == 2: # Point
181 | scaled_coords = (
182 | int(float(coords[0]) * 0.001 * img.width),
183 | int(float(coords[1]) * 0.001 * img.height)
184 | )
185 | radius = 5
186 | draw.ellipse([scaled_coords[0] - radius, scaled_coords[1] - radius,
187 | scaled_coords[0] + radius, scaled_coords[1] + radius],
188 | fill=color)
189 |
190 | return text, img
191 |
192 | def translate_baidu(translate_text, source_lan, target_lan):
193 | """
194 | Translates text using Baidu's translation service. (if you are not use English)
195 |
196 | This function sends a request to the Baidu translation API to translate the provided text
197 | from the source language to the target language.
198 |
199 | Args:
200 | translate_text (str): The text to be translated.
201 | source_lan (str): The source language code (e.g., "en" for English).
202 | target_lan (str): The target language code (e.g., "zh" for Chinese).
203 |
204 | Returns:
205 | str: The translated text. Returns "error" in case of an exception.
206 | """
207 | url = "https://aip.baidubce.com/rpc/2.0/mt/texttrans/v1?access_token="
208 | headers = {'Content-Type': 'application/json'}
209 | payload = {
210 | 'q': translate_text,
211 | 'from': source_lan,
212 | 'to': target_lan
213 | }
214 | try:
215 | r = requests.post(url, json=payload, headers=headers)
216 | result = r.json()
217 | final_translation = ''
218 |
219 | for item in result['result']['trans_result']:
220 | final_translation += item['dst'] + '\n'
221 | except Exception as e:
222 | print(e)
223 | return "error"
224 | return final_translation
225 |
--------------------------------------------------------------------------------
/composite_demo/demo_agent_cogagent.py:
--------------------------------------------------------------------------------
1 | from io import BytesIO
2 | import base64
3 | import streamlit as st
4 | import re
5 |
6 | from streamlit.delta_generator import DeltaGenerator
7 | from client import get_client
8 | from conversation import postprocess_text, Conversation, Role, postprocess_image
9 | from PIL import Image
10 | from utils import images_are_same
11 |
12 | client = get_client()
13 |
14 |
15 | def append_conversation(
16 | conversation: Conversation,
17 | history: list[Conversation],
18 | placeholder: DeltaGenerator | None = None,
19 | ) -> None:
20 | history.append(conversation)
21 | conversation.show(placeholder)
22 |
23 |
24 | def main(
25 | top_p: float = 0.8,
26 | temperature: float = 0.95,
27 | prompt_text: str = "",
28 | metadata: str = "",
29 | top_k: int = 2,
30 | max_new_tokens: int = 2048,
31 | grounding: bool = False,
32 | retry: bool = False,
33 | template: str = ""
34 | ):
35 | if 'chat_history' not in st.session_state:
36 | st.session_state.chat_history = []
37 |
38 | if prompt_text == "" and retry == False:
39 | print("\n== Clean ==\n")
40 | st.session_state.chat_history = []
41 | return
42 |
43 | history: list[Conversation] = st.session_state.chat_history
44 | for conversation in history:
45 | conversation.show()
46 |
47 | if retry:
48 | print("\n== Retry ==\n")
49 | last_user_conversation_idx = None
50 | for idx, conversation in enumerate(history):
51 | if conversation.role == Role.USER:
52 | last_user_conversation_idx = idx
53 | if last_user_conversation_idx is not None:
54 | prompt_text = history[last_user_conversation_idx].content_show
55 | del history[last_user_conversation_idx:]
56 |
57 | if prompt_text:
58 | image = Image.open(BytesIO(base64.b64decode(metadata))).convert('RGB') if metadata else None
59 | image.thumbnail((1120, 1120))
60 | image_input = image
61 | if history and image:
62 | last_user_image = next(
63 | (conv.image for conv in reversed(history) if conv.role == Role.USER and conv.image), None)
64 | if last_user_image and images_are_same(image, last_user_image):
65 | image_input = None
66 |
67 | # Not necessary to clear history
68 | # else:
69 | # # new picture means new conversation
70 | # st.session_state.chat_history = []
71 | # history = []
72 |
73 | # Set conversation
74 | if re.search('[\u4e00-\u9fff]', prompt_text):
75 | translate = True
76 | else:
77 | translate = False
78 |
79 | user_conversation = Conversation(
80 | role=Role.USER,
81 | translate=translate,
82 | content_show=prompt_text.strip() if retry else postprocess_text(template=template,
83 | text=prompt_text.strip()),
84 | image=image_input
85 | )
86 | append_conversation(user_conversation, history)
87 | placeholder = st.empty()
88 | assistant_conversation = placeholder.chat_message(name="assistant", avatar="assistant")
89 | assistant_conversation = assistant_conversation.empty()
90 |
91 | # steam Answer
92 | output_text = ''
93 | for response in client.generate_stream(
94 | model_use='agent_chat',
95 | grounding=grounding,
96 | history=history,
97 | do_sample=True,
98 | max_new_tokens=max_new_tokens,
99 | temperature=temperature,
100 | top_p=top_p,
101 | top_k=top_k,
102 | ):
103 | output_text += response.token.text
104 | assistant_conversation.markdown(output_text.strip() + '▌')
105 |
106 | ## Final Answer with image.
107 | print("\n==Output:==\n", output_text)
108 | content_output, image_output = postprocess_image(output_text, image)
109 | assistant_conversation = Conversation(
110 | role=Role.ASSISTANT,
111 | content=content_output,
112 | image=image_output,
113 | translate=translate,
114 | )
115 | append_conversation(
116 | conversation=assistant_conversation,
117 | history=history,
118 | placeholder=placeholder.chat_message(name="assistant", avatar="assistant"),
119 | )
120 |
--------------------------------------------------------------------------------
/composite_demo/demo_chat_cogagent.py:
--------------------------------------------------------------------------------
1 | import streamlit as st
2 | import base64
3 | import re
4 |
5 | from PIL import Image
6 | from io import BytesIO
7 | from streamlit.delta_generator import DeltaGenerator
8 | from client import get_client
9 | from utils import images_are_same
10 | from conversation import Conversation, Role, postprocess_image, postprocess_text
11 |
12 | client = get_client()
13 |
14 |
15 | def append_conversation(
16 | conversation: Conversation,
17 | history: list[Conversation],
18 | placeholder: DeltaGenerator | None = None,
19 | ) -> None:
20 | history.append(conversation)
21 | conversation.show(placeholder)
22 |
23 |
24 | def main(
25 | top_p: float = 0.8,
26 | temperature: float = 0.95,
27 | prompt_text: str = "",
28 | metadata: str = "",
29 | top_k: int = 2,
30 | max_new_tokens: int = 2048,
31 | grounding: bool = False,
32 | retry: bool = False,
33 | template: str = "",
34 | ):
35 | if 'chat_history' not in st.session_state:
36 | st.session_state.chat_history = []
37 |
38 | if prompt_text == "" and retry == False:
39 | print("\n== Clean ==\n")
40 | st.session_state.chat_history = []
41 | return
42 |
43 | history: list[Conversation] = st.session_state.chat_history
44 | for conversation in history:
45 | conversation.show()
46 | if retry:
47 | last_user_conversation_idx = None
48 | for idx, conversation in enumerate(history):
49 | if conversation.role == Role.USER:
50 | last_user_conversation_idx = idx
51 | if last_user_conversation_idx is not None:
52 | prompt_text = history[last_user_conversation_idx].content_show
53 | del history[last_user_conversation_idx:]
54 |
55 | if prompt_text:
56 | image = Image.open(BytesIO(base64.b64decode(metadata))).convert('RGB') if metadata else None
57 | image.thumbnail((1120, 1120))
58 | image_input = image
59 | if history and image:
60 | last_user_image = next(
61 | (conv.image for conv in reversed(history) if conv.role == Role.USER and conv.image), None)
62 | if last_user_image and images_are_same(image, last_user_image):
63 | image_input = None
64 | else:
65 | st.session_state.chat_history = []
66 | history = []
67 |
68 | # Set conversation
69 | if re.search('[\u4e00-\u9fff]', prompt_text):
70 | translate = True
71 | else:
72 | translate = False
73 |
74 | user_conversation = Conversation(
75 | role=Role.USER,
76 | translate=translate,
77 | content_show=prompt_text.strip() if retry else postprocess_text(template=template,
78 | text=prompt_text.strip()),
79 | image=image_input
80 | )
81 | append_conversation(user_conversation, history)
82 | placeholder = st.empty()
83 | assistant_conversation = placeholder.chat_message(name="assistant", avatar="assistant")
84 | assistant_conversation = assistant_conversation.empty()
85 |
86 | # steam Answer
87 | output_text = ''
88 | for response in client.generate_stream(
89 | model_use='agent_chat',
90 | grounding=grounding,
91 | history=history,
92 | do_sample=True,
93 | max_new_tokens=max_new_tokens,
94 | temperature=temperature,
95 | top_p=top_p,
96 | top_k=top_k,
97 | ):
98 | output_text += response.token.text
99 | assistant_conversation.markdown(output_text.strip() + '▌')
100 |
101 | print("\n==Output:==\n", output_text)
102 | content_output, image_output = postprocess_image(output_text, image)
103 | assistant_conversation = Conversation(
104 | role=Role.ASSISTANT,
105 | content=content_output,
106 | image=image_output,
107 | translate=translate
108 | )
109 | append_conversation(
110 | conversation=assistant_conversation,
111 | history=history,
112 | placeholder=placeholder.chat_message(name="assistant", avatar="assistant")
113 | )
114 |
--------------------------------------------------------------------------------
/composite_demo/demo_chat_cogvlm.py:
--------------------------------------------------------------------------------
1 | import streamlit as st
2 | import base64
3 | import re
4 |
5 | from PIL import Image
6 | from io import BytesIO
7 | from streamlit.delta_generator import DeltaGenerator
8 | from client import get_client
9 | from utils import images_are_same
10 | from conversation import Conversation, Role, postprocess_image, postprocess_text
11 |
12 | client = get_client()
13 |
14 |
15 | def append_conversation(
16 | conversation: Conversation,
17 | history: list[Conversation],
18 | placeholder: DeltaGenerator | None = None,
19 | ) -> None:
20 | history.append(conversation)
21 | conversation.show(placeholder)
22 |
23 |
24 | def main(
25 | top_p: float = 0.8,
26 | temperature: float = 0.95,
27 | prompt_text: str = "",
28 | metadata: str = "",
29 | top_k: int = 2,
30 | max_new_tokens: int = 2048,
31 | grounding: bool = False,
32 | retry: bool = False,
33 | template: str = "",
34 | ):
35 | if 'chat_history' not in st.session_state:
36 | st.session_state.chat_history = []
37 |
38 | if prompt_text == "" and retry == False:
39 | print("\n== Clean ==\n")
40 | st.session_state.chat_history = []
41 | return
42 |
43 | history: list[Conversation] = st.session_state.chat_history
44 | for conversation in history:
45 | conversation.show()
46 | if retry:
47 | last_user_conversation_idx = None
48 | for idx, conversation in enumerate(history):
49 | if conversation.role == Role.USER:
50 | last_user_conversation_idx = idx
51 | if last_user_conversation_idx is not None:
52 | prompt_text = history[last_user_conversation_idx].content_show
53 | del history[last_user_conversation_idx:]
54 |
55 | if prompt_text:
56 | image = Image.open(BytesIO(base64.b64decode(metadata))).convert('RGB') if metadata else None
57 | image.thumbnail((1120, 1120))
58 | image_input = image
59 | if history and image:
60 | last_user_image = next(
61 | (conv.image for conv in reversed(history) if conv.role == Role.USER and conv.image), None)
62 | if last_user_image and images_are_same(image, last_user_image):
63 | image_input = None
64 | else:
65 | st.session_state.chat_history = []
66 | history = []
67 |
68 | # Set conversation
69 | if re.search('[\u4e00-\u9fff]', prompt_text):
70 | translate = True
71 | else:
72 | translate = False
73 |
74 | user_conversation = Conversation(
75 | role=Role.USER,
76 | translate=translate,
77 | content_show=prompt_text.strip() if retry else postprocess_text(template=template,
78 | text=prompt_text.strip()),
79 | image=image_input
80 | )
81 | append_conversation(user_conversation, history)
82 | placeholder = st.empty()
83 | assistant_conversation = placeholder.chat_message(name="assistant", avatar="assistant")
84 | assistant_conversation = assistant_conversation.empty()
85 |
86 | # steam Answer
87 | output_text = ''
88 | for response in client.generate_stream(
89 | model_use='vlm_grounding' if grounding else 'vlm_chat',
90 | grounding=False,
91 | history=history,
92 | do_sample=True,
93 | max_new_tokens=max_new_tokens,
94 | temperature=temperature,
95 | top_p=top_p,
96 | top_k=top_k,
97 | ):
98 | output_text += response.token.text
99 | assistant_conversation.markdown(output_text.strip() + '▌')
100 |
101 | print("\n==Output:==\n", output_text)
102 | content_output, image_output = postprocess_image(output_text, image)
103 | assistant_conversation = Conversation(
104 | role=Role.ASSISTANT,
105 | content=content_output,
106 | image=image_output,
107 | translate=translate
108 | )
109 | append_conversation(
110 | conversation=assistant_conversation,
111 | history=history,
112 | placeholder=placeholder.chat_message(name="assistant", avatar="assistant")
113 | )
114 |
--------------------------------------------------------------------------------
/composite_demo/main.py:
--------------------------------------------------------------------------------
1 | """
2 | This is a demo using the chat version about CogAgent and CogVLM in WebDEMO
3 |
4 | Make sure you have installed the vicuna-7b-v1.5 tokenizer model (https://huggingface.co/lmsys/vicuna-7b-v1.5),
5 | and a full checkpoint of vicuna-7b-v1.5 LLM is not required.
6 |
7 | Mention that only one image can be processed in a conversation, which means you cannot replace or insert another image
8 | during the conversation.
9 |
10 |
11 | The models_info parameter is explained as follows
12 | tokenizer: tokenizer model using vicuna-7b-v1.5 model
13 | agent_chat: Use the CogAgent-chat-18B model to complete the conversation task
14 | vlm_chat: Use the CogVLM-chat-17B model to complete the conversation task
15 | vlm_grounding: Use CogVLM-grounding-17B model to complete the Grounding task
16 |
17 | Web Demo user operation logic is as follows:
18 | CogVLM-Chat -> grounding? - yes -> Choose a template -> CogVLM-grounding-17B
19 | - no -> CogVLM-chat-17B (without grounding)
20 |
21 | CogAgent-Chat -> CogAgent-chat-18B (Only QA,without Grounding)
22 |
23 | CogAgent-Agent -> CogAgent-chat-18B
24 | -> Choose a template -> grounding? - yes -> prompt + (with grounding)
25 | - no -> prompt
26 |
27 | CogAgent-vqa-hf are not included in this demo, but you can use it in the same way as CogAgent-chat-18B
28 | and used it in CogAgent-Chat
29 | """
30 |
31 | import streamlit as st
32 |
33 | st.set_page_config(
34 | page_title="CogVLM & CogAgent Demo",
35 | page_icon=":robot:",
36 | layout='centered',
37 | initial_sidebar_state='expanded',
38 | )
39 |
40 | from enum import Enum
41 | from utils import encode_file_to_base64, templates_agent_cogagent, template_grounding_cogvlm
42 | import demo_chat_cogvlm, demo_agent_cogagent, demo_chat_cogagent
43 |
44 | st.markdown("CogAgent & CogVLM Chat Demo
", unsafe_allow_html=True)
45 | st.markdown(
46 | "更多使用方法请参考文档: https://lslfd0slxc.feishu.cn/wiki/WvQbwIJ9tiPAxGk8ywDck6yfnof \n\n 请根据文档的引导说明来尝试demo,以便理解demo的布局设计 \n",
47 | unsafe_allow_html=True)
48 |
49 |
50 | class Mode(str, Enum):
51 | CogVLM_Chat, CogAgent_Chat, CogAgent_Agent = '💬CogVLM-Chat', '🧑💻 CogAgent-Chat', '💡 CogAgent-Agent'
52 |
53 |
54 | with st.sidebar:
55 | top_p = st.slider(
56 | 'top_p', 0.0, 1.0, 0.8, step=0.01
57 | )
58 | temperature = st.slider(
59 | 'temperature', 0.01, 1.0, 0.90, step=0.01
60 | )
61 | top_k = st.slider(
62 | 'top_k', 1, 20, 5, step=1
63 | )
64 | max_new_token = st.slider(
65 | 'Output length', 1, 2048, 2048, step=1
66 | )
67 |
68 | uploaded_file = st.file_uploader("Choose an image...", type=['.jpg', '.png', '.jpeg'], accept_multiple_files=False)
69 |
70 | cols = st.columns(2)
71 | export_btn = cols[0]
72 | clear_history = cols[1].button("Clear History", use_container_width=True)
73 | retry = export_btn.button("Retry", use_container_width=True)
74 |
75 | prompt_text = st.chat_input(
76 | 'Chat with CogAgent | CogVLM',
77 | key='chat_input',
78 | )
79 |
80 | tab = st.radio(
81 | 'Mode',
82 | [mode.value for mode in Mode],
83 | horizontal=True,
84 | label_visibility='hidden',
85 | )
86 |
87 | selected_template_grounding_cogvlm = ""
88 | with st.sidebar:
89 | grounding = st.checkbox("Grounding")
90 | if tab == Mode.CogVLM_Chat or tab == Mode.CogAgent_Chat:
91 | if grounding:
92 | selected_template_grounding_cogvlm = st.selectbox("Template For Grounding", template_grounding_cogvlm)
93 |
94 | if tab == Mode.CogAgent_Agent:
95 | with st.sidebar:
96 | selected_template_agent_cogagent = st.selectbox("Template For Agent", templates_agent_cogagent)
97 |
98 | if clear_history or retry:
99 | prompt_text = ""
100 |
101 | match tab:
102 | case Mode.CogVLM_Chat:
103 | st.info("This option uses cogvlm-chat and cogvlm-grounding model.")
104 | if uploaded_file is not None:
105 | demo_chat_cogvlm.main(
106 | retry=retry,
107 | top_p=top_p,
108 | top_k=top_k,
109 | temperature=temperature,
110 | prompt_text=prompt_text,
111 | metadata=encode_file_to_base64(uploaded_file),
112 | max_new_tokens=max_new_token,
113 | grounding=grounding,
114 | template=selected_template_grounding_cogvlm
115 | )
116 | else:
117 | st.error(f'Please upload an image to start')
118 |
119 | case Mode.CogAgent_Chat:
120 | st.info("This option uses cogagent-chat model.")
121 | if uploaded_file is not None:
122 | demo_chat_cogagent.main(
123 | retry=retry,
124 | top_p=top_p,
125 | top_k=top_k,
126 | temperature=temperature,
127 | prompt_text=prompt_text,
128 | metadata=encode_file_to_base64(uploaded_file),
129 | max_new_tokens=max_new_token,
130 | grounding=grounding,
131 | template=selected_template_grounding_cogvlm
132 | )
133 | else:
134 | st.error(f'Please upload an image to start')
135 |
136 | case Mode.CogAgent_Agent:
137 | st.info("This option uses cogagent-chat model with agent template.")
138 | if uploaded_file is not None:
139 | demo_agent_cogagent.main(
140 | retry=retry,
141 | top_p=top_p,
142 | top_k=top_k,
143 | temperature=temperature,
144 | prompt_text=prompt_text,
145 | metadata=encode_file_to_base64(uploaded_file),
146 | max_new_tokens=max_new_token,
147 | grounding=grounding,
148 | template=selected_template_agent_cogagent
149 | )
150 | else:
151 | st.error(f'Please upload an image to start')
152 | case _:
153 | st.error(f'Unexpected tab: {tab}')
154 |
--------------------------------------------------------------------------------
/dataset.md:
--------------------------------------------------------------------------------
1 | # CogVLM-SFT-311K: Bilingual Visual Instruction Data in CogVLM SFT
2 |
3 | CogVLM-SFT-311K is the primary aligned corpus used in the initial training of CogVLM v1.0. The process of constructing this dataset is as follows:
4 | 1. Approximately 3500 high-quality data samples were selected from the open source [MiniGPT-4](https://huggingface.co/datasets/Vision-CAIR/cc_sbu_align), known as minigpt4-3500.
5 | 2. Minigpt4-3500 was integrated with [Llava-Instruct-150K](https://huggingface.co/datasets/liuhaotian/LLaVA-Instruct-150K) and translated into Chinese through a language model.
6 | 3. We discovered significant noise in the detailed description part of minigpt4-3500 and Llava-instruct. Thus, we corrected these Chinese corpora and retranslated them into English.
7 |
8 | ## License
9 |
10 | + Due to non-commercial agreements, we did not use these data in the bilingual version of CogVLM or any other models involving commercialization.
11 | + The dataset license adheres to:
Attribution-NonCommercial 4.0 International. It should abide by the policy of OpenAI: https://openai.com/policies/terms-of-use
12 | This will not allow you to use these data for any **commercial activitiesI**.
13 |
14 | ## Dataset Address
15 |
16 | + [CogVLM-SFT-311K](https://huggingface.co/datasets/THUDM/CogVLM-SFT-311K)
17 |
18 | ## Dataset Information
19 |
20 | The dataset contains three folders corresponding to the mixed part of minigpt4-3500 and llava, the llava solo conversation, and the multi-turn conversation datasets. Their layout is as follows:
21 | ```
22 | .CogVLM-SFT-311K
23 | ├── llava_details-minigpt4_3500_formate
24 | ├── llava_instruction_multi_conversations_formate
25 | └── llava_instruction_single_conversation_formate
26 | ```
27 | In our open-source data, the datasets are distributed as follows:
28 | ```
29 | .llava_details-minigpt4_3500_formate
30 | ├── images
31 | │ └── 00000001.jpg
32 | └── labels
33 | └── 00000001.json
34 | ```
35 | Images are stored in the images folder, while labels containing the corresponding image description or dialogue are stored in the labels folder.
36 |
37 | ## Dataset Quantity
38 |
39 | + llava_details-minigpt4_3500_formate: 22,464 images and descriptions
40 | + llava_instruction_multi_conversations_formate: 56,673 images and multi-turn conversations
41 | + llava_instruction_single_conversation_formate: 76,634 images and single-turn conversations
42 |
43 | Dataset Format
44 |
45 | Caption format for image description
46 | ```
47 | {
48 | "captions": [
49 | {
50 | "role": "caption",
51 | "content": "The photograph features a beach scene with a group of people surfing in the ocean. There are ten individuals spaced out in the water, sitting or lying on various sizes of body boards. The surfers line up in a row, close to the sandy shoreline. Some of the body boards are visible floating on the surface of the water, with the surfers riding on them or nearby."
52 | }
53 | ]
54 | }
55 | ```
56 | Conversation format for image dialogue
57 | ```
58 | {
59 | "conversations": [
60 | {
61 | "role": "user",
62 | "content": "What can be inferred about the zebras' behavior and surroundings?"
63 | },
64 | {
65 | "role": "assistant",
66 | "content": "Based on the image, we can infer that the two zebras are likely seeking relief from the sun's heat, as they are standing side by side under the branches of a thorny tree. This shade-providing tree offers some respite from the sun, possibly during the hottest part of the day. The zebras are in a green field with grass, providing them with an ideal environment to graze and eat while staying near their source of shelter. This shows that the zebras' behavior is influenced by the conditions and available resources in their surroundings. It also highlights that these animals adopt strategies to adapt to the fluctuating conditions of their environment, such as cooperation and seeking shelter, to survive and thrive in their natural habitat."
67 | }
68 | ]
69 | }
70 | ```
71 |
72 | ## References
73 | This project utilizes data and concepts based on the following research papers:
74 | - Zhu, D., Chen, J., Shen, X., Li, X., & Elhoseiny, M. (2023). MiniGPT-4: Enhancing Vision-Language Understanding with Advanced Large Language Models. arXiv preprint arXiv:2304.10592.
75 | - Liu, H., Li, C., Wu, Q., & Lee, Y. J. (2023). Visual Instruction Tuning. arXiv:2304.08485.
--------------------------------------------------------------------------------
/dataset_zh.md:
--------------------------------------------------------------------------------
1 | # CogVLM-SFT-311K:CogVLM SFT 中的双语视觉指令数据集
2 |
3 | CogVLM-SFT-311K 是我们在训练 **CogVLM v1.0** 最初版本时使用的主要对齐语料库。此数据集的构建过程如下:
4 | 1. 从开源的 [MiniGPT-4](https://huggingface.co/datasets/Vision-CAIR/cc_sbu_align) 中选取了大约3500个高质量数据样本,称为 minigpt4-3500。
5 | 2. 将 minigpt4-3500 与 [Llava-Instruct-150K](https://huggingface.co/datasets/liuhaotian/LLaVA-Instruct-150K) 整合,并通过语言模型翻译获得中文部分。
6 | 3. 我们发现在 minigpt4-3500 和 Llava-instruct 的详细描述部分存在许多噪声。因此,我们纠正了这两部分的中文语料,并将纠正后的语料重新翻译成英语。
7 |
8 | ## 许可证
9 | + 由于非商业协议限制,我们没有在 CogVLM的双语版本 和其他任何 涉及商业化的模型 中使用这些数据。
10 | + 数据集许可证遵守:
Attribution-NonCommercial 4.0 International It should abide by the policy of OpenAI: https://openai.com/policies/terms-of-use
11 | 这将不允许你使用这些数据进行任何 **商业化行为**。
12 |
13 | ## 数据集地址
14 |
15 | + [CogVLM-SFT-311K](https://huggingface.co/datasets/THUDM/CogVLM-SFT-311K)
16 |
17 | ## 数据集信息
18 | 数据集共有三个文件夹,分别对应混合 minigpt4-3500 与llava混合的一部分数据集,llava 单论对话和多轮对话数据集。其布局如下:
19 | ```
20 | .CogVLM-SFT-311K
21 | ├── llava_details-minigpt4_3500_formate
22 | ├── llava_instruction_multi_conversations_formate
23 | └── llava_instruction_single_conversation_formate
24 | ```
25 | 在我们开源的数据中,数据集按照以下格式分布
26 | ```
27 | .llava_details-minigpt4_3500_formate
28 | ├── images
29 | │ └── 00000001.jpg
30 | └── labels
31 | └── 00000001.json
32 | ```
33 | 其中,images存储图像数据,而labels存储这张图像对应的描述或对话。
34 |
35 | ## 数据集数量
36 | + llava_details-minigpt4_3500_formate 22464 张图片和描述
37 | + llava_instruction_muti_conversations_formate 56673 张图片和多轮连续对话
38 | + llava_instruction_single_conversation_formate 76634 张图片和单轮对话
39 | ## 数据集格式
40 |
41 | 图像描述 Caption 格式
42 | ```
43 | {
44 | "captions": [
45 | {
46 | "role": "caption",
47 | "content": "这张照片展示了一男一女,他们都穿着红嘿色衣服,坐在桌子旁,在参加活动时面带微笑。这对情侣正在一起摆姿势,背景中还有其他人。在整个场景中可以看到另外几个人,一些人坐在附近的桌子旁,另一些人坐得离这对夫妇更近。各种餐桌设置,如酒杯和杯子,被放置在桌子上。排在最前面的男子打着领带。"
48 | }
49 | ]
50 | }
51 | ```
52 | 图像对话 Json 格式
53 | ```
54 | {
55 | "conversations": [
56 | {
57 | "role": "user",
58 | "content": "在尝试保持这个浴室有序时可能会遇到哪些挑战?",
59 | },
60 | {
61 | "role": "assistant",
62 | "content": "在尝试保持这空间和各种物品的存在,例如杯子和手机占用了可用的空间。在图像中,浴室柜台上只有一个水槽可用,这导致多个用户可能必须共享空间放置自己的物品。这可能导致杂乱和混乱外的储物解决方案,如架子、橱柜或墙壁挂架,以保持浴室有序并最小化柜台上的杂乱。"
63 | },
64 | ]
65 | }
66 | ```
67 |
68 | ## References
69 | This project utilizes data and concepts based on the following research papers:
70 | - Zhu, D., Chen, J., Shen, X., Li, X., & Elhoseiny, M. (2023). MiniGPT-4: Enhancing Vision-Language Understanding with Advanced Large Language Models. arXiv preprint arXiv:2304.10592.
71 | - Liu, H., Li, C., Wu, Q., & Lee, Y. J. (2023). Visual Instruction Tuning. arXiv:2304.08485.
--------------------------------------------------------------------------------
/finetune_demo/evaluate_cogagent.sh:
--------------------------------------------------------------------------------
1 | #! /bin/bash
2 | # export PATH=/usr/local/cuda/bin:$PATH
3 | # export LD_LIBRARY_PATH=/usr/local/cuda/lib64:$LD_LIBRARY_PATH
4 |
5 | NUM_GPUS_PER_WORKER=8
6 | MP_SIZE=1
7 |
8 | script_path=$(realpath $0)
9 | script_dir=$(dirname $script_path)
10 | main_dir=$(dirname $script_dir)
11 | MODEL_TYPE="cogagent-chat"
12 | VERSION="chat"
13 | # Tips: max_length should be longer than 256, to accomodate low-resolution image tokens
14 | MODEL_ARGS="--from_pretrained ./checkpoints/ft_cogagent_model \
15 | --max_length 400 \
16 | --local_tokenizer lmsys/vicuna-7b-v1.5 \
17 | --version $VERSION"
18 |
19 | OPTIONS_SAT="SAT_HOME=~/.sat_models"
20 | OPTIONS_NCCL="NCCL_DEBUG=info NCCL_IB_DISABLE=0 NCCL_NET_GDR_LEVEL=2 LOCAL_WORLD_SIZE=$NUM_GPUS_PER_WORKER"
21 | HOST_FILE_PATH="hostfile"
22 |
23 | train_data="./archive_split/train"
24 | test_data="./archive_split/test"
25 |
26 | gpt_options=" \
27 | --experiment-name finetune-$MODEL_TYPE \
28 | --model-parallel-size ${MP_SIZE} \
29 | --mode finetune \
30 | --train-iters 0 \
31 | --resume-dataloader \
32 | $MODEL_ARGS \
33 | --train-data ${train_data} \
34 | --test-data ${test_data} \
35 | --distributed-backend nccl \
36 | --lr-decay-style cosine \
37 | --warmup .02 \
38 | --checkpoint-activations \
39 | --save-interval 200 \
40 | --eval-interval 200 \
41 | --save "./checkpoints" \
42 | --strict-eval \
43 | --eval-batch-size 1 \
44 | --split 1. \
45 | --deepspeed_config test_config_bf16.json \
46 | --skip-init \
47 | --seed 2023
48 | "
49 |
50 |
51 |
52 | run_cmd="${OPTIONS_NCCL} ${OPTIONS_SAT} deepspeed --master_port 16666 --hostfile ${HOST_FILE_PATH} evaluate_cogagent_demo.py ${gpt_options}"
53 | echo ${run_cmd}
54 | eval ${run_cmd}
55 |
56 | set +x
--------------------------------------------------------------------------------
/finetune_demo/evaluate_cogagent_demo.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import argparse
4 | import sys
5 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
6 |
7 | from sat import mpu, get_args, get_tokenizer
8 | from sat.training.deepspeed_training import training_main
9 | from sat.helpers import print_rank0
10 | from collections import defaultdict
11 | from functools import partial
12 |
13 | from utils.models import FineTuneTestCogAgentModel
14 | from utils.utils import llama2_text_processor, llama2_text_processor_inference, get_image_processor
15 |
16 |
17 | def data_collator(examples, cross_image_processor=None):
18 | def to_tensor(value):
19 | """Converts lists or numpy arrays to tensors."""
20 | if isinstance(value, list):
21 | return torch.tensor(value)
22 | elif isinstance(value, np.ndarray):
23 | return torch.from_numpy(value)
24 | return value
25 |
26 | def concatenate_tensors(attribute, key):
27 | """Concatenates tensors for a specific attribute and key."""
28 | if attribute is None:
29 | return torch.cat([ex[key] for ex in examples if isinstance(ex[key], torch.Tensor)])
30 | else:
31 | return torch.cat([ex[attribute][key] for ex in examples if isinstance(ex[attribute][key], torch.Tensor)])
32 |
33 | # Convert all lists and numpy arrays in examples to tensors
34 | for example in examples:
35 | for key, value in example.items():
36 | example[key] = to_tensor(value)
37 |
38 | # Extract and concatenate attributes from examples
39 | img_args = {}
40 | for attribute in ['vision', 'cross']:
41 | if attribute == 'cross' and cross_image_processor is None:
42 | continue
43 |
44 | if attribute in examples[-1]: # Using the last example as reference
45 | for key in examples[-1][attribute]:
46 | tensor_key = f"{attribute}_{key}"
47 | tensors_to_concatenate = [ex[attribute][key] for ex in examples if isinstance(ex[attribute][key], torch.Tensor)]
48 | if tensors_to_concatenate:
49 | img_args[tensor_key] = concatenate_tensors(attribute, key)
50 | else:
51 | img_args[tensor_key] = examples[-1][attribute][key]
52 |
53 | # Remove 'vision' and 'cross' keys from examples
54 | for example in examples:
55 | example.pop('vision', None)
56 | example.pop('cross', None)
57 |
58 | # Create model_args by concatenating tensors and copying other attributes
59 | model_args = {key: concatenate_tensors(None, key)
60 | if isinstance(examples[-1][key], torch.Tensor) else examples[-1][key]
61 | for key in examples[-1]
62 | }
63 |
64 | # Merge img_args into model_args
65 | model_args.update(img_args)
66 | return model_args
67 |
68 | def broadcast_auto(data_dict):
69 | # Classify keys based on their data type
70 | tensor_keys_by_dtype = defaultdict(list)
71 | non_tensor_keys = []
72 |
73 | for key, value in data_dict.items():
74 | if isinstance(value, torch.Tensor):
75 | tensor_keys_by_dtype[value.dtype].append(key)
76 | else:
77 | non_tensor_keys.append(key)
78 |
79 | # Broadcast tensor data and collect in a new dictionary
80 | broadcasted_data = {}
81 | for dtype, keys in tensor_keys_by_dtype.items():
82 | broadcasted_data.update(mpu.broadcast_data(keys, data_dict, dtype))
83 |
84 | # Add non-tensor data to the new dictionary
85 | for key in non_tensor_keys:
86 | broadcasted_data[key] = data_dict[key]
87 |
88 | return broadcasted_data
89 |
90 | def get_batch(data_iterator, args, timers):
91 | # Broadcast data.
92 | timers('data loader').start()
93 | if data_iterator is not None:
94 | data = next(data_iterator)
95 | else:
96 | data = None
97 | timers('data loader').stop()
98 | data_b = broadcast_auto(data)
99 | for k in data_b:
100 | if type(data_b[k]) is torch.Tensor and data_b[k].dtype is not torch.int32 and data_b[k].dtype is not torch.long:
101 | if args.fp16:
102 | data_b[k] = data_b[k].half()
103 | elif args.bf16:
104 | data_b[k] = data_b[k].bfloat16()
105 | return data_b
106 |
107 | from torch.nn import CrossEntropyLoss
108 | import numpy as np
109 |
110 | from sat.model.mixins import CachedAutoregressiveMixin
111 | from sat.generation.autoregressive_sampling import filling_sequence
112 | from sat.generation.sampling_strategies import BaseStrategy, BeamSearchStrategy
113 |
114 |
115 | def chat(model, tokenizer, tokens,
116 | max_length: int = 1800, num_beams=5, top_p=0.95, top_k=0, temperature=0.8, **kwargs):
117 | inputs = tokens.to(model.parameters().__next__().device)[0]
118 | seq = torch.cat(
119 | [inputs, torch.tensor([-1] * (max_length - len(inputs)), device=inputs.device)], dim=0
120 | )
121 | strategy = BaseStrategy(temperature=temperature, top_p=0.4, top_k=1, end_tokens=[tokenizer.eos_token_id])
122 | # strategy = BeamSearchStrategy(temperature=temperature, top_p=top_p, top_k=top_k, end_tokens=[tokenizer.eos_token_id],
123 | # num_beams=num_beams, consider_end=True)
124 | get_func = llama2_text_processor_inference.get_func(None, None, image_rope_mask=kwargs['image_rope_mask'])
125 | output = filling_sequence(
126 | model, seq,
127 | batch_size=1,
128 | strategy=strategy,
129 | get_masks_and_position_ids=get_func,
130 | **kwargs
131 | )[0] # drop memory
132 |
133 | return output
134 |
135 |
136 | def forward_step_eval(data_iterator, model, args, timers):
137 | def compute_metrics(eval_preds):
138 | preds, labels, device = eval_preds
139 | preds = preds.unsqueeze(0)
140 | if isinstance(preds, tuple):
141 | preds = preds[0]
142 | decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
143 | if args.ignore_pad_token_for_loss:
144 | # Replace -100 in the labels as we can't decode them.
145 | labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
146 | decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
147 |
148 | score_dict = {
149 | "acc": [],
150 | "acc_w/o_case": [],
151 | }
152 | for pred, label in zip(decoded_preds, decoded_labels):
153 | if args.rank == 0:
154 | print('pred', pred, 'label', label, flush=True)
155 | if pred == label:
156 | score_dict['acc'].append(1.)
157 | else:
158 | score_dict['acc'].append(0.)
159 | if pred.lower() == label.lower():
160 | score_dict['acc_w/o_case'].append(1.)
161 | else:
162 | score_dict['acc_w/o_case'].append(0.)
163 |
164 |
165 | for k, v in score_dict.items():
166 | score_dict[k] = float(np.mean(v))
167 | return score_dict
168 |
169 | # Get the batch.
170 | timers('batch generator').start()
171 | data_b = get_batch(
172 | data_iterator, args, timers)
173 | timers('batch generator').stop()
174 |
175 | context_len = int(data_b['context_length'][0])
176 | tokens = data_b['input_ids'][:, :context_len]
177 | data_b['vision_expert_mask'] = data_b['vision_expert_mask'][:, :context_len]
178 | data_b['image_embed_mask'] = data_b['image_embed_mask'][:, :context_len]
179 | data_b['image_rope_mask'] = data_b['image_rope_mask'][:, :context_len]
180 |
181 | data_b.pop('input_ids')
182 | data_b.pop('attention_mask')
183 | data_b.pop('position_ids')
184 | labels = data_b.pop('labels')
185 | qid = data_b.pop('question_id')
186 |
187 | model.add_mixin('auto-regressive', CachedAutoregressiveMixin())
188 | outputs = chat(model, tokenizer, tokens, **data_b)[0][context_len:]
189 | # print(outputs)
190 | model.del_mixin('auto-regressive')
191 |
192 | return torch.tensor(0, device=outputs.device), {k: torch.tensor(v, device=outputs.device) for k, v in
193 | compute_metrics(
194 | (outputs.cpu(), labels.cpu(), outputs.device)).items()}
195 |
196 |
197 | from torch.nn import CrossEntropyLoss
198 | def forward_step(data_iterator, model, args, timers):
199 | """Forward step."""
200 |
201 | # Get the batch.
202 | timers('batch generator').start()
203 | data_b = get_batch(
204 | data_iterator, args, timers)
205 | labels = data_b.pop('labels')
206 | timers('batch generator').stop()
207 | logits = model(**data_b)[0]
208 | lm_logits = logits.to(torch.float32)
209 | # Shift so that tokens < n predict n
210 | shift_labels = labels[..., 1:].contiguous()
211 | shift_logits = lm_logits[..., -1-shift_labels.size(-1):-1, :].contiguous()
212 | # Flatten the tokens
213 | loss_fct = CrossEntropyLoss(ignore_index=-100)
214 | loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
215 | loss = loss.to(torch.float32)
216 |
217 | return loss, {'loss': loss}
218 |
219 | from utils.utils import ItemDataset
220 | def create_dataset_function(image_processor, text_processor, cross_image_processor, path, args):
221 | dataset = ItemDataset(image_processor, text_processor, args, path, cross_image_processor=cross_image_processor)
222 | return dataset
223 |
224 | if __name__ == '__main__':
225 | py_parser = argparse.ArgumentParser(add_help=False)
226 | py_parser.add_argument('--max_length', type=int)
227 | py_parser.add_argument('--ignore_pad_token_for_loss', action='store_false')
228 | py_parser.add_argument("--version", type=str, default="chat", help='version to interact with')
229 | py_parser.add_argument("--from_pretrained", type=str, default="cogagent-chat", help='pretrained ckpt')
230 | py_parser.add_argument("--local_tokenizer", type=str, default="lmsys/vicuna-7b-v1.5", help='tokenizer path')
231 | py_parser.add_argument("--vit_checkpoint_activations", action='store_true')
232 | py_parser = FineTuneTestCogAgentModel.add_model_specific_args(py_parser)
233 | known, args_list = py_parser.parse_known_args()
234 | args = get_args(args_list)
235 | args = argparse.Namespace(**vars(args), **vars(known))
236 | if args.use_qlora:
237 | args.device = 'cpu'
238 |
239 | model, args = FineTuneTestCogAgentModel.from_pretrained(args.from_pretrained, args, overwrite_args={'model_parallel_size': args.model_parallel_size} if args.model_parallel_size != 1 else {})
240 | if args.use_qlora and torch.cuda.is_available():
241 | model = model.to('cuda')
242 | from utils.utils import llama2_tokenizer
243 | tokenizer = llama2_tokenizer(args.local_tokenizer, signal_type=args.version)
244 | image_processor = get_image_processor(args.eva_args["image_size"][0])
245 | cross_image_processor = get_image_processor(args.cross_image_pix)
246 | text_processor = llama2_text_processor(tokenizer, args.max_length, args.image_length)
247 |
248 | training_main(args, model_cls=model, forward_step_function=forward_step, create_dataset_function=partial(create_dataset_function, image_processor, text_processor, cross_image_processor), collate_fn=partial(data_collator, cross_image_processor=cross_image_processor), forward_step_eval=forward_step_eval)
--------------------------------------------------------------------------------
/finetune_demo/evaluate_cogvlm.sh:
--------------------------------------------------------------------------------
1 | #! /bin/bash
2 | # export PATH=/usr/local/cuda/bin:$PATH
3 | # export LD_LIBRARY_PATH=/usr/local/cuda/lib64:$LD_LIBRARY_PATH
4 |
5 | NUM_GPUS_PER_WORKER=8
6 | MP_SIZE=1
7 |
8 | script_path=$(realpath $0)
9 | script_dir=$(dirname $script_path)
10 | main_dir=$(dirname $script_dir)
11 | MODEL_TYPE="cogvlm-base-490"
12 | VERSION="base"
13 | MODEL_ARGS="--from_pretrained ./checkpoints/merged_lora_490 \
14 | --max_length 1288 \
15 | --lora_rank 10 \
16 | --use_lora \
17 | --local_tokenizer lmsys/vicuna-7b-v1.5 \
18 | --version $VERSION"
19 | # Tips: If training models of resolution 244, you can set --max_length smaller
20 |
21 |
22 | OPTIONS_SAT="SAT_HOME=~/.sat_models"
23 | OPTIONS_NCCL="NCCL_DEBUG=info NCCL_IB_DISABLE=0 NCCL_NET_GDR_LEVEL=2 LOCAL_WORLD_SIZE=$NUM_GPUS_PER_WORKER"
24 | HOST_FILE_PATH="hostfile"
25 |
26 | train_data="./archive_split/train"
27 | test_data="./archive_split/test"
28 |
29 | gpt_options=" \
30 | --experiment-name finetune-$MODEL_TYPE \
31 | --model-parallel-size ${MP_SIZE} \
32 | --mode finetune \
33 | --train-iters 0 \
34 | --resume-dataloader \
35 | $MODEL_ARGS \
36 | --train-data ${train_data} \
37 | --test-data ${test_data} \
38 | --distributed-backend nccl \
39 | --lr-decay-style cosine \
40 | --warmup .02 \
41 | --checkpoint-activations \
42 | --save-interval 200 \
43 | --eval-interval 200 \
44 | --save "./checkpoints" \
45 | --strict-eval \
46 | --eval-batch-size 1 \
47 | --split 1. \
48 | --deepspeed_config test_config_bf16.json \
49 | --skip-init \
50 | --seed 2023
51 | "
52 |
53 |
54 |
55 | run_cmd="${OPTIONS_NCCL} ${OPTIONS_SAT} deepspeed --master_port 16666 --hostfile ${HOST_FILE_PATH} evaluate_cogvlm_demo.py ${gpt_options}"
56 | echo ${run_cmd}
57 | eval ${run_cmd}
58 |
59 | set +x
--------------------------------------------------------------------------------
/finetune_demo/evaluate_cogvlm_demo.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import argparse
4 | from functools import partial
5 | import sys
6 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
7 |
8 | from sat import mpu, get_args, get_tokenizer
9 | from sat.training.deepspeed_training import training_main
10 | from sat.helpers import print_rank0
11 | from utils.models import FineTuneTestCogVLMModel
12 | from utils.utils import llama2_text_processor, llama2_text_processor_inference, get_image_processor
13 |
14 |
15 | def data_collator(examples):
16 | examples = [ex for ex in examples if len(ex) > 0] # drop {}
17 | for example in examples:
18 | for k in example:
19 | if isinstance(example[k], list):
20 | example[k] = torch.tensor(example[k])
21 | elif isinstance(example[k], np.ndarray):
22 | example[k] = torch.from_numpy(example[k])
23 | img_args = {}
24 | tmp_example = examples[0]
25 | for k in tmp_example['vision']:
26 | if type(tmp_example['vision'][k]) is torch.Tensor:
27 | img_args['vision_'+k] = torch.cat([example['vision'][k] for example in examples])
28 | else:
29 | img_args['vision_'+k] = example['vision'][k]
30 | for example in examples:
31 | example.pop('vision')
32 | if 'cross' in example:
33 | example.pop('cross')
34 |
35 | model_args = {}
36 | tmp_example = examples[0]
37 | for k in tmp_example:
38 | if type(tmp_example[k]) is torch.Tensor:
39 | model_args[k] = torch.cat([example[k] for example in examples])
40 | else:
41 | model_args[k] = tmp_example[k]
42 | model_args.update(img_args)
43 | return model_args
44 |
45 | from collections import defaultdict
46 |
47 | def broadcast_auto(data_dict):
48 | type2list = defaultdict(list)
49 | other = []
50 | for k in data_dict:
51 | if type(data_dict[k]) is torch.Tensor:
52 | type2list[data_dict[k].dtype].append(k)
53 | else:
54 | other.append(k)
55 | new_data = {}
56 | for k in type2list:
57 | new_data.update(mpu.broadcast_data(type2list[k], data_dict, k))
58 | for k in other:
59 | new_data[k] = data_dict[k]
60 | return new_data
61 |
62 | def get_batch(data_iterator, args, timers):
63 | # Broadcast data.
64 | timers('data loader').start()
65 | if data_iterator is not None:
66 | data = next(data_iterator)
67 | else:
68 | data = None
69 | timers('data loader').stop()
70 | data_b = broadcast_auto(data)
71 | for k in data_b:
72 | if type(data_b[k]) is torch.Tensor and data_b[k].dtype is not torch.int32 and data_b[k].dtype is not torch.long:
73 | if args.fp16:
74 | data_b[k] = data_b[k].half()
75 | elif args.bf16:
76 | data_b[k] = data_b[k].bfloat16()
77 | return data_b
78 |
79 | from torch.nn import CrossEntropyLoss
80 | import numpy as np
81 |
82 | from sat.model.mixins import CachedAutoregressiveMixin
83 | from sat.generation.autoregressive_sampling import filling_sequence
84 | from sat.generation.sampling_strategies import BaseStrategy, BeamSearchStrategy
85 |
86 |
87 | def chat(model, tokenizer, tokens,
88 | max_length: int = 1800, num_beams=5, top_p=0.95, top_k=0, temperature=0.8, **kwargs):
89 | inputs = tokens.to(model.parameters().__next__().device)[0]
90 | seq = torch.cat(
91 | [inputs, torch.tensor([-1] * (max_length - len(inputs)), device=inputs.device)], dim=0
92 | )
93 | strategy = BaseStrategy(temperature=temperature, top_p=0.4, top_k=1, end_tokens=[tokenizer.eos_token_id])
94 | # strategy = BeamSearchStrategy(temperature=temperature, top_p=top_p, top_k=top_k, end_tokens=[tokenizer.eos_token_id],
95 | # num_beams=num_beams, consider_end=True)
96 | get_func = llama2_text_processor_inference.get_func(None, None, image_rope_mask=kwargs['image_rope_mask'])
97 | output = filling_sequence(
98 | model, seq,
99 | batch_size=1,
100 | strategy=strategy,
101 | get_masks_and_position_ids=get_func,
102 | **kwargs
103 | )[0] # drop memory
104 |
105 | return output
106 |
107 |
108 | def forward_step_eval(data_iterator, model, args, timers):
109 | def compute_metrics(eval_preds):
110 | preds, labels, device = eval_preds
111 | preds = preds.unsqueeze(0)
112 | if isinstance(preds, tuple):
113 | preds = preds[0]
114 | decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
115 | if args.ignore_pad_token_for_loss:
116 | # Replace -100 in the labels as we can't decode them.
117 | labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
118 | decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
119 |
120 | score_dict = {
121 | "acc": [],
122 | "acc_w/o_case": [],
123 | }
124 | for pred, label in zip(decoded_preds, decoded_labels):
125 | if args.rank == 0:
126 | print('pred', pred, 'label', label, flush=True)
127 | if pred == label:
128 | score_dict['acc'].append(1.)
129 | else:
130 | score_dict['acc'].append(0.)
131 | if pred.lower() == label.lower():
132 | score_dict['acc_w/o_case'].append(1.)
133 | else:
134 | score_dict['acc_w/o_case'].append(0.)
135 |
136 |
137 | for k, v in score_dict.items():
138 | score_dict[k] = float(np.mean(v))
139 | return score_dict
140 |
141 | # Get the batch.
142 | timers('batch generator').start()
143 | data_b = get_batch(
144 | data_iterator, args, timers)
145 | timers('batch generator').stop()
146 |
147 | context_len = int(data_b['context_length'][0])
148 | tokens = data_b['input_ids'][:, :context_len]
149 | data_b['vision_expert_mask'] = data_b['vision_expert_mask'][:, :context_len]
150 | data_b['image_embed_mask'] = data_b['image_embed_mask'][:, :context_len]
151 | data_b['image_rope_mask'] = data_b['image_rope_mask'][:, :context_len]
152 |
153 | data_b.pop('input_ids')
154 | data_b.pop('attention_mask')
155 | data_b.pop('position_ids')
156 | labels = data_b.pop('labels')
157 | qid = data_b.pop('question_id')
158 |
159 | model.add_mixin('auto-regressive', CachedAutoregressiveMixin())
160 | outputs = chat(model, tokenizer, tokens, **data_b)[0][context_len:]
161 | # print(outputs)
162 | model.del_mixin('auto-regressive')
163 |
164 | return torch.tensor(0, device=outputs.device), {k: torch.tensor(v, device=outputs.device) for k, v in
165 | compute_metrics(
166 | (outputs.cpu(), labels.cpu(), outputs.device)).items()}
167 |
168 |
169 | from torch.nn import CrossEntropyLoss
170 | def forward_step(data_iterator, model, args, timers):
171 | """Forward step."""
172 |
173 | # Get the batch.
174 | timers('batch generator').start()
175 | data_b = get_batch(
176 | data_iterator, args, timers)
177 | labels = data_b.pop('labels')
178 | timers('batch generator').stop()
179 | logits = model(**data_b)[0]
180 | lm_logits = logits.to(torch.float32)
181 | # Shift so that tokens < n predict n
182 | shift_labels = labels[..., 1:].contiguous()
183 | shift_logits = lm_logits[..., -1-shift_labels.size(-1):-1, :].contiguous()
184 | # Flatten the tokens
185 | loss_fct = CrossEntropyLoss(ignore_index=-100)
186 | loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
187 | loss = loss.to(torch.float32)
188 |
189 | return loss, {'loss': loss}
190 |
191 | from utils.utils import ItemDataset
192 | def create_dataset_function(image_processor, text_processor, path, args):
193 | dataset = ItemDataset(image_processor, text_processor, args, path)
194 | return dataset
195 |
196 | if __name__ == '__main__':
197 | py_parser = argparse.ArgumentParser(add_help=False)
198 | py_parser.add_argument('--max_length', type=int)
199 | py_parser.add_argument('--ignore_pad_token_for_loss', action='store_false')
200 | py_parser.add_argument("--version", type=str, default="chat", help='version to interact with')
201 | py_parser.add_argument("--from_pretrained", type=str, default="cogvlm-chat", help='pretrained ckpt')
202 | py_parser.add_argument("--local_tokenizer", type=str, default="lmsys/vicuna-7b-v1.5", help='tokenizer path')
203 | py_parser.add_argument("--vit_checkpoint_activations", action='store_true')
204 | py_parser = FineTuneTestCogVLMModel.add_model_specific_args(py_parser)
205 | known, args_list = py_parser.parse_known_args()
206 | args = get_args(args_list)
207 | args = argparse.Namespace(**vars(args), **vars(known))
208 | if args.use_qlora:
209 | args.device = 'cpu'
210 |
211 | model, args = FineTuneTestCogVLMModel.from_pretrained(args.from_pretrained, args, overwrite_args={'model_parallel_size': args.model_parallel_size} if args.model_parallel_size != 1 else {})
212 | if args.use_qlora and torch.cuda.is_available():
213 | model = model.to('cuda')
214 | from utils.utils import llama2_tokenizer
215 | tokenizer = llama2_tokenizer(args.local_tokenizer, signal_type=args.version)
216 | image_processor = get_image_processor(args.eva_args["image_size"][0])
217 | text_processor = llama2_text_processor(tokenizer, args.max_length, args.image_length)
218 |
219 | training_main(args, model_cls=model, forward_step_function=forward_step, create_dataset_function=partial(create_dataset_function, image_processor, text_processor), collate_fn=data_collator, forward_step_eval=forward_step_eval)
--------------------------------------------------------------------------------
/finetune_demo/finetune_cogagent_demo.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import argparse
4 | from functools import partial
5 | import sys
6 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
7 |
8 | from sat import mpu, get_args, get_tokenizer
9 | from sat.training.deepspeed_training import training_main
10 | from sat.helpers import print_rank0
11 | from utils.models import FineTuneTrainCogAgentModel
12 | from utils.utils import llama2_text_processor, llama2_text_processor_inference, get_image_processor
13 |
14 | def disable_untrainable_params(self):
15 | total_trainable = 0
16 | # enable = ['vit']
17 | enable = ["encoder", "cross_attention", "linear_proj", 'mlp.vision', 'rotary.vision', 'eoi', 'boi', 'vit']
18 | if self.args.use_ptuning:
19 | enable.extend(['ptuning'])
20 | if self.args.use_lora or self.args.use_qlora:
21 | enable.extend(['matrix_A', 'matrix_B'])
22 | for n, p in self.named_parameters():
23 | flag = False
24 | for e in enable:
25 | if type(e) is tuple:
26 | if e[0].lower() in n.lower() and e[1].lower() in n.lower() and 55 > int(n[:n.find('.mlp')].split('.')[-1]) > 45:
27 | flag = True
28 | break
29 | else:
30 | if e.lower() in n.lower():
31 | flag = True
32 | break
33 | if not flag:
34 | p.requires_grad_(False)
35 | else:
36 | total_trainable += p.numel()
37 | if 'encoder' in n or 'vit' in n:
38 | p.lr_scale = 0.1
39 | print_rank0(n)
40 | print_rank0("***** Total trainable parameters: "+str(total_trainable)+" *****")
41 |
42 | FineTuneTrainCogAgentModel.disable_untrainable_params = disable_untrainable_params
43 |
44 | def data_collator(examples, cross_image_processor=None):
45 | def to_tensor(value):
46 | """Converts lists or numpy arrays to tensors."""
47 | if isinstance(value, list):
48 | return torch.tensor(value)
49 | elif isinstance(value, np.ndarray):
50 | return torch.from_numpy(value)
51 | return value
52 |
53 | def concatenate_tensors(attribute, key):
54 | """Concatenates tensors for a specific attribute and key."""
55 | if attribute is None:
56 | return torch.cat([ex[key] for ex in examples if isinstance(ex[key], torch.Tensor)])
57 | else:
58 | return torch.cat([ex[attribute][key] for ex in examples if isinstance(ex[attribute][key], torch.Tensor)])
59 |
60 | # Convert all lists and numpy arrays in examples to tensors
61 | for example in examples:
62 | for key, value in example.items():
63 | example[key] = to_tensor(value)
64 |
65 | # Extract and concatenate attributes from examples
66 | img_args = {}
67 | for attribute in ['vision', 'cross']:
68 | if attribute == 'cross' and cross_image_processor is None:
69 | continue
70 |
71 | if attribute in examples[-1]: # Using the last example as reference
72 | for key in examples[-1][attribute]:
73 | tensor_key = f"{attribute}_{key}"
74 | tensors_to_concatenate = [ex[attribute][key] for ex in examples if isinstance(ex[attribute][key], torch.Tensor)]
75 | if tensors_to_concatenate:
76 | img_args[tensor_key] = concatenate_tensors(attribute, key)
77 | else:
78 | img_args[tensor_key] = examples[-1][attribute][key]
79 |
80 | # Remove 'vision' and 'cross' keys from examples
81 | for example in examples:
82 | example.pop('vision', None)
83 | example.pop('cross', None)
84 |
85 | # Create model_args by concatenating tensors and copying other attributes
86 | model_args = {key: concatenate_tensors(None, key)
87 | if isinstance(examples[-1][key], torch.Tensor) else examples[-1][key]
88 | for key in examples[-1]
89 | }
90 |
91 | # Merge img_args into model_args
92 | model_args.update(img_args)
93 | return model_args
94 |
95 |
96 | from collections import defaultdict
97 |
98 | def broadcast_auto(data_dict):
99 | type2list = defaultdict(list)
100 | other = []
101 | for k in data_dict:
102 | if type(data_dict[k]) is torch.Tensor:
103 | type2list[data_dict[k].dtype].append(k)
104 | else:
105 | other.append(k)
106 | new_data = {}
107 | for k in type2list:
108 | new_data.update(mpu.broadcast_data(type2list[k], data_dict, k))
109 | for k in other:
110 | new_data[k] = data_dict[k]
111 | return new_data
112 |
113 | def get_batch(data_iterator, args, timers):
114 | # Broadcast data.
115 | timers('data loader').start()
116 | if data_iterator is not None:
117 | data = next(data_iterator)
118 | else:
119 | data = None
120 | timers('data loader').stop()
121 | data_b = broadcast_auto(data)
122 | for k in data_b:
123 | if type(data_b[k]) is torch.Tensor and data_b[k].dtype is not torch.int32 and data_b[k].dtype is not torch.long:
124 | if args.fp16:
125 | data_b[k] = data_b[k].half()
126 | elif args.bf16:
127 | data_b[k] = data_b[k].bfloat16()
128 | return data_b
129 |
130 | from torch.nn import CrossEntropyLoss
131 | import numpy as np
132 |
133 | from sat.model.mixins import CachedAutoregressiveMixin
134 | from sat.generation.autoregressive_sampling import filling_sequence
135 | from sat.generation.sampling_strategies import BaseStrategy, BeamSearchStrategy
136 |
137 |
138 | def chat(model, tokenizer, tokens,
139 | max_length: int = 1800, num_beams=5, top_p=0.95, top_k=0, temperature=0.8, **kwargs):
140 | inputs = tokens.to(model.parameters().__next__().device)[0]
141 | seq = torch.cat(
142 | [inputs, torch.tensor([-1] * (max_length - len(inputs)), device=inputs.device)], dim=0
143 | )
144 | strategy = BaseStrategy(temperature=temperature, top_p=0.4, top_k=1, end_tokens=[tokenizer.eos_token_id])
145 | # strategy = BeamSearchStrategy(temperature=temperature, top_p=top_p, top_k=top_k, end_tokens=[tokenizer.eos_token_id],
146 | # num_beams=num_beams, consider_end=True)
147 | get_func = llama2_text_processor_inference.get_func(None, None, image_rope_mask=kwargs['image_rope_mask'])
148 | output = filling_sequence(
149 | model, seq,
150 | batch_size=1,
151 | strategy=strategy,
152 | get_masks_and_position_ids=get_func,
153 | **kwargs
154 | )[0] # drop memory
155 |
156 | return output
157 |
158 |
159 | def forward_step_eval(data_iterator, model, args, timers):
160 | def compute_metrics(eval_preds):
161 | preds, labels, device = eval_preds
162 | preds = preds.unsqueeze(0)
163 | if isinstance(preds, tuple):
164 | preds = preds[0]
165 | decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
166 | if args.ignore_pad_token_for_loss:
167 | # Replace -100 in the labels as we can't decode them.
168 | labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
169 | decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
170 |
171 | score_dict = {
172 | "acc": [],
173 | "acc_w/o_case": [],
174 | }
175 | for pred, label in zip(decoded_preds, decoded_labels):
176 | if args.rank == 0:
177 | print('pred', pred, 'label', label, flush=True)
178 | if pred == label:
179 | score_dict['acc'].append(1.)
180 | else:
181 | score_dict['acc'].append(0.)
182 | if pred.lower() == label.lower():
183 | score_dict['acc_w/o_case'].append(1.)
184 | else:
185 | score_dict['acc_w/o_case'].append(0.)
186 |
187 |
188 | for k, v in score_dict.items():
189 | score_dict[k] = float(np.mean(v))
190 | return score_dict
191 |
192 | # Get the batch.
193 | timers('batch generator').start()
194 | data_b = get_batch(
195 | data_iterator, args, timers)
196 | timers('batch generator').stop()
197 |
198 | context_len = int(data_b['context_length'][0])
199 | tokens = data_b['input_ids'][:, :context_len]
200 | data_b['vision_expert_mask'] = data_b['vision_expert_mask'][:, :context_len]
201 | data_b['image_embed_mask'] = data_b['image_embed_mask'][:, :context_len]
202 | data_b['image_rope_mask'] = data_b['image_rope_mask'][:, :context_len]
203 |
204 | data_b.pop('input_ids')
205 | data_b.pop('attention_mask')
206 | data_b.pop('position_ids')
207 | labels = data_b.pop('labels')
208 | qid = data_b.pop('question_id')
209 |
210 | model.add_mixin('auto-regressive', CachedAutoregressiveMixin())
211 | outputs = chat(model, tokenizer, tokens, **data_b)[0][context_len:]
212 | # print(outputs)
213 | model.del_mixin('auto-regressive')
214 |
215 | return torch.tensor(0, device=outputs.device), {k: torch.tensor(v, device=outputs.device) for k, v in
216 | compute_metrics(
217 | (outputs.cpu(), labels.cpu(), outputs.device)).items()}
218 |
219 |
220 | from torch.nn import CrossEntropyLoss
221 | def forward_step(data_iterator, model, args, timers):
222 | """Forward step."""
223 |
224 | # Get the batch.
225 | timers('batch generator').start()
226 | data_b = get_batch(
227 | data_iterator, args, timers)
228 | labels = data_b.pop('labels')
229 | timers('batch generator').stop()
230 | logits = model(**data_b)[0]
231 | lm_logits = logits.to(torch.float32)
232 | # Shift so that tokens < n predict n
233 | shift_labels = labels[..., 1:].contiguous()
234 | shift_logits = lm_logits[..., -1-shift_labels.size(-1):-1, :].contiguous()
235 | # Flatten the tokens
236 | loss_fct = CrossEntropyLoss(ignore_index=-100)
237 | loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
238 | loss = loss.to(torch.float32)
239 |
240 | return loss, {'loss': loss}
241 |
242 | from utils.utils import ItemDataset
243 | def create_dataset_function(image_processor, text_processor, cross_image_processor, path, args):
244 | dataset = ItemDataset(image_processor, text_processor, args, path, cross_image_processor=cross_image_processor)
245 | return dataset
246 |
247 | from sat.model.finetune.lora2 import LoraMixin
248 | from sat.model.finetune.prompt_tuning import PTuningV2Mixin
249 |
250 | if __name__ == '__main__':
251 | py_parser = argparse.ArgumentParser(add_help=False)
252 | py_parser.add_argument('--max_length', type=int)
253 | py_parser.add_argument('--ignore_pad_token_for_loss', action='store_false')
254 | py_parser.add_argument("--version", type=str, default="chat", choices=["chat", "vqa"], help='version to interact with')
255 | py_parser.add_argument("--from_pretrained", type=str, default="cogagent-chat", help='pretrained ckpt')
256 | py_parser.add_argument("--local_tokenizer", type=str, default="lmsys/vicuna-7b-v1.5", help='tokenizer path')
257 | py_parser.add_argument("--vit_checkpoint_activations", action='store_true')
258 | py_parser = FineTuneTrainCogAgentModel.add_model_specific_args(py_parser)
259 | known, args_list = py_parser.parse_known_args()
260 | args = get_args(args_list)
261 | args = argparse.Namespace(**vars(args), **vars(known))
262 | if args.use_qlora:
263 | args.device = 'cpu'
264 |
265 | model, args = FineTuneTrainCogAgentModel.from_pretrained(args.from_pretrained, args, overwrite_args={'model_parallel_size': args.model_parallel_size} if args.model_parallel_size != 1 else {})
266 | if args.use_ptuning: # TODO: wait for SAT updating
267 | model.add_mixin("ptuning", PTuningV2Mixin(args.num_layers, args.hidden_size // args.num_attention_heads, args.num_attention_heads, args.pre_seq_len))
268 |
269 | if args.use_lora:
270 | model.add_mixin("lora", LoraMixin(args.num_layers, args.lora_rank, layer_range=args.layer_range), reinit=True)
271 | model.get_mixin("eva").vit_model.add_mixin("lora", LoraMixin(args.eva_args['num_layers'], args.lora_rank, layer_range=args.layer_range), reinit=True)
272 | elif args.use_qlora:
273 | model.add_mixin("lora", LoraMixin(args.num_layers, args.lora_rank, layer_range=args.layer_range, qlora=True), reinit=True)
274 |
275 | if args.use_qlora and torch.cuda.is_available():
276 | model = model.to('cuda')
277 | from utils.utils import llama2_tokenizer
278 | tokenizer = llama2_tokenizer(args.local_tokenizer, signal_type=args.version)
279 | image_processor = get_image_processor(args.eva_args["image_size"][0])
280 | cross_image_processor = get_image_processor(args.cross_image_pix)
281 | text_processor = llama2_text_processor(tokenizer, args.max_length, args.image_length)
282 |
283 | model = training_main(args, model_cls=model, forward_step_function=forward_step, create_dataset_function=partial(create_dataset_function, image_processor, text_processor, cross_image_processor), collate_fn=partial(data_collator, cross_image_processor=cross_image_processor), forward_step_eval=forward_step_eval)
284 | if args.use_lora:
285 | model.get_mixin("lora").merge_lora()
286 | model.get_mixin("eva").vit_model.get_mixin("lora").merge_lora()
287 | args.use_lora = False
288 | args.save = "checkpoints/merged_lora_cogagent"
289 | from sat.training.model_io import save_checkpoint
290 | save_checkpoint(1, model, None, None, args)
--------------------------------------------------------------------------------
/finetune_demo/finetune_cogagent_lora.sh:
--------------------------------------------------------------------------------
1 | #! /bin/bash
2 | # export PATH=/usr/local/cuda/bin:$PATH
3 | # export LD_LIBRARY_PATH=/usr/local/cuda/lib64:$LD_LIBRARY_PATH
4 |
5 | NUM_GPUS_PER_WORKER=8
6 | MP_SIZE=1
7 |
8 | script_path=$(realpath $0)
9 | script_dir=$(dirname $script_path)
10 | main_dir=$(dirname $script_dir)
11 | MODEL_TYPE="cogagent-chat"
12 | VERSION="chat"
13 | MODEL_ARGS="--from_pretrained $MODEL_TYPE \
14 | --max_length 400 \
15 | --lora_rank 50 \
16 | --use_lora \
17 | --local_tokenizer lmsys/vicuna-7b-v1.5 \
18 | --version $VERSION"
19 | # TIPS: max_length include low-resolution image sequence (which has 256 tokens)
20 |
21 | OPTIONS_SAT="SAT_HOME=~/.sat_models"
22 | OPTIONS_NCCL="NCCL_DEBUG=info NCCL_IB_DISABLE=0 NCCL_NET_GDR_LEVEL=2 LOCAL_WORLD_SIZE=$NUM_GPUS_PER_WORKER"
23 | HOST_FILE_PATH="hostfile"
24 |
25 | train_data="./archive_split/train"
26 | valid_data="./archive_split/valid"
27 |
28 | gpt_options=" \
29 | --experiment-name finetune-$MODEL_TYPE \
30 | --model-parallel-size ${MP_SIZE} \
31 | --mode finetune \
32 | --train-iters 2000 \
33 | --resume-dataloader \
34 | $MODEL_ARGS \
35 | --train-data ${train_data} \
36 | --valid-data ${valid_data} \
37 | --distributed-backend nccl \
38 | --lr-decay-style cosine \
39 | --warmup .02 \
40 | --checkpoint-activations \
41 | --vit_checkpoint_activations \
42 | --save-interval 200 \
43 | --eval-interval 200 \
44 | --save "./checkpoints" \
45 | --eval-iters 10 \
46 | --eval-batch-size 1 \
47 | --split 1. \
48 | --deepspeed_config test_config_bf16.json \
49 | --skip-init \
50 | --seed 2023
51 | "
52 |
53 |
54 |
55 | run_cmd="${OPTIONS_NCCL} ${OPTIONS_SAT} deepspeed --master_port 16666 --hostfile ${HOST_FILE_PATH} finetune_cogagent_demo.py ${gpt_options}"
56 | echo ${run_cmd}
57 | eval ${run_cmd}
58 |
59 | set +x
--------------------------------------------------------------------------------
/finetune_demo/finetune_cogvlm_demo.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import argparse
4 | from functools import partial
5 | import sys
6 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
7 |
8 | from sat import mpu, get_args, get_tokenizer
9 | from sat.training.deepspeed_training import training_main
10 | from sat.helpers import print_rank0
11 | from utils.models import FineTuneTrainCogVLMModel
12 | from utils.utils import llama2_text_processor, llama2_text_processor_inference, get_image_processor
13 |
14 | def disable_untrainable_params(self):
15 | total_trainable = 0
16 | enable = [('mlp', 'vit')]
17 | if self.args.use_ptuning:
18 | enable.extend(['ptuning'])
19 | if self.args.use_lora or self.args.use_qlora:
20 | enable.extend(['matrix_A', 'matrix_B'])
21 | for n, p in self.named_parameters():
22 | flag = False
23 | for e in enable:
24 | if type(e) is tuple:
25 | if e[0].lower() in n.lower() and e[1].lower() in n.lower() and 55 > int(n[:n.find('.mlp')].split('.')[-1]) > 45:
26 | flag = True
27 | break
28 | else:
29 | if e.lower() in n.lower():
30 | flag = True
31 | break
32 | if not flag:
33 | p.requires_grad_(False)
34 | else:
35 | total_trainable += p.numel()
36 | print_rank0(n)
37 | print_rank0("***** Total trainable parameters: "+str(total_trainable)+" *****")
38 |
39 | FineTuneTrainCogVLMModel.disable_untrainable_params = disable_untrainable_params
40 |
41 | def data_collator(examples):
42 | examples = [ex for ex in examples if len(ex) > 0] # drop {}
43 | for example in examples:
44 | for k in example:
45 | if isinstance(example[k], list):
46 | example[k] = torch.tensor(example[k])
47 | elif isinstance(example[k], np.ndarray):
48 | example[k] = torch.from_numpy(example[k])
49 | img_args = {}
50 | tmp_example = examples[0]
51 | for k in tmp_example['vision']:
52 | if type(tmp_example['vision'][k]) is torch.Tensor:
53 | img_args['vision_'+k] = torch.cat([example['vision'][k] for example in examples])
54 | else:
55 | img_args['vision_'+k] = example['vision'][k]
56 | for example in examples:
57 | example.pop('vision')
58 | if 'cross' in example:
59 | example.pop('cross')
60 |
61 | model_args = {}
62 | tmp_example = examples[0]
63 | for k in tmp_example:
64 | if type(tmp_example[k]) is torch.Tensor:
65 | model_args[k] = torch.cat([example[k] for example in examples])
66 | else:
67 | model_args[k] = tmp_example[k]
68 | model_args.update(img_args)
69 | return model_args
70 |
71 | from collections import defaultdict
72 |
73 | def broadcast_auto(data_dict):
74 | type2list = defaultdict(list)
75 | other = []
76 | for k in data_dict:
77 | if type(data_dict[k]) is torch.Tensor:
78 | type2list[data_dict[k].dtype].append(k)
79 | else:
80 | other.append(k)
81 | new_data = {}
82 | for k in type2list:
83 | new_data.update(mpu.broadcast_data(type2list[k], data_dict, k))
84 | for k in other:
85 | new_data[k] = data_dict[k]
86 | return new_data
87 |
88 | def get_batch(data_iterator, args, timers):
89 | # Broadcast data.
90 | timers('data loader').start()
91 | if data_iterator is not None:
92 | data = next(data_iterator)
93 | else:
94 | data = None
95 | timers('data loader').stop()
96 | data_b = broadcast_auto(data)
97 | for k in data_b:
98 | if type(data_b[k]) is torch.Tensor and data_b[k].dtype is not torch.int32 and data_b[k].dtype is not torch.long:
99 | if args.fp16:
100 | data_b[k] = data_b[k].half()
101 | elif args.bf16:
102 | data_b[k] = data_b[k].bfloat16()
103 | return data_b
104 |
105 | from torch.nn import CrossEntropyLoss
106 | import numpy as np
107 |
108 | from sat.model.mixins import CachedAutoregressiveMixin
109 | from sat.generation.autoregressive_sampling import filling_sequence
110 | from sat.generation.sampling_strategies import BaseStrategy, BeamSearchStrategy
111 |
112 |
113 | def chat(model, tokenizer, tokens,
114 | max_length: int = 1800, num_beams=5, top_p=0.95, top_k=0, temperature=0.8, **kwargs):
115 | inputs = tokens.to(model.parameters().__next__().device)[0]
116 | seq = torch.cat(
117 | [inputs, torch.tensor([-1] * (max_length - len(inputs)), device=inputs.device)], dim=0
118 | )
119 | strategy = BaseStrategy(temperature=temperature, top_p=0.4, top_k=1, end_tokens=[tokenizer.eos_token_id])
120 | # strategy = BeamSearchStrategy(temperature=temperature, top_p=top_p, top_k=top_k, end_tokens=[tokenizer.eos_token_id],
121 | # num_beams=num_beams, consider_end=True)
122 | get_func = llama2_text_processor_inference.get_func(None, None, image_rope_mask=kwargs['image_rope_mask'])
123 | output = filling_sequence(
124 | model, seq,
125 | batch_size=1,
126 | strategy=strategy,
127 | get_masks_and_position_ids=get_func,
128 | **kwargs
129 | )[0] # drop memory
130 |
131 | return output
132 |
133 |
134 | def forward_step_eval(data_iterator, model, args, timers):
135 | def compute_metrics(eval_preds):
136 | preds, labels, device = eval_preds
137 | preds = preds.unsqueeze(0)
138 | if isinstance(preds, tuple):
139 | preds = preds[0]
140 | decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
141 | if args.ignore_pad_token_for_loss:
142 | # Replace -100 in the labels as we can't decode them.
143 | labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
144 | decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
145 |
146 | score_dict = {
147 | "acc": [],
148 | "acc_w/o_case": [],
149 | }
150 | for pred, label in zip(decoded_preds, decoded_labels):
151 | if args.rank == 0:
152 | print('pred', pred, 'label', label, flush=True)
153 | if pred == label:
154 | score_dict['acc'].append(1.)
155 | else:
156 | score_dict['acc'].append(0.)
157 | if pred.lower() == label.lower():
158 | score_dict['acc_w/o_case'].append(1.)
159 | else:
160 | score_dict['acc_w/o_case'].append(0.)
161 |
162 |
163 | for k, v in score_dict.items():
164 | score_dict[k] = float(np.mean(v))
165 | return score_dict
166 |
167 | # Get the batch.
168 | timers('batch generator').start()
169 | data_b = get_batch(
170 | data_iterator, args, timers)
171 | timers('batch generator').stop()
172 |
173 | context_len = int(data_b['context_length'][0])
174 | tokens = data_b['input_ids'][:, :context_len]
175 | data_b['vision_expert_mask'] = data_b['vision_expert_mask'][:, :context_len]
176 | data_b['image_embed_mask'] = data_b['image_embed_mask'][:, :context_len]
177 | data_b['image_rope_mask'] = data_b['image_rope_mask'][:, :context_len]
178 |
179 | data_b.pop('input_ids')
180 | data_b.pop('attention_mask')
181 | data_b.pop('position_ids')
182 | labels = data_b.pop('labels')
183 | qid = data_b.pop('question_id')
184 |
185 | model.add_mixin('auto-regressive', CachedAutoregressiveMixin())
186 | outputs = chat(model, tokenizer, tokens, **data_b)[0][context_len:]
187 | # print(outputs)
188 | model.del_mixin('auto-regressive')
189 |
190 | return torch.tensor(0, device=outputs.device), {k: torch.tensor(v, device=outputs.device) for k, v in
191 | compute_metrics(
192 | (outputs.cpu(), labels.cpu(), outputs.device)).items()}
193 |
194 |
195 | from torch.nn import CrossEntropyLoss
196 | def forward_step(data_iterator, model, args, timers):
197 | """Forward step."""
198 |
199 | # Get the batch.
200 | timers('batch generator').start()
201 | data_b = get_batch(
202 | data_iterator, args, timers)
203 | labels = data_b.pop('labels')
204 | timers('batch generator').stop()
205 | logits = model(**data_b)[0]
206 | lm_logits = logits.to(torch.float32)
207 | # Shift so that tokens < n predict n
208 | shift_labels = labels[..., 1:].contiguous()
209 | shift_logits = lm_logits[..., -1-shift_labels.size(-1):-1, :].contiguous()
210 | # Flatten the tokens
211 | loss_fct = CrossEntropyLoss(ignore_index=-100)
212 | loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
213 | loss = loss.to(torch.float32)
214 |
215 | return loss, {'loss': loss}
216 |
217 | from utils.utils import ItemDataset
218 | def create_dataset_function(image_processor, text_processor, path, args):
219 | dataset = ItemDataset(image_processor, text_processor, args, path)
220 | return dataset
221 |
222 | from sat.model.finetune.lora2 import LoraMixin
223 | from sat.model.finetune.prompt_tuning import PTuningV2Mixin
224 |
225 | if __name__ == '__main__':
226 | py_parser = argparse.ArgumentParser(add_help=False)
227 | py_parser.add_argument('--max_length', type=int)
228 | py_parser.add_argument('--ignore_pad_token_for_loss', action='store_false')
229 | py_parser.add_argument("--version", type=str, default="chat_old", help='version to interact with')
230 | py_parser.add_argument("--from_pretrained", type=str, default="cogvlm-chat", help='pretrained ckpt')
231 | py_parser.add_argument("--local_tokenizer", type=str, default="lmsys/vicuna-7b-v1.5", help='tokenizer path')
232 | py_parser.add_argument("--vit_checkpoint_activations", action='store_true')
233 | py_parser = FineTuneTrainCogVLMModel.add_model_specific_args(py_parser)
234 | known, args_list = py_parser.parse_known_args()
235 | args = get_args(args_list)
236 | args = argparse.Namespace(**vars(args), **vars(known))
237 | if args.use_qlora:
238 | args.device = 'cpu'
239 |
240 | model, args = FineTuneTrainCogVLMModel.from_pretrained(args.from_pretrained, args, overwrite_args={'model_parallel_size': args.model_parallel_size} if args.model_parallel_size != 1 else {})
241 | if args.use_ptuning:
242 | model.add_mixin("ptuning", PTuningV2Mixin(args.num_layers, args.hidden_size // args.num_attention_heads, args.num_attention_heads, args.pre_seq_len))
243 | if args.use_lora:
244 | model.add_mixin("lora", LoraMixin(args.num_layers, args.lora_rank, layer_range=args.layer_range), reinit=True)
245 | model.get_mixin("eva").vit_model.add_mixin("lora", LoraMixin(args.eva_args['num_layers'], args.lora_rank, layer_range=args.layer_range), reinit=True)
246 | elif args.use_qlora:
247 | model.add_mixin("lora", LoraMixin(args.num_layers, args.lora_rank, layer_range=args.layer_range, qlora=True), reinit=True)
248 |
249 | if args.use_qlora and torch.cuda.is_available():
250 | model = model.to('cuda')
251 | from utils.utils import llama2_tokenizer
252 | tokenizer = llama2_tokenizer(args.local_tokenizer, signal_type=args.version)
253 | image_processor = get_image_processor(args.eva_args["image_size"][0])
254 | text_processor = llama2_text_processor(tokenizer, args.max_length, args.image_length)
255 |
256 | model = training_main(args, model_cls=model, forward_step_function=forward_step, create_dataset_function=partial(create_dataset_function, image_processor, text_processor), collate_fn=data_collator, forward_step_eval=forward_step_eval)
257 | if args.use_lora:
258 | model.get_mixin("lora").merge_lora()
259 | model.get_mixin("eva").vit_model.get_mixin("lora").merge_lora()
260 | args.use_lora = False
261 | args.save = "checkpoints/merged_lora_cogvlm{}".format(args.eva_args["image_size"][0])
262 | from sat.training.model_io import save_checkpoint
263 | save_checkpoint(1, model, None, None, args)
--------------------------------------------------------------------------------
/finetune_demo/finetune_cogvlm_lora.sh:
--------------------------------------------------------------------------------
1 | #! /bin/bash
2 | # export PATH=/usr/local/cuda/bin:$PATH
3 | # export LD_LIBRARY_PATH=/usr/local/cuda/lib64:$LD_LIBRARY_PATH
4 |
5 | NUM_GPUS_PER_WORKER=8
6 | MP_SIZE=1
7 |
8 | script_path=$(realpath $0)
9 | script_dir=$(dirname $script_path)
10 | main_dir=$(dirname $script_dir)
11 | MODEL_TYPE="cogvlm-base-490"
12 | VERSION="base"
13 | MODEL_ARGS="--from_pretrained $MODEL_TYPE \
14 | --max_length 1288 \
15 | --lora_rank 10 \
16 | --use_lora \
17 | --local_tokenizer lmsys/vicuna-7b-v1.5 \
18 | --version $VERSION"
19 | # Tips: If training models of resolution 244, you can set --max_length smaller
20 |
21 | OPTIONS_SAT="SAT_HOME=~/.sat_models"
22 | OPTIONS_NCCL="NCCL_DEBUG=info NCCL_IB_DISABLE=0 NCCL_NET_GDR_LEVEL=2 LOCAL_WORLD_SIZE=$NUM_GPUS_PER_WORKER"
23 | HOST_FILE_PATH="hostfile"
24 |
25 | train_data="./archive_split/train"
26 | valid_data="./archive_split/valid"
27 |
28 | gpt_options=" \
29 | --experiment-name finetune-$MODEL_TYPE \
30 | --model-parallel-size ${MP_SIZE} \
31 | --mode finetune \
32 | --train-iters 800 \
33 | --resume-dataloader \
34 | $MODEL_ARGS \
35 | --train-data ${train_data} \
36 | --valid-data ${valid_data} \
37 | --distributed-backend nccl \
38 | --lr-decay-style cosine \
39 | --warmup .02 \
40 | --checkpoint-activations \
41 | --vit_checkpoint_activations \
42 | --save-interval 200 \
43 | --eval-interval 200 \
44 | --save "./checkpoints" \
45 | --eval-iters 10 \
46 | --eval-batch-size 1 \
47 | --split 1. \
48 | --deepspeed_config test_config_bf16.json \
49 | --skip-init \
50 | --seed 2023
51 | "
52 |
53 |
54 |
55 | run_cmd="${OPTIONS_NCCL} ${OPTIONS_SAT} deepspeed --master_port 16666 --hostfile ${HOST_FILE_PATH} finetune_cogvlm_demo.py ${gpt_options}"
56 | echo ${run_cmd}
57 | eval ${run_cmd}
58 |
59 | set +x
--------------------------------------------------------------------------------
/finetune_demo/test_config_bf16.json:
--------------------------------------------------------------------------------
1 | {
2 | "train_micro_batch_size_per_gpu": 4,
3 | "gradient_accumulation_steps": 1,
4 | "gradient_clipping": 0.1,
5 | "zero_optimization": {
6 | "stage": 2,
7 | "contiguous_gradients": false,
8 | "overlap_comm": true,
9 | "reduce_scatter": true,
10 | "reduce_bucket_size": 4e7,
11 | "allgather_bucket_size": 1e8,
12 | "load_from_fp32_weights": false
13 | },
14 | "offload_optimizer": {
15 | "device": "cpu",
16 | "pin_memory": true
17 | },
18 | "zero_allow_untested_optimizer": true,
19 | "bf16": {
20 | "enabled": true
21 | },
22 | "optimizer": {
23 | "type": "Adam",
24 | "params": {
25 | "lr": 0.00001,
26 | "betas": [
27 | 0.9,
28 | 0.95
29 | ],
30 | "eps": 1e-8,
31 | "weight_decay": 5e-2
32 | }
33 | },
34 | "activation_checkpointing": {
35 | "partition_activations": false,
36 | "contiguous_memory_optimization": false,
37 | "cpu_checkpointing": false
38 | },
39 | "wall_clock_breakdown": false
40 | }
41 |
--------------------------------------------------------------------------------
/openai_demo/demo.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THUDM/CogVLM/f7283b2c8d26cd7f932d9a5f7f5f9307f568195d/openai_demo/demo.jpg
--------------------------------------------------------------------------------
/openai_demo/openai_api_request.py:
--------------------------------------------------------------------------------
1 | """
2 | This script is designed to mimic the OpenAI API interface with CogVLM & CogAgent Chat
3 | It demonstrates how to integrate image and text-based input to generate a response.
4 | Currently, the model can only handle a single image.
5 | Therefore, do not use this script to process multiple images in one conversation. (includes images from history)
6 | And it only works on the chat model, not the base model.
7 | """
8 | import requests
9 | import json
10 | import base64
11 |
12 | base_url = "http://127.0.0.1:8000"
13 |
14 |
15 | def create_chat_completion(model, messages, temperature=0.8, max_tokens=2048, top_p=0.8, use_stream=False):
16 | """
17 | This function sends a request to the chat API to generate a response based on the given messages.
18 |
19 | Args:
20 | model (str): The name of the model to use for generating the response.
21 | messages (list): A list of message dictionaries representing the conversation history.
22 | temperature (float): Controls randomness in response generation. Higher values lead to more random responses.
23 | max_tokens (int): The maximum length of the generated response.
24 | top_p (float): Controls diversity of response by filtering less likely options.
25 | use_stream (bool): Determines whether to use a streaming response or a single response.
26 |
27 | The function constructs a JSON payload with the specified parameters and sends a POST request to the API.
28 | It then handles the response, either as a stream (for ongoing responses) or a single message.
29 | """
30 |
31 | data = {
32 | "model": model,
33 | "messages": messages,
34 | "stream": use_stream,
35 | "max_tokens": max_tokens,
36 | "temperature": temperature,
37 | "top_p": top_p,
38 | }
39 |
40 | response = requests.post(f"{base_url}/v1/chat/completions", json=data, stream=use_stream)
41 | if response.status_code == 200:
42 | if use_stream:
43 | # 处理流式响应
44 | for line in response.iter_lines():
45 | if line:
46 | decoded_line = line.decode('utf-8')[6:]
47 | try:
48 | response_json = json.loads(decoded_line)
49 | content = response_json.get("choices", [{}])[0].get("delta", {}).get("content", "")
50 | print(content)
51 | except:
52 | print("Special Token:", decoded_line)
53 | else:
54 | # 处理非流式响应
55 | decoded_line = response.json()
56 | content = decoded_line.get("choices", [{}])[0].get("message", "").get("content", "")
57 | print(content)
58 | else:
59 | print("Error:", response.status_code)
60 | return None
61 |
62 |
63 | def encode_image(image_path):
64 | """
65 | Encodes an image file into a base64 string.
66 | Args:
67 | image_path (str): The path to the image file.
68 |
69 | This function opens the specified image file, reads its content, and encodes it into a base64 string.
70 | The base64 encoding is used to send images over HTTP as text.
71 | """
72 |
73 | with open(image_path, "rb") as image_file:
74 | return base64.b64encode(image_file.read()).decode("utf-8")
75 |
76 |
77 | def simple_image_chat(use_stream=True, img_path=None):
78 | """
79 | Facilitates a simple chat interaction involving an image.
80 |
81 | Args:
82 | use_stream (bool): Specifies whether to use streaming for chat responses.
83 | img_path (str): Path to the image file to be included in the chat.
84 |
85 | This function encodes the specified image and constructs a predefined conversation involving the image.
86 | It then calls `create_chat_completion` to generate a response from the model.
87 | The conversation includes asking about the content of the image and a follow-up question.
88 | """
89 |
90 | img_url = f"data:image/jpeg;base64,{encode_image(img_path)}"
91 | messages = [
92 | {
93 | "role": "user",
94 | "content": [
95 | {
96 | "type": "text",
97 | "text": "What’s in this image?",
98 | },
99 | {
100 | "type": "image_url",
101 | "image_url": {
102 | "url": img_url
103 | },
104 | },
105 | ],
106 | },
107 | {
108 | "role": "assistant",
109 | "content": "The image displays a wooden boardwalk extending through a vibrant green grassy wetland. The sky is partly cloudy with soft, wispy clouds, indicating nice weather. Vegetation is seen on either side of the boardwalk, and trees are present in the background, suggesting that this area might be a natural reserve or park designed for ecological preservation and outdoor recreation. The boardwalk allows visitors to explore the area without disturbing the natural habitat.",
110 | },
111 | {
112 | "role": "user",
113 | "content": "Do you think this is a spring or winter photo?"
114 | },
115 | ]
116 | create_chat_completion("cogvlm-chat-17b", messages=messages, use_stream=use_stream)
117 |
118 |
119 | if __name__ == "__main__":
120 | simple_image_chat(use_stream=False, img_path="demo.jpg")
121 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | SwissArmyTransformer>=0.4.9
2 | transformers>=4.36.2
3 | xformers>=0.0.22
4 | torch>=2.1.0
5 | torchvision>=0.16.2
6 | spacy>=3.6.0
7 | pillow>=10.2.0
8 | deepspeed>=0.13.1
9 | seaborn>=0.13.2
10 | loguru~=0.7.2
11 | streamlit>=1.31.0
12 | timm>=0.9.12
13 | accelerate>=0.26.1
14 | pydantic>=2.6.0
15 |
16 | # for openai demo
17 | openai>=1.16.0
18 | sse-starlette>=1.8.2
19 | fastapi>=0.110.1
20 | httpx>=0.27.0
21 | uvicorn>=0.29.0
22 | jsonlines>=4.0.0
23 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THUDM/CogVLM/f7283b2c8d26cd7f932d9a5f7f5f9307f568195d/utils/__init__.py
--------------------------------------------------------------------------------
/utils/merge_model.py:
--------------------------------------------------------------------------------
1 | # -*- encoding: utf-8 -*-
2 | import os, sys
3 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
4 |
5 | import torch
6 | import argparse
7 | from models.cogvlm_model import FineTuneTestCogVLMModel
8 | from sat.training.model_io import save_checkpoint
9 |
10 | def main():
11 | parser = argparse.ArgumentParser()
12 | parser.add_argument("--version", type=str, default="base", help='version to interact with')
13 | parser.add_argument("--from_pretrained", type=str, default="checkpoints/merged_lora", help='pretrained ckpt')
14 | parser.add_argument("--fp16", action="store_true")
15 | parser.add_argument("--bf16", action="store_true")
16 | args = parser.parse_args()
17 | rank = int(os.environ.get('RANK', 0))
18 | world_size = int(os.environ.get('WORLD_SIZE', 1))
19 | parser = FineTuneTestCogVLMModel.add_model_specific_args(parser)
20 | args = parser.parse_args()
21 |
22 | # load model
23 | model, model_args = FineTuneTestCogVLMModel.from_pretrained(
24 | args.from_pretrained,
25 | args=argparse.Namespace(
26 | deepspeed=None,
27 | local_rank=rank,
28 | rank=rank,
29 | world_size=world_size,
30 | model_parallel_size=world_size,
31 | mode='inference',
32 | skip_init=True,
33 | use_gpu_initialization=True if torch.cuda.is_available() else False,
34 | device='cuda',
35 | **vars(args)
36 | ), url='local', overwrite_args={'model_parallel_size': 1})
37 | model = model.eval()
38 | model_args.save = './checkpoints/merged_model_{}'.format(model_args.eva_args["image_size"][0])
39 | save_checkpoint(1, model, None, None, model_args)
40 |
41 | if __name__ == "__main__":
42 | main()
43 |
--------------------------------------------------------------------------------
/utils/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .cogagent_model import CogAgentModel, FineTuneTrainCogAgentModel, FineTuneTestCogAgentModel
2 | from .cogvlm_model import CogVLMModel, FineTuneTrainCogVLMModel, FineTuneTestCogVLMModel
--------------------------------------------------------------------------------
/utils/models/cogagent_model.py:
--------------------------------------------------------------------------------
1 | from sat.model.official.llama_model import LLaMAModel
2 | import json
3 | import torch
4 | from functools import partial
5 | from sat.model.base_model import BaseMixin
6 | import torch.nn as nn
7 | import numpy as np
8 | from sat.resources.urls import MODEL_URLS
9 |
10 | from .eva_clip_L_hf import Eva2LargeEncoder
11 | from .mixin import LlamaVisionExpertFCMixin, LlamaVisionExpertAttnMixin
12 |
13 |
14 | MODEL_URLS["cogagent-chat"] = "r2://cogagent-chat.zip"
15 | MODEL_URLS["cogagent-vqa"] = "r2://cogagent-vqa.zip"
16 |
17 |
18 | class GLU(nn.Module):
19 | def __init__(self, args, in_features):
20 | super().__init__()
21 | self.linear_proj = nn.Linear(in_features, args.hidden_size, bias=False)
22 | self.norm1 = nn.LayerNorm(args.hidden_size)
23 | self.act1 = nn.GELU()
24 | self.act2 = nn.functional.silu
25 | self.dense_h_to_4h = nn.Linear(args.hidden_size, args.inner_hidden_size, bias=False)
26 | self.gate_proj = nn.Linear(args.hidden_size, args.inner_hidden_size, bias=False)
27 | self.dense_4h_to_h = nn.Linear(args.inner_hidden_size, args.hidden_size, bias=False)
28 |
29 | def forward(self, x):
30 | x = self.linear_proj(x)
31 | x = self.act1(self.norm1(x))
32 | x = self.act2(self.gate_proj(x)) * self.dense_h_to_4h(x)
33 | x = self.dense_4h_to_h(x)
34 | return x
35 |
36 | from .eva_clip_model import EVA2CLIPModel
37 | import argparse
38 | from copy import deepcopy
39 | def override_dist_dtype_device_args(args, b={}):
40 | if args.mode == 'inference':
41 | minimal_args = argparse.Namespace(
42 | world_size=args.world_size,
43 | rank=args.rank,
44 | local_rank=args.local_rank,
45 | skip_init=args.skip_init,
46 | use_gpu_initialization=args.use_gpu_initialization,
47 | deepspeed=args.deepspeed,
48 | bf16=args.bf16,
49 | fp16=args.fp16,
50 | mode=args.mode,
51 | device=args.device
52 | )
53 | else:
54 | minimal_args = argparse.Namespace(
55 | world_size=args.world_size,
56 | rank=args.rank,
57 | local_rank=args.local_rank,
58 | skip_init=args.skip_init,
59 | use_gpu_initialization=args.use_gpu_initialization,
60 | deepspeed=args.deepspeed,
61 | bf16=args.bf16,
62 | fp16=args.fp16,
63 | mode=args.mode,
64 | checkpoint_activations=args.checkpoint_activations if not hasattr(args, 'vit_checkpoint_activations') else args.vit_checkpoint_activations,
65 | checkpoint_num_layers=args.checkpoint_num_layers,
66 | device=args.device,
67 | hidden_dropout=0.,
68 | attention_dropout=0.,
69 | )
70 | if hasattr(args, 'model_parallel_size'):
71 | b['model_parallel_size'] = args.model_parallel_size
72 | return argparse.Namespace(**deepcopy(b), **vars(minimal_args))
73 |
74 |
75 | class ExternalVisionModel(BaseMixin):
76 | '''A combination of vit and a linear projection'''
77 | def __init__(self, args, vitclass):
78 | '''
79 | args: the args to initialize the vit model
80 | vitclass: the class of VIT model, must be a subclass of BaseModel
81 | project_dim: the dimension of the projection layer
82 | default_load: the default load path for the vit model
83 | model_parallel_size: the model parallel size for the vit model
84 | '''
85 | super().__init__()
86 | self.vit = vitclass()
87 | # self.ppx = nn.Embedding(80, 1024)
88 | # self.ppy = nn.Embedding(80, 1024)
89 | # nn.init.uniform_(self.ppx.weight.data)
90 | # nn.init.uniform_(self.ppy.weight.data)
91 |
92 | # self.pos_embed = nn.Parameter(
93 | # torch.from_numpy(get_2d_sincos_pos_embed(1024, 80)).float()
94 | # )
95 | cross_image_length = (args.cross_image_pix//14)**2
96 | self.pos_embed = nn.Parameter(
97 | torch.zeros(cross_image_length, 1024)
98 | )
99 |
100 | def forward(self, *args, **kw_args):
101 | enc = self.vit(*args, **kw_args)
102 | # i = torch.arange(80, device=enc.device)
103 | # j = torch.arange(80, device=enc.device)
104 | # posx = self.ppx(i).unsqueeze(0).repeat(80, 1, 1)
105 | # posy = self.ppy(j).unsqueeze(1).repeat(1, 80, 1)
106 | # pos = (posx + posy).view(-1, 1024).unsqueeze(0)
107 |
108 | # return enc + pos + self.pos_embed.unsqueeze(0)
109 | return enc + self.pos_embed.unsqueeze(0)
110 |
111 | class ImageMixin(BaseMixin):
112 | def __init__(self, args):
113 | super().__init__()
114 | vit_args = override_dist_dtype_device_args(args, args.eva_args)
115 | self.vit_model = EVA2CLIPModel(EVA2CLIPModel.get_args(**vars(vit_args)))
116 | self.in_features = 1792
117 | self.linear_proj = GLU(args, self.in_features)
118 | self.image_length = args.image_length
119 | self.boi = nn.Parameter(torch.zeros(1, 1, args.hidden_size))
120 | self.eoi = nn.Parameter(torch.zeros(1, 1, args.hidden_size))
121 |
122 | # self.ppx = nn.Embedding(16,1792)
123 | # self.ppy = nn.Embedding(16,1792)
124 |
125 | # self.pos_embed = nn.Parameter(
126 | # torch.from_numpy(get_2d_sincos_pos_embed(1792, 16)).float()
127 | # )
128 | self.pos_embed = nn.Parameter(
129 | torch.zeros(self.image_length, 1792)
130 | )
131 |
132 | def word_embedding_forward(self, input_ids, output_cross_layer, **kw_args):
133 | vision_inputs = {}
134 | for k in kw_args:
135 | if k.startswith('vision_') and k != 'vision_expert_mask':
136 | vision_inputs[k[7:]] = kw_args[k]
137 | if input_ids.shape[1] == 1 or not vision_inputs:
138 | return self.transformer.word_embeddings(input_ids)
139 | image_emb = self.vit_model(**vision_inputs)[0]
140 |
141 | # i = torch.arange(16, device=image_emb.device)
142 | # j = torch.arange(16, device=image_emb.device)
143 | # posx = self.ppx(i).unsqueeze(0).repeat(16, 1, 1)
144 | # posy = self.ppy(j).unsqueeze(1).repeat(1, 16, 1)
145 | # pos = (posx + posy).view(256, -1).unsqueeze(0)
146 | # image_emb = image_emb + pos + self.pos_embed.unsqueeze(0)
147 | image_emb = image_emb + self.pos_embed.unsqueeze(0)
148 |
149 | image_emb = self.linear_proj(image_emb)
150 |
151 | image_embed_mask = kw_args['image_embed_mask']
152 | word_embedding = self.transformer.word_embeddings(input_ids).clone()
153 | word_embedding[image_embed_mask.bool()] = torch.cat([self.boi.repeat(len(image_emb), 1, 1), image_emb, self.eoi.repeat(len(image_emb), 1, 1)], dim=1).reshape(-1, image_emb.shape[-1])
154 |
155 | return word_embedding.contiguous()
156 |
157 | class CogAgentModel(LLaMAModel):
158 | def __init__(self, args, transformer=None, **kwargs):
159 | super().__init__(args, transformer=transformer, **kwargs)
160 | self.image_length = args.image_length
161 | self.cross_image_pix = args.cross_image_pix
162 | self.add_mixin("eva", ImageMixin(args))
163 | self.del_mixin("mlp")
164 | self.add_mixin("mlp", LlamaVisionExpertFCMixin(args.hidden_size, args.inner_hidden_size, args.num_layers, 32))
165 | self.del_mixin("rotary")
166 | self.add_mixin("rotary", LlamaVisionExpertAttnMixin(args.hidden_size, args.num_attention_heads, args.num_layers, 32))
167 |
168 | cross_model = ExternalVisionModel(args, vitclass=partial(Eva2LargeEncoder, image_size=self.cross_image_pix))
169 | # if args.mode != 'inference':
170 | # cross_model.vit.model.set_grad_checkpointing(True)
171 | self.add_mixin("encoder", cross_model)
172 |
173 | @classmethod
174 | def add_model_specific_args(cls, parser):
175 | group = parser.add_argument_group('CogAgent', 'CogAgent Configurations')
176 | group.add_argument('--image_length', type=int, default=256)
177 | group.add_argument('--cross_image_pix', type=int, default=1120) # Standard CogAgent use 1120; if you want to adjust this param, finetune the model first.
178 | group.add_argument('--eva_args', type=json.loads, default={})
179 | return super().add_model_specific_args(parser)
180 |
181 | def forward(self, input_ids, vision_expert_mask, image_embed_mask, **kwargs):
182 |
183 | cross_inputs = {}
184 | for k in kwargs:
185 | if k.startswith('cross_'):
186 | cross_inputs[k[6:]] = kwargs[k]
187 | if kwargs.get("mems_cross") is not None:
188 | kwargs['encoder_outputs'] = kwargs["mems_cross"][0]
189 | else:
190 | outputs = self.get_mixin('encoder')(**cross_inputs)
191 | kwargs['encoder_outputs'] = outputs
192 | kwargs['cross_attention_mask'] = cross_inputs['attention_mask']
193 |
194 | if input_ids.shape[1] > 1:
195 | return super().forward(input_ids=input_ids, vision_expert_mask=vision_expert_mask, image_embed_mask=image_embed_mask, **kwargs)
196 | return super().forward(input_ids=input_ids, **kwargs)
197 |
198 |
199 | class FineTuneTrainCogAgentModel(CogAgentModel):
200 | def __init__(self, args, transformer=None, **kw_args):
201 | super().__init__(args, transformer=transformer, **kw_args)
202 | self.args = args
203 | # If you want to use model parallel with a mp_size=1 checkpoint, and meanwhile you also want to use lora,
204 | # you have to add_mixin after loading model checkpoint.
205 |
206 | @classmethod
207 | def add_model_specific_args(cls, parser):
208 | group = parser.add_argument_group('CogAgent-finetune', 'CogAgent finetune Configurations')
209 | group.add_argument('--pre_seq_len', type=int, default=8)
210 | group.add_argument('--lora_rank', type=int, default=10)
211 | group.add_argument('--use_ptuning', action="store_true")
212 | group.add_argument('--use_lora', action="store_true")
213 | group.add_argument('--use_qlora', action="store_true")
214 | group.add_argument('--layer_range', nargs='+', type=int, default=None)
215 | return super().add_model_specific_args(parser)
216 |
217 |
218 | from sat.model.finetune import PTuningV2Mixin
219 | from sat.model.finetune.lora2 import LoraMixin
220 | class FineTuneTestCogAgentModel(CogAgentModel):
221 | def __init__(self, args, transformer=None, **kw_args):
222 | super().__init__(args, transformer=transformer, **kw_args)
223 | if args.use_ptuning:
224 | self.add_mixin("ptuning", PTuningV2Mixin(args.num_layers, args.hidden_size // args.num_attention_heads, args.num_attention_heads, args.pre_seq_len))
225 | if args.use_lora:
226 | self.add_mixin("lora", LoraMixin(args.num_layers, args.lora_rank, layer_range=args.layer_range), reinit=True)
227 | self.get_mixin("eva").vit_model.add_mixin("lora", LoraMixin(args.eva_args['num_layers'], args.lora_rank, layer_range=args.layer_range), reinit=True)
228 | elif args.use_qlora:
229 | self.add_mixin("lora", LoraMixin(args.num_layers, args.lora_rank, layer_range=args.layer_range, qlora=True), reinit=True)
230 | self.args = args
231 |
232 | @classmethod
233 | def add_model_specific_args(cls, parser):
234 | group = parser.add_argument_group('CogAgent-finetune', 'CogAgent finetune Configurations')
235 | group.add_argument('--pre_seq_len', type=int, default=8)
236 | group.add_argument('--lora_rank', type=int, default=10)
237 | group.add_argument('--use_ptuning', action="store_true")
238 | group.add_argument('--use_lora', action="store_true")
239 | group.add_argument('--use_qlora', action="store_true")
240 | group.add_argument('--layer_range', nargs='+', type=int, default=None)
241 | return super().add_model_specific_args(parser)
242 |
--------------------------------------------------------------------------------
/utils/models/cogvlm_model.py:
--------------------------------------------------------------------------------
1 | from sat.model.official.llama_model import LLaMAModel
2 | import json
3 | import torch
4 | from sat.model.base_model import BaseMixin
5 | import torch.nn as nn
6 | from .mixin import LlamaVisionExpertFCMixin, LlamaVisionExpertAttnMixin
7 |
8 | from sat.resources.urls import MODEL_URLS
9 |
10 | MODEL_URLS["cogvlm-base-224"] = "r2://cogvlm-base-224.zip"
11 | MODEL_URLS["cogvlm-base-490"] = "r2://cogvlm-base-490.zip"
12 | MODEL_URLS["cogvlm-chat-v1.1"] = "r2://cogvlm-chat-v1.1.zip"
13 | MODEL_URLS["cogvlm-grounding-base"] = "r2://cogvlm-grounding-base.zip"
14 | MODEL_URLS["cogvlm-grounding-generalist-v1.1"] = "r2://cogvlm-grounding-generalist-v1.1.zip"
15 |
16 |
17 | class GLU(nn.Module):
18 | def __init__(self, args, in_features):
19 | super().__init__()
20 | self.linear_proj = nn.Linear(in_features, args.hidden_size, bias=False)
21 | self.norm1 = nn.LayerNorm(args.hidden_size)
22 | self.act1 = nn.GELU()
23 | self.act2 = nn.functional.silu
24 | self.dense_h_to_4h = nn.Linear(args.hidden_size, args.inner_hidden_size, bias=False)
25 | self.gate_proj = nn.Linear(args.hidden_size, args.inner_hidden_size, bias=False)
26 | self.dense_4h_to_h = nn.Linear(args.inner_hidden_size, args.hidden_size, bias=False)
27 |
28 | def forward(self, x):
29 | x = self.linear_proj(x)
30 | x = self.act1(self.norm1(x))
31 | x = self.act2(self.gate_proj(x)) * self.dense_h_to_4h(x)
32 | x = self.dense_4h_to_h(x)
33 | return x
34 |
35 | from .eva_clip_model import EVA2CLIPModel
36 | import argparse
37 | from copy import deepcopy
38 | def override_dist_dtype_device_args(args, b={}):
39 | if args.mode == 'inference':
40 | minimal_args = argparse.Namespace(
41 | world_size=args.world_size,
42 | rank=args.rank,
43 | local_rank=args.local_rank,
44 | skip_init=args.skip_init,
45 | use_gpu_initialization=args.use_gpu_initialization,
46 | deepspeed=args.deepspeed,
47 | bf16=args.bf16,
48 | fp16=args.fp16,
49 | mode=args.mode,
50 | device=args.device
51 | )
52 | else:
53 | minimal_args = argparse.Namespace(
54 | world_size=args.world_size,
55 | rank=args.rank,
56 | local_rank=args.local_rank,
57 | skip_init=args.skip_init,
58 | use_gpu_initialization=args.use_gpu_initialization,
59 | deepspeed=args.deepspeed,
60 | bf16=args.bf16,
61 | fp16=args.fp16,
62 | mode=args.mode,
63 | checkpoint_activations=args.checkpoint_activations if not hasattr(args, 'vit_checkpoint_activations') else args.vit_checkpoint_activations,
64 | checkpoint_num_layers=args.checkpoint_num_layers,
65 | device=args.device,
66 | hidden_dropout=0.,
67 | attention_dropout=0.,
68 | )
69 | if hasattr(args, 'model_parallel_size'):
70 | b['model_parallel_size'] = args.model_parallel_size
71 | return argparse.Namespace(**deepcopy(b), **vars(minimal_args))
72 |
73 | class ImageMixin(BaseMixin):
74 | def __init__(self, args):
75 | super().__init__()
76 | vit_args = override_dist_dtype_device_args(args, args.eva_args)
77 | self.vit_model = EVA2CLIPModel(EVA2CLIPModel.get_args(**vars(vit_args)))
78 | self.in_features = 1792
79 | self.linear_proj = GLU(args, self.in_features)
80 | self.image_length = args.image_length
81 | self.boi = nn.Parameter(torch.zeros(1, 1, args.hidden_size))
82 | self.eoi = nn.Parameter(torch.zeros(1, 1, args.hidden_size))
83 |
84 | def word_embedding_forward(self, input_ids, output_cross_layer, **kw_args):
85 | vision_inputs = {}
86 | for k in kw_args:
87 | if k.startswith('vision_') and k != 'vision_expert_mask':
88 | vision_inputs[k[7:]] = kw_args[k]
89 | if input_ids.shape[1] == 1 or not vision_inputs:
90 | return self.transformer.word_embeddings(input_ids)
91 | image_emb = self.vit_model(**vision_inputs)[0]
92 | image_emb = self.linear_proj(image_emb)
93 |
94 | image_embed_mask = kw_args['image_embed_mask']
95 | word_embedding = self.transformer.word_embeddings(input_ids).clone()
96 | word_embedding[image_embed_mask.bool()] = torch.cat([self.boi.repeat(len(image_emb), 1, 1), image_emb, self.eoi.repeat(len(image_emb), 1, 1)], dim=1).reshape(-1, image_emb.shape[-1])
97 | return word_embedding.contiguous()
98 |
99 |
100 | class CogVLMModel(LLaMAModel):
101 | def __init__(self, args, transformer=None, **kwargs):
102 | super().__init__(args, transformer=transformer, **kwargs)
103 | self.image_length = args.image_length
104 | self.add_mixin("eva", ImageMixin(args))
105 | self.del_mixin("mlp")
106 | self.add_mixin("mlp", LlamaVisionExpertFCMixin(args.hidden_size, args.inner_hidden_size, args.num_layers, 32))
107 | self.del_mixin("rotary")
108 | self.add_mixin("rotary", LlamaVisionExpertAttnMixin(args.hidden_size, args.num_attention_heads, args.num_layers, 32))
109 |
110 | @classmethod
111 | def add_model_specific_args(cls, parser):
112 | group = parser.add_argument_group('CogVLM', 'CogVLM Configurations')
113 | group.add_argument('--image_length', type=int, default=256)
114 | group.add_argument('--eva_args', type=json.loads, default={})
115 | return super().add_model_specific_args(parser)
116 |
117 | def forward(self, input_ids, vision_expert_mask, image_embed_mask, **kwargs):
118 | if input_ids.shape[1] > 1:
119 | return super().forward(input_ids=input_ids, vision_expert_mask=vision_expert_mask, image_embed_mask=image_embed_mask, **kwargs)
120 | return super().forward(input_ids=input_ids, **kwargs)
121 |
122 |
123 | class FineTuneTrainCogVLMModel(CogVLMModel):
124 | def __init__(self, args, transformer=None, **kw_args):
125 | super().__init__(args, transformer=transformer, **kw_args)
126 | self.args = args
127 | # If you want to use model parallel with a mp_size=1 checkpoint, and meanwhile you also want to use lora,
128 | # you have to add_mixin after loading model checkpoint.
129 |
130 | @classmethod
131 | def add_model_specific_args(cls, parser):
132 | group = parser.add_argument_group('CogVLM-finetune', 'CogVLM finetune Configurations')
133 | group.add_argument('--pre_seq_len', type=int, default=8)
134 | group.add_argument('--lora_rank', type=int, default=10)
135 | group.add_argument('--use_ptuning', action="store_true")
136 | group.add_argument('--use_lora', action="store_true")
137 | group.add_argument('--use_qlora', action="store_true")
138 | group.add_argument('--layer_range', nargs='+', type=int, default=None)
139 | return super().add_model_specific_args(parser)
140 |
141 |
142 | from sat.model.finetune import PTuningV2Mixin
143 | from sat.model.finetune.lora2 import LoraMixin
144 | class FineTuneTestCogVLMModel(CogVLMModel):
145 | def __init__(self, args, transformer=None, **kw_args):
146 | super().__init__(args, transformer=transformer, **kw_args)
147 | if args.use_ptuning:
148 | self.add_mixin("ptuning", PTuningV2Mixin(args.num_layers, args.hidden_size // args.num_attention_heads, args.num_attention_heads, args.pre_seq_len))
149 | if args.use_lora:
150 | self.add_mixin("lora", LoraMixin(args.num_layers, args.lora_rank, layer_range=args.layer_range), reinit=True)
151 | self.get_mixin("eva").vit_model.add_mixin("lora", LoraMixin(args.eva_args['num_layers'], args.lora_rank, layer_range=args.layer_range), reinit=True)
152 | elif args.use_qlora:
153 | self.add_mixin("lora", LoraMixin(args.num_layers, args.lora_rank, layer_range=args.layer_range, qlora=True), reinit=True)
154 | self.args = args
155 |
156 | @classmethod
157 | def add_model_specific_args(cls, parser):
158 | group = parser.add_argument_group('CogVLM-finetune', 'CogVLM finetune Configurations')
159 | group.add_argument('--pre_seq_len', type=int, default=8)
160 | group.add_argument('--lora_rank', type=int, default=10)
161 | group.add_argument('--use_ptuning', action="store_true")
162 | group.add_argument('--use_lora', action="store_true")
163 | group.add_argument('--use_qlora', action="store_true")
164 | group.add_argument('--layer_range', nargs='+', type=int, default=None)
165 | return super().add_model_specific_args(parser)
166 |
--------------------------------------------------------------------------------
/utils/models/eva_clip_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from sat.model.base_model import BaseModel
3 | from sat.model.mixins import BaseMixin
4 | from sat.model.official.vit_model import ViTProperty, ImagePatchEmbeddingMixin, InterpolatedPositionEmbeddingMixin, gelu
5 | from sat import mpu
6 |
7 | class IdentityMixin(BaseMixin):
8 | def __init__(self):
9 | super().__init__()
10 |
11 | def final_forward(self, logits, **kwargs):
12 | return logits[:, 1:]
13 |
14 | import xformers.ops as xops
15 | class XAttn(BaseMixin):
16 | def __init__(self, head_dim):
17 | super().__init__()
18 | self.scale = head_dim ** -0.5
19 |
20 | def attention_fn(self, query_layer, key_layer, value_layer, attention_mask,
21 | attention_dropout=None, log_attention_weights=None, scaling_attention_score=True, **kwargs):
22 | dropout_p = 0. # xformers does not support dropout for eva hidden size
23 |
24 | query_layer = query_layer.permute(0, 2, 1, 3) # B, num_heads, N, C -> B, N, num_heads, C
25 | key_layer = key_layer.permute(0, 2, 1, 3)
26 | value_layer = value_layer.permute(0, 2, 1, 3)
27 |
28 | out = xops.memory_efficient_attention(
29 | query_layer, key_layer, value_layer,
30 | p=dropout_p,
31 | scale=self.scale,
32 | )
33 | return out
34 |
35 | def attention_forward(self, hidden_states, mask, **kw_args):
36 | self = self.transformer.layers[kw_args['layer_id']].attention
37 | attention_fn = self.hooks['attention_fn']
38 |
39 | mixed_raw_layer = self.query_key_value(hidden_states)
40 |
41 | B, N, C = hidden_states.shape
42 | mixed_raw_layer = mixed_raw_layer.reshape(B, N, 3, self.num_attention_heads_per_partition, -1).permute(2, 0, 3, 1, 4) # 3, B, num_heads, N, C
43 | query_layer, key_layer, value_layer = mixed_raw_layer[0], mixed_raw_layer[1], mixed_raw_layer[2]
44 |
45 | dropout_fn = self.attention_dropout if self.training else None
46 |
47 | context_layer = attention_fn(query_layer, key_layer, value_layer, mask, dropout_fn, **kw_args)
48 |
49 | context_layer = context_layer.view(B, N, -1)
50 | output = self.dense(context_layer)
51 |
52 | if self.training:
53 | output = self.output_dropout(output)
54 | return output
55 |
56 | class NewLayerForward(BaseMixin):
57 | def __init__(self):
58 | super().__init__()
59 |
60 | def layer_forward(self, hidden_states, mask, *args, **kw_args):
61 | '''
62 | hidden_states: [batch, seq_len, hidden_size]
63 | mask: [(1, 1), seq_len, seq_len]
64 | '''
65 | self = self.transformer.layers[kw_args['layer_id']]
66 |
67 | attention_input = hidden_states
68 |
69 | # Self attention.
70 | attention_output = self.input_layernorm(self.attention(attention_input, mask, **kw_args))
71 |
72 | # DropPath for attention
73 | if self.training and self.drop_path > 0.:
74 | if mpu.get_cuda_rng_tracker is not None:
75 | # drop_path must use model parallel rng tracker
76 | # the tracker is initialized as seed of `seed + model_parallel_rank`
77 | # deepspeed act-ckpt record the model parallel tracker states
78 | with mpu.get_cuda_rng_tracker().fork():
79 | # drop_path percentage 0, others 1/(1-p)
80 | random_tensor = (1-self.drop_path
81 | + torch.rand((attention_output.shape[0],), dtype=attention_output.dtype, device=attention_output.device)).floor_() / (1-self.drop_path)
82 | attention_output = random_tensor.view(-1, 1, 1) * attention_output
83 |
84 | # Residual connection.
85 | hidden_states = attention_input + attention_output
86 | mlp_input = hidden_states
87 |
88 | # MLP.
89 | mlp_output = self.post_attention_layernorm(self.mlp(mlp_input, **kw_args))
90 |
91 | # DropPath for mlp
92 | if self.training and self.drop_path > 0.:
93 | if mpu.get_cuda_rng_tracker is not None:
94 | with mpu.get_cuda_rng_tracker().fork():
95 | random_tensor = (1-self.drop_path
96 | + torch.rand((mlp_output.shape[0],), dtype=mlp_output.dtype, device=mlp_output.device)).floor_() / (1-self.drop_path)
97 | mlp_output = random_tensor.view(-1, 1, 1) * mlp_output
98 |
99 | # Second residual connection.
100 | output = mlp_input + mlp_output
101 |
102 | return output
103 |
104 | class EVA2CLIPModel(BaseModel):
105 | def __init__(self, args, transformer=None, **kwargs):
106 | property = ViTProperty(args.image_size, args.patch_size, args.pre_len, args.post_len)
107 | args.max_sequence_length = property.pre_len + property.num_patches + property.post_len
108 | if 'activation_func' not in kwargs:
109 | kwargs['activation_func'] = gelu
110 | super().__init__(args, transformer=transformer, **kwargs)
111 | self.transformer.property = property
112 | self.add_mixin("patch_embedding", ImagePatchEmbeddingMixin(args.in_channels, args.hidden_size, property))
113 | self.add_mixin("pos_embedding", InterpolatedPositionEmbeddingMixin())
114 | self.add_mixin("final", IdentityMixin())
115 | self.add_mixin("newpost", NewLayerForward())
116 | self.add_mixin("xattn", XAttn(args.hidden_size // args.num_attention_heads))
117 |
118 | @classmethod
119 | def add_model_specific_args(cls, parser):
120 | group = parser.add_argument_group('EVA2CLIP', 'EVA2CLIP Configurations')
121 | group.add_argument('--image-size', nargs='+', type=int, default=[224, 224])
122 | group.add_argument('--pre-len', type=int, default=1) # [cls] by default
123 | group.add_argument('--post-len', type=int, default=0) # empty by default, but sometimes with special tokens, such as [det] in yolos.
124 | group.add_argument('--in-channels', type=int, default=3)
125 | group.add_argument('--patch-size', type=int, default=16)
126 | return parser
127 |
128 |
--------------------------------------------------------------------------------
/utils/models/mixin.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from sat.transformer_defaults import attention_fn_default
5 | from sat.model.base_model import BaseMixin, non_conflict
6 | from sat.mpu.layers import ColumnParallelLinear, RowParallelLinear
7 | from sat.mpu.utils import split_tensor_along_last_dim
8 | from sat import mpu
9 |
10 |
11 | class LlamaVisionExpertFCMixin(BaseMixin):
12 | def __init__(self, in_features, hidden_features, num_layers=32, num_vision_layers=0, vision_layer_range=None,
13 | params_dtype=torch.float, device=torch.device('cpu')):
14 | super().__init__()
15 |
16 | self.num_layers = num_layers
17 | self.num_vision_layers = num_vision_layers
18 | if vision_layer_range is None:
19 | vision_layer_range = [i for i in range(min(num_vision_layers, num_layers))]
20 | self.vision_layer_range = vision_layer_range
21 | self.gate_proj = nn.ModuleList([ColumnParallelLinear(
22 | in_features,
23 | hidden_features,
24 | gather_output=False,
25 | init_method=None,
26 | bias=False,
27 | params_dtype=params_dtype,
28 | module=self,
29 | name="dense_h_to_4h_gate",
30 | skip_init=True,
31 | device=device
32 | ) for i in range(num_layers)])
33 | # Trainable vision expert parameters
34 | vision_dense_h_to_4h_list = []
35 | vision_dense_4h_to_h_list = []
36 | gate_proj_list = []
37 |
38 |
39 | for i in vision_layer_range:
40 | vision_dense_h_to_4h = ColumnParallelLinear(
41 | in_features,
42 | hidden_features,
43 | gather_output=False,
44 | init_method=None,
45 | bias=False,
46 | params_dtype=params_dtype,
47 | module=self,
48 | name="vision_dense_h_to_4h",
49 | skip_init=True,
50 | device=device
51 | )
52 |
53 | # Project back to h.
54 | vision_dense_4h_to_h = RowParallelLinear(
55 | hidden_features,
56 | in_features,
57 | input_is_parallel=True,
58 | init_method=None,
59 | bias=False,
60 | params_dtype=params_dtype,
61 | module=self,
62 | name="vision_dense_4h_to_h",
63 | skip_init=True,
64 | device=device
65 | )
66 |
67 | gate_proj = ColumnParallelLinear(
68 | in_features,
69 | hidden_features,
70 | gather_output=False,
71 | init_method=None,
72 | bias=False,
73 | params_dtype=params_dtype,
74 | module=self,
75 | name="vision_gate_proj",
76 | skip_init=True,
77 | device=device
78 | )
79 |
80 | vision_dense_h_to_4h_list.append(vision_dense_h_to_4h)
81 | vision_dense_4h_to_h_list.append(vision_dense_4h_to_h)
82 | gate_proj_list.append(gate_proj)
83 |
84 | self.vision_dense_h_to_4h_list = nn.ModuleDict([
85 | (str(layer_id), vision_dense_h_to_4h)
86 | for layer_id, vision_dense_h_to_4h in zip(vision_layer_range, vision_dense_h_to_4h_list)
87 | ])
88 | self.vision_dense_4h_to_h_list = nn.ModuleDict([
89 | (str(layer_id), vision_dense_4h_to_h)
90 | for layer_id, vision_dense_4h_to_h in zip(vision_layer_range, vision_dense_4h_to_h_list)
91 | ])
92 | self.vision_gate_proj = nn.ModuleDict([
93 | (str(layer_id), gate_proj)
94 | for layer_id, gate_proj in zip(vision_layer_range, gate_proj_list)
95 | ])
96 |
97 | def mlp_forward(self, hidden_states, **kw_args):
98 | mixin_self = self
99 | self = self.transformer.layers[kw_args['layer_id']].mlp
100 | if "vision_expert_mask" in kw_args:
101 | vision_expert_mask = kw_args['vision_expert_mask']
102 | else:
103 | vision_expert_mask = None
104 |
105 | layer_id_key = str(int(kw_args['layer_id']))
106 |
107 | if kw_args['layer_id'] in mixin_self.vision_layer_range and (vision_expert_mask is not None) and vision_expert_mask.any():
108 | vision_dense_h_to_4h = mixin_self.vision_dense_h_to_4h_list[layer_id_key]
109 | vision_dense_4h_to_h = mixin_self.vision_dense_4h_to_h_list[layer_id_key]
110 | vision_gate_proj = mixin_self.vision_gate_proj[layer_id_key]
111 | output = torch.empty(hidden_states.shape, dtype=hidden_states.dtype, device=hidden_states.device)
112 |
113 | language_hidden_state = hidden_states[~vision_expert_mask.bool()]
114 | language_intermediate_parallel = self.activation_func(mixin_self.gate_proj[kw_args['layer_id']](language_hidden_state)) * self.dense_h_to_4h(language_hidden_state)
115 | output[~vision_expert_mask.bool()] = self.dense_4h_to_h(language_intermediate_parallel) # language_output
116 |
117 | vision_hidden_state = hidden_states[vision_expert_mask.bool()]
118 | vision_intermediate_parallel = vision_dense_h_to_4h(vision_hidden_state)
119 | gate_output = vision_gate_proj(vision_hidden_state)
120 |
121 | vision_intermediate_parallel *= self.activation_func(gate_output)
122 | output[vision_expert_mask.bool()] = vision_dense_4h_to_h(vision_intermediate_parallel) # vision_output
123 | else:
124 | intermediate_parallel = self.activation_func(mixin_self.gate_proj[kw_args['layer_id']](hidden_states)) * self.dense_h_to_4h(hidden_states)
125 | output = self.dense_4h_to_h(intermediate_parallel)
126 |
127 | return output.contiguous()
128 |
129 | def copy_param(self):
130 | with torch.no_grad():
131 | for i in self.vision_layer_range:
132 | self.vision_gate_proj[str(i)].weight.data.copy_(self.gate_proj[i].weight.data)
133 | self.vision_dense_4h_to_h_list[str(i)].weight.data.copy_(self.transformer.layers[i].mlp.dense_4h_to_h.weight.data)
134 | self.vision_dense_h_to_4h_list[str(i)].weight.data.copy_(self.transformer.layers[i].mlp.dense_h_to_4h.weight.data)
135 |
136 | from sat.mpu import get_model_parallel_world_size
137 | from sat.mpu.utils import divide
138 | from sat.model.position_embedding.triton_rotary_embeddings import FastRotaryEmbedding
139 |
140 | class LlamaVisionExpertAttnMixin(BaseMixin):
141 | def __init__(self, hidden_size, num_heads, num_layers=28, num_vision_layers=0, use_vision_expert=True, vision_layer_range=None,
142 | params_dtype=torch.float, device=torch.device('cpu')):
143 | super().__init__()
144 |
145 | world_size = get_model_parallel_world_size()
146 | self.hidden_size = hidden_size
147 | self.num_attention_heads = num_heads
148 | self.hidden_size_per_attention_head = divide(hidden_size, num_heads)
149 | self.num_attention_heads_per_partition = divide(num_heads, world_size)
150 | self.inner_hidden_size = num_heads * self.hidden_size_per_attention_head
151 |
152 | self.rotary_emb = FastRotaryEmbedding(
153 | hidden_size // num_heads, pos_idx_in_fp32=False
154 | )
155 |
156 | self.num_vision_layers = num_vision_layers
157 | self.num_layers = num_layers
158 | if vision_layer_range is None:
159 | vision_layer_range = [i for i in range(min(num_vision_layers, num_layers))]
160 | self.vision_layer_range = vision_layer_range
161 |
162 | self.use_vision_expert = use_vision_expert
163 | # Trainable vision expert parameters
164 |
165 | if self.use_vision_expert:
166 | vision_query_key_value_list = []
167 | vision_dense_list = []
168 | for i in vision_layer_range:
169 | vision_query_key_value = ColumnParallelLinear(
170 | hidden_size,
171 | 3 * hidden_size,
172 | stride=3,
173 | gather_output=False,
174 | init_method=None,
175 | bias=False,
176 | params_dtype=params_dtype,
177 | module=self,
178 | name="vision_query_key_value",
179 | skip_init=True,
180 | device=device
181 | )
182 |
183 | vision_dense = RowParallelLinear(
184 | self.inner_hidden_size,
185 | hidden_size,
186 | input_is_parallel=True,
187 | init_method=None,
188 | bias=False,
189 | params_dtype=params_dtype,
190 | module=self,
191 | name="vision_dense",
192 | skip_init=True,
193 | device=device,
194 | final_bias=False
195 | )
196 |
197 | vision_query_key_value_list.append(vision_query_key_value)
198 | vision_dense_list.append(vision_dense)
199 |
200 | self.vision_query_key_value_list = nn.ModuleDict([
201 | (str(layer_id), vision_query_key_value)
202 | for layer_id, vision_query_key_value in zip(vision_layer_range, vision_query_key_value_list)
203 | ])
204 | self.vision_dense_list = nn.ModuleDict([
205 | (str(layer_id), vision_dense)
206 | for layer_id, vision_dense in zip(vision_layer_range, vision_dense_list)
207 | ])
208 |
209 | def attention_forward(self, hidden_states, mask, **kw_args):
210 | mixin_self = self
211 | self = self.transformer.layers[kw_args['layer_id']].attention
212 | attention_fn = attention_fn_default
213 | if 'attention_fn' in self.hooks:
214 | attention_fn = self.hooks['attention_fn']
215 | if "vision_expert_mask" in kw_args:
216 | vision_expert_mask = kw_args['vision_expert_mask']
217 | else:
218 | vision_expert_mask = None
219 |
220 | layer_id_key = str(int(kw_args['layer_id']))
221 | if mixin_self.use_vision_expert and kw_args['layer_id'] in mixin_self.vision_layer_range and (
222 | vision_expert_mask is not None) and vision_expert_mask.any():
223 | shape = list(hidden_states.shape)
224 | parallel_size = mpu.get_model_parallel_world_size()
225 | shape[-1] = shape[-1] * 3 // parallel_size
226 | vision_query_key_value = mixin_self.vision_query_key_value_list[layer_id_key]
227 | mixed_raw_layer = torch.empty(shape, dtype=hidden_states.dtype, device=hidden_states.device)
228 | language_hidden_states = hidden_states[~vision_expert_mask.bool()]
229 | vision_hidden_states = hidden_states[vision_expert_mask.bool()]
230 | mixed_raw_layer[~vision_expert_mask.bool()] = self.query_key_value(
231 | language_hidden_states) # language_mixed_raw_layer
232 | mixed_raw_layer[vision_expert_mask.bool()] = vision_query_key_value(
233 | vision_hidden_states) # vision_mixed_raw_layer
234 | else:
235 | mixed_raw_layer = self.query_key_value(hidden_states)
236 |
237 | (mixed_query_layer,
238 | mixed_key_layer,
239 | mixed_value_layer) = split_tensor_along_last_dim(mixed_raw_layer, 3)
240 |
241 | dropout_fn = self.attention_dropout if self.training else None
242 |
243 | query_layer = self._transpose_for_scores(mixed_query_layer)
244 | key_layer = self._transpose_for_scores(mixed_key_layer)
245 | value_layer = self._transpose_for_scores(mixed_value_layer)
246 |
247 | query_layer, key_layer = mixin_self.rotary_emb(query_layer,key_layer, kw_args['position_ids'], max_seqlen=kw_args['position_ids'].max()+1, layer_id=kw_args['layer_id'])
248 |
249 | context_layer = attention_fn(query_layer, key_layer, value_layer, mask, dropout_fn, **kw_args)
250 |
251 | context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
252 | new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
253 | context_layer = context_layer.view(*new_context_layer_shape)
254 |
255 | if mixin_self.use_vision_expert and kw_args['layer_id'] in mixin_self.vision_layer_range and (
256 | vision_expert_mask is not None) and vision_expert_mask.any():
257 | vision_dense = mixin_self.vision_dense_list[layer_id_key]
258 | parallel_size = mpu.get_model_parallel_world_size()
259 | target_shape = context_layer.shape[:-1] + (context_layer.shape[-1] * parallel_size,)
260 | output = torch.empty(target_shape, dtype=hidden_states.dtype, device=hidden_states.device)
261 | output[~vision_expert_mask.bool()] = self.dense(context_layer[~vision_expert_mask.bool()]) # language
262 | output[vision_expert_mask.bool()] = vision_dense(context_layer[vision_expert_mask.bool()]) # vision
263 | else:
264 | output = self.dense(context_layer)
265 |
266 | if self.training:
267 | output = self.output_dropout(output)
268 | return output.contiguous()
269 |
270 | def copy_param(self):
271 | with torch.no_grad():
272 | for i in self.vision_layer_range:
273 | self.vision_query_key_value_list[str(i)].weight.data.copy_(self.transformer.layers[i].attention.query_key_value.weight.data)
274 | self.vision_dense_list[str(i)].weight.data.copy_(self.transformer.layers[i].attention.dense.weight.data)
--------------------------------------------------------------------------------
/utils/split_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import shutil
3 |
4 | def find_all_files(path, suffix=".jpg"):
5 | target_files = []
6 | for cur_dir, _, files in os.walk(path, followlinks=True):
7 | for f in files:
8 | if f.endswith(suffix):
9 | target_files.append(os.path.join(cur_dir, f))
10 | print(f'find {len(target_files)} files...')
11 | return target_files
12 |
13 | all_files = find_all_files('archive')
14 | os.makedirs("archive_split", exist_ok=True)
15 | os.makedirs("archive_split/train", exist_ok=True)
16 | os.makedirs("archive_split/valid", exist_ok=True)
17 | os.makedirs("archive_split/test", exist_ok=True)
18 |
19 | import random
20 | random.seed(2023)
21 | random.shuffle(all_files)
22 | train = all_files[:8000]
23 | valid = all_files[8000:8000+500]
24 | test = all_files[8000+500:8000+500+1500]
25 |
26 | print("building train")
27 | for file in train:
28 | shutil.move(file, os.path.join("archive_split/train", file.split("/")[-1]))
29 | print("building valid")
30 | for file in valid:
31 | shutil.move(file, os.path.join("archive_split/valid", file.split("/")[-1]))
32 | print("building test")
33 | for file in test:
34 | shutil.move(file, os.path.join("archive_split/test", file.split("/")[-1]))
35 | print("done")
--------------------------------------------------------------------------------
/utils/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .chat import chat
2 | from .language import llama2_tokenizer, llama2_text_processor, llama2_text_processor_inference
3 | from .vision import get_image_processor
4 | from .grounding_parser import parse_response
5 | from .dataset import ItemDataset
--------------------------------------------------------------------------------
/utils/utils/chat.py:
--------------------------------------------------------------------------------
1 | # -*- encoding: utf-8 -*-
2 | '''
3 | @File : chat.py
4 | @Time : 2023/05/08 19:10:08
5 | @Author : Ming Ding
6 | @Contact : dm18@mails.tsinghua.edu.cn
7 | '''
8 |
9 | from typing import Optional, Tuple, Union, List, Callable, Dict, Any
10 | import requests
11 | from PIL import Image
12 | from io import BytesIO
13 |
14 | import torch
15 | from sat.generation.autoregressive_sampling import filling_sequence, stream_filling_sequence, get_masks_and_position_ids_default
16 | from sat.generation.sampling_strategies import BaseStrategy, BeamSearchStrategy
17 | from sat.mpu import get_model_parallel_rank
18 |
19 | def process_image(image_path, img_processor, cross_img_processor, image):
20 | if image is None:
21 | if image_path.startswith("http"):
22 | response = requests.get(image_path, timeout=10)
23 | image = Image.open(BytesIO(response.content))
24 | else:
25 | image = Image.open(image_path)
26 |
27 | if image is not None and isinstance(image, Image.Image):
28 | pil_img = image.convert('RGB')
29 | img_dict = img_processor(pil_img)
30 | cross_img_dict = cross_img_processor(pil_img) if cross_img_processor is not None else {}
31 | ret = (img_dict, pil_img, cross_img_dict)
32 | else:
33 | ret = image
34 | return ret
35 |
36 | def chat(image_path, model, text_processor, img_processor,
37 | query: str, history: List[Tuple[str, str]] = None, cross_img_processor=None, image: Image = None,
38 | max_length: int = 4096, top_p=0.95, top_k=5, temperature=0.95, repetition_penalty=1.0,
39 | invalid_slices=[], no_prompt=False, args=None
40 | ):
41 | if image is None:
42 | assert image_path is not None
43 | if not history:
44 | history = []
45 |
46 | if no_prompt:
47 | query = ''
48 | prompt = text_processor.history_to_prompt(query, history)
49 |
50 | (torch_image, pil_img, cross_image) = process_image(image_path, img_processor, cross_img_processor, image)
51 |
52 | if torch_image is not None:
53 | for k in torch_image:
54 | if type(torch_image[k]) is torch.Tensor and torch_image[k].dtype is not torch.int and torch_image[k].dtype is not torch.long:
55 | torch_image[k] = torch_image[k].to(torch.bfloat16 if args.bf16 else torch.float16)
56 | if type(torch_image[k]) is torch.Tensor:
57 | torch_image[k] = torch_image[k].to(next(model.parameters()).device)
58 |
59 | if cross_image is not None:
60 | for k in cross_image:
61 | if type(cross_image[k]) is torch.Tensor and cross_image[k].dtype is not torch.int and cross_image[k].dtype is not torch.long:
62 | cross_image[k] = cross_image[k].to(torch.bfloat16 if args.bf16 else torch.float16)
63 | if type(cross_image[k]) is torch.Tensor:
64 | cross_image[k] = cross_image[k].to(next(model.parameters()).device)
65 |
66 | inputs_dic = text_processor(prompt)
67 | for k in inputs_dic:
68 | if type(inputs_dic[k]) is torch.Tensor and inputs_dic[k].dtype is not torch.int and inputs_dic[k].dtype is not torch.long:
69 | inputs_dic[k] = inputs_dic[k].to(torch.bfloat16 if args.bf16 else torch.float16)
70 | if type(inputs_dic[k]) is torch.Tensor:
71 | inputs_dic[k] = inputs_dic[k].to(next(model.parameters()).device)
72 | input_ids = inputs_dic['input_ids'].to(model.parameters().__next__().device)[0]
73 |
74 | if max_length-len(input_ids) <= 1:
75 | response = "The prompt exceeds the context length limit, please try again."
76 | return response, history, (torch_image, pil_img)
77 |
78 | seq = torch.cat(
79 | [input_ids, torch.tensor([-1]*(max_length-len(input_ids)), device=input_ids.device)], dim=0
80 | )
81 | strategy = BaseStrategy(temperature=temperature, top_p=top_p, top_k=top_k, end_tokens=[text_processor.tokenizer.eos_token_id],
82 | invalid_slices=invalid_slices, repetition_penalty=repetition_penalty)
83 | # use beam search to get a better result
84 | # strategy = BeamSearchStrategy(temperature=temperature, top_p=top_p, top_k=top_k, end_tokens=[text_processor.tokenizer.eos_token_id],
85 | # num_beams=5, consider_end=True, repetition_penalty=repetition_penalty)
86 | get_func = text_processor.get_func(input_ids, **inputs_dic) if hasattr(text_processor, 'get_func') else get_masks_and_position_ids_default
87 |
88 | img_inputs = {'vision_'+k: v for k, v in torch_image.items()}
89 | if cross_image is not None:
90 | img_inputs = {**img_inputs, **{'cross_'+k:v for k,v in cross_image.items()}}
91 | inputs_dic.pop('input_ids')
92 | inputs = {**img_inputs, **inputs_dic}
93 |
94 | if args.stream_chat:
95 | filling_stream = stream_filling_sequence(
96 | model, seq,
97 | batch_size=1,
98 | get_masks_and_position_ids=get_func,
99 | strategy=strategy,
100 | **inputs
101 | )
102 | if get_model_parallel_rank() == 0:
103 | if 'chinese' in args and not args.chinese:
104 | print("Model: ", end='')
105 | else:
106 | print("模型:", end='')
107 | offset = len(text_processor.tokenizer.decode(input_ids))
108 | for tokens, mems in filling_stream:
109 | torch.cuda.empty_cache()
110 | tmp_response = text_processor.tokenizer.decode(tokens[0])
111 | if tmp_response[-1] != "�":
112 | if get_model_parallel_rank() == 0:
113 | tmp_response_offseted = tmp_response[offset:]
114 | if hasattr(text_processor, 'process_response'):
115 | tmp_response_offseted = text_processor.process_response(tmp_response_offseted)
116 | print(tmp_response_offseted, end='', flush=True)
117 | offset = len(tmp_response)
118 | if get_model_parallel_rank() == 0:
119 | print()
120 | output = strategy.finalize(tokens, mems)[0]
121 |
122 | response = text_processor.tokenizer.decode(output[0])
123 | else:
124 | output = filling_sequence(
125 | model, seq,
126 | batch_size=1,
127 | get_masks_and_position_ids=get_func,
128 | strategy=strategy,
129 | **inputs
130 | )[0] # drop memory
131 |
132 | # ---------------
133 | # port from inference_glm.py, more general than chat mode
134 | # clip -1s and fill back generated things into seq
135 | if type(output) is not list:
136 | output_list = output.tolist()
137 | else:
138 | output_list = output
139 |
140 | response = text_processor.tokenizer.decode(output_list[0])
141 | # print('original:', response)
142 | if hasattr(text_processor, 'process_response'):
143 | response = text_processor.process_response(response)
144 | response = response.split(text_processor.sep)[-1].strip()
145 | if get_model_parallel_rank() == 0:
146 | from utils.utils.grounding_parser import parse_response
147 | parse_response(pil_img, response)
148 | history = history + [(query, response)]
149 | return response, history, (torch_image, pil_img, cross_image)
150 |
--------------------------------------------------------------------------------
/utils/utils/dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import logging
3 | import random
4 | import logging
5 | import jsonlines
6 | from io import BytesIO
7 | from PIL import Image
8 | from torch.utils.data import Dataset
9 | from sat.helpers import print_rank0
10 |
11 | def find_all_files(path, suffix=".jpg"):
12 | target_files = []
13 | for cur_dir, _, files in os.walk(path, followlinks=True):
14 | for f in files:
15 | if f.endswith(suffix):
16 | target_files.append(os.path.join(cur_dir, f))
17 | print_rank0(f'find {len(target_files)} files...')
18 | return target_files
19 |
20 | class ItemDataset(Dataset):
21 | def __init__(self, image_processor, text_processor, args, data_dirs, cross_image_processor=None, **kwargs):
22 | super().__init__()
23 | self.data = self.load_data(data_dirs)
24 | self.image_processor, self.text_processor, self.cross_image_processor = image_processor, text_processor, cross_image_processor
25 |
26 | def process_img(self, img):
27 | img_dict = {'vision': self.image_processor(img)}
28 | if self.cross_image_processor:
29 | img_dict.update({'cross': self.cross_image_processor(img)})
30 | return img_dict
31 |
32 | def process_text(self, answer, prompt):
33 | return self.text_processor(answer, prompt)
34 |
35 | def load_data(self, data_dir):
36 | all_files = find_all_files(data_dir, suffix=".jpg")
37 | print_rank0(f"find {len(all_files)} samples in all...")
38 | return all_files
39 |
40 | def __len__(self):
41 | return len(self.data)
42 |
43 | def __getitem__(self, index):
44 | data = self.data[index]
45 | # img
46 | try:
47 | img = Image.open(data).convert('RGB')
48 | except Exception as e:
49 | print_rank0(e, level=logging.WARNING)
50 | return {}
51 | img_dict = self.process_img(img)
52 | # text
53 | label = data.split('/')[-1].split('.')[0]
54 | uni_key = label
55 | text_dict = self.process_text(label, "CAPTCHA:")
56 | if text_dict is None:
57 | print_rank0(f"Process text failed. Please check the max_target_length & max_source_length.\n The data is {data}", level=logging.WARNING)
58 | return {}
59 | # other attr
60 | ret = {**img_dict, **text_dict, "question_id": uni_key}
61 | return ret
--------------------------------------------------------------------------------
/utils/utils/grounding_parser.py:
--------------------------------------------------------------------------------
1 | import seaborn as sns
2 | from PIL import Image, ImageDraw, ImageFont
3 | import matplotlib.font_manager
4 | import spacy
5 | import re
6 |
7 | nlp = spacy.load("en_core_web_sm")
8 |
9 | def draw_boxes(image, boxes, texts, output_fn='output.png'):
10 | box_width = 5
11 | color_palette = sns.color_palette("husl", len(boxes))
12 | colors = [(int(r*255), int(g*255), int(b*255)) for r, g, b in color_palette]
13 |
14 | width, height = image.size
15 | absolute_boxes = [[(int(box[0] * width), int(box[1] * height), int(box[2] * width), int(box[3] * height)) for box in b] for b in boxes]
16 |
17 | overlay = Image.new('RGBA', image.size, (255, 255, 255, 0))
18 | draw = ImageDraw.Draw(overlay)
19 | font_path = sorted(matplotlib.font_manager.findSystemFonts(fontpaths=None, fontext='ttf'))[0]
20 | font = ImageFont.truetype(font_path, size=26)
21 |
22 | for box, text, color in zip(absolute_boxes, texts, colors):
23 | for b in box:
24 | draw.rectangle(b, outline=color, width=box_width)
25 | if not text:
26 | continue
27 | splited_text = text.split('\n')
28 | num_lines = len(splited_text)
29 | text_width, text_height = font.getbbox(splited_text[0])[-2:]
30 | y_start = b[3] - text_height * num_lines - box_width
31 | if b[2] - b[0] < 100 or b[3] - b[1] < 100:
32 | y_start = b[3]
33 | for i, line in enumerate(splited_text):
34 | text_width, text_height = font.getbbox(line)[-2:]
35 | x = b[0] + box_width
36 | y = y_start + text_height * i
37 | draw.rectangle([x, y, x+text_width, y+text_height], fill=(128, 128, 128, 160))
38 | draw.text((x, y), line, font=font, fill=(255, 255, 255))
39 | img_with_overlay = Image.alpha_composite(image.convert('RGBA'), overlay).convert('RGB')
40 | img_with_overlay.save(output_fn)
41 |
42 | def boxstr_to_boxes(box_str):
43 | boxes = [[int(y)/1000 for y in x.split(',')] for x in box_str.split(';') if x.replace(',', '').isdigit()]
44 | return boxes
45 |
46 | def text_to_dict(text):
47 | doc = nlp(text)
48 |
49 | box_matches = list(re.finditer(r'\[\[([^\]]+)\]\]', text))
50 | box_positions = [match.start() for match in box_matches]
51 |
52 | noun_phrases = []
53 | boxes = []
54 |
55 | for match, box_position in zip(box_matches, box_positions):
56 | nearest_np_start = max([0] + [chunk.start_char for chunk in doc.noun_chunks if chunk.end_char <= box_position])
57 | noun_phrase = text[nearest_np_start:box_position].strip()
58 | if noun_phrase and noun_phrase[-1] == '?':
59 | noun_phrase = text[:box_position].strip()
60 | box_string = match.group(1)
61 |
62 | noun_phrases.append(noun_phrase)
63 | boxes.append(boxstr_to_boxes(box_string))
64 |
65 | pairs = []
66 | for noun_phrase, box_string in zip(noun_phrases, boxes):
67 | pairs.append((noun_phrase.lower(), box_string))
68 | return dict(pairs)
69 |
70 | def parse_response(img, response, output_fn='output.png'):
71 | img = img.convert('RGB')
72 | width, height = img.size
73 | ratio = min(1920 / width, 1080 / height)
74 | new_width = int(width * ratio)
75 | new_height = int(height * ratio)
76 | new_img = img.resize((new_width, new_height), Image.LANCZOS)
77 | pattern = r"\[\[(.*?)\]\]"
78 | positions = re.findall(pattern, response)
79 | boxes = [[[int(y) for y in x.split(',')] for x in pos.split(';') if x.replace(',', '').isdigit()] for pos in positions]
80 | dic = text_to_dict(response)
81 | if not dic:
82 | texts = []
83 | boxes = []
84 | else:
85 | texts, boxes = zip(*dic.items())
86 | draw_boxes(new_img, boxes, texts, output_fn=output_fn)
--------------------------------------------------------------------------------
/utils/utils/language.py:
--------------------------------------------------------------------------------
1 | def base_history_to_prompt(self, query, history):
2 | prompt = '' + query
3 | return prompt
4 |
5 | def chat_history_to_prompt(self, query, history):
6 | prompt = " [INST] "
7 | for i, (old_query, response) in enumerate(history):
8 | prompt += old_query + " [/INST] " + response + " [INST] "
9 | prompt += query + " [/INST] "
10 | return prompt
11 |
12 | def vqa_history_to_prompt(self, query, history):
13 | # Only support single round chat in vqa mode
14 | prompt = "Question: "
15 | # for i, (old_query, response) in enumerate(history):
16 | # prompt += old_query + " Short answer: " + response + " Question: "
17 | prompt += query + " Short answer:"
18 | return prompt
19 |
20 | def chat_old_history_to_prompt(self, query, history):
21 | prompt = "Question: "
22 | for i, (old_query, response) in enumerate(history):
23 | prompt += old_query + " Answer: " + response + "\nQuestion: "
24 | prompt += query + " Answer:"
25 | return prompt
26 |
27 | _history_to_prompt = {
28 | "base": base_history_to_prompt,
29 | "chat": chat_history_to_prompt,
30 | "vqa": vqa_history_to_prompt,
31 | "chat_old": chat_old_history_to_prompt, # for cogvlm-v1.1
32 | }
33 |
34 | from transformers import LlamaTokenizer
35 |
36 | def llama2_tokenizer(tokenizer_path, signal_type="base"):
37 | tokenizer = LlamaTokenizer.from_pretrained(tokenizer_path)
38 | if tokenizer.pad_token_id is None:
39 | tokenizer.pad_token_id = 32000
40 | tokenizer.boi = "[IMG]"
41 | tokenizer.eoi = "[/IMG]"
42 | assert signal_type in ["base", "chat", "vqa", "chat_old"]
43 | tokenizer.signal_type = signal_type
44 | return tokenizer
45 |
46 | import re
47 | import numpy as np
48 | import torch
49 |
50 | class llama2_text_processor:
51 | def __init__(self, tokenizer, max_target_length=2048, image_length=257, model=None):
52 | self.tokenizer = tokenizer
53 | self.max_target_length = max_target_length
54 | self.image_length = image_length
55 |
56 | def __call__(self, caption, prompt=""):
57 | if '' not in prompt:
58 | prompt = self.replace_tags_with_empty(prompt)
59 | # caption = self.replace_tags_with_empty(caption)
60 | history = []
61 | prompt = self.history_to_prompt(prompt, history)
62 |
63 | input_ids = [self.tokenizer.bos_token_id]
64 |
65 | prompt_splits = prompt.split('')
66 | caption_splits = caption.split('')
67 | if len(prompt_splits) > 0:
68 | input_ids.extend(self.tokenizer.encode(prompt_splits[0], add_special_tokens=False))
69 | for tokens in prompt_splits[1:]:
70 | tokens_with_img = [-100] + self.tokenizer.encode(tokens, add_special_tokens=False)
71 | input_ids.extend(tokens_with_img)
72 | context_length = len(input_ids) + (len(prompt_splits)-1) * (self.image_length + 1)
73 | if context_length > self.max_target_length - 10:
74 | return None
75 | if len(caption_splits) > 0:
76 | input_ids.extend(self.tokenizer.encode(caption_splits[0], add_special_tokens=False))
77 | for tokens in caption_splits[1:]:
78 | tokens_with_img = [-100] + self.tokenizer.encode(tokens, add_special_tokens=False)
79 | input_ids.extend(tokens_with_img)
80 |
81 | if len(input_ids) > self.max_target_length - self.image_length - 5:
82 | input_ids = input_ids[:self.max_target_length - self.image_length - 5]
83 |
84 | input_ids += [self.tokenizer.eos_token_id]
85 |
86 | while -100 in input_ids:
87 | img_idx = input_ids.index(-100)
88 | input_ids = input_ids[:img_idx] + [0] * (self.image_length + 1) + [-1] + input_ids[img_idx+1:]
89 |
90 | image_position = []
91 | while -1 in input_ids:
92 | img_idx = input_ids.index(-1)
93 | input_ids[img_idx] = 0
94 | image_position.append(img_idx)
95 |
96 | image_embed_mask = [0] * len(input_ids)
97 | vision_expert_mask = [0] * len(input_ids)
98 | image_rope_mask = [0] * len(input_ids)
99 | for idx in image_position:
100 | image_embed_mask[idx-self.image_length-1: idx+1] = [1] * (self.image_length + 2)
101 | vision_expert_mask[idx-self.image_length-1: idx] = [1] * (self.image_length + 1)
102 | image_rope_mask[idx - self.image_length: idx] = [1] * self.image_length
103 | attention_mask = [1] * len(input_ids)
104 | labels = [-100] * context_length + input_ids[context_length:]
105 |
106 | pad_len = self.max_target_length - len(input_ids)
107 | input_ids = input_ids + [self.tokenizer.pad_token_id] * pad_len
108 | attention_mask = attention_mask + [1] * pad_len
109 | vision_expert_mask = vision_expert_mask + [0] * pad_len
110 | image_embed_mask = image_embed_mask + [0] * pad_len
111 | image_rope_mask = image_rope_mask + [0] * pad_len
112 | np_mask = np.tril(np.expand_dims(np.array(attention_mask), 0).repeat(len(attention_mask), 0))
113 | labels = labels + [-100] * pad_len
114 |
115 | for idx in image_position:
116 | labels[idx-self.image_length-1: idx+1] = [-100] * (self.image_length + 2)
117 |
118 | position_ids = []
119 | pid = -1
120 | for i in range(len(input_ids)):
121 | if image_rope_mask[i] == 0 or (i > 0 and image_rope_mask[i] != image_rope_mask[i - 1]):
122 | pid += 1
123 | position_ids.append(pid)
124 |
125 | input_ids = torch.tensor(input_ids).unsqueeze(0)
126 | labels = torch.tensor(labels).unsqueeze(0)
127 | attention_mask = torch.from_numpy(np_mask).unsqueeze(0).unsqueeze(0)
128 | image_embed_mask = torch.tensor(image_embed_mask).unsqueeze(0)
129 | vision_expert_mask = torch.tensor(vision_expert_mask).unsqueeze(0)
130 | image_rope_mask = torch.tensor(image_rope_mask).unsqueeze(0)
131 | position_ids = torch.tensor(position_ids).unsqueeze(0)
132 | context_length = torch.tensor(context_length).unsqueeze(0).long()
133 | return {'input_ids': input_ids, 'labels': labels, 'position_ids': position_ids, 'attention_mask': attention_mask, 'image_embed_mask': image_embed_mask,
134 | 'context_length': context_length, 'image_position': image_position, 'vision_expert_mask': vision_expert_mask, 'image_rope_mask': image_rope_mask
135 | }
136 |
137 | def history_to_prompt(self, query, history):
138 | return _history_to_prompt[self.tokenizer.signal_type](self, query, history)
139 |
140 | def replace_tags_with_empty(self, text):
141 | return re.sub('|||', '', text)
142 |
143 | from functools import partial
144 | def get_masks_and_position_ids(seq, image_logits_mask):
145 | tokens = seq.unsqueeze(0)
146 |
147 | attention_mask = torch.ones((1, len(seq), len(seq)), device=tokens.device)
148 | attention_mask.tril_()
149 | attention_mask.unsqueeze_(1)
150 |
151 | position_ids = []
152 | pid = -1
153 | for i in range(len(image_logits_mask[0])):
154 | if image_logits_mask[0][i] == 0 or (i > 0 and image_logits_mask[0][i] != image_logits_mask[0][i - 1]):
155 | pid += 1
156 | position_ids.append(pid)
157 | for i in range(tokens.shape[1]-image_logits_mask.shape[1]):
158 | pid += 1
159 | position_ids.append(pid)
160 | position_ids = torch.tensor(position_ids, dtype=torch.long, device=tokens.device)
161 | position_ids = position_ids.unsqueeze(0)
162 |
163 | return tokens, attention_mask, position_ids
164 |
165 | class llama2_text_processor_inference:
166 | def __init__(self, tokenizer, max_target_length=1024, image_length=257, model=None, no_prompt=False, english=True):
167 | self.tokenizer = tokenizer
168 | self.max_target_length = max_target_length
169 | self.image_length = image_length
170 | if self.tokenizer.signal_type == "chat":
171 | self.sep = "[/INST]"
172 | elif self.tokenizer.signal_type == "vqa":
173 | self.sep = " Short answer:"
174 | elif self.tokenizer.signal_type == "chat_old":
175 | self.sep = " Answer:"
176 | else:
177 | self.sep = ""
178 |
179 | self.invalid_slices = []
180 | self.no_eoi = True
181 |
182 | def __call__(self, prompt=""):
183 | if '' not in prompt:
184 | prompt = self.replace_tags_with_empty(prompt)
185 | # caption = self.replace_tags_with_empty(caption)
186 | history = []
187 | prompt = self.history_to_prompt(prompt, history)
188 |
189 | input_ids = [self.tokenizer.bos_token_id]
190 |
191 | prompt_splits = prompt.split('')
192 | if len(prompt_splits) > 0:
193 | input_ids.extend(self.tokenizer.encode(prompt_splits[0], add_special_tokens=False))
194 | for tokens in prompt_splits[1:]:
195 | tokens_with_img = [-100] + self.tokenizer.encode(tokens, add_special_tokens=False)
196 | input_ids.extend(tokens_with_img)
197 |
198 | while -100 in input_ids:
199 | img_idx = input_ids.index(-100)
200 | input_ids = input_ids[:img_idx] + [0] * (self.image_length + 1) + [-1] + input_ids[img_idx + 1:]
201 |
202 | image_position = []
203 | while -1 in input_ids:
204 | img_idx = input_ids.index(-1)
205 | input_ids[img_idx] = 0
206 | image_position.append(img_idx)
207 |
208 | image_embed_mask = [0] * len(input_ids)
209 | vision_expert_mask = [0] * len(input_ids)
210 | image_rope_mask = [0] * len(input_ids)
211 | for idx in image_position:
212 | image_embed_mask[idx - self.image_length - 1: idx + 1] = [1] * (self.image_length + 2)
213 | vision_expert_mask[idx - self.image_length - 1: idx] = [1] * (self.image_length + 1)
214 | image_rope_mask[idx - self.image_length: idx] = [1] * self.image_length
215 |
216 | input_ids = torch.tensor(input_ids).unsqueeze(0)
217 | image_embed_mask = torch.tensor(image_embed_mask).unsqueeze(0)
218 | vision_expert_mask = torch.tensor(vision_expert_mask).unsqueeze(0)
219 | image_rope_mask = torch.tensor(image_rope_mask).unsqueeze(0)
220 | return {'input_ids': input_ids, 'image_embed_mask': image_embed_mask, 'vision_expert_mask': vision_expert_mask, 'image_rope_mask': image_rope_mask}
221 |
222 | def history_to_prompt(self, query, history):
223 | return _history_to_prompt[self.tokenizer.signal_type](self, query, history)
224 |
225 | def replace_tags_with_empty(self, text):
226 | return re.sub('|||', '', text)
227 |
228 | def process_response(self, response):
229 | return response.replace('', '')
230 |
231 | def get_func(self, inputs, **kwargs):
232 | get_func = partial(get_masks_and_position_ids, image_logits_mask=kwargs['image_rope_mask'])
233 | return get_func
--------------------------------------------------------------------------------
/utils/utils/vision.py:
--------------------------------------------------------------------------------
1 | from torchvision import transforms
2 | from torchvision.transforms.functional import InterpolationMode
3 | import torch
4 |
5 | class BlipImageEvalProcessor:
6 | def __init__(self, image_size=384, mean=None, std=None):
7 | super().__init__()
8 | if mean is None:
9 | mean = (0.48145466, 0.4578275, 0.40821073)
10 | if std is None:
11 | std = (0.26862954, 0.26130258, 0.27577711)
12 |
13 | self.normalize = transforms.Normalize(mean, std)
14 |
15 | self.transform = transforms.Compose(
16 | [
17 | transforms.Resize(
18 | (image_size, image_size), interpolation=InterpolationMode.BICUBIC
19 | ),
20 | transforms.ToTensor(),
21 | self.normalize,
22 | ]
23 | )
24 |
25 | def __call__(self, item):
26 | return self.transform(item)
27 |
28 | from functools import partial
29 |
30 | def blip2_image_processor_func_with_inputs(image_processor, image):
31 | return {'image': image_processor(image).unsqueeze(0), 'input_ids': torch.zeros(1, 1, dtype=torch.long), 'position_ids': None, 'attention_mask': torch.ones(1, 1, dtype=torch.long)}
32 |
33 | def get_image_processor(image_size):
34 | return partial(blip2_image_processor_func_with_inputs, BlipImageEvalProcessor(image_size))
--------------------------------------------------------------------------------