├── .github ├── ISSUE_TEMPLATE │ ├── bug_report.yaml │ └── feature-request.yaml ├── PULL_REQUEST_TEMPLATE.md └── workflows │ └── python-lint.yml ├── .gitignore ├── .pre-commit-config.yaml ├── LICENSE ├── README.md ├── README_ja.md ├── README_zh.md ├── inference ├── cli_demo_cogview3.py ├── cli_demo_cogview4.py ├── cli_demo_cogview4_int8.py ├── gradio_web_demo.py ├── img │ ├── img_1.png │ ├── img_2.png │ └── img_3.png ├── prompt_optimize.py └── requirements.txt ├── pyproject.toml └── resources ├── CogView3_evaluation.png ├── CogView3_pipeline.jpg ├── CogView3_showcase.png ├── WECHAT.md ├── controlnet.png ├── img.png ├── logo.svg ├── showcase.png ├── train_control_model.md ├── train_control_model_zh.md └── wechat.jpg /.github/ISSUE_TEMPLATE/bug_report.yaml: -------------------------------------------------------------------------------- 1 | name: "\U0001F41B Bug Report" 2 | description: Submit a bug report to help us improve CogView3 / 提交一个 Bug 问题报告来帮助我们改进 CogView3 开源模型 3 | body: 4 | - type: textarea 5 | id: system-info 6 | attributes: 7 | label: System Info / 系統信息 8 | description: Your operating environment / 您的运行环境信息 9 | placeholder: Includes Cuda version, Diffusers version, Python version, operating system, hardware information (if you suspect a hardware problem)... / 包括Cuda版本,Diffusers,Python版本,操作系统,硬件信息(如果您怀疑是硬件方面的问题)... 10 | validations: 11 | required: true 12 | 13 | - type: checkboxes 14 | id: information-scripts-examples 15 | attributes: 16 | label: Information / 问题信息 17 | description: 'The problem arises when using: / 问题出现在' 18 | options: 19 | - label: "The official example scripts / 官方的示例脚本" 20 | - label: "My own modified scripts / 我自己修改的脚本和任务" 21 | 22 | - type: textarea 23 | id: reproduction 24 | validations: 25 | required: true 26 | attributes: 27 | label: Reproduction / 复现过程 28 | description: | 29 | Please provide a code example that reproduces the problem you encountered, preferably with a minimal reproduction unit. 30 | If you have code snippets, error messages, stack traces, please provide them here as well. 31 | Please format your code correctly using code tags. See https://help.github.com/en/github/writing-on-github/creating-and-highlighting-code-blocks#syntax-highlighting 32 | Do not use screenshots, as they are difficult to read and (more importantly) do not allow others to copy and paste your code. 33 | 34 | 请提供能重现您遇到的问题的代码示例,最好是最小复现单元。 35 | 如果您有代码片段、错误信息、堆栈跟踪,也请在此提供。 36 | 请使用代码标签正确格式化您的代码。请参见 https://help.github.com/en/github/writing-on-github/creating-and-highlighting-code-blocks#syntax-highlighting 37 | 请勿使用截图,因为截图难以阅读,而且(更重要的是)不允许他人复制粘贴您的代码。 38 | placeholder: | 39 | Steps to reproduce the behavior/复现Bug的步骤: 40 | 41 | 1. 42 | 2. 43 | 3. 44 | 45 | - type: textarea 46 | id: expected-behavior 47 | validations: 48 | required: true 49 | attributes: 50 | label: Expected behavior / 期待表现 51 | description: "A clear and concise description of what you would expect to happen. /简单描述您期望发生的事情。" 52 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature-request.yaml: -------------------------------------------------------------------------------- 1 | name: "\U0001F680 Feature request" 2 | description: Submit a request for a new CogView3 feature / 提交一个新的 CogView3开源模型的功能建议 3 | labels: [ "feature" ] 4 | body: 5 | - type: textarea 6 | id: feature-request 7 | validations: 8 | required: true 9 | attributes: 10 | label: Feature request / 功能建议 11 | description: | 12 | A brief description of the functional proposal. Links to corresponding papers and code are desirable. 13 | 对功能建议的简述。最好提供对应的论文和代码链接。 14 | 15 | - type: textarea 16 | id: motivation 17 | validations: 18 | required: true 19 | attributes: 20 | label: Motivation / 动机 21 | description: | 22 | Your motivation for making the suggestion. If that motivation is related to another GitHub issue, link to it here. 23 | 您提出建议的动机。如果该动机与另一个 GitHub 问题有关,请在此处提供对应的链接。 24 | 25 | - type: textarea 26 | id: contribution 27 | validations: 28 | required: true 29 | attributes: 30 | label: Your contribution / 您的贡献 31 | description: | 32 | 33 | Your PR link or any other link you can help with. 34 | 您的PR链接或者其他您能提供帮助的链接。 35 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | # Contribution Guide 2 | 3 | We welcome your contributions to this repository. To ensure elegant code style and better code quality, we have prepared the following contribution guidelines. 4 | 5 | ## What We Accept 6 | 7 | + This PR fixes a typo or improves the documentation (if this is the case, you may skip the other checks). 8 | + This PR fixes a specific issue — please reference the issue number in the PR description. Make sure your code strictly follows the coding standards below. 9 | + This PR introduces a new feature — please clearly explain the necessity and implementation of the feature. Make sure your code strictly follows the coding standards below. 10 | 11 | ## Code Style Guide 12 | 13 | Good code style is an art. We have prepared a `pyproject.toml` and a `pre-commit` hook to enforce consistent code formatting across the project. You can clean up your code following the steps below: 14 | 15 | 1. Install the required dependencies: 16 | ```shell 17 | pip install ruff pre-commit 18 | ``` 19 | 2. Then, run the following command: 20 | ```shell 21 | pre-commit run --all-files 22 | ``` 23 | If your code complies with the standards, you should not see any errors. 24 | 25 | ## Naming Conventions 26 | 27 | - Please use **English** for naming; do not use Pinyin or other languages. All comments should also be in English. 28 | - Follow **PEP8** naming conventions strictly, and use underscores to separate words. Avoid meaningless names such as `a`, `b`, `c`. 29 | -------------------------------------------------------------------------------- /.github/workflows/python-lint.yml: -------------------------------------------------------------------------------- 1 | name: Python Linting 2 | 3 | on: 4 | push: 5 | branches: [main] 6 | pull_request: 7 | branches: [main] 8 | 9 | jobs: 10 | lint: 11 | runs-on: ubuntu-latest 12 | steps: 13 | - uses: actions/checkout@v4 14 | 15 | - name: Set up Python 16 | uses: actions/setup-python@v5 17 | with: 18 | python-version: '3.10' 19 | cache: 'pip' 20 | 21 | - name: Install dependencies 22 | run: | 23 | python -m pip install --upgrade pip 24 | pip install pre-commit 25 | 26 | - name: Run pre-commit 27 | run: pre-commit run --all-files 28 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *__pycache__/ 2 | samples*/ 3 | runs/ 4 | checkpoints/ 5 | master_ip 6 | logs/ 7 | *.DS_Store 8 | .idea 9 | output* 10 | test* 11 | img 12 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/astral-sh/ruff-pre-commit 3 | rev: v0.4.5 4 | hooks: 5 | - id: ruff 6 | args: [--fix, --respect-gitignore, --config=pyproject.toml] 7 | - id: ruff-format 8 | args: [--config=pyproject.toml] 9 | 10 | - repo: https://github.com/pre-commit/pre-commit-hooks 11 | rev: v4.5.0 12 | hooks: 13 | - id: trailing-whitespace 14 | - id: end-of-file-fixer 15 | - id: check-yaml 16 | - id: check-toml 17 | - id: check-case-conflict 18 | - id: check-merge-conflict 19 | - id: debug-statements 20 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2025 CogView Team@ZhipuAI 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CogView4 & CogView3 & CogView-3Plus 2 | 3 | [阅读中文版](./README_zh.md) 4 | [日本語で読む](./README_ja.md) 5 | 6 |
7 | 8 |
9 | 10 |

11 | 🤗 HuggingFace Space 12 | 🤖ModelScope Space 13 | 🛠️ZhipuAI MaaS(Faster) 14 |
15 | 👋 WeChat Community 📚 CogView3 Paper 16 |

17 | 18 | ![showcase.png](resources/showcase.png) 19 | 20 | ## Project Updates 21 | 22 | - 🔥🔥 ```2025/03/24```: We are launching [CogKit](https://github.com/THUDM/CogKit), a powerful toolkit for fine-tuning and inference of the **CogView4** and **CogVideoX** series, allowing you to fully explore our multimodal generation models. 23 | - ```2025/03/04```: We've adapted and open-sourced the [diffusers](https://github.com/huggingface/diffusers) version 24 | of **CogView-4** model, which has 6B parameters, supports native Chinese input, and Chinese text-to-image generation. 25 | You can try it [online](https://huggingface.co/spaces/THUDM-HF-SPACE/CogView4). 26 | - ```2024/10/13```: We've adapted and open-sourced the [diffusers](https://github.com/huggingface/diffusers) version of 27 | **CogView-3Plus-3B** model. You can try 28 | it [online](https://huggingface.co/spaces/THUDM-HF-SPACE/CogView3-Plus-3B-Space). 29 | - ```2024/9/29```: We've open-sourced **CogView3** and **CogView-3Plus-3B**. **CogView3** is a text-to-image system 30 | based on cascading diffusion, using a relay diffusion framework. **CogView-3Plus** is a series of newly developed 31 | text-to-image models based on Diffusion Transformer. 32 | 33 | ## Project Plan 34 | 35 | - [X] Diffusers workflow adaptation 36 | - [X] Cog series fine-tuning kits (coming soon) 37 | - [ ] ControlNet models and training code 38 | 39 | ## Community Contributions 40 | 41 | We have collected some community projects related to this repository here. These projects are maintained by community members, and we appreciate their contributions. 42 | 43 | + [ComfyUI_CogView4_Wrapper](https://github.com/chflame163/ComfyUI_CogView4_Wrapper) - An implementation of the CogView4 project in ComfyUI. 44 | 45 | ## Model Introduction 46 | 47 | ### Model Comparison 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 |
Model NameCogView4CogView3-Plus-3B
Resolution 57 | 512 <= H, W <= 2048
58 | H * W <= 2^{21}
59 | H, W \mod 32 = 0 60 |
Inference PrecisionOnly supports BF16, FP32
EncoderGLM-4-9BT5-XXL
Prompt LanguageChinese, EnglishEnglish
Prompt Length Limit1024 Tokens224 Tokens
Download Links🤗 HuggingFace
🤖 ModelScope
🟣 WiseModel
🤗 HuggingFace
🤖 ModelScope
🟣 WiseModel
85 | 86 | ### Memory Usage 87 | 88 | DIT models are tested with `BF16` precision and `batchsize=4`, with results shown in the table below: 89 | 90 | | Resolution | enable_model_cpu_offload OFF | enable_model_cpu_offload ON | enable_model_cpu_offload ON
Text Encoder 4bit | 91 | |-------------|------------------------------|-----------------------------|-----------------------------------------------------| 92 | | 512 * 512 | 33GB | 20GB | 13G | 93 | | 1280 * 720 | 35GB | 20GB | 13G | 94 | | 1024 * 1024 | 35GB | 20GB | 13G | 95 | | 1920 * 1280 | 39GB | 20GB | 14G | 96 | 97 | Additionally, we recommend that your device has at least `32GB` of RAM to prevent the process from being killed. 98 | 99 | ### Model Metrics 100 | 101 | We've tested on multiple benchmarks and achieved the following scores: 102 | 103 | #### DPG-Bench 104 | 105 | | Model | Overall | Global | Entity | Attribute | Relation | Other | 106 | |--------------|-----------|-----------|-----------|-----------|-----------|-----------| 107 | | SDXL | 74.65 | 83.27 | 82.43 | 80.91 | 86.76 | 80.41 | 108 | | PixArt-alpha | 71.11 | 74.97 | 79.32 | 78.60 | 82.57 | 76.96 | 109 | | SD3-Medium | 84.08 | 87.90 | **91.01** | 88.83 | 80.70 | 88.68 | 110 | | DALL-E 3 | 83.50 | **90.97** | 89.61 | 88.39 | 90.58 | 89.83 | 111 | | Flux.1-dev | 83.79 | 85.80 | 86.79 | 89.98 | 90.04 | **89.90** | 112 | | Janus-Pro-7B | 84.19 | 86.90 | 88.90 | 89.40 | 89.32 | 89.48 | 113 | | **CogView4-6B** | **85.13** | 83.85 | 90.35 | **91.17** | **91.14** | 87.29 | 114 | 115 | #### GenEval 116 | 117 | | Model | Overall | Single Obj. | Two Obj. | Counting | Colors | Position | Color attribution | 118 | |-----------------|----------|-------------|----------|----------|----------|----------|-------------------| 119 | | SDXL | 0.55 | 0.98 | 0.74 | 0.39 | 0.85 | 0.15 | 0.23 | 120 | | PixArt-alpha | 0.48 | 0.98 | 0.50 | 0.44 | 0.80 | 0.08 | 0.07 | 121 | | SD3-Medium | 0.74 | **0.99** | **0.94** | 0.72 | 0.89 | 0.33 | 0.60 | 122 | | DALL-E 3 | 0.67 | 0.96 | 0.87 | 0.47 | 0.83 | 0.43 | 0.45 | 123 | | Flux.1-dev | 0.66 | 0.98 | 0.79 | **0.73** | 0.77 | 0.22 | 0.45 | 124 | | Janus-Pro-7B | **0.80** | **0.99** | 0.89 | 0.59 | **0.90** | **0.79** | **0.66** | 125 | | **CogView4-6B** | 0.73 | **0.99** | 0.86 | 0.66 | 0.79 | 0.48 | 0.58 | 126 | 127 | #### T2I-CompBench 128 | 129 | | Model | Color | Shape | Texture | 2D-Spatial | 3D-Spatial | Numeracy | Non-spatial Clip | Complex 3-in-1 | 130 | |-----------------|------------|------------|------------|------------|------------|------------|------------------|----------------| 131 | | SDXL | 0.5879 | 0.4687 | 0.5299 | 0.2133 | 0.3566 | 0.4988 | 0.3119 | 0.3237 | 132 | | PixArt-alpha | 0.6690 | 0.4927 | 0.6477 | 0.2064 | 0.3901 | 0.5058 | **0.3197** | 0.3433 | 133 | | SD3-Medium | **0.8132** | 0.5885 | **0.7334** | **0.3200** | **0.4084** | 0.6174 | 0.3140 | 0.3771 | 134 | | DALL-E 3 | 0.7785 | **0.6205** | 0.7036 | 0.2865 | 0.3744 | 0.5880 | 0.3003 | 0.3773 | 135 | | Flux.1-dev | 0.7572 | 0.5066 | 0.6300 | 0.2700 | 0.3992 | 0.6165 | 0.3065 | 0.3628 | 136 | | Janus-Pro-7B | 0.5145 | 0.3323 | 0.4069 | 0.1566 | 0.2753 | 0.4406 | 0.3137 | 0.3806 | 137 | | **CogView4-6B** | 0.7786 | 0.5880 | 0.6983 | 0.3075 | 0.3708 | **0.6626** | 0.3056 | **0.3869** | 138 | 139 | ## Chinese Text Accuracy Evaluation 140 | 141 | | Model | Precision | Recall | F1 Score | Pick@4 | 142 | |-----------------|------------|------------|------------|------------| 143 | | Kolors | 0.6094 | 0.1886 | 0.2880 | 0.1633 | 144 | | **CogView4-6B** | **0.6969** | **0.5532** | **0.6168** | **0.3265** | 145 | 146 | ## Inference Model 147 | 148 | ### Prompt Optimization 149 | 150 | Although CogView4 series models are trained with lengthy synthetic image descriptions, we strongly recommend using a 151 | large language model to rewrite prompts before text-to-image generation, which will greatly improve generation quality. 152 | 153 | We provide an [example script](inference/prompt_optimize.py). We recommend running this script to refine your prompts. 154 | Note that `CogView4` and `CogView3` models use different few-shot examples for prompt optimization. They need to be 155 | distinguished. 156 | 157 | ```shell 158 | cd inference 159 | python prompt_optimize.py --api_key "Zhipu AI API Key" --prompt {your prompt} --base_url "https://open.bigmodel.cn/api/paas/v4" --model "glm-4-plus" --cogview_version "cogview4" 160 | ``` 161 | 162 | ### Inference Model 163 | 164 | Run the model `CogView4-6B` with `BF16` precision: 165 | 166 | ```python 167 | from diffusers import CogView4Pipeline 168 | import torch 169 | 170 | pipe = CogView4Pipeline.from_pretrained("THUDM/CogView4-6B", torch_dtype=torch.bfloat16).to("cuda") 171 | 172 | # Open it for reduce GPU memory usage 173 | pipe.enable_model_cpu_offload() 174 | pipe.vae.enable_slicing() 175 | pipe.vae.enable_tiling() 176 | 177 | prompt = "A vibrant cherry red sports car sits proudly under the gleaming sun, its polished exterior smooth and flawless, casting a mirror-like reflection. The car features a low, aerodynamic body, angular headlights that gaze forward like predatory eyes, and a set of black, high-gloss racing rims that contrast starkly with the red. A subtle hint of chrome embellishes the grille and exhaust, while the tinted windows suggest a luxurious and private interior. The scene conveys a sense of speed and elegance, the car appearing as if it's about to burst into a sprint along a coastal road, with the ocean's azure waves crashing in the background." 178 | image = pipe( 179 | prompt=prompt, 180 | guidance_scale=3.5, 181 | num_images_per_prompt=1, 182 | num_inference_steps=50, 183 | width=1024, 184 | height=1024, 185 | ).images[0] 186 | 187 | image.save("cogview4.png") 188 | ``` 189 | 190 | For more inference code, please check: 191 | 192 | 1. For using `BNB int4` to load `text encoder` and complete inference code annotations, 193 | check [here](inference/cli_demo_cogview4.py). 194 | 2. For using `TorchAO int8 or int4` to load `text encoder & transformer` and complete inference code annotations, 195 | check [here](inference/cli_demo_cogview4_int8.py). 196 | 3. For setting up a `gradio` GUI DEMO, check [here](inference/gradio_web_demo.py). 197 | 198 | 199 | ## Fine-tuning 200 | 201 | This repository does not contain fine-tuning code, but you can fine-tune using the following two approaches, including both LoRA and SFT: 202 | 203 | 1. [CogKit](https://github.com/THUDM/CogKit), our officially maintained system-level fine-tuning framework that supports CogView4 and CogVideoX. 204 | 2. [finetrainers](https://github.com/a-r-r-o-w/finetrainers), a low-memory solution that enables fine-tuning on a single RTX 4090. 205 | 3. If you want to train ControlNet models directly, you can refer to the [training code](https://github.com/huggingface/diffusers/tree/main/examples/cogview4-control) and train your own models. 206 | 207 | ## License 208 | 209 | The code in this repository and the CogView3 models are licensed under [Apache 2.0](./LICENSE). 210 | 211 | We welcome and appreciate your code contributions. You can view the contribution 212 | guidelines [here](resources/contribute.md). 213 | -------------------------------------------------------------------------------- /README_ja.md: -------------------------------------------------------------------------------- 1 | # CogView4 & CogView3 & CogView-3Plus 2 | 3 | [Read this in English](./README.md) 4 | [阅读中文版](./README_zh.md) 5 | 6 |
7 | 8 | 9 |
10 |

11 | 🤗 HuggingFace Space 12 | 🤖ModelScope Space 13 | 🛠️ZhipuAI MaaS(Faster) 14 |
15 | 👋 WeChat Community 📚 CogView3 Paper 16 |

17 | 18 | 19 | ![showcase.png](resources/showcase.png) 20 | 21 | ## プロジェクトの更新 22 | 23 | - 🔥🔥 ```2025/03/24```: [CogView4-6B-Control](https://huggingface.co/THUDM/CogView4-6B-Control) モデルをリリースしました![トレーニングコード](https://github.com/huggingface/diffusers/tree/main/examples/cogview4-control) を使用して、自身でトレーニングすることも可能です。 24 | さらに、**CogView4** および **CogVideoX** シリーズのファインチューニングと推論を簡単に行えるツールキット [CogKit](https://github.com/THUDM/CogKit) も公開しました。私たちのマルチモーダル生成モデルを存分に活用してください! 25 | - ```2025/03/04```: [diffusers](https://github.com/huggingface/diffusers) バージョンの **CogView-4** 26 | モデルを適応し、オープンソース化しました。このモデルは6Bのパラメータを持ち、ネイティブの中国語入力と中国語のテキストから画像生成をサポートしています。オンラインで試すことができます [こちら](https://huggingface.co/spaces/THUDM-HF-SPACE/CogView4)。 27 | - ```2024/10/13```: [diffusers](https://github.com/huggingface/diffusers) バージョンの **CogView-3Plus-3B** 28 | モデルを適応し、オープンソース化しました。オンラインで試すことができます [こちら](https://huggingface.co/spaces/THUDM-HF-SPACE/CogView3-Plus-3B-Space)。 29 | - ```2024/9/29```: **CogView3** と **CogView-3Plus-3B** をオープンソース化しました。**CogView3** 30 | はカスケード拡散に基づくテキストから画像生成システムで、リレーディフュージョンフレームワークを使用しています。* 31 | *CogView-3Plus** は新たに開発されたDiffusion Transformerに基づくテキストから画像生成モデルのシリーズです。 32 | 33 | ## プロジェクト計画 34 | 35 | - [X] Diffusers ワークフローの適応 36 | - [X] Cogシリーズのファインチューニングスイート (近日公開) 37 | - [ ] ControlNetモデルとトレーニングコード 38 | 39 | ## コミュニティの取り組み 40 | 41 | 本リポジトリに関連するいくつかのコミュニティプロジェクトをここにまとめました。これらのプロジェクトはコミュニティメンバーによって維持されており、彼らの貢献に感謝します。 42 | 43 | + [ComfyUI_CogView4_Wrapper](https://github.com/chflame163/ComfyUI_CogView4_Wrapper) - ComfyUI における CogView4 プロジェクトの実装。 44 | 45 | ## モデル紹介 46 | 47 | ### モデル比較 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 |
モデル名CogView4CogView3-Plus-3B
解像度 57 | 512 <= H, W <= 2048
58 | H * W <= 2^{21}
59 | H, W \mod 32 = 0 60 |
推論精度BF16, FP32 のみサポート
エンコーダGLM-4-9BT5-XXL
プロンプト言語中国語、英語英語
プロンプト長さの制限1024 トークン224 トークン
ダウンロードリンク🤗 HuggingFace
🤖 ModelScope
🟣 WiseModel
🤗 HuggingFace
🤖 ModelScope
🟣 WiseModel
85 | 86 | ### メモリ使用量 87 | 88 | DITモデルは `BF16` 精度と `batchsize=4` でテストされ、結果は以下の表に示されています: 89 | 90 | | 解像度 | enable_model_cpu_offload OFF | enable_model_cpu_offload ON | enable_model_cpu_offload ON
Text Encoder 4bit | 91 | |-------------|------------------------------|-----------------------------|-----------------------------------------------------| 92 | | 512 * 512 | 33GB | 20GB | 13G | 93 | | 1280 * 720 | 35GB | 20GB | 13G | 94 | | 1024 * 1024 | 35GB | 20GB | 13G | 95 | | 1920 * 1280 | 39GB | 20GB | 14G | 96 | 97 | さらに、プロセスが強制終了されないようにするために、少なくとも`32GB`のRAMを持つデバイスを推奨します。 98 | 99 | ### モデル指標 100 | 101 | 複数のベンチマークでテストを行い、以下のスコアを達成しました: 102 | 103 | #### DPG-Bench 104 | 105 | | モデル | 全体 | グローバル | エンティティ | 属性 | 関係 | その他 | 106 | |-----------------|-----------|-----------|-----------|-----------|-----------|-----------| 107 | | SDXL | 74.65 | 83.27 | 82.43 | 80.91 | 86.76 | 80.41 | 108 | | PixArt-alpha | 71.11 | 74.97 | 79.32 | 78.60 | 82.57 | 76.96 | 109 | | SD3-Medium | 84.08 | 87.90 | **91.01** | 88.83 | 80.70 | 88.68 | 110 | | DALL-E 3 | 83.50 | **90.97** | 89.61 | 88.39 | 90.58 | 89.83 | 111 | | Flux.1-dev | 83.79 | 85.80 | 86.79 | 89.98 | 90.04 | **89.90** | 112 | | Janus-Pro-7B | 84.19 | 86.90 | 88.90 | 89.40 | 89.32 | 89.48 | 113 | | **CogView4-6B** | **85.13** | 83.85 | 90.35 | **91.17** | **91.14** | 87.29 | 114 | 115 | #### GenEval 116 | 117 | | モデル | 全体 | 単一オブジェクト | 二つのオブジェクト | カウント | 色 | 位置 | 色の属性 | 118 | |-----------------|----------|----------|-----------|----------|----------|----------|----------| 119 | | SDXL | 0.55 | 0.98 | 0.74 | 0.39 | 0.85 | 0.15 | 0.23 | 120 | | PixArt-alpha | 0.48 | 0.98 | 0.50 | 0.44 | 0.80 | 0.08 | 0.07 | 121 | | SD3-Medium | 0.74 | **0.99** | **0.94** | 0.72 | 0.89 | 0.33 | 0.60 | 122 | | DALL-E 3 | 0.67 | 0.96 | 0.87 | 0.47 | 0.83 | 0.43 | 0.45 | 123 | | Flux.1-dev | 0.66 | 0.98 | 0.79 | **0.73** | 0.77 | 0.22 | 0.45 | 124 | | Janus-Pro-7B | **0.80** | **0.99** | 0.89 | 0.59 | **0.90** | **0.79** | **0.66** | 125 | | **CogView4-6B** | 0.73 | **0.99** | 0.86 | 0.66 | 0.79 | 0.48 | 0.58 | 126 | 127 | #### T2I-CompBench 128 | 129 | | モデル | 色 | 形 | テクスチャ | 2D-空間 | 3D-空間 | 数量 | 非空間 Clip | 複雑な3-in-1 | 130 | |-----------------|------------|------------|------------|------------|------------|------------|------------|------------| 131 | | SDXL | 0.5879 | 0.4687 | 0.5299 | 0.2133 | 0.3566 | 0.4988 | 0.3119 | 0.3237 | 132 | | PixArt-alpha | 0.6690 | 0.4927 | 0.6477 | 0.2064 | 0.3901 | 0.5058 | **0.3197** | 0.3433 | 133 | | SD3-Medium | **0.8132** | 0.5885 | **0.7334** | **0.3200** | **0.4084** | 0.6174 | 0.3140 | 0.3771 | 134 | | DALL-E 3 | 0.7785 | **0.6205** | 0.7036 | 0.2865 | 0.3744 | 0.5880 | 0.3003 | 0.3773 | 135 | | Flux.1-dev | 0.7572 | 0.5066 | 0.6300 | 0.2700 | 0.3992 | 0.6165 | 0.3065 | 0.3628 | 136 | | Janus-Pro-7B | 0.5145 | 0.3323 | 0.4069 | 0.1566 | 0.2753 | 0.4406 | 0.3137 | 0.3806 | 137 | | **CogView4-6B** | 0.7786 | 0.5880 | 0.6983 | 0.3075 | 0.3708 | **0.6626** | 0.3056 | **0.3869** | 138 | 139 | ## 中国語テキストの正確性評価 140 | 141 | | モデル | 精度 | リコール | F1スコア | Pick@4 | 142 | |-----------------|------------|------------|------------|------------| 143 | | Kolors | 0.6094 | 0.1886 | 0.2880 | 0.1633 | 144 | | **CogView4-6B** | **0.6969** | **0.5532** | **0.6168** | **0.3265** | 145 | 146 | ## 推論モデル 147 | 148 | ### プロンプトの最適化 149 | 150 | CogView4シリーズのモデルは長文の合成画像説明でトレーニングされていますが、テキストから画像生成を行う前に大規模言語モデルを使用してプロンプトをリライトすることを強くお勧めします。これにより生成品質が大幅に向上します。 151 | 152 | [例のスクリプト](inference/prompt_optimize.py)を提供しています。このスクリプトを実行してプロンプトをリファインすることをお勧めします。 153 | `CogView4` と `CogView3` モデルのプロンプト最適化には異なるfew-shotが使用されていることに注意してください。区別が必要です。 154 | 155 | ```shell 156 | cd inference 157 | python prompt_optimize.py --api_key "Zhipu AI API Key" --prompt {your prompt} --base_url "https://open.bigmodel.cn/api/paas/v4" --model "glm-4-plus" --cogview_version "cogview4" 158 | ``` 159 | 160 | ### 推論モデル 161 | 162 | `BF16` の精度で `CogView4-6B` モデルを実行する: 163 | 164 | ```python 165 | from diffusers import CogView4Pipeline 166 | import torch 167 | 168 | pipe = CogView4Pipeline.from_pretrained("THUDM/CogView4-6B", torch_dtype=torch.bfloat16).to("cuda") 169 | 170 | # GPUメモリ使用量を減らすために開く 171 | pipe.enable_model_cpu_offload() 172 | pipe.vae.enable_slicing() 173 | pipe.vae.enable_tiling() 174 | 175 | prompt = "A vibrant cherry red sports car sits proudly under the gleaming sun, its polished exterior smooth and flawless, casting a mirror-like reflection. The car features a low, aerodynamic body, angular headlights that gaze forward like predatory eyes, and a set of black, high-gloss racing rims that contrast starkly with the red. A subtle hint of chrome embellishes the grille and exhaust, while the tinted windows suggest a luxurious and private interior. The scene conveys a sense of speed and elegance, the car appearing as if it's about to burst into a sprint along a coastal road, with the ocean's azure waves crashing in the background." 176 | image = pipe( 177 | prompt=prompt, 178 | guidance_scale=3.5, 179 | num_images_per_prompt=1, 180 | num_inference_steps=50, 181 | width=1024, 182 | height=1024, 183 | ).images[0] 184 | 185 | image.save("cogview4.png") 186 | ``` 187 | 188 | より詳しい推論コードについては、以下をご確認ください: 189 | 190 | 1. `BNB int4` を使用して `text encoder` をロードし、完全な推論コードの注釈を確認するには、[こちら](inference/cli_demo_cogview4.py) をご覧ください。 191 | 2. `TorchAO int8 または int4` を使用して `text encoder & transformer` をロードし、完全な推論コードの注釈を確認するには、[こちら](inference/cli_demo_cogview4_int8.py) をご覧ください。 192 | 3. `gradio` GUI デモをセットアップするには、[こちら](inference/gradio_web_demo.py) をご覧ください。 193 | 194 | ## ファインチューニング(微調整) 195 | 196 | このリポジトリにはファインチューニング用のコードは含まれていませんが、LoRA および SFT を含む以下の 2 つの方法でファインチューニングが可能です: 197 | 198 | 1. [CogKit](https://github.com/THUDM/CogKit):CogView4 および CogVideoX のファインチューニングをサポートする、公式で保守されているシステムレベルのファインチューニングフレームワークです。 199 | 2. [finetrainers](https://github.com/a-r-r-o-w/finetrainers):低メモリ環境向けのソリューションで、RTX 4090 でのファインチューニングが可能です。 200 | 3. ControlNet モデルを直接訓練したい場合は、[トレーニングコード](https://github.com/huggingface/diffusers/tree/main/examples/cogview4-control) を参考にして自前で訓練することができます。 201 | 202 | ## ライセンス 203 | 204 | このリポジトリのコードとCogView3モデルは [Apache 2.0](./LICENSE) のライセンスの下で提供されています。 205 | 206 | コードの貢献を歓迎し、感謝します。貢献ガイドラインは [こちら](resources/contribute.md) で確認できます。 207 | -------------------------------------------------------------------------------- /README_zh.md: -------------------------------------------------------------------------------- 1 | # CogView4 & CogView3 & CogView-3Plus 2 | 3 | [Read this in English](./README.md) 4 | [日本語で読む](./README_ja.md) 5 | 6 |
7 | 8 |
9 | 10 |

11 | 🤗 CogView4 HuggingFace Space 12 | 🤗 CogView4-Control HuggingFace Space 13 | 🤖 CogView4 魔搭社区空间 14 |
15 | 🛠️ CogView4 智谱MaaS平台 16 | 🛠️ CogView4-Control 智谱MaaS平台 17 |
18 | 👋 微信社区 19 | 📚 CogView3 论文 20 |

21 | 22 | ![showcase.png](resources/showcase.png) 23 | 24 | ## 项目更新 25 | 26 | - 🔥🔥 ```2025/03/24```: 我们推出了 [CogKit](https://github.com/THUDM/CogKit) 工具,这是一个微调**CogView4**, **CogVideoX** 系列的微调和推理框架,一个工具包,玩转我们的多模态生成模型。 27 | - ```2025/03/04```: 我们适配和开源了 [diffusers](https://github.com/huggingface/diffusers) 版本的 **CogView-4** 28 | 模型,该模型具有6B权重,支持原生中文输入,支持中文文字绘画。你可以前往[在线体验](https://huggingface.co/spaces/THUDM-HF-SPACE/CogView4)。 29 | - ```2024/10/13```: 我们适配和开源了 [diffusers](https://github.com/huggingface/diffusers) 版本的 **CogView-3Plus-3B** 30 | 模型。你可以前往[在线体验](https://huggingface.co/spaces/THUDM-HF-SPACE/CogView3-Plus-3B-Space)。 31 | - ```2024/9/29```: 我们已经开源了 **CogView3** 以及 **CogView-3Plus-3B** 。**CogView3** 是一个基于级联扩散的文本生成图像系统,采用了接力扩散框架。 32 | **CogView-3Plus** 是一系列新开发的基 Diffusion Transformer 的文本生成图像模型。 33 | 34 | ## 项目计划 35 | 36 | - [X] diffusers 工作流适配 37 | - [X] Cog系列微调套件 38 | - [ ] ControlNet模型和训练代码 39 | 40 | ## 社区工作 41 | 42 | 我们将一些和本仓库相关的社区工作收录在这里。这些代码由社区成员维护,我们感谢他们的贡献。 43 | 44 | + [ComfyUI_CogView4_Wrapper](https://github.com/chflame163/ComfyUI_CogView4_Wrapper) ComfyUI 中 CogView4 项目的实现。 45 | 46 | ## 模型介绍 47 | 48 | ### 模型对比 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 |
模型名称CogView4CogView3-Plus-3B
分辨率 58 | 512 <= H, W <= 2048
59 | H * W <= 2^{21}
60 | H, W \mod 32 = 0 61 |
推理精度仅支持BF16, FP32
编码器GLM-4-9BT5-XXL
提示词语言中文,EnglishEnglish
提示词长度上限1024 Tokens224 Tokens
下载链接 🤗 HuggingFace
🤖 ModelScope
🟣 WiseModel
🤗 HuggingFace
🤖 ModelScope
🟣 WiseModel
87 | 88 | ### 显存占用 89 | 90 | DIT模型均使用 `BF16` 精度, `batchsize=4` 进行测试,测试结果如下表所示: 91 | 92 | | 分辨率 | enable_model_cpu_offload OFF | enable_model_cpu_offload ON | enable_model_cpu_offload ON
Text Encoder 4bit | 93 | |-------------|------------------------------|-----------------------------|-----------------------------------------------------| 94 | | 512 * 512 | 33GB | 20GB | 13G | 95 | | 1280 * 720 | 35GB | 20GB | 13G | 96 | | 1024 * 1024 | 35GB | 20GB | 13G | 97 | | 1920 * 1280 | 39GB | 20GB | 14G | 98 | 99 | 此外, 建议您的设备至少拥有`32GB`内存,以防止进程被杀。 100 | 101 | ### 模型指标 102 | 103 | 我们在多个榜单上进行了测试, 并得到了如下的成绩: 104 | 105 | #### DPG-Bench 106 | 107 | | Model | Overall | Global | Entity | Attribute | Relation | Other | 108 | |-----------------|-----------|-----------|-----------|-----------|-----------|-----------| 109 | | SDXL | 74.65 | 83.27 | 82.43 | 80.91 | 86.76 | 80.41 | 110 | | PixArt-alpha | 71.11 | 74.97 | 79.32 | 78.60 | 82.57 | 76.96 | 111 | | SD3-Medium | 84.08 | 87.90 | **91.01** | 88.83 | 80.70 | 88.68 | 112 | | DALL-E 3 | 83.50 | **90.97** | 89.61 | 88.39 | 90.58 | 89.83 | 113 | | Flux.1-dev | 83.79 | 85.80 | 86.79 | 89.98 | 90.04 | **89.90** | 114 | | Janus-Pro-7B | 84.19 | 86.90 | 88.90 | 89.40 | 89.32 | 89.48 | 115 | | **CogView4-6B** | **85.13** | 83.85 | 90.35 | **91.17** | **91.14** | 87.29 | 116 | 117 | #### GenEval 118 | 119 | | Model | Overall | Single Obj. | Two Obj. | Counting | Colors | Position | Color attribution | 120 | |-----------------|----------|-------------|----------|----------|----------|----------|-------------------| 121 | | SDXL | 0.55 | 0.98 | 0.74 | 0.39 | 0.85 | 0.15 | 0.23 | 122 | | PixArt-alpha | 0.48 | 0.98 | 0.50 | 0.44 | 0.80 | 0.08 | 0.07 | 123 | | SD3-Medium | 0.74 | **0.99** | **0.94** | 0.72 | 0.89 | 0.33 | 0.60 | 124 | | DALL-E 3 | 0.67 | 0.96 | 0.87 | 0.47 | 0.83 | 0.43 | 0.45 | 125 | | Flux.1-dev | 0.66 | 0.98 | 0.79 | **0.73** | 0.77 | 0.22 | 0.45 | 126 | | Janus-Pro-7B | **0.80** | **0.99** | 0.89 | 0.59 | **0.90** | **0.79** | **0.66** | 127 | | **CogView4-6B** | 0.73 | **0.99** | 0.86 | 0.66 | 0.79 | 0.48 | 0.58 | 128 | 129 | #### T2I-CompBench 130 | 131 | | Model | Color | Shape | Texture | 2D-Spatial | 3D-Spatial | Numeracy | Non-spatial Clip | Complex 3-in-1 | 132 | |-----------------|------------|------------|------------|------------|------------|------------|------------------|----------------| 133 | | SDXL | 0.5879 | 0.4687 | 0.5299 | 0.2133 | 0.3566 | 0.4988 | 0.3119 | 0.3237 | 134 | | PixArt-alpha | 0.6690 | 0.4927 | 0.6477 | 0.2064 | 0.3901 | 0.5058 | **0.3197** | 0.3433 | 135 | | SD3-Medium | **0.8132** | 0.5885 | **0.7334** | **0.3200** | **0.4084** | 0.6174 | 0.3140 | 0.3771 | 136 | | DALL-E 3 | 0.7785 | **0.6205** | 0.7036 | 0.2865 | 0.3744 | 0.5880 | 0.3003 | 0.3773 | 137 | | Flux.1-dev | 0.7572 | 0.5066 | 0.6300 | 0.2700 | 0.3992 | 0.6165 | 0.3065 | 0.3628 | 138 | | Janus-Pro-7B | 0.5145 | 0.3323 | 0.4069 | 0.1566 | 0.2753 | 0.4406 | 0.3137 | 0.3806 | 139 | | **CogView4-6B** | 0.7786 | 0.5880 | 0.6983 | 0.3075 | 0.3708 | **0.6626** | 0.3056 | **0.3869** | 140 | 141 | ## 中文文字准确率评测 142 | 143 | | Model | Precision | Recall | F1 Score | Pick@4 | 144 | |-----------------|------------|------------|------------|------------| 145 | | Kolors | 0.6094 | 0.1886 | 0.2880 | 0.1633 | 146 | | **CogView4-6B** | **0.6969** | **0.5532** | **0.6168** | **0.3265** | 147 | 148 | ## 推理模型 149 | 150 | ### 提示词优化 151 | 152 | 虽然 CogView4 系列模型都是通过长篇合成图像描述进行训练的,但我们强烈建议在文本生成图像之前,基于大语言模型进行提示词的重写操作,这将大大提高生成质量。 153 | 154 | 我们提供了一个 [示例脚本](inference/prompt_optimize.py)。我们建议您运行这个脚本,以实现对提示词对润色。请注意,`CogView4` 和 155 | `CogView3` 模型的提示词优化使用的few shot不同。需要区分。 156 | 157 | ```shell 158 | cd inference 159 | python prompt_optimize.py --api_key "智谱AI API Key" --prompt {你的提示词} --base_url "https://open.bigmodel.cn/api/paas/v4" --model "glm-4-plus" --cogview_version "cogview4" 160 | ``` 161 | 162 | ### 推理模型 163 | 164 | 以 `BF16` 的精度运行`CogView4-6B`模型: 165 | 166 | ```python 167 | from diffusers import CogView4Pipeline 168 | import torch 169 | 170 | pipe = CogView4Pipeline.from_pretrained("THUDM/CogView4-6B", torch_dtype=torch.bfloat16).to("cuda") 171 | 172 | # Open it for reduce GPU memory usage 173 | pipe.enable_model_cpu_offload() 174 | pipe.vae.enable_slicing() 175 | pipe.vae.enable_tiling() 176 | 177 | prompt = "A vibrant cherry red sports car sits proudly under the gleaming sun, its polished exterior smooth and flawless, casting a mirror-like reflection. The car features a low, aerodynamic body, angular headlights that gaze forward like predatory eyes, and a set of black, high-gloss racing rims that contrast starkly with the red. A subtle hint of chrome embellishes the grille and exhaust, while the tinted windows suggest a luxurious and private interior. The scene conveys a sense of speed and elegance, the car appearing as if it's about to burst into a sprint along a coastal road, with the ocean's azure waves crashing in the background." 178 | image = pipe( 179 | prompt=prompt, 180 | guidance_scale=3.5, 181 | num_images_per_prompt=1, 182 | num_inference_steps=50, 183 | width=1024, 184 | height=1024, 185 | ).images[0] 186 | 187 | image.save("cogview4.png") 188 | ``` 189 | 190 | 更多推理代码,可以参考: 191 | 192 | 1. 用 `BNB int4` 加载 `text encoder` 代码,参考[这里](inference/cli_demo_cogview4.py)。 193 | 2. 用 `TorchAO int8 or int4` 加载 `text encoder & transformer` 代码,参考[这里](inference/cli_demo_cogview4_int8.py)。 194 | 3. 使用 `gradio` 界面运行`CogView4-6B-Control`, 参考[这里](inference/gradio_web_demo.py)。 195 | 196 | 197 | ## 微调模型 198 | 199 | 本仓库没有存放微调代码,你可以通过两个方案进行微调,包括 Lora 和 SFT。 200 | 201 | 1. [CogKit](https://github.com/THUDM/CogKit), 由我们提出的系统微调框架,支持 CogView4,CogVideoX 微调,由我们进行维护。 202 | 2. [finetrainers](https://github.com/a-r-r-o-w/finetrainers), 框架采用低显存的解决方案,在4090上即可进行微调。 203 | 3. 如果你想直接训练 ControlNet模型,可以参考 [训练代码](https://github.com/huggingface/diffusers/tree/main/examples/cogview4-control) 自行训练。 204 | 205 | 206 | ## 开源协议 207 | 208 | 本仓库代码和 CogView3 模型均采用 [Apache 2.0](LICENSE) 开源协议。 209 | 210 | 我们欢迎和感谢你贡献代码,你可以在 [这里](resources/contribute.md) 查看贡献指南。 211 | -------------------------------------------------------------------------------- /inference/cli_demo_cogview3.py: -------------------------------------------------------------------------------- 1 | """ 2 | Running the Script: 3 | To run the script, use the following command with appropriate arguments: 4 | 5 | ```bash 6 | python cli_demo.py --prompt "A beautiful sunset over a mountain" --width 1024 --height 1024 7 | ``` 8 | 9 | Additional options are available to specify the model path, guidance scale, number of inference steps, image generation type, and output paths. 10 | """ 11 | 12 | import argparse 13 | 14 | import torch 15 | from diffusers import CogView3PlusPipeline 16 | 17 | 18 | def generate_image( 19 | prompt, model_path, guidance_scale, num_images_per_prompt, num_inference_steps, width, height, output_path, dtype 20 | ): 21 | # Load the pre-trained model with the specified precision 22 | pipe = CogView3PlusPipeline.from_pretrained(model_path, torch_dtype=dtype) 23 | 24 | # Enable CPU offloading to free up GPU memory when layers are not actively being used 25 | pipe.enable_model_cpu_offload() 26 | 27 | # Enable VAE slicing and tiling for memory optimization 28 | pipe.vae.enable_slicing() 29 | pipe.vae.enable_tiling() 30 | 31 | # Generate the image based on the prompt 32 | image = pipe( 33 | prompt=prompt, 34 | guidance_scale=guidance_scale, 35 | num_images_per_prompt=num_images_per_prompt, 36 | num_inference_steps=num_inference_steps, 37 | width=width, 38 | height=height, 39 | ).images[0] 40 | 41 | # Save the generated image to the local file system 42 | image.save(output_path) 43 | 44 | print(f"Image saved to {output_path}") 45 | 46 | 47 | if __name__ == "__main__": 48 | parser = argparse.ArgumentParser(description="Generate an image using the CogView3-Plus-3B model.") 49 | 50 | # Define arguments for prompt, model path, etc. 51 | parser.add_argument("--prompt", type=str, required=True, help="The text description for generating the image.") 52 | parser.add_argument( 53 | "--model_path", type=str, default="THUDM/CogView3-Plus-3B", help="Path to the pre-trained model." 54 | ) 55 | parser.add_argument( 56 | "--guidance_scale", type=float, default=7.0, help="The guidance scale for classifier-free guidance." 57 | ) 58 | parser.add_argument( 59 | "--num_images_per_prompt", type=int, default=1, help="Number of images to generate per prompt." 60 | ) 61 | parser.add_argument("--num_inference_steps", type=int, default=50, help="Number of denoising steps for inference.") 62 | parser.add_argument("--width", type=int, default=1024, help="Width of the generated image.") 63 | parser.add_argument("--height", type=int, default=1024, help="Height of the generated image.") 64 | parser.add_argument("--output_path", type=str, default="cogview3.png", help="Path to save the generated image.") 65 | parser.add_argument("--dtype", type=str, default="bfloat16", help="Precision type (float16 or bfloat16).") 66 | 67 | # Parse the arguments 68 | args = parser.parse_args() 69 | 70 | # Convert dtype argument to torch dtype 71 | dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float16 72 | 73 | # Call the function to generate the image 74 | generate_image( 75 | prompt=args.prompt, 76 | model_path=args.model_path, 77 | guidance_scale=args.guidance_scale, 78 | num_images_per_prompt=args.num_images_per_prompt, 79 | num_inference_steps=args.num_inference_steps, 80 | width=args.width, 81 | height=args.height, 82 | output_path=args.output_path, 83 | dtype=dtype, 84 | ) 85 | -------------------------------------------------------------------------------- /inference/cli_demo_cogview4.py: -------------------------------------------------------------------------------- 1 | """ 2 | Running the Script: 3 | 4 | ```bash 5 | python generate_image.py --prompt "A serene landscape with mountains and a river" --width 1280 --height 720 6 | ``` 7 | 8 | Additional arguments include: 9 | • --model_path: Path to the pre-trained model (e.g., THUDM/CogView4-6B). 10 | • --guidance_scale: The classifier-free guidance scale for enhancing the generated image quality. 11 | • --num_images_per_prompt: Number of images to generate for each prompt. 12 | • --num_inference_steps: The number of denoising steps during inference. 13 | • --output_path: Path to save the generated image. 14 | • --dtype: Precision type for inference (either bfloat16 or float32,float16 will cause error with NaN). 15 | 16 | This version focuses only on the relevant arguments and features of the script, without additional details about non-existent fields. 17 | 18 | """ 19 | 20 | import argparse 21 | 22 | import torch 23 | from diffusers import CogView4Pipeline 24 | 25 | 26 | def generate_image( 27 | prompt, model_path, guidance_scale, num_images_per_prompt, num_inference_steps, width, height, output_path, dtype 28 | ): 29 | # Load the pre-trained model with the specified precision 30 | # text_encoder = GlmModel.from_pretrained(model_path + "/text_encoder", quantization_config=BitsAndBytesConfig(load_in_4bit=True), torch_dtype=dtype) 31 | # pipe = CogView4Pipeline.from_pretrained(model_path, text_encoder=text_encoder, torch_dtype=dtype) 32 | pipe = CogView4Pipeline.from_pretrained(model_path, torch_dtype=dtype) 33 | 34 | # pipe.to("cuda") 35 | pipe.enable_model_cpu_offload() 36 | pipe.vae.enable_slicing() 37 | pipe.vae.enable_tiling() 38 | # Generate the image based on the prompt 39 | image = pipe( 40 | prompt=prompt, 41 | guidance_scale=guidance_scale, 42 | num_images_per_prompt=num_images_per_prompt, 43 | num_inference_steps=num_inference_steps, 44 | width=width, 45 | height=height, 46 | ).images[0] 47 | 48 | # Save the generated image to the local file system 49 | image.save(output_path) 50 | 51 | print(f"Image saved to {output_path}") 52 | 53 | 54 | if __name__ == "__main__": 55 | parser = argparse.ArgumentParser(description="Generate an image using the CogView4 model.") 56 | 57 | # Define arguments for prompt, model path, etc. 58 | parser.add_argument("--prompt", type=str, required=True, help="The text description for generating the image.") 59 | parser.add_argument("--model_path", type=str, default="THUDM/CogView4-6B", help="Path to the pre-trained model.") 60 | parser.add_argument( 61 | "--guidance_scale", type=float, default=3.5, help="The guidance scale for classifier-free guidance." 62 | ) 63 | parser.add_argument( 64 | "--num_images_per_prompt", type=int, default=1, help="Number of images to generate per prompt." 65 | ) 66 | parser.add_argument("--num_inference_steps", type=int, default=50, help="Number of denoising steps for inference.") 67 | parser.add_argument("--width", type=int, default=1280, help="Width of the generated image.") 68 | parser.add_argument("--height", type=int, default=720, help="Height of the generated image.") 69 | parser.add_argument("--output_path", type=str, default="cogview4.png", help="Path to save the generated image.") 70 | parser.add_argument("--dtype", type=str, default="bfloat16", help="Precision type (float16 or float32).") 71 | 72 | # Parse the arguments 73 | args = parser.parse_args() 74 | 75 | # Convert dtype argument to torch dtype 76 | dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float32 77 | 78 | # Call the function to generate the image 79 | generate_image( 80 | prompt=args.prompt, 81 | model_path=args.model_path, 82 | guidance_scale=args.guidance_scale, 83 | num_images_per_prompt=args.num_images_per_prompt, 84 | num_inference_steps=args.num_inference_steps, 85 | width=args.width, 86 | height=args.height, 87 | output_path=args.output_path, 88 | dtype=dtype, 89 | ) 90 | -------------------------------------------------------------------------------- /inference/cli_demo_cogview4_int8.py: -------------------------------------------------------------------------------- 1 | """ 2 | Running the Script: 3 | 4 | ```bash 5 | python generate_image.py --prompt "A serene landscape with mountains and a river" --width 1280 --height 720 6 | ``` 7 | 8 | Additional arguments include: 9 | • --model_path: Path to the pre-trained model (e.g., THUDM/CogView4-6B). 10 | • --guidance_scale: The classifier-free guidance scale for enhancing the generated image quality. 11 | • --num_images_per_prompt: Number of images to generate for each prompt. 12 | • --num_inference_steps: The number of denoising steps during inference. 13 | • --output_path: Path to save the generated image. 14 | • --dtype: Precision type for inference (either bfloat16 or float32,float16 will cause error with NaN). 15 | 16 | This version focuses only on the relevant arguments and features of the script, without additional details about non-existent fields. 17 | 18 | """ 19 | 20 | import argparse 21 | 22 | import torch 23 | from diffusers import CogView4Pipeline 24 | 25 | 26 | def generate_image( 27 | prompt, model_path, guidance_scale, num_images_per_prompt, num_inference_steps, width, height, output_path, dtype 28 | ): 29 | # Load the pre-trained model with int8 30 | """text_encoder = GlmModel.from_pretrained( 31 | model_path + "/text_encoder", 32 | torch_dtype=dtype 33 | ) 34 | quantize_(text_encoder, int8_weight_only()) 35 | transformer = CogView4Transformer2DModel.from_pretrained( 36 | model_path + "/transformer", 37 | torch_dtype=dtype 38 | ) 39 | quantize_(transformer, int8_weight_only()) 40 | pipe = CogView4Pipeline.from_pretrained( 41 | model_path, 42 | text_encoder=text_encoder, 43 | transformer=transformer, 44 | torch_dtype=dtype, 45 | ).to("cuda")""" 46 | pipe = CogView4Pipeline.from_pretrained(model_path, torch_dtype=dtype).to("cuda") 47 | 48 | pipe.enable_model_cpu_offload() 49 | pipe.vae.enable_slicing() 50 | pipe.vae.enable_tiling() 51 | # Generate the image based on the prompt 52 | image = pipe( 53 | prompt=prompt, 54 | guidance_scale=guidance_scale, 55 | num_images_per_prompt=num_images_per_prompt, 56 | num_inference_steps=num_inference_steps, 57 | width=width, 58 | height=height, 59 | ).images[0] 60 | 61 | # Save the generated image to the local file system 62 | image.save(output_path) 63 | 64 | print(f"Image saved to {output_path}") 65 | 66 | 67 | if __name__ == "__main__": 68 | parser = argparse.ArgumentParser(description="Generate an image using the CogView4 model.") 69 | 70 | # Define arguments for prompt, model path, etc. 71 | parser.add_argument("--prompt", type=str, required=True, help="The text description for generating the image.") 72 | parser.add_argument("--model_path", type=str, default="THUDM/CogView4-6B", help="Path to the pre-trained model.") 73 | parser.add_argument( 74 | "--guidance_scale", type=float, default=3.5, help="The guidance scale for classifier-free guidance." 75 | ) 76 | parser.add_argument( 77 | "--num_images_per_prompt", type=int, default=1, help="Number of images to generate per prompt." 78 | ) 79 | parser.add_argument("--num_inference_steps", type=int, default=50, help="Number of denoising steps for inference.") 80 | parser.add_argument("--width", type=int, default=1280, help="Width of the generated image.") 81 | parser.add_argument("--height", type=int, default=720, help="Height of the generated image.") 82 | parser.add_argument("--output_path", type=str, default="cogview4.png", help="Path to save the generated image.") 83 | parser.add_argument("--dtype", type=str, default="bfloat16", help="Precision type (float16 or float32).") 84 | 85 | # Parse the arguments 86 | args = parser.parse_args() 87 | 88 | # Convert dtype argument to torch dtype 89 | dtype = torch.bfloat16 if args.dtype == "bfloat16" else torch.float32 90 | 91 | # Call the function to generate the image 92 | generate_image( 93 | prompt=args.prompt, 94 | model_path=args.model_path, 95 | guidance_scale=args.guidance_scale, 96 | num_images_per_prompt=args.num_images_per_prompt, 97 | num_inference_steps=args.num_inference_steps, 98 | width=args.width, 99 | height=args.height, 100 | output_path=args.output_path, 101 | dtype=dtype, 102 | ) 103 | -------------------------------------------------------------------------------- /inference/gradio_web_demo.py: -------------------------------------------------------------------------------- 1 | """ 2 | This script demonstrates how to generate an image using the CogView4-6B model within the Hugging Face Space interface. Simply interact with the Gradio interface hosted on Hugging Face CogView4 Demo at [CogView4-6B Hugging Face Space](https://huggingface.co/spaces/THUDM-HF-SPACE/CogView4) 3 | 4 | Running the Script: 5 | To run the script, use the following command with appropriate arguments: 6 | 7 | ```bash 8 | OPENAI_API_KEY="your ZhipuAI API keys" OPENAI_BASE_URL="https://open.bigmodel.cn/api/paas/v4" python gradio_web_demo.py 9 | ``` 10 | 11 | We use [glm-4-plus](https://bigmodel.cn/dev/howuse/glm-4) as the large model for prompt refinement. You can also choose other large models, such as GPT-4o, for refinement.” 12 | 13 | For Different GPU Memory Usage: 14 | 15 | 12G VRAM 16 | ``` 17 | MODE=1 OPENAI_API_KEY="your ZhipuAI API keys" OPENAI_BASE_URL="https://open.bigmodel.cn/api/paas/v4" python gradio_web_demo.py 18 | ``` 19 | 24G VRAM 32G RAM 20 | ``` 21 | MODE=2 OPENAI_API_KEY="your ZhipuAI API keys" OPENAI_BASE_URL="https://open.bigmodel.cn/api/paas/v4" python gradio_web_demo.py 22 | ``` 23 | 24G VRAM 64G RAM 24 | ``` 25 | MODE=3 OPENAI_API_KEY="your ZhipuAI API keys" OPENAI_BASE_URL="https://open.bigmodel.cn/api/paas/v4" python gradio_web_demo.py 26 | ``` 27 | 40G VRAM 64G RAM and Larger 28 | ``` 29 | OPENAI_API_KEY="your ZhipuAI API keys" OPENAI_BASE_URL="https://open.bigmodel.cn/api/paas/v4" python gradio_web_demo.py 30 | ``` 31 | """ 32 | 33 | import gc 34 | import os 35 | import random 36 | import re 37 | import threading 38 | import time 39 | from datetime import datetime, timedelta 40 | 41 | import gradio as gr 42 | import torch 43 | from diffusers import CogView4Pipeline 44 | from diffusers.models import CogView4Transformer2DModel 45 | from openai import OpenAI 46 | from torchao.quantization import int8_weight_only, quantize_ 47 | from transformers import GlmModel 48 | 49 | 50 | total_vram_in_gb = torch.cuda.get_device_properties(0).total_memory / 1073741824 51 | 52 | print(f"\033[32mCUDA版本:{torch.version.cuda}\033[0m") 53 | print(f"\033[32mPytorch版本:{torch.__version__}\033[0m") 54 | print(f"\033[32m显卡型号:{torch.cuda.get_device_name()}\033[0m") 55 | print(f"\033[32m显存大小:{total_vram_in_gb:.2f}GB\033[0m") 56 | 57 | if torch.cuda.get_device_capability()[0] >= 8: 58 | print("\033[32m支持BF16\033[0m") 59 | dtype = torch.bfloat16 60 | else: 61 | print("\033[32m不支持BF16,使用FP16\033[0m") 62 | dtype = torch.float16 63 | 64 | device = "cuda" if torch.cuda.is_available() else "cpu" 65 | model_path = "THUDM/CogView4-6B" 66 | mode = os.environ.get("MODE", "0") 67 | 68 | text_encoder = None 69 | transformer = None 70 | if mode in ["1", "2"]: 71 | text_encoder = GlmModel.from_pretrained(model_path, subfolder="text_encoder", torch_dtype=dtype) 72 | transformer = CogView4Transformer2DModel.from_pretrained(model_path, subfolder="transformer", torch_dtype=dtype) 73 | quantize_(text_encoder, int8_weight_only()) 74 | quantize_(transformer, int8_weight_only()) 75 | 76 | pipe = CogView4Pipeline.from_pretrained( 77 | model_path, 78 | text_encoder=text_encoder, 79 | transformer=transformer, 80 | torch_dtype=dtype, 81 | ).to(device) 82 | 83 | if mode in ["1", "3"]: 84 | pipe.enable_model_cpu_offload() 85 | 86 | pipe.vae.enable_slicing() 87 | pipe.vae.enable_tiling() 88 | 89 | 90 | def clean_string(s): 91 | s = s.replace("\n", " ") 92 | s = s.strip() 93 | s = re.sub(r"\s{2,}", " ", s) 94 | return s 95 | 96 | 97 | def convert_prompt( 98 | prompt: str, 99 | key: str, 100 | retry_times: int = 5, 101 | ) -> str: 102 | os.environ["OPENAI_API_KEY"] = key 103 | if not key: 104 | return prompt 105 | client = OpenAI() 106 | prompt = clean_string(prompt) 107 | for i in range(retry_times): 108 | try: 109 | response = client.chat.completions.create( 110 | messages=[ 111 | { 112 | "role": "system", 113 | "content": 'You are a bilingual image description assistant that works with an image generation bot. You work with an assistant bot that will draw anything you say . \n For example ,For example, outputting "a beautiful morning in the woods with the sun peaking through the trees" or "阳光透过树林的美丽清晨" will trigger your partner bot to output an image of a forest morning, as described . \n You will be prompted by people looking to create detailed , amazing images . The way to accomplish this is to take their short prompts and make them extremely detailed and descriptive . \n There are a few rules to follow : \n - Input can be in Chinese or English. If input is in English, prompt should be written in English. If input is in Chinese, prompt should be written in Chinese.\n - You will only ever output a single image description per user request .\n - Image descriptions must be detailed and specific, including keyword categories such as subject, medium, style, additional details, color, and lighting. \n - When generating descriptions, focus on portraying the visual elements rather than delving into abstract psychological and emotional aspects. Provide clear and concise details that vividly depict the scene and its composition, capturing the tangible elements that make up the setting.\n - Do not provide the process and explanation, just return the modified description . \n ', 114 | }, 115 | { 116 | "role": "user", 117 | "content": 'Create an imaginative image descriptive caption for the user input : "An anime girl stands amidst a dense flower bush."', 118 | }, 119 | { 120 | "role": "assistant", 121 | "content": "This image is a beautifully crafted digital illustration in an anime style. It features Suzuran, a character from the video game Arknights. She stands gracefully amidst a picturesque meadow with lush green grass and scattered wildflowers that gently sway in the breeze. Her attire includes a detailed outfit with layered ruffles and intricate fastenings, reflecting both elegance and functionality typical of the game's aesthetic. Suzuran holds a small bouquet of flowers delicately in her hands, adding to the serene atmosphere.\n\nProminent fox ears extend from her head, signifying her animal spirit form or persona within the game. The background showcases rolling hills covered in dense foliage under a brilliant blue sky dotted with fluffy white clouds. Golden leaves float whimsically through the air, enhancing the magical quality of the scene. A distant building partially obscured by vegetation hints at human presence without detracting from the natural beauty surrounding Suzuran. Overall, this artwork captures a moment of tranquility and charm, embodying the essence of the character and the enchanting world she inhabits.", 122 | }, 123 | { 124 | "role": "user", 125 | "content": 'Create an imaginative image descriptive caption for the user input : "Draw a bright convertible car with a sense of artistic design."', 126 | }, 127 | { 128 | "role": "assistant", 129 | "content": 'The image showcases a meticulously crafted Delahaye 135 M.S. Special Roadster from 1937 in a highly polished and realistic rendering style that highlights its luxurious design and impeccable details. The car\'s body is an elegant deep brown with a glossy finish, exuding sophistication and timeless beauty. Its aerodynamic, streamlined shape features smooth curves accentuated by chrome detailing on the fenders and running boards. The front grille has three prominent circular headlights, adding to its classic charm while the dual exhaust pipes are tastefully integrated into the rear fenders.\n\nThe open cockpit area reveals sumptuous tan leather seats, emphasizing both comfort and elegance for the driver and passenger. The vehicle’s number plate reads "MB 175". Photographed against a dark gradient background, the focus remains solely on this automotive masterpiece, allowing viewers to appreciate every intricate detail without distraction. This image captures not just a car but a piece of art, reflecting the luxury and craftsmanship emblematic of pre-war European automobiles.', 130 | }, 131 | { 132 | "role": "user", 133 | "content": 'Create an imaginative image descriptive caption for the user input : "画一个白发、神情忧郁的动漫女孩,手里拿着一支香烟"', 134 | }, 135 | { 136 | "role": "assistant", 137 | "content": "这幅图像是一幅动漫风格的插画,画中描绘了一位长发飘逸的白发女孩。她神情忧郁,双眼低垂,脸上带着微微的愁容。女孩穿着浅色外套,里面搭配深色衬衫和领带,增添了她沉静却时尚的外表。她的手靠近嘴边,似乎在拿着一支香烟,画作的说明也表示她在吸烟。背景描绘了一个冬季的城市场景,地面和建筑物上覆盖着积雪。几株稀疏的植物从雪地中探出,街边停着一辆车,似乎是空荡荡的街道,增强了场景的荒凉氛围。这幅作品整体风格细致精美,典型的现代动漫插画风格,既捕捉了角色的情感表达,也呈现了周围环境的静谧氛围", 138 | }, 139 | { 140 | "role": "user", 141 | "content": 'Create an imaginative image descriptive caption for the user input : "一张红色的海报,中间写有“开门大吉”"', 142 | }, 143 | { 144 | "role": "assistant", 145 | "content": "这张图片采用了海报风格,色彩鲜艳,主要以红色和金色为主,寓意吉祥如意。在画面的正中间是一块红色的大匾,上面用白色大字写着“开门大吉”,四角点缀着金色的装饰图案,显得格外喜庆。匾额上方悬挂着一对红色的灯笼,增添了节日的气氛。背景左右两侧是传统的绿色中式建筑,屋顶呈现出典型的飞檐翘角设计。底部有祥云朵朵和可爱的卡通福袋,象征着好运福气满满。整张海报传达出浓厚的节日氛围。", 146 | }, 147 | { 148 | "role": "user", 149 | "content": f"Create an imaginative image descriptive caption for the user input : {prompt}", 150 | }, 151 | ], 152 | model="glm-4-flash", 153 | temperature=0.01, 154 | top_p=0.7, 155 | stream=False, 156 | max_tokens=300, 157 | ) 158 | prompt = response.choices[0].message.content 159 | if prompt: 160 | prompt = clean_string(prompt) 161 | break 162 | except Exception: 163 | pass 164 | 165 | return prompt 166 | 167 | 168 | def delete_old_files(): 169 | while True: 170 | now = datetime.now() 171 | cutoff = now - timedelta(minutes=5) 172 | os.makedirs("./gradio_tmp", exist_ok=True) 173 | directories = ["./gradio_tmp"] 174 | for directory in directories: 175 | for filename in os.listdir(directory): 176 | file_path = os.path.join(directory, filename) 177 | if os.path.isfile(file_path): 178 | file_mtime = datetime.fromtimestamp(os.path.getmtime(file_path)) 179 | if file_mtime < cutoff: 180 | os.remove(file_path) 181 | time.sleep(600) 182 | 183 | 184 | threading.Thread(target=delete_old_files, daemon=True).start() 185 | 186 | 187 | def infer( 188 | prompt, 189 | seed, 190 | randomize_seed, 191 | width, 192 | height, 193 | guidance_scale, 194 | num_inference_steps, 195 | num_images, 196 | progress=gr.Progress(track_tqdm=True), 197 | ): 198 | gc.collect() 199 | if torch.cuda.is_available(): 200 | torch.cuda.empty_cache() 201 | torch.cuda.ipc_collect() 202 | 203 | if randomize_seed: 204 | seed = random.randint(0, 65536) 205 | 206 | images = pipe( 207 | prompt=prompt, 208 | guidance_scale=guidance_scale, 209 | num_images_per_prompt=num_images, 210 | num_inference_steps=num_inference_steps, 211 | width=width, 212 | height=height, 213 | generator=torch.Generator().manual_seed(seed), 214 | ).images 215 | 216 | return images, seed 217 | 218 | 219 | def update_max_height(width): 220 | max_height = MAX_PIXELS // width 221 | return gr.update(maximum=max_height) 222 | 223 | 224 | def update_max_width(height): 225 | max_width = MAX_PIXELS // height 226 | return gr.update(maximum=max_width) 227 | 228 | 229 | examples = [ 230 | "这是一幅充满皮克斯风格的动画渲染图像,展现了一只拟人化的粘土风格小蛇。这条快乐的小蛇身着魔术师装扮,占据了画面下方三分之一的位置,显得俏皮而生动。它的头上戴着一顶黑色羊毛材质的复古礼帽,身上穿着一件设计独特的红色棉袄,白色的毛袖增添了一抹温暖的对比。小蛇的鳞片上精心绘制了金色梅花花纹,显得既华丽又不失可爱。它的腹部和脸庞呈现洁白,与红色的身体形成鲜明对比。 这条蜿蜒的小蛇拥有可爱的塑胶手办质感,仿佛随时会从画面中跃然而出。背景是一片鲜艳的红色,地面上散布着宝箱、金蛋和红色灯笼等装饰物,营造出浓厚的节日气氛。画面的上半部分用金色连体字书写着 “Happy New Year”,庆祝新年的到来,同时也暗示了蛇年的到来,为整幅画面增添了一份节日的喜悦和祥瑞。", 231 | "在这幅如梦似幻的画作中,一辆由云朵构成的毛绒汽车轻盈地漂浮在蔚蓝的高空之中。这辆汽车设计独特,车身完全由洁白、蓬松的云朵编织而成,每一处都散发着柔软而毛茸茸的质感。从车顶到轮胎,再到它的圆润车灯,无一不是由细腻的云丝构成,仿佛随时都可能随风轻轻摆动。车窗也是由透明的云物质构成,同样覆盖着一层细软的绒毛,让人不禁想要触摸。 这辆神奇的云朵汽车仿佛是魔法世界中的交通工具,它悬浮在夕阳映照的绚丽天空之中,周围是五彩斑斓的晚霞和悠然飘浮的云彩。夕阳的余晖洒在云朵车上,为其柔软的轮廓镀上了一层金色的光辉,使得整个场景既温馨又神秘,引人入胜。", 232 | "A vintage red convertible with gleaming chrome finishes sits attractively under the golden hues of a setting sun, parked on a deserted cobblestone street in a charming old town. The car's polished body reflects the surrounding quaint buildings and the few early evening stars beginning to twinkle in the gentle gradient of the twilight sky. A light breeze teases the few fallen leaves near the car's pristine white-walled tires, which rest casually by the sidewalk, hinting at the leisurely pace of life in this serene setting.", 233 | ] 234 | with gr.Blocks(theme=gr.themes.Soft()) as demo: 235 | gr.Markdown(""" 236 |
237 |

CogView4-6B

238 |
239 |
240 | 🌐 Github | 241 | 📜 arXiv 242 |
243 |
244 | ⚠️ 该演示仅供学术研究和体验使用。 245 |
246 | 247 | """) 248 | 249 | with gr.Column(): 250 | with gr.Row(): 251 | with gr.Column(): 252 | with gr.Row(): 253 | prompt = gr.Text( 254 | label="Prompt", 255 | show_label=False, 256 | max_lines=15, 257 | placeholder="Enter your prompt", 258 | container=False, 259 | ) 260 | with gr.Row(): 261 | enhance = gr.Button("Enhance Prompt (Strongly Suggest)", scale=1) 262 | run_button = gr.Button("Run", scale=1) 263 | with gr.Row(): 264 | num_images = gr.Number( 265 | label="Number of Images", 266 | minimum=1, 267 | maximum=8, 268 | step=1, 269 | value=2, 270 | ) 271 | key = gr.Textbox( 272 | label="Key", 273 | placeholder="Enter your key", 274 | type="password", 275 | max_lines=1, 276 | ) 277 | with gr.Row(): 278 | seed = gr.Slider( 279 | label="Seed", 280 | minimum=0, 281 | maximum=65536, 282 | step=1, 283 | value=0, 284 | ) 285 | randomize_seed = gr.Checkbox(label="Randomize seed", value=True) 286 | with gr.Row(): 287 | width = gr.Slider( 288 | label="Width", 289 | minimum=512, 290 | maximum=2048, 291 | step=32, 292 | value=1024, 293 | ) 294 | height = gr.Slider( 295 | label="Height", 296 | minimum=512, 297 | maximum=2048, 298 | step=32, 299 | value=1024, 300 | ) 301 | with gr.Row(): 302 | guidance_scale = gr.Slider( 303 | label="Guidance scale", 304 | minimum=0.0, 305 | maximum=10.0, 306 | step=0.1, 307 | value=3.5, 308 | ) 309 | num_inference_steps = gr.Slider( 310 | label="Number of inference steps", 311 | minimum=10, 312 | maximum=100, 313 | step=1, 314 | value=50, 315 | ) 316 | with gr.Column(): 317 | result = gr.Gallery(label="Results", show_label=True) 318 | 319 | MAX_PIXELS = 2**21 320 | enhance.click(convert_prompt, inputs=[prompt, key], outputs=[prompt]) 321 | width.change(update_max_height, inputs=[width], outputs=[height]) 322 | height.change(update_max_width, inputs=[height], outputs=[width]) 323 | 324 | with gr.Column(): 325 | gr.Markdown("### Examples (Enhance prompt finish)") 326 | for i, ex in enumerate(examples): 327 | with gr.Row(): 328 | ex_btn = gr.Button(value=ex, variant="secondary", elem_id=f"ex_btn_{i}", scale=3) 329 | ex_img = gr.Image( 330 | value=f"img/img_{i + 1}.png", 331 | label="Effect", 332 | interactive=False, 333 | height=130, 334 | width=130, 335 | scale=1, 336 | ) 337 | ex_btn.click(fn=lambda ex=ex: ex, inputs=[], outputs=prompt) 338 | 339 | gr.on( 340 | triggers=[run_button.click, prompt.submit], 341 | fn=infer, 342 | inputs=[prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps, num_images], 343 | outputs=[result, seed], 344 | ) 345 | 346 | demo.queue().launch(inbrowser=True) 347 | -------------------------------------------------------------------------------- /inference/img/img_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THUDM/CogView4/1f21551ad1a00fec276c01e8bbf98b38d9590108/inference/img/img_1.png -------------------------------------------------------------------------------- /inference/img/img_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THUDM/CogView4/1f21551ad1a00fec276c01e8bbf98b38d9590108/inference/img/img_2.png -------------------------------------------------------------------------------- /inference/img/img_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THUDM/CogView4/1f21551ad1a00fec276c01e8bbf98b38d9590108/inference/img/img_3.png -------------------------------------------------------------------------------- /inference/prompt_optimize.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import re 3 | 4 | from openai import OpenAI 5 | 6 | 7 | def clean_string(s): 8 | s = s.replace("\n", " ") 9 | s = s.strip() 10 | s = re.sub(r"\s{2,}", " ", s) 11 | return s 12 | 13 | 14 | def convert_prompt( 15 | api_key: str, base_url: str, prompt: str, system_instruction: str, model: str, user_assistant_pairs: list 16 | ) -> str: 17 | client = OpenAI(api_key=api_key, base_url=base_url) 18 | prompt = clean_string(prompt) 19 | messages = ( 20 | [{"role": "system", "content": system_instruction}] 21 | + user_assistant_pairs 22 | + [ 23 | { 24 | "role": "user", 25 | "content": f'Create an imaginative image descriptive caption for the user input : "{prompt}"', 26 | } 27 | ] 28 | ) 29 | response = client.chat.completions.create( 30 | messages=messages, 31 | model=model, 32 | temperature=0.01, 33 | top_p=0.7, 34 | stream=False, 35 | max_tokens=300, 36 | ) 37 | prompt = response.choices[0].message.content 38 | prompt = clean_string(prompt) 39 | return prompt 40 | 41 | 42 | def get_system_instruction(cogview_version: str) -> str: 43 | if cogview_version == "cogview3": 44 | return """ 45 | You are part of a team of bots that creates images. You work with an assistant bot that will draw anything you say. 46 | For example, outputting "a beautiful morning in the woods with the sun peaking through the trees" will trigger your partner bot to output an image of a forest morning, as described. 47 | You will be prompted by people looking to create detailed, amazing images. The way to accomplish this is to take their short prompts and make them extremely detailed and descriptive. 48 | There are a few rules to follow: 49 | - Prompt should always be written in English, regardless of the input language. Please provide the prompts in English. 50 | - You will only ever output a single image description per user request. 51 | - Image descriptions must be detailed and specific, including keyword categories such as subject, medium, style, additional details, color, and lighting. 52 | - When generating descriptions, focus on portraying the visual elements rather than delving into abstract psychological and emotional aspects. Provide clear and concise details that vividly depict the scene and its composition, capturing the tangible elements that make up the setting. 53 | - Do not provide the process and explanation, just return the modified English description. Image descriptions must be between 100-200 words. Extra words will be ignored. 54 | """ 55 | elif cogview_version == "cogview4": 56 | return """ 57 | You are a bilingual image description assistant that works with an image generation bot. You work with an assistant bot that will draw anything you say. 58 | For example, outputting "a beautiful morning in the woods with the sun peaking through the trees" or "阳光透过树林的美丽清晨" will trigger your partner bot to output an image of a forest morning, as described. 59 | You will be prompted by people looking to create detailed, amazing images. The way to accomplish this is to take their short prompts and make them extremely detailed and descriptive. 60 | There are a few rules to follow: 61 | - Input can be in Chinese or English. If input is in English, prompt should be written in English. If input is in Chinese, prompt should be written in Chinese. 62 | - You will only ever output a single image description per user request. 63 | - Image descriptions must be detailed and specific, including keyword categories such as subject, medium, style, additional details, color, and lighting. 64 | - When generating descriptions, focus on portraying the visual elements rather than delving into abstract psychological and emotional aspects. Provide clear and concise details that vividly depict the scene and its composition, capturing the tangible elements that make up the setting. 65 | - Do not provide the process and explanation, just return the modified description. 66 | """ 67 | else: 68 | raise ValueError("Invalid cogview_version. Please choose either 'cogview3' or 'cogview4'.") 69 | 70 | 71 | def get_user_assistant_pairs(cogview_version: str) -> list: 72 | if cogview_version == "cogview3": 73 | return [ 74 | { 75 | "role": "user", 76 | "content": 'Create an imaginative image descriptive caption for the user input : "一个头发花白的老人"', 77 | }, 78 | { 79 | "role": "assistant", 80 | "content": "A seasoned male with white hair and a neatly groomed beard stands confidently, donning a dark vest over a striped shirt. His hands are clasped together in front, one adorned with a ring, as he looks directly at the viewer with a composed expression. The soft lighting accentuates his features and the subtle textures of his attire, creating a portrait that exudes sophistication and a timeless elegance.", 81 | }, 82 | { 83 | "role": "user", 84 | "content": 'Create an imaginative image descriptive caption for the user input : "画一只老鹰"', 85 | }, 86 | { 87 | "role": "assistant", 88 | "content": "A majestic eagle with expansive brown and white wings glides through the air, its sharp yellow eyes focused intently ahead. The eagle's talons are poised and ready for hunting, as it soars over a rugged mountainous terrain dusted with snow, under a soft blue sky.", 89 | }, 90 | { 91 | "role": "user", 92 | "content": 'Create an imaginative image descriptive caption for the user input : "画一辆摩托车"', 93 | }, 94 | { 95 | "role": "assistant", 96 | "content": "Parked on a wet city street at night, a sleek motorcycle with a black and green design stands out. Its headlights cast a soft glow, reflecting off the puddles and highlighting its aerodynamic shape. The design is marked by sharp lines and angular features, with gold accents that shine against the dark backdrop. The motorcycle exudes an air of performance and luxury, ready to slice through the urban landscape.", 97 | }, 98 | { 99 | "role": "user", 100 | "content": 'Create an imaginative image descriptive caption for the user input : "穿着金色盔甲的人"', 101 | }, 102 | { 103 | "role": "assistant", 104 | "content": "A figure clad in meticulously crafted, golden armor stands with an air of quiet confidence. The armor, reminiscent of medieval knight attire, features a scalloped design with leaf-like patterns and is complemented by a black, form-fitting undergarment. The helmet, with its angular visor, adds to the intimidating presence. This armor, with its rich gold tones and intricate details, suggests a character of nobility or mythical origin, poised for valorous endeavors.", 105 | }, 106 | ] 107 | elif cogview_version == "cogview4": 108 | return [ 109 | { 110 | "role": "user", 111 | "content": 'Create an imaginative image descriptive caption for the user input : "An anime girl stands amidst a dense flower bush."', 112 | }, 113 | { 114 | "role": "assistant", 115 | "content": "This image is a beautifully crafted digital illustration in an anime style. It features Suzuran, a character from the video game Arknights. She stands gracefully amidst a picturesque meadow with lush green grass and scattered wildflowers that gently sway in the breeze. Her attire includes a detailed outfit with layered ruffles and intricate fastenings, reflecting both elegance and functionality typical of the game's aesthetic. Suzuran holds a small bouquet of flowers delicately in her hands, adding to the serene atmosphere.\n\nProminent fox ears extend from her head, signifying her animal spirit form or persona within the game. The background showcases rolling hills covered in dense foliage under a brilliant blue sky dotted with fluffy white clouds. Golden leaves float whimsically through the air, enhancing the magical quality of the scene. A distant building partially obscured by vegetation hints at human presence without detracting from the natural beauty surrounding Suzuran. Overall, this artwork captures a moment of tranquility and charm, embodying the essence of the character and the enchanting world she inhabits.", 116 | }, 117 | { 118 | "role": "user", 119 | "content": 'Create an imaginative image descriptive caption for the user input : "Draw a bright convertible car with a sense of artistic design."', 120 | }, 121 | { 122 | "role": "assistant", 123 | "content": 'The image showcases a meticulously crafted Delahaye 135 M.S. Special Roadster from 1937 in a highly polished and realistic rendering style that highlights its luxurious design and impeccable details. The car\'s body is an elegant deep brown with a glossy finish, exuding sophistication and timeless beauty. Its aerodynamic, streamlined shape features smooth curves accentuated by chrome detailing on the fenders and running boards. The front grille has three prominent circular headlights, adding to its classic charm while the dual exhaust pipes are tastefully integrated into the rear fenders.\n\nThe open cockpit area reveals sumptuous tan leather seats, emphasizing both comfort and elegance for the driver and passenger. The vehicle’s number plate reads "MB 175". Photographed against a dark gradient background, the focus remains solely on this automotive masterpiece, allowing viewers to appreciate every intricate detail without distraction. This image captures not just a car but a piece of art, reflecting the luxury and craftsmanship emblematic of pre-war European automobiles.', 124 | }, 125 | { 126 | "role": "user", 127 | "content": 'Create an imaginative image descriptive caption for the user input : "画一个白发、神情忧郁的动漫女孩,手里拿着一支香烟"', 128 | }, 129 | { 130 | "role": "assistant", 131 | "content": "这幅图像是一幅动漫风格的插画,画中描绘了一位长发飘逸的白发女孩。她神情忧郁,双眼低垂,脸上带着微微的愁容。女孩穿着浅色外套,里面搭配深色衬衫和领带,增添了她沉静却时尚的外表。她的手靠近嘴边,似乎在拿着一支香烟,画作的说明也表示她在吸烟。背景描绘了一个冬季的城市场景,地面和建筑物上覆盖着积雪。几株稀疏的植物从雪地中探出,街边停着一辆车,似乎是空荡荡的街道,增强了场景的荒凉氛围。这幅作品整体风格细致精美,典型的现代动漫插画风格,既捕捉了角色的情感表达,也呈现了周围环境的静谧氛围", 132 | }, 133 | { 134 | "role": "user", 135 | "content": 'Create an imaginative image descriptive caption for the user input : "一张红色的海报,中间写有“开门大吉”"', 136 | }, 137 | { 138 | "role": "assistant", 139 | "content": "这张图片采用了海报风格,色彩鲜艳,主要以红色和金色为主,寓意吉祥如意。在画面的正中间是一块红色的大匾,上面用白色大字写着“开门大吉”,四角点缀着金色的装饰图案,显得格外喜庆。匾额上方悬挂着一对红色的灯笼,增添了节日的气氛。背景左右两侧是传统的绿色中式建筑,屋顶呈现出典型的飞檐翘角设计。底部有祥云朵朵和可爱的卡通福袋,象征着好运福气满满。整张海报传达出浓厚的节日氛围。", 140 | }, 141 | ] 142 | else: 143 | raise ValueError("Invalid cogview_version. Please choose either 'cogview3' or 'cogview4'.") 144 | 145 | 146 | if __name__ == "__main__": 147 | parser = argparse.ArgumentParser() 148 | parser.add_argument("--api_key", type=str, help="API key") 149 | parser.add_argument("--prompt", type=str, help="Prompt to upsample") 150 | parser.add_argument("--base_url", type=str, default="https://open.bigmodel.cn/api/paas/v4", help="Base URL") 151 | parser.add_argument("--model", type=str, default="glm-4-plus", help="LLM using for upsampling") 152 | parser.add_argument( 153 | "--cogview_version", 154 | type=str, 155 | choices=["cogview3", "cogview4"], 156 | required=True, 157 | help="Choose the version of CogView (cogview3 or cogview4)", 158 | ) 159 | args = parser.parse_args() 160 | system_instruction = get_system_instruction(args.cogview_version) 161 | user_assistant_pairs = get_user_assistant_pairs(args.cogview_version) 162 | prompt_enhanced = convert_prompt( 163 | api_key=args.api_key, 164 | base_url=args.base_url, 165 | prompt=args.prompt, 166 | model=args.model, 167 | system_instruction=system_instruction, 168 | user_assistant_pairs=user_assistant_pairs, 169 | ) 170 | print(prompt_enhanced) 171 | -------------------------------------------------------------------------------- /inference/requirements.txt: -------------------------------------------------------------------------------- 1 | deepspeed>=0.16.3 2 | transformers>=4.49.0 3 | gradio>=5.16.0 4 | accelerate>=1.2.1 5 | git+https://github.com/huggingface/diffusers.git 6 | sentencepiece>=0.2.0 7 | torch>=2.6.0 8 | openai>=1.61.0 9 | bitsandbytes>=0.45.2 10 | torchao>=0.9.0 11 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.ruff] 2 | line-length = 119 3 | 4 | [tool.ruff.lint] 5 | # Never enforce `E501` (line length violations). 6 | ignore = ["C901", "E501", "E741", "F402", "F823"] 7 | select = ["C", "E", "F", "I", "W"] 8 | 9 | # Ignore import violations in all `__init__.py` files. 10 | [tool.ruff.lint.per-file-ignores] 11 | "__init__.py" = ["E402", "F401", "F403", "F811"] 12 | 13 | [tool.ruff.lint.isort] 14 | lines-after-imports = 2 15 | 16 | [tool.ruff.format] 17 | # Like Black, use double quotes for strings. 18 | quote-style = "double" 19 | 20 | # Like Black, indent with spaces, rather than tabs. 21 | indent-style = "space" 22 | 23 | # Like Black, respect magic trailing commas. 24 | skip-magic-trailing-comma = false 25 | 26 | # Like Black, automatically detect the appropriate line ending. 27 | line-ending = "auto" 28 | -------------------------------------------------------------------------------- /resources/CogView3_evaluation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THUDM/CogView4/1f21551ad1a00fec276c01e8bbf98b38d9590108/resources/CogView3_evaluation.png -------------------------------------------------------------------------------- /resources/CogView3_pipeline.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THUDM/CogView4/1f21551ad1a00fec276c01e8bbf98b38d9590108/resources/CogView3_pipeline.jpg -------------------------------------------------------------------------------- /resources/CogView3_showcase.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THUDM/CogView4/1f21551ad1a00fec276c01e8bbf98b38d9590108/resources/CogView3_showcase.png -------------------------------------------------------------------------------- /resources/WECHAT.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | 4 |

扫码关注公众号,加入「 CogView 交流群」

5 |

Scan the QR code to follow the official account and join the "CogView Discussion Group"

6 |
7 | -------------------------------------------------------------------------------- /resources/controlnet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THUDM/CogView4/1f21551ad1a00fec276c01e8bbf98b38d9590108/resources/controlnet.png -------------------------------------------------------------------------------- /resources/img.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THUDM/CogView4/1f21551ad1a00fec276c01e8bbf98b38d9590108/resources/img.png -------------------------------------------------------------------------------- /resources/logo.svg: -------------------------------------------------------------------------------- 1 | 2 | 3 | 5 | 97 | 98 | 99 | 105 | 110 | 118 | 121 | 125 | 132 | 137 | 143 | 144 | 145 | 146 | 147 | 148 | 150 | 151 | 152 | 153 | 154 | 155 | 157 | 160 | 163 | 165 | 167 | 169 | 171 | 172 | 173 | 174 | 175 | 176 | 177 | 178 | 179 | 182 | 183 | 184 | 185 | 186 | 187 | 188 | 189 | 190 | 192 | 193 | 194 | 195 | 196 | 197 | 198 | 199 | 200 | 202 | 203 | 204 | 205 | 206 | 207 | 208 | 209 | 211 | 212 | 213 | 214 | 215 | 216 | 217 | 218 | 220 | 221 | 222 | 223 | 224 | 225 | 226 | 227 | 229 | 230 | 231 | 232 | 233 | 234 | 235 | 236 | 238 | 239 | 240 | 241 | 242 | 243 | 244 | 245 | 247 | 248 | 249 | 250 | 251 | 252 | 253 | 254 | 256 | 257 | 258 | 259 | 260 | 261 | 262 | 263 | 264 | 266 | 267 | 268 | 269 | 270 | 271 | 272 | 273 | 274 | 276 | 277 | 278 | 279 | 280 | 281 | 282 | 283 | 284 | 286 | 287 | 288 | 289 | 290 | 291 | 292 | 293 | 294 | 296 | 297 | 298 | 299 | 300 | 301 | 302 | 303 | 304 | 306 | 307 | 308 | 309 | 310 | 311 | 312 | 313 | 314 | 317 | 318 | 319 | 320 | 321 | 322 | 323 | 324 | 325 | 326 | 328 | 329 | 330 | 331 | 332 | 333 | 335 | 336 | 337 | 338 | 339 | 340 | 341 | 343 | 344 | 345 | 346 | 347 | 348 | 350 | 351 | 352 | 353 | 354 | 355 | 357 | 358 | 359 | 360 | 361 | 362 | 364 | 365 | 366 | 367 | 368 | 369 | 370 | 371 | 372 | 373 | 374 | 375 | 376 | 377 | 379 | 380 | 381 | 382 | 383 | 384 | 385 | 387 | 388 | 389 | 390 | 391 | 392 | 394 | 395 | 396 | 397 | 398 | 399 | 401 | 402 | 403 | 404 | 405 | 406 | 407 | 408 | 409 | 410 | 411 | 412 | 414 | 415 | 416 | 417 | 418 | 419 | 421 | 422 | 423 | 424 | 425 | 426 | 428 | 429 | 430 | 431 | 432 | 433 | 435 | 436 | 437 | 438 | 439 | 440 | 442 | 443 | 444 | 445 | 446 | 447 | 449 | 450 | 451 | 452 | 453 | 454 | 457 | 458 | 459 | 460 | 461 | 462 | 463 | 464 | 465 | 466 | 467 | 468 | 469 | 470 | 472 | 473 | 474 | 475 | 476 | 477 | 480 | 481 | 482 | 483 | 484 | 485 | 486 | 488 | 489 | 490 | 491 | 492 | 493 | 495 | 496 | 497 | 498 | 499 | 500 | 501 | 502 | 503 | 504 | 505 | 506 | 507 | 508 | 509 | 510 | 511 | 512 | 513 | 514 | 515 | 516 | 517 | 518 | 519 | 520 | 521 | 522 | 523 | 524 | 525 | 526 | 527 | 528 | 529 | 530 | 531 | 532 | 533 | 534 | 535 | 536 | 537 | 539 | 540 | 541 | 542 | 543 | 544 | 546 | 547 | 548 | 549 | 550 | 551 | 552 | 553 | 554 | 556 | 557 | 558 | 559 | 560 | 561 | 563 | 564 | 565 | 566 | 567 | 568 | 569 | 571 | 572 | 573 | 574 | 575 | 576 | 578 | 579 | 580 | 581 | 582 | 583 | 585 | 586 | 587 | 588 | 589 | 590 | 592 | 593 | 594 | 595 | 596 | 597 | 598 | 599 | 600 | 601 | 602 | 603 | 604 | 605 | 607 | 608 | 609 | 610 | 611 | 612 | 613 | 615 | 616 | 617 | 618 | 619 | 620 | 622 | 623 | 624 | 625 | 626 | 627 | 629 | 630 | 631 | 632 | 633 | 634 | 635 | 636 | 637 | 638 | 639 | 640 | 642 | 643 | 644 | 645 | 646 | 647 | 649 | 650 | 651 | 652 | 653 | 654 | 656 | 657 | 658 | 659 | 660 | 661 | 663 | 664 | 665 | 666 | 667 | 668 | 670 | 671 | 672 | 673 | 674 | 675 | 677 | 678 | 679 | 680 | 681 | 682 | 685 | 686 | 687 | 688 | 689 | 690 | 691 | 692 | 693 | 694 | 695 | 696 | 697 | 698 | 699 | 701 | 702 | 703 | 704 | 705 | 706 | 708 | 709 | 710 | 711 | 712 | 713 | 714 | 715 | 716 | 718 | 719 | 720 | 721 | 722 | 723 | 726 | 727 | 728 | 729 | 730 | 731 | 732 | 734 | 735 | 736 | 737 | 738 | 739 | 741 | 742 | 743 | 744 | 745 | 746 | 747 | 748 | 749 | 750 | 751 | 752 | 753 | 754 | 755 | 756 | 757 | 758 | 759 | 760 | 761 | 762 | 763 | 764 | 765 | 766 | 767 | 768 | 769 | 770 | 771 | 772 | 773 | 774 | 775 | 776 | 777 | 778 | 779 | 780 | 781 | 782 | 783 | 784 | 785 | 789 | 790 | 791 | 792 | 793 | 794 | 800 | 801 | 802 | 803 | 804 | 805 | 812 | 813 | 814 | 815 | 816 | 817 | 824 | 825 | 826 | 827 | 828 | 829 | 836 | 837 | 838 | 839 | 840 | 841 | 842 | 843 | 851 | 852 | 853 | 854 | 855 | 856 | 864 | 865 | 866 | 867 | 868 | 869 | 872 | 873 | 874 | 875 | 876 | 877 | 880 | 881 | 882 | 883 | 884 | 885 | 886 | 887 | 888 | 896 | 897 | 898 | 899 | 900 | 901 | 902 | 903 | 904 | 905 | 906 | 907 | 908 | 916 | 917 | 918 | 919 | 920 | 921 | 927 | 928 | 929 | 930 | 931 | 932 | 938 | 939 | 940 | 941 | -------------------------------------------------------------------------------- /resources/showcase.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THUDM/CogView4/1f21551ad1a00fec276c01e8bbf98b38d9590108/resources/showcase.png -------------------------------------------------------------------------------- /resources/train_control_model.md: -------------------------------------------------------------------------------- 1 | ## Fine-Tuning Control Models (ControlNet) 2 | 3 | ### Principles 4 | 5 | We increase the number of channels in the `latent` from `16` to `32`. Of these, an additional 16 channels are used for 6 | the control model. The values of these channels will be used to control the output of the generator. We treat these 7 | channel values as part of the `latent` and concatenate them with the `latent`. In this way, we can control the output of 8 | the generator by adjusting the values of these channels. 9 | 10 | ### Server Requirements 11 | 12 | - At least one `A100 GPU` is required. Using `zero2` for training, each card can handle a batch of 8. 13 | - If you want to fine-tune the model in full, we recommend using `batchsize=128`. 14 | - Linux operating system is required for installing `deepspeed`. 15 | 16 | ### Preparing the Dataset 17 | 18 | In this example, we use [open_pose_controlnet](https://huggingface.co/datasets/raulc0399/open_pose_controlnet) for 19 | fine-tuning. 20 | You can also use your own dataset, but you will need to follow the `open_pose_controlnet` dataset format or modify the 21 | dataloader accordingly. 22 | 23 | **Note** 24 | 25 | + All images will be resized to a fixed size. Dynamic resolutions are not supported. 26 | + Dynamic-length tokens are not supported. Tokens in each batch will be padded to the maximum length. 27 | 28 | ### Start Training 29 | 30 | 1. Clone the source code and install [diffusers](https://github.com/huggingface/diffusers), then navigate to the 31 | fine-tuning directory: 32 | 33 | ```shell 34 | git clone https://github.com/huggingface/diffusers.git 35 | cd diffusers 36 | pip install -e . 37 | cd examples/cogview4-control 38 | ``` 39 | 40 | 2. Set up deepspeed and accelerate environments 41 | 42 | Here is an example accelerate configuration file using zero2: 43 | 44 | ```yaml 45 | compute_environment: LOCAL_MACHINE 46 | debug: false 47 | deepspeed_config: 48 | # deepspeed_hostfile: hostfile # If using multi-machine multi-card training, prepare the hostfile configuration 49 | gradient_accumulation_steps: 1 50 | gradient_clipping: 1.0 51 | offload_optimizer_device: cpu 52 | offload_param_device: cpu 53 | zero3_init_flag: false 54 | zero_stage: 2 55 | num_machines: 1 56 | num_processes: 8 # 8 processes in total, write 16 if using two machines 57 | distributed_type: DEEPSPEED 58 | downcast_bf16: 'no' 59 | enable_cpu_affinity: false 60 | gpu_ids: all 61 | machine_rank: 0 62 | main_training_function: main 63 | mixed_precision: bf16 64 | rdzv_backend: static 65 | same_network: true 66 | tpu_env: [ ] 67 | tpu_use_cluster: false 68 | tpu_use_sudo: false 69 | use_cpu: false 70 | ``` 71 | 72 | Save this configuration as `accelerate_ds.yaml`. 73 | 74 | 3. Run the following command to start training: 75 | 76 | ```shell 77 | accelerate launch --config_file=accelerate_ds.yaml train_control_cogview4.py \ 78 | --pretrained_model_name_or_path="THUDM/CogView4-6B" \ 79 | --dataset_name="raulc0399/open_pose_controlnet" \ 80 | --output_dir="pose-control" \ 81 | --mixed_precision="bf16" \ 82 | --train_batch_size=1 \ 83 | --dataloader_num_workers=16 \ 84 | --gradient_accumulation_steps=1 \ 85 | --gradient_checkpointing \ 86 | --proportion_empty_prompts=0 \ 87 | --learning_rate=5e-5 \ 88 | --adam_weight_decay=1e-4 \ 89 | --report_to="wandb" \ 90 | --lr_scheduler="cosine" \ 91 | --lr_warmup_steps=1000 \ 92 | --checkpointing_steps=100 \ 93 | --max_train_steps=50000 \ 94 | --validation_steps=100 \ 95 | --validation_image "pose.jpg" \ 96 | --validation_prompt "two friends sitting by each other enjoying a day at the park, full hd, cinematic" \ 97 | --offload \ 98 | --seed="0" 99 | ``` 100 | 101 | **Note** 102 | 103 | - Training must be done using bf16 mixed precision or fp32. fp16 and fp8 are not supported. 104 | - idation_image and validation_prompt need to be prepared and placed in the same directory. In this example, the first 105 | entry from the open_pose_controlnet dataset is used. 106 | 107 | ## Using the Fine-Tuned Weights 108 | 109 | Assuming you used the results from `10000` steps, and your model resolution is `1024`: 110 | 111 | ## SFT 112 | 113 | ```python 114 | from diffusers import CogView4ControlPipeline, CogView4Transformer2DModel 115 | from controlnet_aux import CannyDetector 116 | from diffusers.utils import load_image 117 | import torch 118 | 119 | transformer = CogView4Transformer2DModel.from_pretrained("pose-control/checkpoint-10000/transformer", 120 | torch_dtype=torch.bfloat16).to("cuda:0") 121 | pipe = CogView4ControlPipeline.from_pretrained("THUDM/CogView4-6B", transformer=transformer, 122 | torch_dtype=torch.bfloat16).to("cuda:0") 123 | 124 | prompt = "two friends sitting by each other enjoying a day at the park, full hd, cinematic" 125 | control_image = load_image("pose.jpg") 126 | processor = CannyDetector() 127 | control_image = processor( 128 | control_image, low_threshold=50, high_threshold=200, detect_resolution=1024, image_resolution=1024 129 | ) 130 | image = pipe( 131 | prompt=prompt, 132 | control_image=control_image, 133 | height=1024, 134 | width=1024, 135 | num_inference_steps=50, 136 | guidance_scale=3.5, 137 | ).images[0] 138 | 139 | image.save("cogview4.png") 140 | ``` 141 | -------------------------------------------------------------------------------- /resources/train_control_model_zh.md: -------------------------------------------------------------------------------- 1 | ## 训练控制模型 (ControlNet) 2 | 3 | ### 原理 4 | 5 | 我们将 `latent` 的通道数从`16`提升到`32`。其中,另外16个通道用于控制模型。这些通道的值将被用于控制生成器的输出。我们将这些通道的值视为`latent`的一部分,因此我们将其与`latent`连接在一起。这样,我们就可以通过调整这些通道的值来控制生成器的输出。 6 | 7 | ### 服务器要求 8 | 9 | - 至少需要`1`张`A100GPU`。使用`zero2`进行训练,则每张卡可以承载8个batch。 10 | - 如果你希望全量微调模型,我们推荐使用`batchsize=128`。 11 | - Linux操作系统以安装`deepseed`。 12 | 13 | ### 准备数据集 14 | 15 | 在这个例子汇总,我们使用 [open_pose_controlnet](https://huggingface.co/datasets/raulc0399/open_pose_controlnet) 进行微调。 16 | 你也可以使用自己的数据,这通常需要你遵循`open_pose_controlnet`数据集的格式或修改 dataloader。 17 | 18 | **请注意** 19 | 20 | + 所有图像都会被resize成固定大小的图片,不支持动态分辨率。 21 | + 不支持动态长度的token。每个batch的token都会被pad成最大长度。 22 | 23 | ### 开始训练 24 | 25 | 1. 源代码安装 [diffusers](https://github.com/huggingface/diffusers) 并打开微调目录 26 | ```shell 27 | git clone https://github.com/huggingface/diffusers.git 28 | cd diffusers 29 | pip install -e . 30 | cd examples/cogview4-control 31 | ``` 32 | 33 | 2. 配置`deepsped`环境 和 `accelerate` 环境 34 | 35 | 这里是一个示例的 `accelerate` 配置文件。使用`zero2` 36 | 37 | ```yaml 38 | compute_environment: LOCAL_MACHINE 39 | debug: false 40 | deepspeed_config: 41 | # deepspeed_hostfile: hostfile # 如果是多机多卡训练,需要自行准备配置hostfile 42 | gradient_accumulation_steps: 1 43 | gradient_clipping: 1.0![img.png](img.png) 44 | offload_optimizer_device: cpu 45 | offload_param_device: cpu 46 | zero3_init_flag: false 47 | zero_stage: 2 48 | num_machines: 1 49 | num_processes: 8 # 一共8台机器,如果是两台机器则要写16。 50 | # main_process_port: 29500 51 | # main_process_ip: 10.250.XX.XX # 如果是多机多卡训练,需要自己设置主机IP 52 | distributed_type: DEEPSPEED 53 | downcast_bf16: 'no' 54 | enable_cpu_affinity: false 55 | gpu_ids: all 56 | machine_rank: 0 57 | main_training_function: main 58 | mixed_precision: bf16 59 | rdzv_backend: static 60 | same_network: true 61 | tpu_env: [] 62 | tpu_use_cluster: false 63 | tpu_use_sudo: false 64 | use_cpu: false 65 | ``` 66 | 并将这个配置文件保存为 `accelerate_ds.yaml`。 67 | 68 | 3. 运行以下命令开始训练: 69 | ```shell 70 | accelerate launch --config_file=accelerate_ds.yaml train_control_cogview4.py \ 71 | --pretrained_model_name_or_path="THUDM/CogView4-6B" \ 72 | --dataset_name="raulc0399/open_pose_controlnet" \ 73 | --output_dir="pose-control" \ 74 | --mixed_precision="bf16" \ 75 | --train_batch_size=1 \ 76 | --dataloader_num_workers=16 \ 77 | --gradient_accumulation_steps=1 \ 78 | --gradient_checkpointing \ 79 | --proportion_empty_prompts=0 \ 80 | --learning_rate=5e-5 \ 81 | --adam_weight_decay=1e-4 \ 82 | --report_to="wandb" \ 83 | --lr_scheduler="cosine" \ 84 | --lr_warmup_steps=1000 \ 85 | --checkpointing_steps=100 \ 86 | --max_train_steps=50000 \ 87 | --validation_steps=100 \ 88 | --validation_image "pose.jpg" \ 89 | --validation_prompt "two friends sitting by each other enjoying a day at the park, full hd, cinematic" \ 90 | --offload \ 91 | --seed="0" 92 | ``` 93 | 94 | **请注意** 95 | 96 | - 必须使用 `bf16` 混合精度训练 或者 `fp32` 进行训练。不支持 `fp16` 和 `fp8`。 97 | - `validation_image` 和 `validation_prompt` 需要自己准备并放在同一个目录下,例子中使用的是 `open_pose_controlnet`的第一条数据。 98 | 99 | ## 调用微调后的权重 100 | 101 | 这里假设你使用了`10000`步的结果。并假设你训练的模型的分辨率是`1024`。 102 | 103 | ### SFT 104 | 105 | ```python 106 | from diffusers import CogView4ControlPipeline,CogView4Transformer2DModel 107 | from controlnet_aux import CannyDetector 108 | from diffusers.utils import load_image 109 | import torch 110 | transformer = CogView4Transformer2DModel.from_pretrained("pose-control/checkpoint-10000/transformer", torch_dtype=torch.bfloat16).to("cuda:0") 111 | pipe = CogView4ControlPipeline.from_pretrained("THUDM/CogView4-6B",transformer=transformer, torch_dtype=torch.bfloat16).to("cuda:0") 112 | 113 | prompt = "two friends sitting by each other enjoying a day at the park, full hd, cinematic" 114 | control_image = load_image("pose.jpg") 115 | processor = CannyDetector() 116 | control_image = processor( 117 | control_image, low_threshold=50, high_threshold=200, detect_resolution=1024, image_resolution=1024 118 | ) 119 | image = pipe( 120 | prompt=prompt, 121 | control_image=control_image, 122 | height=1024, 123 | width=1024, 124 | num_inference_steps=50, 125 | guidance_scale=3.5, 126 | ).images[0] 127 | 128 | image.save("cogview4.png") 129 | ``` 130 | -------------------------------------------------------------------------------- /resources/wechat.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THUDM/CogView4/1f21551ad1a00fec276c01e8bbf98b38d9590108/resources/wechat.jpg --------------------------------------------------------------------------------