├── .gitignore
├── LICENSE
├── MODEL_LICENSE
├── README.md
├── README_CN.md
├── controlnet
├── README.md
├── annotator
│ ├── __init__.py
│ ├── canny
│ │ └── __init__.py
│ ├── dwpose
│ │ ├── __init__.py
│ │ ├── onnxdet.py
│ │ ├── onnxpose.py
│ │ ├── util.py
│ │ └── wholebody.py
│ ├── midas
│ │ ├── LICENSE
│ │ ├── __init__.py
│ │ ├── api.py
│ │ ├── midas
│ │ │ ├── __init__.py
│ │ │ ├── base_model.py
│ │ │ ├── blocks.py
│ │ │ ├── dpt_depth.py
│ │ │ ├── midas_net.py
│ │ │ ├── midas_net_custom.py
│ │ │ ├── transforms.py
│ │ │ └── vit.py
│ │ └── utils.py
│ └── util.py
├── assets
│ ├── bird.png
│ ├── dog.png
│ ├── woman_1.png
│ ├── woman_2.png
│ ├── woman_3.png
│ └── woman_4.png
├── outputs
│ ├── Canny_dog.jpg
│ ├── Canny_dog_condition.jpg
│ ├── Canny_woman_1.jpg
│ ├── Canny_woman_1_condition.jpg
│ ├── Canny_woman_1_sdxl.jpg
│ ├── Depth_1_condition.jpg
│ ├── Depth_bird.jpg
│ ├── Depth_bird_condition.jpg
│ ├── Depth_bird_sdxl.jpg
│ ├── Depth_ipadapter_1.jpg
│ ├── Depth_ipadapter_woman_2.jpg
│ ├── Depth_woman_2.jpg
│ ├── Depth_woman_2_condition.jpg
│ ├── Pose_woman_3.jpg
│ ├── Pose_woman_3_condition.jpg
│ ├── Pose_woman_3_sdxl.jpg
│ ├── Pose_woman_4.jpg
│ └── Pose_woman_4_condition.jpg
├── sample_controlNet.py
└── sample_controlNet_ipadapter.py
├── dreambooth
├── README.md
├── default_config.yaml
├── infer_dreambooth.py
├── ktxl_test_image.png
├── train.sh
└── train_dreambooth_lora.py
├── imgs
├── Kolors_paper.pdf
├── cn_all.png
├── fz_all.png
├── head_final3.png
├── logo.png
├── prompt_vis.txt
├── wechat.png
├── wz_all.png
├── zl8.png
├── 可图KOLORS模型商业授权申请书-英文版本.docx
└── 可图KOLORS模型商业授权申请书.docx
├── inpainting
├── README.md
├── asset
│ ├── 1.png
│ ├── 1_kolors.png
│ ├── 1_masked.png
│ ├── 1_sdxl.png
│ ├── 2.png
│ ├── 2_kolors.png
│ ├── 2_masked.png
│ ├── 2_sdxl.png
│ ├── 3.png
│ ├── 3_kolors.png
│ ├── 3_mask.png
│ ├── 3_masked.png
│ ├── 3_sdxl.png
│ ├── 4.png
│ ├── 4_kolors.png
│ ├── 4_mask.png
│ ├── 4_masked.png
│ ├── 4_sdxl.png
│ ├── 5.png
│ ├── 5_kolors.png
│ ├── 5_masked.png
│ └── 5_sdxl.png
└── sample_inpainting.py
├── ipadapter
├── README.md
├── asset
│ ├── 1.png
│ ├── 1_kolors_ip_result.jpg
│ ├── 1_mj_cw_result.png
│ ├── 1_sdxl_ip_result.jpg
│ ├── 2.png
│ ├── 2_kolors_ip_result.jpg
│ ├── 2_mj_cw_result.png
│ ├── 2_sdxl_ip_result.jpg
│ ├── 3.png
│ ├── 3_kolors_ip_result.jpg
│ ├── 3_mj_cw_result.png
│ ├── 3_sdxl_ip_result.jpg
│ ├── 4.png
│ ├── 4_kolors_ip_result.jpg
│ ├── 4_mj_cw_result.png
│ ├── 4_sdxl_ip_result.jpg
│ ├── 5.png
│ ├── 5_kolors_ip_result.jpg
│ ├── 5_mj_cw_result.png
│ ├── 5_sdxl_ip_result.jpg
│ ├── test_ip.jpg
│ └── test_ip2.png
└── sample_ipadapter_plus.py
├── ipadapter_FaceID
├── README.md
├── assets
│ ├── image1.png
│ ├── image1_res.png
│ ├── image2.png
│ ├── image2_res.png
│ ├── test_img1_Kolors.png
│ ├── test_img1_SDXL.png
│ ├── test_img1_org.png
│ ├── test_img2_Kolors.png
│ ├── test_img2_SDXL.png
│ ├── test_img2_org.png
│ ├── test_img3_Kolors.png
│ ├── test_img3_SDXL.png
│ ├── test_img3_org.png
│ ├── test_img4_Kolors.png
│ ├── test_img4_SDXL.png
│ └── test_img4_org.png
└── sample_ipadapter_faceid_plus.py
├── kolors
├── __init__.py
├── models
│ ├── __init__.py
│ ├── configuration_chatglm.py
│ ├── controlnet.py
│ ├── ipa_faceid_plus
│ │ ├── __init__.py
│ │ ├── attention_processor.py
│ │ └── ipa_faceid_plus.py
│ ├── modeling_chatglm.py
│ ├── tokenization_chatglm.py
│ └── unet_2d_condition.py
└── pipelines
│ ├── __init__.py
│ ├── pipeline_controlnet_xl_kolors_img2img.py
│ ├── pipeline_stable_diffusion_xl_chatglm_256.py
│ ├── pipeline_stable_diffusion_xl_chatglm_256_inpainting.py
│ ├── pipeline_stable_diffusion_xl_chatglm_256_ipadapter.py
│ └── pipeline_stable_diffusion_xl_chatglm_256_ipadapter_FaceID.py
├── requirements.txt
├── scripts
├── outputs
│ ├── sample_inpainting_3.jpg
│ ├── sample_inpainting_4.jpg
│ ├── sample_ip_test_ip.jpg
│ ├── sample_ip_test_ip2.jpg
│ └── sample_test.jpg
├── sample.py
└── sampleui.py
└── setup.py
/.gitignore:
--------------------------------------------------------------------------------
1 | weights
2 | *.egg-info
3 | __pycache__
4 | *.pyc
5 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/MODEL_LICENSE:
--------------------------------------------------------------------------------
1 | 中文版
2 | 模型许可协议
3 | 模型发布日期:2024/7/6
4 |
5 | 通过点击同意或使用、复制、修改、分发、表演或展示模型作品的任何部分或元素,您将被视为已承认并接受本协议的内容,本协议立即生效。
6 |
7 | 1.定义。
8 | a. “协议”指本协议中所规定的使用、复制、分发、修改、表演和展示模型作品或其任何部分或元素的条款和条件。
9 | b. “材料”是指根据本协议提供的专有的模型和文档(及其任何部分)的统称。
10 | c. “模型”指大型语言模型、图像/视频/音频/3D 生成模型、多模态大型语言模型及其软件和算法,包括训练后的模型权重、参数(包括优化器状态)、机器学习模型代码、推理支持代码、训练支持代码、微调支持代码以及我们公开提供的前述其他元素。
11 | d. “输出”是指通过操作或以其他方式使用模型或模型衍生品而产生的模型或模型衍生品的信息和/或内容输出。
12 | e. “模型衍生品”包括:(i)对模型或任何模型衍生物的修改;(ii)基于模型的任何模型衍生物的作品;或(iii)通过将模型或模型的任何模型衍生物的权重、参数、操作或输出的模式转移到该模型而创建的任何其他机器学习模型,以使该模型的性能类似于模型或模型衍生物。为清楚起见,输出本身不被视为模型衍生物。
13 | f. “模型作品”包括:(i)材料;(ii)模型衍生品;及(iii)其所有衍生作品。
14 | g. “许可人”或“我们”指作品所有者或作品所有者授权的授予许可的实体,包括可能对模型和/或分发模型拥有权利的个人或实体。
15 | h.“被许可人”、“您”或“您的”是指行使本协议授予的权利和/或为任何目的和在任何使用领域使用模型作品的自然人或法人实体。
16 | i.“第三方”是指不受我们或您共同控制的个人或法人实体。
17 |
18 | 2. 许可内容。
19 | a.我们授予您非排他性的、全球性的、不可转让的、免版税的许可(在我们的知识产权或我们拥有的体现在材料中或利用材料的其他权利的范围内),允许您仅根据本协议的条款使用、复制、分发、创作衍生作品(包括模型衍生品)和对材料进行修改,并且您不得违反(或鼓励、或允许任何其他人违反)本协议的任何条款。
20 | b.在遵守本协议的前提下,您可以分发或向第三方提供模型作品,您须满足以下条件:
21 | (i)您必须向所有该模型作品或使用该作品的产品或服务的任何第三方接收者提供模型作品的来源和本协议的副本;
22 | (ii)您必须在任何修改过的文档上附加明显的声明,说明您更改了这些文档;
23 | (iii)您可以在您的修改中添加您自己的版权声明,并且,在您对该作品的使用、复制、修改、分发、表演和展示符合本协议的条款和条件的前提下,您可以为您的修改或任何此类模型衍生品的使用、复制或分发提供额外或不同的许可条款和条件。
24 | c. 附加商业条款:若您或其关联方提供的所有产品或服务的月活跃用户数在前一个自然月未超过3亿月活跃用户数,则您向许可方进行登记,将被视为获得相应的商业许可;若您或其关联方提供的所有产品或服务的月活跃用户数在前一个自然月超过3亿月活跃用户数,则您必须向许可人申请许可,许可人可自行决定向您授予许可。除非许可人另行明确授予您该等权利,否则您无权行使本协议项下的任何权利。
25 |
26 | 3.使用限制。
27 | a. 您对本模型作品的使用必须遵守适用法律法规(包括贸易合规法律法规),并遵守《服务协议》(https://kolors.kuaishou.com/agreement)。您必须将本第 3(a) 和 3(b) 条中提及的使用限制作为可执行条款纳入任何规范本模型作品使用和/或分发的协议(例如许可协议、使用条款等),并且您必须向您分发的后续用户发出通知,告知其本模型作品受本第 3(a) 和 3(b) 条中的使用限制约束。
28 | b. 您不得使用本模型作品或本模型作品的任何输出或成果来改进任何其他模型(本模型或其模型衍生品除外)。
29 |
30 | 4.知识产权。
31 | a. 我们保留材料的所有权及其相关知识产权。在遵守本协议条款和条件的前提下,对于您制作的材料的任何衍生作品和修改,您是且将是此类衍生作品和修改的所有者。
32 | b. 本协议不授予任何商标、商号、服务标记或产品名称的标识许可,除非出于描述和分发本模型作品的合理和惯常用途。
33 | c. 如果您对我们或任何个人或实体提起诉讼或其他程序(包括诉讼中的交叉索赔或反索赔),声称材料或任何输出或任何上述内容的任何部分侵犯您拥有或可许可的任何知识产权或其他权利,则根据本协议授予您的所有许可应于提起此类诉讼或其他程序之日起终止。
34 |
35 | 5. 免责声明和责任限制。
36 | a. 本模型作品及其任何输出和结果按“原样”提供,不作任何明示或暗示的保证,包括适销性、非侵权性或适用于特定用途的保证。我们不对材料及其任何输出的安全性或稳定性作任何保证,也不承担任何责任。
37 | b. 在任何情况下,我们均不对您承担任何损害赔偿责任,包括但不限于因您使用或无法使用材料或其任何输出而造成的任何直接、间接、特殊或后果性损害赔偿责任,无论该损害赔偿责任是如何造成的。
38 | c. 对于因您使用或分发模型的衍生物而引起的或与之相关的任何第三方索赔,您应提供辩护,赔偿,并使我方免受损害。
39 |
40 | 6. 存续和终止。
41 | a. 本协议期限自您接受本协议或访问材料之日起开始,并将持续完全有效,直至根据本协议条款和条件终止。
42 | b. 如果您违反本协议的任何条款或条件,我们可终止本协议。本协议终止后,您必须立即删除并停止使用本模型作品。第 4(a)、4(c)、5和 7 条在本协议终止后仍然有效。
43 |
44 | 7. 适用法律和管辖权。
45 | a. 本协议及由本协议引起的或与本协议有关的任何争议均受中华人民共和国大陆地区(仅为本协议目的,不包括香港、澳门和台湾)法律管辖,并排除冲突法的适用,且《联合国国际货物销售合同公约》不适用于本协议。
46 | b. 因本协议引起或与本协议有关的任何争议,由许可人住所地人民法院管辖。
47 |
48 | 请注意,许可证可能会更新到更全面的版本。 有关许可和版权的任何问题,请通过 kwai-kolors@kuaishou.com 与我们联系。
49 |
50 |
51 | 英文版
52 |
53 | MODEL LICENSE AGREEMENT
54 | Release Date: 2024/7/6
55 | By clicking to agree or by using, reproducing, modifying, distributing, performing or displaying any portion or element of the Model Works, You will be deemed to have recognized and accepted the content of this Agreement, which is effective immediately.
56 | 1. DEFINITIONS.
57 | a. “Agreement” shall mean the terms and conditions for use, reproduction, distribution, modification, performance and displaying of the Model Works or any portion or element thereof set forth herein.
58 | b. “Materials” shall mean, collectively, Us proprietary the Model and Documentation (and any portion thereof) as made available by Us under this Agreement.
59 | c. “Model” shall mean the large language models, image/video/audio/3D generation models, and multimodal large language models and their software and algorithms, including trained model weights, parameters (including optimizer states), machine-learning model code, inference-enabling code, training-enabling code, fine-tuning enabling code and other elements of the foregoing made publicly available by Us .
60 | d. “Output” shall mean the information and/or content output of Model or a Model Derivative that results from operating or otherwise using Model or a Model Derivative.
61 | e. “Model Derivatives” shall mean all: (i) modifications to the Model or any Model Derivative; (ii) works based on the Model or any Model Derivative; or (iii) any other machine learning model which is created by transfer of patterns of the weights, parameters, operations, or Output of the Model or any Model Derivative, to that model in order to cause that model to perform similarly to the Model or a Model Derivative, including distillation methods, methods that use intermediate data representations, or methods based on the generation of synthetic data Outputs or a Model Derivative for training that model. For clarity, Outputs by themselves are not deemed Model Derivatives.
62 | f. “Model Works” shall mean: (i) the Materials; (ii) Model Derivatives; and (iii) all derivative works thereof.
63 | g. “Licensor” , “We” or “Us” shall mean the copyright owner or entity authorized by the copyright owner that is granting the License, including the persons or entities that may have rights in the Model and/or distributing the Model.
64 | h. “Licensee”, “You” or “Your” shall mean a natural person or legal entity exercising the rights granted by this Agreement and/or using the Model Works for any purpose and in any field of use.
65 | i. “Third Party” or “Third Parties” shall mean individuals or legal entities that are not under common control with Us or You.
66 |
67 | 2. LICENSE CONTENT.
68 | a. We grant You a non-exclusive, worldwide, non-transferable and royalty-free limited license under the intellectual property or other rights owned by Us embodied in or utilized by the Materials to use, reproduce, distribute, create derivative works of (including Model Derivatives), and make modifications to the Materials, only in accordance with the terms of this Agreement and the Acceptable Use Policy, and You must not violate (or encourage or permit anyone else to violate) any term of this Agreement or the Acceptable Use Policy.
69 | b. You may, subject to Your compliance with this Agreement, distribute or make available to Third Parties the Model Works, provided that You meet all of the following conditions:
70 | (i) You must provide all such Third Party recipients of the Model Works or products or services using them the source of the Model and a copy of this Agreement;
71 | (ii) You must cause any modified documents to carry prominent notices stating that You changed the documents;
72 | (iii) You may add Your own copyright statement to Your modifications and, may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Model Derivatives as a whole, provided Your use, reproduction, modification, distribution, performance and display of the work otherwise complies with the terms and conditions of this Agreement.
73 | c. additional commercial terms: If, the monthly active users of all products or services made available by or for You, or Your affiliates, does not exceed 300 million monthly active users in the preceding calendar month, Your registration with the Licensor will be deemed to have obtained the corresponding business license; If, the monthly active users of all products or services made available by or for You, or Your affiliates, is greater than 300 million monthly active users in the preceding calendar month, You must request a license from Licensor, which the Licensor may grant to You in its sole discretion, and You are not authorized to exercise any of the rights under this Agreement unless or until We otherwise expressly grants You such rights.
74 |
75 | 3. LICENSE RESTRICITIONS.
76 | a. Your use of the Model Works must comply with applicable laws and regulations (including trade compliance laws and regulations) and adhere to the Service Agreement. You must include the use restrictions referenced in these Sections 3(a) and 3(b) as an enforceable provision in any agreement (e.g., license agreement, terms of use, etc.) governing the use and/or distribution of Model Works and You must provide notice to subsequent users to whom You distribute that Model Works are subject to the use restrictions in these Sections 3(a) and 3(b).
77 | b. You must not use the Model Works or any Output or results of the Model Works to improve any other large model (other than Model or Model Derivatives thereof).
78 | 4. INTELLECTUAL PROPERTY.
79 | a. We retain ownership of all intellectual property rights in and to the Model and derivatives. Conditioned upon compliance with the terms and conditions of this Agreement, with respect to any derivative works and modifications of the Materials that are made by You, You are and will be the owner of such derivative works and modifications.
80 | b. No trademark license is granted to use the trade names, trademarks, service marks, or product names of Us, except as required to fulfill notice requirements under this Agreement or as required for reasonable and customary use in describing and redistributing the Materials.
81 | c. If You commence a lawsuit or other proceedings (including a cross-claim or counterclaim in a lawsuit) against Us or any person or entity alleging that the Materials or any Output, or any portion of any of the foregoing, infringe any intellectual property or other right owned or licensable by You, then all licenses granted to You under this Agreement shall terminate as of the date such lawsuit or other proceeding is filed.
82 | 5. DISCLAIMERS OF WARRANTY AND LIMITATIONS OF LIABILITY.
83 | a. THE MODEL WORKS AND ANY OUTPUT AND RESULTS THERE FROM ARE PROVIDED "AS IS" WITHOUT ANY EXPRESS OR IMPLIED WARRANTY OF ANY KIND INCLUDING WARRANTIES OF MERCHANTABILITY, NONINFRINGEMENT, OR FITNESS FOR A PARTICULAR PURPOSE. WE MAKE NO WARRANTY AND ASSUME NO RESPONSIBILITY FOR THE SAFETY OR STABILITY OF THE MATERIALS AND ANY OUTPUT THEREFROM.
84 | b. IN NO EVENT SHALL WE BE LIABLE TO YOU FOR ANY DAMAGES, INCLUDING, BUT NOT LIMITED TO ANY DIRECT, OR INDIRECT, SPECIAL OR CONSEQUENTIAL DAMAGES ARISING FROM YOUR USE OR INABILITY TO USE THE MATERIALS OR ANY OUTPUT OF IT, NO MATTER HOW IT’S CAUSED.
85 | c. You will defend, indemnify and hold harmless Us from and against any claim by any third party arising out of or related to Your use or distribution of the Materials.
86 |
87 | 6. SURVIVAL AND TERMINATION.
88 | a. The term of this Agreement shall commence upon Your acceptance of this Agreement or access to the Materials and will continue in full force and effect until terminated in accordance with the terms and conditions herein.
89 | b. We may terminate this Agreement if You breach any of the terms or conditions of this Agreement. Upon termination of this Agreement, You must promptly delete and cease use of the Model Works. Sections 4(a), 4(c), 5 and 7 shall survive the termination of this Agreement.
90 | 7. GOVERNING LAW AND JURISDICTION.
91 | a. This Agreement and any dispute arising out of or relating to it will be governed by the laws of China (for the purpose of this agreement only, excluding Hong Kong, Macau, and Taiwan), without regard to conflict of law principles, and the UN Convention on Contracts for the International Sale of Goods does not apply to this Agreement.
92 | b. Any disputes arising from or related to this Agreement shall be under the jurisdiction of the People's Court where the Licensor is located.
93 |
94 | Note that the license is subject to update to a more comprehensive version. For any questions related to the license and copyright, please contact us at kwai-kolors@kuaishou.com.
95 |
--------------------------------------------------------------------------------
/controlnet/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | ## 📖 Introduction
5 |
6 | We provide three ControlNet weights and inference code based on Kolors-Basemodel: Canny, Depth and Pose. You can find some example images below.
7 |
8 |
9 | **1、ControlNet Demos**
10 |
11 |
12 |
13 |
14 | Condition Image |
15 | Prompt |
16 | Result Image |
17 |
18 |
19 |
20 |  |
21 | 全景,一只可爱的白色小狗坐在杯子里,看向镜头,动漫风格,3d渲染,辛烷值渲染。 Panorama of a cute white puppy sitting in a cup and looking towards the camera, anime style, 3d rendering, octane rendering. |
22 |  |
23 |
24 |
25 |
26 |  |
27 | 新海诚风格,丰富的色彩,穿着绿色衬衫的女人站在田野里,唯美风景,清新明亮,斑驳的光影,最好的质量,超细节,8K画质。 Makoto Shinkai style, rich colors, a woman in a green shirt standing in the field, beautiful scenery, fresh and bright, mottled light and shadow, best quality, ultra-detailed, 8K quality. |
28 |  |
29 |
30 |
31 |
32 |  |
33 | 一个穿着黑色运动外套、白色内搭,上面戴着项链的女子,站在街边,背景是红色建筑和绿树,高品质,超清晰,色彩鲜艳,超高分辨率,最佳品质,8k,高清,4K。 A woman wearing a black sports jacket and a white top, adorned with a necklace, stands by the street, with a background of red buildings and green trees. high quality, ultra clear, colorful, ultra high resolution, best quality, 8k, HD, 4K. |
34 |  |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 | **2、ControlNet and IP-Adapter-Plus Demos**
44 |
45 | We also support joint inference code between Kolors-IPadapter and Kolors-ControlNet.
46 |
47 |
48 |
49 | Reference Image |
50 | Condition Image |
51 | Prompt |
52 | Result Image |
53 |
54 |
55 |
56 |  |
57 |  |
58 | 一个红色头发的女孩,唯美风景,清新明亮,斑驳的光影,最好的质量,超细节,8K画质。 A girl with red hair, beautiful scenery, fresh and bright, mottled light and shadow, best quality, ultra-detailed, 8K quality. |
59 |  |
60 |
61 |
62 |
63 |  |
64 |  |
65 | 一个漂亮的女孩,最好的质量,超细节,8K画质。 A beautiful girl, best quality, super detail, 8K quality. |
66 |  |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 | ## 📊 Evaluation
77 | To evaluate the performance of models, we compiled a test set of more than 200 images and text prompts. We invite several image experts to provide fair ratings for the generated results of different models. The experts rate the generated images based on four criteria: visual appeal, text faithfulness, conditional controllability, and overall satisfaction. Conditional controllability measures controlnet's ability to preserve spatial structure, while the other criteria follow the evaluation standards of BaseModel. The specific results are summarized in the table below, where Kolors-ControlNet achieved better performance in various criterias.
78 |
79 | **1、Canny**
80 |
81 | | Model | Average Overall Satisfaction | Average Visual Appeal | Average Text Faithfulness | Average Conditional Controllability |
82 | | :--------------: | :--------: | :--------: | :--------: | :--------: |
83 | | SDXL-ControlNet-Canny | 3.14 | 3.63 | 4.37 | 2.84 |
84 | | **Kolors-ControlNet-Canny** | **4.06** | **4.64** | **4.45** | **3.52** |
85 |
86 |
87 |
88 | **2、Depth**
89 |
90 | | Model | Average Overall Satisfaction | Average Visual Appeal | Average Text Faithfulness | Average Conditional Controllability |
91 | | :--------------: | :--------: | :--------: | :--------: | :--------: |
92 | | SDXL-ControlNet-Depth | 3.35 | 3.77 | 4.26 | 4.5 |
93 | | **Kolors-ControlNet-Depth** | **4.12** | **4.12** | **4.62** | **4.6** |
94 |
95 |
96 |
97 | **3、Pose**
98 |
99 | | Model | Average Overall Satisfaction | Average Visual Appeal | Average Text Faithfulness | Average Conditional Controllability |
100 | | :--------------: | :--------: | :--------: | :--------: | :--------: |
101 | | SDXL-ControlNet-Pose | 1.70 | 2.78 | 4.05 | 1.98 |
102 | | **Kolors-ControlNet-Pose** | **3.33** | **3.63** | **4.78** | **4.4** |
103 |
104 |
105 | *The [SDXL-ControlNet-Canny](https://huggingface.co/diffusers/controlnet-canny-sdxl-1.0) and [SDXL-ControlNet-Depth](https://huggingface.co/diffusers/controlnet-depth-sdxl-1.0) load [DreamShaper-XL](https://civitai.com/models/112902?modelVersionId=351306) as backbone model.*
106 |
107 |
108 |
109 |
110 | Compare Result |
111 |
112 |
113 |
114 | Condition Image |
115 | Prompt |
116 | Kolors-ControlNet Result |
117 | SDXL-ControlNet Result |
118 |
119 |
120 |
121 |  |
122 | 一个漂亮的女孩,高品质,超清晰,色彩鲜艳,超高分辨率,最佳品质,8k,高清,4K。 A beautiful girl, high quality, ultra clear, colorful, ultra high resolution, best quality, 8k, HD, 4K. |
123 |  |
124 |  |
125 |
126 |
127 |
128 |
129 |  |
130 | 一只颜色鲜艳的小鸟,高品质,超清晰,色彩鲜艳,超高分辨率,最佳品质,8k,高清,4K。 A colorful bird, high quality, ultra clear, colorful, ultra high resolution, best quality, 8k, HD, 4K. |
131 |  |
132 |  |
133 |
134 |
135 |
136 |  |
137 | 一位穿着紫色泡泡袖连衣裙、戴着皇冠和白色蕾丝手套的女孩双手托脸,高品质,超清晰,色彩鲜艳,超高分辨率 ,最佳品质,8k,高清,4K。 A girl wearing a purple puff-sleeve dress, with a crown and white lace gloves, is cupping her face with both hands. High quality, ultra-clear, vibrant colors, ultra-high resolution, best quality, 8k, HD, 4K. |
138 |  |
139 |  |
140 |
141 |
142 |
143 |
144 |
145 |
146 |
147 | ------
148 |
149 |
150 | ## 🛠️ Usage
151 |
152 | ### Requirements
153 |
154 | The dependencies and installation are basically the same as the [Kolors-BaseModel](https://huggingface.co/Kwai-Kolors/Kolors).
155 |
156 |
157 |
158 |
159 | ### Weights download
160 | ```bash
161 | # Canny - ControlNet
162 | huggingface-cli download --resume-download Kwai-Kolors/Kolors-ControlNet-Canny --local-dir weights/Kolors-ControlNet-Canny
163 |
164 | # Depth - ControlNet
165 | huggingface-cli download --resume-download Kwai-Kolors/Kolors-ControlNet-Depth --local-dir weights/Kolors-ControlNet-Depth
166 |
167 | # Pose - ControlNet
168 | huggingface-cli download --resume-download Kwai-Kolors/Kolors-ControlNet-Pose --local-dir weights/Kolors-ControlNet-Pose
169 | ```
170 |
171 | If you intend to utilize the depth estimation network, please make sure to download its corresponding model weights.
172 | ```
173 | huggingface-cli download lllyasviel/Annotators ./dpt_hybrid-midas-501f0c75.pt --local-dir ./controlnet/annotator/ckpts
174 | ```
175 |
176 | Thanks to [DWPose](https://github.com/IDEA-Research/DWPose/tree/onnx?tab=readme-ov-file), you can utilize the pose estimation network. Please download the Pose model dw-ll_ucoco_384.onnx ([baidu](https://pan.baidu.com/s/1nuBjw-KKSxD_BkpmwXUJiw?pwd=28d7), [google](https://drive.google.com/file/d/12L8E2oAgZy4VACGSK9RaZBZrfgx7VTA2/view?usp=sharing)) and Det model yolox_l.onnx ([baidu](https://pan.baidu.com/s/1fpfIVpv5ypo4c1bUlzkMYQ?pwd=mjdn), [google](https://drive.google.com/file/d/1w9pXC8tT0p9ndMN-CArp1__b2GbzewWI/view?usp=sharing)). Then please put them into `controlnet/annotator/ckpts/`.
177 |
178 |
179 | ### Inference
180 |
181 |
182 | **a. Using canny ControlNet:**
183 |
184 | ```bash
185 | python ./controlnet/sample_controlNet.py ./controlnet/assets/woman_1.png 一个漂亮的女孩,高品质,超清晰,色彩鲜艳,超高分辨率,最佳品质,8k,高清,4K Canny
186 |
187 | python ./controlnet/sample_controlNet.py ./controlnet/assets/dog.png 全景,一只可爱的白色小狗坐在杯子里,看向镜头,动漫风格,3d渲染,辛烷值渲染 Canny
188 |
189 | # The image will be saved to "controlnet/outputs/"
190 | ```
191 |
192 | **b. Using depth ControlNet:**
193 |
194 | ```bash
195 | python ./controlnet/sample_controlNet.py ./controlnet/assets/woman_2.png 新海诚风格,丰富的色彩,穿着绿色衬衫的女人站在田野里,唯美风景,清新明亮,斑驳的光影,最好的质量,超细节,8K画质 Depth
196 |
197 | python ./controlnet/sample_controlNet.py ./controlnet/assets/bird.png 一只颜色鲜艳的小鸟,高品质,超清晰,色彩鲜艳,超高分辨率,最佳品质,8k,高清,4K Depth
198 |
199 | # The image will be saved to "controlnet/outputs/"
200 | ```
201 |
202 | **c. Using pose ControlNet:**
203 |
204 | ```bash
205 | python ./controlnet/sample_controlNet.py ./controlnet/assets/woman_3.png 一位穿着紫色泡泡袖连衣裙、戴着皇冠和白色蕾丝手套的女孩双手托脸,高品质,超清晰,色彩鲜艳,超高分辨率,最佳品质,8k,高清,4K Pose
206 |
207 | python ./controlnet/sample_controlNet.py ./controlnet/assets/woman_4.png 一个穿着黑色运动外套、白色内搭,上面戴着项链的女子,站在街边,背景是红色建筑和绿树,高品质,超清晰,色彩鲜艳,超高分辨率,最佳品质,8k,高清,4K Pose
208 |
209 | # The image will be saved to "controlnet/outputs/"
210 | ```
211 |
212 |
213 | **c. Using depth ControlNet + IP-Adapter-Plus:**
214 |
215 | If you intend to utilize the kolors-ip-adapter-plus, please make sure to download its corresponding model weights.
216 |
217 | ```bash
218 | python ./controlnet/sample_controlNet_ipadapter.py ./controlnet/assets/woman_2.png ./ipadapter/asset/2.png 一个红色头发的女孩,唯美风景,清新明亮,斑驳的光影,最好的质量,超细节,8K画质 Depth
219 |
220 | python ./controlnet/sample_controlNet_ipadapter.py ./ipadapter/asset/1.png ./controlnet/assets/woman_1.png 一个漂亮的女孩,最好的质量,超细节,8K画质 Depth
221 |
222 | # The image will be saved to "controlnet/outputs/"
223 | ```
224 |
225 |
226 |
227 |
228 | ### Acknowledgments
229 | - Thanks to [ControlNet](https://github.com/lllyasviel/ControlNet) for providing the codebase.
230 |
231 |
232 |
233 |
234 |
--------------------------------------------------------------------------------
/controlnet/annotator/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/controlnet/annotator/__init__.py
--------------------------------------------------------------------------------
/controlnet/annotator/canny/__init__.py:
--------------------------------------------------------------------------------
1 | import cv2
2 |
3 |
4 | class CannyDetector:
5 | def __call__(self, img, low_threshold, high_threshold):
6 | return cv2.Canny(img, low_threshold, high_threshold)
7 |
--------------------------------------------------------------------------------
/controlnet/annotator/dwpose/__init__.py:
--------------------------------------------------------------------------------
1 | # Openpose
2 | # Original from CMU https://github.com/CMU-Perceptual-Computing-Lab/openpose
3 | # 2nd Edited by https://github.com/Hzzone/pytorch-openpose
4 | # 3rd Edited by ControlNet
5 | # 4th Edited by ControlNet (added face and correct hands)
6 |
7 | import os
8 | os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
9 |
10 | import torch
11 | import numpy as np
12 | from . import util
13 | from .wholebody import Wholebody
14 |
15 | def draw_pose(pose, H, W):
16 | bodies = pose['bodies']
17 | faces = pose['faces']
18 | hands = pose['hands']
19 | candidate = bodies['candidate']
20 | subset = bodies['subset']
21 | canvas = np.zeros(shape=(H, W, 3), dtype=np.uint8)
22 |
23 | canvas = util.draw_bodypose(canvas, candidate, subset)
24 |
25 | canvas = util.draw_handpose(canvas, hands)
26 |
27 | # canvas = util.draw_facepose(canvas, faces)
28 |
29 | return canvas
30 |
31 |
32 | class DWposeDetector:
33 | def __init__(self):
34 |
35 | self.pose_estimation = Wholebody()
36 |
37 |
38 | def getres(self, oriImg):
39 | out_res = {}
40 | oriImg = oriImg.copy()
41 | H, W, C = oriImg.shape
42 | with torch.no_grad():
43 | candidate, subset = self.pose_estimation(oriImg)
44 | out_res['candidate']=candidate
45 | out_res['subset']=subset
46 | out_res['width']=W
47 | out_res['height']=H
48 | return out_res
49 |
50 | def __call__(self, oriImg):
51 |
52 | oriImg = oriImg.copy()
53 | H, W, C = oriImg.shape
54 | with torch.no_grad():
55 | _candidate, _subset = self.pose_estimation(oriImg)
56 |
57 | subset = _subset.copy()
58 | candidate = _candidate.copy()
59 | nums, keys, locs = candidate.shape
60 | candidate[..., 0] /= float(W)
61 | candidate[..., 1] /= float(H)
62 | body = candidate[:,:18].copy()
63 | body = body.reshape(nums*18, locs)
64 | score = subset[:,:18]
65 | for i in range(len(score)):
66 | for j in range(len(score[i])):
67 | if score[i][j] > 0.3:
68 | score[i][j] = int(18*i+j)
69 | else:
70 | score[i][j] = -1
71 |
72 | un_visible = subset<0.3
73 | candidate[un_visible] = -1
74 |
75 | foot = candidate[:,18:24]
76 |
77 | faces = candidate[:,24:92]
78 |
79 | hands = candidate[:,92:113]
80 | hands = np.vstack([hands, candidate[:,113:]])
81 |
82 | bodies = dict(candidate=body, subset=score)
83 | pose = dict(bodies=bodies, hands=hands, faces=faces)
84 |
85 | out_res = {}
86 | out_res['candidate']=candidate
87 | out_res['subset']=subset
88 | out_res['width']=W
89 | out_res['height']=H
90 |
91 | return out_res,draw_pose(pose, H, W)
92 |
--------------------------------------------------------------------------------
/controlnet/annotator/dwpose/onnxdet.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 |
4 | import onnxruntime
5 |
6 | def nms(boxes, scores, nms_thr):
7 | """Single class NMS implemented in Numpy."""
8 | x1 = boxes[:, 0]
9 | y1 = boxes[:, 1]
10 | x2 = boxes[:, 2]
11 | y2 = boxes[:, 3]
12 |
13 | areas = (x2 - x1 + 1) * (y2 - y1 + 1)
14 | order = scores.argsort()[::-1]
15 |
16 | keep = []
17 | while order.size > 0:
18 | i = order[0]
19 | keep.append(i)
20 | xx1 = np.maximum(x1[i], x1[order[1:]])
21 | yy1 = np.maximum(y1[i], y1[order[1:]])
22 | xx2 = np.minimum(x2[i], x2[order[1:]])
23 | yy2 = np.minimum(y2[i], y2[order[1:]])
24 |
25 | w = np.maximum(0.0, xx2 - xx1 + 1)
26 | h = np.maximum(0.0, yy2 - yy1 + 1)
27 | inter = w * h
28 | ovr = inter / (areas[i] + areas[order[1:]] - inter)
29 |
30 | inds = np.where(ovr <= nms_thr)[0]
31 | order = order[inds + 1]
32 |
33 | return keep
34 |
35 | def multiclass_nms(boxes, scores, nms_thr, score_thr):
36 | """Multiclass NMS implemented in Numpy. Class-aware version."""
37 | final_dets = []
38 | num_classes = scores.shape[1]
39 | for cls_ind in range(num_classes):
40 | cls_scores = scores[:, cls_ind]
41 | valid_score_mask = cls_scores > score_thr
42 | if valid_score_mask.sum() == 0:
43 | continue
44 | else:
45 | valid_scores = cls_scores[valid_score_mask]
46 | valid_boxes = boxes[valid_score_mask]
47 | keep = nms(valid_boxes, valid_scores, nms_thr)
48 | if len(keep) > 0:
49 | cls_inds = np.ones((len(keep), 1)) * cls_ind
50 | dets = np.concatenate(
51 | [valid_boxes[keep], valid_scores[keep, None], cls_inds], 1
52 | )
53 | final_dets.append(dets)
54 | if len(final_dets) == 0:
55 | return None
56 | return np.concatenate(final_dets, 0)
57 |
58 | def demo_postprocess(outputs, img_size, p6=False):
59 | grids = []
60 | expanded_strides = []
61 | strides = [8, 16, 32] if not p6 else [8, 16, 32, 64]
62 |
63 | hsizes = [img_size[0] // stride for stride in strides]
64 | wsizes = [img_size[1] // stride for stride in strides]
65 |
66 | for hsize, wsize, stride in zip(hsizes, wsizes, strides):
67 | xv, yv = np.meshgrid(np.arange(wsize), np.arange(hsize))
68 | grid = np.stack((xv, yv), 2).reshape(1, -1, 2)
69 | grids.append(grid)
70 | shape = grid.shape[:2]
71 | expanded_strides.append(np.full((*shape, 1), stride))
72 |
73 | grids = np.concatenate(grids, 1)
74 | expanded_strides = np.concatenate(expanded_strides, 1)
75 | outputs[..., :2] = (outputs[..., :2] + grids) * expanded_strides
76 | outputs[..., 2:4] = np.exp(outputs[..., 2:4]) * expanded_strides
77 |
78 | return outputs
79 |
80 | def preprocess(img, input_size, swap=(2, 0, 1)):
81 | if len(img.shape) == 3:
82 | padded_img = np.ones((input_size[0], input_size[1], 3), dtype=np.uint8) * 114
83 | else:
84 | padded_img = np.ones(input_size, dtype=np.uint8) * 114
85 |
86 | r = min(input_size[0] / img.shape[0], input_size[1] / img.shape[1])
87 | resized_img = cv2.resize(
88 | img,
89 | (int(img.shape[1] * r), int(img.shape[0] * r)),
90 | interpolation=cv2.INTER_LINEAR,
91 | ).astype(np.uint8)
92 | padded_img[: int(img.shape[0] * r), : int(img.shape[1] * r)] = resized_img
93 |
94 | padded_img = padded_img.transpose(swap)
95 | padded_img = np.ascontiguousarray(padded_img, dtype=np.float32)
96 | return padded_img, r
97 |
98 | def inference_detector(session, oriImg):
99 | input_shape = (640,640)
100 | img, ratio = preprocess(oriImg, input_shape)
101 |
102 | ort_inputs = {session.get_inputs()[0].name: img[None, :, :, :]}
103 | output = session.run(None, ort_inputs)
104 | predictions = demo_postprocess(output[0], input_shape)[0]
105 |
106 | boxes = predictions[:, :4]
107 | scores = predictions[:, 4:5] * predictions[:, 5:]
108 |
109 | boxes_xyxy = np.ones_like(boxes)
110 | boxes_xyxy[:, 0] = boxes[:, 0] - boxes[:, 2]/2.
111 | boxes_xyxy[:, 1] = boxes[:, 1] - boxes[:, 3]/2.
112 | boxes_xyxy[:, 2] = boxes[:, 0] + boxes[:, 2]/2.
113 | boxes_xyxy[:, 3] = boxes[:, 1] + boxes[:, 3]/2.
114 | boxes_xyxy /= ratio
115 | dets = multiclass_nms(boxes_xyxy, scores, nms_thr=0.45, score_thr=0.1)
116 | if dets is not None:
117 | final_boxes, final_scores, final_cls_inds = dets[:, :4], dets[:, 4], dets[:, 5]
118 | isscore = final_scores>0.3
119 | iscat = final_cls_inds == 0
120 | isbbox = [ i and j for (i, j) in zip(isscore, iscat)]
121 | final_boxes = final_boxes[isbbox]
122 | else:
123 | final_boxes = np.array([])
124 |
125 | return final_boxes
126 |
--------------------------------------------------------------------------------
/controlnet/annotator/dwpose/util.py:
--------------------------------------------------------------------------------
1 | import math
2 | import numpy as np
3 | import matplotlib
4 | import cv2
5 |
6 |
7 | eps = 0.01
8 |
9 |
10 | def smart_resize(x, s):
11 | Ht, Wt = s
12 | if x.ndim == 2:
13 | Ho, Wo = x.shape
14 | Co = 1
15 | else:
16 | Ho, Wo, Co = x.shape
17 | if Co == 3 or Co == 1:
18 | k = float(Ht + Wt) / float(Ho + Wo)
19 | return cv2.resize(x, (int(Wt), int(Ht)), interpolation=cv2.INTER_AREA if k < 1 else cv2.INTER_LANCZOS4)
20 | else:
21 | return np.stack([smart_resize(x[:, :, i], s) for i in range(Co)], axis=2)
22 |
23 |
24 | def smart_resize_k(x, fx, fy):
25 | if x.ndim == 2:
26 | Ho, Wo = x.shape
27 | Co = 1
28 | else:
29 | Ho, Wo, Co = x.shape
30 | Ht, Wt = Ho * fy, Wo * fx
31 | if Co == 3 or Co == 1:
32 | k = float(Ht + Wt) / float(Ho + Wo)
33 | return cv2.resize(x, (int(Wt), int(Ht)), interpolation=cv2.INTER_AREA if k < 1 else cv2.INTER_LANCZOS4)
34 | else:
35 | return np.stack([smart_resize_k(x[:, :, i], fx, fy) for i in range(Co)], axis=2)
36 |
37 |
38 | def padRightDownCorner(img, stride, padValue):
39 | h = img.shape[0]
40 | w = img.shape[1]
41 |
42 | pad = 4 * [None]
43 | pad[0] = 0 # up
44 | pad[1] = 0 # left
45 | pad[2] = 0 if (h % stride == 0) else stride - (h % stride) # down
46 | pad[3] = 0 if (w % stride == 0) else stride - (w % stride) # right
47 |
48 | img_padded = img
49 | pad_up = np.tile(img_padded[0:1, :, :]*0 + padValue, (pad[0], 1, 1))
50 | img_padded = np.concatenate((pad_up, img_padded), axis=0)
51 | pad_left = np.tile(img_padded[:, 0:1, :]*0 + padValue, (1, pad[1], 1))
52 | img_padded = np.concatenate((pad_left, img_padded), axis=1)
53 | pad_down = np.tile(img_padded[-2:-1, :, :]*0 + padValue, (pad[2], 1, 1))
54 | img_padded = np.concatenate((img_padded, pad_down), axis=0)
55 | pad_right = np.tile(img_padded[:, -2:-1, :]*0 + padValue, (1, pad[3], 1))
56 | img_padded = np.concatenate((img_padded, pad_right), axis=1)
57 |
58 | return img_padded, pad
59 |
60 |
61 | def transfer(model, model_weights):
62 | transfered_model_weights = {}
63 | for weights_name in model.state_dict().keys():
64 | transfered_model_weights[weights_name] = model_weights['.'.join(weights_name.split('.')[1:])]
65 | return transfered_model_weights
66 |
67 |
68 | def draw_bodypose(canvas, candidate, subset):
69 | H, W, C = canvas.shape
70 | candidate = np.array(candidate)
71 | subset = np.array(subset)
72 |
73 | stickwidth = 4
74 |
75 | limbSeq = [[2, 3], [2, 6], [3, 4], [4, 5], [6, 7], [7, 8], [2, 9], [9, 10], \
76 | [10, 11], [2, 12], [12, 13], [13, 14], [2, 1], [1, 15], [15, 17], \
77 | [1, 16], [16, 18], [3, 17], [6, 18]]
78 |
79 | colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], \
80 | [0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], \
81 | [170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]]
82 |
83 | for i in range(17):
84 | for n in range(len(subset)):
85 | index = subset[n][np.array(limbSeq[i]) - 1]
86 | if -1 in index:
87 | continue
88 | Y = candidate[index.astype(int), 0] * float(W)
89 | X = candidate[index.astype(int), 1] * float(H)
90 | mX = np.mean(X)
91 | mY = np.mean(Y)
92 | length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5
93 | angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1]))
94 | polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0, 360, 1)
95 | cv2.fillConvexPoly(canvas, polygon, colors[i])
96 |
97 | canvas = (canvas * 0.6).astype(np.uint8)
98 |
99 | for i in range(18):
100 | for n in range(len(subset)):
101 | index = int(subset[n][i])
102 | if index == -1:
103 | continue
104 | x, y = candidate[index][0:2]
105 | x = int(x * W)
106 | y = int(y * H)
107 | cv2.circle(canvas, (int(x), int(y)), 4, colors[i], thickness=-1)
108 |
109 | return canvas
110 |
111 |
112 | def draw_handpose(canvas, all_hand_peaks):
113 | H, W, C = canvas.shape
114 |
115 | edges = [[0, 1], [1, 2], [2, 3], [3, 4], [0, 5], [5, 6], [6, 7], [7, 8], [0, 9], [9, 10], \
116 | [10, 11], [11, 12], [0, 13], [13, 14], [14, 15], [15, 16], [0, 17], [17, 18], [18, 19], [19, 20]]
117 |
118 | for peaks in all_hand_peaks:
119 | peaks = np.array(peaks)
120 |
121 | for ie, e in enumerate(edges):
122 | x1, y1 = peaks[e[0]]
123 | x2, y2 = peaks[e[1]]
124 | x1 = int(x1 * W)
125 | y1 = int(y1 * H)
126 | x2 = int(x2 * W)
127 | y2 = int(y2 * H)
128 | if x1 > eps and y1 > eps and x2 > eps and y2 > eps:
129 | cv2.line(canvas, (x1, y1), (x2, y2), matplotlib.colors.hsv_to_rgb([ie / float(len(edges)), 1.0, 1.0]) * 255, thickness=2)
130 |
131 | for i, keyponit in enumerate(peaks):
132 | x, y = keyponit
133 | x = int(x * W)
134 | y = int(y * H)
135 | if x > eps and y > eps:
136 | cv2.circle(canvas, (x, y), 4, (0, 0, 255), thickness=-1)
137 | return canvas
138 |
139 |
140 | def draw_facepose(canvas, all_lmks):
141 | H, W, C = canvas.shape
142 | for lmks in all_lmks:
143 | lmks = np.array(lmks)
144 | for lmk in lmks:
145 | x, y = lmk
146 | x = int(x * W)
147 | y = int(y * H)
148 | if x > eps and y > eps:
149 | cv2.circle(canvas, (x, y), 3, (255, 255, 255), thickness=-1)
150 | return canvas
151 |
152 |
153 | # detect hand according to body pose keypoints
154 | # please refer to https://github.com/CMU-Perceptual-Computing-Lab/openpose/blob/master/src/openpose/hand/handDetector.cpp
155 | def handDetect(candidate, subset, oriImg):
156 | # right hand: wrist 4, elbow 3, shoulder 2
157 | # left hand: wrist 7, elbow 6, shoulder 5
158 | ratioWristElbow = 0.33
159 | detect_result = []
160 | image_height, image_width = oriImg.shape[0:2]
161 | for person in subset.astype(int):
162 | # if any of three not detected
163 | has_left = np.sum(person[[5, 6, 7]] == -1) == 0
164 | has_right = np.sum(person[[2, 3, 4]] == -1) == 0
165 | if not (has_left or has_right):
166 | continue
167 | hands = []
168 | #left hand
169 | if has_left:
170 | left_shoulder_index, left_elbow_index, left_wrist_index = person[[5, 6, 7]]
171 | x1, y1 = candidate[left_shoulder_index][:2]
172 | x2, y2 = candidate[left_elbow_index][:2]
173 | x3, y3 = candidate[left_wrist_index][:2]
174 | hands.append([x1, y1, x2, y2, x3, y3, True])
175 | # right hand
176 | if has_right:
177 | right_shoulder_index, right_elbow_index, right_wrist_index = person[[2, 3, 4]]
178 | x1, y1 = candidate[right_shoulder_index][:2]
179 | x2, y2 = candidate[right_elbow_index][:2]
180 | x3, y3 = candidate[right_wrist_index][:2]
181 | hands.append([x1, y1, x2, y2, x3, y3, False])
182 |
183 | for x1, y1, x2, y2, x3, y3, is_left in hands:
184 | # pos_hand = pos_wrist + ratio * (pos_wrist - pos_elbox) = (1 + ratio) * pos_wrist - ratio * pos_elbox
185 | # handRectangle.x = posePtr[wrist*3] + ratioWristElbow * (posePtr[wrist*3] - posePtr[elbow*3]);
186 | # handRectangle.y = posePtr[wrist*3+1] + ratioWristElbow * (posePtr[wrist*3+1] - posePtr[elbow*3+1]);
187 | # const auto distanceWristElbow = getDistance(poseKeypoints, person, wrist, elbow);
188 | # const auto distanceElbowShoulder = getDistance(poseKeypoints, person, elbow, shoulder);
189 | # handRectangle.width = 1.5f * fastMax(distanceWristElbow, 0.9f * distanceElbowShoulder);
190 | x = x3 + ratioWristElbow * (x3 - x2)
191 | y = y3 + ratioWristElbow * (y3 - y2)
192 | distanceWristElbow = math.sqrt((x3 - x2) ** 2 + (y3 - y2) ** 2)
193 | distanceElbowShoulder = math.sqrt((x2 - x1) ** 2 + (y2 - y1) ** 2)
194 | width = 1.5 * max(distanceWristElbow, 0.9 * distanceElbowShoulder)
195 | # x-y refers to the center --> offset to topLeft point
196 | # handRectangle.x -= handRectangle.width / 2.f;
197 | # handRectangle.y -= handRectangle.height / 2.f;
198 | x -= width / 2
199 | y -= width / 2 # width = height
200 | # overflow the image
201 | if x < 0: x = 0
202 | if y < 0: y = 0
203 | width1 = width
204 | width2 = width
205 | if x + width > image_width: width1 = image_width - x
206 | if y + width > image_height: width2 = image_height - y
207 | width = min(width1, width2)
208 | # the max hand box value is 20 pixels
209 | if width >= 20:
210 | detect_result.append([int(x), int(y), int(width), is_left])
211 |
212 | '''
213 | return value: [[x, y, w, True if left hand else False]].
214 | width=height since the network require squared input.
215 | x, y is the coordinate of top left
216 | '''
217 | return detect_result
218 |
219 |
220 | # Written by Lvmin
221 | def faceDetect(candidate, subset, oriImg):
222 | # left right eye ear 14 15 16 17
223 | detect_result = []
224 | image_height, image_width = oriImg.shape[0:2]
225 | for person in subset.astype(int):
226 | has_head = person[0] > -1
227 | if not has_head:
228 | continue
229 |
230 | has_left_eye = person[14] > -1
231 | has_right_eye = person[15] > -1
232 | has_left_ear = person[16] > -1
233 | has_right_ear = person[17] > -1
234 |
235 | if not (has_left_eye or has_right_eye or has_left_ear or has_right_ear):
236 | continue
237 |
238 | head, left_eye, right_eye, left_ear, right_ear = person[[0, 14, 15, 16, 17]]
239 |
240 | width = 0.0
241 | x0, y0 = candidate[head][:2]
242 |
243 | if has_left_eye:
244 | x1, y1 = candidate[left_eye][:2]
245 | d = max(abs(x0 - x1), abs(y0 - y1))
246 | width = max(width, d * 3.0)
247 |
248 | if has_right_eye:
249 | x1, y1 = candidate[right_eye][:2]
250 | d = max(abs(x0 - x1), abs(y0 - y1))
251 | width = max(width, d * 3.0)
252 |
253 | if has_left_ear:
254 | x1, y1 = candidate[left_ear][:2]
255 | d = max(abs(x0 - x1), abs(y0 - y1))
256 | width = max(width, d * 1.5)
257 |
258 | if has_right_ear:
259 | x1, y1 = candidate[right_ear][:2]
260 | d = max(abs(x0 - x1), abs(y0 - y1))
261 | width = max(width, d * 1.5)
262 |
263 | x, y = x0, y0
264 |
265 | x -= width
266 | y -= width
267 |
268 | if x < 0:
269 | x = 0
270 |
271 | if y < 0:
272 | y = 0
273 |
274 | width1 = width * 2
275 | width2 = width * 2
276 |
277 | if x + width > image_width:
278 | width1 = image_width - x
279 |
280 | if y + width > image_height:
281 | width2 = image_height - y
282 |
283 | width = min(width1, width2)
284 |
285 | if width >= 20:
286 | detect_result.append([int(x), int(y), int(width)])
287 |
288 | return detect_result
289 |
290 |
291 | # get max index of 2d array
292 | def npmax(array):
293 | arrayindex = array.argmax(1)
294 | arrayvalue = array.max(1)
295 | i = arrayvalue.argmax()
296 | j = arrayindex[i]
297 | return i, j
298 |
--------------------------------------------------------------------------------
/controlnet/annotator/dwpose/wholebody.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 |
4 | import onnxruntime as ort
5 | from .onnxdet import inference_detector
6 | from .onnxpose import inference_pose
7 |
8 | class Wholebody:
9 | def __init__(self):
10 | device = 'cuda:0'
11 | providers = ['CPUExecutionProvider'
12 | ] if device == 'cpu' else ['CUDAExecutionProvider']
13 | # providers = ['CPUExecutionProvider']
14 | providers = ['CUDAExecutionProvider']
15 | onnx_det = 'annotator/ckpts/yolox_l.onnx'
16 | onnx_pose = 'annotator/ckpts/dw-ll_ucoco_384.onnx'
17 |
18 | self.session_det = ort.InferenceSession(path_or_bytes=onnx_det, providers=providers)
19 | self.session_pose = ort.InferenceSession(path_or_bytes=onnx_pose, providers=providers)
20 | def __call__(self, oriImg):
21 | det_result = inference_detector(self.session_det, oriImg)
22 | keypoints, scores = inference_pose(self.session_pose, det_result, oriImg)
23 |
24 | keypoints_info = np.concatenate(
25 | (keypoints, scores[..., None]), axis=-1)
26 | # compute neck joint
27 | neck = np.mean(keypoints_info[:, [5, 6]], axis=1)
28 | # neck score when visualizing pred
29 | neck[:, 2:4] = np.logical_and(
30 | keypoints_info[:, 5, 2:4] > 0.3,
31 | keypoints_info[:, 6, 2:4] > 0.3).astype(int)
32 | new_keypoints_info = np.insert(
33 | keypoints_info, 17, neck, axis=1)
34 | mmpose_idx = [
35 | 17, 6, 8, 10, 7, 9, 12, 14, 16, 13, 15, 2, 1, 4, 3
36 | ]
37 | openpose_idx = [
38 | 1, 2, 3, 4, 6, 7, 8, 9, 10, 12, 13, 14, 15, 16, 17
39 | ]
40 | new_keypoints_info[:, openpose_idx] = \
41 | new_keypoints_info[:, mmpose_idx]
42 | keypoints_info = new_keypoints_info
43 |
44 | keypoints, scores = keypoints_info[
45 | ..., :2], keypoints_info[..., 2]
46 |
47 | return keypoints, scores
48 |
49 |
50 |
--------------------------------------------------------------------------------
/controlnet/annotator/midas/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2019 Intel ISL (Intel Intelligent Systems Lab)
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/controlnet/annotator/midas/__init__.py:
--------------------------------------------------------------------------------
1 | # Midas Depth Estimation
2 | # From https://github.com/isl-org/MiDaS
3 | # MIT LICENSE
4 |
5 | import cv2
6 | import numpy as np
7 | import torch
8 |
9 | from einops import rearrange
10 | from .api import MiDaSInference
11 |
12 |
13 | class MidasDetector:
14 | def __init__(self):
15 | self.model = MiDaSInference(model_type="dpt_hybrid").cuda()
16 | self.rng = np.random.RandomState(0)
17 |
18 | def __call__(self, input_image):
19 | assert input_image.ndim == 3
20 | image_depth = input_image
21 | with torch.no_grad():
22 | image_depth = torch.from_numpy(image_depth).float().cuda()
23 | image_depth = image_depth / 127.5 - 1.0
24 | image_depth = rearrange(image_depth, 'h w c -> 1 c h w')
25 | depth = self.model(image_depth)[0]
26 |
27 | depth -= torch.min(depth)
28 | depth /= torch.max(depth)
29 | depth = depth.cpu().numpy()
30 | depth_image = (depth * 255.0).clip(0, 255).astype(np.uint8)
31 |
32 | return depth_image
33 |
34 |
35 |
36 |
--------------------------------------------------------------------------------
/controlnet/annotator/midas/api.py:
--------------------------------------------------------------------------------
1 | # based on https://github.com/isl-org/MiDaS
2 |
3 | import cv2
4 | import os
5 | import torch
6 | import torch.nn as nn
7 | from torchvision.transforms import Compose
8 |
9 | from .midas.dpt_depth import DPTDepthModel
10 | from .midas.midas_net import MidasNet
11 | from .midas.midas_net_custom import MidasNet_small
12 | from .midas.transforms import Resize, NormalizeImage, PrepareForNet
13 | from annotator.util import annotator_ckpts_path
14 |
15 |
16 | ISL_PATHS = {
17 | "dpt_large": os.path.join(annotator_ckpts_path, "dpt_large-midas-2f21e586.pt"),
18 | "dpt_hybrid": os.path.join(annotator_ckpts_path, "dpt_hybrid-midas-501f0c75.pt"),
19 | "midas_v21": "",
20 | "midas_v21_small": "",
21 | }
22 |
23 | remote_model_path = "https://huggingface.co/lllyasviel/Annotators/resolve/main/dpt_hybrid-midas-501f0c75.pt"
24 |
25 |
26 | def disabled_train(self, mode=True):
27 | """Overwrite model.train with this function to make sure train/eval mode
28 | does not change anymore."""
29 | return self
30 |
31 |
32 | def load_midas_transform(model_type):
33 | # https://github.com/isl-org/MiDaS/blob/master/run.py
34 | # load transform only
35 | if model_type == "dpt_large": # DPT-Large
36 | net_w, net_h = 384, 384
37 | resize_mode = "minimal"
38 | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
39 |
40 | elif model_type == "dpt_hybrid": # DPT-Hybrid
41 | net_w, net_h = 384, 384
42 | resize_mode = "minimal"
43 | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
44 |
45 | elif model_type == "midas_v21":
46 | net_w, net_h = 384, 384
47 | resize_mode = "upper_bound"
48 | normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
49 |
50 | elif model_type == "midas_v21_small":
51 | net_w, net_h = 256, 256
52 | resize_mode = "upper_bound"
53 | normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
54 |
55 | else:
56 | assert False, f"model_type '{model_type}' not implemented, use: --model_type large"
57 |
58 | transform = Compose(
59 | [
60 | Resize(
61 | net_w,
62 | net_h,
63 | resize_target=None,
64 | keep_aspect_ratio=True,
65 | ensure_multiple_of=32,
66 | resize_method=resize_mode,
67 | image_interpolation_method=cv2.INTER_CUBIC,
68 | ),
69 | normalization,
70 | PrepareForNet(),
71 | ]
72 | )
73 |
74 | return transform
75 |
76 |
77 | def load_model(model_type):
78 | # https://github.com/isl-org/MiDaS/blob/master/run.py
79 | # load network
80 | model_path = ISL_PATHS[model_type]
81 | if model_type == "dpt_large": # DPT-Large
82 | model = DPTDepthModel(
83 | path=model_path,
84 | backbone="vitl16_384",
85 | non_negative=True,
86 | )
87 | net_w, net_h = 384, 384
88 | resize_mode = "minimal"
89 | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
90 |
91 | elif model_type == "dpt_hybrid": # DPT-Hybrid
92 | if not os.path.exists(model_path):
93 | from basicsr.utils.download_util import load_file_from_url
94 | load_file_from_url(remote_model_path, model_dir=annotator_ckpts_path)
95 |
96 | model = DPTDepthModel(
97 | path=model_path,
98 | backbone="vitb_rn50_384",
99 | non_negative=True,
100 | )
101 | net_w, net_h = 384, 384
102 | resize_mode = "minimal"
103 | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
104 |
105 | elif model_type == "midas_v21":
106 | model = MidasNet(model_path, non_negative=True)
107 | net_w, net_h = 384, 384
108 | resize_mode = "upper_bound"
109 | normalization = NormalizeImage(
110 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
111 | )
112 |
113 | elif model_type == "midas_v21_small":
114 | model = MidasNet_small(model_path, features=64, backbone="efficientnet_lite3", exportable=True,
115 | non_negative=True, blocks={'expand': True})
116 | net_w, net_h = 256, 256
117 | resize_mode = "upper_bound"
118 | normalization = NormalizeImage(
119 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
120 | )
121 |
122 | else:
123 | print(f"model_type '{model_type}' not implemented, use: --model_type large")
124 | assert False
125 |
126 | transform = Compose(
127 | [
128 | Resize(
129 | net_w,
130 | net_h,
131 | resize_target=None,
132 | keep_aspect_ratio=True,
133 | ensure_multiple_of=32,
134 | resize_method=resize_mode,
135 | image_interpolation_method=cv2.INTER_CUBIC,
136 | ),
137 | normalization,
138 | PrepareForNet(),
139 | ]
140 | )
141 |
142 | return model.eval(), transform
143 |
144 |
145 | class MiDaSInference(nn.Module):
146 | MODEL_TYPES_TORCH_HUB = [
147 | "DPT_Large",
148 | "DPT_Hybrid",
149 | "MiDaS_small"
150 | ]
151 | MODEL_TYPES_ISL = [
152 | "dpt_large",
153 | "dpt_hybrid",
154 | "midas_v21",
155 | "midas_v21_small",
156 | ]
157 |
158 | def __init__(self, model_type):
159 | super().__init__()
160 | assert (model_type in self.MODEL_TYPES_ISL)
161 | model, _ = load_model(model_type)
162 | self.model = model
163 | self.model.train = disabled_train
164 |
165 | def forward(self, x):
166 | with torch.no_grad():
167 | prediction = self.model(x)
168 | return prediction
169 |
170 |
--------------------------------------------------------------------------------
/controlnet/annotator/midas/midas/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/controlnet/annotator/midas/midas/__init__.py
--------------------------------------------------------------------------------
/controlnet/annotator/midas/midas/base_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | class BaseModel(torch.nn.Module):
5 | def load(self, path):
6 | """Load model from file.
7 |
8 | Args:
9 | path (str): file path
10 | """
11 | parameters = torch.load(path, map_location=torch.device('cpu'))
12 |
13 | if "optimizer" in parameters:
14 | parameters = parameters["model"]
15 |
16 | self.load_state_dict(parameters)
17 |
--------------------------------------------------------------------------------
/controlnet/annotator/midas/midas/blocks.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from .vit import (
5 | _make_pretrained_vitb_rn50_384,
6 | _make_pretrained_vitl16_384,
7 | _make_pretrained_vitb16_384,
8 | forward_vit,
9 | )
10 |
11 | def _make_encoder(backbone, features, use_pretrained, groups=1, expand=False, exportable=True, hooks=None, use_vit_only=False, use_readout="ignore",):
12 | if backbone == "vitl16_384":
13 | pretrained = _make_pretrained_vitl16_384(
14 | use_pretrained, hooks=hooks, use_readout=use_readout
15 | )
16 | scratch = _make_scratch(
17 | [256, 512, 1024, 1024], features, groups=groups, expand=expand
18 | ) # ViT-L/16 - 85.0% Top1 (backbone)
19 | elif backbone == "vitb_rn50_384":
20 | pretrained = _make_pretrained_vitb_rn50_384(
21 | use_pretrained,
22 | hooks=hooks,
23 | use_vit_only=use_vit_only,
24 | use_readout=use_readout,
25 | )
26 | scratch = _make_scratch(
27 | [256, 512, 768, 768], features, groups=groups, expand=expand
28 | ) # ViT-H/16 - 85.0% Top1 (backbone)
29 | elif backbone == "vitb16_384":
30 | pretrained = _make_pretrained_vitb16_384(
31 | use_pretrained, hooks=hooks, use_readout=use_readout
32 | )
33 | scratch = _make_scratch(
34 | [96, 192, 384, 768], features, groups=groups, expand=expand
35 | ) # ViT-B/16 - 84.6% Top1 (backbone)
36 | elif backbone == "resnext101_wsl":
37 | pretrained = _make_pretrained_resnext101_wsl(use_pretrained)
38 | scratch = _make_scratch([256, 512, 1024, 2048], features, groups=groups, expand=expand) # efficientnet_lite3
39 | elif backbone == "efficientnet_lite3":
40 | pretrained = _make_pretrained_efficientnet_lite3(use_pretrained, exportable=exportable)
41 | scratch = _make_scratch([32, 48, 136, 384], features, groups=groups, expand=expand) # efficientnet_lite3
42 | else:
43 | print(f"Backbone '{backbone}' not implemented")
44 | assert False
45 |
46 | return pretrained, scratch
47 |
48 |
49 | def _make_scratch(in_shape, out_shape, groups=1, expand=False):
50 | scratch = nn.Module()
51 |
52 | out_shape1 = out_shape
53 | out_shape2 = out_shape
54 | out_shape3 = out_shape
55 | out_shape4 = out_shape
56 | if expand==True:
57 | out_shape1 = out_shape
58 | out_shape2 = out_shape*2
59 | out_shape3 = out_shape*4
60 | out_shape4 = out_shape*8
61 |
62 | scratch.layer1_rn = nn.Conv2d(
63 | in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
64 | )
65 | scratch.layer2_rn = nn.Conv2d(
66 | in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
67 | )
68 | scratch.layer3_rn = nn.Conv2d(
69 | in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
70 | )
71 | scratch.layer4_rn = nn.Conv2d(
72 | in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
73 | )
74 |
75 | return scratch
76 |
77 |
78 | def _make_pretrained_efficientnet_lite3(use_pretrained, exportable=False):
79 | efficientnet = torch.hub.load(
80 | "rwightman/gen-efficientnet-pytorch",
81 | "tf_efficientnet_lite3",
82 | pretrained=use_pretrained,
83 | exportable=exportable
84 | )
85 | return _make_efficientnet_backbone(efficientnet)
86 |
87 |
88 | def _make_efficientnet_backbone(effnet):
89 | pretrained = nn.Module()
90 |
91 | pretrained.layer1 = nn.Sequential(
92 | effnet.conv_stem, effnet.bn1, effnet.act1, *effnet.blocks[0:2]
93 | )
94 | pretrained.layer2 = nn.Sequential(*effnet.blocks[2:3])
95 | pretrained.layer3 = nn.Sequential(*effnet.blocks[3:5])
96 | pretrained.layer4 = nn.Sequential(*effnet.blocks[5:9])
97 |
98 | return pretrained
99 |
100 |
101 | def _make_resnet_backbone(resnet):
102 | pretrained = nn.Module()
103 | pretrained.layer1 = nn.Sequential(
104 | resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1
105 | )
106 |
107 | pretrained.layer2 = resnet.layer2
108 | pretrained.layer3 = resnet.layer3
109 | pretrained.layer4 = resnet.layer4
110 |
111 | return pretrained
112 |
113 |
114 | def _make_pretrained_resnext101_wsl(use_pretrained):
115 | resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl")
116 | return _make_resnet_backbone(resnet)
117 |
118 |
119 |
120 | class Interpolate(nn.Module):
121 | """Interpolation module.
122 | """
123 |
124 | def __init__(self, scale_factor, mode, align_corners=False):
125 | """Init.
126 |
127 | Args:
128 | scale_factor (float): scaling
129 | mode (str): interpolation mode
130 | """
131 | super(Interpolate, self).__init__()
132 |
133 | self.interp = nn.functional.interpolate
134 | self.scale_factor = scale_factor
135 | self.mode = mode
136 | self.align_corners = align_corners
137 |
138 | def forward(self, x):
139 | """Forward pass.
140 |
141 | Args:
142 | x (tensor): input
143 |
144 | Returns:
145 | tensor: interpolated data
146 | """
147 |
148 | x = self.interp(
149 | x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners
150 | )
151 |
152 | return x
153 |
154 |
155 | class ResidualConvUnit(nn.Module):
156 | """Residual convolution module.
157 | """
158 |
159 | def __init__(self, features):
160 | """Init.
161 |
162 | Args:
163 | features (int): number of features
164 | """
165 | super().__init__()
166 |
167 | self.conv1 = nn.Conv2d(
168 | features, features, kernel_size=3, stride=1, padding=1, bias=True
169 | )
170 |
171 | self.conv2 = nn.Conv2d(
172 | features, features, kernel_size=3, stride=1, padding=1, bias=True
173 | )
174 |
175 | self.relu = nn.ReLU(inplace=True)
176 |
177 | def forward(self, x):
178 | """Forward pass.
179 |
180 | Args:
181 | x (tensor): input
182 |
183 | Returns:
184 | tensor: output
185 | """
186 | out = self.relu(x)
187 | out = self.conv1(out)
188 | out = self.relu(out)
189 | out = self.conv2(out)
190 |
191 | return out + x
192 |
193 |
194 | class FeatureFusionBlock(nn.Module):
195 | """Feature fusion block.
196 | """
197 |
198 | def __init__(self, features):
199 | """Init.
200 |
201 | Args:
202 | features (int): number of features
203 | """
204 | super(FeatureFusionBlock, self).__init__()
205 |
206 | self.resConfUnit1 = ResidualConvUnit(features)
207 | self.resConfUnit2 = ResidualConvUnit(features)
208 |
209 | def forward(self, *xs):
210 | """Forward pass.
211 |
212 | Returns:
213 | tensor: output
214 | """
215 | output = xs[0]
216 |
217 | if len(xs) == 2:
218 | output += self.resConfUnit1(xs[1])
219 |
220 | output = self.resConfUnit2(output)
221 |
222 | output = nn.functional.interpolate(
223 | output, scale_factor=2, mode="bilinear", align_corners=True
224 | )
225 |
226 | return output
227 |
228 |
229 |
230 |
231 | class ResidualConvUnit_custom(nn.Module):
232 | """Residual convolution module.
233 | """
234 |
235 | def __init__(self, features, activation, bn):
236 | """Init.
237 |
238 | Args:
239 | features (int): number of features
240 | """
241 | super().__init__()
242 |
243 | self.bn = bn
244 |
245 | self.groups=1
246 |
247 | self.conv1 = nn.Conv2d(
248 | features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
249 | )
250 |
251 | self.conv2 = nn.Conv2d(
252 | features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
253 | )
254 |
255 | if self.bn==True:
256 | self.bn1 = nn.BatchNorm2d(features)
257 | self.bn2 = nn.BatchNorm2d(features)
258 |
259 | self.activation = activation
260 |
261 | self.skip_add = nn.quantized.FloatFunctional()
262 |
263 | def forward(self, x):
264 | """Forward pass.
265 |
266 | Args:
267 | x (tensor): input
268 |
269 | Returns:
270 | tensor: output
271 | """
272 |
273 | out = self.activation(x)
274 | out = self.conv1(out)
275 | if self.bn==True:
276 | out = self.bn1(out)
277 |
278 | out = self.activation(out)
279 | out = self.conv2(out)
280 | if self.bn==True:
281 | out = self.bn2(out)
282 |
283 | if self.groups > 1:
284 | out = self.conv_merge(out)
285 |
286 | return self.skip_add.add(out, x)
287 |
288 | # return out + x
289 |
290 |
291 | class FeatureFusionBlock_custom(nn.Module):
292 | """Feature fusion block.
293 | """
294 |
295 | def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True):
296 | """Init.
297 |
298 | Args:
299 | features (int): number of features
300 | """
301 | super(FeatureFusionBlock_custom, self).__init__()
302 |
303 | self.deconv = deconv
304 | self.align_corners = align_corners
305 |
306 | self.groups=1
307 |
308 | self.expand = expand
309 | out_features = features
310 | if self.expand==True:
311 | out_features = features//2
312 |
313 | self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1)
314 |
315 | self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn)
316 | self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn)
317 |
318 | self.skip_add = nn.quantized.FloatFunctional()
319 |
320 | def forward(self, *xs):
321 | """Forward pass.
322 |
323 | Returns:
324 | tensor: output
325 | """
326 | output = xs[0]
327 |
328 | if len(xs) == 2:
329 | res = self.resConfUnit1(xs[1])
330 | output = self.skip_add.add(output, res)
331 | # output += res
332 |
333 | output = self.resConfUnit2(output)
334 |
335 | output = nn.functional.interpolate(
336 | output, scale_factor=2, mode="bilinear", align_corners=self.align_corners
337 | )
338 |
339 | output = self.out_conv(output)
340 |
341 | return output
342 |
343 |
--------------------------------------------------------------------------------
/controlnet/annotator/midas/midas/dpt_depth.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from .base_model import BaseModel
6 | from .blocks import (
7 | FeatureFusionBlock,
8 | FeatureFusionBlock_custom,
9 | Interpolate,
10 | _make_encoder,
11 | forward_vit,
12 | )
13 |
14 |
15 | def _make_fusion_block(features, use_bn):
16 | return FeatureFusionBlock_custom(
17 | features,
18 | nn.ReLU(False),
19 | deconv=False,
20 | bn=use_bn,
21 | expand=False,
22 | align_corners=True,
23 | )
24 |
25 |
26 | class DPT(BaseModel):
27 | def __init__(
28 | self,
29 | head,
30 | features=256,
31 | backbone="vitb_rn50_384",
32 | readout="project",
33 | channels_last=False,
34 | use_bn=False,
35 | ):
36 |
37 | super(DPT, self).__init__()
38 |
39 | self.channels_last = channels_last
40 |
41 | hooks = {
42 | "vitb_rn50_384": [0, 1, 8, 11],
43 | "vitb16_384": [2, 5, 8, 11],
44 | "vitl16_384": [5, 11, 17, 23],
45 | }
46 |
47 | # Instantiate backbone and reassemble blocks
48 | self.pretrained, self.scratch = _make_encoder(
49 | backbone,
50 | features,
51 | False, # Set to true of you want to train from scratch, uses ImageNet weights
52 | groups=1,
53 | expand=False,
54 | exportable=False,
55 | hooks=hooks[backbone],
56 | use_readout=readout,
57 | )
58 |
59 | self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
60 | self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
61 | self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
62 | self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
63 |
64 | self.scratch.output_conv = head
65 |
66 |
67 | def forward(self, x):
68 | if self.channels_last == True:
69 | x.contiguous(memory_format=torch.channels_last)
70 |
71 | layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x)
72 |
73 | layer_1_rn = self.scratch.layer1_rn(layer_1)
74 | layer_2_rn = self.scratch.layer2_rn(layer_2)
75 | layer_3_rn = self.scratch.layer3_rn(layer_3)
76 | layer_4_rn = self.scratch.layer4_rn(layer_4)
77 |
78 | path_4 = self.scratch.refinenet4(layer_4_rn)
79 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
80 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
81 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
82 |
83 | out = self.scratch.output_conv(path_1)
84 |
85 | return out
86 |
87 |
88 | class DPTDepthModel(DPT):
89 | def __init__(self, path=None, non_negative=True, **kwargs):
90 | features = kwargs["features"] if "features" in kwargs else 256
91 |
92 | head = nn.Sequential(
93 | nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1),
94 | Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
95 | nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1),
96 | nn.ReLU(True),
97 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
98 | nn.ReLU(True) if non_negative else nn.Identity(),
99 | nn.Identity(),
100 | )
101 |
102 | super().__init__(head, **kwargs)
103 |
104 | if path is not None:
105 | self.load(path)
106 |
107 | def forward(self, x):
108 | return super().forward(x).squeeze(dim=1)
109 |
110 |
--------------------------------------------------------------------------------
/controlnet/annotator/midas/midas/midas_net.py:
--------------------------------------------------------------------------------
1 | """MidashNet: Network for monocular depth estimation trained by mixing several datasets.
2 | This file contains code that is adapted from
3 | https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
4 | """
5 | import torch
6 | import torch.nn as nn
7 |
8 | from .base_model import BaseModel
9 | from .blocks import FeatureFusionBlock, Interpolate, _make_encoder
10 |
11 |
12 | class MidasNet(BaseModel):
13 | """Network for monocular depth estimation.
14 | """
15 |
16 | def __init__(self, path=None, features=256, non_negative=True):
17 | """Init.
18 |
19 | Args:
20 | path (str, optional): Path to saved model. Defaults to None.
21 | features (int, optional): Number of features. Defaults to 256.
22 | backbone (str, optional): Backbone network for encoder. Defaults to resnet50
23 | """
24 | print("Loading weights: ", path)
25 |
26 | super(MidasNet, self).__init__()
27 |
28 | use_pretrained = False if path is None else True
29 |
30 | self.pretrained, self.scratch = _make_encoder(backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained)
31 |
32 | self.scratch.refinenet4 = FeatureFusionBlock(features)
33 | self.scratch.refinenet3 = FeatureFusionBlock(features)
34 | self.scratch.refinenet2 = FeatureFusionBlock(features)
35 | self.scratch.refinenet1 = FeatureFusionBlock(features)
36 |
37 | self.scratch.output_conv = nn.Sequential(
38 | nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1),
39 | Interpolate(scale_factor=2, mode="bilinear"),
40 | nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1),
41 | nn.ReLU(True),
42 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
43 | nn.ReLU(True) if non_negative else nn.Identity(),
44 | )
45 |
46 | if path:
47 | self.load(path)
48 |
49 | def forward(self, x):
50 | """Forward pass.
51 |
52 | Args:
53 | x (tensor): input data (image)
54 |
55 | Returns:
56 | tensor: depth
57 | """
58 |
59 | layer_1 = self.pretrained.layer1(x)
60 | layer_2 = self.pretrained.layer2(layer_1)
61 | layer_3 = self.pretrained.layer3(layer_2)
62 | layer_4 = self.pretrained.layer4(layer_3)
63 |
64 | layer_1_rn = self.scratch.layer1_rn(layer_1)
65 | layer_2_rn = self.scratch.layer2_rn(layer_2)
66 | layer_3_rn = self.scratch.layer3_rn(layer_3)
67 | layer_4_rn = self.scratch.layer4_rn(layer_4)
68 |
69 | path_4 = self.scratch.refinenet4(layer_4_rn)
70 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
71 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
72 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
73 |
74 | out = self.scratch.output_conv(path_1)
75 |
76 | return torch.squeeze(out, dim=1)
77 |
--------------------------------------------------------------------------------
/controlnet/annotator/midas/midas/midas_net_custom.py:
--------------------------------------------------------------------------------
1 | """MidashNet: Network for monocular depth estimation trained by mixing several datasets.
2 | This file contains code that is adapted from
3 | https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
4 | """
5 | import torch
6 | import torch.nn as nn
7 |
8 | from .base_model import BaseModel
9 | from .blocks import FeatureFusionBlock, FeatureFusionBlock_custom, Interpolate, _make_encoder
10 |
11 |
12 | class MidasNet_small(BaseModel):
13 | """Network for monocular depth estimation.
14 | """
15 |
16 | def __init__(self, path=None, features=64, backbone="efficientnet_lite3", non_negative=True, exportable=True, channels_last=False, align_corners=True,
17 | blocks={'expand': True}):
18 | """Init.
19 |
20 | Args:
21 | path (str, optional): Path to saved model. Defaults to None.
22 | features (int, optional): Number of features. Defaults to 256.
23 | backbone (str, optional): Backbone network for encoder. Defaults to resnet50
24 | """
25 | print("Loading weights: ", path)
26 |
27 | super(MidasNet_small, self).__init__()
28 |
29 | use_pretrained = False if path else True
30 |
31 | self.channels_last = channels_last
32 | self.blocks = blocks
33 | self.backbone = backbone
34 |
35 | self.groups = 1
36 |
37 | features1=features
38 | features2=features
39 | features3=features
40 | features4=features
41 | self.expand = False
42 | if "expand" in self.blocks and self.blocks['expand'] == True:
43 | self.expand = True
44 | features1=features
45 | features2=features*2
46 | features3=features*4
47 | features4=features*8
48 |
49 | self.pretrained, self.scratch = _make_encoder(self.backbone, features, use_pretrained, groups=self.groups, expand=self.expand, exportable=exportable)
50 |
51 | self.scratch.activation = nn.ReLU(False)
52 |
53 | self.scratch.refinenet4 = FeatureFusionBlock_custom(features4, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
54 | self.scratch.refinenet3 = FeatureFusionBlock_custom(features3, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
55 | self.scratch.refinenet2 = FeatureFusionBlock_custom(features2, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
56 | self.scratch.refinenet1 = FeatureFusionBlock_custom(features1, self.scratch.activation, deconv=False, bn=False, align_corners=align_corners)
57 |
58 |
59 | self.scratch.output_conv = nn.Sequential(
60 | nn.Conv2d(features, features//2, kernel_size=3, stride=1, padding=1, groups=self.groups),
61 | Interpolate(scale_factor=2, mode="bilinear"),
62 | nn.Conv2d(features//2, 32, kernel_size=3, stride=1, padding=1),
63 | self.scratch.activation,
64 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
65 | nn.ReLU(True) if non_negative else nn.Identity(),
66 | nn.Identity(),
67 | )
68 |
69 | if path:
70 | self.load(path)
71 |
72 |
73 | def forward(self, x):
74 | """Forward pass.
75 |
76 | Args:
77 | x (tensor): input data (image)
78 |
79 | Returns:
80 | tensor: depth
81 | """
82 | if self.channels_last==True:
83 | print("self.channels_last = ", self.channels_last)
84 | x.contiguous(memory_format=torch.channels_last)
85 |
86 |
87 | layer_1 = self.pretrained.layer1(x)
88 | layer_2 = self.pretrained.layer2(layer_1)
89 | layer_3 = self.pretrained.layer3(layer_2)
90 | layer_4 = self.pretrained.layer4(layer_3)
91 |
92 | layer_1_rn = self.scratch.layer1_rn(layer_1)
93 | layer_2_rn = self.scratch.layer2_rn(layer_2)
94 | layer_3_rn = self.scratch.layer3_rn(layer_3)
95 | layer_4_rn = self.scratch.layer4_rn(layer_4)
96 |
97 |
98 | path_4 = self.scratch.refinenet4(layer_4_rn)
99 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
100 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
101 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
102 |
103 | out = self.scratch.output_conv(path_1)
104 |
105 | return torch.squeeze(out, dim=1)
106 |
107 |
108 |
109 | def fuse_model(m):
110 | prev_previous_type = nn.Identity()
111 | prev_previous_name = ''
112 | previous_type = nn.Identity()
113 | previous_name = ''
114 | for name, module in m.named_modules():
115 | if prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d and type(module) == nn.ReLU:
116 | # print("FUSED ", prev_previous_name, previous_name, name)
117 | torch.quantization.fuse_modules(m, [prev_previous_name, previous_name, name], inplace=True)
118 | elif prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d:
119 | # print("FUSED ", prev_previous_name, previous_name)
120 | torch.quantization.fuse_modules(m, [prev_previous_name, previous_name], inplace=True)
121 | # elif previous_type == nn.Conv2d and type(module) == nn.ReLU:
122 | # print("FUSED ", previous_name, name)
123 | # torch.quantization.fuse_modules(m, [previous_name, name], inplace=True)
124 |
125 | prev_previous_type = previous_type
126 | prev_previous_name = previous_name
127 | previous_type = type(module)
128 | previous_name = name
--------------------------------------------------------------------------------
/controlnet/annotator/midas/midas/transforms.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import cv2
3 | import math
4 |
5 |
6 | def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA):
7 | """Rezise the sample to ensure the given size. Keeps aspect ratio.
8 |
9 | Args:
10 | sample (dict): sample
11 | size (tuple): image size
12 |
13 | Returns:
14 | tuple: new size
15 | """
16 | shape = list(sample["disparity"].shape)
17 |
18 | if shape[0] >= size[0] and shape[1] >= size[1]:
19 | return sample
20 |
21 | scale = [0, 0]
22 | scale[0] = size[0] / shape[0]
23 | scale[1] = size[1] / shape[1]
24 |
25 | scale = max(scale)
26 |
27 | shape[0] = math.ceil(scale * shape[0])
28 | shape[1] = math.ceil(scale * shape[1])
29 |
30 | # resize
31 | sample["image"] = cv2.resize(
32 | sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method
33 | )
34 |
35 | sample["disparity"] = cv2.resize(
36 | sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST
37 | )
38 | sample["mask"] = cv2.resize(
39 | sample["mask"].astype(np.float32),
40 | tuple(shape[::-1]),
41 | interpolation=cv2.INTER_NEAREST,
42 | )
43 | sample["mask"] = sample["mask"].astype(bool)
44 |
45 | return tuple(shape)
46 |
47 |
48 | class Resize(object):
49 | """Resize sample to given size (width, height).
50 | """
51 |
52 | def __init__(
53 | self,
54 | width,
55 | height,
56 | resize_target=True,
57 | keep_aspect_ratio=False,
58 | ensure_multiple_of=1,
59 | resize_method="lower_bound",
60 | image_interpolation_method=cv2.INTER_AREA,
61 | ):
62 | """Init.
63 |
64 | Args:
65 | width (int): desired output width
66 | height (int): desired output height
67 | resize_target (bool, optional):
68 | True: Resize the full sample (image, mask, target).
69 | False: Resize image only.
70 | Defaults to True.
71 | keep_aspect_ratio (bool, optional):
72 | True: Keep the aspect ratio of the input sample.
73 | Output sample might not have the given width and height, and
74 | resize behaviour depends on the parameter 'resize_method'.
75 | Defaults to False.
76 | ensure_multiple_of (int, optional):
77 | Output width and height is constrained to be multiple of this parameter.
78 | Defaults to 1.
79 | resize_method (str, optional):
80 | "lower_bound": Output will be at least as large as the given size.
81 | "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.)
82 | "minimal": Scale as least as possible. (Output size might be smaller than given size.)
83 | Defaults to "lower_bound".
84 | """
85 | self.__width = width
86 | self.__height = height
87 |
88 | self.__resize_target = resize_target
89 | self.__keep_aspect_ratio = keep_aspect_ratio
90 | self.__multiple_of = ensure_multiple_of
91 | self.__resize_method = resize_method
92 | self.__image_interpolation_method = image_interpolation_method
93 |
94 | def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
95 | y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int)
96 |
97 | if max_val is not None and y > max_val:
98 | y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int)
99 |
100 | if y < min_val:
101 | y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int)
102 |
103 | return y
104 |
105 | def get_size(self, width, height):
106 | # determine new height and width
107 | scale_height = self.__height / height
108 | scale_width = self.__width / width
109 |
110 | if self.__keep_aspect_ratio:
111 | if self.__resize_method == "lower_bound":
112 | # scale such that output size is lower bound
113 | if scale_width > scale_height:
114 | # fit width
115 | scale_height = scale_width
116 | else:
117 | # fit height
118 | scale_width = scale_height
119 | elif self.__resize_method == "upper_bound":
120 | # scale such that output size is upper bound
121 | if scale_width < scale_height:
122 | # fit width
123 | scale_height = scale_width
124 | else:
125 | # fit height
126 | scale_width = scale_height
127 | elif self.__resize_method == "minimal":
128 | # scale as least as possbile
129 | if abs(1 - scale_width) < abs(1 - scale_height):
130 | # fit width
131 | scale_height = scale_width
132 | else:
133 | # fit height
134 | scale_width = scale_height
135 | else:
136 | raise ValueError(
137 | f"resize_method {self.__resize_method} not implemented"
138 | )
139 |
140 | if self.__resize_method == "lower_bound":
141 | new_height = self.constrain_to_multiple_of(
142 | scale_height * height, min_val=self.__height
143 | )
144 | new_width = self.constrain_to_multiple_of(
145 | scale_width * width, min_val=self.__width
146 | )
147 | elif self.__resize_method == "upper_bound":
148 | new_height = self.constrain_to_multiple_of(
149 | scale_height * height, max_val=self.__height
150 | )
151 | new_width = self.constrain_to_multiple_of(
152 | scale_width * width, max_val=self.__width
153 | )
154 | elif self.__resize_method == "minimal":
155 | new_height = self.constrain_to_multiple_of(scale_height * height)
156 | new_width = self.constrain_to_multiple_of(scale_width * width)
157 | else:
158 | raise ValueError(f"resize_method {self.__resize_method} not implemented")
159 |
160 | return (new_width, new_height)
161 |
162 | def __call__(self, sample):
163 | width, height = self.get_size(
164 | sample["image"].shape[1], sample["image"].shape[0]
165 | )
166 |
167 | # resize sample
168 | sample["image"] = cv2.resize(
169 | sample["image"],
170 | (width, height),
171 | interpolation=self.__image_interpolation_method,
172 | )
173 |
174 | if self.__resize_target:
175 | if "disparity" in sample:
176 | sample["disparity"] = cv2.resize(
177 | sample["disparity"],
178 | (width, height),
179 | interpolation=cv2.INTER_NEAREST,
180 | )
181 |
182 | if "depth" in sample:
183 | sample["depth"] = cv2.resize(
184 | sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST
185 | )
186 |
187 | sample["mask"] = cv2.resize(
188 | sample["mask"].astype(np.float32),
189 | (width, height),
190 | interpolation=cv2.INTER_NEAREST,
191 | )
192 | sample["mask"] = sample["mask"].astype(bool)
193 |
194 | return sample
195 |
196 |
197 | class NormalizeImage(object):
198 | """Normlize image by given mean and std.
199 | """
200 |
201 | def __init__(self, mean, std):
202 | self.__mean = mean
203 | self.__std = std
204 |
205 | def __call__(self, sample):
206 | sample["image"] = (sample["image"] - self.__mean) / self.__std
207 |
208 | return sample
209 |
210 |
211 | class PrepareForNet(object):
212 | """Prepare sample for usage as network input.
213 | """
214 |
215 | def __init__(self):
216 | pass
217 |
218 | def __call__(self, sample):
219 | image = np.transpose(sample["image"], (2, 0, 1))
220 | sample["image"] = np.ascontiguousarray(image).astype(np.float32)
221 |
222 | if "mask" in sample:
223 | sample["mask"] = sample["mask"].astype(np.float32)
224 | sample["mask"] = np.ascontiguousarray(sample["mask"])
225 |
226 | if "disparity" in sample:
227 | disparity = sample["disparity"].astype(np.float32)
228 | sample["disparity"] = np.ascontiguousarray(disparity)
229 |
230 | if "depth" in sample:
231 | depth = sample["depth"].astype(np.float32)
232 | sample["depth"] = np.ascontiguousarray(depth)
233 |
234 | return sample
235 |
--------------------------------------------------------------------------------
/controlnet/annotator/midas/utils.py:
--------------------------------------------------------------------------------
1 | """Utils for monoDepth."""
2 | import sys
3 | import re
4 | import numpy as np
5 | import cv2
6 | import torch
7 |
8 |
9 | def read_pfm(path):
10 | """Read pfm file.
11 |
12 | Args:
13 | path (str): path to file
14 |
15 | Returns:
16 | tuple: (data, scale)
17 | """
18 | with open(path, "rb") as file:
19 |
20 | color = None
21 | width = None
22 | height = None
23 | scale = None
24 | endian = None
25 |
26 | header = file.readline().rstrip()
27 | if header.decode("ascii") == "PF":
28 | color = True
29 | elif header.decode("ascii") == "Pf":
30 | color = False
31 | else:
32 | raise Exception("Not a PFM file: " + path)
33 |
34 | dim_match = re.match(r"^(\d+)\s(\d+)\s$", file.readline().decode("ascii"))
35 | if dim_match:
36 | width, height = list(map(int, dim_match.groups()))
37 | else:
38 | raise Exception("Malformed PFM header.")
39 |
40 | scale = float(file.readline().decode("ascii").rstrip())
41 | if scale < 0:
42 | # little-endian
43 | endian = "<"
44 | scale = -scale
45 | else:
46 | # big-endian
47 | endian = ">"
48 |
49 | data = np.fromfile(file, endian + "f")
50 | shape = (height, width, 3) if color else (height, width)
51 |
52 | data = np.reshape(data, shape)
53 | data = np.flipud(data)
54 |
55 | return data, scale
56 |
57 |
58 | def write_pfm(path, image, scale=1):
59 | """Write pfm file.
60 |
61 | Args:
62 | path (str): pathto file
63 | image (array): data
64 | scale (int, optional): Scale. Defaults to 1.
65 | """
66 |
67 | with open(path, "wb") as file:
68 | color = None
69 |
70 | if image.dtype.name != "float32":
71 | raise Exception("Image dtype must be float32.")
72 |
73 | image = np.flipud(image)
74 |
75 | if len(image.shape) == 3 and image.shape[2] == 3: # color image
76 | color = True
77 | elif (
78 | len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1
79 | ): # greyscale
80 | color = False
81 | else:
82 | raise Exception("Image must have H x W x 3, H x W x 1 or H x W dimensions.")
83 |
84 | file.write("PF\n" if color else "Pf\n".encode())
85 | file.write("%d %d\n".encode() % (image.shape[1], image.shape[0]))
86 |
87 | endian = image.dtype.byteorder
88 |
89 | if endian == "<" or endian == "=" and sys.byteorder == "little":
90 | scale = -scale
91 |
92 | file.write("%f\n".encode() % scale)
93 |
94 | image.tofile(file)
95 |
96 |
97 | def read_image(path):
98 | """Read image and output RGB image (0-1).
99 |
100 | Args:
101 | path (str): path to file
102 |
103 | Returns:
104 | array: RGB image (0-1)
105 | """
106 | img = cv2.imread(path)
107 |
108 | if img.ndim == 2:
109 | img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
110 |
111 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0
112 |
113 | return img
114 |
115 |
116 | def resize_image(img):
117 | """Resize image and make it fit for network.
118 |
119 | Args:
120 | img (array): image
121 |
122 | Returns:
123 | tensor: data ready for network
124 | """
125 | height_orig = img.shape[0]
126 | width_orig = img.shape[1]
127 |
128 | if width_orig > height_orig:
129 | scale = width_orig / 384
130 | else:
131 | scale = height_orig / 384
132 |
133 | height = (np.ceil(height_orig / scale / 32) * 32).astype(int)
134 | width = (np.ceil(width_orig / scale / 32) * 32).astype(int)
135 |
136 | img_resized = cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA)
137 |
138 | img_resized = (
139 | torch.from_numpy(np.transpose(img_resized, (2, 0, 1))).contiguous().float()
140 | )
141 | img_resized = img_resized.unsqueeze(0)
142 |
143 | return img_resized
144 |
145 |
146 | def resize_depth(depth, width, height):
147 | """Resize depth map and bring to CPU (numpy).
148 |
149 | Args:
150 | depth (tensor): depth
151 | width (int): image width
152 | height (int): image height
153 |
154 | Returns:
155 | array: processed depth
156 | """
157 | depth = torch.squeeze(depth[0, :, :, :]).to("cpu")
158 |
159 | depth_resized = cv2.resize(
160 | depth.numpy(), (width, height), interpolation=cv2.INTER_CUBIC
161 | )
162 |
163 | return depth_resized
164 |
165 | def write_depth(path, depth, bits=1):
166 | """Write depth map to pfm and png file.
167 |
168 | Args:
169 | path (str): filepath without extension
170 | depth (array): depth
171 | """
172 | write_pfm(path + ".pfm", depth.astype(np.float32))
173 |
174 | depth_min = depth.min()
175 | depth_max = depth.max()
176 |
177 | max_val = (2**(8*bits))-1
178 |
179 | if depth_max - depth_min > np.finfo("float").eps:
180 | out = max_val * (depth - depth_min) / (depth_max - depth_min)
181 | else:
182 | out = np.zeros(depth.shape, dtype=depth.type)
183 |
184 | if bits == 1:
185 | cv2.imwrite(path + ".png", out.astype("uint8"))
186 | elif bits == 2:
187 | cv2.imwrite(path + ".png", out.astype("uint16"))
188 |
189 | return
190 |
--------------------------------------------------------------------------------
/controlnet/annotator/util.py:
--------------------------------------------------------------------------------
1 | import random
2 |
3 | import numpy as np
4 | import cv2
5 | import os
6 | import PIL
7 |
8 | annotator_ckpts_path = os.path.join(os.path.dirname(__file__), 'ckpts')
9 |
10 | def HWC3(x):
11 | assert x.dtype == np.uint8
12 | if x.ndim == 2:
13 | x = x[:, :, None]
14 | assert x.ndim == 3
15 | H, W, C = x.shape
16 | assert C == 1 or C == 3 or C == 4
17 | if C == 3:
18 | return x
19 | if C == 1:
20 | return np.concatenate([x, x, x], axis=2)
21 | if C == 4:
22 | color = x[:, :, 0:3].astype(np.float32)
23 | alpha = x[:, :, 3:4].astype(np.float32) / 255.0
24 | y = color * alpha + 255.0 * (1.0 - alpha)
25 | y = y.clip(0, 255).astype(np.uint8)
26 | return y
27 |
28 |
29 | def resize_image(input_image, resolution, short = False, interpolation=None):
30 | if isinstance(input_image,PIL.Image.Image):
31 | mode = 'pil'
32 | W,H = input_image.size
33 |
34 | elif isinstance(input_image,np.ndarray):
35 | mode = 'cv2'
36 | H, W, _ = input_image.shape
37 |
38 | H = float(H)
39 | W = float(W)
40 | if short:
41 | k = float(resolution) / min(H, W) # k>1 放大, k<1 缩小
42 | else:
43 | k = float(resolution) / max(H, W) # k>1 放大, k<1 缩小
44 | H *= k
45 | W *= k
46 | H = int(np.round(H / 64.0)) * 64
47 | W = int(np.round(W / 64.0)) * 64
48 |
49 | if mode == 'cv2':
50 | if interpolation is None:
51 | interpolation = cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA
52 | img = cv2.resize(input_image, (W, H), interpolation=interpolation)
53 |
54 | elif mode == 'pil':
55 | if interpolation is None:
56 | interpolation = PIL.Image.LANCZOS if k > 1 else PIL.Image.BILINEAR
57 | img = input_image.resize((W, H), resample=interpolation)
58 |
59 | return img
60 |
61 | # def resize_image(input_image, resolution):
62 | # H, W, C = input_image.shape
63 | # H = float(H)
64 | # W = float(W)
65 | # k = float(resolution) / min(H, W)
66 | # H *= k
67 | # W *= k
68 | # H = int(np.round(H / 64.0)) * 64
69 | # W = int(np.round(W / 64.0)) * 64
70 | # img = cv2.resize(input_image, (W, H), interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA)
71 | # return img
72 |
73 |
74 | def nms(x, t, s):
75 | x = cv2.GaussianBlur(x.astype(np.float32), (0, 0), s)
76 |
77 | f1 = np.array([[0, 0, 0], [1, 1, 1], [0, 0, 0]], dtype=np.uint8)
78 | f2 = np.array([[0, 1, 0], [0, 1, 0], [0, 1, 0]], dtype=np.uint8)
79 | f3 = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.uint8)
80 | f4 = np.array([[0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=np.uint8)
81 |
82 | y = np.zeros_like(x)
83 |
84 | for f in [f1, f2, f3, f4]:
85 | np.putmask(y, cv2.dilate(x, kernel=f) == x, x)
86 |
87 | z = np.zeros_like(y, dtype=np.uint8)
88 | z[y > t] = 255
89 | return z
90 |
91 |
92 | def make_noise_disk(H, W, C, F):
93 | noise = np.random.uniform(low=0, high=1, size=((H // F) + 2, (W // F) + 2, C))
94 | noise = cv2.resize(noise, (W + 2 * F, H + 2 * F), interpolation=cv2.INTER_CUBIC)
95 | noise = noise[F: F + H, F: F + W]
96 | noise -= np.min(noise)
97 | noise /= np.max(noise)
98 | if C == 1:
99 | noise = noise[:, :, None]
100 | return noise
101 |
102 |
103 | def min_max_norm(x):
104 | x -= np.min(x)
105 | x /= np.maximum(np.max(x), 1e-5)
106 | return x
107 |
108 |
109 | def safe_step(x, step=2):
110 | y = x.astype(np.float32) * float(step + 1)
111 | y = y.astype(np.int32).astype(np.float32) / float(step)
112 | return y
113 |
114 |
115 | def img2mask(img, H, W, low=10, high=90):
116 | assert img.ndim == 3 or img.ndim == 2
117 | assert img.dtype == np.uint8
118 |
119 | if img.ndim == 3:
120 | y = img[:, :, random.randrange(0, img.shape[2])]
121 | else:
122 | y = img
123 |
124 | y = cv2.resize(y, (W, H), interpolation=cv2.INTER_CUBIC)
125 |
126 | if random.uniform(0, 1) < 0.5:
127 | y = 255 - y
128 |
129 | return y < np.percentile(y, random.randrange(low, high))
130 |
--------------------------------------------------------------------------------
/controlnet/assets/bird.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/controlnet/assets/bird.png
--------------------------------------------------------------------------------
/controlnet/assets/dog.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/controlnet/assets/dog.png
--------------------------------------------------------------------------------
/controlnet/assets/woman_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/controlnet/assets/woman_1.png
--------------------------------------------------------------------------------
/controlnet/assets/woman_2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/controlnet/assets/woman_2.png
--------------------------------------------------------------------------------
/controlnet/assets/woman_3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/controlnet/assets/woman_3.png
--------------------------------------------------------------------------------
/controlnet/assets/woman_4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/controlnet/assets/woman_4.png
--------------------------------------------------------------------------------
/controlnet/outputs/Canny_dog.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/controlnet/outputs/Canny_dog.jpg
--------------------------------------------------------------------------------
/controlnet/outputs/Canny_dog_condition.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/controlnet/outputs/Canny_dog_condition.jpg
--------------------------------------------------------------------------------
/controlnet/outputs/Canny_woman_1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/controlnet/outputs/Canny_woman_1.jpg
--------------------------------------------------------------------------------
/controlnet/outputs/Canny_woman_1_condition.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/controlnet/outputs/Canny_woman_1_condition.jpg
--------------------------------------------------------------------------------
/controlnet/outputs/Canny_woman_1_sdxl.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/controlnet/outputs/Canny_woman_1_sdxl.jpg
--------------------------------------------------------------------------------
/controlnet/outputs/Depth_1_condition.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/controlnet/outputs/Depth_1_condition.jpg
--------------------------------------------------------------------------------
/controlnet/outputs/Depth_bird.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/controlnet/outputs/Depth_bird.jpg
--------------------------------------------------------------------------------
/controlnet/outputs/Depth_bird_condition.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/controlnet/outputs/Depth_bird_condition.jpg
--------------------------------------------------------------------------------
/controlnet/outputs/Depth_bird_sdxl.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/controlnet/outputs/Depth_bird_sdxl.jpg
--------------------------------------------------------------------------------
/controlnet/outputs/Depth_ipadapter_1.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/controlnet/outputs/Depth_ipadapter_1.jpg
--------------------------------------------------------------------------------
/controlnet/outputs/Depth_ipadapter_woman_2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/controlnet/outputs/Depth_ipadapter_woman_2.jpg
--------------------------------------------------------------------------------
/controlnet/outputs/Depth_woman_2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/controlnet/outputs/Depth_woman_2.jpg
--------------------------------------------------------------------------------
/controlnet/outputs/Depth_woman_2_condition.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/controlnet/outputs/Depth_woman_2_condition.jpg
--------------------------------------------------------------------------------
/controlnet/outputs/Pose_woman_3.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/controlnet/outputs/Pose_woman_3.jpg
--------------------------------------------------------------------------------
/controlnet/outputs/Pose_woman_3_condition.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/controlnet/outputs/Pose_woman_3_condition.jpg
--------------------------------------------------------------------------------
/controlnet/outputs/Pose_woman_3_sdxl.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/controlnet/outputs/Pose_woman_3_sdxl.jpg
--------------------------------------------------------------------------------
/controlnet/outputs/Pose_woman_4.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/controlnet/outputs/Pose_woman_4.jpg
--------------------------------------------------------------------------------
/controlnet/outputs/Pose_woman_4_condition.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/controlnet/outputs/Pose_woman_4_condition.jpg
--------------------------------------------------------------------------------
/controlnet/sample_controlNet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from transformers import CLIPVisionModelWithProjection,CLIPImageProcessor
3 | from diffusers.utils import load_image
4 | import os,sys
5 |
6 | from kolors.pipelines.pipeline_controlnet_xl_kolors_img2img import StableDiffusionXLControlNetImg2ImgPipeline
7 | from kolors.models.modeling_chatglm import ChatGLMModel
8 | from kolors.models.tokenization_chatglm import ChatGLMTokenizer
9 | from kolors.models.controlnet import ControlNetModel
10 |
11 | from diffusers import AutoencoderKL
12 | from kolors.models.unet_2d_condition import UNet2DConditionModel
13 |
14 | from diffusers import EulerDiscreteScheduler
15 | from PIL import Image
16 | import numpy as np
17 | import cv2
18 |
19 | from annotator.midas import MidasDetector
20 | from annotator.dwpose import DWposeDetector
21 | from annotator.util import resize_image,HWC3
22 |
23 | root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
24 |
25 | def process_canny_condition( image, canny_threods=[100,200] ):
26 | np_image = image.copy()
27 | np_image = cv2.Canny(np_image, canny_threods[0], canny_threods[1])
28 | np_image = np_image[:, :, None]
29 | np_image = np.concatenate([np_image, np_image, np_image], axis=2)
30 | np_image = HWC3(np_image)
31 | return Image.fromarray(np_image)
32 |
33 |
34 | model_midas = None
35 | def process_depth_condition_midas(img, res = 1024):
36 | h,w,_ = img.shape
37 | img = resize_image(HWC3(img), res)
38 | global model_midas
39 | if model_midas is None:
40 | model_midas = MidasDetector()
41 |
42 | result = HWC3( model_midas(img) )
43 | result = cv2.resize( result, (w,h) )
44 | return Image.fromarray(result)
45 |
46 |
47 | model_dwpose = None
48 | def process_dwpose_condition( image, res=1024 ):
49 | h,w,_ = image.shape
50 | img = resize_image(HWC3(image), res)
51 | global model_dwpose
52 | if model_dwpose is None:
53 | model_dwpose = DWposeDetector()
54 | out_res, out_img = model_dwpose(image)
55 | result = HWC3( out_img )
56 | result = cv2.resize( result, (w,h) )
57 | return Image.fromarray(result)
58 |
59 |
60 | def infer( image_path , prompt, model_type = 'Canny' ):
61 |
62 | ckpt_dir = f'{root_dir}/weights/Kolors'
63 | text_encoder = ChatGLMModel.from_pretrained(
64 | f'{ckpt_dir}/text_encoder',
65 | torch_dtype=torch.float16).half()
66 | tokenizer = ChatGLMTokenizer.from_pretrained(f'{ckpt_dir}/text_encoder')
67 | vae = AutoencoderKL.from_pretrained(f"{ckpt_dir}/vae", revision=None).half()
68 | scheduler = EulerDiscreteScheduler.from_pretrained(f"{ckpt_dir}/scheduler")
69 | unet = UNet2DConditionModel.from_pretrained(f"{ckpt_dir}/unet", revision=None).half()
70 |
71 | control_path = f'{root_dir}/weights/Kolors-ControlNet-{model_type}'
72 | controlnet = ControlNetModel.from_pretrained( control_path , revision=None).half()
73 |
74 | pipe = StableDiffusionXLControlNetImg2ImgPipeline(
75 | vae=vae,
76 | controlnet = controlnet,
77 | text_encoder=text_encoder,
78 | tokenizer=tokenizer,
79 | unet=unet,
80 | scheduler=scheduler,
81 | force_zeros_for_empty_prompt=False
82 | )
83 |
84 | pipe = pipe.to("cuda")
85 | pipe.enable_model_cpu_offload()
86 |
87 | negative_prompt = 'nsfw,脸部阴影,低分辨率,jpeg伪影、模糊、糟糕,黑脸,霓虹灯'
88 |
89 | MAX_IMG_SIZE=1024
90 | controlnet_conditioning_scale = 0.7
91 | control_guidance_end = 0.9
92 | strength = 1.0
93 |
94 | basename = image_path.rsplit('/',1)[-1].rsplit('.',1)[0]
95 |
96 | init_image = Image.open( image_path )
97 |
98 | init_image = resize_image( init_image, MAX_IMG_SIZE)
99 | if model_type == 'Canny':
100 | condi_img = process_canny_condition( np.array(init_image) )
101 | elif model_type == 'Depth':
102 | condi_img = process_depth_condition_midas( np.array(init_image), MAX_IMG_SIZE )
103 | elif model_type == 'Pose':
104 | condi_img = process_dwpose_condition( np.array(init_image), MAX_IMG_SIZE)
105 |
106 | generator = torch.Generator(device="cpu").manual_seed(66)
107 | image = pipe(
108 | prompt= prompt ,
109 | image = init_image,
110 | controlnet_conditioning_scale = controlnet_conditioning_scale,
111 | control_guidance_end = control_guidance_end,
112 | strength= strength ,
113 | control_image = condi_img,
114 | negative_prompt= negative_prompt ,
115 | num_inference_steps= 50 ,
116 | guidance_scale= 6.0,
117 | num_images_per_prompt=1,
118 | generator=generator,
119 | ).images[0]
120 |
121 | condi_img.save( f'{root_dir}/controlnet/outputs/{model_type}_{basename}_condition.jpg' )
122 | image.save(f'{root_dir}/controlnet/outputs/{model_type}_{basename}.jpg')
123 |
124 |
125 | if __name__ == '__main__':
126 | import fire
127 | fire.Fire(infer)
128 |
129 |
130 |
--------------------------------------------------------------------------------
/controlnet/sample_controlNet_ipadapter.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from transformers import CLIPVisionModelWithProjection,CLIPImageProcessor
3 | from diffusers.utils import load_image
4 | import os,sys
5 |
6 | from kolors.pipelines.pipeline_controlnet_xl_kolors_img2img import StableDiffusionXLControlNetImg2ImgPipeline
7 | from kolors.models.modeling_chatglm import ChatGLMModel
8 | from kolors.models.tokenization_chatglm import ChatGLMTokenizer
9 | from kolors.models.controlnet import ControlNetModel
10 |
11 | from diffusers import AutoencoderKL
12 | from kolors.models.unet_2d_condition import UNet2DConditionModel
13 |
14 | from diffusers import EulerDiscreteScheduler
15 | from PIL import Image
16 | import numpy as np
17 | import cv2
18 |
19 | from annotator.midas import MidasDetector
20 | from annotator.util import resize_image,HWC3
21 |
22 | root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
23 |
24 | def process_canny_condition( image, canny_threods=[100,200] ):
25 | np_image = image.copy()
26 | np_image = cv2.Canny(np_image, canny_threods[0], canny_threods[1])
27 | np_image = np_image[:, :, None]
28 | np_image = np.concatenate([np_image, np_image, np_image], axis=2)
29 | np_image = HWC3(np_image)
30 | return Image.fromarray(np_image)
31 |
32 |
33 | model_midas = None
34 | def process_depth_condition_midas(img, res = 1024):
35 | h,w,_ = img.shape
36 | img = resize_image(HWC3(img), res)
37 | global model_midas
38 | if model_midas is None:
39 | model_midas = MidasDetector()
40 |
41 | result = HWC3( model_midas(img) )
42 | result = cv2.resize( result, (w,h) )
43 | return Image.fromarray(result)
44 |
45 |
46 | def infer( image_path , ip_image_path, prompt, model_type = 'Canny' ):
47 |
48 | ckpt_dir = f'{root_dir}/weights/Kolors'
49 | text_encoder = ChatGLMModel.from_pretrained(
50 | f'{ckpt_dir}/text_encoder',
51 | torch_dtype=torch.float16).half()
52 | tokenizer = ChatGLMTokenizer.from_pretrained(f'{ckpt_dir}/text_encoder')
53 | vae = AutoencoderKL.from_pretrained(f"{ckpt_dir}/vae", revision=None).half()
54 | scheduler = EulerDiscreteScheduler.from_pretrained(f"{ckpt_dir}/scheduler")
55 | unet = UNet2DConditionModel.from_pretrained(f"{ckpt_dir}/unet", revision=None).half()
56 |
57 | control_path = f'{root_dir}/weights/Kolors-ControlNet-{model_type}'
58 | controlnet = ControlNetModel.from_pretrained( control_path , revision=None).half()
59 |
60 | # IP-Adapter model
61 | image_encoder = CLIPVisionModelWithProjection.from_pretrained( f'{root_dir}/weights/Kolors-IP-Adapter-Plus/image_encoder', ignore_mismatched_sizes=True).to(dtype=torch.float16)
62 | ip_img_size = 336
63 | clip_image_processor = CLIPImageProcessor( size=ip_img_size, crop_size=ip_img_size )
64 |
65 | pipe = StableDiffusionXLControlNetImg2ImgPipeline(
66 | vae=vae,
67 | controlnet = controlnet,
68 | text_encoder=text_encoder,
69 | tokenizer=tokenizer,
70 | unet=unet,
71 | scheduler=scheduler,
72 | image_encoder=image_encoder,
73 | feature_extractor=clip_image_processor,
74 | force_zeros_for_empty_prompt=False
75 | )
76 |
77 | if hasattr(pipe.unet, 'encoder_hid_proj'):
78 | pipe.unet.text_encoder_hid_proj = pipe.unet.encoder_hid_proj
79 |
80 | pipe.load_ip_adapter( f'{root_dir}/weights/Kolors-IP-Adapter-Plus' , subfolder="", weight_name=["ip_adapter_plus_general.bin"])
81 |
82 | pipe = pipe.to("cuda")
83 | pipe.enable_model_cpu_offload()
84 |
85 | negative_prompt = 'nsfw,脸部阴影,低分辨率,糟糕的解剖结构、糟糕的手,缺失手指、质量最差、低质量、jpeg伪影、模糊、糟糕,黑脸,霓虹灯'
86 |
87 | MAX_IMG_SIZE=1024
88 | controlnet_conditioning_scale = 0.5
89 | control_guidance_end = 0.9
90 | strength = 1.0
91 | ip_scale = 0.5
92 |
93 | basename = image_path.rsplit('/',1)[-1].rsplit('.',1)[0]
94 |
95 | init_image = Image.open( image_path )
96 | init_image = resize_image( init_image, MAX_IMG_SIZE)
97 |
98 | if model_type == 'Canny':
99 | condi_img = process_canny_condition( np.array(init_image) )
100 | elif model_type == 'Depth':
101 | condi_img = process_depth_condition_midas( np.array(init_image), MAX_IMG_SIZE )
102 |
103 | ip_adapter_img = Image.open(ip_image_path)
104 | pipe.set_ip_adapter_scale([ ip_scale ])
105 |
106 | generator = torch.Generator(device="cpu").manual_seed(66)
107 | image = pipe(
108 | prompt= prompt ,
109 | image = init_image,
110 | controlnet_conditioning_scale = controlnet_conditioning_scale,
111 | control_guidance_end = control_guidance_end,
112 |
113 | ip_adapter_image=[ ip_adapter_img ],
114 |
115 | strength= strength ,
116 | control_image = condi_img,
117 | negative_prompt= negative_prompt ,
118 | num_inference_steps= 50 ,
119 | guidance_scale= 5.0,
120 | num_images_per_prompt=1,
121 | generator=generator,
122 | ).images[0]
123 |
124 | image.save(f'{root_dir}/controlnet/outputs/{model_type}_ipadapter_{basename}.jpg')
125 | condi_img.save(f'{root_dir}/controlnet/outputs/{model_type}_{basename}_condition.jpg')
126 |
127 |
128 | if __name__ == '__main__':
129 | import fire
130 | fire.Fire(infer)
131 |
132 |
133 |
134 |
135 |
136 |
137 |
138 |
139 |
140 |
141 |
--------------------------------------------------------------------------------
/dreambooth/README.md:
--------------------------------------------------------------------------------
1 | ## 📖 Introduction
2 | We Provide LoRA training and inference code based on [Kolors-Basemodel](https://huggingface.co/Kwai-Kolors/Kolors), along with an IP LoRA training example.
3 |
4 |
5 |
6 | Example Result |
7 |
8 |
9 |
10 | Prompt |
11 | Result Image |
12 |
13 |
14 |
15 |
16 | ktxl狗在草地上跑。 |
17 |  |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 | **Our improvements**
28 |
29 | - Supporting user-defined caption files. By putting the '.txt' file with the same name as the image file in the same directory, the caption file will be automatically matched with the proper image
30 |
31 | ## 🛠️ Usage
32 |
33 | ### Requirements
34 |
35 | The dependencies and installation are basically the same as the [Kolors-BaseModel](https://huggingface.co/Kwai-Kolors/Kolors).
36 |
37 | 1. Repository Cloning and Dependency Installation
38 |
39 | ```bash
40 | apt-get install git-lfs
41 | git clone https://github.com/Kwai-Kolors/Kolors
42 | cd Kolors
43 | conda create --name kolors python=3.8
44 | conda activate kolors
45 | pip install -r requirements.txt
46 | python3 setup.py install
47 | ```
48 | ### Training
49 |
50 | 1. First, we need to get our datasaet.https://huggingface.co/datasets/diffusers/dog-example.
51 | ```python
52 | from huggingface_hub import snapshot_download
53 |
54 | local_dir = "./dog"
55 | snapshot_download(
56 | "diffusers/dog-example",
57 | local_dir=local_dir, repo_type="dataset",
58 | ignore_patterns=".gitattributes",
59 | )
60 | ```
61 |
62 | **___Note: To load caption files automatically during training, you can use the same name for both the image and its corresponding '.txt' caption file.___**
63 |
64 |
65 | 2. Launch the training using:
66 | ```bash
67 | sh train.sh
68 | ```
69 |
70 | 3. Training configuration. We train the model using the default configuration in the `train.sh` file on 8 V100 GPUs, consuming a total of 27GB of the memory. You can also finetune the text encoder by adding `--train_text_encoder`:
71 | ```bash
72 | MODEL_NAME="/path/base_model_path"
73 | CLASS_DIR="/path/regularization_image_path"
74 | INSTANCE_DIR="path/training_image_path"
75 | OUTPUT_DIR="./trained_models"
76 | cfg_file=./default_config.yaml
77 |
78 | accelerate launch --config_file ${cfg_file} train_dreambooth_lora.py \
79 | --pretrained_model_name_or_path=$MODEL_NAME \
80 | --instance_data_dir=$INSTANCE_DIR \
81 | --output_dir=$OUTPUT_DIR \
82 | --class_data_dir=$CLASS_DIR \
83 | --instance_prompt="ktxl狗" \
84 | --class_prompt="狗" \
85 | --train_batch_size=1 \
86 | --gradient_accumulation_steps=1 \
87 | --learning_rate=2e-5 \
88 | --text_encoder_lr=5e-5 \
89 | --lr_scheduler="polynomial" \
90 | --lr_warmup_steps=100 \
91 | --rank=4 \
92 | --resolution=1024 \
93 | --max_train_steps=2000 \
94 | --checkpointing_steps=200 \
95 | --num_class_images=100 \
96 | --center_crop \
97 | --mixed_precision='fp16' \
98 | --seed=19980818 \
99 | --img_repeat_nums=1 \
100 | --sample_batch_size=2 \
101 | --use_preffix_prompt \
102 | --gradient_checkpointing \
103 | --train_text_encoder \
104 | --adam_weight_decay=1e-02 \
105 | --with_prior_preservation \
106 | --prior_loss_weight=0.7 \
107 |
108 | ```
109 |
110 |
111 | **___Note: Most of our training configurations stay the same with official [diffusers](https://github.com/huggingface/diffusers/tree/main/examples/dreambooth) .___**
112 |
113 | ### Inference
114 | ```bash
115 | python infer_dreambooth.py "ktxl狗在草地上跑"
116 | ```
117 |
--------------------------------------------------------------------------------
/dreambooth/default_config.yaml:
--------------------------------------------------------------------------------
1 | compute_environment: LOCAL_MACHINE
2 | deepspeed_config:
3 | deepspeed_multinode_launcher: standard
4 | gradient_accumulation_steps: 1
5 | gradient_clipping: 1.5
6 | offload_optimizer_device: none
7 | offload_param_device: none
8 | zero3_init_flag: true
9 | zero_stage: 2
10 | reduce_scatter: false
11 | overlap_comm: true
12 | distributed_type: DEEPSPEED
13 | downcast_bf16: 'no'
14 | dynamo_backend: 'NO'
15 | fsdp_config: {}
16 | machine_rank: 0
17 | main_process_ip: 10.82.42.75
18 | main_process_port: 22280
19 | main_training_function: main
20 | megatron_lm_config: {}
21 | mixed_precision: fp16
22 | num_machines: 1
23 | num_processes: 8
24 | rdzv_backend: static
25 | same_network: true
26 | use_cpu: false
27 |
--------------------------------------------------------------------------------
/dreambooth/infer_dreambooth.py:
--------------------------------------------------------------------------------
1 | import os, torch
2 | from kolors.pipelines.pipeline_stable_diffusion_xl_chatglm_256 import StableDiffusionXLPipeline
3 | from kolors.models.modeling_chatglm import ChatGLMModel
4 | from kolors.models.tokenization_chatglm import ChatGLMTokenizer
5 | from diffusers import UNet2DConditionModel, AutoencoderKL
6 | from diffusers import EulerDiscreteScheduler
7 |
8 | from peft import (
9 | LoraConfig,
10 | PeftModel,
11 | )
12 |
13 | def infer(prompt):
14 | ckpt_dir = "/path/base_model_path"
15 | lora_ckpt = 'trained_models/ktxl_dog_text/checkpoint-1000/'
16 | load_text_encoder = True
17 |
18 | text_encoder = ChatGLMModel.from_pretrained(
19 | f'{ckpt_dir}/text_encoder',
20 | torch_dtype=torch.float16).half()
21 | tokenizer = ChatGLMTokenizer.from_pretrained(f'{ckpt_dir}/text_encoder')
22 | vae = AutoencoderKL.from_pretrained(f"{ckpt_dir}/vae", revision=None).half()
23 | scheduler = EulerDiscreteScheduler.from_pretrained(f"{ckpt_dir}/scheduler")
24 | unet = UNet2DConditionModel.from_pretrained(f"{ckpt_dir}/unet", revision=None).half()
25 | pipe = StableDiffusionXLPipeline(
26 | vae=vae,
27 | text_encoder=text_encoder,
28 | tokenizer=tokenizer,
29 | unet=unet,
30 | scheduler=scheduler,
31 | force_zeros_for_empty_prompt=False)
32 | pipe = pipe.to("cuda")
33 |
34 | pipe.load_lora_weights(lora_ckpt, adapter_name="ktxl-lora")
35 | pipe.set_adapters(["ktxl-lora"], [0.8])
36 |
37 | if load_text_encoder:
38 | pipe.text_encoder = PeftModel.from_pretrained(pipe.text_encoder, lora_ckpt)
39 |
40 | random_seed = 0
41 | generator = torch.Generator(pipe.device).manual_seed(random_seed)
42 |
43 | neg_p = ''
44 | out = pipe(prompt, generator=generator, negative_prompt=neg_p, num_inference_steps=25, width=1024, height=1024, num_images_per_prompt=1, guidance_scale=5).images[0]
45 | out.save("ktxl_test_image.png")
46 |
47 | if __name__ == '__main__':
48 | import fire
49 | fire.Fire(infer)
--------------------------------------------------------------------------------
/dreambooth/ktxl_test_image.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/dreambooth/ktxl_test_image.png
--------------------------------------------------------------------------------
/dreambooth/train.sh:
--------------------------------------------------------------------------------
1 | MODEL_NAME="/path/base_model_path"
2 | INSTANCE_DIR='dog'
3 | CLASS_DIR="reg_dog"
4 | OUTPUT_DIR="trained_models/ktxl_dog_text"
5 | cfg_file=./default_config.yaml
6 |
7 | accelerate launch --config_file ${cfg_file} train_dreambooth_lora.py \
8 | --pretrained_model_name_or_path=$MODEL_NAME \
9 | --instance_data_dir=$INSTANCE_DIR \
10 | --output_dir=$OUTPUT_DIR \
11 | --class_data_dir=$CLASS_DIR \
12 | --instance_prompt="ktxl狗" \
13 | --class_prompt="狗" \
14 | --train_batch_size=1 \
15 | --gradient_accumulation_steps=1 \
16 | --learning_rate=2e-5 \
17 | --text_encoder_lr=5e-5 \
18 | --lr_scheduler="polynomial" \
19 | --lr_warmup_steps=100 \
20 | --rank=4 \
21 | --resolution=1024 \
22 | --max_train_steps=1000 \
23 | --checkpointing_steps=200 \
24 | --num_class_images=100 \
25 | --center_crop \
26 | --mixed_precision='fp16' \
27 | --seed=19980818 \
28 | --img_repeat_nums=1 \
29 | --sample_batch_size=2 \
30 | --gradient_checkpointing \
31 | --adam_weight_decay=1e-02 \
32 | --with_prior_preservation \
33 | --prior_loss_weight=0.7 \
34 | --train_text_encoder \
35 |
--------------------------------------------------------------------------------
/imgs/Kolors_paper.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/imgs/Kolors_paper.pdf
--------------------------------------------------------------------------------
/imgs/cn_all.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/imgs/cn_all.png
--------------------------------------------------------------------------------
/imgs/fz_all.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/imgs/fz_all.png
--------------------------------------------------------------------------------
/imgs/head_final3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/imgs/head_final3.png
--------------------------------------------------------------------------------
/imgs/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/imgs/logo.png
--------------------------------------------------------------------------------
/imgs/prompt_vis.txt:
--------------------------------------------------------------------------------
1 | 1、视觉质量:
2 | - 一对年轻的中国情侣,皮肤白皙,穿着时尚的运动装,背景是现代的北京城市天际线。面部细节,清晰的毛孔,使用最新款的相机拍摄,特写镜头,超高画质,8K,视觉盛宴
3 | 2、中国元素
4 | - 万里长城,蜿蜒
5 | - 一张北京国家体育场(俗称“鸟巢”)的高度细节图片。图片应展示体育场复杂的钢结构,强调其独特的建筑设计。场景设定在白天,天空晴朗,突显出体育场的宏伟规模和现代感。包括周围的奥林匹克公园和一些游客,以增加场景的背景和生气。
6 | - 上海外滩
7 | 3、复杂语义理解
8 | - 满月下的街道,熙熙攘攘的行人正在享受繁华夜生活。街角摊位上,一位有着火红头发、穿着标志性天鹅绒斗篷的年轻女子,正在和脾气暴躁的老小贩讨价还价。这个脾气暴躁的小贩身材高大、老道,身着一套整洁西装,留着小胡子,用他那部蒸汽朋克式的电话兴致勃勃地交谈
9 | - 画面有四只神兽:朱雀、玄武、青龙、白虎。朱雀位于画面上方,羽毛鲜红如火,尾羽如凤凰般绚丽,翅膀展开时似燃烧的火焰。玄武居于下方,是龟蛇相缠的形象,巨龟背上盘绕着一条黑色巨蛇,龟甲上有古老的符文,蛇眼冰冷锐利。青龙位于右方,长身盘旋在天际,龙鳞碧绿如翡翠,龙须飘逸,龙角如鹿,口吐云雾。白虎居于左方,体态威猛,白色的皮毛上有黑色斑纹,双眼炯炯有神,尖牙利爪,周围是苍茫的群山和草原。
10 | - 一张高对比度的照片,熊猫骑在马上,戴着巫师帽,正在看书,马站在土墙旁的街道上,有绿草从街道的裂缝中长出来。
11 | 4、文字绘制
12 | - 一张瓢虫的照片,微距,变焦,高质量,电影,瓢虫拿着一个木牌,上面写着“我爱世界” 的文字
13 | - 一只小橘猫在弹钢琴,钢琴是黑色的牌子是“KOLORS”,猫的身影清晰的映照在钢琴上
14 | - 街边的路牌,上面写着“天道酬勤”
--------------------------------------------------------------------------------
/imgs/wechat.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/imgs/wechat.png
--------------------------------------------------------------------------------
/imgs/wz_all.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/imgs/wz_all.png
--------------------------------------------------------------------------------
/imgs/zl8.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/imgs/zl8.png
--------------------------------------------------------------------------------
/imgs/可图KOLORS模型商业授权申请书-英文版本.docx:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/imgs/可图KOLORS模型商业授权申请书-英文版本.docx
--------------------------------------------------------------------------------
/imgs/可图KOLORS模型商业授权申请书.docx:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/imgs/可图KOLORS模型商业授权申请书.docx
--------------------------------------------------------------------------------
/inpainting/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | ## 📖 Introduction
4 |
5 | We provide Kolors-Inpainting inference code and weights which were initialized with [Kolors-Basemodel](https://huggingface.co/Kwai-Kolors/Kolors). Examples of Kolors-Inpainting results are as follows:
6 |
7 |
8 |
9 |
10 |
11 | Inpainting Results |
12 |
13 |
14 |
15 | Original Image |
16 | Masked Image |
17 | Prompt |
18 | Result Image |
19 |
20 |
21 |
22 |  |
23 |  |
24 | 穿着美少女战士的衣服,一件类似于水手服风格的衣服,包括一个白色紧身上衣,前胸搭配一个大大的红色蝴蝶结。衣服的领子部分呈蓝色,并且有白色条纹。她还穿着一条蓝色百褶裙,超高清,辛烷渲染,高级质感,32k,高分辨率,最好的质量,超级细节,景深 Wearing Sailor Moon's outfit, a sailor-style outfit consisting of a white tight top with a large red bow on the chest. The collar of the outfit is blue and has white stripes. She also wears a blue pleated skirt, Ultra HD, Octane Rendering, Premium Textures, 32k, High Resolution, Best Quality, Super Detail, Depth of Field |
25 |  |
26 |
27 |
28 |
29 |  |
30 |  |
31 | 穿着钢铁侠的衣服,高科技盔甲,主要颜色为红色和金色,并且有一些银色装饰。胸前有一个亮起的圆形反应堆装置,充满了未来科技感。超清晰,高质量,超逼真,高分辨率,最好的质量,超级细节,景深 Wearing Iron Man's clothes, high-tech armor, the main colors are red and gold, and there are some silver decorations. There is a light-up round reactor device on the chest, full of futuristic technology. Ultra-clear, high-quality, ultra-realistic, high-resolution, best quality, super details, depth of field |
32 |  |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 | **Model details**
43 |
44 | - For inpainting, the UNet has 5 additional input channels (4 for the encoded masked image and 1 for the mask itself). The weights for the encoded masked-image channels were initialized from the non-inpainting checkpoint, while the weights for the mask channel were zero-initialized.
45 | - To improve the robustness of the inpainting model, we adopt a more diverse strategy for generating masks, including random masks, subject segmentation masks, rectangular masks, and masks based on dilation operations.
46 |
47 |
48 |
49 |
50 |
51 | ## 📊 Evaluation
52 | For evaluation, we created a test set comprising 200 masked images and text prompts. We invited several image experts to provide unbiased ratings for the generated results of different models. The experts assessed the generated images based on four criteria: visual appeal, text faithfulness, inpainting artifacts, and overall satisfaction. Inpainting artifacts measure the perceptual boundaries in the inpainting results, while the other criteria adhere to the evaluation standards of the BaseModel. The specific results are summarized in the table below, where Kolors-Inpainting achieved the highest overall satisfaction score.
53 |
54 | | Model | Average Overall Satisfaction | Average Inpainting Artifacts | Average Visual Appeal | Average Text Faithfulness |
55 | | :-----------------: | :-----------: | :-----------: | :-----------: | :-----------: |
56 | | SDXL-Inpainting | 2.573 | 1.205 | 3.000 | 4.299 |
57 | | **Kolors-Inpainting** | **3.493** | **0.204** | **3.855** | **4.346** |
58 |
59 |
60 | *The higher the scores for Average Overall Satisfaction, Average Visual Appeal, and Average Text Faithfulness, the better. Conversely, the lower the score for Average Inpainting Artifacts, the better.*
61 |
62 |
63 | The comparison results of SDXL-Inpainting and Kolors-Inpainting are as follows:
64 |
65 |
66 | Comparison Results |
67 |
68 |
69 |
70 | Original Image |
71 | Masked Image |
72 | Prompt |
73 | SDXL-Inpainting Result |
74 | Kolors-Inpainting Result |
75 |
76 |
77 |
78 |  |
79 |  |
80 | 穿着美少女战士的衣服,一件类似于水手服风格的衣服,包括一个白色紧身上衣,前胸搭配一个大大的红色蝴蝶结。衣服的领子部分呈蓝色,并且有白色条纹。她还穿着一条蓝色百褶裙,超高清,辛烷渲染,高级质感,32k,高分辨率,最好的质量,超级细节,景深 Wearing Sailor Moon's outfit, a sailor-style outfit consisting of a white tight top with a large red bow on the chest. The collar of the outfit is blue and has white stripes. She also wears a blue pleated skirt, Ultra HD, Octane Rendering, Premium Textures, 32k, High Resolution, Best Quality, Super Detail, Depth of Field |
81 |  |
82 |  |
83 |
84 |
85 |
86 |  |
87 |  |
88 | 穿着钢铁侠的衣服,高科技盔甲,主要颜色为红色和金色,并且有一些银色装饰。胸前有一个亮起的圆形反应堆装置,充满了未来科技感。超清晰,高质量,超逼真,高分辨率,最好的质量,超级细节,景深 Wearing Iron Man's clothes, high-tech armor, the main colors are red and gold, and there are some silverdecorations. There is a light-up round reactor device on the chest, full of futuristic technology. Ultra-clear , high-quality, ultra-realistic, high-resolution, best quality, super details, depth of field |
89 |  |
90 |  |
91 |
92 |
93 |
94 |  |
95 |  |
96 | 穿着白雪公主的衣服,经典的蓝色裙子,并且在袖口处饰有红色细节,超高清,辛烷渲染,高级质感,32k Dressed in Snow White's classic blue skirt with red details at the cuffs, Ultra HD, Octane Rendering, Premium Textures, 32k |
97 |  |
98 |  |
99 |
100 |
101 |
102 |  |
103 |  |
104 | 一只带着红色帽子的小猫咪,圆脸,大眼,极度可爱,高饱和度,立体,柔和的光线 A kitten wearing a red hat, round face, big eyes, extremely cute, high saturation, three-dimensional, soft light |
105 |  |
106 |  |
107 |
108 |
109 |
110 |
111 |  |
112 |  |
113 | 这是一幅令人垂涎欲滴的火锅画面,各种美味的食材在翻滚的锅中煮着,散发出的热气和香气令人陶醉。火红的辣椒和鲜艳的辣椒油熠熠生辉,具有诱人的招人入胜之色彩。锅内肉质细腻的薄切牛肉、爽口的豆腐皮、鲍汁浓郁的金针菇、爽脆的蔬菜,融合在一起,营造出五彩斑斓的视觉呈现 This is a mouth-watering hot pot scene, with all kinds of delicious ingredients cooking in the boiling pot, emitting intoxicating heat and aroma. The fiery red peppers and bright chili oil are shining, with attractive and fascinating colors. The delicate thin-cut beef, refreshing tofu skin, enoki mushrooms with rich abalone sauce, and crisp vegetables in the pot are combined together to create a colorful visual presentation |
114 |  |
115 |  |
116 |
117 |
118 |
119 |
120 |
121 |
122 | *Kolors-Inpainting employs Chinese prompts, while SDXL-Inpainting uses English prompts.*
123 |
124 |
125 |
126 | ## 🛠️ Usage
127 |
128 | ### Requirements
129 |
130 | The dependencies and installation are basically the same as the [Kolors-BaseModel](https://huggingface.co/Kwai-Kolors/Kolors).
131 |
132 |
133 |
134 | 1. Repository Cloning and Dependency Installation
135 |
136 | ```bash
137 | apt-get install git-lfs
138 | git clone https://github.com/Kwai-Kolors/Kolors
139 | cd Kolors
140 | conda create --name kolors python=3.8
141 | conda activate kolors
142 | pip install -r requirements.txt
143 | python3 setup.py install
144 | ```
145 |
146 | 2. Weights download [link](https://huggingface.co/Kwai-Kolors/Kolors-Inpainting):
147 | ```bash
148 | huggingface-cli download --resume-download Kwai-Kolors/Kolors-Inpainting --local-dir weights/Kolors-Inpainting
149 | ```
150 |
151 | 3. Inference:
152 | ```bash
153 | python3 inpainting/sample_inpainting.py ./inpainting/asset/3.png ./inpainting/asset/3_mask.png 穿着美少女战士的衣服,一件类似于水手服风格的衣服,包括一个白色紧身上衣,前胸搭配一个大大的红色蝴蝶结。衣服的领子部分呈蓝色,并且有白色条纹。她还穿着一条蓝色百褶裙,超高清,辛烷渲染,高级质感,32k,高分辨率,最好的质量,超级细节,景深
154 |
155 | python3 inpainting/sample_inpainting.py ./inpainting/asset/4.png ./inpainting/asset/4_mask.png 穿着钢铁侠的衣服,高科技盔甲,主要颜色为红色和金色,并且有一些银色装饰。胸前有一个亮起的圆形反应堆装置,充满了未来科技感。超清晰,高质量,超逼真,高分辨率,最好的质量,超级细节,景深
156 |
157 | # The image will be saved to "scripts/outputs/"
158 | ```
159 |
160 |
161 |
--------------------------------------------------------------------------------
/inpainting/asset/1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/inpainting/asset/1.png
--------------------------------------------------------------------------------
/inpainting/asset/1_kolors.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/inpainting/asset/1_kolors.png
--------------------------------------------------------------------------------
/inpainting/asset/1_masked.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/inpainting/asset/1_masked.png
--------------------------------------------------------------------------------
/inpainting/asset/1_sdxl.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/inpainting/asset/1_sdxl.png
--------------------------------------------------------------------------------
/inpainting/asset/2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/inpainting/asset/2.png
--------------------------------------------------------------------------------
/inpainting/asset/2_kolors.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/inpainting/asset/2_kolors.png
--------------------------------------------------------------------------------
/inpainting/asset/2_masked.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/inpainting/asset/2_masked.png
--------------------------------------------------------------------------------
/inpainting/asset/2_sdxl.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/inpainting/asset/2_sdxl.png
--------------------------------------------------------------------------------
/inpainting/asset/3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/inpainting/asset/3.png
--------------------------------------------------------------------------------
/inpainting/asset/3_kolors.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/inpainting/asset/3_kolors.png
--------------------------------------------------------------------------------
/inpainting/asset/3_mask.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/inpainting/asset/3_mask.png
--------------------------------------------------------------------------------
/inpainting/asset/3_masked.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/inpainting/asset/3_masked.png
--------------------------------------------------------------------------------
/inpainting/asset/3_sdxl.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/inpainting/asset/3_sdxl.png
--------------------------------------------------------------------------------
/inpainting/asset/4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/inpainting/asset/4.png
--------------------------------------------------------------------------------
/inpainting/asset/4_kolors.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/inpainting/asset/4_kolors.png
--------------------------------------------------------------------------------
/inpainting/asset/4_mask.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/inpainting/asset/4_mask.png
--------------------------------------------------------------------------------
/inpainting/asset/4_masked.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/inpainting/asset/4_masked.png
--------------------------------------------------------------------------------
/inpainting/asset/4_sdxl.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/inpainting/asset/4_sdxl.png
--------------------------------------------------------------------------------
/inpainting/asset/5.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/inpainting/asset/5.png
--------------------------------------------------------------------------------
/inpainting/asset/5_kolors.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/inpainting/asset/5_kolors.png
--------------------------------------------------------------------------------
/inpainting/asset/5_masked.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/inpainting/asset/5_masked.png
--------------------------------------------------------------------------------
/inpainting/asset/5_sdxl.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/inpainting/asset/5_sdxl.png
--------------------------------------------------------------------------------
/inpainting/sample_inpainting.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import os, sys
3 | from PIL import Image
4 |
5 | from diffusers import (
6 | AutoencoderKL,
7 | UNet2DConditionModel,
8 | EulerDiscreteScheduler
9 | )
10 |
11 | root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
12 | sys.path.append(root_dir)
13 |
14 | from kolors.pipelines.pipeline_stable_diffusion_xl_chatglm_256_inpainting import StableDiffusionXLInpaintPipeline
15 | from kolors.models.modeling_chatglm import ChatGLMModel
16 | from kolors.models.tokenization_chatglm import ChatGLMTokenizer
17 |
18 | def infer(image_path, mask_path, prompt):
19 |
20 | ckpt_dir = f'{root_dir}/weights/Kolors-Inpainting'
21 | text_encoder = ChatGLMModel.from_pretrained(
22 | f'{ckpt_dir}/text_encoder',
23 | torch_dtype=torch.float16).half()
24 | tokenizer = ChatGLMTokenizer.from_pretrained(f'{ckpt_dir}/text_encoder')
25 | vae = AutoencoderKL.from_pretrained(f"{ckpt_dir}/vae", revision=None).half()
26 | scheduler = EulerDiscreteScheduler.from_pretrained(f"{ckpt_dir}/scheduler")
27 | unet = UNet2DConditionModel.from_pretrained(f"{ckpt_dir}/unet", revision=None).half()
28 |
29 | pipe = StableDiffusionXLInpaintPipeline(
30 | vae=vae,
31 | text_encoder=text_encoder,
32 | tokenizer=tokenizer,
33 | unet=unet,
34 | scheduler=scheduler
35 | )
36 |
37 | pipe.to("cuda")
38 | pipe.enable_attention_slicing()
39 |
40 | generator = torch.Generator(device="cpu").manual_seed(603)
41 | basename = image_path.rsplit('/', 1)[-1].rsplit('.', 1)[0]
42 | image = Image.open(image_path).convert('RGB')
43 | mask_image = Image.open(mask_path).convert('RGB')
44 |
45 | result = pipe(
46 | prompt = prompt,
47 | image = image,
48 | mask_image = mask_image,
49 | height=1024,
50 | width=768,
51 | guidance_scale = 6.0,
52 | generator= generator,
53 | num_inference_steps= 25,
54 | negative_prompt = '残缺的手指,畸形的手指,畸形的手,残肢,模糊,低质量',
55 | num_images_per_prompt = 1,
56 | strength = 0.999
57 | ).images[0]
58 | result.save(f'{root_dir}/scripts/outputs/sample_inpainting_{basename}.jpg')
59 |
60 | if __name__ == '__main__':
61 | import fire
62 | fire.Fire(infer)
63 |
--------------------------------------------------------------------------------
/ipadapter/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | ## 📖 Introduction
4 |
5 | We provide IP-Adapter-Plus weights and inference code based on [Kolors-Basemodel](https://huggingface.co/Kwai-Kolors/Kolors). Examples of Kolors-IP-Adapter-Plus results are as follows:
6 |
7 |
8 |
9 |
10 |
11 | Example Result |
12 |
13 |
14 |
15 | Reference Image |
16 | Prompt |
17 | Result Image |
18 |
19 |
20 |
21 |  |
22 | 穿着黑色T恤衫,上面中文绿色大字写着“可图”。 Wearing a black T-shirt with the Chinese characters "Ketu" written in large green letters on it. |
23 |  |
24 |
25 |
26 |
27 |  |
28 | 一只可爱的小狗在奔跑。A cute dog is running. |
29 |  |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 | **Our improvements**
40 |
41 | - A stronger image feature extractor. We employ the Openai-CLIP-336 model as the image encoder, which allows us to preserve more details in the reference images
42 | - More diverse and high-quality training data: We construct a large-scale and high-quality training dataset inspired by the data strategies of other works. We believe that paired training data can effectively improve performance.
43 |
44 |
45 |
46 |
47 |
48 | ## 📊 Evaluation
49 | For evaluation, we create a test set consisting of over 200 reference images and text prompts. We invite several image experts to provide fair ratings for the generated results of different models. The experts rate the generated images based on four criteria: visual appeal, text faithfulness, image faithfulness, and overall satisfaction. Image faithfulness measures the semantic preservation ability of IP-Adapter on reference images, while the other criteria follow the evaluation standards of BaseModel. The specific results are summarized in the table below, where Kolors-IP-Adapter-Plus achieves the highest overall satisfaction score.
50 |
51 | | Model | Average Overall Satisfaction | Average Image Faithfulness | Average Visual Appeal | Average Text Faithfulness |
52 | | :--------------: | :--------: | :--------: | :--------: | :--------: |
53 | | SDXL-IP-Adapter-Plus | 2.29 | 2.64 | 3.22 | 4.02 |
54 | | Midjourney-v6-CW | 2.79 | 3.0 | 3.92 | 4.35 |
55 | | **Kolors-IP-Adapter-Plus** | **3.04** | **3.25** | **4.45** | **4.30** |
56 |
57 | *The ip_scale parameter is set to 0.3 in SDXL-IP-Adapter-Plus, while Midjourney-v6-CW utilizes the default cw scale.*
58 |
59 | ------
60 |
61 |
62 |
63 |
64 |
65 |
66 | Compare Result |
67 |
68 |
69 |
70 | Reference Image |
71 | Prompt |
72 | Kolors-IP-Adapter-Plus Result |
73 | SDXL-IP-Adapter-Plus Result |
74 | Midjourney-v6-CW Result |
75 |
76 |
77 |
78 |  |
79 | 一个看向远山的少女形象,雪山背景,采用日本浮世绘风格,混合蓝色和红色柔和调色板,高分辨率 Image of a girl looking towards distant mountains, snowy mountains background, in Japanese ukiyo-e style, mixed blue and red pastel color palette, high resolution. |
80 | |
81 | |
82 | |
83 |
84 |
85 |
86 |  |
87 | 一个漂亮的美女,看向远方A beautiful lady looking into the distance. |
88 | |
89 | |
90 | |
91 |
92 |
93 |
94 |  |
95 | 可爱的猫咪,在花丛中,看镜头Cute cat among flowers, looking at the camera. |
96 | |
97 | |
98 | |
99 |
100 |
101 |
102 |  |
103 | 站在丛林前,戴着太阳帽,高画质,高细节,高清,疯狂的细节,超高清 Standing in front of the jungle, wearing a sun hat, high quality, high detail, high definition, crazy details, ultra high definition. |
104 | |
105 | |
106 | |
107 |
108 |
109 |
110 |
111 |  |
112 | 做个头像,新海诚动漫风格,丰富的色彩,唯美风景,清新明亮,斑驳的光影,最好的质量,超细节,8K画质 Create an avatar, Shinkai Makoto anime style, rich colors, beautiful scenery, fresh and bright, mottled light and shadow, best quality, ultra-detailed, 8K quality. |
113 | |
114 | |
115 | |
116 |
117 |
118 |
119 |
120 |
121 |
122 | *Kolors-IP-Adapter-Plus employs chinese prompts, while other methods use english prompts.*
123 |
124 |
125 |
126 | ## 🛠️ Usage
127 |
128 | ### Requirements
129 |
130 | The dependencies and installation are basically the same as the [Kolors-BaseModel](https://huggingface.co/Kwai-Kolors/Kolors).
131 |
132 |
133 |
134 | 1. Repository Cloning and Dependency Installation
135 |
136 | ```bash
137 | apt-get install git-lfs
138 | git clone https://github.com/Kwai-Kolors/Kolors
139 | cd Kolors
140 | conda create --name kolors python=3.8
141 | conda activate kolors
142 | pip install -r requirements.txt
143 | python3 setup.py install
144 | ```
145 |
146 | 2. Weights download [link](https://huggingface.co/Kwai-Kolors/Kolors-IP-Adapter-Plus):
147 | ```bash
148 | huggingface-cli download --resume-download Kwai-Kolors/Kolors-IP-Adapter-Plus --local-dir weights/Kolors-IP-Adapter-Plus
149 | ```
150 | or
151 | ```bash
152 | git lfs clone https://huggingface.co/Kwai-Kolors/Kolors-IP-Adapter-Plus weights/Kolors-IP-Adapter-Plus
153 | ```
154 |
155 | 3. Inference:
156 | ```bash
157 | python ipadapter/sample_ipadapter_plus.py ./ipadapter/asset/test_ip.jpg "穿着黑色T恤衫,上面中文绿色大字写着“可图”"
158 |
159 | python ipadapter/sample_ipadapter_plus.py ./ipadapter/asset/test_ip2.png "一只可爱的小狗在奔跑"
160 |
161 | # The image will be saved to "scripts/outputs/"
162 | ```
163 |
164 |
165 |
166 |
167 | **Note**
168 |
169 | The IP-Adapter-FaceID model based on Kolors will also be released soon!
170 |
171 |
172 |
173 | ### Acknowledgments
174 | - Thanks to [IP-Adapter](https://github.com/tencent-ailab/IP-Adapter) for providing the codebase.
175 |
176 |
177 |
--------------------------------------------------------------------------------
/ipadapter/asset/1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/ipadapter/asset/1.png
--------------------------------------------------------------------------------
/ipadapter/asset/1_kolors_ip_result.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/ipadapter/asset/1_kolors_ip_result.jpg
--------------------------------------------------------------------------------
/ipadapter/asset/1_mj_cw_result.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/ipadapter/asset/1_mj_cw_result.png
--------------------------------------------------------------------------------
/ipadapter/asset/1_sdxl_ip_result.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/ipadapter/asset/1_sdxl_ip_result.jpg
--------------------------------------------------------------------------------
/ipadapter/asset/2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/ipadapter/asset/2.png
--------------------------------------------------------------------------------
/ipadapter/asset/2_kolors_ip_result.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/ipadapter/asset/2_kolors_ip_result.jpg
--------------------------------------------------------------------------------
/ipadapter/asset/2_mj_cw_result.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/ipadapter/asset/2_mj_cw_result.png
--------------------------------------------------------------------------------
/ipadapter/asset/2_sdxl_ip_result.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/ipadapter/asset/2_sdxl_ip_result.jpg
--------------------------------------------------------------------------------
/ipadapter/asset/3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/ipadapter/asset/3.png
--------------------------------------------------------------------------------
/ipadapter/asset/3_kolors_ip_result.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/ipadapter/asset/3_kolors_ip_result.jpg
--------------------------------------------------------------------------------
/ipadapter/asset/3_mj_cw_result.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/ipadapter/asset/3_mj_cw_result.png
--------------------------------------------------------------------------------
/ipadapter/asset/3_sdxl_ip_result.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/ipadapter/asset/3_sdxl_ip_result.jpg
--------------------------------------------------------------------------------
/ipadapter/asset/4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/ipadapter/asset/4.png
--------------------------------------------------------------------------------
/ipadapter/asset/4_kolors_ip_result.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/ipadapter/asset/4_kolors_ip_result.jpg
--------------------------------------------------------------------------------
/ipadapter/asset/4_mj_cw_result.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/ipadapter/asset/4_mj_cw_result.png
--------------------------------------------------------------------------------
/ipadapter/asset/4_sdxl_ip_result.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/ipadapter/asset/4_sdxl_ip_result.jpg
--------------------------------------------------------------------------------
/ipadapter/asset/5.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/ipadapter/asset/5.png
--------------------------------------------------------------------------------
/ipadapter/asset/5_kolors_ip_result.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/ipadapter/asset/5_kolors_ip_result.jpg
--------------------------------------------------------------------------------
/ipadapter/asset/5_mj_cw_result.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/ipadapter/asset/5_mj_cw_result.png
--------------------------------------------------------------------------------
/ipadapter/asset/5_sdxl_ip_result.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/ipadapter/asset/5_sdxl_ip_result.jpg
--------------------------------------------------------------------------------
/ipadapter/asset/test_ip.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/ipadapter/asset/test_ip.jpg
--------------------------------------------------------------------------------
/ipadapter/asset/test_ip2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/ipadapter/asset/test_ip2.png
--------------------------------------------------------------------------------
/ipadapter/sample_ipadapter_plus.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from transformers import CLIPVisionModelWithProjection,CLIPImageProcessor
3 | from diffusers.utils import load_image
4 | import os,sys
5 |
6 | from kolors.pipelines.pipeline_stable_diffusion_xl_chatglm_256_ipadapter import StableDiffusionXLPipeline
7 | from kolors.models.modeling_chatglm import ChatGLMModel
8 | from kolors.models.tokenization_chatglm import ChatGLMTokenizer
9 |
10 | # from diffusers import UNet2DConditionModel, AutoencoderKL
11 | from diffusers import AutoencoderKL
12 | from kolors.models.unet_2d_condition import UNet2DConditionModel
13 |
14 | from diffusers import EulerDiscreteScheduler
15 | from PIL import Image
16 |
17 | root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
18 |
19 |
20 | def infer( ip_img_path, prompt ):
21 |
22 | ckpt_dir = f'{root_dir}/weights/Kolors'
23 | text_encoder = ChatGLMModel.from_pretrained(
24 | f'{ckpt_dir}/text_encoder',
25 | torch_dtype=torch.float16).half()
26 | tokenizer = ChatGLMTokenizer.from_pretrained(f'{ckpt_dir}/text_encoder')
27 | vae = AutoencoderKL.from_pretrained(f"{ckpt_dir}/vae", revision=None).half()
28 | scheduler = EulerDiscreteScheduler.from_pretrained(f"{ckpt_dir}/scheduler")
29 | unet = UNet2DConditionModel.from_pretrained(f"{ckpt_dir}/unet", revision=None).half()
30 |
31 | image_encoder = CLIPVisionModelWithProjection.from_pretrained( f'{root_dir}/weights/Kolors-IP-Adapter-Plus/image_encoder', ignore_mismatched_sizes=True).to(dtype=torch.float16)
32 | ip_img_size = 336
33 | clip_image_processor = CLIPImageProcessor( size=ip_img_size, crop_size=ip_img_size )
34 |
35 | pipe = StableDiffusionXLPipeline(
36 | vae=vae,
37 | text_encoder=text_encoder,
38 | tokenizer=tokenizer,
39 | unet=unet,
40 | scheduler=scheduler,
41 | image_encoder=image_encoder,
42 | feature_extractor=clip_image_processor,
43 | force_zeros_for_empty_prompt=False
44 | )
45 |
46 | pipe = pipe.to("cuda")
47 | pipe.enable_model_cpu_offload()
48 |
49 | if hasattr(pipe.unet, 'encoder_hid_proj'):
50 | pipe.unet.text_encoder_hid_proj = pipe.unet.encoder_hid_proj
51 |
52 | pipe.load_ip_adapter( f'{root_dir}/weights/Kolors-IP-Adapter-Plus' , subfolder="", weight_name=["ip_adapter_plus_general.bin"])
53 |
54 | basename = ip_img_path.rsplit('/',1)[-1].rsplit('.',1)[0]
55 | ip_adapter_img = Image.open( ip_img_path )
56 | generator = torch.Generator(device="cpu").manual_seed(66)
57 |
58 | for scale in [0.5]:
59 | pipe.set_ip_adapter_scale([ scale ])
60 | # print(prompt)
61 | image = pipe(
62 | prompt= prompt ,
63 | ip_adapter_image=[ ip_adapter_img ],
64 | negative_prompt="",
65 | height=1024,
66 | width=1024,
67 | num_inference_steps= 50,
68 | guidance_scale=5.0,
69 | num_images_per_prompt=1,
70 | generator=generator,
71 | ).images[0]
72 | image.save(f'{root_dir}/scripts/outputs/sample_ip_{basename}.jpg')
73 |
74 |
75 | if __name__ == '__main__':
76 | import fire
77 | fire.Fire(infer)
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
--------------------------------------------------------------------------------
/ipadapter_FaceID/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | ## 📖 Introduction
4 |
5 | We provide Kolors-IP-Adapter-FaceID-Plus module weights and inference code based on [Kolors-Basemodel](https://huggingface.co/Kwai-Kolors/Kolors). Examples of Kolors-IP-Adapter-FaceID-Plus results are as follows:
6 |
7 |
8 |
9 |
10 | Example Results |
11 |
12 |
13 |
14 | Reference Image |
15 | Prompt |
16 | Result Image |
17 |
18 |
19 |
20 |  |
21 | 穿着晚礼服,在星光下的晚宴场景中,烛光闪闪,整个场景洋溢着浪漫而奢华的氛围 Wearing an evening dress, in a starry night banquet scene, candlelight flickering, the whole scene exudes a romantic and luxurious atmosphere. |
22 |  |
23 |
24 |
25 |
26 |  |
27 | 西部牛仔,牛仔帽,荒野大镖客,背景是西部小镇,仙人掌,,日落余晖, 暖色调, 使用XT4胶片拍摄, 噪点, 晕影, 柯达胶卷,复古 Western cowboy, cowboy hat, Red Dead Redemption, background is a western town, cactus, sunset glow, warm tones, shot with XT4 film, grain, vignette, Kodak film, retro. |
28 |  |
29 |
30 |
31 |
32 | - Our Kolors-IP-Adapter-FaceID-Plus module is trained on a large-scale and high-quality face dataset. We use the face ID embeddings generated by [insightface](https://github.com/deepinsight/insightface) and the CLIP features of face area to keep the face ID and structure information.
33 |
34 | ## 📊 Evaluation
35 | For evaluation, we constructed a test set consisting of over 200 reference images and text prompts. We invited several image experts to provide fair ratings for the generated results of different models. The experts assessed the generated images based on five criteria: visual appeal, text faithfulness, face similarity, facial aesthetics and overall satisfaction. Visual appeal and text faithfulness are used to measure the text-to-image generation capability, adhering to the evaluation standards of BaseModel. Meanwhile, face similarity and facial aesthetics are used to evaluate the performance of the proposed Kolors-IP-Adapter-FaceID-Plus. The results are summarized in the table below, where Kolors-IP-Adapter-FaceID-Plus outperforms SDXL-IP-Adapter-FaceID-Plus across all metrics.
36 |
37 |
38 | | Model | Average Text Faithfulness | Average Visual Appeal | Average Face Similarity | Average Facial Aesthetics | Average Overall Satisfaction |
39 | | :--------------: | :--------: | :--------: | :--------: | :--------: | :--------: |
40 | | SDXL-IP-Adapter-FaceID-Plus | 4.014 | 3.455 | 3.05 | 2.584 | 2.448 |
41 | | **Kolors-IP-Adapter-FaceID-Plus** | **4.235** | **4.374** | **4.415** | **3.887** | **3.561** |
42 | ------
43 |
44 |
45 |
46 |
47 |
48 |
49 | Comparison Results |
50 |
51 |
52 |
53 | Reference Image |
54 | Prompt |
55 | SDXL-IP-Adapter-FaceID-Plus |
56 | Kolors-IP-Adapter-FaceID-Plus |
57 |
58 |
59 |
60 |  |
61 | 古典油画风格,油彩厚重, 古典美感,历史气息 Classical oil painting style, thick oil paint, classical aesthetic, historical atmosphere. |
62 | |
63 | |
64 |
65 |
66 |
67 |  |
68 | 夜晚,浪漫的海边,落日余晖洒在海面上,晚霞映照着整个海滩,头戴花环,花短袖,飘逸的头发,背景是美丽的海滩,可爱年轻的半身照,优雅梦幻,细节繁复,超逼真,高分辨率,柔和的背景,低对比度 Night, romantic seaside, sunset glow on the sea, evening glow reflecting on the whole beach, wearing a flower crown, short floral sleeves, flowing hair, background is a beautiful beach, cute young half-body portrait, elegant and dreamy, intricate details, ultra-realistic, high resolution, soft background, low contrast. |
69 | |
70 | |
71 |
72 |
73 |
74 |  |
75 | F1赛车手, 法拉利,戴着着红黑白相间的赛车手头盔,帅气的赛车手,飞舞的彩带,背景赛车车库和天花板泛光,璀璨闪光,穿红白黑相间赛车服,色调统一且明艳,面部白皙,面部特写,正视图 F1 racer, Ferrari, wearing a red, black, and white racing helmet, handsome racer, flying ribbons, background race car garage and ceiling lights, dazzling flashes, wearing red, white, and black racing suit, unified and bright color tone, fair face, facial close-up, front view. |
76 | |
77 | |
78 |
79 |
80 |
81 |  |
82 | 和服,日本传统服饰,在海边的黄昏,远山的背景,在远处的烟火,柔和的灯光,长焦镜头,夜间摄影风格,凉爽的色调,浪漫的气氛,火花四溅,时尚摄影,胶片滤镜 Kimono, traditional Japanese clothing, at the seaside at dusk, distant mountain background, fireworks in the distance, soft lighting, telephoto lens, night photography style, cool tones, romantic atmosphere, sparks flying, fashion photography, film filter. |
83 | |
84 | |
85 |
86 |
87 |
88 |
89 | *Kolors-IP-Adapter-FaceID-Plus employs chinese prompts, while SDXL-IP-Adapter-FaceID-Plus uses english prompts.*
90 |
91 | ## 🛠️ Usage
92 |
93 | ### Requirements
94 |
95 | The dependencies and installation are basically the same as the [Kolors-BaseModel](https://huggingface.co/Kwai-Kolors/Kolors).
96 |
97 |
98 |
99 | 1. Repository Cloning and Dependency Installation
100 |
101 | ```bash
102 | apt-get install git-lfs
103 | git clone https://github.com/Kwai-Kolors/Kolors
104 | cd Kolors
105 | conda create --name kolors python=3.8
106 | conda activate kolors
107 | pip install -r requirements.txt
108 | pip install insightface onnxruntime-gpu
109 | python3 setup.py install
110 | ```
111 |
112 | 2. Weights download [link](https://huggingface.co/Kwai-Kolors/Kolors-IP-Adapter-FaceID-Plus):
113 | ```bash
114 | huggingface-cli download --resume-download Kwai-Kolors/Kolors-IP-Adapter-FaceID-Plus --local-dir weights/Kolors-IP-Adapter-FaceID-Plus
115 | ```
116 | or
117 | ```bash
118 | git lfs clone https://huggingface.co/Kwai-Kolors/Kolors-IP-Adapter-FaceID-Plus weights/Kolors-IP-Adapter-FaceID-Plus
119 | ```
120 |
121 | 3. Inference:
122 | ```bash
123 | python ipadapter_FaceID/sample_ipadapter_faceid_plus.py ./ipadapter_FaceID/assets/image1.png "穿着晚礼服,在星光下的晚宴场景中,烛光闪闪,整个场景洋溢着浪漫而奢华的氛围"
124 |
125 | python ipadapter_FaceID/sample_ipadapter_faceid_plus.py ./ipadapter_FaceID/assets/image2.png "西部牛仔,牛仔帽,荒野大镖客,背景是西部小镇,仙人掌,,日落余晖, 暖色调, 使用XT4胶片拍摄, 噪点, 晕影, 柯达胶卷,复古"
126 |
127 | # The image will be saved to "scripts/outputs/"
128 | ```
129 |
130 | ### Acknowledgments
131 | - Thanks to [insightface](https://github.com/deepinsight/insightface) for the face representations.
132 | - Thanks to [IP-Adapter](https://github.com/tencent-ailab/IP-Adapter) for the codebase.
133 |
134 |
135 |
--------------------------------------------------------------------------------
/ipadapter_FaceID/assets/image1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/ipadapter_FaceID/assets/image1.png
--------------------------------------------------------------------------------
/ipadapter_FaceID/assets/image1_res.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/ipadapter_FaceID/assets/image1_res.png
--------------------------------------------------------------------------------
/ipadapter_FaceID/assets/image2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/ipadapter_FaceID/assets/image2.png
--------------------------------------------------------------------------------
/ipadapter_FaceID/assets/image2_res.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/ipadapter_FaceID/assets/image2_res.png
--------------------------------------------------------------------------------
/ipadapter_FaceID/assets/test_img1_Kolors.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/ipadapter_FaceID/assets/test_img1_Kolors.png
--------------------------------------------------------------------------------
/ipadapter_FaceID/assets/test_img1_SDXL.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/ipadapter_FaceID/assets/test_img1_SDXL.png
--------------------------------------------------------------------------------
/ipadapter_FaceID/assets/test_img1_org.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/ipadapter_FaceID/assets/test_img1_org.png
--------------------------------------------------------------------------------
/ipadapter_FaceID/assets/test_img2_Kolors.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/ipadapter_FaceID/assets/test_img2_Kolors.png
--------------------------------------------------------------------------------
/ipadapter_FaceID/assets/test_img2_SDXL.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/ipadapter_FaceID/assets/test_img2_SDXL.png
--------------------------------------------------------------------------------
/ipadapter_FaceID/assets/test_img2_org.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/ipadapter_FaceID/assets/test_img2_org.png
--------------------------------------------------------------------------------
/ipadapter_FaceID/assets/test_img3_Kolors.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/ipadapter_FaceID/assets/test_img3_Kolors.png
--------------------------------------------------------------------------------
/ipadapter_FaceID/assets/test_img3_SDXL.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/ipadapter_FaceID/assets/test_img3_SDXL.png
--------------------------------------------------------------------------------
/ipadapter_FaceID/assets/test_img3_org.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/ipadapter_FaceID/assets/test_img3_org.png
--------------------------------------------------------------------------------
/ipadapter_FaceID/assets/test_img4_Kolors.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/ipadapter_FaceID/assets/test_img4_Kolors.png
--------------------------------------------------------------------------------
/ipadapter_FaceID/assets/test_img4_SDXL.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/ipadapter_FaceID/assets/test_img4_SDXL.png
--------------------------------------------------------------------------------
/ipadapter_FaceID/assets/test_img4_org.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/ipadapter_FaceID/assets/test_img4_org.png
--------------------------------------------------------------------------------
/ipadapter_FaceID/sample_ipadapter_faceid_plus.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from transformers import CLIPVisionModelWithProjection,CLIPImageProcessor
3 | from diffusers.utils import load_image
4 | import os,sys
5 |
6 | from kolors.pipelines.pipeline_stable_diffusion_xl_chatglm_256_ipadapter_FaceID import StableDiffusionXLPipeline
7 | from kolors.models.modeling_chatglm import ChatGLMModel
8 | from kolors.models.tokenization_chatglm import ChatGLMTokenizer
9 |
10 | from diffusers import AutoencoderKL
11 | from kolors.models.unet_2d_condition import UNet2DConditionModel
12 |
13 | from diffusers import EulerDiscreteScheduler
14 | from PIL import Image
15 |
16 | root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
17 |
18 | import cv2
19 | import numpy as np
20 | import insightface
21 | from diffusers.utils import load_image
22 | from insightface.app import FaceAnalysis
23 | from insightface.data import get_image as ins_get_image
24 |
25 | class FaceInfoGenerator():
26 | def __init__(self, root_dir = "./"):
27 | self.app = FaceAnalysis(name = 'antelopev2', root = root_dir, providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
28 | self.app.prepare(ctx_id = 0, det_size = (640, 640))
29 |
30 | def get_faceinfo_one_img(self, image_path):
31 | face_image = load_image(image_path)
32 | face_info = self.app.get(cv2.cvtColor(np.array(face_image), cv2.COLOR_RGB2BGR))
33 |
34 | if len(face_info) == 0:
35 | face_info = None
36 | else:
37 | face_info = sorted(face_info, key=lambda x:(x['bbox'][2]-x['bbox'][0])*(x['bbox'][3]-x['bbox'][1]))[-1] # only use the maximum face
38 | return face_info
39 |
40 | def face_bbox_to_square(bbox):
41 | ## l, t, r, b to square l, t, r, b
42 | l,t,r,b = bbox
43 | cent_x = (l + r) / 2
44 | cent_y = (t + b) / 2
45 | w, h = r - l, b - t
46 | r = max(w, h) / 2
47 |
48 | l0 = cent_x - r
49 | r0 = cent_x + r
50 | t0 = cent_y - r
51 | b0 = cent_y + r
52 |
53 | return [l0, t0, r0, b0]
54 |
55 | def infer(test_image_path, text_prompt):
56 | ckpt_dir = f'{root_dir}/weights/Kolors'
57 | ip_model_dir = f'{root_dir}/weights/Kolors-IP-Adapter-FaceID-Plus'
58 | device = "cuda:0"
59 |
60 | #### base Kolors model
61 | text_encoder = ChatGLMModel.from_pretrained( f'{ckpt_dir}/text_encoder', torch_dtype = torch.float16).half()
62 | tokenizer = ChatGLMTokenizer.from_pretrained(f'{ckpt_dir}/text_encoder')
63 | vae = AutoencoderKL.from_pretrained(f'{ckpt_dir}/vae', subfolder = "vae", revision = None)
64 | scheduler = EulerDiscreteScheduler.from_pretrained(f'{ckpt_dir}/scheduler')
65 | unet = UNet2DConditionModel.from_pretrained(f'{ckpt_dir}/unet', revision = None).half()
66 |
67 | #### clip image encoder for face structure
68 | clip_image_encoder = CLIPVisionModelWithProjection.from_pretrained(f'{ip_model_dir}/clip-vit-large-patch14-336', ignore_mismatched_sizes=True)
69 | clip_image_encoder.to(device)
70 | clip_image_processor = CLIPImageProcessor(size = 336, crop_size = 336)
71 |
72 | pipe = StableDiffusionXLPipeline(
73 | vae = vae,
74 | text_encoder = text_encoder,
75 | tokenizer = tokenizer,
76 | unet = unet,
77 | scheduler = scheduler,
78 | face_clip_encoder = clip_image_encoder,
79 | face_clip_processor = clip_image_processor,
80 | force_zeros_for_empty_prompt = False,
81 | )
82 | pipe = pipe.to(device)
83 | pipe.enable_model_cpu_offload()
84 |
85 | pipe.load_ip_adapter_faceid_plus(f'{ip_model_dir}/ipa-faceid-plus.bin', device = device)
86 |
87 | scale = 0.8
88 | pipe.set_face_fidelity_scale(scale)
89 |
90 | #### prepare face embedding & bbox with insightface toolbox
91 | face_info_generator = FaceInfoGenerator(root_dir = "./")
92 | img = Image.open(test_image_path)
93 | face_info = face_info_generator.get_faceinfo_one_img(test_image_path)
94 |
95 | face_bbox_square = face_bbox_to_square(face_info["bbox"])
96 | crop_image = img.crop(face_bbox_square)
97 | crop_image = crop_image.resize((336, 336))
98 | crop_image = [crop_image]
99 |
100 | face_embeds = torch.from_numpy(np.array([face_info["embedding"]]))
101 | face_embeds = face_embeds.to(device, dtype = torch.float16)
102 |
103 | #### generate image
104 | generator = torch.Generator(device = device).manual_seed(66)
105 | image = pipe(
106 | prompt = text_prompt,
107 | negative_prompt = "",
108 | height = 1024,
109 | width = 1024,
110 | num_inference_steps= 25,
111 | guidance_scale = 5.0,
112 | num_images_per_prompt = 1,
113 | generator = generator,
114 | face_crop_image = crop_image,
115 | face_insightface_embeds = face_embeds,
116 | ).images[0]
117 | image.save(f'../scripts/outputs/test_res.png')
118 |
119 | if __name__ == '__main__':
120 | import fire
121 | fire.Fire(infer)
--------------------------------------------------------------------------------
/kolors/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/kolors/__init__.py
--------------------------------------------------------------------------------
/kolors/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/kolors/models/__init__.py
--------------------------------------------------------------------------------
/kolors/models/configuration_chatglm.py:
--------------------------------------------------------------------------------
1 | from transformers import PretrainedConfig
2 |
3 |
4 | class ChatGLMConfig(PretrainedConfig):
5 | model_type = "chatglm"
6 | def __init__(
7 | self,
8 | num_layers=28,
9 | padded_vocab_size=65024,
10 | hidden_size=4096,
11 | ffn_hidden_size=13696,
12 | kv_channels=128,
13 | num_attention_heads=32,
14 | seq_length=2048,
15 | hidden_dropout=0.0,
16 | classifier_dropout=None,
17 | attention_dropout=0.0,
18 | layernorm_epsilon=1e-5,
19 | rmsnorm=True,
20 | apply_residual_connection_post_layernorm=False,
21 | post_layer_norm=True,
22 | add_bias_linear=False,
23 | add_qkv_bias=False,
24 | bias_dropout_fusion=True,
25 | multi_query_attention=False,
26 | multi_query_group_num=1,
27 | apply_query_key_layer_scaling=True,
28 | attention_softmax_in_fp32=True,
29 | fp32_residual_connection=False,
30 | quantization_bit=0,
31 | pre_seq_len=None,
32 | prefix_projection=False,
33 | **kwargs
34 | ):
35 | self.num_layers = num_layers
36 | self.vocab_size = padded_vocab_size
37 | self.padded_vocab_size = padded_vocab_size
38 | self.hidden_size = hidden_size
39 | self.ffn_hidden_size = ffn_hidden_size
40 | self.kv_channels = kv_channels
41 | self.num_attention_heads = num_attention_heads
42 | self.seq_length = seq_length
43 | self.hidden_dropout = hidden_dropout
44 | self.classifier_dropout = classifier_dropout
45 | self.attention_dropout = attention_dropout
46 | self.layernorm_epsilon = layernorm_epsilon
47 | self.rmsnorm = rmsnorm
48 | self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm
49 | self.post_layer_norm = post_layer_norm
50 | self.add_bias_linear = add_bias_linear
51 | self.add_qkv_bias = add_qkv_bias
52 | self.bias_dropout_fusion = bias_dropout_fusion
53 | self.multi_query_attention = multi_query_attention
54 | self.multi_query_group_num = multi_query_group_num
55 | self.apply_query_key_layer_scaling = apply_query_key_layer_scaling
56 | self.attention_softmax_in_fp32 = attention_softmax_in_fp32
57 | self.fp32_residual_connection = fp32_residual_connection
58 | self.quantization_bit = quantization_bit
59 | self.pre_seq_len = pre_seq_len
60 | self.prefix_projection = prefix_projection
61 | super().__init__(**kwargs)
62 |
--------------------------------------------------------------------------------
/kolors/models/ipa_faceid_plus/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/kolors/models/ipa_faceid_plus/__init__.py
--------------------------------------------------------------------------------
/kolors/models/ipa_faceid_plus/attention_processor.py:
--------------------------------------------------------------------------------
1 | # modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | class AttnProcessor2_0(torch.nn.Module):
7 | r"""
8 | Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
9 | """
10 | def __init__(
11 | self,
12 | hidden_size=None,
13 | cross_attention_dim=None,
14 | ):
15 | super().__init__()
16 | if not hasattr(F, "scaled_dot_product_attention"):
17 | raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
18 |
19 | def __call__(
20 | self,
21 | attn,
22 | hidden_states,
23 | encoder_hidden_states=None,
24 | attention_mask=None,
25 | temb=None,
26 | ):
27 | residual = hidden_states
28 |
29 | if attn.spatial_norm is not None:
30 | hidden_states = attn.spatial_norm(hidden_states, temb)
31 |
32 | input_ndim = hidden_states.ndim
33 |
34 | if input_ndim == 4:
35 | batch_size, channel, height, width = hidden_states.shape
36 | hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
37 |
38 | batch_size, sequence_length, _ = (
39 | hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
40 | )
41 |
42 | if attention_mask is not None:
43 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
44 | # scaled_dot_product_attention expects attention_mask shape to be
45 | # (batch, heads, source_length, target_length)
46 | attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
47 |
48 | if attn.group_norm is not None:
49 | hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
50 |
51 | query = attn.to_q(hidden_states)
52 |
53 | if encoder_hidden_states is None:
54 | encoder_hidden_states = hidden_states
55 | elif attn.norm_cross:
56 | encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
57 |
58 | key = attn.to_k(encoder_hidden_states)
59 | value = attn.to_v(encoder_hidden_states)
60 |
61 | inner_dim = key.shape[-1]
62 | head_dim = inner_dim // attn.heads
63 |
64 | query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
65 |
66 | key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
67 | value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
68 |
69 | # the output of sdp = (batch, num_heads, seq_len, head_dim)
70 | # TODO: add support for attn.scale when we move to Torch 2.1
71 | hidden_states = F.scaled_dot_product_attention(
72 | query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
73 | )
74 |
75 | hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
76 | hidden_states = hidden_states.to(query.dtype)
77 |
78 | # linear proj
79 | hidden_states = attn.to_out[0](hidden_states)
80 | # dropout
81 | hidden_states = attn.to_out[1](hidden_states)
82 |
83 | if input_ndim == 4:
84 | hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
85 |
86 | if attn.residual_connection:
87 | hidden_states = hidden_states + residual
88 |
89 | hidden_states = hidden_states / attn.rescale_output_factor
90 |
91 | return hidden_states
92 |
93 | class IPAttnProcessor2_0(torch.nn.Module):
94 | r"""
95 | Attention processor for IP-Adapater for PyTorch 2.0.
96 | Args:
97 | hidden_size (`int`):
98 | The hidden size of the attention layer.
99 | cross_attention_dim (`int`):
100 | The number of channels in the `encoder_hidden_states`.
101 | scale (`float`, defaults to 1.0):
102 | the weight scale of image prompt.
103 | num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
104 | The context length of the image features.
105 | """
106 |
107 | def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4):
108 | super().__init__()
109 |
110 | if not hasattr(F, "scaled_dot_product_attention"):
111 | raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
112 |
113 | self.hidden_size = hidden_size
114 | self.cross_attention_dim = cross_attention_dim
115 | self.scale = scale
116 | self.num_tokens = num_tokens
117 |
118 | self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
119 | self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
120 |
121 | def __call__(
122 | self,
123 | attn,
124 | hidden_states,
125 | encoder_hidden_states=None,
126 | attention_mask=None,
127 | temb=None,
128 | ):
129 | residual = hidden_states
130 |
131 | if attn.spatial_norm is not None:
132 | hidden_states = attn.spatial_norm(hidden_states, temb)
133 |
134 | input_ndim = hidden_states.ndim
135 |
136 | if input_ndim == 4:
137 | batch_size, channel, height, width = hidden_states.shape
138 | hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
139 |
140 | batch_size, sequence_length, _ = (
141 | hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
142 | )
143 |
144 | if attention_mask is not None:
145 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
146 | # scaled_dot_product_attention expects attention_mask shape to be
147 | # (batch, heads, source_length, target_length)
148 | attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
149 |
150 | if attn.group_norm is not None:
151 | hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
152 |
153 | query = attn.to_q(hidden_states)
154 |
155 | if encoder_hidden_states is None:
156 | encoder_hidden_states = hidden_states
157 | else:
158 | # get encoder_hidden_states, ip_hidden_states
159 | end_pos = encoder_hidden_states.shape[1] - self.num_tokens
160 | encoder_hidden_states, ip_hidden_states = encoder_hidden_states[:, :end_pos, :], encoder_hidden_states[:, end_pos:, :]
161 | if attn.norm_cross:
162 | encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
163 |
164 | key = attn.to_k(encoder_hidden_states)
165 | value = attn.to_v(encoder_hidden_states)
166 |
167 | inner_dim = key.shape[-1]
168 | head_dim = inner_dim // attn.heads
169 |
170 | query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
171 |
172 | key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
173 | value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
174 |
175 | # the output of sdp = (batch, num_heads, seq_len, head_dim)
176 | # TODO: add support for attn.scale when we move to Torch 2.1
177 | hidden_states = F.scaled_dot_product_attention(
178 | query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
179 | )
180 |
181 | hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
182 | hidden_states = hidden_states.to(query.dtype)
183 |
184 | # for ip-adapter
185 | ip_key = self.to_k_ip(ip_hidden_states)
186 | ip_value = self.to_v_ip(ip_hidden_states)
187 |
188 | ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
189 | ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
190 |
191 | # the output of sdp = (batch, num_heads, seq_len, head_dim)
192 | # TODO: add support for attn.scale when we move to Torch 2.1
193 | ip_hidden_states = F.scaled_dot_product_attention(
194 | query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
195 | )
196 |
197 | ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
198 | ip_hidden_states = ip_hidden_states.to(query.dtype)
199 |
200 | hidden_states = hidden_states + self.scale * ip_hidden_states
201 |
202 | # linear proj
203 | hidden_states = attn.to_out[0](hidden_states)
204 | # dropout
205 | hidden_states = attn.to_out[1](hidden_states)
206 |
207 | if input_ndim == 4:
208 | hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
209 |
210 | if attn.residual_connection:
211 | hidden_states = hidden_states + residual
212 |
213 | hidden_states = hidden_states / attn.rescale_output_factor
214 |
215 | return hidden_states
--------------------------------------------------------------------------------
/kolors/models/ipa_faceid_plus/ipa_faceid_plus.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import math
4 |
5 | def reshape_tensor(x, heads):
6 | bs, length, width = x.shape
7 | #(bs, length, width) --> (bs, length, n_heads, dim_per_head)
8 | x = x.view(bs, length, heads, -1)
9 | # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
10 | x = x.transpose(1, 2)
11 | # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
12 | x = x.reshape(bs, heads, length, -1)
13 | return x
14 |
15 | def FeedForward(dim, mult=4):
16 | inner_dim = int(dim * mult)
17 | return nn.Sequential(
18 | nn.LayerNorm(dim),
19 | nn.Linear(dim, inner_dim, bias=False),
20 | nn.GELU(),
21 | nn.Linear(inner_dim, dim, bias=False),
22 | )
23 |
24 | class PerceiverAttention(nn.Module):
25 | def __init__(self, *, dim, dim_head=64, heads=8):
26 | super().__init__()
27 | self.scale = dim_head**-0.5
28 | self.dim_head = dim_head
29 | self.heads = heads
30 | inner_dim = dim_head * heads
31 |
32 | self.norm1 = nn.LayerNorm(dim)
33 | self.norm2 = nn.LayerNorm(dim)
34 |
35 | self.to_q = nn.Linear(dim, inner_dim, bias=False)
36 | self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
37 | self.to_out = nn.Linear(inner_dim, dim, bias=False)
38 |
39 | def forward(self, x, latents):
40 | """
41 | Args:
42 | x (torch.Tensor): image features
43 | shape (b, n1, D)
44 | latent (torch.Tensor): latent features
45 | shape (b, n2, D)
46 | """
47 | x = self.norm1(x)
48 | latents = self.norm2(latents)
49 |
50 | b, l, _ = latents.shape
51 |
52 | q = self.to_q(latents)
53 | kv_input = torch.cat((x, latents), dim=-2)
54 | k, v = self.to_kv(kv_input).chunk(2, dim=-1)
55 |
56 | q = reshape_tensor(q, self.heads)
57 | k = reshape_tensor(k, self.heads)
58 | v = reshape_tensor(v, self.heads)
59 |
60 | # attention
61 | scale = 1 / math.sqrt(math.sqrt(self.dim_head))
62 | weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
63 | weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
64 | out = weight @ v
65 |
66 | out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
67 |
68 | return self.to_out(out)
69 |
70 | class FacePerceiverResampler(torch.nn.Module):
71 | def __init__(
72 | self,
73 | *,
74 | dim=768,
75 | depth=4,
76 | dim_head=64,
77 | heads=16,
78 | embedding_dim=1280,
79 | output_dim=768,
80 | ff_mult=4,
81 | ):
82 | super().__init__()
83 |
84 | self.proj_in = torch.nn.Linear(embedding_dim, dim)
85 | self.proj_out = torch.nn.Linear(dim, output_dim)
86 | self.norm_out = torch.nn.LayerNorm(output_dim)
87 | self.layers = torch.nn.ModuleList([])
88 | for _ in range(depth):
89 | self.layers.append(
90 | torch.nn.ModuleList(
91 | [
92 | PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
93 | FeedForward(dim=dim, mult=ff_mult),
94 | ]
95 | )
96 | )
97 |
98 | def forward(self, latents, x):
99 | x = self.proj_in(x)
100 | for attn, ff in self.layers:
101 | latents = attn(x, latents) + latents
102 | latents = ff(latents) + latents
103 | latents = self.proj_out(latents)
104 | return self.norm_out(latents)
105 |
106 | class ProjPlusModel(torch.nn.Module):
107 | def __init__(self, cross_attention_dim=768, id_embeddings_dim=512, clip_embeddings_dim=1280, num_tokens=4):
108 | super().__init__()
109 |
110 | self.cross_attention_dim = cross_attention_dim
111 | self.num_tokens = num_tokens
112 |
113 | self.proj = torch.nn.Sequential(
114 | torch.nn.Linear(id_embeddings_dim, id_embeddings_dim*2),
115 | torch.nn.GELU(),
116 | torch.nn.Linear(id_embeddings_dim*2, cross_attention_dim*num_tokens),
117 | )
118 | self.norm = torch.nn.LayerNorm(cross_attention_dim)
119 |
120 | self.perceiver_resampler = FacePerceiverResampler(
121 | dim=cross_attention_dim,
122 | depth=4,
123 | dim_head=64,
124 | heads=cross_attention_dim // 64,
125 | embedding_dim=clip_embeddings_dim,
126 | output_dim=cross_attention_dim,
127 | ff_mult=4,
128 | )
129 |
130 | def forward(self, id_embeds, clip_embeds, shortcut = True, scale = 1.0):
131 | x = self.proj(id_embeds)
132 | x = x.reshape(-1, self.num_tokens, self.cross_attention_dim)
133 | x = self.norm(x)
134 | out = self.perceiver_resampler(x, clip_embeds)
135 | if shortcut:
136 | out = x + scale * out
137 | return out
--------------------------------------------------------------------------------
/kolors/pipelines/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/kolors/pipelines/__init__.py
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | fire
2 | triton
3 | pydantic==2.8.2
4 | accelerate==0.27.2
5 | deepspeed==0.8.1
6 | huggingface-hub==0.23.4
7 | imageio==2.25.1
8 | numpy==1.21.6
9 | omegaconf==2.3.0
10 | pandas==1.3.5
11 | Pillow==9.4.0
12 | tokenizers==0.13.2
13 | torch==1.13.1
14 | torchvision==0.14.1
15 | transformers==4.42.4
16 | xformers==0.0.16
17 | safetensors==0.3.3
18 | diffusers==0.28.2
19 | sentencepiece==0.1.99
20 | gradio==4.37.2
21 | opencv-python
22 | einops
23 | timm
24 | onnxruntime
25 |
--------------------------------------------------------------------------------
/scripts/outputs/sample_inpainting_3.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/scripts/outputs/sample_inpainting_3.jpg
--------------------------------------------------------------------------------
/scripts/outputs/sample_inpainting_4.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/scripts/outputs/sample_inpainting_4.jpg
--------------------------------------------------------------------------------
/scripts/outputs/sample_ip_test_ip.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/scripts/outputs/sample_ip_test_ip.jpg
--------------------------------------------------------------------------------
/scripts/outputs/sample_ip_test_ip2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/scripts/outputs/sample_ip_test_ip2.jpg
--------------------------------------------------------------------------------
/scripts/outputs/sample_test.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/scripts/outputs/sample_test.jpg
--------------------------------------------------------------------------------
/scripts/sample.py:
--------------------------------------------------------------------------------
1 | import os, torch
2 | # from PIL import Image
3 | from kolors.pipelines.pipeline_stable_diffusion_xl_chatglm_256 import StableDiffusionXLPipeline
4 | from kolors.models.modeling_chatglm import ChatGLMModel
5 | from kolors.models.tokenization_chatglm import ChatGLMTokenizer
6 | from diffusers import UNet2DConditionModel, AutoencoderKL
7 | from diffusers import EulerDiscreteScheduler
8 |
9 | root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
10 |
11 | def infer(prompt):
12 | ckpt_dir = f'{root_dir}/weights/Kolors'
13 | text_encoder = ChatGLMModel.from_pretrained(
14 | f'{ckpt_dir}/text_encoder',
15 | torch_dtype=torch.float16).half()
16 | tokenizer = ChatGLMTokenizer.from_pretrained(f'{ckpt_dir}/text_encoder')
17 | vae = AutoencoderKL.from_pretrained(f"{ckpt_dir}/vae", revision=None).half()
18 | scheduler = EulerDiscreteScheduler.from_pretrained(f"{ckpt_dir}/scheduler")
19 | unet = UNet2DConditionModel.from_pretrained(f"{ckpt_dir}/unet", revision=None).half()
20 | pipe = StableDiffusionXLPipeline(
21 | vae=vae,
22 | text_encoder=text_encoder,
23 | tokenizer=tokenizer,
24 | unet=unet,
25 | scheduler=scheduler,
26 | force_zeros_for_empty_prompt=False)
27 | pipe = pipe.to("cuda")
28 | pipe.enable_model_cpu_offload()
29 | image = pipe(
30 | prompt=prompt,
31 | height=1024,
32 | width=1024,
33 | num_inference_steps=50,
34 | guidance_scale=5.0,
35 | num_images_per_prompt=1,
36 | generator= torch.Generator(pipe.device).manual_seed(66)).images[0]
37 | image.save(f'{root_dir}/scripts/outputs/sample_test.jpg')
38 |
39 |
40 | if __name__ == '__main__':
41 | import fire
42 | fire.Fire(infer)
43 |
--------------------------------------------------------------------------------
/scripts/sampleui.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import gradio as gr
4 | # from PIL import Image
5 | from kolors.pipelines.pipeline_stable_diffusion_xl_chatglm_256 import StableDiffusionXLPipeline
6 | from kolors.models.modeling_chatglm import ChatGLMModel
7 | from kolors.models.tokenization_chatglm import ChatGLMTokenizer
8 | from diffusers import UNet2DConditionModel, AutoencoderKL
9 | from diffusers import EulerDiscreteScheduler
10 |
11 | root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
12 |
13 | # Initialize global variables for models and pipeline
14 | text_encoder = None
15 | tokenizer = None
16 | vae = None
17 | scheduler = None
18 | unet = None
19 | pipe = None
20 |
21 | def load_models():
22 | global text_encoder, tokenizer, vae, scheduler, unet, pipe
23 |
24 | if text_encoder is None:
25 | ckpt_dir = f'{root_dir}/weights/Kolors'
26 |
27 | # Load the text encoder on CPU (this speeds stuff up 2x)
28 | text_encoder = ChatGLMModel.from_pretrained(
29 | f'{ckpt_dir}/text_encoder',
30 | torch_dtype=torch.float16).to('cpu').half()
31 | tokenizer = ChatGLMTokenizer.from_pretrained(f'{ckpt_dir}/text_encoder')
32 |
33 | # Load the VAE and UNet on GPU
34 | vae = AutoencoderKL.from_pretrained(f"{ckpt_dir}/vae", revision=None).half().to('cuda')
35 | scheduler = EulerDiscreteScheduler.from_pretrained(f"{ckpt_dir}/scheduler")
36 | unet = UNet2DConditionModel.from_pretrained(f"{ckpt_dir}/unet", revision=None).half().to('cuda')
37 |
38 | # Prepare the pipeline
39 | pipe = StableDiffusionXLPipeline(
40 | vae=vae,
41 | text_encoder=text_encoder,
42 | tokenizer=tokenizer,
43 | unet=unet,
44 | scheduler=scheduler,
45 | force_zeros_for_empty_prompt=False)
46 | pipe = pipe.to("cuda")
47 | pipe.enable_model_cpu_offload() # Enable offloading to balance CPU/GPU usage
48 |
49 | def infer(prompt, use_random_seed, seed, height, width, num_inference_steps, guidance_scale, num_images_per_prompt):
50 | load_models()
51 |
52 | if use_random_seed:
53 | seed = torch.randint(0, 2**32 - 1, (1,)).item()
54 |
55 | generator = torch.Generator(pipe.device).manual_seed(seed)
56 | images = pipe(
57 | prompt=prompt,
58 | height=height,
59 | width=width,
60 | num_inference_steps=num_inference_steps,
61 | guidance_scale=guidance_scale,
62 | num_images_per_prompt=num_images_per_prompt,
63 | generator=generator
64 | ).images
65 |
66 | saved_images = []
67 | output_dir = f'{root_dir}/scripts/outputs'
68 | os.makedirs(output_dir, exist_ok=True)
69 |
70 | for i, image in enumerate(images):
71 | file_path = os.path.join(output_dir, 'sample_test.jpg')
72 | base_name, ext = os.path.splitext(file_path)
73 | counter = 1
74 | while os.path.exists(file_path):
75 | file_path = f"{base_name}_{counter}{ext}"
76 | counter += 1
77 | image.save(file_path)
78 | saved_images.append(file_path)
79 |
80 | return saved_images
81 |
82 | def gradio_interface():
83 | with gr.Blocks() as demo:
84 | with gr.Row():
85 | with gr.Column():
86 | gr.Markdown("## Kolors: Diffusion Model Gradio Interface")
87 | prompt = gr.Textbox(label="Prompt")
88 | use_random_seed = gr.Checkbox(label="Use Random Seed", value=True)
89 | seed = gr.Slider(minimum=0, maximum=2**32 - 1, step=1, label="Seed", randomize=True, visible=False)
90 | use_random_seed.change(lambda x: gr.update(visible=not x), use_random_seed, seed)
91 | height = gr.Slider(minimum=128, maximum=2048, step=64, label="Height", value=1024)
92 | width = gr.Slider(minimum=128, maximum=2048, step=64, label="Width", value=1024)
93 | num_inference_steps = gr.Slider(minimum=1, maximum=100, step=1, label="Inference Steps", value=50)
94 | guidance_scale = gr.Slider(minimum=1.0, maximum=20.0, step=0.1, label="Guidance Scale", value=5.0)
95 | num_images_per_prompt = gr.Slider(minimum=1, maximum=10, step=1, label="Images per Prompt", value=1)
96 | btn = gr.Button("Generate Image")
97 |
98 | with gr.Column():
99 | output_images = gr.Gallery(label="Output Images", elem_id="output_gallery")
100 |
101 | btn.click(
102 | fn=infer,
103 | inputs=[prompt, use_random_seed, seed, height, width, num_inference_steps, guidance_scale, num_images_per_prompt],
104 | outputs=output_images
105 | )
106 |
107 | return demo
108 |
109 | if __name__ == '__main__':
110 | gradio_interface().launch()
111 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | setup(
4 | name="kolors",
5 | version="0.1",
6 | author="Kolors",
7 | description="The training and inference code for Kolors models.",
8 | packages=find_packages(),
9 | install_requires=[],
10 | dependency_links=[],
11 | )
12 |
--------------------------------------------------------------------------------