├── .gitignore ├── LICENSE ├── MODEL_LICENSE ├── README.md ├── README_CN.md ├── controlnet ├── README.md ├── annotator │ ├── __init__.py │ ├── canny │ │ └── __init__.py │ ├── dwpose │ │ ├── __init__.py │ │ ├── onnxdet.py │ │ ├── onnxpose.py │ │ ├── util.py │ │ └── wholebody.py │ ├── midas │ │ ├── LICENSE │ │ ├── __init__.py │ │ ├── api.py │ │ ├── midas │ │ │ ├── __init__.py │ │ │ ├── base_model.py │ │ │ ├── blocks.py │ │ │ ├── dpt_depth.py │ │ │ ├── midas_net.py │ │ │ ├── midas_net_custom.py │ │ │ ├── transforms.py │ │ │ └── vit.py │ │ └── utils.py │ └── util.py ├── assets │ ├── bird.png │ ├── dog.png │ ├── woman_1.png │ ├── woman_2.png │ ├── woman_3.png │ └── woman_4.png ├── outputs │ ├── Canny_dog.jpg │ ├── Canny_dog_condition.jpg │ ├── Canny_woman_1.jpg │ ├── Canny_woman_1_condition.jpg │ ├── Canny_woman_1_sdxl.jpg │ ├── Depth_1_condition.jpg │ ├── Depth_bird.jpg │ ├── Depth_bird_condition.jpg │ ├── Depth_bird_sdxl.jpg │ ├── Depth_ipadapter_1.jpg │ ├── Depth_ipadapter_woman_2.jpg │ ├── Depth_woman_2.jpg │ ├── Depth_woman_2_condition.jpg │ ├── Pose_woman_3.jpg │ ├── Pose_woman_3_condition.jpg │ ├── Pose_woman_3_sdxl.jpg │ ├── Pose_woman_4.jpg │ └── Pose_woman_4_condition.jpg ├── sample_controlNet.py └── sample_controlNet_ipadapter.py ├── dreambooth ├── README.md ├── default_config.yaml ├── infer_dreambooth.py ├── ktxl_test_image.png ├── train.sh └── train_dreambooth_lora.py ├── imgs ├── Kolors_paper.pdf ├── cn_all.png ├── fz_all.png ├── head_final3.png ├── logo.png ├── prompt_vis.txt ├── wechat.png ├── wz_all.png ├── zl8.png ├── 可图KOLORS模型商业授权申请书-英文版本.docx └── 可图KOLORS模型商业授权申请书.docx ├── inpainting ├── README.md ├── asset │ ├── 1.png │ ├── 1_kolors.png │ ├── 1_masked.png │ ├── 1_sdxl.png │ ├── 2.png │ ├── 2_kolors.png │ ├── 2_masked.png │ ├── 2_sdxl.png │ ├── 3.png │ ├── 3_kolors.png │ ├── 3_mask.png │ ├── 3_masked.png │ ├── 3_sdxl.png │ ├── 4.png │ ├── 4_kolors.png │ ├── 4_mask.png │ ├── 4_masked.png │ ├── 4_sdxl.png │ ├── 5.png │ ├── 5_kolors.png │ ├── 5_masked.png │ └── 5_sdxl.png └── sample_inpainting.py ├── ipadapter ├── README.md ├── asset │ ├── 1.png │ ├── 1_kolors_ip_result.jpg │ ├── 1_mj_cw_result.png │ ├── 1_sdxl_ip_result.jpg │ ├── 2.png │ ├── 2_kolors_ip_result.jpg │ ├── 2_mj_cw_result.png │ ├── 2_sdxl_ip_result.jpg │ ├── 3.png │ ├── 3_kolors_ip_result.jpg │ ├── 3_mj_cw_result.png │ ├── 3_sdxl_ip_result.jpg │ ├── 4.png │ ├── 4_kolors_ip_result.jpg │ ├── 4_mj_cw_result.png │ ├── 4_sdxl_ip_result.jpg │ ├── 5.png │ ├── 5_kolors_ip_result.jpg │ ├── 5_mj_cw_result.png │ ├── 5_sdxl_ip_result.jpg │ ├── test_ip.jpg │ └── test_ip2.png └── sample_ipadapter_plus.py ├── ipadapter_FaceID ├── README.md ├── assets │ ├── image1.png │ ├── image1_res.png │ ├── image2.png │ ├── image2_res.png │ ├── test_img1_Kolors.png │ ├── test_img1_SDXL.png │ ├── test_img1_org.png │ ├── test_img2_Kolors.png │ ├── test_img2_SDXL.png │ ├── test_img2_org.png │ ├── test_img3_Kolors.png │ ├── test_img3_SDXL.png │ ├── test_img3_org.png │ ├── test_img4_Kolors.png │ ├── test_img4_SDXL.png │ └── test_img4_org.png └── sample_ipadapter_faceid_plus.py ├── kolors ├── __init__.py ├── models │ ├── __init__.py │ ├── configuration_chatglm.py │ ├── controlnet.py │ ├── ipa_faceid_plus │ │ ├── __init__.py │ │ ├── attention_processor.py │ │ └── ipa_faceid_plus.py │ ├── modeling_chatglm.py │ ├── tokenization_chatglm.py │ └── unet_2d_condition.py └── pipelines │ ├── __init__.py │ ├── pipeline_controlnet_xl_kolors_img2img.py │ ├── pipeline_stable_diffusion_xl_chatglm_256.py │ ├── pipeline_stable_diffusion_xl_chatglm_256_inpainting.py │ ├── pipeline_stable_diffusion_xl_chatglm_256_ipadapter.py │ └── pipeline_stable_diffusion_xl_chatglm_256_ipadapter_FaceID.py ├── requirements.txt ├── scripts ├── outputs │ ├── sample_inpainting_3.jpg │ ├── sample_inpainting_4.jpg │ ├── sample_ip_test_ip.jpg │ ├── sample_ip_test_ip2.jpg │ └── sample_test.jpg ├── sample.py └── sampleui.py └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | weights 2 | *.egg-info 3 | __pycache__ 4 | *.pyc 5 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /MODEL_LICENSE: -------------------------------------------------------------------------------- 1 | 中文版 2 | 模型许可协议 3 | 模型发布日期:2024/7/6 4 | 5 | 通过点击同意或使用、复制、修改、分发、表演或展示模型作品的任何部分或元素,您将被视为已承认并接受本协议的内容,本协议立即生效。 6 | 7 | 1.定义。 8 | a. “协议”指本协议中所规定的使用、复制、分发、修改、表演和展示模型作品或其任何部分或元素的条款和条件。 9 | b. “材料”是指根据本协议提供的专有的模型和文档(及其任何部分)的统称。 10 | c. “模型”指大型语言模型、图像/视频/音频/3D 生成模型、多模态大型语言模型及其软件和算法,包括训练后的模型权重、参数(包括优化器状态)、机器学习模型代码、推理支持代码、训练支持代码、微调支持代码以及我们公开提供的前述其他元素。 11 | d. “输出”是指通过操作或以其他方式使用模型或模型衍生品而产生的模型或模型衍生品的信息和/或内容输出。 12 | e. “模型衍生品”包括:(i)对模型或任何模型衍生物的修改;(ii)基于模型的任何模型衍生物的作品;或(iii)通过将模型或模型的任何模型衍生物的权重、参数、操作或输出的模式转移到该模型而创建的任何其他机器学习模型,以使该模型的性能类似于模型或模型衍生物。为清楚起见,输出本身不被视为模型衍生物。 13 | f. “模型作品”包括:(i)材料;(ii)模型衍生品;及(iii)其所有衍生作品。 14 | g. “许可人”或“我们”指作品所有者或作品所有者授权的授予许可的实体,包括可能对模型和/或分发模型拥有权利的个人或实体。 15 | h.“被许可人”、“您”或“您的”是指行使本协议授予的权利和/或为任何目的和在任何使用领域使用模型作品的自然人或法人实体。 16 | i.“第三方”是指不受我们或您共同控制的个人或法人实体。 17 | 18 | 2. 许可内容。 19 | a.我们授予您非排他性的、全球性的、不可转让的、免版税的许可(在我们的知识产权或我们拥有的体现在材料中或利用材料的其他权利的范围内),允许您仅根据本协议的条款使用、复制、分发、创作衍生作品(包括模型衍生品)和对材料进行修改,并且您不得违反(或鼓励、或允许任何其他人违反)本协议的任何条款。 20 | b.在遵守本协议的前提下,您可以分发或向第三方提供模型作品,您须满足以下条件: 21 | (i)您必须向所有该模型作品或使用该作品的产品或服务的任何第三方接收者提供模型作品的来源和本协议的副本; 22 | (ii)您必须在任何修改过的文档上附加明显的声明,说明您更改了这些文档; 23 | (iii)您可以在您的修改中添加您自己的版权声明,并且,在您对该作品的使用、复制、修改、分发、表演和展示符合本协议的条款和条件的前提下,您可以为您的修改或任何此类模型衍生品的使用、复制或分发提供额外或不同的许可条款和条件。 24 | c. 附加商业条款:若您或其关联方提供的所有产品或服务的月活跃用户数在前一个自然月未超过3亿月活跃用户数,则您向许可方进行登记,将被视为获得相应的商业许可;若您或其关联方提供的所有产品或服务的月活跃用户数在前一个自然月超过3亿月活跃用户数,则您必须向许可人申请许可,许可人可自行决定向您授予许可。除非许可人另行明确授予您该等权利,否则您无权行使本协议项下的任何权利。 25 | 26 | 3.使用限制。 27 | a. 您对本模型作品的使用必须遵守适用法律法规(包括贸易合规法律法规),并遵守《服务协议》(https://kolors.kuaishou.com/agreement)。您必须将本第 3(a) 和 3(b) 条中提及的使用限制作为可执行条款纳入任何规范本模型作品使用和/或分发的协议(例如许可协议、使用条款等),并且您必须向您分发的后续用户发出通知,告知其本模型作品受本第 3(a) 和 3(b) 条中的使用限制约束。 28 | b. 您不得使用本模型作品或本模型作品的任何输出或成果来改进任何其他模型(本模型或其模型衍生品除外)。 29 | 30 | 4.知识产权。 31 | a. 我们保留材料的所有权及其相关知识产权。在遵守本协议条款和条件的前提下,对于您制作的材料的任何衍生作品和修改,您是且将是此类衍生作品和修改的所有者。 32 | b. 本协议不授予任何商标、商号、服务标记或产品名称的标识许可,除非出于描述和分发本模型作品的合理和惯常用途。 33 | c. 如果您对我们或任何个人或实体提起诉讼或其他程序(包括诉讼中的交叉索赔或反索赔),声称材料或任何输出或任何上述内容的任何部分侵犯您拥有或可许可的任何知识产权或其他权利,则根据本协议授予您的所有许可应于提起此类诉讼或其他程序之日起终止。 34 | 35 | 5. 免责声明和责任限制。 36 | a. 本模型作品及其任何输出和结果按“原样”提供,不作任何明示或暗示的保证,包括适销性、非侵权性或适用于特定用途的保证。我们不对材料及其任何输出的安全性或稳定性作任何保证,也不承担任何责任。 37 | b. 在任何情况下,我们均不对您承担任何损害赔偿责任,包括但不限于因您使用或无法使用材料或其任何输出而造成的任何直接、间接、特殊或后果性损害赔偿责任,无论该损害赔偿责任是如何造成的。 38 | c. 对于因您使用或分发模型的衍生物而引起的或与之相关的任何第三方索赔,您应提供辩护,赔偿,并使我方免受损害。 39 | 40 | 6. 存续和终止。 41 | a. 本协议期限自您接受本协议或访问材料之日起开始,并将持续完全有效,直至根据本协议条款和条件终止。 42 | b. 如果您违反本协议的任何条款或条件,我们可终止本协议。本协议终止后,您必须立即删除并停止使用本模型作品。第 4(a)、4(c)、5和 7 条在本协议终止后仍然有效。 43 | 44 | 7. 适用法律和管辖权。 45 | a. 本协议及由本协议引起的或与本协议有关的任何争议均受中华人民共和国大陆地区(仅为本协议目的,不包括香港、澳门和台湾)法律管辖,并排除冲突法的适用,且《联合国国际货物销售合同公约》不适用于本协议。 46 | b. 因本协议引起或与本协议有关的任何争议,由许可人住所地人民法院管辖。 47 | 48 | 请注意,许可证可能会更新到更全面的版本。 有关许可和版权的任何问题,请通过 kwai-kolors@kuaishou.com 与我们联系。 49 |   50 | 51 | 英文版 52 | 53 | MODEL LICENSE AGREEMENT 54 | Release Date: 2024/7/6 55 | By clicking to agree or by using, reproducing, modifying, distributing, performing or displaying any portion or element of the Model Works, You will be deemed to have recognized and accepted the content of this Agreement, which is effective immediately. 56 | 1. DEFINITIONS. 57 | a. “Agreement” shall mean the terms and conditions for use, reproduction, distribution, modification, performance and displaying of the Model Works or any portion or element thereof set forth herein. 58 | b. “Materials” shall mean, collectively, Us proprietary the Model and Documentation (and any portion thereof) as made available by Us under this Agreement. 59 | c. “Model” shall mean the large language models, image/video/audio/3D generation models, and multimodal large language models and their software and algorithms, including trained model weights, parameters (including optimizer states), machine-learning model code, inference-enabling code, training-enabling code, fine-tuning enabling code and other elements of the foregoing made publicly available by Us . 60 | d. “Output” shall mean the information and/or content output of Model or a Model Derivative that results from operating or otherwise using Model or a Model Derivative. 61 | e. “Model Derivatives” shall mean all: (i) modifications to the Model or any Model Derivative; (ii) works based on the Model or any Model Derivative; or (iii) any other machine learning model which is created by transfer of patterns of the weights, parameters, operations, or Output of the Model or any Model Derivative, to that model in order to cause that model to perform similarly to the Model or a Model Derivative, including distillation methods, methods that use intermediate data representations, or methods based on the generation of synthetic data Outputs or a Model Derivative for training that model. For clarity, Outputs by themselves are not deemed Model Derivatives. 62 | f. “Model Works” shall mean: (i) the Materials; (ii) Model Derivatives; and (iii) all derivative works thereof. 63 | g. “Licensor” , “We” or “Us” shall mean the copyright owner or entity authorized by the copyright owner that is granting the License, including the persons or entities that may have rights in the Model and/or distributing the Model. 64 | h. “Licensee”, “You” or “Your” shall mean a natural person or legal entity exercising the rights granted by this Agreement and/or using the Model Works for any purpose and in any field of use. 65 | i. “Third Party” or “Third Parties” shall mean individuals or legal entities that are not under common control with Us or You. 66 | 67 | 2. LICENSE CONTENT. 68 | a. We grant You a non-exclusive, worldwide, non-transferable and royalty-free limited license under the intellectual property or other rights owned by Us embodied in or utilized by the Materials to use, reproduce, distribute, create derivative works of (including Model Derivatives), and make modifications to the Materials, only in accordance with the terms of this Agreement and the Acceptable Use Policy, and You must not violate (or encourage or permit anyone else to violate) any term of this Agreement or the Acceptable Use Policy. 69 | b. You may, subject to Your compliance with this Agreement, distribute or make available to Third Parties the Model Works, provided that You meet all of the following conditions: 70 | (i) You must provide all such Third Party recipients of the Model Works or products or services using them the source of the Model and a copy of this Agreement; 71 | (ii) You must cause any modified documents to carry prominent notices stating that You changed the documents; 72 | (iii) You may add Your own copyright statement to Your modifications and, may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Model Derivatives as a whole, provided Your use, reproduction, modification, distribution, performance and display of the work otherwise complies with the terms and conditions of this Agreement. 73 | c. additional commercial terms: If, the monthly active users of all products or services made available by or for You, or Your affiliates, does not exceed 300 million monthly active users in the preceding calendar month, Your registration with the Licensor will be deemed to have obtained the corresponding business license; If, the monthly active users of all products or services made available by or for You, or Your affiliates, is greater than 300 million monthly active users in the preceding calendar month, You must request a license from Licensor, which the Licensor may grant to You in its sole discretion, and You are not authorized to exercise any of the rights under this Agreement unless or until We otherwise expressly grants You such rights. 74 | 75 | 3. LICENSE RESTRICITIONS. 76 | a. Your use of the Model Works must comply with applicable laws and regulations (including trade compliance laws and regulations) and adhere to the Service Agreement. You must include the use restrictions referenced in these Sections 3(a) and 3(b) as an enforceable provision in any agreement (e.g., license agreement, terms of use, etc.) governing the use and/or distribution of Model Works and You must provide notice to subsequent users to whom You distribute that Model Works are subject to the use restrictions in these Sections 3(a) and 3(b). 77 | b. You must not use the Model Works or any Output or results of the Model Works to improve any other large model (other than Model or Model Derivatives thereof). 78 | 4. INTELLECTUAL PROPERTY. 79 | a. We retain ownership of all intellectual property rights in and to the Model and derivatives. Conditioned upon compliance with the terms and conditions of this Agreement, with respect to any derivative works and modifications of the Materials that are made by You, You are and will be the owner of such derivative works and modifications. 80 | b. No trademark license is granted to use the trade names, trademarks, service marks, or product names of Us, except as required to fulfill notice requirements under this Agreement or as required for reasonable and customary use in describing and redistributing the Materials. 81 | c. If You commence a lawsuit or other proceedings (including a cross-claim or counterclaim in a lawsuit) against Us or any person or entity alleging that the Materials or any Output, or any portion of any of the foregoing, infringe any intellectual property or other right owned or licensable by You, then all licenses granted to You under this Agreement shall terminate as of the date such lawsuit or other proceeding is filed. 82 | 5. DISCLAIMERS OF WARRANTY AND LIMITATIONS OF LIABILITY. 83 | a. THE MODEL WORKS AND ANY OUTPUT AND RESULTS THERE FROM ARE PROVIDED "AS IS" WITHOUT ANY EXPRESS OR IMPLIED WARRANTY OF ANY KIND INCLUDING WARRANTIES OF MERCHANTABILITY, NONINFRINGEMENT, OR FITNESS FOR A PARTICULAR PURPOSE. WE MAKE NO WARRANTY AND ASSUME NO RESPONSIBILITY FOR THE SAFETY OR STABILITY OF THE MATERIALS AND ANY OUTPUT THEREFROM. 84 | b. IN NO EVENT SHALL WE BE LIABLE TO YOU FOR ANY DAMAGES, INCLUDING, BUT NOT LIMITED TO ANY DIRECT, OR INDIRECT, SPECIAL OR CONSEQUENTIAL DAMAGES ARISING FROM YOUR USE OR INABILITY TO USE THE MATERIALS OR ANY OUTPUT OF IT, NO MATTER HOW IT’S CAUSED. 85 | c. You will defend, indemnify and hold harmless Us from and against any claim by any third party arising out of or related to Your use or distribution of the Materials. 86 | 87 | 6. SURVIVAL AND TERMINATION. 88 | a. The term of this Agreement shall commence upon Your acceptance of this Agreement or access to the Materials and will continue in full force and effect until terminated in accordance with the terms and conditions herein. 89 | b. We may terminate this Agreement if You breach any of the terms or conditions of this Agreement. Upon termination of this Agreement, You must promptly delete and cease use of the Model Works. Sections 4(a), 4(c), 5 and 7 shall survive the termination of this Agreement. 90 | 7. GOVERNING LAW AND JURISDICTION. 91 | a. This Agreement and any dispute arising out of or relating to it will be governed by the laws of China (for the purpose of this agreement only, excluding Hong Kong, Macau, and Taiwan), without regard to conflict of law principles, and the UN Convention on Contracts for the International Sale of Goods does not apply to this Agreement. 92 | b. Any disputes arising from or related to this Agreement shall be under the jurisdiction of the People's Court where the Licensor is located. 93 | 94 | Note that the license is subject to update to a more comprehensive version. For any questions related to the license and copyright, please contact us at kwai-kolors@kuaishou.com. 95 | -------------------------------------------------------------------------------- /controlnet/README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | ## 📖 Introduction 5 | 6 | We provide three ControlNet weights and inference code based on Kolors-Basemodel: Canny, Depth and Pose. You can find some example images below. 7 | 8 | 9 | **1、ControlNet Demos** 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 |
Condition Image Prompt Result Image
全景,一只可爱的白色小狗坐在杯子里,看向镜头,动漫风格,3d渲染,辛烷值渲染。

Panorama of a cute white puppy sitting in a cup and looking towards the camera, anime style, 3d rendering, octane rendering.
新海诚风格,丰富的色彩,穿着绿色衬衫的女人站在田野里,唯美风景,清新明亮,斑驳的光影,最好的质量,超细节,8K画质。

Makoto Shinkai style, rich colors, a woman in a green shirt standing in the field, beautiful scenery, fresh and bright, mottled light and shadow, best quality, ultra-detailed, 8K quality.
一个穿着黑色运动外套、白色内搭,上面戴着项链的女子,站在街边,背景是红色建筑和绿树,高品质,超清晰,色彩鲜艳,超高分辨率,最佳品质,8k,高清,4K。

A woman wearing a black sports jacket and a white top, adorned with a necklace, stands by the street, with a background of red buildings and green trees. high quality, ultra clear, colorful, ultra high resolution, best quality, 8k, HD, 4K.
40 | 41 | 42 | 43 | **2、ControlNet and IP-Adapter-Plus Demos** 44 | 45 | We also support joint inference code between Kolors-IPadapter and Kolors-ControlNet. 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 |
Reference Image Condition Image Prompt Result Image
一个红色头发的女孩,唯美风景,清新明亮,斑驳的光影,最好的质量,超细节,8K画质。

A girl with red hair, beautiful scenery, fresh and bright, mottled light and shadow, best quality, ultra-detailed, 8K quality.
一个漂亮的女孩,最好的质量,超细节,8K画质。

A beautiful girl, best quality, super detail, 8K quality.
70 | 71 |
72 | 73 | 74 | 75 | 76 | ## 📊 Evaluation 77 | To evaluate the performance of models, we compiled a test set of more than 200 images and text prompts. We invite several image experts to provide fair ratings for the generated results of different models. The experts rate the generated images based on four criteria: visual appeal, text faithfulness, conditional controllability, and overall satisfaction. Conditional controllability measures controlnet's ability to preserve spatial structure, while the other criteria follow the evaluation standards of BaseModel. The specific results are summarized in the table below, where Kolors-ControlNet achieved better performance in various criterias. 78 | 79 | **1、Canny** 80 | 81 | | Model | Average Overall Satisfaction | Average Visual Appeal | Average Text Faithfulness | Average Conditional Controllability | 82 | | :--------------: | :--------: | :--------: | :--------: | :--------: | 83 | | SDXL-ControlNet-Canny | 3.14 | 3.63 | 4.37 | 2.84 | 84 | | **Kolors-ControlNet-Canny** | **4.06** | **4.64** | **4.45** | **3.52** | 85 | 86 | 87 | 88 | **2、Depth** 89 | 90 | | Model | Average Overall Satisfaction | Average Visual Appeal | Average Text Faithfulness | Average Conditional Controllability | 91 | | :--------------: | :--------: | :--------: | :--------: | :--------: | 92 | | SDXL-ControlNet-Depth | 3.35 | 3.77 | 4.26 | 4.5 | 93 | | **Kolors-ControlNet-Depth** | **4.12** | **4.12** | **4.62** | **4.6** | 94 | 95 | 96 | 97 | **3、Pose** 98 | 99 | | Model | Average Overall Satisfaction | Average Visual Appeal | Average Text Faithfulness | Average Conditional Controllability | 100 | | :--------------: | :--------: | :--------: | :--------: | :--------: | 101 | | SDXL-ControlNet-Pose | 1.70 | 2.78 | 4.05 | 1.98 | 102 | | **Kolors-ControlNet-Pose** | **3.33** | **3.63** | **4.78** | **4.4** | 103 | 104 | 105 | *The [SDXL-ControlNet-Canny](https://huggingface.co/diffusers/controlnet-canny-sdxl-1.0) and [SDXL-ControlNet-Depth](https://huggingface.co/diffusers/controlnet-depth-sdxl-1.0) load [DreamShaper-XL](https://civitai.com/models/112902?modelVersionId=351306) as backbone model.* 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 | 141 | 142 | 143 | 144 |
Compare Result
Condition Image Prompt Kolors-ControlNet Result SDXL-ControlNet Result
一个漂亮的女孩,高品质,超清晰,色彩鲜艳,超高分辨率,最佳品质,8k,高清,4K。

A beautiful girl, high quality, ultra clear, colorful, ultra high resolution, best quality, 8k, HD, 4K.
一只颜色鲜艳的小鸟,高品质,超清晰,色彩鲜艳,超高分辨率,最佳品质,8k,高清,4K。

A colorful bird, high quality, ultra clear, colorful, ultra high resolution, best quality, 8k, HD, 4K.
一位穿着紫色泡泡袖连衣裙、戴着皇冠和白色蕾丝手套的女孩双手托脸,高品质,超清晰,色彩鲜艳,超高分辨率 ,最佳品质,8k,高清,4K。

A girl wearing a purple puff-sleeve dress, with a crown and white lace gloves, is cupping her face with both hands. High quality, ultra-clear, vibrant colors, ultra-high resolution, best quality, 8k, HD, 4K.
145 | 146 | 147 | ------ 148 | 149 | 150 | ## 🛠️ Usage 151 | 152 | ### Requirements 153 | 154 | The dependencies and installation are basically the same as the [Kolors-BaseModel](https://huggingface.co/Kwai-Kolors/Kolors). 155 | 156 |
157 | 158 | 159 | ### Weights download 160 | ```bash 161 | # Canny - ControlNet 162 | huggingface-cli download --resume-download Kwai-Kolors/Kolors-ControlNet-Canny --local-dir weights/Kolors-ControlNet-Canny 163 | 164 | # Depth - ControlNet 165 | huggingface-cli download --resume-download Kwai-Kolors/Kolors-ControlNet-Depth --local-dir weights/Kolors-ControlNet-Depth 166 | 167 | # Pose - ControlNet 168 | huggingface-cli download --resume-download Kwai-Kolors/Kolors-ControlNet-Pose --local-dir weights/Kolors-ControlNet-Pose 169 | ``` 170 | 171 | If you intend to utilize the depth estimation network, please make sure to download its corresponding model weights. 172 | ``` 173 | huggingface-cli download lllyasviel/Annotators ./dpt_hybrid-midas-501f0c75.pt --local-dir ./controlnet/annotator/ckpts 174 | ``` 175 | 176 | Thanks to [DWPose](https://github.com/IDEA-Research/DWPose/tree/onnx?tab=readme-ov-file), you can utilize the pose estimation network. Please download the Pose model dw-ll_ucoco_384.onnx ([baidu](https://pan.baidu.com/s/1nuBjw-KKSxD_BkpmwXUJiw?pwd=28d7), [google](https://drive.google.com/file/d/12L8E2oAgZy4VACGSK9RaZBZrfgx7VTA2/view?usp=sharing)) and Det model yolox_l.onnx ([baidu](https://pan.baidu.com/s/1fpfIVpv5ypo4c1bUlzkMYQ?pwd=mjdn), [google](https://drive.google.com/file/d/1w9pXC8tT0p9ndMN-CArp1__b2GbzewWI/view?usp=sharing)). Then please put them into `controlnet/annotator/ckpts/`. 177 | 178 | 179 | ### Inference 180 | 181 | 182 | **a. Using canny ControlNet:** 183 | 184 | ```bash 185 | python ./controlnet/sample_controlNet.py ./controlnet/assets/woman_1.png 一个漂亮的女孩,高品质,超清晰,色彩鲜艳,超高分辨率,最佳品质,8k,高清,4K Canny 186 | 187 | python ./controlnet/sample_controlNet.py ./controlnet/assets/dog.png 全景,一只可爱的白色小狗坐在杯子里,看向镜头,动漫风格,3d渲染,辛烷值渲染 Canny 188 | 189 | # The image will be saved to "controlnet/outputs/" 190 | ``` 191 | 192 | **b. Using depth ControlNet:** 193 | 194 | ```bash 195 | python ./controlnet/sample_controlNet.py ./controlnet/assets/woman_2.png 新海诚风格,丰富的色彩,穿着绿色衬衫的女人站在田野里,唯美风景,清新明亮,斑驳的光影,最好的质量,超细节,8K画质 Depth 196 | 197 | python ./controlnet/sample_controlNet.py ./controlnet/assets/bird.png 一只颜色鲜艳的小鸟,高品质,超清晰,色彩鲜艳,超高分辨率,最佳品质,8k,高清,4K Depth 198 | 199 | # The image will be saved to "controlnet/outputs/" 200 | ``` 201 | 202 | **c. Using pose ControlNet:** 203 | 204 | ```bash 205 | python ./controlnet/sample_controlNet.py ./controlnet/assets/woman_3.png 一位穿着紫色泡泡袖连衣裙、戴着皇冠和白色蕾丝手套的女孩双手托脸,高品质,超清晰,色彩鲜艳,超高分辨率,最佳品质,8k,高清,4K Pose 206 | 207 | python ./controlnet/sample_controlNet.py ./controlnet/assets/woman_4.png 一个穿着黑色运动外套、白色内搭,上面戴着项链的女子,站在街边,背景是红色建筑和绿树,高品质,超清晰,色彩鲜艳,超高分辨率,最佳品质,8k,高清,4K Pose 208 | 209 | # The image will be saved to "controlnet/outputs/" 210 | ``` 211 | 212 | 213 | **c. Using depth ControlNet + IP-Adapter-Plus:** 214 | 215 | If you intend to utilize the kolors-ip-adapter-plus, please make sure to download its corresponding model weights. 216 | 217 | ```bash 218 | python ./controlnet/sample_controlNet_ipadapter.py ./controlnet/assets/woman_2.png ./ipadapter/asset/2.png 一个红色头发的女孩,唯美风景,清新明亮,斑驳的光影,最好的质量,超细节,8K画质 Depth 219 | 220 | python ./controlnet/sample_controlNet_ipadapter.py ./ipadapter/asset/1.png ./controlnet/assets/woman_1.png 一个漂亮的女孩,最好的质量,超细节,8K画质 Depth 221 | 222 | # The image will be saved to "controlnet/outputs/" 223 | ``` 224 | 225 |
226 | 227 | 228 | ### Acknowledgments 229 | - Thanks to [ControlNet](https://github.com/lllyasviel/ControlNet) for providing the codebase. 230 | 231 |
232 | 233 | 234 | -------------------------------------------------------------------------------- /controlnet/annotator/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/controlnet/annotator/__init__.py -------------------------------------------------------------------------------- /controlnet/annotator/canny/__init__.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | 3 | 4 | class CannyDetector: 5 | def __call__(self, img, low_threshold, high_threshold): 6 | return cv2.Canny(img, low_threshold, high_threshold) 7 | -------------------------------------------------------------------------------- /controlnet/annotator/dwpose/__init__.py: -------------------------------------------------------------------------------- 1 | # Openpose 2 | # Original from CMU https://github.com/CMU-Perceptual-Computing-Lab/openpose 3 | # 2nd Edited by https://github.com/Hzzone/pytorch-openpose 4 | # 3rd Edited by ControlNet 5 | # 4th Edited by ControlNet (added face and correct hands) 6 | 7 | import os 8 | os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE" 9 | 10 | import torch 11 | import numpy as np 12 | from . import util 13 | from .wholebody import Wholebody 14 | 15 | def draw_pose(pose, H, W): 16 | bodies = pose['bodies'] 17 | faces = pose['faces'] 18 | hands = pose['hands'] 19 | candidate = bodies['candidate'] 20 | subset = bodies['subset'] 21 | canvas = np.zeros(shape=(H, W, 3), dtype=np.uint8) 22 | 23 | canvas = util.draw_bodypose(canvas, candidate, subset) 24 | 25 | canvas = util.draw_handpose(canvas, hands) 26 | 27 | # canvas = util.draw_facepose(canvas, faces) 28 | 29 | return canvas 30 | 31 | 32 | class DWposeDetector: 33 | def __init__(self): 34 | 35 | self.pose_estimation = Wholebody() 36 | 37 | 38 | def getres(self, oriImg): 39 | out_res = {} 40 | oriImg = oriImg.copy() 41 | H, W, C = oriImg.shape 42 | with torch.no_grad(): 43 | candidate, subset = self.pose_estimation(oriImg) 44 | out_res['candidate']=candidate 45 | out_res['subset']=subset 46 | out_res['width']=W 47 | out_res['height']=H 48 | return out_res 49 | 50 | def __call__(self, oriImg): 51 | 52 | oriImg = oriImg.copy() 53 | H, W, C = oriImg.shape 54 | with torch.no_grad(): 55 | _candidate, _subset = self.pose_estimation(oriImg) 56 | 57 | subset = _subset.copy() 58 | candidate = _candidate.copy() 59 | nums, keys, locs = candidate.shape 60 | candidate[..., 0] /= float(W) 61 | candidate[..., 1] /= float(H) 62 | body = candidate[:,:18].copy() 63 | body = body.reshape(nums*18, locs) 64 | score = subset[:,:18] 65 | for i in range(len(score)): 66 | for j in range(len(score[i])): 67 | if score[i][j] > 0.3: 68 | score[i][j] = int(18*i+j) 69 | else: 70 | score[i][j] = -1 71 | 72 | un_visible = subset<0.3 73 | candidate[un_visible] = -1 74 | 75 | foot = candidate[:,18:24] 76 | 77 | faces = candidate[:,24:92] 78 | 79 | hands = candidate[:,92:113] 80 | hands = np.vstack([hands, candidate[:,113:]]) 81 | 82 | bodies = dict(candidate=body, subset=score) 83 | pose = dict(bodies=bodies, hands=hands, faces=faces) 84 | 85 | out_res = {} 86 | out_res['candidate']=candidate 87 | out_res['subset']=subset 88 | out_res['width']=W 89 | out_res['height']=H 90 | 91 | return out_res,draw_pose(pose, H, W) 92 | -------------------------------------------------------------------------------- /controlnet/annotator/dwpose/onnxdet.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | import onnxruntime 5 | 6 | def nms(boxes, scores, nms_thr): 7 | """Single class NMS implemented in Numpy.""" 8 | x1 = boxes[:, 0] 9 | y1 = boxes[:, 1] 10 | x2 = boxes[:, 2] 11 | y2 = boxes[:, 3] 12 | 13 | areas = (x2 - x1 + 1) * (y2 - y1 + 1) 14 | order = scores.argsort()[::-1] 15 | 16 | keep = [] 17 | while order.size > 0: 18 | i = order[0] 19 | keep.append(i) 20 | xx1 = np.maximum(x1[i], x1[order[1:]]) 21 | yy1 = np.maximum(y1[i], y1[order[1:]]) 22 | xx2 = np.minimum(x2[i], x2[order[1:]]) 23 | yy2 = np.minimum(y2[i], y2[order[1:]]) 24 | 25 | w = np.maximum(0.0, xx2 - xx1 + 1) 26 | h = np.maximum(0.0, yy2 - yy1 + 1) 27 | inter = w * h 28 | ovr = inter / (areas[i] + areas[order[1:]] - inter) 29 | 30 | inds = np.where(ovr <= nms_thr)[0] 31 | order = order[inds + 1] 32 | 33 | return keep 34 | 35 | def multiclass_nms(boxes, scores, nms_thr, score_thr): 36 | """Multiclass NMS implemented in Numpy. Class-aware version.""" 37 | final_dets = [] 38 | num_classes = scores.shape[1] 39 | for cls_ind in range(num_classes): 40 | cls_scores = scores[:, cls_ind] 41 | valid_score_mask = cls_scores > score_thr 42 | if valid_score_mask.sum() == 0: 43 | continue 44 | else: 45 | valid_scores = cls_scores[valid_score_mask] 46 | valid_boxes = boxes[valid_score_mask] 47 | keep = nms(valid_boxes, valid_scores, nms_thr) 48 | if len(keep) > 0: 49 | cls_inds = np.ones((len(keep), 1)) * cls_ind 50 | dets = np.concatenate( 51 | [valid_boxes[keep], valid_scores[keep, None], cls_inds], 1 52 | ) 53 | final_dets.append(dets) 54 | if len(final_dets) == 0: 55 | return None 56 | return np.concatenate(final_dets, 0) 57 | 58 | def demo_postprocess(outputs, img_size, p6=False): 59 | grids = [] 60 | expanded_strides = [] 61 | strides = [8, 16, 32] if not p6 else [8, 16, 32, 64] 62 | 63 | hsizes = [img_size[0] // stride for stride in strides] 64 | wsizes = [img_size[1] // stride for stride in strides] 65 | 66 | for hsize, wsize, stride in zip(hsizes, wsizes, strides): 67 | xv, yv = np.meshgrid(np.arange(wsize), np.arange(hsize)) 68 | grid = np.stack((xv, yv), 2).reshape(1, -1, 2) 69 | grids.append(grid) 70 | shape = grid.shape[:2] 71 | expanded_strides.append(np.full((*shape, 1), stride)) 72 | 73 | grids = np.concatenate(grids, 1) 74 | expanded_strides = np.concatenate(expanded_strides, 1) 75 | outputs[..., :2] = (outputs[..., :2] + grids) * expanded_strides 76 | outputs[..., 2:4] = np.exp(outputs[..., 2:4]) * expanded_strides 77 | 78 | return outputs 79 | 80 | def preprocess(img, input_size, swap=(2, 0, 1)): 81 | if len(img.shape) == 3: 82 | padded_img = np.ones((input_size[0], input_size[1], 3), dtype=np.uint8) * 114 83 | else: 84 | padded_img = np.ones(input_size, dtype=np.uint8) * 114 85 | 86 | r = min(input_size[0] / img.shape[0], input_size[1] / img.shape[1]) 87 | resized_img = cv2.resize( 88 | img, 89 | (int(img.shape[1] * r), int(img.shape[0] * r)), 90 | interpolation=cv2.INTER_LINEAR, 91 | ).astype(np.uint8) 92 | padded_img[: int(img.shape[0] * r), : int(img.shape[1] * r)] = resized_img 93 | 94 | padded_img = padded_img.transpose(swap) 95 | padded_img = np.ascontiguousarray(padded_img, dtype=np.float32) 96 | return padded_img, r 97 | 98 | def inference_detector(session, oriImg): 99 | input_shape = (640,640) 100 | img, ratio = preprocess(oriImg, input_shape) 101 | 102 | ort_inputs = {session.get_inputs()[0].name: img[None, :, :, :]} 103 | output = session.run(None, ort_inputs) 104 | predictions = demo_postprocess(output[0], input_shape)[0] 105 | 106 | boxes = predictions[:, :4] 107 | scores = predictions[:, 4:5] * predictions[:, 5:] 108 | 109 | boxes_xyxy = np.ones_like(boxes) 110 | boxes_xyxy[:, 0] = boxes[:, 0] - boxes[:, 2]/2. 111 | boxes_xyxy[:, 1] = boxes[:, 1] - boxes[:, 3]/2. 112 | boxes_xyxy[:, 2] = boxes[:, 0] + boxes[:, 2]/2. 113 | boxes_xyxy[:, 3] = boxes[:, 1] + boxes[:, 3]/2. 114 | boxes_xyxy /= ratio 115 | dets = multiclass_nms(boxes_xyxy, scores, nms_thr=0.45, score_thr=0.1) 116 | if dets is not None: 117 | final_boxes, final_scores, final_cls_inds = dets[:, :4], dets[:, 4], dets[:, 5] 118 | isscore = final_scores>0.3 119 | iscat = final_cls_inds == 0 120 | isbbox = [ i and j for (i, j) in zip(isscore, iscat)] 121 | final_boxes = final_boxes[isbbox] 122 | else: 123 | final_boxes = np.array([]) 124 | 125 | return final_boxes 126 | -------------------------------------------------------------------------------- /controlnet/annotator/dwpose/util.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import matplotlib 4 | import cv2 5 | 6 | 7 | eps = 0.01 8 | 9 | 10 | def smart_resize(x, s): 11 | Ht, Wt = s 12 | if x.ndim == 2: 13 | Ho, Wo = x.shape 14 | Co = 1 15 | else: 16 | Ho, Wo, Co = x.shape 17 | if Co == 3 or Co == 1: 18 | k = float(Ht + Wt) / float(Ho + Wo) 19 | return cv2.resize(x, (int(Wt), int(Ht)), interpolation=cv2.INTER_AREA if k < 1 else cv2.INTER_LANCZOS4) 20 | else: 21 | return np.stack([smart_resize(x[:, :, i], s) for i in range(Co)], axis=2) 22 | 23 | 24 | def smart_resize_k(x, fx, fy): 25 | if x.ndim == 2: 26 | Ho, Wo = x.shape 27 | Co = 1 28 | else: 29 | Ho, Wo, Co = x.shape 30 | Ht, Wt = Ho * fy, Wo * fx 31 | if Co == 3 or Co == 1: 32 | k = float(Ht + Wt) / float(Ho + Wo) 33 | return cv2.resize(x, (int(Wt), int(Ht)), interpolation=cv2.INTER_AREA if k < 1 else cv2.INTER_LANCZOS4) 34 | else: 35 | return np.stack([smart_resize_k(x[:, :, i], fx, fy) for i in range(Co)], axis=2) 36 | 37 | 38 | def padRightDownCorner(img, stride, padValue): 39 | h = img.shape[0] 40 | w = img.shape[1] 41 | 42 | pad = 4 * [None] 43 | pad[0] = 0 # up 44 | pad[1] = 0 # left 45 | pad[2] = 0 if (h % stride == 0) else stride - (h % stride) # down 46 | pad[3] = 0 if (w % stride == 0) else stride - (w % stride) # right 47 | 48 | img_padded = img 49 | pad_up = np.tile(img_padded[0:1, :, :]*0 + padValue, (pad[0], 1, 1)) 50 | img_padded = np.concatenate((pad_up, img_padded), axis=0) 51 | pad_left = np.tile(img_padded[:, 0:1, :]*0 + padValue, (1, pad[1], 1)) 52 | img_padded = np.concatenate((pad_left, img_padded), axis=1) 53 | pad_down = np.tile(img_padded[-2:-1, :, :]*0 + padValue, (pad[2], 1, 1)) 54 | img_padded = np.concatenate((img_padded, pad_down), axis=0) 55 | pad_right = np.tile(img_padded[:, -2:-1, :]*0 + padValue, (1, pad[3], 1)) 56 | img_padded = np.concatenate((img_padded, pad_right), axis=1) 57 | 58 | return img_padded, pad 59 | 60 | 61 | def transfer(model, model_weights): 62 | transfered_model_weights = {} 63 | for weights_name in model.state_dict().keys(): 64 | transfered_model_weights[weights_name] = model_weights['.'.join(weights_name.split('.')[1:])] 65 | return transfered_model_weights 66 | 67 | 68 | def draw_bodypose(canvas, candidate, subset): 69 | H, W, C = canvas.shape 70 | candidate = np.array(candidate) 71 | subset = np.array(subset) 72 | 73 | stickwidth = 4 74 | 75 | limbSeq = [[2, 3], [2, 6], [3, 4], [4, 5], [6, 7], [7, 8], [2, 9], [9, 10], \ 76 | [10, 11], [2, 12], [12, 13], [13, 14], [2, 1], [1, 15], [15, 17], \ 77 | [1, 16], [16, 18], [3, 17], [6, 18]] 78 | 79 | colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], \ 80 | [0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], \ 81 | [170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]] 82 | 83 | for i in range(17): 84 | for n in range(len(subset)): 85 | index = subset[n][np.array(limbSeq[i]) - 1] 86 | if -1 in index: 87 | continue 88 | Y = candidate[index.astype(int), 0] * float(W) 89 | X = candidate[index.astype(int), 1] * float(H) 90 | mX = np.mean(X) 91 | mY = np.mean(Y) 92 | length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5 93 | angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1])) 94 | polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0, 360, 1) 95 | cv2.fillConvexPoly(canvas, polygon, colors[i]) 96 | 97 | canvas = (canvas * 0.6).astype(np.uint8) 98 | 99 | for i in range(18): 100 | for n in range(len(subset)): 101 | index = int(subset[n][i]) 102 | if index == -1: 103 | continue 104 | x, y = candidate[index][0:2] 105 | x = int(x * W) 106 | y = int(y * H) 107 | cv2.circle(canvas, (int(x), int(y)), 4, colors[i], thickness=-1) 108 | 109 | return canvas 110 | 111 | 112 | def draw_handpose(canvas, all_hand_peaks): 113 | H, W, C = canvas.shape 114 | 115 | edges = [[0, 1], [1, 2], [2, 3], [3, 4], [0, 5], [5, 6], [6, 7], [7, 8], [0, 9], [9, 10], \ 116 | [10, 11], [11, 12], [0, 13], [13, 14], [14, 15], [15, 16], [0, 17], [17, 18], [18, 19], [19, 20]] 117 | 118 | for peaks in all_hand_peaks: 119 | peaks = np.array(peaks) 120 | 121 | for ie, e in enumerate(edges): 122 | x1, y1 = peaks[e[0]] 123 | x2, y2 = peaks[e[1]] 124 | x1 = int(x1 * W) 125 | y1 = int(y1 * H) 126 | x2 = int(x2 * W) 127 | y2 = int(y2 * H) 128 | if x1 > eps and y1 > eps and x2 > eps and y2 > eps: 129 | cv2.line(canvas, (x1, y1), (x2, y2), matplotlib.colors.hsv_to_rgb([ie / float(len(edges)), 1.0, 1.0]) * 255, thickness=2) 130 | 131 | for i, keyponit in enumerate(peaks): 132 | x, y = keyponit 133 | x = int(x * W) 134 | y = int(y * H) 135 | if x > eps and y > eps: 136 | cv2.circle(canvas, (x, y), 4, (0, 0, 255), thickness=-1) 137 | return canvas 138 | 139 | 140 | def draw_facepose(canvas, all_lmks): 141 | H, W, C = canvas.shape 142 | for lmks in all_lmks: 143 | lmks = np.array(lmks) 144 | for lmk in lmks: 145 | x, y = lmk 146 | x = int(x * W) 147 | y = int(y * H) 148 | if x > eps and y > eps: 149 | cv2.circle(canvas, (x, y), 3, (255, 255, 255), thickness=-1) 150 | return canvas 151 | 152 | 153 | # detect hand according to body pose keypoints 154 | # please refer to https://github.com/CMU-Perceptual-Computing-Lab/openpose/blob/master/src/openpose/hand/handDetector.cpp 155 | def handDetect(candidate, subset, oriImg): 156 | # right hand: wrist 4, elbow 3, shoulder 2 157 | # left hand: wrist 7, elbow 6, shoulder 5 158 | ratioWristElbow = 0.33 159 | detect_result = [] 160 | image_height, image_width = oriImg.shape[0:2] 161 | for person in subset.astype(int): 162 | # if any of three not detected 163 | has_left = np.sum(person[[5, 6, 7]] == -1) == 0 164 | has_right = np.sum(person[[2, 3, 4]] == -1) == 0 165 | if not (has_left or has_right): 166 | continue 167 | hands = [] 168 | #left hand 169 | if has_left: 170 | left_shoulder_index, left_elbow_index, left_wrist_index = person[[5, 6, 7]] 171 | x1, y1 = candidate[left_shoulder_index][:2] 172 | x2, y2 = candidate[left_elbow_index][:2] 173 | x3, y3 = candidate[left_wrist_index][:2] 174 | hands.append([x1, y1, x2, y2, x3, y3, True]) 175 | # right hand 176 | if has_right: 177 | right_shoulder_index, right_elbow_index, right_wrist_index = person[[2, 3, 4]] 178 | x1, y1 = candidate[right_shoulder_index][:2] 179 | x2, y2 = candidate[right_elbow_index][:2] 180 | x3, y3 = candidate[right_wrist_index][:2] 181 | hands.append([x1, y1, x2, y2, x3, y3, False]) 182 | 183 | for x1, y1, x2, y2, x3, y3, is_left in hands: 184 | # pos_hand = pos_wrist + ratio * (pos_wrist - pos_elbox) = (1 + ratio) * pos_wrist - ratio * pos_elbox 185 | # handRectangle.x = posePtr[wrist*3] + ratioWristElbow * (posePtr[wrist*3] - posePtr[elbow*3]); 186 | # handRectangle.y = posePtr[wrist*3+1] + ratioWristElbow * (posePtr[wrist*3+1] - posePtr[elbow*3+1]); 187 | # const auto distanceWristElbow = getDistance(poseKeypoints, person, wrist, elbow); 188 | # const auto distanceElbowShoulder = getDistance(poseKeypoints, person, elbow, shoulder); 189 | # handRectangle.width = 1.5f * fastMax(distanceWristElbow, 0.9f * distanceElbowShoulder); 190 | x = x3 + ratioWristElbow * (x3 - x2) 191 | y = y3 + ratioWristElbow * (y3 - y2) 192 | distanceWristElbow = math.sqrt((x3 - x2) ** 2 + (y3 - y2) ** 2) 193 | distanceElbowShoulder = math.sqrt((x2 - x1) ** 2 + (y2 - y1) ** 2) 194 | width = 1.5 * max(distanceWristElbow, 0.9 * distanceElbowShoulder) 195 | # x-y refers to the center --> offset to topLeft point 196 | # handRectangle.x -= handRectangle.width / 2.f; 197 | # handRectangle.y -= handRectangle.height / 2.f; 198 | x -= width / 2 199 | y -= width / 2 # width = height 200 | # overflow the image 201 | if x < 0: x = 0 202 | if y < 0: y = 0 203 | width1 = width 204 | width2 = width 205 | if x + width > image_width: width1 = image_width - x 206 | if y + width > image_height: width2 = image_height - y 207 | width = min(width1, width2) 208 | # the max hand box value is 20 pixels 209 | if width >= 20: 210 | detect_result.append([int(x), int(y), int(width), is_left]) 211 | 212 | ''' 213 | return value: [[x, y, w, True if left hand else False]]. 214 | width=height since the network require squared input. 215 | x, y is the coordinate of top left 216 | ''' 217 | return detect_result 218 | 219 | 220 | # Written by Lvmin 221 | def faceDetect(candidate, subset, oriImg): 222 | # left right eye ear 14 15 16 17 223 | detect_result = [] 224 | image_height, image_width = oriImg.shape[0:2] 225 | for person in subset.astype(int): 226 | has_head = person[0] > -1 227 | if not has_head: 228 | continue 229 | 230 | has_left_eye = person[14] > -1 231 | has_right_eye = person[15] > -1 232 | has_left_ear = person[16] > -1 233 | has_right_ear = person[17] > -1 234 | 235 | if not (has_left_eye or has_right_eye or has_left_ear or has_right_ear): 236 | continue 237 | 238 | head, left_eye, right_eye, left_ear, right_ear = person[[0, 14, 15, 16, 17]] 239 | 240 | width = 0.0 241 | x0, y0 = candidate[head][:2] 242 | 243 | if has_left_eye: 244 | x1, y1 = candidate[left_eye][:2] 245 | d = max(abs(x0 - x1), abs(y0 - y1)) 246 | width = max(width, d * 3.0) 247 | 248 | if has_right_eye: 249 | x1, y1 = candidate[right_eye][:2] 250 | d = max(abs(x0 - x1), abs(y0 - y1)) 251 | width = max(width, d * 3.0) 252 | 253 | if has_left_ear: 254 | x1, y1 = candidate[left_ear][:2] 255 | d = max(abs(x0 - x1), abs(y0 - y1)) 256 | width = max(width, d * 1.5) 257 | 258 | if has_right_ear: 259 | x1, y1 = candidate[right_ear][:2] 260 | d = max(abs(x0 - x1), abs(y0 - y1)) 261 | width = max(width, d * 1.5) 262 | 263 | x, y = x0, y0 264 | 265 | x -= width 266 | y -= width 267 | 268 | if x < 0: 269 | x = 0 270 | 271 | if y < 0: 272 | y = 0 273 | 274 | width1 = width * 2 275 | width2 = width * 2 276 | 277 | if x + width > image_width: 278 | width1 = image_width - x 279 | 280 | if y + width > image_height: 281 | width2 = image_height - y 282 | 283 | width = min(width1, width2) 284 | 285 | if width >= 20: 286 | detect_result.append([int(x), int(y), int(width)]) 287 | 288 | return detect_result 289 | 290 | 291 | # get max index of 2d array 292 | def npmax(array): 293 | arrayindex = array.argmax(1) 294 | arrayvalue = array.max(1) 295 | i = arrayvalue.argmax() 296 | j = arrayindex[i] 297 | return i, j 298 | -------------------------------------------------------------------------------- /controlnet/annotator/dwpose/wholebody.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | import onnxruntime as ort 5 | from .onnxdet import inference_detector 6 | from .onnxpose import inference_pose 7 | 8 | class Wholebody: 9 | def __init__(self): 10 | device = 'cuda:0' 11 | providers = ['CPUExecutionProvider' 12 | ] if device == 'cpu' else ['CUDAExecutionProvider'] 13 | # providers = ['CPUExecutionProvider'] 14 | providers = ['CUDAExecutionProvider'] 15 | onnx_det = 'annotator/ckpts/yolox_l.onnx' 16 | onnx_pose = 'annotator/ckpts/dw-ll_ucoco_384.onnx' 17 | 18 | self.session_det = ort.InferenceSession(path_or_bytes=onnx_det, providers=providers) 19 | self.session_pose = ort.InferenceSession(path_or_bytes=onnx_pose, providers=providers) 20 | def __call__(self, oriImg): 21 | det_result = inference_detector(self.session_det, oriImg) 22 | keypoints, scores = inference_pose(self.session_pose, det_result, oriImg) 23 | 24 | keypoints_info = np.concatenate( 25 | (keypoints, scores[..., None]), axis=-1) 26 | # compute neck joint 27 | neck = np.mean(keypoints_info[:, [5, 6]], axis=1) 28 | # neck score when visualizing pred 29 | neck[:, 2:4] = np.logical_and( 30 | keypoints_info[:, 5, 2:4] > 0.3, 31 | keypoints_info[:, 6, 2:4] > 0.3).astype(int) 32 | new_keypoints_info = np.insert( 33 | keypoints_info, 17, neck, axis=1) 34 | mmpose_idx = [ 35 | 17, 6, 8, 10, 7, 9, 12, 14, 16, 13, 15, 2, 1, 4, 3 36 | ] 37 | openpose_idx = [ 38 | 1, 2, 3, 4, 6, 7, 8, 9, 10, 12, 13, 14, 15, 16, 17 39 | ] 40 | new_keypoints_info[:, openpose_idx] = \ 41 | new_keypoints_info[:, mmpose_idx] 42 | keypoints_info = new_keypoints_info 43 | 44 | keypoints, scores = keypoints_info[ 45 | ..., :2], keypoints_info[..., 2] 46 | 47 | return keypoints, scores 48 | 49 | 50 | -------------------------------------------------------------------------------- /controlnet/annotator/midas/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Intel ISL (Intel Intelligent Systems Lab) 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /controlnet/annotator/midas/__init__.py: -------------------------------------------------------------------------------- 1 | # Midas Depth Estimation 2 | # From https://github.com/isl-org/MiDaS 3 | # MIT LICENSE 4 | 5 | import cv2 6 | import numpy as np 7 | import torch 8 | 9 | from einops import rearrange 10 | from .api import MiDaSInference 11 | 12 | 13 | class MidasDetector: 14 | def __init__(self): 15 | self.model = MiDaSInference(model_type="dpt_hybrid").cuda() 16 | self.rng = np.random.RandomState(0) 17 | 18 | def __call__(self, input_image): 19 | assert input_image.ndim == 3 20 | image_depth = input_image 21 | with torch.no_grad(): 22 | image_depth = torch.from_numpy(image_depth).float().cuda() 23 | image_depth = image_depth / 127.5 - 1.0 24 | image_depth = rearrange(image_depth, 'h w c -> 1 c h w') 25 | depth = self.model(image_depth)[0] 26 | 27 | depth -= torch.min(depth) 28 | depth /= torch.max(depth) 29 | depth = depth.cpu().numpy() 30 | depth_image = (depth * 255.0).clip(0, 255).astype(np.uint8) 31 | 32 | return depth_image 33 | 34 | 35 | 36 | -------------------------------------------------------------------------------- /controlnet/annotator/midas/api.py: -------------------------------------------------------------------------------- 1 | # based on https://github.com/isl-org/MiDaS 2 | 3 | import cv2 4 | import os 5 | import torch 6 | import torch.nn as nn 7 | from torchvision.transforms import Compose 8 | 9 | from .midas.dpt_depth import DPTDepthModel 10 | from .midas.midas_net import MidasNet 11 | from .midas.midas_net_custom import MidasNet_small 12 | from .midas.transforms import Resize, NormalizeImage, PrepareForNet 13 | from annotator.util import annotator_ckpts_path 14 | 15 | 16 | ISL_PATHS = { 17 | "dpt_large": os.path.join(annotator_ckpts_path, "dpt_large-midas-2f21e586.pt"), 18 | "dpt_hybrid": os.path.join(annotator_ckpts_path, "dpt_hybrid-midas-501f0c75.pt"), 19 | "midas_v21": "", 20 | "midas_v21_small": "", 21 | } 22 | 23 | remote_model_path = "https://huggingface.co/lllyasviel/Annotators/resolve/main/dpt_hybrid-midas-501f0c75.pt" 24 | 25 | 26 | def disabled_train(self, mode=True): 27 | """Overwrite model.train with this function to make sure train/eval mode 28 | does not change anymore.""" 29 | return self 30 | 31 | 32 | def load_midas_transform(model_type): 33 | # https://github.com/isl-org/MiDaS/blob/master/run.py 34 | # load transform only 35 | if model_type == "dpt_large": # DPT-Large 36 | net_w, net_h = 384, 384 37 | resize_mode = "minimal" 38 | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 39 | 40 | elif model_type == "dpt_hybrid": # DPT-Hybrid 41 | net_w, net_h = 384, 384 42 | resize_mode = "minimal" 43 | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 44 | 45 | elif model_type == "midas_v21": 46 | net_w, net_h = 384, 384 47 | resize_mode = "upper_bound" 48 | normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 49 | 50 | elif model_type == "midas_v21_small": 51 | net_w, net_h = 256, 256 52 | resize_mode = "upper_bound" 53 | normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 54 | 55 | else: 56 | assert False, f"model_type '{model_type}' not implemented, use: --model_type large" 57 | 58 | transform = Compose( 59 | [ 60 | Resize( 61 | net_w, 62 | net_h, 63 | resize_target=None, 64 | keep_aspect_ratio=True, 65 | ensure_multiple_of=32, 66 | resize_method=resize_mode, 67 | image_interpolation_method=cv2.INTER_CUBIC, 68 | ), 69 | normalization, 70 | PrepareForNet(), 71 | ] 72 | ) 73 | 74 | return transform 75 | 76 | 77 | def load_model(model_type): 78 | # https://github.com/isl-org/MiDaS/blob/master/run.py 79 | # load network 80 | model_path = ISL_PATHS[model_type] 81 | if model_type == "dpt_large": # DPT-Large 82 | model = DPTDepthModel( 83 | path=model_path, 84 | backbone="vitl16_384", 85 | non_negative=True, 86 | ) 87 | net_w, net_h = 384, 384 88 | resize_mode = "minimal" 89 | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 90 | 91 | elif model_type == "dpt_hybrid": # DPT-Hybrid 92 | if not os.path.exists(model_path): 93 | from basicsr.utils.download_util import load_file_from_url 94 | load_file_from_url(remote_model_path, model_dir=annotator_ckpts_path) 95 | 96 | model = DPTDepthModel( 97 | path=model_path, 98 | backbone="vitb_rn50_384", 99 | non_negative=True, 100 | ) 101 | net_w, net_h = 384, 384 102 | resize_mode = "minimal" 103 | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 104 | 105 | elif model_type == "midas_v21": 106 | model = MidasNet(model_path, non_negative=True) 107 | net_w, net_h = 384, 384 108 | resize_mode = "upper_bound" 109 | normalization = NormalizeImage( 110 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] 111 | ) 112 | 113 | elif model_type == "midas_v21_small": 114 | model = MidasNet_small(model_path, features=64, backbone="efficientnet_lite3", exportable=True, 115 | non_negative=True, blocks={'expand': True}) 116 | net_w, net_h = 256, 256 117 | resize_mode = "upper_bound" 118 | normalization = NormalizeImage( 119 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] 120 | ) 121 | 122 | else: 123 | print(f"model_type '{model_type}' not implemented, use: --model_type large") 124 | assert False 125 | 126 | transform = Compose( 127 | [ 128 | Resize( 129 | net_w, 130 | net_h, 131 | resize_target=None, 132 | keep_aspect_ratio=True, 133 | ensure_multiple_of=32, 134 | resize_method=resize_mode, 135 | image_interpolation_method=cv2.INTER_CUBIC, 136 | ), 137 | normalization, 138 | PrepareForNet(), 139 | ] 140 | ) 141 | 142 | return model.eval(), transform 143 | 144 | 145 | class MiDaSInference(nn.Module): 146 | MODEL_TYPES_TORCH_HUB = [ 147 | "DPT_Large", 148 | "DPT_Hybrid", 149 | "MiDaS_small" 150 | ] 151 | MODEL_TYPES_ISL = [ 152 | "dpt_large", 153 | "dpt_hybrid", 154 | "midas_v21", 155 | "midas_v21_small", 156 | ] 157 | 158 | def __init__(self, model_type): 159 | super().__init__() 160 | assert (model_type in self.MODEL_TYPES_ISL) 161 | model, _ = load_model(model_type) 162 | self.model = model 163 | self.model.train = disabled_train 164 | 165 | def forward(self, x): 166 | with torch.no_grad(): 167 | prediction = self.model(x) 168 | return prediction 169 | 170 | -------------------------------------------------------------------------------- /controlnet/annotator/midas/midas/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/controlnet/annotator/midas/midas/__init__.py -------------------------------------------------------------------------------- /controlnet/annotator/midas/midas/base_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class BaseModel(torch.nn.Module): 5 | def load(self, path): 6 | """Load model from file. 7 | 8 | Args: 9 | path (str): file path 10 | """ 11 | parameters = torch.load(path, map_location=torch.device('cpu')) 12 | 13 | if "optimizer" in parameters: 14 | parameters = parameters["model"] 15 | 16 | self.load_state_dict(parameters) 17 | -------------------------------------------------------------------------------- /controlnet/annotator/midas/midas/blocks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .vit import ( 5 | _make_pretrained_vitb_rn50_384, 6 | _make_pretrained_vitl16_384, 7 | _make_pretrained_vitb16_384, 8 | forward_vit, 9 | ) 10 | 11 | def _make_encoder(backbone, features, use_pretrained, groups=1, expand=False, exportable=True, hooks=None, use_vit_only=False, use_readout="ignore",): 12 | if backbone == "vitl16_384": 13 | pretrained = _make_pretrained_vitl16_384( 14 | use_pretrained, hooks=hooks, use_readout=use_readout 15 | ) 16 | scratch = _make_scratch( 17 | [256, 512, 1024, 1024], features, groups=groups, expand=expand 18 | ) # ViT-L/16 - 85.0% Top1 (backbone) 19 | elif backbone == "vitb_rn50_384": 20 | pretrained = _make_pretrained_vitb_rn50_384( 21 | use_pretrained, 22 | hooks=hooks, 23 | use_vit_only=use_vit_only, 24 | use_readout=use_readout, 25 | ) 26 | scratch = _make_scratch( 27 | [256, 512, 768, 768], features, groups=groups, expand=expand 28 | ) # ViT-H/16 - 85.0% Top1 (backbone) 29 | elif backbone == "vitb16_384": 30 | pretrained = _make_pretrained_vitb16_384( 31 | use_pretrained, hooks=hooks, use_readout=use_readout 32 | ) 33 | scratch = _make_scratch( 34 | [96, 192, 384, 768], features, groups=groups, expand=expand 35 | ) # ViT-B/16 - 84.6% Top1 (backbone) 36 | elif backbone == "resnext101_wsl": 37 | pretrained = _make_pretrained_resnext101_wsl(use_pretrained) 38 | scratch = _make_scratch([256, 512, 1024, 2048], features, groups=groups, expand=expand) # efficientnet_lite3 39 | elif backbone == "efficientnet_lite3": 40 | pretrained = _make_pretrained_efficientnet_lite3(use_pretrained, exportable=exportable) 41 | scratch = _make_scratch([32, 48, 136, 384], features, groups=groups, expand=expand) # efficientnet_lite3 42 | else: 43 | print(f"Backbone '{backbone}' not implemented") 44 | assert False 45 | 46 | return pretrained, scratch 47 | 48 | 49 | def _make_scratch(in_shape, out_shape, groups=1, expand=False): 50 | scratch = nn.Module() 51 | 52 | out_shape1 = out_shape 53 | out_shape2 = out_shape 54 | out_shape3 = out_shape 55 | out_shape4 = out_shape 56 | if expand==True: 57 | out_shape1 = out_shape 58 | out_shape2 = out_shape*2 59 | out_shape3 = out_shape*4 60 | out_shape4 = out_shape*8 61 | 62 | scratch.layer1_rn = nn.Conv2d( 63 | in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups 64 | ) 65 | scratch.layer2_rn = nn.Conv2d( 66 | in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups 67 | ) 68 | scratch.layer3_rn = nn.Conv2d( 69 | in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups 70 | ) 71 | scratch.layer4_rn = nn.Conv2d( 72 | in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups 73 | ) 74 | 75 | return scratch 76 | 77 | 78 | def _make_pretrained_efficientnet_lite3(use_pretrained, exportable=False): 79 | efficientnet = torch.hub.load( 80 | "rwightman/gen-efficientnet-pytorch", 81 | "tf_efficientnet_lite3", 82 | pretrained=use_pretrained, 83 | exportable=exportable 84 | ) 85 | return _make_efficientnet_backbone(efficientnet) 86 | 87 | 88 | def _make_efficientnet_backbone(effnet): 89 | pretrained = nn.Module() 90 | 91 | pretrained.layer1 = nn.Sequential( 92 | effnet.conv_stem, effnet.bn1, effnet.act1, *effnet.blocks[0:2] 93 | ) 94 | pretrained.layer2 = nn.Sequential(*effnet.blocks[2:3]) 95 | pretrained.layer3 = nn.Sequential(*effnet.blocks[3:5]) 96 | pretrained.layer4 = nn.Sequential(*effnet.blocks[5:9]) 97 | 98 | return pretrained 99 | 100 | 101 | def _make_resnet_backbone(resnet): 102 | pretrained = nn.Module() 103 | pretrained.layer1 = nn.Sequential( 104 | resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1 105 | ) 106 | 107 | pretrained.layer2 = resnet.layer2 108 | pretrained.layer3 = resnet.layer3 109 | pretrained.layer4 = resnet.layer4 110 | 111 | return pretrained 112 | 113 | 114 | def _make_pretrained_resnext101_wsl(use_pretrained): 115 | resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl") 116 | return _make_resnet_backbone(resnet) 117 | 118 | 119 | 120 | class Interpolate(nn.Module): 121 | """Interpolation module. 122 | """ 123 | 124 | def __init__(self, scale_factor, mode, align_corners=False): 125 | """Init. 126 | 127 | Args: 128 | scale_factor (float): scaling 129 | mode (str): interpolation mode 130 | """ 131 | super(Interpolate, self).__init__() 132 | 133 | self.interp = nn.functional.interpolate 134 | self.scale_factor = scale_factor 135 | self.mode = mode 136 | self.align_corners = align_corners 137 | 138 | def forward(self, x): 139 | """Forward pass. 140 | 141 | Args: 142 | x (tensor): input 143 | 144 | Returns: 145 | tensor: interpolated data 146 | """ 147 | 148 | x = self.interp( 149 | x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners 150 | ) 151 | 152 | return x 153 | 154 | 155 | class ResidualConvUnit(nn.Module): 156 | """Residual convolution module. 157 | """ 158 | 159 | def __init__(self, features): 160 | """Init. 161 | 162 | Args: 163 | features (int): number of features 164 | """ 165 | super().__init__() 166 | 167 | self.conv1 = nn.Conv2d( 168 | features, features, kernel_size=3, stride=1, padding=1, bias=True 169 | ) 170 | 171 | self.conv2 = nn.Conv2d( 172 | features, features, kernel_size=3, stride=1, padding=1, bias=True 173 | ) 174 | 175 | self.relu = nn.ReLU(inplace=True) 176 | 177 | def forward(self, x): 178 | """Forward pass. 179 | 180 | Args: 181 | x (tensor): input 182 | 183 | Returns: 184 | tensor: output 185 | """ 186 | out = self.relu(x) 187 | out = self.conv1(out) 188 | out = self.relu(out) 189 | out = self.conv2(out) 190 | 191 | return out + x 192 | 193 | 194 | class FeatureFusionBlock(nn.Module): 195 | """Feature fusion block. 196 | """ 197 | 198 | def __init__(self, features): 199 | """Init. 200 | 201 | Args: 202 | features (int): number of features 203 | """ 204 | super(FeatureFusionBlock, self).__init__() 205 | 206 | self.resConfUnit1 = ResidualConvUnit(features) 207 | self.resConfUnit2 = ResidualConvUnit(features) 208 | 209 | def forward(self, *xs): 210 | """Forward pass. 211 | 212 | Returns: 213 | tensor: output 214 | """ 215 | output = xs[0] 216 | 217 | if len(xs) == 2: 218 | output += self.resConfUnit1(xs[1]) 219 | 220 | output = self.resConfUnit2(output) 221 | 222 | output = nn.functional.interpolate( 223 | output, scale_factor=2, mode="bilinear", align_corners=True 224 | ) 225 | 226 | return output 227 | 228 | 229 | 230 | 231 | class ResidualConvUnit_custom(nn.Module): 232 | """Residual convolution module. 233 | """ 234 | 235 | def __init__(self, features, activation, bn): 236 | """Init. 237 | 238 | Args: 239 | features (int): number of features 240 | """ 241 | super().__init__() 242 | 243 | self.bn = bn 244 | 245 | self.groups=1 246 | 247 | self.conv1 = nn.Conv2d( 248 | features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups 249 | ) 250 | 251 | self.conv2 = nn.Conv2d( 252 | features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups 253 | ) 254 | 255 | if self.bn==True: 256 | self.bn1 = nn.BatchNorm2d(features) 257 | self.bn2 = nn.BatchNorm2d(features) 258 | 259 | self.activation = activation 260 | 261 | self.skip_add = nn.quantized.FloatFunctional() 262 | 263 | def forward(self, x): 264 | """Forward pass. 265 | 266 | Args: 267 | x (tensor): input 268 | 269 | Returns: 270 | tensor: output 271 | """ 272 | 273 | out = self.activation(x) 274 | out = self.conv1(out) 275 | if self.bn==True: 276 | out = self.bn1(out) 277 | 278 | out = self.activation(out) 279 | out = self.conv2(out) 280 | if self.bn==True: 281 | out = self.bn2(out) 282 | 283 | if self.groups > 1: 284 | out = self.conv_merge(out) 285 | 286 | return self.skip_add.add(out, x) 287 | 288 | # return out + x 289 | 290 | 291 | class FeatureFusionBlock_custom(nn.Module): 292 | """Feature fusion block. 293 | """ 294 | 295 | def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True): 296 | """Init. 297 | 298 | Args: 299 | features (int): number of features 300 | """ 301 | super(FeatureFusionBlock_custom, self).__init__() 302 | 303 | self.deconv = deconv 304 | self.align_corners = align_corners 305 | 306 | self.groups=1 307 | 308 | self.expand = expand 309 | out_features = features 310 | if self.expand==True: 311 | out_features = features//2 312 | 313 | self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1) 314 | 315 | self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn) 316 | self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn) 317 | 318 | self.skip_add = nn.quantized.FloatFunctional() 319 | 320 | def forward(self, *xs): 321 | """Forward pass. 322 | 323 | Returns: 324 | tensor: output 325 | """ 326 | output = xs[0] 327 | 328 | if len(xs) == 2: 329 | res = self.resConfUnit1(xs[1]) 330 | output = self.skip_add.add(output, res) 331 | # output += res 332 | 333 | output = self.resConfUnit2(output) 334 | 335 | output = nn.functional.interpolate( 336 | output, scale_factor=2, mode="bilinear", align_corners=self.align_corners 337 | ) 338 | 339 | output = self.out_conv(output) 340 | 341 | return output 342 | 343 | -------------------------------------------------------------------------------- /controlnet/annotator/midas/midas/dpt_depth.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .base_model import BaseModel 6 | from .blocks import ( 7 | FeatureFusionBlock, 8 | FeatureFusionBlock_custom, 9 | Interpolate, 10 | _make_encoder, 11 | forward_vit, 12 | ) 13 | 14 | 15 | def _make_fusion_block(features, use_bn): 16 | return FeatureFusionBlock_custom( 17 | features, 18 | nn.ReLU(False), 19 | deconv=False, 20 | bn=use_bn, 21 | expand=False, 22 | align_corners=True, 23 | ) 24 | 25 | 26 | class DPT(BaseModel): 27 | def __init__( 28 | self, 29 | head, 30 | features=256, 31 | backbone="vitb_rn50_384", 32 | readout="project", 33 | channels_last=False, 34 | use_bn=False, 35 | ): 36 | 37 | super(DPT, self).__init__() 38 | 39 | self.channels_last = channels_last 40 | 41 | hooks = { 42 | "vitb_rn50_384": [0, 1, 8, 11], 43 | "vitb16_384": [2, 5, 8, 11], 44 | "vitl16_384": [5, 11, 17, 23], 45 | } 46 | 47 | # Instantiate backbone and reassemble blocks 48 | self.pretrained, self.scratch = _make_encoder( 49 | backbone, 50 | features, 51 | False, # Set to true of you want to train from scratch, uses ImageNet weights 52 | groups=1, 53 | expand=False, 54 | exportable=False, 55 | hooks=hooks[backbone], 56 | use_readout=readout, 57 | ) 58 | 59 | self.scratch.refinenet1 = _make_fusion_block(features, use_bn) 60 | self.scratch.refinenet2 = _make_fusion_block(features, use_bn) 61 | self.scratch.refinenet3 = _make_fusion_block(features, use_bn) 62 | self.scratch.refinenet4 = _make_fusion_block(features, use_bn) 63 | 64 | self.scratch.output_conv = head 65 | 66 | 67 | def forward(self, x): 68 | if self.channels_last == True: 69 | x.contiguous(memory_format=torch.channels_last) 70 | 71 | layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x) 72 | 73 | layer_1_rn = self.scratch.layer1_rn(layer_1) 74 | layer_2_rn = self.scratch.layer2_rn(layer_2) 75 | layer_3_rn = self.scratch.layer3_rn(layer_3) 76 | layer_4_rn = self.scratch.layer4_rn(layer_4) 77 | 78 | path_4 = self.scratch.refinenet4(layer_4_rn) 79 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn) 80 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn) 81 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn) 82 | 83 | out = self.scratch.output_conv(path_1) 84 | 85 | return out 86 | 87 | 88 | class DPTDepthModel(DPT): 89 | def __init__(self, path=None, non_negative=True, **kwargs): 90 | features = kwargs["features"] if "features" in kwargs else 256 91 | 92 | head = nn.Sequential( 93 | nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1), 94 | Interpolate(scale_factor=2, mode="bilinear", align_corners=True), 95 | nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1), 96 | nn.ReLU(True), 97 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), 98 | nn.ReLU(True) if non_negative else nn.Identity(), 99 | nn.Identity(), 100 | ) 101 | 102 | super().__init__(head, **kwargs) 103 | 104 | if path is not None: 105 | self.load(path) 106 | 107 | def forward(self, x): 108 | return super().forward(x).squeeze(dim=1) 109 | 110 | -------------------------------------------------------------------------------- /controlnet/annotator/midas/midas/midas_net.py: -------------------------------------------------------------------------------- 1 | """MidashNet: Network for monocular depth estimation trained by mixing several datasets. 2 | This file contains code that is adapted from 3 | https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py 4 | """ 5 | import torch 6 | import torch.nn as nn 7 | 8 | from .base_model import BaseModel 9 | from .blocks import FeatureFusionBlock, Interpolate, _make_encoder 10 | 11 | 12 | class MidasNet(BaseModel): 13 | """Network for monocular depth estimation. 14 | """ 15 | 16 | def __init__(self, path=None, features=256, non_negative=True): 17 | """Init. 18 | 19 | Args: 20 | path (str, optional): Path to saved model. Defaults to None. 21 | features (int, optional): Number of features. Defaults to 256. 22 | backbone (str, optional): Backbone network for encoder. Defaults to resnet50 23 | """ 24 | print("Loading weights: ", path) 25 | 26 | super(MidasNet, self).__init__() 27 | 28 | use_pretrained = False if path is None else True 29 | 30 | self.pretrained, self.scratch = _make_encoder(backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained) 31 | 32 | self.scratch.refinenet4 = FeatureFusionBlock(features) 33 | self.scratch.refinenet3 = FeatureFusionBlock(features) 34 | self.scratch.refinenet2 = FeatureFusionBlock(features) 35 | self.scratch.refinenet1 = FeatureFusionBlock(features) 36 | 37 | self.scratch.output_conv = nn.Sequential( 38 | nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1), 39 | Interpolate(scale_factor=2, mode="bilinear"), 40 | nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1), 41 | nn.ReLU(True), 42 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), 43 | nn.ReLU(True) if non_negative else nn.Identity(), 44 | ) 45 | 46 | if path: 47 | self.load(path) 48 | 49 | def forward(self, x): 50 | """Forward pass. 51 | 52 | Args: 53 | x (tensor): input data (image) 54 | 55 | Returns: 56 | tensor: depth 57 | """ 58 | 59 | layer_1 = self.pretrained.layer1(x) 60 | layer_2 = self.pretrained.layer2(layer_1) 61 | layer_3 = self.pretrained.layer3(layer_2) 62 | layer_4 = self.pretrained.layer4(layer_3) 63 | 64 | layer_1_rn = self.scratch.layer1_rn(layer_1) 65 | layer_2_rn = self.scratch.layer2_rn(layer_2) 66 | layer_3_rn = self.scratch.layer3_rn(layer_3) 67 | layer_4_rn = self.scratch.layer4_rn(layer_4) 68 | 69 | path_4 = self.scratch.refinenet4(layer_4_rn) 70 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn) 71 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn) 72 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn) 73 | 74 | out = self.scratch.output_conv(path_1) 75 | 76 | return torch.squeeze(out, dim=1) 77 | -------------------------------------------------------------------------------- /controlnet/annotator/midas/midas/midas_net_custom.py: -------------------------------------------------------------------------------- 1 | """MidashNet: Network for monocular depth estimation trained by mixing several datasets. 2 | This file contains code that is adapted from 3 | https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py 4 | """ 5 | import torch 6 | import torch.nn as nn 7 | 8 | from .base_model import BaseModel 9 | from .blocks import FeatureFusionBlock, FeatureFusionBlock_custom, Interpolate, _make_encoder 10 | 11 | 12 | class MidasNet_small(BaseModel): 13 | """Network for monocular depth estimation. 14 | """ 15 | 16 | def __init__(self, path=None, features=64, backbone="efficientnet_lite3", non_negative=True, exportable=True, channels_last=False, align_corners=True, 17 | blocks={'expand': True}): 18 | """Init. 19 | 20 | Args: 21 | path (str, optional): Path to saved model. Defaults to None. 22 | features (int, optional): Number of features. Defaults to 256. 23 | backbone (str, optional): Backbone network for encoder. Defaults to resnet50 24 | """ 25 | print("Loading weights: ", path) 26 | 27 | super(MidasNet_small, self).__init__() 28 | 29 | use_pretrained = False if path else True 30 | 31 | self.channels_last = channels_last 32 | self.blocks = blocks 33 | self.backbone = backbone 34 | 35 | self.groups = 1 36 | 37 | features1=features 38 | features2=features 39 | features3=features 40 | features4=features 41 | self.expand = False 42 | if "expand" in self.blocks and self.blocks['expand'] == True: 43 | self.expand = True 44 | features1=features 45 | features2=features*2 46 | features3=features*4 47 | features4=features*8 48 | 49 | self.pretrained, self.scratch = _make_encoder(self.backbone, features, use_pretrained, groups=self.groups, expand=self.expand, exportable=exportable) 50 | 51 | self.scratch.activation = nn.ReLU(False) 52 | 53 | self.scratch.refinenet4 = FeatureFusionBlock_custom(features4, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) 54 | self.scratch.refinenet3 = FeatureFusionBlock_custom(features3, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) 55 | self.scratch.refinenet2 = FeatureFusionBlock_custom(features2, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners) 56 | self.scratch.refinenet1 = FeatureFusionBlock_custom(features1, self.scratch.activation, deconv=False, bn=False, align_corners=align_corners) 57 | 58 | 59 | self.scratch.output_conv = nn.Sequential( 60 | nn.Conv2d(features, features//2, kernel_size=3, stride=1, padding=1, groups=self.groups), 61 | Interpolate(scale_factor=2, mode="bilinear"), 62 | nn.Conv2d(features//2, 32, kernel_size=3, stride=1, padding=1), 63 | self.scratch.activation, 64 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), 65 | nn.ReLU(True) if non_negative else nn.Identity(), 66 | nn.Identity(), 67 | ) 68 | 69 | if path: 70 | self.load(path) 71 | 72 | 73 | def forward(self, x): 74 | """Forward pass. 75 | 76 | Args: 77 | x (tensor): input data (image) 78 | 79 | Returns: 80 | tensor: depth 81 | """ 82 | if self.channels_last==True: 83 | print("self.channels_last = ", self.channels_last) 84 | x.contiguous(memory_format=torch.channels_last) 85 | 86 | 87 | layer_1 = self.pretrained.layer1(x) 88 | layer_2 = self.pretrained.layer2(layer_1) 89 | layer_3 = self.pretrained.layer3(layer_2) 90 | layer_4 = self.pretrained.layer4(layer_3) 91 | 92 | layer_1_rn = self.scratch.layer1_rn(layer_1) 93 | layer_2_rn = self.scratch.layer2_rn(layer_2) 94 | layer_3_rn = self.scratch.layer3_rn(layer_3) 95 | layer_4_rn = self.scratch.layer4_rn(layer_4) 96 | 97 | 98 | path_4 = self.scratch.refinenet4(layer_4_rn) 99 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn) 100 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn) 101 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn) 102 | 103 | out = self.scratch.output_conv(path_1) 104 | 105 | return torch.squeeze(out, dim=1) 106 | 107 | 108 | 109 | def fuse_model(m): 110 | prev_previous_type = nn.Identity() 111 | prev_previous_name = '' 112 | previous_type = nn.Identity() 113 | previous_name = '' 114 | for name, module in m.named_modules(): 115 | if prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d and type(module) == nn.ReLU: 116 | # print("FUSED ", prev_previous_name, previous_name, name) 117 | torch.quantization.fuse_modules(m, [prev_previous_name, previous_name, name], inplace=True) 118 | elif prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d: 119 | # print("FUSED ", prev_previous_name, previous_name) 120 | torch.quantization.fuse_modules(m, [prev_previous_name, previous_name], inplace=True) 121 | # elif previous_type == nn.Conv2d and type(module) == nn.ReLU: 122 | # print("FUSED ", previous_name, name) 123 | # torch.quantization.fuse_modules(m, [previous_name, name], inplace=True) 124 | 125 | prev_previous_type = previous_type 126 | prev_previous_name = previous_name 127 | previous_type = type(module) 128 | previous_name = name -------------------------------------------------------------------------------- /controlnet/annotator/midas/midas/transforms.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import math 4 | 5 | 6 | def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA): 7 | """Rezise the sample to ensure the given size. Keeps aspect ratio. 8 | 9 | Args: 10 | sample (dict): sample 11 | size (tuple): image size 12 | 13 | Returns: 14 | tuple: new size 15 | """ 16 | shape = list(sample["disparity"].shape) 17 | 18 | if shape[0] >= size[0] and shape[1] >= size[1]: 19 | return sample 20 | 21 | scale = [0, 0] 22 | scale[0] = size[0] / shape[0] 23 | scale[1] = size[1] / shape[1] 24 | 25 | scale = max(scale) 26 | 27 | shape[0] = math.ceil(scale * shape[0]) 28 | shape[1] = math.ceil(scale * shape[1]) 29 | 30 | # resize 31 | sample["image"] = cv2.resize( 32 | sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method 33 | ) 34 | 35 | sample["disparity"] = cv2.resize( 36 | sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST 37 | ) 38 | sample["mask"] = cv2.resize( 39 | sample["mask"].astype(np.float32), 40 | tuple(shape[::-1]), 41 | interpolation=cv2.INTER_NEAREST, 42 | ) 43 | sample["mask"] = sample["mask"].astype(bool) 44 | 45 | return tuple(shape) 46 | 47 | 48 | class Resize(object): 49 | """Resize sample to given size (width, height). 50 | """ 51 | 52 | def __init__( 53 | self, 54 | width, 55 | height, 56 | resize_target=True, 57 | keep_aspect_ratio=False, 58 | ensure_multiple_of=1, 59 | resize_method="lower_bound", 60 | image_interpolation_method=cv2.INTER_AREA, 61 | ): 62 | """Init. 63 | 64 | Args: 65 | width (int): desired output width 66 | height (int): desired output height 67 | resize_target (bool, optional): 68 | True: Resize the full sample (image, mask, target). 69 | False: Resize image only. 70 | Defaults to True. 71 | keep_aspect_ratio (bool, optional): 72 | True: Keep the aspect ratio of the input sample. 73 | Output sample might not have the given width and height, and 74 | resize behaviour depends on the parameter 'resize_method'. 75 | Defaults to False. 76 | ensure_multiple_of (int, optional): 77 | Output width and height is constrained to be multiple of this parameter. 78 | Defaults to 1. 79 | resize_method (str, optional): 80 | "lower_bound": Output will be at least as large as the given size. 81 | "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.) 82 | "minimal": Scale as least as possible. (Output size might be smaller than given size.) 83 | Defaults to "lower_bound". 84 | """ 85 | self.__width = width 86 | self.__height = height 87 | 88 | self.__resize_target = resize_target 89 | self.__keep_aspect_ratio = keep_aspect_ratio 90 | self.__multiple_of = ensure_multiple_of 91 | self.__resize_method = resize_method 92 | self.__image_interpolation_method = image_interpolation_method 93 | 94 | def constrain_to_multiple_of(self, x, min_val=0, max_val=None): 95 | y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int) 96 | 97 | if max_val is not None and y > max_val: 98 | y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int) 99 | 100 | if y < min_val: 101 | y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int) 102 | 103 | return y 104 | 105 | def get_size(self, width, height): 106 | # determine new height and width 107 | scale_height = self.__height / height 108 | scale_width = self.__width / width 109 | 110 | if self.__keep_aspect_ratio: 111 | if self.__resize_method == "lower_bound": 112 | # scale such that output size is lower bound 113 | if scale_width > scale_height: 114 | # fit width 115 | scale_height = scale_width 116 | else: 117 | # fit height 118 | scale_width = scale_height 119 | elif self.__resize_method == "upper_bound": 120 | # scale such that output size is upper bound 121 | if scale_width < scale_height: 122 | # fit width 123 | scale_height = scale_width 124 | else: 125 | # fit height 126 | scale_width = scale_height 127 | elif self.__resize_method == "minimal": 128 | # scale as least as possbile 129 | if abs(1 - scale_width) < abs(1 - scale_height): 130 | # fit width 131 | scale_height = scale_width 132 | else: 133 | # fit height 134 | scale_width = scale_height 135 | else: 136 | raise ValueError( 137 | f"resize_method {self.__resize_method} not implemented" 138 | ) 139 | 140 | if self.__resize_method == "lower_bound": 141 | new_height = self.constrain_to_multiple_of( 142 | scale_height * height, min_val=self.__height 143 | ) 144 | new_width = self.constrain_to_multiple_of( 145 | scale_width * width, min_val=self.__width 146 | ) 147 | elif self.__resize_method == "upper_bound": 148 | new_height = self.constrain_to_multiple_of( 149 | scale_height * height, max_val=self.__height 150 | ) 151 | new_width = self.constrain_to_multiple_of( 152 | scale_width * width, max_val=self.__width 153 | ) 154 | elif self.__resize_method == "minimal": 155 | new_height = self.constrain_to_multiple_of(scale_height * height) 156 | new_width = self.constrain_to_multiple_of(scale_width * width) 157 | else: 158 | raise ValueError(f"resize_method {self.__resize_method} not implemented") 159 | 160 | return (new_width, new_height) 161 | 162 | def __call__(self, sample): 163 | width, height = self.get_size( 164 | sample["image"].shape[1], sample["image"].shape[0] 165 | ) 166 | 167 | # resize sample 168 | sample["image"] = cv2.resize( 169 | sample["image"], 170 | (width, height), 171 | interpolation=self.__image_interpolation_method, 172 | ) 173 | 174 | if self.__resize_target: 175 | if "disparity" in sample: 176 | sample["disparity"] = cv2.resize( 177 | sample["disparity"], 178 | (width, height), 179 | interpolation=cv2.INTER_NEAREST, 180 | ) 181 | 182 | if "depth" in sample: 183 | sample["depth"] = cv2.resize( 184 | sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST 185 | ) 186 | 187 | sample["mask"] = cv2.resize( 188 | sample["mask"].astype(np.float32), 189 | (width, height), 190 | interpolation=cv2.INTER_NEAREST, 191 | ) 192 | sample["mask"] = sample["mask"].astype(bool) 193 | 194 | return sample 195 | 196 | 197 | class NormalizeImage(object): 198 | """Normlize image by given mean and std. 199 | """ 200 | 201 | def __init__(self, mean, std): 202 | self.__mean = mean 203 | self.__std = std 204 | 205 | def __call__(self, sample): 206 | sample["image"] = (sample["image"] - self.__mean) / self.__std 207 | 208 | return sample 209 | 210 | 211 | class PrepareForNet(object): 212 | """Prepare sample for usage as network input. 213 | """ 214 | 215 | def __init__(self): 216 | pass 217 | 218 | def __call__(self, sample): 219 | image = np.transpose(sample["image"], (2, 0, 1)) 220 | sample["image"] = np.ascontiguousarray(image).astype(np.float32) 221 | 222 | if "mask" in sample: 223 | sample["mask"] = sample["mask"].astype(np.float32) 224 | sample["mask"] = np.ascontiguousarray(sample["mask"]) 225 | 226 | if "disparity" in sample: 227 | disparity = sample["disparity"].astype(np.float32) 228 | sample["disparity"] = np.ascontiguousarray(disparity) 229 | 230 | if "depth" in sample: 231 | depth = sample["depth"].astype(np.float32) 232 | sample["depth"] = np.ascontiguousarray(depth) 233 | 234 | return sample 235 | -------------------------------------------------------------------------------- /controlnet/annotator/midas/utils.py: -------------------------------------------------------------------------------- 1 | """Utils for monoDepth.""" 2 | import sys 3 | import re 4 | import numpy as np 5 | import cv2 6 | import torch 7 | 8 | 9 | def read_pfm(path): 10 | """Read pfm file. 11 | 12 | Args: 13 | path (str): path to file 14 | 15 | Returns: 16 | tuple: (data, scale) 17 | """ 18 | with open(path, "rb") as file: 19 | 20 | color = None 21 | width = None 22 | height = None 23 | scale = None 24 | endian = None 25 | 26 | header = file.readline().rstrip() 27 | if header.decode("ascii") == "PF": 28 | color = True 29 | elif header.decode("ascii") == "Pf": 30 | color = False 31 | else: 32 | raise Exception("Not a PFM file: " + path) 33 | 34 | dim_match = re.match(r"^(\d+)\s(\d+)\s$", file.readline().decode("ascii")) 35 | if dim_match: 36 | width, height = list(map(int, dim_match.groups())) 37 | else: 38 | raise Exception("Malformed PFM header.") 39 | 40 | scale = float(file.readline().decode("ascii").rstrip()) 41 | if scale < 0: 42 | # little-endian 43 | endian = "<" 44 | scale = -scale 45 | else: 46 | # big-endian 47 | endian = ">" 48 | 49 | data = np.fromfile(file, endian + "f") 50 | shape = (height, width, 3) if color else (height, width) 51 | 52 | data = np.reshape(data, shape) 53 | data = np.flipud(data) 54 | 55 | return data, scale 56 | 57 | 58 | def write_pfm(path, image, scale=1): 59 | """Write pfm file. 60 | 61 | Args: 62 | path (str): pathto file 63 | image (array): data 64 | scale (int, optional): Scale. Defaults to 1. 65 | """ 66 | 67 | with open(path, "wb") as file: 68 | color = None 69 | 70 | if image.dtype.name != "float32": 71 | raise Exception("Image dtype must be float32.") 72 | 73 | image = np.flipud(image) 74 | 75 | if len(image.shape) == 3 and image.shape[2] == 3: # color image 76 | color = True 77 | elif ( 78 | len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1 79 | ): # greyscale 80 | color = False 81 | else: 82 | raise Exception("Image must have H x W x 3, H x W x 1 or H x W dimensions.") 83 | 84 | file.write("PF\n" if color else "Pf\n".encode()) 85 | file.write("%d %d\n".encode() % (image.shape[1], image.shape[0])) 86 | 87 | endian = image.dtype.byteorder 88 | 89 | if endian == "<" or endian == "=" and sys.byteorder == "little": 90 | scale = -scale 91 | 92 | file.write("%f\n".encode() % scale) 93 | 94 | image.tofile(file) 95 | 96 | 97 | def read_image(path): 98 | """Read image and output RGB image (0-1). 99 | 100 | Args: 101 | path (str): path to file 102 | 103 | Returns: 104 | array: RGB image (0-1) 105 | """ 106 | img = cv2.imread(path) 107 | 108 | if img.ndim == 2: 109 | img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) 110 | 111 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0 112 | 113 | return img 114 | 115 | 116 | def resize_image(img): 117 | """Resize image and make it fit for network. 118 | 119 | Args: 120 | img (array): image 121 | 122 | Returns: 123 | tensor: data ready for network 124 | """ 125 | height_orig = img.shape[0] 126 | width_orig = img.shape[1] 127 | 128 | if width_orig > height_orig: 129 | scale = width_orig / 384 130 | else: 131 | scale = height_orig / 384 132 | 133 | height = (np.ceil(height_orig / scale / 32) * 32).astype(int) 134 | width = (np.ceil(width_orig / scale / 32) * 32).astype(int) 135 | 136 | img_resized = cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA) 137 | 138 | img_resized = ( 139 | torch.from_numpy(np.transpose(img_resized, (2, 0, 1))).contiguous().float() 140 | ) 141 | img_resized = img_resized.unsqueeze(0) 142 | 143 | return img_resized 144 | 145 | 146 | def resize_depth(depth, width, height): 147 | """Resize depth map and bring to CPU (numpy). 148 | 149 | Args: 150 | depth (tensor): depth 151 | width (int): image width 152 | height (int): image height 153 | 154 | Returns: 155 | array: processed depth 156 | """ 157 | depth = torch.squeeze(depth[0, :, :, :]).to("cpu") 158 | 159 | depth_resized = cv2.resize( 160 | depth.numpy(), (width, height), interpolation=cv2.INTER_CUBIC 161 | ) 162 | 163 | return depth_resized 164 | 165 | def write_depth(path, depth, bits=1): 166 | """Write depth map to pfm and png file. 167 | 168 | Args: 169 | path (str): filepath without extension 170 | depth (array): depth 171 | """ 172 | write_pfm(path + ".pfm", depth.astype(np.float32)) 173 | 174 | depth_min = depth.min() 175 | depth_max = depth.max() 176 | 177 | max_val = (2**(8*bits))-1 178 | 179 | if depth_max - depth_min > np.finfo("float").eps: 180 | out = max_val * (depth - depth_min) / (depth_max - depth_min) 181 | else: 182 | out = np.zeros(depth.shape, dtype=depth.type) 183 | 184 | if bits == 1: 185 | cv2.imwrite(path + ".png", out.astype("uint8")) 186 | elif bits == 2: 187 | cv2.imwrite(path + ".png", out.astype("uint16")) 188 | 189 | return 190 | -------------------------------------------------------------------------------- /controlnet/annotator/util.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import numpy as np 4 | import cv2 5 | import os 6 | import PIL 7 | 8 | annotator_ckpts_path = os.path.join(os.path.dirname(__file__), 'ckpts') 9 | 10 | def HWC3(x): 11 | assert x.dtype == np.uint8 12 | if x.ndim == 2: 13 | x = x[:, :, None] 14 | assert x.ndim == 3 15 | H, W, C = x.shape 16 | assert C == 1 or C == 3 or C == 4 17 | if C == 3: 18 | return x 19 | if C == 1: 20 | return np.concatenate([x, x, x], axis=2) 21 | if C == 4: 22 | color = x[:, :, 0:3].astype(np.float32) 23 | alpha = x[:, :, 3:4].astype(np.float32) / 255.0 24 | y = color * alpha + 255.0 * (1.0 - alpha) 25 | y = y.clip(0, 255).astype(np.uint8) 26 | return y 27 | 28 | 29 | def resize_image(input_image, resolution, short = False, interpolation=None): 30 | if isinstance(input_image,PIL.Image.Image): 31 | mode = 'pil' 32 | W,H = input_image.size 33 | 34 | elif isinstance(input_image,np.ndarray): 35 | mode = 'cv2' 36 | H, W, _ = input_image.shape 37 | 38 | H = float(H) 39 | W = float(W) 40 | if short: 41 | k = float(resolution) / min(H, W) # k>1 放大, k<1 缩小 42 | else: 43 | k = float(resolution) / max(H, W) # k>1 放大, k<1 缩小 44 | H *= k 45 | W *= k 46 | H = int(np.round(H / 64.0)) * 64 47 | W = int(np.round(W / 64.0)) * 64 48 | 49 | if mode == 'cv2': 50 | if interpolation is None: 51 | interpolation = cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA 52 | img = cv2.resize(input_image, (W, H), interpolation=interpolation) 53 | 54 | elif mode == 'pil': 55 | if interpolation is None: 56 | interpolation = PIL.Image.LANCZOS if k > 1 else PIL.Image.BILINEAR 57 | img = input_image.resize((W, H), resample=interpolation) 58 | 59 | return img 60 | 61 | # def resize_image(input_image, resolution): 62 | # H, W, C = input_image.shape 63 | # H = float(H) 64 | # W = float(W) 65 | # k = float(resolution) / min(H, W) 66 | # H *= k 67 | # W *= k 68 | # H = int(np.round(H / 64.0)) * 64 69 | # W = int(np.round(W / 64.0)) * 64 70 | # img = cv2.resize(input_image, (W, H), interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA) 71 | # return img 72 | 73 | 74 | def nms(x, t, s): 75 | x = cv2.GaussianBlur(x.astype(np.float32), (0, 0), s) 76 | 77 | f1 = np.array([[0, 0, 0], [1, 1, 1], [0, 0, 0]], dtype=np.uint8) 78 | f2 = np.array([[0, 1, 0], [0, 1, 0], [0, 1, 0]], dtype=np.uint8) 79 | f3 = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.uint8) 80 | f4 = np.array([[0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=np.uint8) 81 | 82 | y = np.zeros_like(x) 83 | 84 | for f in [f1, f2, f3, f4]: 85 | np.putmask(y, cv2.dilate(x, kernel=f) == x, x) 86 | 87 | z = np.zeros_like(y, dtype=np.uint8) 88 | z[y > t] = 255 89 | return z 90 | 91 | 92 | def make_noise_disk(H, W, C, F): 93 | noise = np.random.uniform(low=0, high=1, size=((H // F) + 2, (W // F) + 2, C)) 94 | noise = cv2.resize(noise, (W + 2 * F, H + 2 * F), interpolation=cv2.INTER_CUBIC) 95 | noise = noise[F: F + H, F: F + W] 96 | noise -= np.min(noise) 97 | noise /= np.max(noise) 98 | if C == 1: 99 | noise = noise[:, :, None] 100 | return noise 101 | 102 | 103 | def min_max_norm(x): 104 | x -= np.min(x) 105 | x /= np.maximum(np.max(x), 1e-5) 106 | return x 107 | 108 | 109 | def safe_step(x, step=2): 110 | y = x.astype(np.float32) * float(step + 1) 111 | y = y.astype(np.int32).astype(np.float32) / float(step) 112 | return y 113 | 114 | 115 | def img2mask(img, H, W, low=10, high=90): 116 | assert img.ndim == 3 or img.ndim == 2 117 | assert img.dtype == np.uint8 118 | 119 | if img.ndim == 3: 120 | y = img[:, :, random.randrange(0, img.shape[2])] 121 | else: 122 | y = img 123 | 124 | y = cv2.resize(y, (W, H), interpolation=cv2.INTER_CUBIC) 125 | 126 | if random.uniform(0, 1) < 0.5: 127 | y = 255 - y 128 | 129 | return y < np.percentile(y, random.randrange(low, high)) 130 | -------------------------------------------------------------------------------- /controlnet/assets/bird.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/controlnet/assets/bird.png -------------------------------------------------------------------------------- /controlnet/assets/dog.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/controlnet/assets/dog.png -------------------------------------------------------------------------------- /controlnet/assets/woman_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/controlnet/assets/woman_1.png -------------------------------------------------------------------------------- /controlnet/assets/woman_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/controlnet/assets/woman_2.png -------------------------------------------------------------------------------- /controlnet/assets/woman_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/controlnet/assets/woman_3.png -------------------------------------------------------------------------------- /controlnet/assets/woman_4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/controlnet/assets/woman_4.png -------------------------------------------------------------------------------- /controlnet/outputs/Canny_dog.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/controlnet/outputs/Canny_dog.jpg -------------------------------------------------------------------------------- /controlnet/outputs/Canny_dog_condition.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/controlnet/outputs/Canny_dog_condition.jpg -------------------------------------------------------------------------------- /controlnet/outputs/Canny_woman_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/controlnet/outputs/Canny_woman_1.jpg -------------------------------------------------------------------------------- /controlnet/outputs/Canny_woman_1_condition.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/controlnet/outputs/Canny_woman_1_condition.jpg -------------------------------------------------------------------------------- /controlnet/outputs/Canny_woman_1_sdxl.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/controlnet/outputs/Canny_woman_1_sdxl.jpg -------------------------------------------------------------------------------- /controlnet/outputs/Depth_1_condition.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/controlnet/outputs/Depth_1_condition.jpg -------------------------------------------------------------------------------- /controlnet/outputs/Depth_bird.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/controlnet/outputs/Depth_bird.jpg -------------------------------------------------------------------------------- /controlnet/outputs/Depth_bird_condition.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/controlnet/outputs/Depth_bird_condition.jpg -------------------------------------------------------------------------------- /controlnet/outputs/Depth_bird_sdxl.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/controlnet/outputs/Depth_bird_sdxl.jpg -------------------------------------------------------------------------------- /controlnet/outputs/Depth_ipadapter_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/controlnet/outputs/Depth_ipadapter_1.jpg -------------------------------------------------------------------------------- /controlnet/outputs/Depth_ipadapter_woman_2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/controlnet/outputs/Depth_ipadapter_woman_2.jpg -------------------------------------------------------------------------------- /controlnet/outputs/Depth_woman_2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/controlnet/outputs/Depth_woman_2.jpg -------------------------------------------------------------------------------- /controlnet/outputs/Depth_woman_2_condition.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/controlnet/outputs/Depth_woman_2_condition.jpg -------------------------------------------------------------------------------- /controlnet/outputs/Pose_woman_3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/controlnet/outputs/Pose_woman_3.jpg -------------------------------------------------------------------------------- /controlnet/outputs/Pose_woman_3_condition.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/controlnet/outputs/Pose_woman_3_condition.jpg -------------------------------------------------------------------------------- /controlnet/outputs/Pose_woman_3_sdxl.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/controlnet/outputs/Pose_woman_3_sdxl.jpg -------------------------------------------------------------------------------- /controlnet/outputs/Pose_woman_4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/controlnet/outputs/Pose_woman_4.jpg -------------------------------------------------------------------------------- /controlnet/outputs/Pose_woman_4_condition.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/controlnet/outputs/Pose_woman_4_condition.jpg -------------------------------------------------------------------------------- /controlnet/sample_controlNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import CLIPVisionModelWithProjection,CLIPImageProcessor 3 | from diffusers.utils import load_image 4 | import os,sys 5 | 6 | from kolors.pipelines.pipeline_controlnet_xl_kolors_img2img import StableDiffusionXLControlNetImg2ImgPipeline 7 | from kolors.models.modeling_chatglm import ChatGLMModel 8 | from kolors.models.tokenization_chatglm import ChatGLMTokenizer 9 | from kolors.models.controlnet import ControlNetModel 10 | 11 | from diffusers import AutoencoderKL 12 | from kolors.models.unet_2d_condition import UNet2DConditionModel 13 | 14 | from diffusers import EulerDiscreteScheduler 15 | from PIL import Image 16 | import numpy as np 17 | import cv2 18 | 19 | from annotator.midas import MidasDetector 20 | from annotator.dwpose import DWposeDetector 21 | from annotator.util import resize_image,HWC3 22 | 23 | root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 24 | 25 | def process_canny_condition( image, canny_threods=[100,200] ): 26 | np_image = image.copy() 27 | np_image = cv2.Canny(np_image, canny_threods[0], canny_threods[1]) 28 | np_image = np_image[:, :, None] 29 | np_image = np.concatenate([np_image, np_image, np_image], axis=2) 30 | np_image = HWC3(np_image) 31 | return Image.fromarray(np_image) 32 | 33 | 34 | model_midas = None 35 | def process_depth_condition_midas(img, res = 1024): 36 | h,w,_ = img.shape 37 | img = resize_image(HWC3(img), res) 38 | global model_midas 39 | if model_midas is None: 40 | model_midas = MidasDetector() 41 | 42 | result = HWC3( model_midas(img) ) 43 | result = cv2.resize( result, (w,h) ) 44 | return Image.fromarray(result) 45 | 46 | 47 | model_dwpose = None 48 | def process_dwpose_condition( image, res=1024 ): 49 | h,w,_ = image.shape 50 | img = resize_image(HWC3(image), res) 51 | global model_dwpose 52 | if model_dwpose is None: 53 | model_dwpose = DWposeDetector() 54 | out_res, out_img = model_dwpose(image) 55 | result = HWC3( out_img ) 56 | result = cv2.resize( result, (w,h) ) 57 | return Image.fromarray(result) 58 | 59 | 60 | def infer( image_path , prompt, model_type = 'Canny' ): 61 | 62 | ckpt_dir = f'{root_dir}/weights/Kolors' 63 | text_encoder = ChatGLMModel.from_pretrained( 64 | f'{ckpt_dir}/text_encoder', 65 | torch_dtype=torch.float16).half() 66 | tokenizer = ChatGLMTokenizer.from_pretrained(f'{ckpt_dir}/text_encoder') 67 | vae = AutoencoderKL.from_pretrained(f"{ckpt_dir}/vae", revision=None).half() 68 | scheduler = EulerDiscreteScheduler.from_pretrained(f"{ckpt_dir}/scheduler") 69 | unet = UNet2DConditionModel.from_pretrained(f"{ckpt_dir}/unet", revision=None).half() 70 | 71 | control_path = f'{root_dir}/weights/Kolors-ControlNet-{model_type}' 72 | controlnet = ControlNetModel.from_pretrained( control_path , revision=None).half() 73 | 74 | pipe = StableDiffusionXLControlNetImg2ImgPipeline( 75 | vae=vae, 76 | controlnet = controlnet, 77 | text_encoder=text_encoder, 78 | tokenizer=tokenizer, 79 | unet=unet, 80 | scheduler=scheduler, 81 | force_zeros_for_empty_prompt=False 82 | ) 83 | 84 | pipe = pipe.to("cuda") 85 | pipe.enable_model_cpu_offload() 86 | 87 | negative_prompt = 'nsfw,脸部阴影,低分辨率,jpeg伪影、模糊、糟糕,黑脸,霓虹灯' 88 | 89 | MAX_IMG_SIZE=1024 90 | controlnet_conditioning_scale = 0.7 91 | control_guidance_end = 0.9 92 | strength = 1.0 93 | 94 | basename = image_path.rsplit('/',1)[-1].rsplit('.',1)[0] 95 | 96 | init_image = Image.open( image_path ) 97 | 98 | init_image = resize_image( init_image, MAX_IMG_SIZE) 99 | if model_type == 'Canny': 100 | condi_img = process_canny_condition( np.array(init_image) ) 101 | elif model_type == 'Depth': 102 | condi_img = process_depth_condition_midas( np.array(init_image), MAX_IMG_SIZE ) 103 | elif model_type == 'Pose': 104 | condi_img = process_dwpose_condition( np.array(init_image), MAX_IMG_SIZE) 105 | 106 | generator = torch.Generator(device="cpu").manual_seed(66) 107 | image = pipe( 108 | prompt= prompt , 109 | image = init_image, 110 | controlnet_conditioning_scale = controlnet_conditioning_scale, 111 | control_guidance_end = control_guidance_end, 112 | strength= strength , 113 | control_image = condi_img, 114 | negative_prompt= negative_prompt , 115 | num_inference_steps= 50 , 116 | guidance_scale= 6.0, 117 | num_images_per_prompt=1, 118 | generator=generator, 119 | ).images[0] 120 | 121 | condi_img.save( f'{root_dir}/controlnet/outputs/{model_type}_{basename}_condition.jpg' ) 122 | image.save(f'{root_dir}/controlnet/outputs/{model_type}_{basename}.jpg') 123 | 124 | 125 | if __name__ == '__main__': 126 | import fire 127 | fire.Fire(infer) 128 | 129 | 130 | -------------------------------------------------------------------------------- /controlnet/sample_controlNet_ipadapter.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import CLIPVisionModelWithProjection,CLIPImageProcessor 3 | from diffusers.utils import load_image 4 | import os,sys 5 | 6 | from kolors.pipelines.pipeline_controlnet_xl_kolors_img2img import StableDiffusionXLControlNetImg2ImgPipeline 7 | from kolors.models.modeling_chatglm import ChatGLMModel 8 | from kolors.models.tokenization_chatglm import ChatGLMTokenizer 9 | from kolors.models.controlnet import ControlNetModel 10 | 11 | from diffusers import AutoencoderKL 12 | from kolors.models.unet_2d_condition import UNet2DConditionModel 13 | 14 | from diffusers import EulerDiscreteScheduler 15 | from PIL import Image 16 | import numpy as np 17 | import cv2 18 | 19 | from annotator.midas import MidasDetector 20 | from annotator.util import resize_image,HWC3 21 | 22 | root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 23 | 24 | def process_canny_condition( image, canny_threods=[100,200] ): 25 | np_image = image.copy() 26 | np_image = cv2.Canny(np_image, canny_threods[0], canny_threods[1]) 27 | np_image = np_image[:, :, None] 28 | np_image = np.concatenate([np_image, np_image, np_image], axis=2) 29 | np_image = HWC3(np_image) 30 | return Image.fromarray(np_image) 31 | 32 | 33 | model_midas = None 34 | def process_depth_condition_midas(img, res = 1024): 35 | h,w,_ = img.shape 36 | img = resize_image(HWC3(img), res) 37 | global model_midas 38 | if model_midas is None: 39 | model_midas = MidasDetector() 40 | 41 | result = HWC3( model_midas(img) ) 42 | result = cv2.resize( result, (w,h) ) 43 | return Image.fromarray(result) 44 | 45 | 46 | def infer( image_path , ip_image_path, prompt, model_type = 'Canny' ): 47 | 48 | ckpt_dir = f'{root_dir}/weights/Kolors' 49 | text_encoder = ChatGLMModel.from_pretrained( 50 | f'{ckpt_dir}/text_encoder', 51 | torch_dtype=torch.float16).half() 52 | tokenizer = ChatGLMTokenizer.from_pretrained(f'{ckpt_dir}/text_encoder') 53 | vae = AutoencoderKL.from_pretrained(f"{ckpt_dir}/vae", revision=None).half() 54 | scheduler = EulerDiscreteScheduler.from_pretrained(f"{ckpt_dir}/scheduler") 55 | unet = UNet2DConditionModel.from_pretrained(f"{ckpt_dir}/unet", revision=None).half() 56 | 57 | control_path = f'{root_dir}/weights/Kolors-ControlNet-{model_type}' 58 | controlnet = ControlNetModel.from_pretrained( control_path , revision=None).half() 59 | 60 | # IP-Adapter model 61 | image_encoder = CLIPVisionModelWithProjection.from_pretrained( f'{root_dir}/weights/Kolors-IP-Adapter-Plus/image_encoder', ignore_mismatched_sizes=True).to(dtype=torch.float16) 62 | ip_img_size = 336 63 | clip_image_processor = CLIPImageProcessor( size=ip_img_size, crop_size=ip_img_size ) 64 | 65 | pipe = StableDiffusionXLControlNetImg2ImgPipeline( 66 | vae=vae, 67 | controlnet = controlnet, 68 | text_encoder=text_encoder, 69 | tokenizer=tokenizer, 70 | unet=unet, 71 | scheduler=scheduler, 72 | image_encoder=image_encoder, 73 | feature_extractor=clip_image_processor, 74 | force_zeros_for_empty_prompt=False 75 | ) 76 | 77 | if hasattr(pipe.unet, 'encoder_hid_proj'): 78 | pipe.unet.text_encoder_hid_proj = pipe.unet.encoder_hid_proj 79 | 80 | pipe.load_ip_adapter( f'{root_dir}/weights/Kolors-IP-Adapter-Plus' , subfolder="", weight_name=["ip_adapter_plus_general.bin"]) 81 | 82 | pipe = pipe.to("cuda") 83 | pipe.enable_model_cpu_offload() 84 | 85 | negative_prompt = 'nsfw,脸部阴影,低分辨率,糟糕的解剖结构、糟糕的手,缺失手指、质量最差、低质量、jpeg伪影、模糊、糟糕,黑脸,霓虹灯' 86 | 87 | MAX_IMG_SIZE=1024 88 | controlnet_conditioning_scale = 0.5 89 | control_guidance_end = 0.9 90 | strength = 1.0 91 | ip_scale = 0.5 92 | 93 | basename = image_path.rsplit('/',1)[-1].rsplit('.',1)[0] 94 | 95 | init_image = Image.open( image_path ) 96 | init_image = resize_image( init_image, MAX_IMG_SIZE) 97 | 98 | if model_type == 'Canny': 99 | condi_img = process_canny_condition( np.array(init_image) ) 100 | elif model_type == 'Depth': 101 | condi_img = process_depth_condition_midas( np.array(init_image), MAX_IMG_SIZE ) 102 | 103 | ip_adapter_img = Image.open(ip_image_path) 104 | pipe.set_ip_adapter_scale([ ip_scale ]) 105 | 106 | generator = torch.Generator(device="cpu").manual_seed(66) 107 | image = pipe( 108 | prompt= prompt , 109 | image = init_image, 110 | controlnet_conditioning_scale = controlnet_conditioning_scale, 111 | control_guidance_end = control_guidance_end, 112 | 113 | ip_adapter_image=[ ip_adapter_img ], 114 | 115 | strength= strength , 116 | control_image = condi_img, 117 | negative_prompt= negative_prompt , 118 | num_inference_steps= 50 , 119 | guidance_scale= 5.0, 120 | num_images_per_prompt=1, 121 | generator=generator, 122 | ).images[0] 123 | 124 | image.save(f'{root_dir}/controlnet/outputs/{model_type}_ipadapter_{basename}.jpg') 125 | condi_img.save(f'{root_dir}/controlnet/outputs/{model_type}_{basename}_condition.jpg') 126 | 127 | 128 | if __name__ == '__main__': 129 | import fire 130 | fire.Fire(infer) 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 | 141 | -------------------------------------------------------------------------------- /dreambooth/README.md: -------------------------------------------------------------------------------- 1 | ## 📖 Introduction 2 | We Provide LoRA training and inference code based on [Kolors-Basemodel](https://huggingface.co/Kwai-Kolors/Kolors), along with an IP LoRA training example. 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 |
Example Result
Prompt Result Image
ktxl狗在草地上跑。

24 | 25 |
26 | 27 | **Our improvements** 28 | 29 | - Supporting user-defined caption files. By putting the '.txt' file with the same name as the image file in the same directory, the caption file will be automatically matched with the proper image 30 | 31 | ## 🛠️ Usage 32 | 33 | ### Requirements 34 | 35 | The dependencies and installation are basically the same as the [Kolors-BaseModel](https://huggingface.co/Kwai-Kolors/Kolors). 36 | 37 | 1. Repository Cloning and Dependency Installation 38 | 39 | ```bash 40 | apt-get install git-lfs 41 | git clone https://github.com/Kwai-Kolors/Kolors 42 | cd Kolors 43 | conda create --name kolors python=3.8 44 | conda activate kolors 45 | pip install -r requirements.txt 46 | python3 setup.py install 47 | ``` 48 | ### Training 49 | 50 | 1. First, we need to get our datasaet.https://huggingface.co/datasets/diffusers/dog-example. 51 | ```python 52 | from huggingface_hub import snapshot_download 53 | 54 | local_dir = "./dog" 55 | snapshot_download( 56 | "diffusers/dog-example", 57 | local_dir=local_dir, repo_type="dataset", 58 | ignore_patterns=".gitattributes", 59 | ) 60 | ``` 61 | 62 | **___Note: To load caption files automatically during training, you can use the same name for both the image and its corresponding '.txt' caption file.___** 63 | 64 | 65 | 2. Launch the training using: 66 | ```bash 67 | sh train.sh 68 | ``` 69 | 70 | 3. Training configuration. We train the model using the default configuration in the `train.sh` file on 8 V100 GPUs, consuming a total of 27GB of the memory. You can also finetune the text encoder by adding `--train_text_encoder`: 71 | ```bash 72 | MODEL_NAME="/path/base_model_path" 73 | CLASS_DIR="/path/regularization_image_path" 74 | INSTANCE_DIR="path/training_image_path" 75 | OUTPUT_DIR="./trained_models" 76 | cfg_file=./default_config.yaml 77 | 78 | accelerate launch --config_file ${cfg_file} train_dreambooth_lora.py \ 79 | --pretrained_model_name_or_path=$MODEL_NAME \ 80 | --instance_data_dir=$INSTANCE_DIR \ 81 | --output_dir=$OUTPUT_DIR \ 82 | --class_data_dir=$CLASS_DIR \ 83 | --instance_prompt="ktxl狗" \ 84 | --class_prompt="狗" \ 85 | --train_batch_size=1 \ 86 | --gradient_accumulation_steps=1 \ 87 | --learning_rate=2e-5 \ 88 | --text_encoder_lr=5e-5 \ 89 | --lr_scheduler="polynomial" \ 90 | --lr_warmup_steps=100 \ 91 | --rank=4 \ 92 | --resolution=1024 \ 93 | --max_train_steps=2000 \ 94 | --checkpointing_steps=200 \ 95 | --num_class_images=100 \ 96 | --center_crop \ 97 | --mixed_precision='fp16' \ 98 | --seed=19980818 \ 99 | --img_repeat_nums=1 \ 100 | --sample_batch_size=2 \ 101 | --use_preffix_prompt \ 102 | --gradient_checkpointing \ 103 | --train_text_encoder \ 104 | --adam_weight_decay=1e-02 \ 105 | --with_prior_preservation \ 106 | --prior_loss_weight=0.7 \ 107 | 108 | ``` 109 | 110 | 111 | **___Note: Most of our training configurations stay the same with official [diffusers](https://github.com/huggingface/diffusers/tree/main/examples/dreambooth) .___** 112 | 113 | ### Inference 114 | ```bash 115 | python infer_dreambooth.py "ktxl狗在草地上跑" 116 | ``` 117 | -------------------------------------------------------------------------------- /dreambooth/default_config.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | deepspeed_config: 3 | deepspeed_multinode_launcher: standard 4 | gradient_accumulation_steps: 1 5 | gradient_clipping: 1.5 6 | offload_optimizer_device: none 7 | offload_param_device: none 8 | zero3_init_flag: true 9 | zero_stage: 2 10 | reduce_scatter: false 11 | overlap_comm: true 12 | distributed_type: DEEPSPEED 13 | downcast_bf16: 'no' 14 | dynamo_backend: 'NO' 15 | fsdp_config: {} 16 | machine_rank: 0 17 | main_process_ip: 10.82.42.75 18 | main_process_port: 22280 19 | main_training_function: main 20 | megatron_lm_config: {} 21 | mixed_precision: fp16 22 | num_machines: 1 23 | num_processes: 8 24 | rdzv_backend: static 25 | same_network: true 26 | use_cpu: false 27 | -------------------------------------------------------------------------------- /dreambooth/infer_dreambooth.py: -------------------------------------------------------------------------------- 1 | import os, torch 2 | from kolors.pipelines.pipeline_stable_diffusion_xl_chatglm_256 import StableDiffusionXLPipeline 3 | from kolors.models.modeling_chatglm import ChatGLMModel 4 | from kolors.models.tokenization_chatglm import ChatGLMTokenizer 5 | from diffusers import UNet2DConditionModel, AutoencoderKL 6 | from diffusers import EulerDiscreteScheduler 7 | 8 | from peft import ( 9 | LoraConfig, 10 | PeftModel, 11 | ) 12 | 13 | def infer(prompt): 14 | ckpt_dir = "/path/base_model_path" 15 | lora_ckpt = 'trained_models/ktxl_dog_text/checkpoint-1000/' 16 | load_text_encoder = True 17 | 18 | text_encoder = ChatGLMModel.from_pretrained( 19 | f'{ckpt_dir}/text_encoder', 20 | torch_dtype=torch.float16).half() 21 | tokenizer = ChatGLMTokenizer.from_pretrained(f'{ckpt_dir}/text_encoder') 22 | vae = AutoencoderKL.from_pretrained(f"{ckpt_dir}/vae", revision=None).half() 23 | scheduler = EulerDiscreteScheduler.from_pretrained(f"{ckpt_dir}/scheduler") 24 | unet = UNet2DConditionModel.from_pretrained(f"{ckpt_dir}/unet", revision=None).half() 25 | pipe = StableDiffusionXLPipeline( 26 | vae=vae, 27 | text_encoder=text_encoder, 28 | tokenizer=tokenizer, 29 | unet=unet, 30 | scheduler=scheduler, 31 | force_zeros_for_empty_prompt=False) 32 | pipe = pipe.to("cuda") 33 | 34 | pipe.load_lora_weights(lora_ckpt, adapter_name="ktxl-lora") 35 | pipe.set_adapters(["ktxl-lora"], [0.8]) 36 | 37 | if load_text_encoder: 38 | pipe.text_encoder = PeftModel.from_pretrained(pipe.text_encoder, lora_ckpt) 39 | 40 | random_seed = 0 41 | generator = torch.Generator(pipe.device).manual_seed(random_seed) 42 | 43 | neg_p = '' 44 | out = pipe(prompt, generator=generator, negative_prompt=neg_p, num_inference_steps=25, width=1024, height=1024, num_images_per_prompt=1, guidance_scale=5).images[0] 45 | out.save("ktxl_test_image.png") 46 | 47 | if __name__ == '__main__': 48 | import fire 49 | fire.Fire(infer) -------------------------------------------------------------------------------- /dreambooth/ktxl_test_image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/dreambooth/ktxl_test_image.png -------------------------------------------------------------------------------- /dreambooth/train.sh: -------------------------------------------------------------------------------- 1 | MODEL_NAME="/path/base_model_path" 2 | INSTANCE_DIR='dog' 3 | CLASS_DIR="reg_dog" 4 | OUTPUT_DIR="trained_models/ktxl_dog_text" 5 | cfg_file=./default_config.yaml 6 | 7 | accelerate launch --config_file ${cfg_file} train_dreambooth_lora.py \ 8 | --pretrained_model_name_or_path=$MODEL_NAME \ 9 | --instance_data_dir=$INSTANCE_DIR \ 10 | --output_dir=$OUTPUT_DIR \ 11 | --class_data_dir=$CLASS_DIR \ 12 | --instance_prompt="ktxl狗" \ 13 | --class_prompt="狗" \ 14 | --train_batch_size=1 \ 15 | --gradient_accumulation_steps=1 \ 16 | --learning_rate=2e-5 \ 17 | --text_encoder_lr=5e-5 \ 18 | --lr_scheduler="polynomial" \ 19 | --lr_warmup_steps=100 \ 20 | --rank=4 \ 21 | --resolution=1024 \ 22 | --max_train_steps=1000 \ 23 | --checkpointing_steps=200 \ 24 | --num_class_images=100 \ 25 | --center_crop \ 26 | --mixed_precision='fp16' \ 27 | --seed=19980818 \ 28 | --img_repeat_nums=1 \ 29 | --sample_batch_size=2 \ 30 | --gradient_checkpointing \ 31 | --adam_weight_decay=1e-02 \ 32 | --with_prior_preservation \ 33 | --prior_loss_weight=0.7 \ 34 | --train_text_encoder \ 35 | -------------------------------------------------------------------------------- /imgs/Kolors_paper.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/imgs/Kolors_paper.pdf -------------------------------------------------------------------------------- /imgs/cn_all.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/imgs/cn_all.png -------------------------------------------------------------------------------- /imgs/fz_all.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/imgs/fz_all.png -------------------------------------------------------------------------------- /imgs/head_final3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/imgs/head_final3.png -------------------------------------------------------------------------------- /imgs/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/imgs/logo.png -------------------------------------------------------------------------------- /imgs/prompt_vis.txt: -------------------------------------------------------------------------------- 1 | 1、视觉质量: 2 | - 一对年轻的中国情侣,皮肤白皙,穿着时尚的运动装,背景是现代的北京城市天际线。面部细节,清晰的毛孔,使用最新款的相机拍摄,特写镜头,超高画质,8K,视觉盛宴 3 | 2、中国元素 4 | - 万里长城,蜿蜒 5 | - 一张北京国家体育场(俗称“鸟巢”)的高度细节图片。图片应展示体育场复杂的钢结构,强调其独特的建筑设计。场景设定在白天,天空晴朗,突显出体育场的宏伟规模和现代感。包括周围的奥林匹克公园和一些游客,以增加场景的背景和生气。 6 | - 上海外滩 7 | 3、复杂语义理解 8 | - 满月下的街道,熙熙攘攘的行人正在享受繁华夜生活。街角摊位上,一位有着火红头发、穿着标志性天鹅绒斗篷的年轻女子,正在和脾气暴躁的老小贩讨价还价。这个脾气暴躁的小贩身材高大、老道,身着一套整洁西装,留着小胡子,用他那部蒸汽朋克式的电话兴致勃勃地交谈 9 | - 画面有四只神兽:朱雀、玄武、青龙、白虎。朱雀位于画面上方,羽毛鲜红如火,尾羽如凤凰般绚丽,翅膀展开时似燃烧的火焰。玄武居于下方,是龟蛇相缠的形象,巨龟背上盘绕着一条黑色巨蛇,龟甲上有古老的符文,蛇眼冰冷锐利。青龙位于右方,长身盘旋在天际,龙鳞碧绿如翡翠,龙须飘逸,龙角如鹿,口吐云雾。白虎居于左方,体态威猛,白色的皮毛上有黑色斑纹,双眼炯炯有神,尖牙利爪,周围是苍茫的群山和草原。 10 | - 一张高对比度的照片,熊猫骑在马上,戴着巫师帽,正在看书,马站在土墙旁的街道上,有绿草从街道的裂缝中长出来。 11 | 4、文字绘制 12 | - 一张瓢虫的照片,微距,变焦,高质量,电影,瓢虫拿着一个木牌,上面写着“我爱世界” 的文字 13 | - 一只小橘猫在弹钢琴,钢琴是黑色的牌子是“KOLORS”,猫的身影清晰的映照在钢琴上 14 | - 街边的路牌,上面写着“天道酬勤” -------------------------------------------------------------------------------- /imgs/wechat.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/imgs/wechat.png -------------------------------------------------------------------------------- /imgs/wz_all.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/imgs/wz_all.png -------------------------------------------------------------------------------- /imgs/zl8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/imgs/zl8.png -------------------------------------------------------------------------------- /imgs/可图KOLORS模型商业授权申请书-英文版本.docx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/imgs/可图KOLORS模型商业授权申请书-英文版本.docx -------------------------------------------------------------------------------- /imgs/可图KOLORS模型商业授权申请书.docx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/imgs/可图KOLORS模型商业授权申请书.docx -------------------------------------------------------------------------------- /inpainting/README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## 📖 Introduction 4 | 5 | We provide Kolors-Inpainting inference code and weights which were initialized with [Kolors-Basemodel](https://huggingface.co/Kwai-Kolors/Kolors). Examples of Kolors-Inpainting results are as follows: 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 |
Inpainting Results
Original Image Masked Image Prompt Result Image
穿着美少女战士的衣服,一件类似于水手服风格的衣服,包括一个白色紧身上衣,前胸搭配一个大大的红色蝴蝶结。衣服的领子部分呈蓝色,并且有白色条纹。她还穿着一条蓝色百褶裙,超高清,辛烷渲染,高级质感,32k,高分辨率,最好的质量,超级细节,景深

Wearing Sailor Moon's outfit, a sailor-style outfit consisting of a white tight top with a large red bow on the chest. The collar of the outfit is blue and has white stripes. She also wears a blue pleated skirt, Ultra HD, Octane Rendering, Premium Textures, 32k, High Resolution, Best Quality, Super Detail, Depth of Field
穿着钢铁侠的衣服,高科技盔甲,主要颜色为红色和金色,并且有一些银色装饰。胸前有一个亮起的圆形反应堆装置,充满了未来科技感。超清晰,高质量,超逼真,高分辨率,最好的质量,超级细节,景深

Wearing Iron Man's clothes, high-tech armor, the main colors are red and gold, and there are some silver decorations. There is a light-up round reactor device on the chest, full of futuristic technology. Ultra-clear, high-quality, ultra-realistic, high-resolution, best quality, super details, depth of field
37 | 38 | 39 | 40 |
41 | 42 | **Model details** 43 | 44 | - For inpainting, the UNet has 5 additional input channels (4 for the encoded masked image and 1 for the mask itself). The weights for the encoded masked-image channels were initialized from the non-inpainting checkpoint, while the weights for the mask channel were zero-initialized. 45 | - To improve the robustness of the inpainting model, we adopt a more diverse strategy for generating masks, including random masks, subject segmentation masks, rectangular masks, and masks based on dilation operations. 46 | 47 | 48 |
49 | 50 | 51 | ## 📊 Evaluation 52 | For evaluation, we created a test set comprising 200 masked images and text prompts. We invited several image experts to provide unbiased ratings for the generated results of different models. The experts assessed the generated images based on four criteria: visual appeal, text faithfulness, inpainting artifacts, and overall satisfaction. Inpainting artifacts measure the perceptual boundaries in the inpainting results, while the other criteria adhere to the evaluation standards of the BaseModel. The specific results are summarized in the table below, where Kolors-Inpainting achieved the highest overall satisfaction score. 53 | 54 | | Model | Average Overall Satisfaction | Average Inpainting Artifacts | Average Visual Appeal | Average Text Faithfulness | 55 | | :-----------------: | :-----------: | :-----------: | :-----------: | :-----------: | 56 | | SDXL-Inpainting | 2.573 | 1.205 | 3.000 | 4.299 | 57 | | **Kolors-Inpainting** | **3.493** | **0.204** | **3.855** | **4.346** | 58 |
59 | 60 | *The higher the scores for Average Overall Satisfaction, Average Visual Appeal, and Average Text Faithfulness, the better. Conversely, the lower the score for Average Inpainting Artifacts, the better.* 61 | 62 |
63 | The comparison results of SDXL-Inpainting and Kolors-Inpainting are as follows: 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 |
Comparison Results
Original Image Masked Image Prompt SDXL-Inpainting Result Kolors-Inpainting Result
穿着美少女战士的衣服,一件类似于水手服风格的衣服,包括一个白色紧身上衣,前胸搭配一个大大的红色蝴蝶结。衣服的领子部分呈蓝色,并且有白色条纹。她还穿着一条蓝色百褶裙,超高清,辛烷渲染,高级质感,32k,高分辨率,最好的质量,超级细节,景深

Wearing Sailor Moon's outfit, a sailor-style outfit consisting of a white tight top with a large red bow on the chest. The collar of the outfit is blue and has white stripes. She also wears a blue pleated skirt, Ultra HD, Octane Rendering, Premium Textures, 32k, High Resolution, Best Quality, Super Detail, Depth of Field
穿着钢铁侠的衣服,高科技盔甲,主要颜色为红色和金色,并且有一些银色装饰。胸前有一个亮起的圆形反应堆装置,充满了未来科技感。超清晰,高质量,超逼真,高分辨率,最好的质量,超级细节,景深

Wearing Iron Man's clothes, high-tech armor, the main colors are red and gold, and there are some silverdecorations. There is a light-up round reactor device on the chest, full of futuristic technology. Ultra-clear , high-quality, ultra-realistic, high-resolution, best quality, super details, depth of field
穿着白雪公主的衣服,经典的蓝色裙子,并且在袖口处饰有红色细节,超高清,辛烷渲染,高级质感,32k

Dressed in Snow White's classic blue skirt with red details at the cuffs, Ultra HD, Octane Rendering, Premium Textures, 32k
一只带着红色帽子的小猫咪,圆脸,大眼,极度可爱,高饱和度,立体,柔和的光线

A kitten wearing a red hat, round face, big eyes, extremely cute, high saturation, three-dimensional, soft light
这是一幅令人垂涎欲滴的火锅画面,各种美味的食材在翻滚的锅中煮着,散发出的热气和香气令人陶醉。火红的辣椒和鲜艳的辣椒油熠熠生辉,具有诱人的招人入胜之色彩。锅内肉质细腻的薄切牛肉、爽口的豆腐皮、鲍汁浓郁的金针菇、爽脆的蔬菜,融合在一起,营造出五彩斑斓的视觉呈现

This is a mouth-watering hot pot scene, with all kinds of delicious ingredients cooking in the boiling pot, emitting intoxicating heat and aroma. The fiery red peppers and bright chili oil are shining, with attractive and fascinating colors. The delicate thin-cut beef, refreshing tofu skin, enoki mushrooms with rich abalone sauce, and crisp vegetables in the pot are combined together to create a colorful visual presentation
121 | 122 | *Kolors-Inpainting employs Chinese prompts, while SDXL-Inpainting uses English prompts.* 123 | 124 | 125 | 126 | ## 🛠️ Usage 127 | 128 | ### Requirements 129 | 130 | The dependencies and installation are basically the same as the [Kolors-BaseModel](https://huggingface.co/Kwai-Kolors/Kolors). 131 | 132 |
133 | 134 | 1. Repository Cloning and Dependency Installation 135 | 136 | ```bash 137 | apt-get install git-lfs 138 | git clone https://github.com/Kwai-Kolors/Kolors 139 | cd Kolors 140 | conda create --name kolors python=3.8 141 | conda activate kolors 142 | pip install -r requirements.txt 143 | python3 setup.py install 144 | ``` 145 | 146 | 2. Weights download [link](https://huggingface.co/Kwai-Kolors/Kolors-Inpainting): 147 | ```bash 148 | huggingface-cli download --resume-download Kwai-Kolors/Kolors-Inpainting --local-dir weights/Kolors-Inpainting 149 | ``` 150 | 151 | 3. Inference: 152 | ```bash 153 | python3 inpainting/sample_inpainting.py ./inpainting/asset/3.png ./inpainting/asset/3_mask.png 穿着美少女战士的衣服,一件类似于水手服风格的衣服,包括一个白色紧身上衣,前胸搭配一个大大的红色蝴蝶结。衣服的领子部分呈蓝色,并且有白色条纹。她还穿着一条蓝色百褶裙,超高清,辛烷渲染,高级质感,32k,高分辨率,最好的质量,超级细节,景深 154 | 155 | python3 inpainting/sample_inpainting.py ./inpainting/asset/4.png ./inpainting/asset/4_mask.png 穿着钢铁侠的衣服,高科技盔甲,主要颜色为红色和金色,并且有一些银色装饰。胸前有一个亮起的圆形反应堆装置,充满了未来科技感。超清晰,高质量,超逼真,高分辨率,最好的质量,超级细节,景深 156 | 157 | # The image will be saved to "scripts/outputs/" 158 | ``` 159 | 160 |
161 | -------------------------------------------------------------------------------- /inpainting/asset/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/inpainting/asset/1.png -------------------------------------------------------------------------------- /inpainting/asset/1_kolors.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/inpainting/asset/1_kolors.png -------------------------------------------------------------------------------- /inpainting/asset/1_masked.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/inpainting/asset/1_masked.png -------------------------------------------------------------------------------- /inpainting/asset/1_sdxl.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/inpainting/asset/1_sdxl.png -------------------------------------------------------------------------------- /inpainting/asset/2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/inpainting/asset/2.png -------------------------------------------------------------------------------- /inpainting/asset/2_kolors.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/inpainting/asset/2_kolors.png -------------------------------------------------------------------------------- /inpainting/asset/2_masked.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/inpainting/asset/2_masked.png -------------------------------------------------------------------------------- /inpainting/asset/2_sdxl.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/inpainting/asset/2_sdxl.png -------------------------------------------------------------------------------- /inpainting/asset/3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/inpainting/asset/3.png -------------------------------------------------------------------------------- /inpainting/asset/3_kolors.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/inpainting/asset/3_kolors.png -------------------------------------------------------------------------------- /inpainting/asset/3_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/inpainting/asset/3_mask.png -------------------------------------------------------------------------------- /inpainting/asset/3_masked.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/inpainting/asset/3_masked.png -------------------------------------------------------------------------------- /inpainting/asset/3_sdxl.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/inpainting/asset/3_sdxl.png -------------------------------------------------------------------------------- /inpainting/asset/4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/inpainting/asset/4.png -------------------------------------------------------------------------------- /inpainting/asset/4_kolors.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/inpainting/asset/4_kolors.png -------------------------------------------------------------------------------- /inpainting/asset/4_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/inpainting/asset/4_mask.png -------------------------------------------------------------------------------- /inpainting/asset/4_masked.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/inpainting/asset/4_masked.png -------------------------------------------------------------------------------- /inpainting/asset/4_sdxl.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/inpainting/asset/4_sdxl.png -------------------------------------------------------------------------------- /inpainting/asset/5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/inpainting/asset/5.png -------------------------------------------------------------------------------- /inpainting/asset/5_kolors.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/inpainting/asset/5_kolors.png -------------------------------------------------------------------------------- /inpainting/asset/5_masked.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/inpainting/asset/5_masked.png -------------------------------------------------------------------------------- /inpainting/asset/5_sdxl.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/inpainting/asset/5_sdxl.png -------------------------------------------------------------------------------- /inpainting/sample_inpainting.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os, sys 3 | from PIL import Image 4 | 5 | from diffusers import ( 6 | AutoencoderKL, 7 | UNet2DConditionModel, 8 | EulerDiscreteScheduler 9 | ) 10 | 11 | root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 12 | sys.path.append(root_dir) 13 | 14 | from kolors.pipelines.pipeline_stable_diffusion_xl_chatglm_256_inpainting import StableDiffusionXLInpaintPipeline 15 | from kolors.models.modeling_chatglm import ChatGLMModel 16 | from kolors.models.tokenization_chatglm import ChatGLMTokenizer 17 | 18 | def infer(image_path, mask_path, prompt): 19 | 20 | ckpt_dir = f'{root_dir}/weights/Kolors-Inpainting' 21 | text_encoder = ChatGLMModel.from_pretrained( 22 | f'{ckpt_dir}/text_encoder', 23 | torch_dtype=torch.float16).half() 24 | tokenizer = ChatGLMTokenizer.from_pretrained(f'{ckpt_dir}/text_encoder') 25 | vae = AutoencoderKL.from_pretrained(f"{ckpt_dir}/vae", revision=None).half() 26 | scheduler = EulerDiscreteScheduler.from_pretrained(f"{ckpt_dir}/scheduler") 27 | unet = UNet2DConditionModel.from_pretrained(f"{ckpt_dir}/unet", revision=None).half() 28 | 29 | pipe = StableDiffusionXLInpaintPipeline( 30 | vae=vae, 31 | text_encoder=text_encoder, 32 | tokenizer=tokenizer, 33 | unet=unet, 34 | scheduler=scheduler 35 | ) 36 | 37 | pipe.to("cuda") 38 | pipe.enable_attention_slicing() 39 | 40 | generator = torch.Generator(device="cpu").manual_seed(603) 41 | basename = image_path.rsplit('/', 1)[-1].rsplit('.', 1)[0] 42 | image = Image.open(image_path).convert('RGB') 43 | mask_image = Image.open(mask_path).convert('RGB') 44 | 45 | result = pipe( 46 | prompt = prompt, 47 | image = image, 48 | mask_image = mask_image, 49 | height=1024, 50 | width=768, 51 | guidance_scale = 6.0, 52 | generator= generator, 53 | num_inference_steps= 25, 54 | negative_prompt = '残缺的手指,畸形的手指,畸形的手,残肢,模糊,低质量', 55 | num_images_per_prompt = 1, 56 | strength = 0.999 57 | ).images[0] 58 | result.save(f'{root_dir}/scripts/outputs/sample_inpainting_{basename}.jpg') 59 | 60 | if __name__ == '__main__': 61 | import fire 62 | fire.Fire(infer) 63 | -------------------------------------------------------------------------------- /ipadapter/README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## 📖 Introduction 4 | 5 | We provide IP-Adapter-Plus weights and inference code based on [Kolors-Basemodel](https://huggingface.co/Kwai-Kolors/Kolors). Examples of Kolors-IP-Adapter-Plus results are as follows: 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 |
Example Result
Reference Image Prompt Result Image
穿着黑色T恤衫,上面中文绿色大字写着“可图”。

Wearing a black T-shirt with the Chinese characters "Ketu" written in large green letters on it.
一只可爱的小狗在奔跑。

A cute dog is running.
34 | 35 | 36 | 37 |
38 | 39 | **Our improvements** 40 | 41 | - A stronger image feature extractor. We employ the Openai-CLIP-336 model as the image encoder, which allows us to preserve more details in the reference images 42 | - More diverse and high-quality training data: We construct a large-scale and high-quality training dataset inspired by the data strategies of other works. We believe that paired training data can effectively improve performance. 43 | 44 | 45 |
46 | 47 | 48 | ## 📊 Evaluation 49 | For evaluation, we create a test set consisting of over 200 reference images and text prompts. We invite several image experts to provide fair ratings for the generated results of different models. The experts rate the generated images based on four criteria: visual appeal, text faithfulness, image faithfulness, and overall satisfaction. Image faithfulness measures the semantic preservation ability of IP-Adapter on reference images, while the other criteria follow the evaluation standards of BaseModel. The specific results are summarized in the table below, where Kolors-IP-Adapter-Plus achieves the highest overall satisfaction score. 50 | 51 | | Model | Average Overall Satisfaction | Average Image Faithfulness | Average Visual Appeal | Average Text Faithfulness | 52 | | :--------------: | :--------: | :--------: | :--------: | :--------: | 53 | | SDXL-IP-Adapter-Plus | 2.29 | 2.64 | 3.22 | 4.02 | 54 | | Midjourney-v6-CW | 2.79 | 3.0 | 3.92 | 4.35 | 55 | | **Kolors-IP-Adapter-Plus** | **3.04** | **3.25** | **4.45** | **4.30** | 56 | 57 | *The ip_scale parameter is set to 0.3 in SDXL-IP-Adapter-Plus, while Midjourney-v6-CW utilizes the default cw scale.* 58 | 59 | ------ 60 | 61 |
62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 |
Compare Result
Reference Image Prompt Kolors-IP-Adapter-Plus Result SDXL-IP-Adapter-Plus Result Midjourney-v6-CW Result
一个看向远山的少女形象,雪山背景,采用日本浮世绘风格,混合蓝色和红色柔和调色板,高分辨率

Image of a girl looking towards distant mountains, snowy mountains background, in Japanese ukiyo-e style, mixed blue and red pastel color palette, high resolution.
一个漂亮的美女,看向远方

A beautiful lady looking into the distance.
可爱的猫咪,在花丛中,看镜头

Cute cat among flowers, looking at the camera.
站在丛林前,戴着太阳帽,高画质,高细节,高清,疯狂的细节,超高清

Standing in front of the jungle, wearing a sun hat, high quality, high detail, high definition, crazy details, ultra high definition.
做个头像,新海诚动漫风格,丰富的色彩,唯美风景,清新明亮,斑驳的光影,最好的质量,超细节,8K画质

Create an avatar, Shinkai Makoto anime style, rich colors, beautiful scenery, fresh and bright, mottled light and shadow, best quality, ultra-detailed, 8K quality.
121 | 122 | *Kolors-IP-Adapter-Plus employs chinese prompts, while other methods use english prompts.* 123 | 124 | 125 | 126 | ## 🛠️ Usage 127 | 128 | ### Requirements 129 | 130 | The dependencies and installation are basically the same as the [Kolors-BaseModel](https://huggingface.co/Kwai-Kolors/Kolors). 131 | 132 |
133 | 134 | 1. Repository Cloning and Dependency Installation 135 | 136 | ```bash 137 | apt-get install git-lfs 138 | git clone https://github.com/Kwai-Kolors/Kolors 139 | cd Kolors 140 | conda create --name kolors python=3.8 141 | conda activate kolors 142 | pip install -r requirements.txt 143 | python3 setup.py install 144 | ``` 145 | 146 | 2. Weights download [link](https://huggingface.co/Kwai-Kolors/Kolors-IP-Adapter-Plus): 147 | ```bash 148 | huggingface-cli download --resume-download Kwai-Kolors/Kolors-IP-Adapter-Plus --local-dir weights/Kolors-IP-Adapter-Plus 149 | ``` 150 | or 151 | ```bash 152 | git lfs clone https://huggingface.co/Kwai-Kolors/Kolors-IP-Adapter-Plus weights/Kolors-IP-Adapter-Plus 153 | ``` 154 | 155 | 3. Inference: 156 | ```bash 157 | python ipadapter/sample_ipadapter_plus.py ./ipadapter/asset/test_ip.jpg "穿着黑色T恤衫,上面中文绿色大字写着“可图”" 158 | 159 | python ipadapter/sample_ipadapter_plus.py ./ipadapter/asset/test_ip2.png "一只可爱的小狗在奔跑" 160 | 161 | # The image will be saved to "scripts/outputs/" 162 | ``` 163 | 164 |
165 | 166 | 167 | **Note** 168 | 169 | The IP-Adapter-FaceID model based on Kolors will also be released soon! 170 | 171 |
172 | 173 | ### Acknowledgments 174 | - Thanks to [IP-Adapter](https://github.com/tencent-ailab/IP-Adapter) for providing the codebase. 175 |
176 | 177 | -------------------------------------------------------------------------------- /ipadapter/asset/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/ipadapter/asset/1.png -------------------------------------------------------------------------------- /ipadapter/asset/1_kolors_ip_result.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/ipadapter/asset/1_kolors_ip_result.jpg -------------------------------------------------------------------------------- /ipadapter/asset/1_mj_cw_result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/ipadapter/asset/1_mj_cw_result.png -------------------------------------------------------------------------------- /ipadapter/asset/1_sdxl_ip_result.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/ipadapter/asset/1_sdxl_ip_result.jpg -------------------------------------------------------------------------------- /ipadapter/asset/2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/ipadapter/asset/2.png -------------------------------------------------------------------------------- /ipadapter/asset/2_kolors_ip_result.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/ipadapter/asset/2_kolors_ip_result.jpg -------------------------------------------------------------------------------- /ipadapter/asset/2_mj_cw_result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/ipadapter/asset/2_mj_cw_result.png -------------------------------------------------------------------------------- /ipadapter/asset/2_sdxl_ip_result.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/ipadapter/asset/2_sdxl_ip_result.jpg -------------------------------------------------------------------------------- /ipadapter/asset/3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/ipadapter/asset/3.png -------------------------------------------------------------------------------- /ipadapter/asset/3_kolors_ip_result.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/ipadapter/asset/3_kolors_ip_result.jpg -------------------------------------------------------------------------------- /ipadapter/asset/3_mj_cw_result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/ipadapter/asset/3_mj_cw_result.png -------------------------------------------------------------------------------- /ipadapter/asset/3_sdxl_ip_result.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/ipadapter/asset/3_sdxl_ip_result.jpg -------------------------------------------------------------------------------- /ipadapter/asset/4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/ipadapter/asset/4.png -------------------------------------------------------------------------------- /ipadapter/asset/4_kolors_ip_result.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/ipadapter/asset/4_kolors_ip_result.jpg -------------------------------------------------------------------------------- /ipadapter/asset/4_mj_cw_result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/ipadapter/asset/4_mj_cw_result.png -------------------------------------------------------------------------------- /ipadapter/asset/4_sdxl_ip_result.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/ipadapter/asset/4_sdxl_ip_result.jpg -------------------------------------------------------------------------------- /ipadapter/asset/5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/ipadapter/asset/5.png -------------------------------------------------------------------------------- /ipadapter/asset/5_kolors_ip_result.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/ipadapter/asset/5_kolors_ip_result.jpg -------------------------------------------------------------------------------- /ipadapter/asset/5_mj_cw_result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/ipadapter/asset/5_mj_cw_result.png -------------------------------------------------------------------------------- /ipadapter/asset/5_sdxl_ip_result.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/ipadapter/asset/5_sdxl_ip_result.jpg -------------------------------------------------------------------------------- /ipadapter/asset/test_ip.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/ipadapter/asset/test_ip.jpg -------------------------------------------------------------------------------- /ipadapter/asset/test_ip2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/ipadapter/asset/test_ip2.png -------------------------------------------------------------------------------- /ipadapter/sample_ipadapter_plus.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import CLIPVisionModelWithProjection,CLIPImageProcessor 3 | from diffusers.utils import load_image 4 | import os,sys 5 | 6 | from kolors.pipelines.pipeline_stable_diffusion_xl_chatglm_256_ipadapter import StableDiffusionXLPipeline 7 | from kolors.models.modeling_chatglm import ChatGLMModel 8 | from kolors.models.tokenization_chatglm import ChatGLMTokenizer 9 | 10 | # from diffusers import UNet2DConditionModel, AutoencoderKL 11 | from diffusers import AutoencoderKL 12 | from kolors.models.unet_2d_condition import UNet2DConditionModel 13 | 14 | from diffusers import EulerDiscreteScheduler 15 | from PIL import Image 16 | 17 | root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 18 | 19 | 20 | def infer( ip_img_path, prompt ): 21 | 22 | ckpt_dir = f'{root_dir}/weights/Kolors' 23 | text_encoder = ChatGLMModel.from_pretrained( 24 | f'{ckpt_dir}/text_encoder', 25 | torch_dtype=torch.float16).half() 26 | tokenizer = ChatGLMTokenizer.from_pretrained(f'{ckpt_dir}/text_encoder') 27 | vae = AutoencoderKL.from_pretrained(f"{ckpt_dir}/vae", revision=None).half() 28 | scheduler = EulerDiscreteScheduler.from_pretrained(f"{ckpt_dir}/scheduler") 29 | unet = UNet2DConditionModel.from_pretrained(f"{ckpt_dir}/unet", revision=None).half() 30 | 31 | image_encoder = CLIPVisionModelWithProjection.from_pretrained( f'{root_dir}/weights/Kolors-IP-Adapter-Plus/image_encoder', ignore_mismatched_sizes=True).to(dtype=torch.float16) 32 | ip_img_size = 336 33 | clip_image_processor = CLIPImageProcessor( size=ip_img_size, crop_size=ip_img_size ) 34 | 35 | pipe = StableDiffusionXLPipeline( 36 | vae=vae, 37 | text_encoder=text_encoder, 38 | tokenizer=tokenizer, 39 | unet=unet, 40 | scheduler=scheduler, 41 | image_encoder=image_encoder, 42 | feature_extractor=clip_image_processor, 43 | force_zeros_for_empty_prompt=False 44 | ) 45 | 46 | pipe = pipe.to("cuda") 47 | pipe.enable_model_cpu_offload() 48 | 49 | if hasattr(pipe.unet, 'encoder_hid_proj'): 50 | pipe.unet.text_encoder_hid_proj = pipe.unet.encoder_hid_proj 51 | 52 | pipe.load_ip_adapter( f'{root_dir}/weights/Kolors-IP-Adapter-Plus' , subfolder="", weight_name=["ip_adapter_plus_general.bin"]) 53 | 54 | basename = ip_img_path.rsplit('/',1)[-1].rsplit('.',1)[0] 55 | ip_adapter_img = Image.open( ip_img_path ) 56 | generator = torch.Generator(device="cpu").manual_seed(66) 57 | 58 | for scale in [0.5]: 59 | pipe.set_ip_adapter_scale([ scale ]) 60 | # print(prompt) 61 | image = pipe( 62 | prompt= prompt , 63 | ip_adapter_image=[ ip_adapter_img ], 64 | negative_prompt="", 65 | height=1024, 66 | width=1024, 67 | num_inference_steps= 50, 68 | guidance_scale=5.0, 69 | num_images_per_prompt=1, 70 | generator=generator, 71 | ).images[0] 72 | image.save(f'{root_dir}/scripts/outputs/sample_ip_{basename}.jpg') 73 | 74 | 75 | if __name__ == '__main__': 76 | import fire 77 | fire.Fire(infer) 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | -------------------------------------------------------------------------------- /ipadapter_FaceID/README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## 📖 Introduction 4 | 5 | We provide Kolors-IP-Adapter-FaceID-Plus module weights and inference code based on [Kolors-Basemodel](https://huggingface.co/Kwai-Kolors/Kolors). Examples of Kolors-IP-Adapter-FaceID-Plus results are as follows: 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 |
Example Results
Reference Image Prompt Result Image
穿着晚礼服,在星光下的晚宴场景中,烛光闪闪,整个场景洋溢着浪漫而奢华的氛围

Wearing an evening dress, in a starry night banquet scene, candlelight flickering, the whole scene exudes a romantic and luxurious atmosphere.
西部牛仔,牛仔帽,荒野大镖客,背景是西部小镇,仙人掌,,日落余晖, 暖色调, 使用XT4胶片拍摄, 噪点, 晕影, 柯达胶卷,复古

Western cowboy, cowboy hat, Red Dead Redemption, background is a western town, cactus, sunset glow, warm tones, shot with XT4 film, grain, vignette, Kodak film, retro.
31 | 32 | - Our Kolors-IP-Adapter-FaceID-Plus module is trained on a large-scale and high-quality face dataset. We use the face ID embeddings generated by [insightface](https://github.com/deepinsight/insightface) and the CLIP features of face area to keep the face ID and structure information. 33 | 34 | ## 📊 Evaluation 35 | For evaluation, we constructed a test set consisting of over 200 reference images and text prompts. We invited several image experts to provide fair ratings for the generated results of different models. The experts assessed the generated images based on five criteria: visual appeal, text faithfulness, face similarity, facial aesthetics and overall satisfaction. Visual appeal and text faithfulness are used to measure the text-to-image generation capability, adhering to the evaluation standards of BaseModel. Meanwhile, face similarity and facial aesthetics are used to evaluate the performance of the proposed Kolors-IP-Adapter-FaceID-Plus. The results are summarized in the table below, where Kolors-IP-Adapter-FaceID-Plus outperforms SDXL-IP-Adapter-FaceID-Plus across all metrics. 36 | 37 | 38 | | Model | Average Text Faithfulness | Average Visual Appeal | Average Face Similarity | Average Facial Aesthetics | Average Overall Satisfaction | 39 | | :--------------: | :--------: | :--------: | :--------: | :--------: | :--------: | 40 | | SDXL-IP-Adapter-FaceID-Plus | 4.014 | 3.455 | 3.05 | 2.584 | 2.448 | 41 | | **Kolors-IP-Adapter-FaceID-Plus** | **4.235** | **4.374** | **4.415** | **3.887** | **3.561** | 42 | ------ 43 | 44 |
45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 |
Comparison Results
Reference Image Prompt SDXL-IP-Adapter-FaceID-Plus Kolors-IP-Adapter-FaceID-Plus
古典油画风格,油彩厚重, 古典美感,历史气息

Classical oil painting style, thick oil paint, classical aesthetic, historical atmosphere.
夜晚,浪漫的海边,落日余晖洒在海面上,晚霞映照着整个海滩,头戴花环,花短袖,飘逸的头发,背景是美丽的海滩,可爱年轻的半身照,优雅梦幻,细节繁复,超逼真,高分辨率,柔和的背景,低对比度

Night, romantic seaside, sunset glow on the sea, evening glow reflecting on the whole beach, wearing a flower crown, short floral sleeves, flowing hair, background is a beautiful beach, cute young half-body portrait, elegant and dreamy, intricate details, ultra-realistic, high resolution, soft background, low contrast.
F1赛车手, 法拉利,戴着着红黑白相间的赛车手头盔,帅气的赛车手,飞舞的彩带,背景赛车车库和天花板泛光,璀璨闪光,穿红白黑相间赛车服,色调统一且明艳,面部白皙,面部特写,正视图

F1 racer, Ferrari, wearing a red, black, and white racing helmet, handsome racer, flying ribbons, background race car garage and ceiling lights, dazzling flashes, wearing red, white, and black racing suit, unified and bright color tone, fair face, facial close-up, front view.
和服,日本传统服饰,在海边的黄昏,远山的背景,在远处的烟火,柔和的灯光,长焦镜头,夜间摄影风格,凉爽的色调,浪漫的气氛,火花四溅,时尚摄影,胶片滤镜

Kimono, traditional Japanese clothing, at the seaside at dusk, distant mountain background, fireworks in the distance, soft lighting, telephoto lens, night photography style, cool tones, romantic atmosphere, sparks flying, fashion photography, film filter.
88 | 89 | *Kolors-IP-Adapter-FaceID-Plus employs chinese prompts, while SDXL-IP-Adapter-FaceID-Plus uses english prompts.* 90 | 91 | ## 🛠️ Usage 92 | 93 | ### Requirements 94 | 95 | The dependencies and installation are basically the same as the [Kolors-BaseModel](https://huggingface.co/Kwai-Kolors/Kolors). 96 | 97 |
98 | 99 | 1. Repository Cloning and Dependency Installation 100 | 101 | ```bash 102 | apt-get install git-lfs 103 | git clone https://github.com/Kwai-Kolors/Kolors 104 | cd Kolors 105 | conda create --name kolors python=3.8 106 | conda activate kolors 107 | pip install -r requirements.txt 108 | pip install insightface onnxruntime-gpu 109 | python3 setup.py install 110 | ``` 111 | 112 | 2. Weights download [link](https://huggingface.co/Kwai-Kolors/Kolors-IP-Adapter-FaceID-Plus): 113 | ```bash 114 | huggingface-cli download --resume-download Kwai-Kolors/Kolors-IP-Adapter-FaceID-Plus --local-dir weights/Kolors-IP-Adapter-FaceID-Plus 115 | ``` 116 | or 117 | ```bash 118 | git lfs clone https://huggingface.co/Kwai-Kolors/Kolors-IP-Adapter-FaceID-Plus weights/Kolors-IP-Adapter-FaceID-Plus 119 | ``` 120 | 121 | 3. Inference: 122 | ```bash 123 | python ipadapter_FaceID/sample_ipadapter_faceid_plus.py ./ipadapter_FaceID/assets/image1.png "穿着晚礼服,在星光下的晚宴场景中,烛光闪闪,整个场景洋溢着浪漫而奢华的氛围" 124 | 125 | python ipadapter_FaceID/sample_ipadapter_faceid_plus.py ./ipadapter_FaceID/assets/image2.png "西部牛仔,牛仔帽,荒野大镖客,背景是西部小镇,仙人掌,,日落余晖, 暖色调, 使用XT4胶片拍摄, 噪点, 晕影, 柯达胶卷,复古" 126 | 127 | # The image will be saved to "scripts/outputs/" 128 | ``` 129 | 130 | ### Acknowledgments 131 | - Thanks to [insightface](https://github.com/deepinsight/insightface) for the face representations. 132 | - Thanks to [IP-Adapter](https://github.com/tencent-ailab/IP-Adapter) for the codebase. 133 |
134 | 135 | -------------------------------------------------------------------------------- /ipadapter_FaceID/assets/image1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/ipadapter_FaceID/assets/image1.png -------------------------------------------------------------------------------- /ipadapter_FaceID/assets/image1_res.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/ipadapter_FaceID/assets/image1_res.png -------------------------------------------------------------------------------- /ipadapter_FaceID/assets/image2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/ipadapter_FaceID/assets/image2.png -------------------------------------------------------------------------------- /ipadapter_FaceID/assets/image2_res.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/ipadapter_FaceID/assets/image2_res.png -------------------------------------------------------------------------------- /ipadapter_FaceID/assets/test_img1_Kolors.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/ipadapter_FaceID/assets/test_img1_Kolors.png -------------------------------------------------------------------------------- /ipadapter_FaceID/assets/test_img1_SDXL.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/ipadapter_FaceID/assets/test_img1_SDXL.png -------------------------------------------------------------------------------- /ipadapter_FaceID/assets/test_img1_org.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/ipadapter_FaceID/assets/test_img1_org.png -------------------------------------------------------------------------------- /ipadapter_FaceID/assets/test_img2_Kolors.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/ipadapter_FaceID/assets/test_img2_Kolors.png -------------------------------------------------------------------------------- /ipadapter_FaceID/assets/test_img2_SDXL.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/ipadapter_FaceID/assets/test_img2_SDXL.png -------------------------------------------------------------------------------- /ipadapter_FaceID/assets/test_img2_org.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/ipadapter_FaceID/assets/test_img2_org.png -------------------------------------------------------------------------------- /ipadapter_FaceID/assets/test_img3_Kolors.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/ipadapter_FaceID/assets/test_img3_Kolors.png -------------------------------------------------------------------------------- /ipadapter_FaceID/assets/test_img3_SDXL.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/ipadapter_FaceID/assets/test_img3_SDXL.png -------------------------------------------------------------------------------- /ipadapter_FaceID/assets/test_img3_org.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/ipadapter_FaceID/assets/test_img3_org.png -------------------------------------------------------------------------------- /ipadapter_FaceID/assets/test_img4_Kolors.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/ipadapter_FaceID/assets/test_img4_Kolors.png -------------------------------------------------------------------------------- /ipadapter_FaceID/assets/test_img4_SDXL.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/ipadapter_FaceID/assets/test_img4_SDXL.png -------------------------------------------------------------------------------- /ipadapter_FaceID/assets/test_img4_org.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/ipadapter_FaceID/assets/test_img4_org.png -------------------------------------------------------------------------------- /ipadapter_FaceID/sample_ipadapter_faceid_plus.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import CLIPVisionModelWithProjection,CLIPImageProcessor 3 | from diffusers.utils import load_image 4 | import os,sys 5 | 6 | from kolors.pipelines.pipeline_stable_diffusion_xl_chatglm_256_ipadapter_FaceID import StableDiffusionXLPipeline 7 | from kolors.models.modeling_chatglm import ChatGLMModel 8 | from kolors.models.tokenization_chatglm import ChatGLMTokenizer 9 | 10 | from diffusers import AutoencoderKL 11 | from kolors.models.unet_2d_condition import UNet2DConditionModel 12 | 13 | from diffusers import EulerDiscreteScheduler 14 | from PIL import Image 15 | 16 | root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 17 | 18 | import cv2 19 | import numpy as np 20 | import insightface 21 | from diffusers.utils import load_image 22 | from insightface.app import FaceAnalysis 23 | from insightface.data import get_image as ins_get_image 24 | 25 | class FaceInfoGenerator(): 26 | def __init__(self, root_dir = "./"): 27 | self.app = FaceAnalysis(name = 'antelopev2', root = root_dir, providers=['CUDAExecutionProvider', 'CPUExecutionProvider']) 28 | self.app.prepare(ctx_id = 0, det_size = (640, 640)) 29 | 30 | def get_faceinfo_one_img(self, image_path): 31 | face_image = load_image(image_path) 32 | face_info = self.app.get(cv2.cvtColor(np.array(face_image), cv2.COLOR_RGB2BGR)) 33 | 34 | if len(face_info) == 0: 35 | face_info = None 36 | else: 37 | face_info = sorted(face_info, key=lambda x:(x['bbox'][2]-x['bbox'][0])*(x['bbox'][3]-x['bbox'][1]))[-1] # only use the maximum face 38 | return face_info 39 | 40 | def face_bbox_to_square(bbox): 41 | ## l, t, r, b to square l, t, r, b 42 | l,t,r,b = bbox 43 | cent_x = (l + r) / 2 44 | cent_y = (t + b) / 2 45 | w, h = r - l, b - t 46 | r = max(w, h) / 2 47 | 48 | l0 = cent_x - r 49 | r0 = cent_x + r 50 | t0 = cent_y - r 51 | b0 = cent_y + r 52 | 53 | return [l0, t0, r0, b0] 54 | 55 | def infer(test_image_path, text_prompt): 56 | ckpt_dir = f'{root_dir}/weights/Kolors' 57 | ip_model_dir = f'{root_dir}/weights/Kolors-IP-Adapter-FaceID-Plus' 58 | device = "cuda:0" 59 | 60 | #### base Kolors model 61 | text_encoder = ChatGLMModel.from_pretrained( f'{ckpt_dir}/text_encoder', torch_dtype = torch.float16).half() 62 | tokenizer = ChatGLMTokenizer.from_pretrained(f'{ckpt_dir}/text_encoder') 63 | vae = AutoencoderKL.from_pretrained(f'{ckpt_dir}/vae', subfolder = "vae", revision = None) 64 | scheduler = EulerDiscreteScheduler.from_pretrained(f'{ckpt_dir}/scheduler') 65 | unet = UNet2DConditionModel.from_pretrained(f'{ckpt_dir}/unet', revision = None).half() 66 | 67 | #### clip image encoder for face structure 68 | clip_image_encoder = CLIPVisionModelWithProjection.from_pretrained(f'{ip_model_dir}/clip-vit-large-patch14-336', ignore_mismatched_sizes=True) 69 | clip_image_encoder.to(device) 70 | clip_image_processor = CLIPImageProcessor(size = 336, crop_size = 336) 71 | 72 | pipe = StableDiffusionXLPipeline( 73 | vae = vae, 74 | text_encoder = text_encoder, 75 | tokenizer = tokenizer, 76 | unet = unet, 77 | scheduler = scheduler, 78 | face_clip_encoder = clip_image_encoder, 79 | face_clip_processor = clip_image_processor, 80 | force_zeros_for_empty_prompt = False, 81 | ) 82 | pipe = pipe.to(device) 83 | pipe.enable_model_cpu_offload() 84 | 85 | pipe.load_ip_adapter_faceid_plus(f'{ip_model_dir}/ipa-faceid-plus.bin', device = device) 86 | 87 | scale = 0.8 88 | pipe.set_face_fidelity_scale(scale) 89 | 90 | #### prepare face embedding & bbox with insightface toolbox 91 | face_info_generator = FaceInfoGenerator(root_dir = "./") 92 | img = Image.open(test_image_path) 93 | face_info = face_info_generator.get_faceinfo_one_img(test_image_path) 94 | 95 | face_bbox_square = face_bbox_to_square(face_info["bbox"]) 96 | crop_image = img.crop(face_bbox_square) 97 | crop_image = crop_image.resize((336, 336)) 98 | crop_image = [crop_image] 99 | 100 | face_embeds = torch.from_numpy(np.array([face_info["embedding"]])) 101 | face_embeds = face_embeds.to(device, dtype = torch.float16) 102 | 103 | #### generate image 104 | generator = torch.Generator(device = device).manual_seed(66) 105 | image = pipe( 106 | prompt = text_prompt, 107 | negative_prompt = "", 108 | height = 1024, 109 | width = 1024, 110 | num_inference_steps= 25, 111 | guidance_scale = 5.0, 112 | num_images_per_prompt = 1, 113 | generator = generator, 114 | face_crop_image = crop_image, 115 | face_insightface_embeds = face_embeds, 116 | ).images[0] 117 | image.save(f'../scripts/outputs/test_res.png') 118 | 119 | if __name__ == '__main__': 120 | import fire 121 | fire.Fire(infer) -------------------------------------------------------------------------------- /kolors/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/kolors/__init__.py -------------------------------------------------------------------------------- /kolors/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/kolors/models/__init__.py -------------------------------------------------------------------------------- /kolors/models/configuration_chatglm.py: -------------------------------------------------------------------------------- 1 | from transformers import PretrainedConfig 2 | 3 | 4 | class ChatGLMConfig(PretrainedConfig): 5 | model_type = "chatglm" 6 | def __init__( 7 | self, 8 | num_layers=28, 9 | padded_vocab_size=65024, 10 | hidden_size=4096, 11 | ffn_hidden_size=13696, 12 | kv_channels=128, 13 | num_attention_heads=32, 14 | seq_length=2048, 15 | hidden_dropout=0.0, 16 | classifier_dropout=None, 17 | attention_dropout=0.0, 18 | layernorm_epsilon=1e-5, 19 | rmsnorm=True, 20 | apply_residual_connection_post_layernorm=False, 21 | post_layer_norm=True, 22 | add_bias_linear=False, 23 | add_qkv_bias=False, 24 | bias_dropout_fusion=True, 25 | multi_query_attention=False, 26 | multi_query_group_num=1, 27 | apply_query_key_layer_scaling=True, 28 | attention_softmax_in_fp32=True, 29 | fp32_residual_connection=False, 30 | quantization_bit=0, 31 | pre_seq_len=None, 32 | prefix_projection=False, 33 | **kwargs 34 | ): 35 | self.num_layers = num_layers 36 | self.vocab_size = padded_vocab_size 37 | self.padded_vocab_size = padded_vocab_size 38 | self.hidden_size = hidden_size 39 | self.ffn_hidden_size = ffn_hidden_size 40 | self.kv_channels = kv_channels 41 | self.num_attention_heads = num_attention_heads 42 | self.seq_length = seq_length 43 | self.hidden_dropout = hidden_dropout 44 | self.classifier_dropout = classifier_dropout 45 | self.attention_dropout = attention_dropout 46 | self.layernorm_epsilon = layernorm_epsilon 47 | self.rmsnorm = rmsnorm 48 | self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm 49 | self.post_layer_norm = post_layer_norm 50 | self.add_bias_linear = add_bias_linear 51 | self.add_qkv_bias = add_qkv_bias 52 | self.bias_dropout_fusion = bias_dropout_fusion 53 | self.multi_query_attention = multi_query_attention 54 | self.multi_query_group_num = multi_query_group_num 55 | self.apply_query_key_layer_scaling = apply_query_key_layer_scaling 56 | self.attention_softmax_in_fp32 = attention_softmax_in_fp32 57 | self.fp32_residual_connection = fp32_residual_connection 58 | self.quantization_bit = quantization_bit 59 | self.pre_seq_len = pre_seq_len 60 | self.prefix_projection = prefix_projection 61 | super().__init__(**kwargs) 62 | -------------------------------------------------------------------------------- /kolors/models/ipa_faceid_plus/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/kolors/models/ipa_faceid_plus/__init__.py -------------------------------------------------------------------------------- /kolors/models/ipa_faceid_plus/attention_processor.py: -------------------------------------------------------------------------------- 1 | # modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | class AttnProcessor2_0(torch.nn.Module): 7 | r""" 8 | Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). 9 | """ 10 | def __init__( 11 | self, 12 | hidden_size=None, 13 | cross_attention_dim=None, 14 | ): 15 | super().__init__() 16 | if not hasattr(F, "scaled_dot_product_attention"): 17 | raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") 18 | 19 | def __call__( 20 | self, 21 | attn, 22 | hidden_states, 23 | encoder_hidden_states=None, 24 | attention_mask=None, 25 | temb=None, 26 | ): 27 | residual = hidden_states 28 | 29 | if attn.spatial_norm is not None: 30 | hidden_states = attn.spatial_norm(hidden_states, temb) 31 | 32 | input_ndim = hidden_states.ndim 33 | 34 | if input_ndim == 4: 35 | batch_size, channel, height, width = hidden_states.shape 36 | hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) 37 | 38 | batch_size, sequence_length, _ = ( 39 | hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape 40 | ) 41 | 42 | if attention_mask is not None: 43 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) 44 | # scaled_dot_product_attention expects attention_mask shape to be 45 | # (batch, heads, source_length, target_length) 46 | attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) 47 | 48 | if attn.group_norm is not None: 49 | hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) 50 | 51 | query = attn.to_q(hidden_states) 52 | 53 | if encoder_hidden_states is None: 54 | encoder_hidden_states = hidden_states 55 | elif attn.norm_cross: 56 | encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) 57 | 58 | key = attn.to_k(encoder_hidden_states) 59 | value = attn.to_v(encoder_hidden_states) 60 | 61 | inner_dim = key.shape[-1] 62 | head_dim = inner_dim // attn.heads 63 | 64 | query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 65 | 66 | key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 67 | value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 68 | 69 | # the output of sdp = (batch, num_heads, seq_len, head_dim) 70 | # TODO: add support for attn.scale when we move to Torch 2.1 71 | hidden_states = F.scaled_dot_product_attention( 72 | query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False 73 | ) 74 | 75 | hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) 76 | hidden_states = hidden_states.to(query.dtype) 77 | 78 | # linear proj 79 | hidden_states = attn.to_out[0](hidden_states) 80 | # dropout 81 | hidden_states = attn.to_out[1](hidden_states) 82 | 83 | if input_ndim == 4: 84 | hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) 85 | 86 | if attn.residual_connection: 87 | hidden_states = hidden_states + residual 88 | 89 | hidden_states = hidden_states / attn.rescale_output_factor 90 | 91 | return hidden_states 92 | 93 | class IPAttnProcessor2_0(torch.nn.Module): 94 | r""" 95 | Attention processor for IP-Adapater for PyTorch 2.0. 96 | Args: 97 | hidden_size (`int`): 98 | The hidden size of the attention layer. 99 | cross_attention_dim (`int`): 100 | The number of channels in the `encoder_hidden_states`. 101 | scale (`float`, defaults to 1.0): 102 | the weight scale of image prompt. 103 | num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16): 104 | The context length of the image features. 105 | """ 106 | 107 | def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4): 108 | super().__init__() 109 | 110 | if not hasattr(F, "scaled_dot_product_attention"): 111 | raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") 112 | 113 | self.hidden_size = hidden_size 114 | self.cross_attention_dim = cross_attention_dim 115 | self.scale = scale 116 | self.num_tokens = num_tokens 117 | 118 | self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) 119 | self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) 120 | 121 | def __call__( 122 | self, 123 | attn, 124 | hidden_states, 125 | encoder_hidden_states=None, 126 | attention_mask=None, 127 | temb=None, 128 | ): 129 | residual = hidden_states 130 | 131 | if attn.spatial_norm is not None: 132 | hidden_states = attn.spatial_norm(hidden_states, temb) 133 | 134 | input_ndim = hidden_states.ndim 135 | 136 | if input_ndim == 4: 137 | batch_size, channel, height, width = hidden_states.shape 138 | hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) 139 | 140 | batch_size, sequence_length, _ = ( 141 | hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape 142 | ) 143 | 144 | if attention_mask is not None: 145 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) 146 | # scaled_dot_product_attention expects attention_mask shape to be 147 | # (batch, heads, source_length, target_length) 148 | attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) 149 | 150 | if attn.group_norm is not None: 151 | hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) 152 | 153 | query = attn.to_q(hidden_states) 154 | 155 | if encoder_hidden_states is None: 156 | encoder_hidden_states = hidden_states 157 | else: 158 | # get encoder_hidden_states, ip_hidden_states 159 | end_pos = encoder_hidden_states.shape[1] - self.num_tokens 160 | encoder_hidden_states, ip_hidden_states = encoder_hidden_states[:, :end_pos, :], encoder_hidden_states[:, end_pos:, :] 161 | if attn.norm_cross: 162 | encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) 163 | 164 | key = attn.to_k(encoder_hidden_states) 165 | value = attn.to_v(encoder_hidden_states) 166 | 167 | inner_dim = key.shape[-1] 168 | head_dim = inner_dim // attn.heads 169 | 170 | query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 171 | 172 | key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 173 | value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 174 | 175 | # the output of sdp = (batch, num_heads, seq_len, head_dim) 176 | # TODO: add support for attn.scale when we move to Torch 2.1 177 | hidden_states = F.scaled_dot_product_attention( 178 | query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False 179 | ) 180 | 181 | hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) 182 | hidden_states = hidden_states.to(query.dtype) 183 | 184 | # for ip-adapter 185 | ip_key = self.to_k_ip(ip_hidden_states) 186 | ip_value = self.to_v_ip(ip_hidden_states) 187 | 188 | ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 189 | ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) 190 | 191 | # the output of sdp = (batch, num_heads, seq_len, head_dim) 192 | # TODO: add support for attn.scale when we move to Torch 2.1 193 | ip_hidden_states = F.scaled_dot_product_attention( 194 | query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False 195 | ) 196 | 197 | ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) 198 | ip_hidden_states = ip_hidden_states.to(query.dtype) 199 | 200 | hidden_states = hidden_states + self.scale * ip_hidden_states 201 | 202 | # linear proj 203 | hidden_states = attn.to_out[0](hidden_states) 204 | # dropout 205 | hidden_states = attn.to_out[1](hidden_states) 206 | 207 | if input_ndim == 4: 208 | hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) 209 | 210 | if attn.residual_connection: 211 | hidden_states = hidden_states + residual 212 | 213 | hidden_states = hidden_states / attn.rescale_output_factor 214 | 215 | return hidden_states -------------------------------------------------------------------------------- /kolors/models/ipa_faceid_plus/ipa_faceid_plus.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | 5 | def reshape_tensor(x, heads): 6 | bs, length, width = x.shape 7 | #(bs, length, width) --> (bs, length, n_heads, dim_per_head) 8 | x = x.view(bs, length, heads, -1) 9 | # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head) 10 | x = x.transpose(1, 2) 11 | # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head) 12 | x = x.reshape(bs, heads, length, -1) 13 | return x 14 | 15 | def FeedForward(dim, mult=4): 16 | inner_dim = int(dim * mult) 17 | return nn.Sequential( 18 | nn.LayerNorm(dim), 19 | nn.Linear(dim, inner_dim, bias=False), 20 | nn.GELU(), 21 | nn.Linear(inner_dim, dim, bias=False), 22 | ) 23 | 24 | class PerceiverAttention(nn.Module): 25 | def __init__(self, *, dim, dim_head=64, heads=8): 26 | super().__init__() 27 | self.scale = dim_head**-0.5 28 | self.dim_head = dim_head 29 | self.heads = heads 30 | inner_dim = dim_head * heads 31 | 32 | self.norm1 = nn.LayerNorm(dim) 33 | self.norm2 = nn.LayerNorm(dim) 34 | 35 | self.to_q = nn.Linear(dim, inner_dim, bias=False) 36 | self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) 37 | self.to_out = nn.Linear(inner_dim, dim, bias=False) 38 | 39 | def forward(self, x, latents): 40 | """ 41 | Args: 42 | x (torch.Tensor): image features 43 | shape (b, n1, D) 44 | latent (torch.Tensor): latent features 45 | shape (b, n2, D) 46 | """ 47 | x = self.norm1(x) 48 | latents = self.norm2(latents) 49 | 50 | b, l, _ = latents.shape 51 | 52 | q = self.to_q(latents) 53 | kv_input = torch.cat((x, latents), dim=-2) 54 | k, v = self.to_kv(kv_input).chunk(2, dim=-1) 55 | 56 | q = reshape_tensor(q, self.heads) 57 | k = reshape_tensor(k, self.heads) 58 | v = reshape_tensor(v, self.heads) 59 | 60 | # attention 61 | scale = 1 / math.sqrt(math.sqrt(self.dim_head)) 62 | weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards 63 | weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) 64 | out = weight @ v 65 | 66 | out = out.permute(0, 2, 1, 3).reshape(b, l, -1) 67 | 68 | return self.to_out(out) 69 | 70 | class FacePerceiverResampler(torch.nn.Module): 71 | def __init__( 72 | self, 73 | *, 74 | dim=768, 75 | depth=4, 76 | dim_head=64, 77 | heads=16, 78 | embedding_dim=1280, 79 | output_dim=768, 80 | ff_mult=4, 81 | ): 82 | super().__init__() 83 | 84 | self.proj_in = torch.nn.Linear(embedding_dim, dim) 85 | self.proj_out = torch.nn.Linear(dim, output_dim) 86 | self.norm_out = torch.nn.LayerNorm(output_dim) 87 | self.layers = torch.nn.ModuleList([]) 88 | for _ in range(depth): 89 | self.layers.append( 90 | torch.nn.ModuleList( 91 | [ 92 | PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads), 93 | FeedForward(dim=dim, mult=ff_mult), 94 | ] 95 | ) 96 | ) 97 | 98 | def forward(self, latents, x): 99 | x = self.proj_in(x) 100 | for attn, ff in self.layers: 101 | latents = attn(x, latents) + latents 102 | latents = ff(latents) + latents 103 | latents = self.proj_out(latents) 104 | return self.norm_out(latents) 105 | 106 | class ProjPlusModel(torch.nn.Module): 107 | def __init__(self, cross_attention_dim=768, id_embeddings_dim=512, clip_embeddings_dim=1280, num_tokens=4): 108 | super().__init__() 109 | 110 | self.cross_attention_dim = cross_attention_dim 111 | self.num_tokens = num_tokens 112 | 113 | self.proj = torch.nn.Sequential( 114 | torch.nn.Linear(id_embeddings_dim, id_embeddings_dim*2), 115 | torch.nn.GELU(), 116 | torch.nn.Linear(id_embeddings_dim*2, cross_attention_dim*num_tokens), 117 | ) 118 | self.norm = torch.nn.LayerNorm(cross_attention_dim) 119 | 120 | self.perceiver_resampler = FacePerceiverResampler( 121 | dim=cross_attention_dim, 122 | depth=4, 123 | dim_head=64, 124 | heads=cross_attention_dim // 64, 125 | embedding_dim=clip_embeddings_dim, 126 | output_dim=cross_attention_dim, 127 | ff_mult=4, 128 | ) 129 | 130 | def forward(self, id_embeds, clip_embeds, shortcut = True, scale = 1.0): 131 | x = self.proj(id_embeds) 132 | x = x.reshape(-1, self.num_tokens, self.cross_attention_dim) 133 | x = self.norm(x) 134 | out = self.perceiver_resampler(x, clip_embeds) 135 | if shortcut: 136 | out = x + scale * out 137 | return out -------------------------------------------------------------------------------- /kolors/pipelines/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/kolors/pipelines/__init__.py -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | fire 2 | triton 3 | pydantic==2.8.2 4 | accelerate==0.27.2 5 | deepspeed==0.8.1 6 | huggingface-hub==0.23.4 7 | imageio==2.25.1 8 | numpy==1.21.6 9 | omegaconf==2.3.0 10 | pandas==1.3.5 11 | Pillow==9.4.0 12 | tokenizers==0.13.2 13 | torch==1.13.1 14 | torchvision==0.14.1 15 | transformers==4.42.4 16 | xformers==0.0.16 17 | safetensors==0.3.3 18 | diffusers==0.28.2 19 | sentencepiece==0.1.99 20 | gradio==4.37.2 21 | opencv-python 22 | einops 23 | timm 24 | onnxruntime 25 | -------------------------------------------------------------------------------- /scripts/outputs/sample_inpainting_3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/scripts/outputs/sample_inpainting_3.jpg -------------------------------------------------------------------------------- /scripts/outputs/sample_inpainting_4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/scripts/outputs/sample_inpainting_4.jpg -------------------------------------------------------------------------------- /scripts/outputs/sample_ip_test_ip.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/scripts/outputs/sample_ip_test_ip.jpg -------------------------------------------------------------------------------- /scripts/outputs/sample_ip_test_ip2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/scripts/outputs/sample_ip_test_ip2.jpg -------------------------------------------------------------------------------- /scripts/outputs/sample_test.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kwai-Kolors/Kolors/038818d244ed103056abd10f429729a26af4d239/scripts/outputs/sample_test.jpg -------------------------------------------------------------------------------- /scripts/sample.py: -------------------------------------------------------------------------------- 1 | import os, torch 2 | # from PIL import Image 3 | from kolors.pipelines.pipeline_stable_diffusion_xl_chatglm_256 import StableDiffusionXLPipeline 4 | from kolors.models.modeling_chatglm import ChatGLMModel 5 | from kolors.models.tokenization_chatglm import ChatGLMTokenizer 6 | from diffusers import UNet2DConditionModel, AutoencoderKL 7 | from diffusers import EulerDiscreteScheduler 8 | 9 | root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 10 | 11 | def infer(prompt): 12 | ckpt_dir = f'{root_dir}/weights/Kolors' 13 | text_encoder = ChatGLMModel.from_pretrained( 14 | f'{ckpt_dir}/text_encoder', 15 | torch_dtype=torch.float16).half() 16 | tokenizer = ChatGLMTokenizer.from_pretrained(f'{ckpt_dir}/text_encoder') 17 | vae = AutoencoderKL.from_pretrained(f"{ckpt_dir}/vae", revision=None).half() 18 | scheduler = EulerDiscreteScheduler.from_pretrained(f"{ckpt_dir}/scheduler") 19 | unet = UNet2DConditionModel.from_pretrained(f"{ckpt_dir}/unet", revision=None).half() 20 | pipe = StableDiffusionXLPipeline( 21 | vae=vae, 22 | text_encoder=text_encoder, 23 | tokenizer=tokenizer, 24 | unet=unet, 25 | scheduler=scheduler, 26 | force_zeros_for_empty_prompt=False) 27 | pipe = pipe.to("cuda") 28 | pipe.enable_model_cpu_offload() 29 | image = pipe( 30 | prompt=prompt, 31 | height=1024, 32 | width=1024, 33 | num_inference_steps=50, 34 | guidance_scale=5.0, 35 | num_images_per_prompt=1, 36 | generator= torch.Generator(pipe.device).manual_seed(66)).images[0] 37 | image.save(f'{root_dir}/scripts/outputs/sample_test.jpg') 38 | 39 | 40 | if __name__ == '__main__': 41 | import fire 42 | fire.Fire(infer) 43 | -------------------------------------------------------------------------------- /scripts/sampleui.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import gradio as gr 4 | # from PIL import Image 5 | from kolors.pipelines.pipeline_stable_diffusion_xl_chatglm_256 import StableDiffusionXLPipeline 6 | from kolors.models.modeling_chatglm import ChatGLMModel 7 | from kolors.models.tokenization_chatglm import ChatGLMTokenizer 8 | from diffusers import UNet2DConditionModel, AutoencoderKL 9 | from diffusers import EulerDiscreteScheduler 10 | 11 | root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 12 | 13 | # Initialize global variables for models and pipeline 14 | text_encoder = None 15 | tokenizer = None 16 | vae = None 17 | scheduler = None 18 | unet = None 19 | pipe = None 20 | 21 | def load_models(): 22 | global text_encoder, tokenizer, vae, scheduler, unet, pipe 23 | 24 | if text_encoder is None: 25 | ckpt_dir = f'{root_dir}/weights/Kolors' 26 | 27 | # Load the text encoder on CPU (this speeds stuff up 2x) 28 | text_encoder = ChatGLMModel.from_pretrained( 29 | f'{ckpt_dir}/text_encoder', 30 | torch_dtype=torch.float16).to('cpu').half() 31 | tokenizer = ChatGLMTokenizer.from_pretrained(f'{ckpt_dir}/text_encoder') 32 | 33 | # Load the VAE and UNet on GPU 34 | vae = AutoencoderKL.from_pretrained(f"{ckpt_dir}/vae", revision=None).half().to('cuda') 35 | scheduler = EulerDiscreteScheduler.from_pretrained(f"{ckpt_dir}/scheduler") 36 | unet = UNet2DConditionModel.from_pretrained(f"{ckpt_dir}/unet", revision=None).half().to('cuda') 37 | 38 | # Prepare the pipeline 39 | pipe = StableDiffusionXLPipeline( 40 | vae=vae, 41 | text_encoder=text_encoder, 42 | tokenizer=tokenizer, 43 | unet=unet, 44 | scheduler=scheduler, 45 | force_zeros_for_empty_prompt=False) 46 | pipe = pipe.to("cuda") 47 | pipe.enable_model_cpu_offload() # Enable offloading to balance CPU/GPU usage 48 | 49 | def infer(prompt, use_random_seed, seed, height, width, num_inference_steps, guidance_scale, num_images_per_prompt): 50 | load_models() 51 | 52 | if use_random_seed: 53 | seed = torch.randint(0, 2**32 - 1, (1,)).item() 54 | 55 | generator = torch.Generator(pipe.device).manual_seed(seed) 56 | images = pipe( 57 | prompt=prompt, 58 | height=height, 59 | width=width, 60 | num_inference_steps=num_inference_steps, 61 | guidance_scale=guidance_scale, 62 | num_images_per_prompt=num_images_per_prompt, 63 | generator=generator 64 | ).images 65 | 66 | saved_images = [] 67 | output_dir = f'{root_dir}/scripts/outputs' 68 | os.makedirs(output_dir, exist_ok=True) 69 | 70 | for i, image in enumerate(images): 71 | file_path = os.path.join(output_dir, 'sample_test.jpg') 72 | base_name, ext = os.path.splitext(file_path) 73 | counter = 1 74 | while os.path.exists(file_path): 75 | file_path = f"{base_name}_{counter}{ext}" 76 | counter += 1 77 | image.save(file_path) 78 | saved_images.append(file_path) 79 | 80 | return saved_images 81 | 82 | def gradio_interface(): 83 | with gr.Blocks() as demo: 84 | with gr.Row(): 85 | with gr.Column(): 86 | gr.Markdown("## Kolors: Diffusion Model Gradio Interface") 87 | prompt = gr.Textbox(label="Prompt") 88 | use_random_seed = gr.Checkbox(label="Use Random Seed", value=True) 89 | seed = gr.Slider(minimum=0, maximum=2**32 - 1, step=1, label="Seed", randomize=True, visible=False) 90 | use_random_seed.change(lambda x: gr.update(visible=not x), use_random_seed, seed) 91 | height = gr.Slider(minimum=128, maximum=2048, step=64, label="Height", value=1024) 92 | width = gr.Slider(minimum=128, maximum=2048, step=64, label="Width", value=1024) 93 | num_inference_steps = gr.Slider(minimum=1, maximum=100, step=1, label="Inference Steps", value=50) 94 | guidance_scale = gr.Slider(minimum=1.0, maximum=20.0, step=0.1, label="Guidance Scale", value=5.0) 95 | num_images_per_prompt = gr.Slider(minimum=1, maximum=10, step=1, label="Images per Prompt", value=1) 96 | btn = gr.Button("Generate Image") 97 | 98 | with gr.Column(): 99 | output_images = gr.Gallery(label="Output Images", elem_id="output_gallery") 100 | 101 | btn.click( 102 | fn=infer, 103 | inputs=[prompt, use_random_seed, seed, height, width, num_inference_steps, guidance_scale, num_images_per_prompt], 104 | outputs=output_images 105 | ) 106 | 107 | return demo 108 | 109 | if __name__ == '__main__': 110 | gradio_interface().launch() 111 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name="kolors", 5 | version="0.1", 6 | author="Kolors", 7 | description="The training and inference code for Kolors models.", 8 | packages=find_packages(), 9 | install_requires=[], 10 | dependency_links=[], 11 | ) 12 | --------------------------------------------------------------------------------