├── LICENSE ├── README.md ├── assets ├── CuteCat.jpeg ├── Dog.png └── Lotus.jpeg ├── bash_scripts ├── canny_controlnet_inference.sh ├── controlnet_tile_inference.sh ├── depth_controlnet_inference.sh └── lora_inference.sh ├── inference.py ├── model ├── __pycache__ │ ├── adapter.cpython-310.pyc │ └── unet_adapter.cpython-310.pyc ├── adapter.py ├── unet_adapter.py └── utils.py ├── pipeline ├── __pycache__ │ ├── pipeline_sd_xl_adapter.cpython-310.pyc │ ├── pipeline_sd_xl_adapter_controlnet.cpython-310.pyc │ └── pipeline_sd_xl_adapter_controlnet_img2img.cpython-310.pyc ├── pipeline_sd_xl_adapter.py ├── pipeline_sd_xl_adapter_controlnet.py └── pipeline_sd_xl_adapter_controlnet_img2img.py ├── requirements.txt └── scripts ├── __pycache__ ├── inference_controlnet.cpython-310.pyc ├── inference_ctrlnet_tile.cpython-310.pyc ├── inference_lora.cpython-310.pyc └── utils.cpython-310.pyc ├── inference_controlnet.py ├── inference_ctrlnet_tile.py ├── inference_lora.py └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # X-Adapter 2 | 3 | This repository is the official implementation of [X-Adapter](https://arxiv.org/abs/2312.02238). 4 | 5 | **[X-Adapter: Adding Universal Compatibility of Plugins for Upgraded Diffusion Model](https://arxiv.org/abs/2312.02238)** 6 |
7 | [Lingmin Ran](), 8 | [Xiaodong Cun](https://vinthony.github.io/academic/), 9 | [Jia-Wei Liu](https://jia-wei-liu.github.io/), 10 | [Rui Zhao](https://ruizhaocv.github.io/), 11 | [Song Zijie](), 12 | [Xintao Wang](https://xinntao.github.io/), 13 | [Jussi Keppo](https://www.jussikeppo.com/), 14 | [Mike Zheng Shou](https://sites.google.com/view/showlab) 15 |
16 | 17 | [![Project Website](https://img.shields.io/badge/Project-Website-orange)](https://showlab.github.io/X-Adapter/) 18 | [![arXiv](https://img.shields.io/badge/arXiv-2312.02238-b31b1b.svg)](https://arxiv.org/abs/2312.02238) 19 | 20 | ![Overview_v7](https://github.com/showlab/X-Adapter/assets/152716091/eb41c508-826c-404f-8223-09765765823b) 21 | 22 | X-Adapter enables plugins pretrained on the old version (e.g. SD1.5) directly work with the upgraded Model (e.g., SDXL) without further retraining. 23 | 24 | [//]: # (

) 25 | 26 | [//]: # ( ) 27 | 28 | [//]: # (
) 29 | 30 | [//]: # (Given a video-text pair as input, our method, Tune-A-Video, fine-tunes a pre-trained text-to-image diffusion model for text-to-video generation.) 31 | 32 | [//]: # (

) 33 | 34 | ### Thank @[kijai](https://github.com/kijai) for CumfyUI implementation [here](https://github.com/kijai/ComfyUI-Diffusers-X-Adapter)! Please refer to this [tutorial](https://www.reddit.com/r/StableDiffusion/comments/1asuyiw/xadapter/) for hyperparameter setting. 35 | 36 | ## News 37 | 38 | - [17/02/2024] Inference code released 39 | 40 | ## Setup 41 | 42 | ### Requirements 43 | 44 | ```shell 45 | conda create -n xadapter python=3.10 46 | conda activate xadapter 47 | 48 | pip install torch==1.13.1+cu116 torchvision==0.14.1+cu116 --extra-index-url https://download.pytorch.org/whl/cu116 49 | pip install -r requirements.txt 50 | ``` 51 | 52 | Installing [xformers](https://github.com/facebookresearch/xformers) is highly recommended for high efficiency and low GPU cost. 53 | 54 | ### Weights 55 | 56 | **[Stable Diffusion]** [Stable Diffusion](https://arxiv.org/abs/2112.10752) is a latent text-to-image diffusion model capable of generating photo-realistic images given any text input. The pre-trained Stable Diffusion models can be downloaded from Hugging Face (e.g., [Stable Diffusion v1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5)). You can also use fine-tuned Stable Diffusion models trained on different styles (e.g., [Anything V4.0](https://huggingface.co/andite/anything-v4.0), [Redshift](https://huggingface.co/nitrosocke/redshift-diffusion), etc.). 57 | 58 | **[ControlNet]** [Controlnet](https://github.com/lllyasviel/ControlNet) is a method to control diffusion models with spatial conditions. You can download the ControlNet family [here](https://huggingface.co/lllyasviel/ControlNet). 59 | 60 | **[LoRA]** [LoRA](https://arxiv.org/abs/2106.09685) is a lightweight adapter to fine-tune large-scale pretrained model. It is widely used for style or identity customization in diffusion models. You can download LoRA from the diffusion community (e.g., [civitai](https://civitai.com/)). 61 | 62 | ### Checkpoint 63 | 64 | Models can be downloaded from our [Hugging Face page](https://huggingface.co/Lingmin-Ran/X-Adapter). Put the checkpoint in folder `./checkpoint/X-Adapter`. 65 | 66 | ## Usage 67 | 68 | After preparing all checkpoints, we can run inference code using different plugins. You can refer to this [tutorial](https://www.reddit.com/r/StableDiffusion/comments/1asuyiw/xadapter/) to quickly get started with X-Adapter. 69 | 70 | ### Controlnet Inference 71 | 72 | Set `--controlnet_canny_path` or `--controlnet_depth_path` to ControlNet's path in the bash script. The default value is its Hugging Face model card. 73 | 74 | sh ./bash_scripts/canny_controlnet_inference.sh 75 | sh ./bash_scripts/depth_controlnet_inference.sh 76 | 77 | ### LoRA Inference 78 | 79 | Set `--lora_model_path` to LoRA's checkpoint in the bash script. In this example we use [MoXin](https://civitai.com/models/12597/moxin), and we put it in folder `./checkpoint/lora`. 80 | 81 | sh ./bash_scripts/lora_inference.sh 82 | 83 | ### Controlnet-Tile Inference 84 | 85 | Set `--controlnet_tile_path` to ControlNet-tile's path in the bash script. The default value is its Hugging Face model card. 86 | 87 | sh ./bash_scripts/controlnet_tile_inference.sh 88 | 89 | ## Cite 90 | If you find X-Adapter useful for your research and applications, please cite us using this BibTeX: 91 | 92 | ```bibtex 93 | @article{ran2023xadapter, 94 | title={X-Adapter: Adding Universal Compatibility of Plugins for Upgraded Diffusion Model}, 95 | author={Lingmin Ran and Xiaodong Cun and Jia-Wei Liu and Rui Zhao and Song Zijie and Xintao Wang and Jussi Keppo and Mike Zheng Shou}, 96 | journal={arXiv preprint arXiv:2312.02238}, 97 | year={2023} 98 | } 99 | ``` 100 | -------------------------------------------------------------------------------- /assets/CuteCat.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/showlab/X-Adapter/e7348cb1b83e039dec8fe6f664266f1594e5103f/assets/CuteCat.jpeg -------------------------------------------------------------------------------- /assets/Dog.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/showlab/X-Adapter/e7348cb1b83e039dec8fe6f664266f1594e5103f/assets/Dog.png -------------------------------------------------------------------------------- /assets/Lotus.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/showlab/X-Adapter/e7348cb1b83e039dec8fe6f664266f1594e5103f/assets/Lotus.jpeg -------------------------------------------------------------------------------- /bash_scripts/canny_controlnet_inference.sh: -------------------------------------------------------------------------------- 1 | python inference.py --plugin_type "controlnet" \ 2 | --prompt "A cute cat, high quality, extremely detailed" \ 3 | --condition_type "canny" \ 4 | --input_image_path "./assets/CuteCat.jpeg" \ 5 | --controlnet_condition_scale_list 1.5 1.75 2.0 \ 6 | --adapter_guidance_start_list 1.00 \ 7 | --adapter_condition_scale_list 1.0 1.20 \ 8 | --height 1024 \ 9 | --width 1024 \ 10 | --height_sd1_5 512 \ 11 | --width_sd1_5 512 \ 12 | -------------------------------------------------------------------------------- /bash_scripts/controlnet_tile_inference.sh: -------------------------------------------------------------------------------- 1 | python inference.py --plugin_type "controlnet_tile" \ 2 | --prompt "best quality, extremely datailed" \ 3 | --controlnet_condition_scale_list 1.0 \ 4 | --adapter_guidance_start_list 0.7 \ 5 | --adapter_condition_scale_list 1.2 \ 6 | --input_image_path "./assets/Dog.png" \ 7 | --height 1024 \ 8 | --width 768 \ 9 | --height_sd1_5 512 \ 10 | --width_sd1_5 384 \ 11 | -------------------------------------------------------------------------------- /bash_scripts/depth_controlnet_inference.sh: -------------------------------------------------------------------------------- 1 | python inference.py --plugin_type "controlnet" \ 2 | --prompt "A colorful lotus, ink, high quality, extremely detailed" \ 3 | --condition_type "depth" \ 4 | --input_image_path "./assets/Lotus.jpeg" \ 5 | --controlnet_condition_scale_list 1.0 \ 6 | --adapter_guidance_start_list 0.80 \ 7 | --adapter_condition_scale_list 1.0 \ 8 | --height 1024 \ 9 | --width 1024 \ 10 | --height_sd1_5 512 \ 11 | --width_sd1_5 512 \ 12 | -------------------------------------------------------------------------------- /bash_scripts/lora_inference.sh: -------------------------------------------------------------------------------- 1 | python inference.py --plugin_type "lora" \ 2 | --prompt "masterpiece, best quality, ultra detailed, 1 girl , solo, smile, looking at viewer, holding flowers" \ 3 | --prompt_sd1_5 "masterpiece, best quality, ultra detailed, 1 girl, solo, smile, looking at viewer, holding flowers, shuimobysim, wuchangshuo, bonian, zhenbanqiao, badashanren" \ 4 | --adapter_guidance_start_list 0.95 \ 5 | --adapter_condition_scale_list 1.50 \ 6 | --seed 3943946911 \ 7 | --height 1024 \ 8 | --width 1024 \ 9 | --height_sd1_5 512 \ 10 | --width_sd1_5 512 \ 11 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import os 2 | import datetime 3 | import argparse 4 | 5 | from scripts.inference_controlnet import inference_controlnet 6 | from scripts.inference_lora import inference_lora 7 | from scripts.inference_ctrlnet_tile import inference_ctrlnet_tile 8 | 9 | 10 | def parse_args(input_args=None): 11 | parser = argparse.ArgumentParser(description="Inference setting for X-Adapter.") 12 | 13 | parser.add_argument( 14 | "--plugin_type", 15 | type=str, help='lora or controlnet', default="controlnet" 16 | ) 17 | parser.add_argument( 18 | "--controlnet_condition_scale_list", 19 | nargs='+', help='controlnet_scale', default=[1.0, 2.0] 20 | ) 21 | parser.add_argument( 22 | "--adapter_guidance_start_list", 23 | nargs='+', help='start of 2nd stage', default=[0.6, 0.65, 0.7, 0.75, 0.8] 24 | ) 25 | parser.add_argument( 26 | "--adapter_condition_scale_list", 27 | nargs='+', help='X-Adapter scale', default=[0.8, 1.0, 1.2] 28 | ) 29 | parser.add_argument( 30 | "--base_path", 31 | type=str, help='path to base model', default="runwayml/stable-diffusion-v1-5" 32 | ) 33 | parser.add_argument( 34 | "--sdxl_path", 35 | type=str, help='path to SDXL', default="stabilityai/stable-diffusion-xl-base-1.0" 36 | ) 37 | parser.add_argument( 38 | "--path_vae_sdxl", 39 | type=str, help='path to SDXL vae', default="madebyollin/sdxl-vae-fp16-fix" 40 | ) 41 | parser.add_argument( 42 | "--adapter_checkpoint", 43 | type=str, help='path to X-Adapter', default="./checkpoint/X-Adapter/X_Adapter_v1.bin" 44 | ) 45 | parser.add_argument( 46 | "--condition_type", 47 | type=str, help='condition type', default="canny" 48 | ) 49 | parser.add_argument( 50 | "--controlnet_canny_path", 51 | type=str, help='path to canny controlnet', default="lllyasviel/sd-controlnet-canny" 52 | ) 53 | parser.add_argument( 54 | "--controlnet_depth_path", 55 | type=str, help='path to depth controlnet', default="lllyasviel/sd-controlnet-depth" 56 | ) 57 | parser.add_argument( 58 | "--controlnet_tile_path", 59 | type=str, help='path to controlnet tile', default="lllyasviel/control_v11f1e_sd15_tile" 60 | ) 61 | parser.add_argument( 62 | "--lora_model_path", 63 | type=str, help='path to lora', default="./checkpoint/lora/MoXinV1.safetensors" 64 | ) 65 | parser.add_argument( 66 | "--prompt", 67 | type=str, help='SDXL prompt', default=None, required=True 68 | ) 69 | parser.add_argument( 70 | "--prompt_sd1_5", 71 | type=str, help='SD1.5 prompt', default=None 72 | ) 73 | parser.add_argument( 74 | "--negative_prompt", 75 | type=str, default=None 76 | ) 77 | parser.add_argument( 78 | "--iter_num", 79 | type=int, default=1 80 | ) 81 | parser.add_argument( 82 | "--input_image_path", 83 | type=str, default="./controlnet_test_image/CuteCat.jpeg" 84 | ) 85 | parser.add_argument( 86 | "--num_inference_steps", 87 | type=int, default=50 88 | ) 89 | parser.add_argument( 90 | "--guidance_scale", 91 | type=float, default=7.5 92 | ) 93 | parser.add_argument( 94 | "--seed", 95 | type=int, default=1674753452 96 | ) 97 | parser.add_argument( 98 | "--width", 99 | type=int, default=1024 100 | ) 101 | parser.add_argument( 102 | "--height", 103 | type=int, default=1024 104 | ) 105 | parser.add_argument( 106 | "--height_sd1_5", 107 | type=int, default=512 108 | ) 109 | parser.add_argument( 110 | "--width_sd1_5", 111 | type=int, default=512 112 | ) 113 | 114 | if input_args is not None: 115 | args = parser.parse_args(input_args) 116 | else: 117 | args = parser.parse_args() 118 | 119 | return args 120 | 121 | 122 | def run_inference(args): 123 | current_datetime = datetime.datetime.now() 124 | current_datetime = str(current_datetime).replace(":", "_") 125 | save_path = f"./result/{current_datetime}_lora" if args.plugin_type == "lora" else f"./result/{current_datetime}_controlnet" 126 | os.makedirs(save_path) 127 | args.save_path = save_path 128 | 129 | if args.plugin_type == "controlnet": 130 | inference_controlnet(args) 131 | elif args.plugin_type == "controlnet_tile": 132 | inference_ctrlnet_tile(args) 133 | elif args.plugin_type == "lora": 134 | inference_lora(args) 135 | else: 136 | raise NotImplementedError("not implemented yet") 137 | 138 | 139 | if __name__ == "__main__": 140 | args = parse_args() 141 | run_inference(args) 142 | -------------------------------------------------------------------------------- /model/__pycache__/adapter.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/showlab/X-Adapter/e7348cb1b83e039dec8fe6f664266f1594e5103f/model/__pycache__/adapter.cpython-310.pyc -------------------------------------------------------------------------------- /model/__pycache__/unet_adapter.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/showlab/X-Adapter/e7348cb1b83e039dec8fe6f664266f1594e5103f/model/__pycache__/unet_adapter.cpython-310.pyc -------------------------------------------------------------------------------- /model/adapter.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from collections import OrderedDict 4 | from diffusers.models.embeddings import ( 5 | TimestepEmbedding, 6 | Timesteps, 7 | ) 8 | 9 | 10 | def conv_nd(dims, *args, **kwargs): 11 | """ 12 | Create a 1D, 2D, or 3D convolution module. 13 | """ 14 | if dims == 1: 15 | return nn.Conv1d(*args, **kwargs) 16 | elif dims == 2: 17 | return nn.Conv2d(*args, **kwargs) 18 | elif dims == 3: 19 | return nn.Conv3d(*args, **kwargs) 20 | raise ValueError(f"unsupported dimensions: {dims}") 21 | 22 | 23 | def avg_pool_nd(dims, *args, **kwargs): 24 | """ 25 | Create a 1D, 2D, or 3D average pooling module. 26 | """ 27 | if dims == 1: 28 | return nn.AvgPool1d(*args, **kwargs) 29 | elif dims == 2: 30 | return nn.AvgPool2d(*args, **kwargs) 31 | elif dims == 3: 32 | return nn.AvgPool3d(*args, **kwargs) 33 | raise ValueError(f"unsupported dimensions: {dims}") 34 | 35 | 36 | def get_parameter_dtype(parameter: torch.nn.Module): 37 | try: 38 | params = tuple(parameter.parameters()) 39 | if len(params) > 0: 40 | return params[0].dtype 41 | 42 | buffers = tuple(parameter.buffers()) 43 | if len(buffers) > 0: 44 | return buffers[0].dtype 45 | 46 | except StopIteration: 47 | # For torch.nn.DataParallel compatibility in PyTorch 1.5 48 | 49 | def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]: 50 | tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] 51 | return tuples 52 | 53 | gen = parameter._named_members(get_members_fn=find_tensor_attributes) 54 | first_tuple = next(gen) 55 | return first_tuple[1].dtype 56 | 57 | 58 | class Downsample(nn.Module): 59 | """ 60 | A downsampling layer with an optional convolution. 61 | :param channels: channels in the inputs and outputs. 62 | :param use_conv: a bool determining if a convolution is applied. 63 | :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then 64 | downsampling occurs in the inner-two dimensions. 65 | """ 66 | 67 | def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): 68 | super().__init__() 69 | self.channels = channels 70 | self.out_channels = out_channels or channels 71 | self.use_conv = use_conv 72 | self.dims = dims 73 | stride = 2 if dims != 3 else (1, 2, 2) 74 | if use_conv: 75 | self.op = conv_nd(dims, self.channels, self.out_channels, 3, stride=stride, padding=padding) 76 | else: 77 | assert self.channels == self.out_channels 78 | from torch.nn import MaxUnpool2d 79 | self.op = MaxUnpool2d(dims, kernel_size=stride, stride=stride) 80 | 81 | def forward(self, x): 82 | assert x.shape[1] == self.channels 83 | return self.op(x) 84 | 85 | 86 | class Upsample(nn.Module): 87 | def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): 88 | super().__init__() 89 | self.channels = channels 90 | self.out_channels = out_channels or channels 91 | self.use_conv = use_conv 92 | self.dims = dims 93 | stride = 2 if dims != 3 else (1, 2, 2) 94 | if use_conv: 95 | self.op = nn.ConvTranspose2d(self.channels, self.out_channels, 3, stride=stride, padding=1) 96 | else: 97 | assert self.channels == self.out_channels 98 | self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) 99 | 100 | def forward(self, x, output_size): 101 | assert x.shape[1] == self.channels 102 | return self.op(x, output_size) 103 | 104 | 105 | class Linear(nn.Module): 106 | def __init__(self, temb_channels, out_channels): 107 | super(Linear, self).__init__() 108 | self.linear = nn.Linear(temb_channels, out_channels) 109 | 110 | def forward(self, x): 111 | return self.linear(x) 112 | 113 | 114 | 115 | class ResnetBlock(nn.Module): 116 | 117 | def __init__(self, in_c, out_c, down, up, ksize=3, sk=False, use_conv=True, enable_timestep=False, temb_channels=None, use_norm=False): 118 | super().__init__() 119 | self.use_norm = use_norm 120 | self.enable_timestep = enable_timestep 121 | ps = ksize // 2 122 | if in_c != out_c or sk == False: 123 | self.in_conv = nn.Conv2d(in_c, out_c, ksize, 1, ps) 124 | else: 125 | self.in_conv = None 126 | self.block1 = nn.Conv2d(out_c, out_c, 3, 1, 1) 127 | self.act = nn.ReLU() 128 | if use_norm: 129 | self.norm1 = nn.GroupNorm(num_groups=32, num_channels=out_c, eps=1e-6, affine=True) 130 | self.block2 = nn.Conv2d(out_c, out_c, ksize, 1, ps) 131 | if sk == False: 132 | self.skep = nn.Conv2d(in_c, out_c, ksize, 1, ps) 133 | else: 134 | self.skep = None 135 | 136 | self.down = down 137 | self.up = up 138 | if self.down: 139 | self.down_opt = Downsample(in_c, use_conv=use_conv) 140 | if self.up: 141 | self.up_opt = Upsample(in_c, use_conv=use_conv) 142 | if enable_timestep: 143 | self.timestep_proj = Linear(temb_channels, out_c) 144 | 145 | 146 | def forward(self, x, output_size=None, temb=None): 147 | if self.down == True: 148 | x = self.down_opt(x) 149 | if self.up == True: 150 | x = self.up_opt(x, output_size) 151 | if self.in_conv is not None: # edit 152 | x = self.in_conv(x) 153 | 154 | h = self.block1(x) 155 | if temb is not None: 156 | temb = self.timestep_proj(temb)[:, :, None, None] 157 | h = h + temb 158 | if self.use_norm: 159 | h = self.norm1(h) 160 | h = self.act(h) 161 | h = self.block2(h) 162 | if self.skep is not None: 163 | return h + self.skep(x) 164 | else: 165 | return h + x 166 | 167 | 168 | class Adapter_XL(nn.Module): 169 | 170 | def __init__(self, in_channels=[1280, 640, 320], out_channels=[1280, 1280, 640], nums_rb=3, ksize=3, sk=True, use_conv=False, use_zero_conv=True, 171 | enable_timestep=False, use_norm=False, temb_channels=None, fusion_type='ADD'): 172 | super(Adapter_XL, self).__init__() 173 | self.channels = in_channels 174 | self.nums_rb = nums_rb 175 | self.body = [] 176 | self.out = [] 177 | self.use_zero_conv = use_zero_conv 178 | self.fusion_type = fusion_type 179 | self.gamma = [] 180 | self.beta = [] 181 | self.norm = [] 182 | if fusion_type == "SPADE": 183 | self.use_zero_conv = False 184 | for i in range(len(self.channels)): 185 | if self.fusion_type == 'SPADE': 186 | # Corresponding to SPADE 187 | self.gamma.append(nn.Conv2d(out_channels[i], out_channels[i], 1, padding=0)) 188 | self.beta.append(nn.Conv2d(out_channels[i], out_channels[i], 1, padding=0)) 189 | self.norm.append(nn.BatchNorm2d(out_channels[i])) 190 | elif use_zero_conv: 191 | self.out.append(self.make_zero_conv(out_channels[i])) 192 | else: 193 | self.out.append(nn.Conv2d(out_channels[i], out_channels[i], 1, padding=0)) 194 | for j in range(nums_rb): 195 | if i==0: 196 | # 1280, 32, 32 -> 1280, 32, 32 197 | self.body.append( 198 | ResnetBlock(in_channels[i], out_channels[i], down=False, up=False, ksize=ksize, sk=sk, use_conv=use_conv, 199 | enable_timestep=enable_timestep, temb_channels=temb_channels, use_norm=use_norm)) 200 | # 1280, 32, 32 -> 1280, 32, 32 201 | elif i==1: 202 | # 640, 64, 64 -> 1280, 64, 64 203 | if j==0: 204 | self.body.append( 205 | ResnetBlock(in_channels[i], out_channels[i], down=False, up=False, ksize=ksize, sk=sk, 206 | use_conv=use_conv, enable_timestep=enable_timestep, temb_channels=temb_channels, use_norm=use_norm)) 207 | else: 208 | self.body.append( 209 | ResnetBlock(out_channels[i], out_channels[i], down=False, up=False, ksize=ksize,sk=sk, 210 | use_conv=use_conv, enable_timestep=enable_timestep, temb_channels=temb_channels, use_norm=use_norm)) 211 | else: 212 | # 320, 64, 64 -> 640, 128, 128 213 | if j==0: 214 | self.body.append( 215 | ResnetBlock(in_channels[i], out_channels[i], down=False, up=True, ksize=ksize, sk=sk, 216 | use_conv=True, enable_timestep=enable_timestep, temb_channels=temb_channels, use_norm=use_norm)) 217 | # use convtranspose2d 218 | else: 219 | self.body.append( 220 | ResnetBlock(out_channels[i], out_channels[i], down=False, up=False, ksize=ksize, sk=sk, 221 | use_conv=use_conv, enable_timestep=enable_timestep, temb_channels=temb_channels, use_norm=use_norm)) 222 | 223 | 224 | self.body = nn.ModuleList(self.body) 225 | if self.use_zero_conv: 226 | self.zero_out = nn.ModuleList(self.out) 227 | 228 | # if self.fusion_type == 'SPADE': 229 | # self.norm = nn.ModuleList(self.norm) 230 | # self.gamma = nn.ModuleList(self.gamma) 231 | # self.beta = nn.ModuleList(self.beta) 232 | # else: 233 | # self.zero_out = nn.ModuleList(self.out) 234 | 235 | 236 | # if enable_timestep: 237 | # a = 320 238 | # 239 | # time_embed_dim = a * 4 240 | # self.time_proj = Timesteps(a, True, 0) 241 | # timestep_input_dim = a 242 | # 243 | # self.time_embedding = TimestepEmbedding( 244 | # timestep_input_dim, 245 | # time_embed_dim, 246 | # act_fn='silu', 247 | # post_act_fn=None, 248 | # cond_proj_dim=None, 249 | # ) 250 | 251 | 252 | def make_zero_conv(self, channels): 253 | 254 | return zero_module(nn.Conv2d(channels, channels, 1, padding=0)) 255 | 256 | @property 257 | def dtype(self) -> torch.dtype: 258 | """ 259 | `torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype). 260 | """ 261 | return get_parameter_dtype(self) 262 | 263 | def forward(self, x, t=None): 264 | # extract features 265 | features = [] 266 | b, c, _, _ = x[-1].shape 267 | if t is not None: 268 | if not torch.is_tensor(t): 269 | # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can 270 | # This would be a good case for the `match` statement (Python 3.10+) 271 | is_mps = x[0].device.type == "mps" 272 | if isinstance(timestep, float): 273 | dtype = torch.float32 if is_mps else torch.float64 274 | else: 275 | dtype = torch.int32 if is_mps else torch.int64 276 | t = torch.tensor([t], dtype=dtype, device=x[0].device) 277 | elif len(t.shape) == 0: 278 | t = t[None].to(x[0].device) 279 | 280 | t = t.expand(b) 281 | t = self.time_proj(t) # b, 320 282 | t = t.to(dtype=x[0].dtype) 283 | t = self.time_embedding(t) # b, 1280 284 | # output_size = (b, 640, 128, 128) # last CA layer output 285 | output_size = (b, 640, (x[0].shape)[2] * 4 , (x[0].shape)[3] * 4) # last CA layer output should suit to the input size CSR 286 | 287 | for i in range(len(self.channels)): 288 | for j in range(self.nums_rb): 289 | idx = i * self.nums_rb + j 290 | if j == 0: 291 | if i < 2: 292 | out = self.body[idx](x[i], temb=t) 293 | else: 294 | out = self.body[idx](x[i], output_size=output_size, temb=t) 295 | else: 296 | out = self.body[idx](out, temb=t) 297 | if self.fusion_type == 'SPADE': 298 | out_gamma = self.gamma[i](out) 299 | out_beta = self.beta[i](out) 300 | out = [out_gamma, out_beta] 301 | else: 302 | out = self.zero_out[i](out) 303 | features.append(out) 304 | 305 | return features 306 | 307 | 308 | def zero_module(module): 309 | """ 310 | Zero out the parameters of a module and return it. 311 | """ 312 | for p in module.parameters(): 313 | p.detach().zero_() 314 | return module 315 | 316 | 317 | if __name__=='__main__': 318 | adapter = Adapter_XL(use_zero_conv=True, 319 | enable_timestep=True, use_norm=True, temb_channels=1280, fusion_type='SPADE').cuda() 320 | x = [torch.randn(4, 1280, 32, 32).cuda(), torch.randn(4, 640, 64, 64).cuda(), torch.randn(4, 320, 64, 64).cuda()] 321 | t = torch.tensor([1,2,3,4]).cuda() 322 | result = adapter(x, t=t) 323 | for xx in result: 324 | print(xx[0].shape) 325 | print(xx[1].shape) 326 | 327 | 328 | 329 | 330 | -------------------------------------------------------------------------------- /model/unet_adapter.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | from dataclasses import dataclass 15 | from typing import Any, Dict, List, Optional, Tuple, Union 16 | 17 | import torch 18 | import torch.nn as nn 19 | import torch.utils.checkpoint 20 | 21 | from diffusers.configuration_utils import ConfigMixin, register_to_config 22 | from diffusers.loaders import UNet2DConditionLoadersMixin 23 | from diffusers.utils import BaseOutput, logging 24 | from diffusers.models.activations import get_activation 25 | from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor 26 | from diffusers.models.embeddings import ( 27 | GaussianFourierProjection, 28 | ImageHintTimeEmbedding, 29 | ImageProjection, 30 | ImageTimeEmbedding, 31 | PositionNet, 32 | TextImageProjection, 33 | TextImageTimeEmbedding, 34 | TextTimeEmbedding, 35 | TimestepEmbedding, 36 | Timesteps, 37 | ) 38 | from diffusers.models.modeling_utils import ModelMixin 39 | from diffusers.models.unet_2d_blocks import ( 40 | UNetMidBlock2DCrossAttn, 41 | UNetMidBlock2DSimpleCrossAttn, 42 | get_down_block, 43 | get_up_block, 44 | ) 45 | 46 | 47 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 48 | 49 | 50 | @dataclass 51 | class UNet2DConditionOutput(BaseOutput): 52 | """ 53 | The output of [`UNet2DConditionModel`]. 54 | 55 | Args: 56 | sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): 57 | The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model. 58 | """ 59 | 60 | sample: torch.FloatTensor = None 61 | hidden_states: Optional[list] = None 62 | encoder_feature: Optional[list] = None 63 | 64 | 65 | class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): 66 | r""" 67 | A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample 68 | shaped output. 69 | 70 | This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented 71 | for all models (such as downloading or saving). 72 | 73 | Parameters: 74 | sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): 75 | Height and width of input/output sample. 76 | in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample. 77 | out_channels (`int`, *optional*, defaults to 4): Number of channels in the output. 78 | center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample. 79 | flip_sin_to_cos (`bool`, *optional*, defaults to `False`): 80 | Whether to flip the sin to cos in the time embedding. 81 | freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding. 82 | down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): 83 | The tuple of downsample blocks to use. 84 | mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`): 85 | Block type for middle of UNet, it can be either `UNetMidBlock2DCrossAttn` or 86 | `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped. 87 | up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`): 88 | The tuple of upsample blocks to use. 89 | only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`): 90 | Whether to include self-attention in the basic transformer blocks, see 91 | [`~models.attention.BasicTransformerBlock`]. 92 | block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): 93 | The tuple of output channels for each block. 94 | layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. 95 | downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution. 96 | mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block. 97 | act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. 98 | norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization. 99 | If `None`, normalization and activation layers is skipped in post-processing. 100 | norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization. 101 | cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280): 102 | The dimension of the cross attention features. 103 | transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1): 104 | The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for 105 | [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`], 106 | [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`]. 107 | encoder_hid_dim (`int`, *optional*, defaults to None): 108 | If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim` 109 | dimension to `cross_attention_dim`. 110 | encoder_hid_dim_type (`str`, *optional*, defaults to `None`): 111 | If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text 112 | embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`. 113 | attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. 114 | num_attention_heads (`int`, *optional*): 115 | The number of attention heads. If not defined, defaults to `attention_head_dim` 116 | resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config 117 | for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`. 118 | class_embed_type (`str`, *optional*, defaults to `None`): 119 | The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`, 120 | `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`. 121 | addition_embed_type (`str`, *optional*, defaults to `None`): 122 | Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or 123 | "text". "text" will use the `TextTimeEmbedding` layer. 124 | addition_time_embed_dim: (`int`, *optional*, defaults to `None`): 125 | Dimension for the timestep embeddings. 126 | num_class_embeds (`int`, *optional*, defaults to `None`): 127 | Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing 128 | class conditioning with `class_embed_type` equal to `None`. 129 | time_embedding_type (`str`, *optional*, defaults to `positional`): 130 | The type of position embedding to use for timesteps. Choose from `positional` or `fourier`. 131 | time_embedding_dim (`int`, *optional*, defaults to `None`): 132 | An optional override for the dimension of the projected time embedding. 133 | time_embedding_act_fn (`str`, *optional*, defaults to `None`): 134 | Optional activation function to use only once on the time embeddings before they are passed to the rest of 135 | the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`. 136 | timestep_post_act (`str`, *optional*, defaults to `None`): 137 | The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`. 138 | time_cond_proj_dim (`int`, *optional*, defaults to `None`): 139 | The dimension of `cond_proj` layer in the timestep embedding. 140 | conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer. 141 | conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer. 142 | projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when 143 | `class_embed_type="projection"`. Required when `class_embed_type="projection"`. 144 | class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time 145 | embeddings with the class embeddings. 146 | mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`): 147 | Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If 148 | `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the 149 | `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False` 150 | otherwise. 151 | """ 152 | 153 | _supports_gradient_checkpointing = True 154 | 155 | @register_to_config 156 | def __init__( 157 | self, 158 | sample_size: Optional[int] = None, 159 | in_channels: int = 4, 160 | out_channels: int = 4, 161 | center_input_sample: bool = False, 162 | flip_sin_to_cos: bool = True, 163 | freq_shift: int = 0, 164 | down_block_types: Tuple[str] = ( 165 | "CrossAttnDownBlock2D", 166 | "CrossAttnDownBlock2D", 167 | "CrossAttnDownBlock2D", 168 | "DownBlock2D", 169 | ), 170 | mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn", 171 | up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"), 172 | only_cross_attention: Union[bool, Tuple[bool]] = False, 173 | block_out_channels: Tuple[int] = (320, 640, 1280, 1280), 174 | layers_per_block: Union[int, Tuple[int]] = 2, 175 | downsample_padding: int = 1, 176 | mid_block_scale_factor: float = 1, 177 | act_fn: str = "silu", 178 | norm_num_groups: Optional[int] = 32, 179 | norm_eps: float = 1e-5, 180 | cross_attention_dim: Union[int, Tuple[int]] = 1280, 181 | transformer_layers_per_block: Union[int, Tuple[int]] = 1, 182 | encoder_hid_dim: Optional[int] = None, 183 | encoder_hid_dim_type: Optional[str] = None, 184 | attention_head_dim: Union[int, Tuple[int]] = 8, 185 | num_attention_heads: Optional[Union[int, Tuple[int]]] = None, 186 | dual_cross_attention: bool = False, 187 | use_linear_projection: bool = False, 188 | class_embed_type: Optional[str] = None, 189 | addition_embed_type: Optional[str] = None, 190 | addition_time_embed_dim: Optional[int] = None, 191 | num_class_embeds: Optional[int] = None, 192 | upcast_attention: bool = False, 193 | resnet_time_scale_shift: str = "default", 194 | resnet_skip_time_act: bool = False, 195 | resnet_out_scale_factor: int = 1.0, 196 | time_embedding_type: str = "positional", 197 | time_embedding_dim: Optional[int] = None, 198 | time_embedding_act_fn: Optional[str] = None, 199 | timestep_post_act: Optional[str] = None, 200 | time_cond_proj_dim: Optional[int] = None, 201 | conv_in_kernel: int = 3, 202 | conv_out_kernel: int = 3, 203 | projection_class_embeddings_input_dim: Optional[int] = None, 204 | attention_type: str = "default", 205 | class_embeddings_concat: bool = False, 206 | mid_block_only_cross_attention: Optional[bool] = None, 207 | cross_attention_norm: Optional[str] = None, 208 | addition_embed_type_num_heads=64, 209 | ): 210 | super().__init__() 211 | 212 | self.sample_size = sample_size 213 | 214 | if num_attention_heads is not None: 215 | raise ValueError( 216 | "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19." 217 | ) 218 | 219 | # If `num_attention_heads` is not defined (which is the case for most models) 220 | # it will default to `attention_head_dim`. This looks weird upon first reading it and it is. 221 | # The reason for this behavior is to correct for incorrectly named variables that were introduced 222 | # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 223 | # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking 224 | # which is why we correct for the naming here. 225 | num_attention_heads = num_attention_heads or attention_head_dim 226 | 227 | # Check inputs 228 | if len(down_block_types) != len(up_block_types): 229 | raise ValueError( 230 | f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." 231 | ) 232 | 233 | if len(block_out_channels) != len(down_block_types): 234 | raise ValueError( 235 | f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." 236 | ) 237 | 238 | if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types): 239 | raise ValueError( 240 | f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}." 241 | ) 242 | 243 | if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): 244 | raise ValueError( 245 | f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." 246 | ) 247 | 248 | if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types): 249 | raise ValueError( 250 | f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}." 251 | ) 252 | 253 | if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types): 254 | raise ValueError( 255 | f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}." 256 | ) 257 | 258 | if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types): 259 | raise ValueError( 260 | f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}." 261 | ) 262 | 263 | # input 264 | conv_in_padding = (conv_in_kernel - 1) // 2 265 | self.conv_in = nn.Conv2d( 266 | in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding 267 | ) 268 | 269 | # time 270 | if time_embedding_type == "fourier": 271 | time_embed_dim = time_embedding_dim or block_out_channels[0] * 2 272 | if time_embed_dim % 2 != 0: 273 | raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.") 274 | self.time_proj = GaussianFourierProjection( 275 | time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos 276 | ) 277 | timestep_input_dim = time_embed_dim 278 | elif time_embedding_type == "positional": 279 | time_embed_dim = time_embedding_dim or block_out_channels[0] * 4 280 | 281 | self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) 282 | timestep_input_dim = block_out_channels[0] 283 | else: 284 | raise ValueError( 285 | f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`." 286 | ) 287 | 288 | self.time_embedding = TimestepEmbedding( 289 | timestep_input_dim, 290 | time_embed_dim, 291 | act_fn=act_fn, 292 | post_act_fn=timestep_post_act, 293 | cond_proj_dim=time_cond_proj_dim, 294 | ) 295 | 296 | if encoder_hid_dim_type is None and encoder_hid_dim is not None: 297 | encoder_hid_dim_type = "text_proj" 298 | self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type) 299 | logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.") 300 | 301 | if encoder_hid_dim is None and encoder_hid_dim_type is not None: 302 | raise ValueError( 303 | f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}." 304 | ) 305 | 306 | if encoder_hid_dim_type == "text_proj": 307 | self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim) 308 | elif encoder_hid_dim_type == "text_image_proj": 309 | # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much 310 | # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use 311 | # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)` 312 | self.encoder_hid_proj = TextImageProjection( 313 | text_embed_dim=encoder_hid_dim, 314 | image_embed_dim=cross_attention_dim, 315 | cross_attention_dim=cross_attention_dim, 316 | ) 317 | elif encoder_hid_dim_type == "image_proj": 318 | # Kandinsky 2.2 319 | self.encoder_hid_proj = ImageProjection( 320 | image_embed_dim=encoder_hid_dim, 321 | cross_attention_dim=cross_attention_dim, 322 | ) 323 | elif encoder_hid_dim_type is not None: 324 | raise ValueError( 325 | f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'." 326 | ) 327 | else: 328 | self.encoder_hid_proj = None 329 | 330 | # class embedding 331 | if class_embed_type is None and num_class_embeds is not None: 332 | self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) 333 | elif class_embed_type == "timestep": 334 | self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn) 335 | elif class_embed_type == "identity": 336 | self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) 337 | elif class_embed_type == "projection": 338 | if projection_class_embeddings_input_dim is None: 339 | raise ValueError( 340 | "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set" 341 | ) 342 | # The projection `class_embed_type` is the same as the timestep `class_embed_type` except 343 | # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings 344 | # 2. it projects from an arbitrary input dimension. 345 | # 346 | # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations. 347 | # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings. 348 | # As a result, `TimestepEmbedding` can be passed arbitrary vectors. 349 | self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) 350 | elif class_embed_type == "simple_projection": 351 | if projection_class_embeddings_input_dim is None: 352 | raise ValueError( 353 | "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set" 354 | ) 355 | self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim) 356 | else: 357 | self.class_embedding = None 358 | 359 | if addition_embed_type == "text": 360 | if encoder_hid_dim is not None: 361 | text_time_embedding_from_dim = encoder_hid_dim 362 | else: 363 | text_time_embedding_from_dim = cross_attention_dim 364 | 365 | self.add_embedding = TextTimeEmbedding( 366 | text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads 367 | ) 368 | elif addition_embed_type == "text_image": 369 | # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much 370 | # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use 371 | # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)` 372 | self.add_embedding = TextImageTimeEmbedding( 373 | text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim 374 | ) 375 | elif addition_embed_type == "text_time": 376 | self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift) 377 | self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) 378 | elif addition_embed_type == "image": 379 | # Kandinsky 2.2 380 | self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim) 381 | elif addition_embed_type == "image_hint": 382 | # Kandinsky 2.2 ControlNet 383 | self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim) 384 | elif addition_embed_type is not None: 385 | raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.") 386 | 387 | if time_embedding_act_fn is None: 388 | self.time_embed_act = None 389 | else: 390 | self.time_embed_act = get_activation(time_embedding_act_fn) 391 | 392 | self.down_blocks = nn.ModuleList([]) 393 | self.up_blocks = nn.ModuleList([]) 394 | 395 | if isinstance(only_cross_attention, bool): 396 | if mid_block_only_cross_attention is None: 397 | mid_block_only_cross_attention = only_cross_attention 398 | 399 | only_cross_attention = [only_cross_attention] * len(down_block_types) 400 | 401 | if mid_block_only_cross_attention is None: 402 | mid_block_only_cross_attention = False 403 | 404 | if isinstance(num_attention_heads, int): 405 | num_attention_heads = (num_attention_heads,) * len(down_block_types) 406 | 407 | if isinstance(attention_head_dim, int): 408 | attention_head_dim = (attention_head_dim,) * len(down_block_types) 409 | 410 | if isinstance(cross_attention_dim, int): 411 | cross_attention_dim = (cross_attention_dim,) * len(down_block_types) 412 | 413 | if isinstance(layers_per_block, int): 414 | layers_per_block = [layers_per_block] * len(down_block_types) 415 | 416 | if isinstance(transformer_layers_per_block, int): 417 | transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) 418 | 419 | if class_embeddings_concat: 420 | # The time embeddings are concatenated with the class embeddings. The dimension of the 421 | # time embeddings passed to the down, middle, and up blocks is twice the dimension of the 422 | # regular time embeddings 423 | blocks_time_embed_dim = time_embed_dim * 2 424 | else: 425 | blocks_time_embed_dim = time_embed_dim 426 | 427 | # down 428 | output_channel = block_out_channels[0] 429 | for i, down_block_type in enumerate(down_block_types): 430 | input_channel = output_channel 431 | output_channel = block_out_channels[i] 432 | is_final_block = i == len(block_out_channels) - 1 433 | 434 | down_block = get_down_block( 435 | down_block_type, 436 | num_layers=layers_per_block[i], 437 | transformer_layers_per_block=transformer_layers_per_block[i], 438 | in_channels=input_channel, 439 | out_channels=output_channel, 440 | temb_channels=blocks_time_embed_dim, 441 | add_downsample=not is_final_block, 442 | resnet_eps=norm_eps, 443 | resnet_act_fn=act_fn, 444 | resnet_groups=norm_num_groups, 445 | cross_attention_dim=cross_attention_dim[i], 446 | num_attention_heads=num_attention_heads[i], 447 | downsample_padding=downsample_padding, 448 | dual_cross_attention=dual_cross_attention, 449 | use_linear_projection=use_linear_projection, 450 | only_cross_attention=only_cross_attention[i], 451 | upcast_attention=upcast_attention, 452 | resnet_time_scale_shift=resnet_time_scale_shift, 453 | attention_type=attention_type, 454 | resnet_skip_time_act=resnet_skip_time_act, 455 | resnet_out_scale_factor=resnet_out_scale_factor, 456 | cross_attention_norm=cross_attention_norm, 457 | attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, 458 | ) 459 | self.down_blocks.append(down_block) 460 | 461 | # mid 462 | if mid_block_type == "UNetMidBlock2DCrossAttn": 463 | self.mid_block = UNetMidBlock2DCrossAttn( 464 | transformer_layers_per_block=transformer_layers_per_block[-1], 465 | in_channels=block_out_channels[-1], 466 | temb_channels=blocks_time_embed_dim, 467 | resnet_eps=norm_eps, 468 | resnet_act_fn=act_fn, 469 | output_scale_factor=mid_block_scale_factor, 470 | resnet_time_scale_shift=resnet_time_scale_shift, 471 | cross_attention_dim=cross_attention_dim[-1], 472 | num_attention_heads=num_attention_heads[-1], 473 | resnet_groups=norm_num_groups, 474 | dual_cross_attention=dual_cross_attention, 475 | use_linear_projection=use_linear_projection, 476 | upcast_attention=upcast_attention, 477 | attention_type=attention_type, 478 | ) 479 | elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn": 480 | self.mid_block = UNetMidBlock2DSimpleCrossAttn( 481 | in_channels=block_out_channels[-1], 482 | temb_channels=blocks_time_embed_dim, 483 | resnet_eps=norm_eps, 484 | resnet_act_fn=act_fn, 485 | output_scale_factor=mid_block_scale_factor, 486 | cross_attention_dim=cross_attention_dim[-1], 487 | attention_head_dim=attention_head_dim[-1], 488 | resnet_groups=norm_num_groups, 489 | resnet_time_scale_shift=resnet_time_scale_shift, 490 | skip_time_act=resnet_skip_time_act, 491 | only_cross_attention=mid_block_only_cross_attention, 492 | cross_attention_norm=cross_attention_norm, 493 | ) 494 | elif mid_block_type is None: 495 | self.mid_block = None 496 | else: 497 | raise ValueError(f"unknown mid_block_type : {mid_block_type}") 498 | 499 | # count how many layers upsample the images 500 | self.num_upsamplers = 0 501 | 502 | # up 503 | reversed_block_out_channels = list(reversed(block_out_channels)) 504 | reversed_num_attention_heads = list(reversed(num_attention_heads)) 505 | reversed_layers_per_block = list(reversed(layers_per_block)) 506 | reversed_cross_attention_dim = list(reversed(cross_attention_dim)) 507 | reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block)) 508 | only_cross_attention = list(reversed(only_cross_attention)) 509 | 510 | output_channel = reversed_block_out_channels[0] 511 | for i, up_block_type in enumerate(up_block_types): 512 | is_final_block = i == len(block_out_channels) - 1 513 | 514 | prev_output_channel = output_channel 515 | output_channel = reversed_block_out_channels[i] 516 | input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] 517 | 518 | # add upsample block for all BUT final layer 519 | if not is_final_block: 520 | add_upsample = True 521 | self.num_upsamplers += 1 522 | else: 523 | add_upsample = False 524 | 525 | up_block = get_up_block( 526 | up_block_type, 527 | num_layers=reversed_layers_per_block[i] + 1, 528 | transformer_layers_per_block=reversed_transformer_layers_per_block[i], 529 | in_channels=input_channel, 530 | out_channels=output_channel, 531 | prev_output_channel=prev_output_channel, 532 | temb_channels=blocks_time_embed_dim, 533 | add_upsample=add_upsample, 534 | resnet_eps=norm_eps, 535 | resnet_act_fn=act_fn, 536 | resnet_groups=norm_num_groups, 537 | cross_attention_dim=reversed_cross_attention_dim[i], 538 | num_attention_heads=reversed_num_attention_heads[i], 539 | dual_cross_attention=dual_cross_attention, 540 | use_linear_projection=use_linear_projection, 541 | only_cross_attention=only_cross_attention[i], 542 | upcast_attention=upcast_attention, 543 | resnet_time_scale_shift=resnet_time_scale_shift, 544 | attention_type=attention_type, 545 | resnet_skip_time_act=resnet_skip_time_act, 546 | resnet_out_scale_factor=resnet_out_scale_factor, 547 | cross_attention_norm=cross_attention_norm, 548 | attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, 549 | ) 550 | self.up_blocks.append(up_block) 551 | prev_output_channel = output_channel 552 | 553 | # out 554 | if norm_num_groups is not None: 555 | self.conv_norm_out = nn.GroupNorm( 556 | num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps 557 | ) 558 | 559 | self.conv_act = get_activation(act_fn) 560 | 561 | else: 562 | self.conv_norm_out = None 563 | self.conv_act = None 564 | 565 | conv_out_padding = (conv_out_kernel - 1) // 2 566 | self.conv_out = nn.Conv2d( 567 | block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding 568 | ) 569 | 570 | if attention_type == "gated": 571 | positive_len = 768 572 | if isinstance(cross_attention_dim, int): 573 | positive_len = cross_attention_dim 574 | elif isinstance(cross_attention_dim, tuple) or isinstance(cross_attention_dim, list): 575 | positive_len = cross_attention_dim[0] 576 | self.position_net = PositionNet(positive_len=positive_len, out_dim=cross_attention_dim) 577 | 578 | 579 | @property 580 | def attn_processors(self) -> Dict[str, AttentionProcessor]: 581 | r""" 582 | Returns: 583 | `dict` of attention processors: A dictionary containing all attention processors used in the model with 584 | indexed by its weight name. 585 | """ 586 | # set recursively 587 | processors = {} 588 | 589 | def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): 590 | if hasattr(module, "set_processor"): 591 | processors[f"{name}.processor"] = module.processor 592 | 593 | for sub_name, child in module.named_children(): 594 | fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) 595 | 596 | return processors 597 | 598 | for name, module in self.named_children(): 599 | fn_recursive_add_processors(name, module, processors) 600 | 601 | return processors 602 | 603 | def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): 604 | r""" 605 | Sets the attention processor to use to compute attention. 606 | 607 | Parameters: 608 | processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): 609 | The instantiated processor class or a dictionary of processor classes that will be set as the processor 610 | for **all** `Attention` layers. 611 | 612 | If `processor` is a dict, the key needs to define the path to the corresponding cross attention 613 | processor. This is strongly recommended when setting trainable attention processors. 614 | 615 | """ 616 | count = len(self.attn_processors.keys()) 617 | 618 | if isinstance(processor, dict) and len(processor) != count: 619 | raise ValueError( 620 | f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" 621 | f" number of attention layers: {count}. Please make sure to pass {count} processor classes." 622 | ) 623 | 624 | def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): 625 | if hasattr(module, "set_processor"): 626 | if not isinstance(processor, dict): 627 | module.set_processor(processor) 628 | else: 629 | module.set_processor(processor.pop(f"{name}.processor")) 630 | 631 | for sub_name, child in module.named_children(): 632 | fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) 633 | 634 | for name, module in self.named_children(): 635 | fn_recursive_attn_processor(name, module, processor) 636 | 637 | def set_default_attn_processor(self): 638 | """ 639 | Disables custom attention processors and sets the default attention implementation. 640 | """ 641 | self.set_attn_processor(AttnProcessor()) 642 | 643 | def set_attention_slice(self, slice_size): 644 | r""" 645 | Enable sliced attention computation. 646 | 647 | When this option is enabled, the attention module splits the input tensor in slices to compute attention in 648 | several steps. This is useful for saving some memory in exchange for a small decrease in speed. 649 | 650 | Args: 651 | slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): 652 | When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If 653 | `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is 654 | provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` 655 | must be a multiple of `slice_size`. 656 | """ 657 | sliceable_head_dims = [] 658 | 659 | def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module): 660 | if hasattr(module, "set_attention_slice"): 661 | sliceable_head_dims.append(module.sliceable_head_dim) 662 | 663 | for child in module.children(): 664 | fn_recursive_retrieve_sliceable_dims(child) 665 | 666 | # retrieve number of attention layers 667 | for module in self.children(): 668 | fn_recursive_retrieve_sliceable_dims(module) 669 | 670 | num_sliceable_layers = len(sliceable_head_dims) 671 | 672 | if slice_size == "auto": 673 | # half the attention head size is usually a good trade-off between 674 | # speed and memory 675 | slice_size = [dim // 2 for dim in sliceable_head_dims] 676 | elif slice_size == "max": 677 | # make smallest slice possible 678 | slice_size = num_sliceable_layers * [1] 679 | 680 | slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size 681 | 682 | if len(slice_size) != len(sliceable_head_dims): 683 | raise ValueError( 684 | f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" 685 | f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." 686 | ) 687 | 688 | for i in range(len(slice_size)): 689 | size = slice_size[i] 690 | dim = sliceable_head_dims[i] 691 | if size is not None and size > dim: 692 | raise ValueError(f"size {size} has to be smaller or equal to {dim}.") 693 | 694 | # Recursively walk through all the children. 695 | # Any children which exposes the set_attention_slice method 696 | # gets the message 697 | def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): 698 | if hasattr(module, "set_attention_slice"): 699 | module.set_attention_slice(slice_size.pop()) 700 | 701 | for child in module.children(): 702 | fn_recursive_set_attention_slice(child, slice_size) 703 | 704 | reversed_slice_size = list(reversed(slice_size)) 705 | for module in self.children(): 706 | fn_recursive_set_attention_slice(module, reversed_slice_size) 707 | 708 | def _set_gradient_checkpointing(self, module, value=False): 709 | if hasattr(module, "gradient_checkpointing"): 710 | module.gradient_checkpointing = value 711 | 712 | def forward( 713 | self, 714 | sample: torch.FloatTensor, 715 | timestep: Union[torch.Tensor, float, int], 716 | encoder_hidden_states: torch.Tensor, 717 | class_labels: Optional[torch.Tensor] = None, 718 | timestep_cond: Optional[torch.Tensor] = None, 719 | attention_mask: Optional[torch.Tensor] = None, 720 | cross_attention_kwargs: Optional[Dict[str, Any]] = None, 721 | added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, 722 | down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, 723 | mid_block_additional_residual: Optional[torch.Tensor] = None, 724 | up_block_additional_residual: Optional[torch.Tensor] = None, 725 | encoder_attention_mask: Optional[torch.Tensor] = None, 726 | return_dict: bool = True, 727 | return_hidden_states: bool = False, 728 | return_encoder_feature: bool = False, 729 | return_early: bool = False, 730 | down_bridge_residuals: Optional[Tuple[torch.Tensor]] = None, 731 | fusion_guidance_scale: Optional[torch.FloatTensor] = None, 732 | fusion_type: Optional[str] = 'ADD', 733 | adapter: Optional = None 734 | ) -> Union[UNet2DConditionOutput, Tuple]: 735 | r""" 736 | The [`UNet2DConditionModel`] forward method. 737 | 738 | Args: 739 | sample (`torch.FloatTensor`): 740 | The noisy input tensor with the following shape `(batch, channel, height, width)`. 741 | timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input. 742 | encoder_hidden_states (`torch.FloatTensor`): 743 | The encoder hidden states with shape `(batch, sequence_length, feature_dim)`. 744 | encoder_attention_mask (`torch.Tensor`): 745 | A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If 746 | `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias, 747 | which adds large negative values to the attention scores corresponding to "discard" tokens. 748 | return_dict (`bool`, *optional*, defaults to `True`): 749 | Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain 750 | tuple. 751 | cross_attention_kwargs (`dict`, *optional*): 752 | A kwargs dictionary that if specified is passed along to the [`AttnProcessor`]. 753 | added_cond_kwargs: (`dict`, *optional*): 754 | A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that 755 | are passed along to the UNet blocks. 756 | 757 | Returns: 758 | [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: 759 | If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise 760 | a `tuple` is returned where the first element is the sample tensor. 761 | """ 762 | # By default samples have to be AT least a multiple of the overall upsampling factor. 763 | # The overall upsampling factor is equal to 2 ** (# num of upsampling layers). 764 | # However, the upsampling interpolation output size can be forced to fit any upsampling size 765 | # on the fly if necessary. 766 | ############## bridge usage ################## 767 | if return_hidden_states: 768 | hidden_states = [] 769 | return_dict = True 770 | ############## end of bridge usage ################## 771 | 772 | 773 | 774 | default_overall_up_factor = 2**self.num_upsamplers 775 | 776 | # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` 777 | forward_upsample_size = False 778 | upsample_size = None 779 | 780 | if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): 781 | logger.info("Forward upsample size to force interpolation output size.") 782 | forward_upsample_size = True 783 | 784 | # ensure attention_mask is a bias, and give it a singleton query_tokens dimension 785 | # expects mask of shape: 786 | # [batch, key_tokens] 787 | # adds singleton query_tokens dimension: 788 | # [batch, 1, key_tokens] 789 | # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: 790 | # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) 791 | # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) 792 | if attention_mask is not None: 793 | # assume that mask is expressed as: 794 | # (1 = keep, 0 = discard) 795 | # convert mask into a bias that can be added to attention scores: 796 | # (keep = +0, discard = -10000.0) 797 | attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 798 | attention_mask = attention_mask.unsqueeze(1) 799 | 800 | # convert encoder_attention_mask to a bias the same way we do for attention_mask 801 | if encoder_attention_mask is not None: 802 | encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0 803 | encoder_attention_mask = encoder_attention_mask.unsqueeze(1) 804 | 805 | # 0. center input if necessary 806 | if self.config.center_input_sample: 807 | sample = 2 * sample - 1.0 808 | 809 | # 1. time 810 | timesteps = timestep 811 | if not torch.is_tensor(timesteps): 812 | # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can 813 | # This would be a good case for the `match` statement (Python 3.10+) 814 | is_mps = sample.device.type == "mps" 815 | if isinstance(timestep, float): 816 | dtype = torch.float32 if is_mps else torch.float64 817 | else: 818 | dtype = torch.int32 if is_mps else torch.int64 819 | timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) 820 | elif len(timesteps.shape) == 0: 821 | timesteps = timesteps[None].to(sample.device) 822 | 823 | # broadcast to batch dimension in a way that's compatible with ONNX/Core ML 824 | timesteps = timesteps.expand(sample.shape[0]) 825 | 826 | t_emb = self.time_proj(timesteps) # 2, 320 827 | 828 | # `Timesteps` does not contain any weights and will always return f32 tensors 829 | # but time_embedding might actually be running in fp16. so we need to cast here. 830 | # there might be better ways to encapsulate this. 831 | t_emb = t_emb.to(dtype=sample.dtype) 832 | 833 | emb = self.time_embedding(t_emb, timestep_cond) 834 | 835 | aug_emb = None 836 | 837 | if self.class_embedding is not None: 838 | if class_labels is None: 839 | raise ValueError("class_labels should be provided when num_class_embeds > 0") 840 | 841 | if self.config.class_embed_type == "timestep": 842 | class_labels = self.time_proj(class_labels) 843 | 844 | # `Timesteps` does not contain any weights and will always return f32 tensors 845 | # there might be better ways to encapsulate this. 846 | class_labels = class_labels.to(dtype=sample.dtype) 847 | 848 | class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype) 849 | 850 | if self.config.class_embeddings_concat: 851 | emb = torch.cat([emb, class_emb], dim=-1) 852 | else: 853 | emb = emb + class_emb 854 | 855 | if self.config.addition_embed_type == "text": 856 | aug_emb = self.add_embedding(encoder_hidden_states) 857 | elif self.config.addition_embed_type == "text_image": 858 | # Kandinsky 2.1 - style 859 | if "image_embeds" not in added_cond_kwargs: 860 | raise ValueError( 861 | f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" 862 | ) 863 | 864 | image_embs = added_cond_kwargs.get("image_embeds") 865 | text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states) 866 | aug_emb = self.add_embedding(text_embs, image_embs) 867 | elif self.config.addition_embed_type == "text_time": 868 | # SDXL - style 869 | if "text_embeds" not in added_cond_kwargs: 870 | raise ValueError( 871 | f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`" 872 | ) 873 | text_embeds = added_cond_kwargs.get("text_embeds") 874 | if "time_ids" not in added_cond_kwargs: 875 | raise ValueError( 876 | f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`" 877 | ) 878 | time_ids = added_cond_kwargs.get("time_ids") 879 | time_embeds = self.add_time_proj(time_ids.flatten()) 880 | time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) 881 | 882 | add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) 883 | add_embeds = add_embeds.to(emb.dtype) 884 | aug_emb = self.add_embedding(add_embeds) 885 | elif self.config.addition_embed_type == "image": 886 | # Kandinsky 2.2 - style 887 | if "image_embeds" not in added_cond_kwargs: 888 | raise ValueError( 889 | f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" 890 | ) 891 | image_embs = added_cond_kwargs.get("image_embeds") 892 | aug_emb = self.add_embedding(image_embs) 893 | elif self.config.addition_embed_type == "image_hint": 894 | # Kandinsky 2.2 - style 895 | if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs: 896 | raise ValueError( 897 | f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`" 898 | ) 899 | image_embs = added_cond_kwargs.get("image_embeds") 900 | hint = added_cond_kwargs.get("hint") 901 | aug_emb, hint = self.add_embedding(image_embs, hint) 902 | sample = torch.cat([sample, hint], dim=1) 903 | 904 | emb = emb + aug_emb if aug_emb is not None else emb 905 | 906 | if self.time_embed_act is not None: 907 | emb = self.time_embed_act(emb) 908 | 909 | if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj": 910 | encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states) 911 | elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj": 912 | # Kadinsky 2.1 - style 913 | if "image_embeds" not in added_cond_kwargs: 914 | raise ValueError( 915 | f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" 916 | ) 917 | 918 | image_embeds = added_cond_kwargs.get("image_embeds") 919 | encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds) 920 | elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj": 921 | # Kandinsky 2.2 - style 922 | if "image_embeds" not in added_cond_kwargs: 923 | raise ValueError( 924 | f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" 925 | ) 926 | image_embeds = added_cond_kwargs.get("image_embeds") 927 | encoder_hidden_states = self.encoder_hid_proj(image_embeds) 928 | # 2. pre-process 929 | sample = self.conv_in(sample) 930 | 931 | # 2.5 GLIGEN position net 932 | if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None: 933 | cross_attention_kwargs = cross_attention_kwargs.copy() 934 | gligen_args = cross_attention_kwargs.pop("gligen") 935 | cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)} 936 | 937 | # 3. down 938 | 939 | if return_encoder_feature: 940 | encoder_feature = [] 941 | 942 | is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None 943 | is_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not None 944 | is_bridge_encoder = down_bridge_residuals is not None 945 | is_bridge = up_block_additional_residual is not None 946 | 947 | down_block_res_samples = (sample,) 948 | 949 | 950 | 951 | for downsample_block in self.down_blocks: 952 | if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: 953 | # For t2i-adapter CrossAttnDownBlock2D 954 | additional_residuals = {} 955 | if is_adapter and len(down_block_additional_residuals) > 0: 956 | additional_residuals["additional_residuals"] = down_block_additional_residuals.pop(0) 957 | 958 | sample, res_samples = downsample_block( 959 | hidden_states=sample, 960 | temb=emb, 961 | encoder_hidden_states=encoder_hidden_states, 962 | attention_mask=attention_mask, 963 | cross_attention_kwargs=cross_attention_kwargs, 964 | encoder_attention_mask=encoder_attention_mask, 965 | **additional_residuals, 966 | ) 967 | 968 | if is_bridge_encoder and len(down_bridge_residuals) > 0: 969 | sample += down_bridge_residuals.pop(0) 970 | 971 | if return_encoder_feature: 972 | encoder_feature.append(sample) 973 | else: 974 | sample, res_samples = downsample_block(hidden_states=sample, temb=emb) 975 | 976 | if is_adapter and len(down_block_additional_residuals) > 0: 977 | sample += down_block_additional_residuals.pop(0) 978 | 979 | if is_bridge_encoder and len(down_bridge_residuals) > 0: 980 | sample += down_bridge_residuals.pop(0) 981 | 982 | down_block_res_samples += res_samples 983 | 984 | 985 | if is_controlnet: 986 | new_down_block_res_samples = () 987 | 988 | for down_block_res_sample, down_block_additional_residual in zip( 989 | down_block_res_samples, down_block_additional_residuals 990 | ): 991 | down_block_res_sample = down_block_res_sample + down_block_additional_residual 992 | new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,) 993 | 994 | down_block_res_samples = new_down_block_res_samples 995 | 996 | if return_encoder_feature and return_early: 997 | return encoder_feature 998 | 999 | # 4. mid 1000 | if self.mid_block is not None: 1001 | sample = self.mid_block( 1002 | sample, 1003 | emb, 1004 | encoder_hidden_states=encoder_hidden_states, 1005 | attention_mask=attention_mask, 1006 | cross_attention_kwargs=cross_attention_kwargs, 1007 | encoder_attention_mask=encoder_attention_mask, 1008 | ) 1009 | 1010 | if is_controlnet: 1011 | sample = sample + mid_block_additional_residual 1012 | 1013 | ################# bridge usage ################# 1014 | 1015 | if is_bridge: 1016 | if fusion_guidance_scale is not None: 1017 | sample = sample + fusion_guidance_scale * (up_block_additional_residual.pop(0) - sample) 1018 | else: 1019 | sample += up_block_additional_residual.pop(0) 1020 | ################# end of bridge usage ################# 1021 | # 5. up 1022 | 1023 | for i, upsample_block in enumerate(self.up_blocks): 1024 | is_final_block = i == len(self.up_blocks) - 1 1025 | 1026 | res_samples = down_block_res_samples[-len(upsample_block.resnets) :] 1027 | down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] 1028 | 1029 | # if we have not reached the final block and need to forward the 1030 | # upsample size, we do it here 1031 | if not is_final_block and forward_upsample_size: 1032 | upsample_size = down_block_res_samples[-1].shape[2:] 1033 | 1034 | if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: 1035 | sample = upsample_block( 1036 | hidden_states=sample, 1037 | temb=emb, 1038 | res_hidden_states_tuple=res_samples, 1039 | encoder_hidden_states=encoder_hidden_states, 1040 | cross_attention_kwargs=cross_attention_kwargs, 1041 | upsample_size=upsample_size, 1042 | attention_mask=attention_mask, 1043 | encoder_attention_mask=encoder_attention_mask, 1044 | ) 1045 | else: 1046 | sample = upsample_block( 1047 | hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size 1048 | ) 1049 | 1050 | 1051 | ################# bridge usage ################# 1052 | if is_bridge and len(up_block_additional_residual) > 0: 1053 | if fusion_guidance_scale is not None: 1054 | sample = sample + fusion_guidance_scale * (up_block_additional_residual.pop(0) - sample) 1055 | else: 1056 | sample += up_block_additional_residual.pop(0) 1057 | 1058 | if return_hidden_states and i > 0: 1059 | # Collect last three up blk in SD1.5 1060 | hidden_states.append(sample) 1061 | ################# end of bridge usage ################# 1062 | 1063 | # 6. post-process 1064 | if self.conv_norm_out: 1065 | sample = self.conv_norm_out(sample) 1066 | sample = self.conv_act(sample) 1067 | sample = self.conv_out(sample) 1068 | 1069 | if not return_dict: 1070 | return (sample,) 1071 | 1072 | return UNet2DConditionOutput(sample=sample, hidden_states=hidden_states if return_hidden_states else None, 1073 | encoder_feature=encoder_feature if return_encoder_feature else None) 1074 | -------------------------------------------------------------------------------- /model/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import imageio 3 | import numpy as np 4 | from typing import Union 5 | 6 | import torch 7 | import torchvision 8 | import torch.distributed as dist 9 | 10 | from safetensors import safe_open 11 | from tqdm import tqdm 12 | from einops import rearrange 13 | from model.convert_from_ckpt import convert_ldm_unet_checkpoint, convert_ldm_clip_checkpoint, convert_ldm_vae_checkpoint 14 | # from animatediff.utils.convert_lora_safetensor_to_diffusers import convert_lora, convert_motion_lora_ckpt_to_diffusers -------------------------------------------------------------------------------- /pipeline/__pycache__/pipeline_sd_xl_adapter.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/showlab/X-Adapter/e7348cb1b83e039dec8fe6f664266f1594e5103f/pipeline/__pycache__/pipeline_sd_xl_adapter.cpython-310.pyc -------------------------------------------------------------------------------- /pipeline/__pycache__/pipeline_sd_xl_adapter_controlnet.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/showlab/X-Adapter/e7348cb1b83e039dec8fe6f664266f1594e5103f/pipeline/__pycache__/pipeline_sd_xl_adapter_controlnet.cpython-310.pyc -------------------------------------------------------------------------------- /pipeline/__pycache__/pipeline_sd_xl_adapter_controlnet_img2img.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/showlab/X-Adapter/e7348cb1b83e039dec8fe6f664266f1594e5103f/pipeline/__pycache__/pipeline_sd_xl_adapter_controlnet_img2img.cpython-310.pyc -------------------------------------------------------------------------------- /pipeline/pipeline_sd_xl_adapter.py: -------------------------------------------------------------------------------- 1 | # Copyright 2023 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import inspect 16 | import os 17 | import PIL 18 | import numpy as np 19 | import torch.nn.functional as F 20 | from PIL import Image 21 | from typing import Any, Callable, Dict, List, Optional, Tuple, Union 22 | 23 | import torch 24 | from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer 25 | 26 | from diffusers.image_processor import VaeImageProcessor 27 | from diffusers.loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin 28 | # from diffusers.models import AutoencoderKL, UNet2DConditionModel 29 | from diffusers.models import AutoencoderKL 30 | from model.unet_adapter import UNet2DConditionModel 31 | 32 | from diffusers.models.attention_processor import ( 33 | AttnProcessor2_0, 34 | LoRAAttnProcessor2_0, 35 | LoRAXFormersAttnProcessor, 36 | XFormersAttnProcessor, 37 | ) 38 | from diffusers.schedulers import KarrasDiffusionSchedulers 39 | from diffusers.utils import ( 40 | is_accelerate_available, 41 | is_accelerate_version, 42 | is_invisible_watermark_available, 43 | logging, 44 | randn_tensor, 45 | replace_example_docstring, 46 | ) 47 | from diffusers.pipelines.pipeline_utils import DiffusionPipeline 48 | from diffusers.pipelines.stable_diffusion_xl import StableDiffusionXLPipelineOutput 49 | from model.adapter import Adapter_XL 50 | 51 | 52 | if is_invisible_watermark_available(): 53 | from diffusers.pipelines.stable_diffusion_xl.watermark import StableDiffusionXLWatermarker 54 | 55 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 56 | 57 | EXAMPLE_DOC_STRING = """ 58 | Examples: 59 | ```py 60 | >>> import torch 61 | >>> from diffusers import StableDiffusionXLPipeline 62 | 63 | >>> pipe = StableDiffusionXLPipeline.from_pretrained( 64 | ... "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 65 | ... ) 66 | >>> pipe = pipe.to("cuda") 67 | 68 | >>> prompt = "a photo of an astronaut riding a horse on mars" 69 | >>> image = pipe(prompt).images[0] 70 | ``` 71 | """ 72 | 73 | 74 | # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg 75 | def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): 76 | """ 77 | Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and 78 | Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 79 | """ 80 | std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) 81 | std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) 82 | # rescale the results from guidance (fixes overexposure) 83 | noise_pred_rescaled = noise_cfg * (std_text / std_cfg) 84 | # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images 85 | noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg 86 | return noise_cfg 87 | 88 | 89 | class StableDiffusionXLAdapterPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoaderMixin): 90 | r""" 91 | Pipeline for text-to-image generation using Stable Diffusion XL. 92 | 93 | This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the 94 | library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) 95 | 96 | In addition the pipeline inherits the following loading methods: 97 | - *LoRA*: [`StableDiffusionXLPipeline.load_lora_weights`] 98 | - *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`] 99 | 100 | as well as the following saving methods: 101 | - *LoRA*: [`loaders.StableDiffusionXLPipeline.save_lora_weights`] 102 | 103 | Args: 104 | vae ([`AutoencoderKL`]): 105 | Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. 106 | text_encoder ([`CLIPTextModel`]): 107 | Frozen text-encoder. Stable Diffusion XL uses the text portion of 108 | [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically 109 | the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. 110 | text_encoder_2 ([` CLIPTextModelWithProjection`]): 111 | Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of 112 | [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection), 113 | specifically the 114 | [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k) 115 | variant. 116 | tokenizer (`CLIPTokenizer`): 117 | Tokenizer of class 118 | [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). 119 | tokenizer_2 (`CLIPTokenizer`): 120 | Second Tokenizer of class 121 | [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). 122 | unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. 123 | scheduler ([`SchedulerMixin`]): 124 | A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of 125 | [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. 126 | """ 127 | 128 | def __init__( 129 | self, 130 | vae: AutoencoderKL, 131 | text_encoder: CLIPTextModel, 132 | text_encoder_2: CLIPTextModelWithProjection, 133 | tokenizer: CLIPTokenizer, 134 | tokenizer_2: CLIPTokenizer, 135 | unet: UNet2DConditionModel, 136 | scheduler: KarrasDiffusionSchedulers, 137 | vae_sd1_5: AutoencoderKL, 138 | text_encoder_sd1_5: CLIPTextModel, 139 | tokenizer_sd1_5: CLIPTokenizer, 140 | unet_sd1_5: UNet2DConditionModel, 141 | scheduler_sd1_5: KarrasDiffusionSchedulers, 142 | adapter: Adapter_XL, 143 | force_zeros_for_empty_prompt: bool = True, 144 | add_watermarker: Optional[bool] = None, 145 | ): 146 | super().__init__() 147 | 148 | self.register_modules( 149 | vae=vae, 150 | text_encoder=text_encoder, 151 | text_encoder_2=text_encoder_2, 152 | tokenizer=tokenizer, 153 | tokenizer_2=tokenizer_2, 154 | unet=unet, 155 | scheduler=scheduler, 156 | vae_sd1_5=vae_sd1_5, 157 | text_encoder_sd1_5=text_encoder_sd1_5, 158 | tokenizer_sd1_5=tokenizer_sd1_5, 159 | unet_sd1_5=unet_sd1_5, 160 | scheduler_sd1_5=scheduler_sd1_5, 161 | adapter=adapter, 162 | ) 163 | self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) 164 | self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) 165 | self.vae_scale_factor_sd1_5 = 2 ** (len(self.vae_sd1_5.config.block_out_channels) - 1) 166 | self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) 167 | self.default_sample_size = self.unet.config.sample_size 168 | self.image_processor_sd1_5 = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor_sd1_5) 169 | 170 | add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available() 171 | 172 | if add_watermarker: 173 | self.watermark = StableDiffusionXLWatermarker() 174 | else: 175 | self.watermark = None 176 | 177 | # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing 178 | def enable_vae_slicing(self): 179 | r""" 180 | Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to 181 | compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. 182 | """ 183 | self.vae.enable_slicing() 184 | 185 | # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing 186 | def disable_vae_slicing(self): 187 | r""" 188 | Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to 189 | computing decoding in one step. 190 | """ 191 | self.vae.disable_slicing() 192 | 193 | # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling 194 | def enable_vae_tiling(self): 195 | r""" 196 | Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to 197 | compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow 198 | processing larger images. 199 | """ 200 | self.vae.enable_tiling() 201 | 202 | # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling 203 | def disable_vae_tiling(self): 204 | r""" 205 | Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to 206 | computing decoding in one step. 207 | """ 208 | self.vae.disable_tiling() 209 | 210 | def enable_model_cpu_offload(self, gpu_id=0): 211 | r""" 212 | Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared 213 | to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` 214 | method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with 215 | `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. 216 | """ 217 | if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): 218 | from accelerate import cpu_offload_with_hook 219 | else: 220 | raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") 221 | 222 | device = torch.device(f"cuda:{gpu_id}") 223 | 224 | 225 | self.to("cpu", silence_dtype_warnings=True) 226 | torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) 227 | 228 | model_sequence = ( 229 | [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] 230 | ) 231 | model_sequence.extend([self.unet, self.vae]) 232 | 233 | model_sequence.extend([self.unet_sd1_5, self.vae_sd1_5, self.text_encoder_sd1_5, self.adapter]) 234 | 235 | hook = None 236 | for cpu_offloaded_model in model_sequence: 237 | _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) 238 | 239 | # We'll offload the last model manually. 240 | self.final_offload_hook = hook 241 | 242 | def encode_prompt( 243 | self, 244 | prompt: str, 245 | prompt_2: Optional[str] = None, 246 | device: Optional[torch.device] = None, 247 | num_images_per_prompt: int = 1, 248 | do_classifier_free_guidance: bool = True, 249 | negative_prompt: Optional[str] = None, 250 | negative_prompt_2: Optional[str] = None, 251 | prompt_embeds: Optional[torch.FloatTensor] = None, 252 | negative_prompt_embeds: Optional[torch.FloatTensor] = None, 253 | pooled_prompt_embeds: Optional[torch.FloatTensor] = None, 254 | negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, 255 | lora_scale: Optional[float] = None, 256 | ): 257 | r""" 258 | Encodes the prompt into text encoder hidden states. 259 | 260 | Args: 261 | prompt (`str` or `List[str]`, *optional*): 262 | prompt to be encoded 263 | prompt_2 (`str` or `List[str]`, *optional*): 264 | The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is 265 | used in both text-encoders 266 | device: (`torch.device`): 267 | torch device 268 | num_images_per_prompt (`int`): 269 | number of images that should be generated per prompt 270 | do_classifier_free_guidance (`bool`): 271 | whether to use classifier free guidance or not 272 | negative_prompt (`str` or `List[str]`, *optional*): 273 | The prompt or prompts not to guide the image generation. If not defined, one has to pass 274 | `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is 275 | less than `1`). 276 | negative_prompt_2 (`str` or `List[str]`, *optional*): 277 | The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and 278 | `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders 279 | prompt_embeds (`torch.FloatTensor`, *optional*): 280 | Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not 281 | provided, text embeddings will be generated from `prompt` input argument. 282 | negative_prompt_embeds (`torch.FloatTensor`, *optional*): 283 | Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt 284 | weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input 285 | argument. 286 | pooled_prompt_embeds (`torch.FloatTensor`, *optional*): 287 | Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. 288 | If not provided, pooled text embeddings will be generated from `prompt` input argument. 289 | negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): 290 | Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt 291 | weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` 292 | input argument. 293 | lora_scale (`float`, *optional*): 294 | A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. 295 | """ 296 | device = device or self._execution_device 297 | 298 | # set lora scale so that monkey patched LoRA 299 | # function of text encoder can correctly access it 300 | if lora_scale is not None and isinstance(self, LoraLoaderMixin): 301 | self._lora_scale = lora_scale 302 | 303 | if prompt is not None and isinstance(prompt, str): 304 | batch_size = 1 305 | elif prompt is not None and isinstance(prompt, list): 306 | batch_size = len(prompt) 307 | else: 308 | batch_size = prompt_embeds.shape[0] 309 | 310 | # Define tokenizers and text encoders 311 | tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] 312 | text_encoders = ( 313 | [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] 314 | ) 315 | 316 | if prompt_embeds is None: 317 | prompt_2 = prompt_2 or prompt 318 | # textual inversion: procecss multi-vector tokens if necessary 319 | prompt_embeds_list = [] 320 | prompts = [prompt, prompt_2] 321 | for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): 322 | if isinstance(self, TextualInversionLoaderMixin): 323 | prompt = self.maybe_convert_prompt(prompt, tokenizer) 324 | 325 | text_inputs = tokenizer( 326 | prompt, 327 | padding="max_length", 328 | max_length=tokenizer.model_max_length, 329 | truncation=True, 330 | return_tensors="pt", 331 | ) 332 | 333 | text_input_ids = text_inputs.input_ids 334 | untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids 335 | 336 | if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( 337 | text_input_ids, untruncated_ids 338 | ): 339 | removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1: -1]) 340 | logger.warning( 341 | "The following part of your input was truncated because CLIP can only handle sequences up to" 342 | f" {tokenizer.model_max_length} tokens: {removed_text}" 343 | ) 344 | 345 | prompt_embeds = text_encoder( 346 | text_input_ids.to(device), 347 | output_hidden_states=True, 348 | ) 349 | 350 | # We are only ALWAYS interested in the pooled output of the final text encoder 351 | pooled_prompt_embeds = prompt_embeds[0] 352 | prompt_embeds = prompt_embeds.hidden_states[-2] 353 | 354 | prompt_embeds_list.append(prompt_embeds) 355 | 356 | prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) 357 | 358 | # get unconditional embeddings for classifier free guidance 359 | zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt 360 | if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: 361 | negative_prompt_embeds = torch.zeros_like(prompt_embeds) 362 | negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) 363 | elif do_classifier_free_guidance and negative_prompt_embeds is None: 364 | negative_prompt = negative_prompt or "" 365 | negative_prompt_2 = negative_prompt_2 or negative_prompt 366 | 367 | uncond_tokens: List[str] 368 | if prompt is not None and type(prompt) is not type(negative_prompt): 369 | raise TypeError( 370 | f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" 371 | f" {type(prompt)}." 372 | ) 373 | elif isinstance(negative_prompt, str): 374 | uncond_tokens = [negative_prompt, negative_prompt_2] 375 | elif batch_size != len(negative_prompt): 376 | raise ValueError( 377 | f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" 378 | f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" 379 | " the batch size of `prompt`." 380 | ) 381 | else: 382 | uncond_tokens = [negative_prompt, negative_prompt_2] 383 | 384 | negative_prompt_embeds_list = [] 385 | for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): 386 | if isinstance(self, TextualInversionLoaderMixin): 387 | negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer) 388 | 389 | max_length = prompt_embeds.shape[1] 390 | uncond_input = tokenizer( 391 | negative_prompt, 392 | padding="max_length", 393 | max_length=max_length, 394 | truncation=True, 395 | return_tensors="pt", 396 | ) 397 | 398 | negative_prompt_embeds = text_encoder( 399 | uncond_input.input_ids.to(device), 400 | output_hidden_states=True, 401 | ) 402 | # We are only ALWAYS interested in the pooled output of the final text encoder 403 | negative_pooled_prompt_embeds = negative_prompt_embeds[0] 404 | negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] 405 | 406 | negative_prompt_embeds_list.append(negative_prompt_embeds) 407 | 408 | negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) 409 | 410 | prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) 411 | bs_embed, seq_len, _ = prompt_embeds.shape 412 | # duplicate text embeddings for each generation per prompt, using mps friendly method 413 | prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) 414 | prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) 415 | 416 | if do_classifier_free_guidance: 417 | # duplicate unconditional embeddings for each generation per prompt, using mps friendly method 418 | seq_len = negative_prompt_embeds.shape[1] 419 | negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) 420 | negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) 421 | negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) 422 | 423 | pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( 424 | bs_embed * num_images_per_prompt, -1 425 | ) 426 | if do_classifier_free_guidance: 427 | negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( 428 | bs_embed * num_images_per_prompt, -1 429 | ) 430 | 431 | return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds 432 | 433 | # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs 434 | def prepare_extra_step_kwargs(self, generator, eta): 435 | # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature 436 | # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. 437 | # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 438 | # and should be between [0, 1] 439 | 440 | accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) 441 | extra_step_kwargs = {} 442 | if accepts_eta: 443 | extra_step_kwargs["eta"] = eta 444 | 445 | # check if the scheduler accepts generator 446 | accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) 447 | if accepts_generator: 448 | extra_step_kwargs["generator"] = generator 449 | return extra_step_kwargs 450 | 451 | def check_inputs( 452 | self, 453 | prompt, 454 | prompt_2, 455 | height, 456 | width, 457 | callback_steps, 458 | negative_prompt=None, 459 | negative_prompt_2=None, 460 | prompt_embeds=None, 461 | negative_prompt_embeds=None, 462 | pooled_prompt_embeds=None, 463 | negative_pooled_prompt_embeds=None, 464 | ): 465 | if height % 8 != 0 or width % 8 != 0: 466 | raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") 467 | 468 | if (callback_steps is None) or ( 469 | callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) 470 | ): 471 | raise ValueError( 472 | f"`callback_steps` has to be a positive integer but is {callback_steps} of type" 473 | f" {type(callback_steps)}." 474 | ) 475 | 476 | if prompt is not None and prompt_embeds is not None: 477 | raise ValueError( 478 | f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" 479 | " only forward one of the two." 480 | ) 481 | elif prompt_2 is not None and prompt_embeds is not None: 482 | raise ValueError( 483 | f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" 484 | " only forward one of the two." 485 | ) 486 | elif prompt is None and prompt_embeds is None: 487 | raise ValueError( 488 | "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." 489 | ) 490 | elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): 491 | raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") 492 | elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): 493 | raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") 494 | 495 | if negative_prompt is not None and negative_prompt_embeds is not None: 496 | raise ValueError( 497 | f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" 498 | f" {negative_prompt_embeds}. Please make sure to only forward one of the two." 499 | ) 500 | elif negative_prompt_2 is not None and negative_prompt_embeds is not None: 501 | raise ValueError( 502 | f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" 503 | f" {negative_prompt_embeds}. Please make sure to only forward one of the two." 504 | ) 505 | 506 | if prompt_embeds is not None and negative_prompt_embeds is not None: 507 | if prompt_embeds.shape != negative_prompt_embeds.shape: 508 | raise ValueError( 509 | "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" 510 | f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" 511 | f" {negative_prompt_embeds.shape}." 512 | ) 513 | 514 | if prompt_embeds is not None and pooled_prompt_embeds is None: 515 | raise ValueError( 516 | "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." 517 | ) 518 | 519 | if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: 520 | raise ValueError( 521 | "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." 522 | ) 523 | 524 | # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents 525 | def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): 526 | shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) 527 | if isinstance(generator, list) and len(generator) != batch_size: 528 | raise ValueError( 529 | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" 530 | f" size of {batch_size}. Make sure the batch size matches the length of the generators." 531 | ) 532 | 533 | if latents is None: 534 | latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) 535 | else: 536 | latents = latents.to(device) 537 | 538 | # scale the initial noise by the standard deviation required by the scheduler 539 | latents = latents * self.scheduler.init_noise_sigma 540 | return latents 541 | 542 | def _get_add_time_ids(self, original_size, crops_coords_top_left, target_size, dtype): 543 | add_time_ids = list(original_size + crops_coords_top_left + target_size) 544 | 545 | passed_add_embed_dim = ( 546 | self.unet.config.addition_time_embed_dim * len(add_time_ids) + self.text_encoder_2.config.projection_dim 547 | ) 548 | expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features 549 | 550 | if expected_add_embed_dim != passed_add_embed_dim: 551 | raise ValueError( 552 | f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." 553 | ) 554 | 555 | add_time_ids = torch.tensor([add_time_ids], dtype=dtype) 556 | return add_time_ids 557 | 558 | # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae 559 | def upcast_vae(self): 560 | dtype = self.vae.dtype 561 | self.vae.to(dtype=torch.float32) 562 | use_torch_2_0_or_xformers = isinstance( 563 | self.vae.decoder.mid_block.attentions[0].processor, 564 | ( 565 | AttnProcessor2_0, 566 | XFormersAttnProcessor, 567 | LoRAXFormersAttnProcessor, 568 | LoRAAttnProcessor2_0, 569 | ), 570 | ) 571 | # if xformers or torch_2_0 is used attention block does not need 572 | # to be in float32 which can save lots of memory 573 | if use_torch_2_0_or_xformers: 574 | self.vae.post_quant_conv.to(dtype) 575 | self.vae.decoder.conv_in.to(dtype) 576 | self.vae.decoder.mid_block.to(dtype) 577 | 578 | @torch.no_grad() 579 | @replace_example_docstring(EXAMPLE_DOC_STRING) 580 | def __call__( 581 | self, 582 | prompt: Union[str, List[str]] = None, 583 | prompt_2: Optional[Union[str, List[str]]] = None, 584 | prompt_sd1_5: Optional[Union[str, List[str]]] = None, 585 | height: Optional[int] = None, 586 | width: Optional[int] = None, 587 | height_sd1_5: Optional[int] = None, 588 | width_sd1_5: Optional[int] = None, 589 | num_inference_steps: int = 50, 590 | denoising_end: Optional[float] = None, 591 | guidance_scale: float = 5.0, 592 | negative_prompt: Optional[Union[str, List[str]]] = None, 593 | negative_prompt_2: Optional[Union[str, List[str]]] = None, 594 | num_images_per_prompt: Optional[int] = 1, 595 | eta: float = 0.0, 596 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 597 | latents: Optional[torch.FloatTensor] = None, 598 | latents_sd1_5: Optional[torch.FloatTensor] = None, 599 | prompt_embeds: Optional[torch.FloatTensor] = None, 600 | negative_prompt_embeds: Optional[torch.FloatTensor] = None, 601 | prompt_embeds_sd1_5: Optional[torch.FloatTensor] = None, 602 | negative_prompt_embeds_sd1_5: Optional[torch.FloatTensor] = None, 603 | pooled_prompt_embeds: Optional[torch.FloatTensor] = None, 604 | negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, 605 | output_type: Optional[str] = "pil", 606 | return_dict: bool = True, 607 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, 608 | callback_steps: int = 1, 609 | cross_attention_kwargs: Optional[Dict[str, Any]] = None, 610 | guidance_rescale: float = 0.0, 611 | original_size: Optional[Tuple[int, int]] = None, 612 | crops_coords_top_left: Tuple[int, int] = (0, 0), 613 | target_size: Optional[Tuple[int, int]] = None, 614 | adapter_condition_scale: Optional[float] = 1.0, 615 | adapter_guidance_start: Union[float, List[float]] = 0.5, 616 | denoising_start: Optional[float] = None, 617 | adapter_type: str = "de", # "de", "en", "en_de" 618 | fusion_guidance_scale: Optional[float] = None, 619 | enable_time_step: bool = False 620 | ): 621 | r""" 622 | Function invoked when calling the pipeline for generation. 623 | 624 | Args: 625 | prompt (`str` or `List[str]`, *optional*): 626 | The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. 627 | instead. 628 | prompt_2 (`str` or `List[str]`, *optional*): 629 | The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is 630 | used in both text-encoders 631 | height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): 632 | The height in pixels of the generated image. 633 | width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): 634 | The width in pixels of the generated image. 635 | num_inference_steps (`int`, *optional*, defaults to 50): 636 | The number of denoising steps. More denoising steps usually lead to a higher quality image at the 637 | expense of slower inference. 638 | denoising_end (`float`, *optional*): 639 | When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be 640 | completed before it is intentionally prematurely terminated. As a result, the returned sample will 641 | still retain a substantial amount of noise as determined by the discrete timesteps selected by the 642 | scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a 643 | "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image 644 | Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output) 645 | guidance_scale (`float`, *optional*, defaults to 5.0): 646 | Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). 647 | `guidance_scale` is defined as `w` of equation 2. of [Imagen 648 | Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 649 | 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, 650 | usually at the expense of lower image quality. 651 | negative_prompt (`str` or `List[str]`, *optional*): 652 | The prompt or prompts not to guide the image generation. If not defined, one has to pass 653 | `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is 654 | less than `1`). 655 | negative_prompt_2 (`str` or `List[str]`, *optional*): 656 | The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and 657 | `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders 658 | num_images_per_prompt (`int`, *optional*, defaults to 1): 659 | The number of images to generate per prompt. 660 | eta (`float`, *optional*, defaults to 0.0): 661 | Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to 662 | [`schedulers.DDIMScheduler`], will be ignored for others. 663 | generator (`torch.Generator` or `List[torch.Generator]`, *optional*): 664 | One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) 665 | to make generation deterministic. 666 | latents (`torch.FloatTensor`, *optional*): 667 | Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image 668 | generation. Can be used to tweak the same generation with different prompts. If not provided, a latents 669 | tensor will ge generated by sampling using the supplied random `generator`. 670 | prompt_embeds (`torch.FloatTensor`, *optional*): 671 | Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not 672 | provided, text embeddings will be generated from `prompt` input argument. 673 | negative_prompt_embeds (`torch.FloatTensor`, *optional*): 674 | Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt 675 | weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input 676 | argument. 677 | pooled_prompt_embeds (`torch.FloatTensor`, *optional*): 678 | Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. 679 | If not provided, pooled text embeddings will be generated from `prompt` input argument. 680 | negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): 681 | Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt 682 | weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` 683 | input argument. 684 | output_type (`str`, *optional*, defaults to `"pil"`): 685 | The output format of the generate image. Choose between 686 | [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. 687 | return_dict (`bool`, *optional*, defaults to `True`): 688 | Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead 689 | of a plain tuple. 690 | callback (`Callable`, *optional*): 691 | A function that will be called every `callback_steps` steps during inference. The function will be 692 | called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. 693 | callback_steps (`int`, *optional*, defaults to 1): 694 | The frequency at which the `callback` function will be called. If not specified, the callback will be 695 | called at every step. 696 | cross_attention_kwargs (`dict`, *optional*): 697 | A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under 698 | `self.processor` in 699 | [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). 700 | guidance_rescale (`float`, *optional*, defaults to 0.7): 701 | Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are 702 | Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of 703 | [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). 704 | Guidance rescale factor should fix overexposure when using zero terminal SNR. 705 | original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): 706 | If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. 707 | `original_size` defaults to `(width, height)` if not specified. Part of SDXL's micro-conditioning as 708 | explained in section 2.2 of 709 | [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). 710 | crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): 711 | `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position 712 | `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting 713 | `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of 714 | [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). 715 | target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): 716 | For most cases, `target_size` should be set to the desired height and width of the generated image. If 717 | not specified it will default to `(width, height)`. Part of SDXL's micro-conditioning as explained in 718 | section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). 719 | 720 | Examples: 721 | 722 | Returns: 723 | [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`: 724 | [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a 725 | `tuple`. When returning a tuple, the first element is a list with the generated images. 726 | """ 727 | # 0. Default height and width to unet 728 | height = height or self.default_sample_size * self.vae_scale_factor 729 | width = width or self.default_sample_size * self.vae_scale_factor 730 | 731 | height_sd1_5 = height_sd1_5 or self.default_sample_size_sd1_5 * self.vae_scale_factor_sd1_5 732 | width_sd1_5 = width_sd1_5 or self.default_sample_size_sd1_5 * self.vae_scale_factor_sd1_5 733 | 734 | original_size = original_size or (height, width) 735 | target_size = target_size or (height, width) 736 | 737 | # 1. Check inputs. Raise error if not correct 738 | self.check_inputs( 739 | prompt, 740 | prompt_2, 741 | height, 742 | width, 743 | callback_steps, 744 | negative_prompt, 745 | negative_prompt_2, 746 | prompt_embeds, 747 | negative_prompt_embeds, 748 | pooled_prompt_embeds, 749 | negative_pooled_prompt_embeds, 750 | ) 751 | 752 | self.check_inputs_sd1_5( 753 | prompt if prompt_sd1_5 is None else prompt_sd1_5, height_sd1_5, width_sd1_5, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds 754 | ) 755 | 756 | # 2. Define call parameters 757 | if prompt is not None and isinstance(prompt, str): 758 | batch_size = 1 759 | elif prompt is not None and isinstance(prompt, list): 760 | batch_size = len(prompt) 761 | else: 762 | batch_size = prompt_embeds.shape[0] 763 | 764 | device = torch.device('cuda') 765 | 766 | # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) 767 | # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` 768 | # corresponds to doing no classifier free guidance. 769 | do_classifier_free_guidance = guidance_scale > 1.0 770 | 771 | # 3. Encode input prompt 772 | text_encoder_lora_scale = ( 773 | cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None 774 | ) 775 | ( 776 | prompt_embeds, 777 | negative_prompt_embeds, 778 | pooled_prompt_embeds, 779 | negative_pooled_prompt_embeds, 780 | ) = self.encode_prompt( 781 | prompt=prompt, 782 | prompt_2=prompt_2, 783 | device=device, 784 | num_images_per_prompt=num_images_per_prompt, 785 | do_classifier_free_guidance=do_classifier_free_guidance, 786 | negative_prompt=negative_prompt, 787 | negative_prompt_2=negative_prompt_2, 788 | prompt_embeds=prompt_embeds, 789 | negative_prompt_embeds=negative_prompt_embeds, 790 | pooled_prompt_embeds=pooled_prompt_embeds, 791 | negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, 792 | lora_scale=text_encoder_lora_scale, 793 | ) 794 | 795 | prompt_embeds_sd1_5 = self._encode_prompt_sd1_5( 796 | prompt if prompt_sd1_5 is None else prompt_sd1_5, 797 | device, 798 | num_images_per_prompt, 799 | do_classifier_free_guidance, 800 | negative_prompt, 801 | prompt_embeds=prompt_embeds_sd1_5, 802 | negative_prompt_embeds=negative_prompt_embeds_sd1_5, 803 | lora_scale=text_encoder_lora_scale, 804 | ) 805 | # todo: implement prompt_embeds for SD1.5 806 | 807 | # 4. Prepare timesteps 808 | self.scheduler_sd1_5.set_timesteps(num_inference_steps, device=device) 809 | timesteps_sd1_5 = self.scheduler_sd1_5.timesteps 810 | num_inference_steps_sd1_5 = num_inference_steps 811 | 812 | self.scheduler.set_timesteps(num_inference_steps, device=device) 813 | 814 | timesteps, num_inference_steps = self.get_timesteps( 815 | num_inference_steps, adapter_guidance_start, device, denoising_start=denoising_start 816 | ) 817 | latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) 818 | 819 | # 5. Prepare latent variables 820 | num_channels_latents = self.unet.config.in_channels 821 | latents = self.prepare_latents( 822 | batch_size * num_images_per_prompt, 823 | num_channels_latents, 824 | height, 825 | width, 826 | prompt_embeds.dtype, 827 | device, 828 | generator, 829 | latents, 830 | ) 831 | 832 | num_channels_latents_sd1_5 = self.unet_sd1_5.config.in_channels 833 | latents_sd1_5 = self.prepare_latents_sd1_5( 834 | batch_size * num_images_per_prompt, 835 | num_channels_latents_sd1_5, 836 | height_sd1_5, 837 | width_sd1_5, 838 | prompt_embeds_sd1_5.dtype, 839 | device, 840 | generator, 841 | latents_sd1_5, 842 | ) 843 | 844 | # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline 845 | extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) 846 | 847 | # 7. Prepare added time ids & embeddings 848 | add_text_embeds = pooled_prompt_embeds 849 | add_time_ids = self._get_add_time_ids( 850 | original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype 851 | ) 852 | 853 | if do_classifier_free_guidance: 854 | prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) 855 | add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) 856 | add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0) 857 | 858 | prompt_embeds = prompt_embeds.to(device) 859 | add_text_embeds = add_text_embeds.to(device) 860 | add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) 861 | 862 | # 8. Denoising loop 863 | num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) 864 | 865 | # 7.1 Apply denoising_end 866 | if denoising_end is not None and type(denoising_end) == float and denoising_end > 0 and denoising_end < 1: 867 | discrete_timestep_cutoff = int( 868 | round( 869 | self.scheduler.config.num_train_timesteps 870 | - (denoising_end * self.scheduler.config.num_train_timesteps) 871 | ) 872 | ) 873 | num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) 874 | timesteps = timesteps[:num_inference_steps] 875 | 876 | latents_sd1_5_prior = latents_sd1_5.clone() 877 | 878 | with self.progress_bar(total=num_inference_steps_sd1_5) as progress_bar: 879 | for i, t in enumerate(timesteps_sd1_5): 880 | 881 | #################### SD1.5 forward #################### 882 | t_sd1_5 = timesteps_sd1_5[i] 883 | 884 | latent_model_input = torch.cat([latents_sd1_5_prior] * 2) if do_classifier_free_guidance else latents_sd1_5_prior 885 | latent_model_input = self.scheduler_sd1_5.scale_model_input(latent_model_input, t_sd1_5) 886 | 887 | # predict the noise residual 888 | unet_output = self.unet_sd1_5( 889 | latent_model_input, 890 | t_sd1_5, 891 | encoder_hidden_states=prompt_embeds_sd1_5, 892 | cross_attention_kwargs=cross_attention_kwargs, 893 | return_hidden_states=False 894 | ) 895 | noise_pred = unet_output.sample 896 | 897 | # perform guidance 898 | if do_classifier_free_guidance: 899 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 900 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) 901 | 902 | if do_classifier_free_guidance and guidance_rescale > 0.0: 903 | # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf 904 | noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) 905 | 906 | # compute the previous noisy sample x_t -> x_t-1 907 | latents_sd1_5_prior = self.scheduler_sd1_5.step(noise_pred, t_sd1_5, latents_sd1_5_prior, **extra_step_kwargs, return_dict=False)[0] 908 | 909 | #################### End of SD1.5 forward #################### 910 | 911 | # call the callback, if provided 912 | if i == len(timesteps_sd1_5) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler_sd1_5.order == 0): 913 | progress_bar.update() 914 | 915 | add_noise = True if denoising_start is None else False 916 | latents = self.prepare_xl_latents_from_sd_1_5(latents_sd1_5_prior, latent_timestep, batch_size, 917 | num_images_per_prompt, height, width, prompt_embeds.dtype, device, 918 | generator=generator, add_noise=add_noise) 919 | latents_sd1_5 = self.sd1_5_add_noise(latents_sd1_5_prior, latent_timestep, generator, device, 920 | prompt_embeds.dtype) 921 | 922 | with self.progress_bar(total=num_inference_steps) as progress_bar: 923 | for i, t in enumerate(timesteps): 924 | # expand the latents if we are doing classifier free guidance 925 | 926 | #################### SD1.5 forward #################### 927 | t_sd1_5 = timesteps_sd1_5[i] 928 | latent_model_input = torch.cat([latents_sd1_5] * 2) if do_classifier_free_guidance else latents 929 | latent_model_input = self.scheduler_sd1_5.scale_model_input(latent_model_input, t_sd1_5) 930 | 931 | unet_output = self.unet_sd1_5( 932 | latent_model_input, 933 | t_sd1_5, 934 | encoder_hidden_states=prompt_embeds_sd1_5, 935 | cross_attention_kwargs=cross_attention_kwargs, 936 | return_hidden_states=True, 937 | return_encoder_feature=True 938 | ) 939 | noise_pred = unet_output.sample 940 | hidden_states = unet_output.hidden_states 941 | encoder_feature = unet_output.encoder_feature 942 | 943 | 944 | # adapter forward 945 | if adapter_type == "de": 946 | down_bridge_residuals = None 947 | up_block_additional_residual = self.adapter(hidden_states, t=t_sd1_5 if enable_time_step else None) 948 | for xx in range(len(up_block_additional_residual)): 949 | up_block_additional_residual[xx] = up_block_additional_residual[xx] * adapter_condition_scale 950 | elif adapter_type == "en": 951 | up_block_additional_residual = None 952 | down_bridge_residuals = self.adapter(encoder_feature) 953 | for xx in range(len(down_bridge_residuals)): 954 | down_bridge_residuals[xx] = down_bridge_residuals[xx] * adapter_condition_scale 955 | else: 956 | dict = self.adapter(x=hidden_states, enc_x=encoder_feature) 957 | down_bridge_residuals = dict['encoder_features'] 958 | up_block_additional_residual = dict['decoder_features'] 959 | for xx in range(len(up_block_additional_residual)): 960 | up_block_additional_residual[xx] = up_block_additional_residual[xx] * adapter_condition_scale 961 | for xx in range(len(down_bridge_residuals)): 962 | down_bridge_residuals[xx] = down_bridge_residuals[xx] * adapter_condition_scale 963 | 964 | # perform guidance 965 | if do_classifier_free_guidance: 966 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 967 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) 968 | 969 | if do_classifier_free_guidance and guidance_rescale > 0.0: 970 | # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf 971 | noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) 972 | 973 | # compute the previous noisy sample x_t -> x_t-1 974 | 975 | latents_sd1_5 = self.scheduler_sd1_5.step(noise_pred, t_sd1_5, latents_sd1_5, **extra_step_kwargs, 976 | return_dict=False)[0] 977 | 978 | #################### End of SD1.5 forward #################### 979 | 980 | latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents 981 | 982 | latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) 983 | 984 | # predict the noise residual 985 | added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} 986 | 987 | noise_pred = self.unet( 988 | latent_model_input, 989 | t, 990 | encoder_hidden_states=prompt_embeds, 991 | cross_attention_kwargs=cross_attention_kwargs, 992 | added_cond_kwargs=added_cond_kwargs, 993 | up_block_additional_residual=up_block_additional_residual, 994 | down_bridge_residuals=down_bridge_residuals, 995 | return_dict=False, 996 | fusion_guidance_scale=fusion_guidance_scale 997 | )[0] 998 | 999 | # perform guidance 1000 | if do_classifier_free_guidance: 1001 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 1002 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) 1003 | 1004 | if do_classifier_free_guidance and guidance_rescale > 0.0: 1005 | # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf 1006 | noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) 1007 | 1008 | # compute the previous noisy sample x_t -> x_t-1 1009 | latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] 1010 | 1011 | # call the callback, if provided 1012 | if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): 1013 | progress_bar.update() 1014 | if callback is not None and i % callback_steps == 0: 1015 | callback(i, t, latents) 1016 | 1017 | # make sure the VAE is in float32 mode, as it overflows in float16 1018 | if self.vae.dtype == torch.float16 and self.vae.config.force_upcast: 1019 | self.upcast_vae() 1020 | latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) 1021 | 1022 | if not output_type == "latent": 1023 | image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] 1024 | else: 1025 | image = latents 1026 | return StableDiffusionXLPipelineOutput(images=image) 1027 | 1028 | # apply watermark if available 1029 | if self.watermark is not None: 1030 | image = self.watermark.apply_watermark(image) 1031 | 1032 | image = self.image_processor.postprocess(image, output_type=output_type) 1033 | 1034 | # Offload last model to CPU 1035 | if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: 1036 | self.final_offload_hook.offload() 1037 | 1038 | if not return_dict: 1039 | return (image,) 1040 | 1041 | return StableDiffusionXLPipelineOutput(images=image) 1042 | 1043 | # Overrride to properly handle the loading and unloading of the additional text encoder. 1044 | def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs): 1045 | # We could have accessed the unet config from `lora_state_dict()` too. We pass 1046 | # it here explicitly to be able to tell that it's coming from an SDXL 1047 | # pipeline. 1048 | state_dict, network_alphas = self.lora_state_dict( 1049 | pretrained_model_name_or_path_or_dict, 1050 | unet_config=self.unet.config, 1051 | **kwargs, 1052 | ) 1053 | self.load_lora_into_unet(state_dict, network_alphas=network_alphas, unet=self.unet) 1054 | 1055 | text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k} 1056 | if len(text_encoder_state_dict) > 0: 1057 | self.load_lora_into_text_encoder( 1058 | text_encoder_state_dict, 1059 | network_alphas=network_alphas, 1060 | text_encoder=self.text_encoder, 1061 | prefix="text_encoder", 1062 | lora_scale=self.lora_scale, 1063 | ) 1064 | 1065 | text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if "text_encoder_2." in k} 1066 | if len(text_encoder_2_state_dict) > 0: 1067 | self.load_lora_into_text_encoder( 1068 | text_encoder_2_state_dict, 1069 | network_alphas=network_alphas, 1070 | text_encoder=self.text_encoder_2, 1071 | prefix="text_encoder_2", 1072 | lora_scale=self.lora_scale, 1073 | ) 1074 | 1075 | @classmethod 1076 | def save_lora_weights( 1077 | self, 1078 | save_directory: Union[str, os.PathLike], 1079 | unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, 1080 | text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, 1081 | text_encoder_2_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, 1082 | is_main_process: bool = True, 1083 | weight_name: str = None, 1084 | save_function: Callable = None, 1085 | safe_serialization: bool = True, 1086 | ): 1087 | state_dict = {} 1088 | 1089 | def pack_weights(layers, prefix): 1090 | layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers 1091 | layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()} 1092 | return layers_state_dict 1093 | 1094 | state_dict.update(pack_weights(unet_lora_layers, "unet")) 1095 | 1096 | if text_encoder_lora_layers and text_encoder_2_lora_layers: 1097 | state_dict.update(pack_weights(text_encoder_lora_layers, "text_encoder")) 1098 | state_dict.update(pack_weights(text_encoder_2_lora_layers, "text_encoder_2")) 1099 | 1100 | self.write_lora_layers( 1101 | state_dict=state_dict, 1102 | save_directory=save_directory, 1103 | is_main_process=is_main_process, 1104 | weight_name=weight_name, 1105 | save_function=save_function, 1106 | safe_serialization=safe_serialization, 1107 | ) 1108 | 1109 | def _remove_text_encoder_monkey_patch(self): 1110 | self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder) 1111 | self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder_2) 1112 | 1113 | def _encode_prompt_sd1_5( 1114 | self, 1115 | prompt, 1116 | device, 1117 | num_images_per_prompt, 1118 | do_classifier_free_guidance, 1119 | negative_prompt=None, 1120 | prompt_embeds: Optional[torch.FloatTensor] = None, 1121 | negative_prompt_embeds: Optional[torch.FloatTensor] = None, 1122 | lora_scale: Optional[float] = None, 1123 | ): 1124 | r""" 1125 | Encodes the prompt into text encoder hidden states. 1126 | 1127 | Args: 1128 | prompt (`str` or `List[str]`, *optional*): 1129 | prompt to be encoded 1130 | device: (`torch.device`): 1131 | torch device 1132 | num_images_per_prompt (`int`): 1133 | number of images that should be generated per prompt 1134 | do_classifier_free_guidance (`bool`): 1135 | whether to use classifier free guidance or not 1136 | negative_prompt (`str` or `List[str]`, *optional*): 1137 | The prompt or prompts not to guide the image generation. If not defined, one has to pass 1138 | `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is 1139 | less than `1`). 1140 | prompt_embeds (`torch.FloatTensor`, *optional*): 1141 | Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not 1142 | provided, text embeddings will be generated from `prompt` input argument. 1143 | negative_prompt_embeds (`torch.FloatTensor`, *optional*): 1144 | Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt 1145 | weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input 1146 | argument. 1147 | lora_scale (`float`, *optional*): 1148 | A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. 1149 | """ 1150 | # set lora scale so that monkey patched LoRA 1151 | # function of text encoder can correctly access it 1152 | if lora_scale is not None and isinstance(self, LoraLoaderMixin): 1153 | self._lora_scale = lora_scale 1154 | 1155 | if prompt is not None and isinstance(prompt, str): 1156 | batch_size = 1 1157 | elif prompt is not None and isinstance(prompt, list): 1158 | batch_size = len(prompt) 1159 | else: 1160 | batch_size = prompt_embeds.shape[0] 1161 | 1162 | if prompt_embeds is None: 1163 | # textual inversion: procecss multi-vector tokens if necessary 1164 | if isinstance(self, TextualInversionLoaderMixin): 1165 | prompt = self.maybe_convert_prompt(prompt, self.tokenizer_sd1_5) 1166 | 1167 | text_inputs = self.tokenizer_sd1_5( 1168 | prompt, 1169 | padding="max_length", 1170 | max_length=self.tokenizer_sd1_5.model_max_length, 1171 | truncation=True, 1172 | return_tensors="pt", 1173 | ) 1174 | text_input_ids = text_inputs.input_ids 1175 | untruncated_ids = self.tokenizer_sd1_5(prompt, padding="longest", return_tensors="pt").input_ids 1176 | 1177 | if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( 1178 | text_input_ids, untruncated_ids 1179 | ): 1180 | removed_text = self.tokenizer_sd1_5.batch_decode( 1181 | untruncated_ids[:, self.tokenizer_sd1_5.model_max_length - 1: -1] 1182 | ) 1183 | logger.warning( 1184 | "The following part of your input was truncated because CLIP can only handle sequences up to" 1185 | f" {self.tokenizer_sd1_5.model_max_length} tokens: {removed_text}" 1186 | ) 1187 | 1188 | if hasattr(self.text_encoder_sd1_5.config, 1189 | "use_attention_mask") and self.text_encoder_sd1_5.config.use_attention_mask: 1190 | attention_mask = text_inputs.attention_mask.to(device) 1191 | else: 1192 | attention_mask = None 1193 | 1194 | prompt_embeds = self.text_encoder_sd1_5( 1195 | text_input_ids.to(device), 1196 | attention_mask=attention_mask, 1197 | ) 1198 | prompt_embeds = prompt_embeds[0] 1199 | 1200 | if self.text_encoder_sd1_5 is not None: 1201 | prompt_embeds_dtype = self.text_encoder_sd1_5.dtype 1202 | elif self.unet_sd1_5 is not None: 1203 | prompt_embeds_dtype = self.unet_sd1_5.dtype 1204 | else: 1205 | prompt_embeds_dtype = prompt_embeds.dtype 1206 | 1207 | prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) 1208 | 1209 | bs_embed, seq_len, _ = prompt_embeds.shape 1210 | # duplicate text embeddings for each generation per prompt, using mps friendly method 1211 | prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) 1212 | prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) 1213 | 1214 | # get unconditional embeddings for classifier free guidance 1215 | if do_classifier_free_guidance and negative_prompt_embeds is None: 1216 | uncond_tokens: List[str] 1217 | if negative_prompt is None: 1218 | uncond_tokens = [""] * batch_size 1219 | elif prompt is not None and type(prompt) is not type(negative_prompt): 1220 | raise TypeError( 1221 | f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" 1222 | f" {type(prompt)}." 1223 | ) 1224 | elif isinstance(negative_prompt, str): 1225 | uncond_tokens = [negative_prompt] 1226 | elif batch_size != len(negative_prompt): 1227 | raise ValueError( 1228 | f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" 1229 | f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" 1230 | " the batch size of `prompt`." 1231 | ) 1232 | else: 1233 | uncond_tokens = negative_prompt 1234 | 1235 | # textual inversion: procecss multi-vector tokens if necessary 1236 | if isinstance(self, TextualInversionLoaderMixin): 1237 | uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer_sd1_5) 1238 | 1239 | max_length = prompt_embeds.shape[1] 1240 | uncond_input = self.tokenizer_sd1_5( 1241 | uncond_tokens, 1242 | padding="max_length", 1243 | max_length=max_length, 1244 | truncation=True, 1245 | return_tensors="pt", 1246 | ) 1247 | 1248 | if hasattr(self.text_encoder_sd1_5.config, 1249 | "use_attention_mask") and self.text_encoder_sd1_5.config.use_attention_mask: 1250 | attention_mask = uncond_input.attention_mask.to(device) 1251 | else: 1252 | attention_mask = None 1253 | 1254 | negative_prompt_embeds = self.text_encoder_sd1_5( 1255 | uncond_input.input_ids.to(device), 1256 | attention_mask=attention_mask, 1257 | ) 1258 | negative_prompt_embeds = negative_prompt_embeds[0] 1259 | 1260 | if do_classifier_free_guidance: 1261 | # duplicate unconditional embeddings for each generation per prompt, using mps friendly method 1262 | seq_len = negative_prompt_embeds.shape[1] 1263 | 1264 | negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) 1265 | 1266 | negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) 1267 | negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) 1268 | 1269 | # For classifier free guidance, we need to do two forward passes. 1270 | # Here we concatenate the unconditional and text embeddings into a single batch 1271 | # to avoid doing two forward passes 1272 | prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) 1273 | 1274 | return prompt_embeds 1275 | 1276 | def decode_latents_sd1_5(self, latents): 1277 | warnings.warn( 1278 | "The decode_latents method is deprecated and will be removed in a future version. Please" 1279 | " use VaeImageProcessor instead", 1280 | FutureWarning, 1281 | ) 1282 | latents = 1 / self.vae_sd1_5.config.scaling_factor * latents 1283 | image = self.vae_sd1_5.decode(latents, return_dict=False)[0] 1284 | image = (image / 2 + 0.5).clamp(0, 1) 1285 | # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 1286 | image = image.cpu().permute(0, 2, 3, 1).float().numpy() 1287 | return image 1288 | 1289 | def check_inputs_sd1_5( 1290 | self, 1291 | prompt, 1292 | height, 1293 | width, 1294 | callback_steps, 1295 | negative_prompt=None, 1296 | prompt_embeds=None, 1297 | negative_prompt_embeds=None, 1298 | ): 1299 | if height % 8 != 0 or width % 8 != 0: 1300 | raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") 1301 | 1302 | if (callback_steps is None) or ( 1303 | callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) 1304 | ): 1305 | raise ValueError( 1306 | f"`callback_steps` has to be a positive integer but is {callback_steps} of type" 1307 | f" {type(callback_steps)}." 1308 | ) 1309 | 1310 | if prompt is not None and prompt_embeds is not None: 1311 | raise ValueError( 1312 | f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" 1313 | " only forward one of the two." 1314 | ) 1315 | elif prompt is None and prompt_embeds is None: 1316 | raise ValueError( 1317 | "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." 1318 | ) 1319 | elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): 1320 | raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") 1321 | 1322 | if negative_prompt is not None and negative_prompt_embeds is not None: 1323 | raise ValueError( 1324 | f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" 1325 | f" {negative_prompt_embeds}. Please make sure to only forward one of the two." 1326 | ) 1327 | 1328 | if prompt_embeds is not None and negative_prompt_embeds is not None: 1329 | if prompt_embeds.shape != negative_prompt_embeds.shape: 1330 | raise ValueError( 1331 | "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" 1332 | f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" 1333 | f" {negative_prompt_embeds.shape}." 1334 | ) 1335 | 1336 | def prepare_xl_latents_from_sd_1_5( 1337 | self, latent, timestep, batch_size, num_images_per_prompt, height, width, dtype, device, generator=None, 1338 | add_noise=True 1339 | ): 1340 | # sd1.5 latent -> img 1341 | image = self.vae_sd1_5.decode(latent / self.vae_sd1_5.config.scaling_factor, return_dict=False)[0] 1342 | do_denormalize = [True] * image.shape[0] 1343 | image = self.image_processor_sd1_5.postprocess(image, output_type='pil', do_denormalize=do_denormalize)[0] 1344 | image = image.resize((width, height)) 1345 | # image.save('./test_img/image_sd1_5.jpg') 1346 | # input() 1347 | 1348 | if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): 1349 | raise ValueError( 1350 | f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" 1351 | ) 1352 | 1353 | # Offload text encoder if `enable_model_cpu_offload` was enabled 1354 | if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: 1355 | self.text_encoder_2.to("cpu") 1356 | torch.cuda.empty_cache() 1357 | 1358 | image = self.image_processor.preprocess(image) 1359 | 1360 | image = image.to(device=device, dtype=dtype) 1361 | 1362 | batch_size = batch_size * num_images_per_prompt 1363 | 1364 | if image.shape[1] == 4: 1365 | init_latents = image 1366 | 1367 | else: 1368 | # make sure the VAE is in float32 mode, as it overflows in float16 1369 | if self.vae.config.force_upcast: 1370 | image = image.float() 1371 | self.vae.to(dtype=torch.float32) 1372 | 1373 | if isinstance(generator, list) and len(generator) != batch_size: 1374 | raise ValueError( 1375 | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" 1376 | f" size of {batch_size}. Make sure the batch size matches the length of the generators." 1377 | ) 1378 | 1379 | elif isinstance(generator, list): 1380 | init_latents = [ 1381 | self.vae.encode(image[i: i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size) 1382 | ] 1383 | init_latents = torch.cat(init_latents, dim=0) 1384 | else: 1385 | init_latents = self.vae.encode(image).latent_dist.sample(generator) 1386 | 1387 | if self.vae.config.force_upcast: 1388 | self.vae.to(dtype) 1389 | 1390 | init_latents = init_latents.to(dtype) 1391 | init_latents = self.vae.config.scaling_factor * init_latents 1392 | 1393 | if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: 1394 | # expand init_latents for batch_size 1395 | additional_image_per_prompt = batch_size // init_latents.shape[0] 1396 | init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0) 1397 | elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: 1398 | raise ValueError( 1399 | f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." 1400 | ) 1401 | else: 1402 | init_latents = torch.cat([init_latents], dim=0) 1403 | 1404 | if add_noise: 1405 | shape = init_latents.shape 1406 | noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) 1407 | # get latents 1408 | init_latents = self.scheduler.add_noise(init_latents, noise, timestep) 1409 | 1410 | latents = init_latents 1411 | 1412 | return latents 1413 | 1414 | def sd1_5_add_noise(self, init_latents, timestep, generator, device, dtype): 1415 | shape = init_latents.shape 1416 | noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) 1417 | # get latents 1418 | init_latents = self.scheduler.add_noise(init_latents, noise, timestep) 1419 | 1420 | image = self.vae_sd1_5.decode(init_latents / self.vae_sd1_5.config.scaling_factor, return_dict=False)[0] 1421 | do_denormalize = [True] * image.shape[0] 1422 | image = self.image_processor_sd1_5.postprocess(image, output_type='pil', do_denormalize=do_denormalize)[0] 1423 | # image.save(f'./test_img/noisy_image_sd1_5_{int(timestep)}.jpg') 1424 | 1425 | return init_latents 1426 | 1427 | def get_timesteps(self, num_inference_steps, strength, device, denoising_start=None): 1428 | # get the original timestep using init_timestep 1429 | if denoising_start is None: 1430 | init_timestep = min(int(num_inference_steps * strength), num_inference_steps) 1431 | t_start = max(num_inference_steps - init_timestep, 0) 1432 | else: 1433 | t_start = 0 1434 | 1435 | timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] 1436 | 1437 | # Strength is irrelevant if we directly request a timestep to start at; 1438 | # that is, strength is determined by the denoising_start instead. 1439 | if denoising_start is not None: 1440 | discrete_timestep_cutoff = int( 1441 | round( 1442 | self.scheduler.config.num_train_timesteps 1443 | - (denoising_start * self.scheduler.config.num_train_timesteps) 1444 | ) 1445 | ) 1446 | timesteps = list(filter(lambda ts: ts < discrete_timestep_cutoff, timesteps)) 1447 | return torch.tensor(timesteps), len(timesteps) 1448 | 1449 | return timesteps, num_inference_steps - t_start 1450 | 1451 | def prepare_latents_sd1_5(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): 1452 | shape = (batch_size, num_channels_latents, height // self.vae_scale_factor_sd1_5, width // self.vae_scale_factor_sd1_5) 1453 | if isinstance(generator, list) and len(generator) != batch_size: 1454 | raise ValueError( 1455 | f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" 1456 | f" size of {batch_size}. Make sure the batch size matches the length of the generators." 1457 | ) 1458 | 1459 | if latents is None: 1460 | latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) 1461 | else: 1462 | latents = latents.to(device) 1463 | 1464 | # scale the initial noise by the standard deviation required by the scheduler 1465 | latents = latents * self.scheduler_sd1_5.init_noise_sigma 1466 | return latents 1467 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate == 0.18.0 2 | controlnet_aux == 0.0.7 3 | opencv_python_headless == 4.8.0.76 4 | dataclasses == 0.6 5 | diffusers == 0.20.0 6 | einops == 0.4.1 7 | huggingface_hub == 0.17.2 8 | imageio == 2.26.0 9 | matplotlib == 3.7.1 10 | numpy == 1.23.3 11 | safetensors == 0.3.3 12 | tqdm == 4.64.1 13 | transformers == 4.25.1 14 | Pillow == 10.2.0 -------------------------------------------------------------------------------- /scripts/__pycache__/inference_controlnet.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/showlab/X-Adapter/e7348cb1b83e039dec8fe6f664266f1594e5103f/scripts/__pycache__/inference_controlnet.cpython-310.pyc -------------------------------------------------------------------------------- /scripts/__pycache__/inference_ctrlnet_tile.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/showlab/X-Adapter/e7348cb1b83e039dec8fe6f664266f1594e5103f/scripts/__pycache__/inference_ctrlnet_tile.cpython-310.pyc -------------------------------------------------------------------------------- /scripts/__pycache__/inference_lora.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/showlab/X-Adapter/e7348cb1b83e039dec8fe6f664266f1594e5103f/scripts/__pycache__/inference_lora.cpython-310.pyc -------------------------------------------------------------------------------- /scripts/__pycache__/utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/showlab/X-Adapter/e7348cb1b83e039dec8fe6f664266f1594e5103f/scripts/__pycache__/utils.cpython-310.pyc -------------------------------------------------------------------------------- /scripts/inference_controlnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import json 3 | import os 4 | import numpy as np 5 | import cv2 6 | import matplotlib 7 | from tqdm import tqdm 8 | from diffusers import DiffusionPipeline 9 | from diffusers import DPMSolverMultistepScheduler 10 | from diffusers.utils import load_image 11 | from torch import Generator 12 | from PIL import Image 13 | from packaging import version 14 | 15 | from transformers import CLIPTextModel, CLIPTokenizer, AutoTokenizer, PretrainedConfig 16 | 17 | import diffusers 18 | from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel, ControlNetModel, T2IAdapter 19 | from diffusers.optimization import get_scheduler 20 | from diffusers.training_utils import EMAModel 21 | from diffusers.utils import check_min_version, deprecate, is_wandb_available 22 | from diffusers.utils.import_utils import is_xformers_available 23 | 24 | from model.unet_adapter import UNet2DConditionModel 25 | from model.adapter import Adapter_XL 26 | from pipeline.pipeline_sd_xl_adapter_controlnet import StableDiffusionXLAdapterControlnetPipeline 27 | from controlnet_aux import MidasDetector, CannyDetector 28 | 29 | from scripts.utils import str2float 30 | 31 | 32 | def import_model_class_from_model_name_or_path( 33 | pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder" 34 | ): 35 | text_encoder_config = PretrainedConfig.from_pretrained( 36 | pretrained_model_name_or_path, subfolder=subfolder, revision=revision 37 | ) 38 | model_class = text_encoder_config.architectures[0] 39 | 40 | if model_class == "CLIPTextModel": 41 | from transformers import CLIPTextModel 42 | 43 | return CLIPTextModel 44 | elif model_class == "CLIPTextModelWithProjection": 45 | from transformers import CLIPTextModelWithProjection 46 | 47 | return CLIPTextModelWithProjection 48 | else: 49 | raise ValueError(f"{model_class} is not supported.") 50 | 51 | 52 | def inference_controlnet(args): 53 | device = 'cuda' 54 | weight_dtype = torch.float16 55 | 56 | controlnet_condition_scale_list = str2float(args.controlnet_condition_scale_list) 57 | adapter_guidance_start_list = str2float(args.adapter_guidance_start_list) 58 | adapter_condition_scale_list = str2float(args.adapter_condition_scale_list) 59 | 60 | path = args.base_path 61 | path_sdxl = args.sdxl_path 62 | path_vae_sdxl = args.path_vae_sdxl 63 | adapter_path = args.adapter_checkpoint 64 | 65 | if args.condition_type == "canny": 66 | controlnet_path = args.controlnet_canny_path 67 | canny = CannyDetector() 68 | elif args.condition_type == "depth": 69 | controlnet_path = args.controlnet_depth_path # todo: haven't defined in args 70 | depth = MidasDetector.from_pretrained("lllyasviel/Annotators") 71 | else: 72 | raise NotImplementedError("not implemented yet") 73 | 74 | prompt = args.prompt 75 | if args.prompt_sd1_5 is None: 76 | prompt_sd1_5 = prompt 77 | else: 78 | prompt_sd1_5 = args.prompt_sd1_5 79 | 80 | if args.negative_prompt is None: 81 | negative_prompt = "(deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime:1.4), text, close up, cropped, out of frame, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck" 82 | else: 83 | negative_prompt = args.negative_prompt 84 | 85 | torch.set_grad_enabled(False) 86 | torch.backends.cudnn.benchmark = True 87 | 88 | # load controlnet 89 | controlnet = ControlNetModel.from_pretrained( 90 | controlnet_path, torch_dtype=weight_dtype 91 | ) 92 | print('successfully load controlnet') 93 | 94 | input_image = Image.open(args.input_image_path) 95 | # input_image = input_image.resize((512, 512), Image.LANCZOS) 96 | input_image = input_image.resize((args.width_sd1_5, args.height_sd1_5), Image.LANCZOS) 97 | if args.condition_type == "canny": 98 | control_image = canny(input_image) 99 | control_image.save(f'{args.save_path}/{prompt[:10]}_canny_condition.png') 100 | elif args.condition_type == "depth": 101 | control_image = depth(input_image) 102 | control_image.save(f'{args.save_path}/{prompt[:10]}_depth_condition.png') 103 | 104 | # load adapter 105 | adapter = Adapter_XL() 106 | ckpt = torch.load(adapter_path) 107 | adapter.load_state_dict(ckpt) 108 | adapter.to(weight_dtype) 109 | print('successfully load adapter') 110 | # load SD1.5 111 | noise_scheduler_sd1_5 = DDPMScheduler.from_pretrained( 112 | path, subfolder="scheduler" 113 | ) 114 | tokenizer_sd1_5 = CLIPTokenizer.from_pretrained( 115 | path, subfolder="tokenizer", revision=None, torch_dtype=weight_dtype 116 | ) 117 | text_encoder_sd1_5 = CLIPTextModel.from_pretrained( 118 | path, subfolder="text_encoder", revision=None, torch_dtype=weight_dtype 119 | ) 120 | vae_sd1_5 = AutoencoderKL.from_pretrained( 121 | path, subfolder="vae", revision=None, torch_dtype=weight_dtype 122 | ) 123 | unet_sd1_5 = UNet2DConditionModel.from_pretrained( 124 | path, subfolder="unet", revision=None, torch_dtype=weight_dtype 125 | ) 126 | print('successfully load SD1.5') 127 | # load SDXL 128 | tokenizer_one = AutoTokenizer.from_pretrained( 129 | path_sdxl, subfolder="tokenizer", revision=None, use_fast=False, torch_dtype=weight_dtype 130 | ) 131 | tokenizer_two = AutoTokenizer.from_pretrained( 132 | path_sdxl, subfolder="tokenizer_2", revision=None, use_fast=False, torch_dtype=weight_dtype 133 | ) 134 | # import correct text encoder classes 135 | text_encoder_cls_one = import_model_class_from_model_name_or_path( 136 | path_sdxl, None 137 | ) 138 | text_encoder_cls_two = import_model_class_from_model_name_or_path( 139 | path_sdxl, None, subfolder="text_encoder_2" 140 | ) 141 | # Load scheduler and models 142 | noise_scheduler = DDPMScheduler.from_pretrained(path_sdxl, subfolder="scheduler") 143 | text_encoder_one = text_encoder_cls_one.from_pretrained( 144 | path_sdxl, subfolder="text_encoder", revision=None, torch_dtype=weight_dtype 145 | ) 146 | text_encoder_two = text_encoder_cls_two.from_pretrained( 147 | path_sdxl, subfolder="text_encoder_2", revision=None, torch_dtype=weight_dtype 148 | ) 149 | vae = AutoencoderKL.from_pretrained( 150 | path_vae_sdxl, revision=None, torch_dtype=weight_dtype 151 | ) 152 | unet = UNet2DConditionModel.from_pretrained( 153 | path_sdxl, subfolder="unet", revision=None, torch_dtype=weight_dtype 154 | ) 155 | print('successfully load SDXL') 156 | 157 | 158 | if is_xformers_available(): 159 | import xformers 160 | 161 | xformers_version = version.parse(xformers.__version__) 162 | if xformers_version == version.parse("0.0.16"): 163 | logger.warn( 164 | "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." 165 | ) 166 | unet.enable_xformers_memory_efficient_attention() 167 | unet_sd1_5.enable_xformers_memory_efficient_attention() 168 | controlnet.enable_xformers_memory_efficient_attention() 169 | 170 | 171 | with torch.inference_mode(): 172 | gen = Generator("cuda") 173 | gen.manual_seed(args.seed) 174 | pipe = StableDiffusionXLAdapterControlnetPipeline( 175 | vae=vae, 176 | text_encoder=text_encoder_one, 177 | text_encoder_2=text_encoder_two, 178 | tokenizer=tokenizer_one, 179 | tokenizer_2=tokenizer_two, 180 | unet=unet, 181 | scheduler=noise_scheduler, 182 | vae_sd1_5=vae_sd1_5, 183 | text_encoder_sd1_5=text_encoder_sd1_5, 184 | tokenizer_sd1_5=tokenizer_sd1_5, 185 | unet_sd1_5=unet_sd1_5, 186 | scheduler_sd1_5=noise_scheduler_sd1_5, 187 | adapter=adapter, 188 | controlnet=controlnet 189 | ) 190 | 191 | pipe.enable_model_cpu_offload() 192 | 193 | pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) 194 | pipe.scheduler_sd1_5 = DPMSolverMultistepScheduler.from_config(pipe.scheduler_sd1_5.config) 195 | pipe.scheduler_sd1_5.config.timestep_spacing = "leading" 196 | pipe.unet.to(device=device, dtype=torch.float16, memory_format=torch.channels_last) 197 | 198 | for i in range(args.iter_num): 199 | for controlnet_condition_scale in controlnet_condition_scale_list: 200 | for adapter_guidance_start in adapter_guidance_start_list: 201 | for adapter_condition_scale in adapter_condition_scale_list: 202 | img = \ 203 | pipe(prompt=prompt, negative_prompt=negative_prompt, prompt_sd1_5=prompt_sd1_5, 204 | width=args.width, height=args.height, height_sd1_5=args.height_sd1_5, 205 | width_sd1_5=args.width_sd1_5, image=control_image, 206 | num_inference_steps=args.num_inference_steps, guidance_scale=args.guidance_scale, 207 | num_images_per_prompt=1, generator=gen, 208 | controlnet_conditioning_scale=controlnet_condition_scale, 209 | adapter_condition_scale=adapter_condition_scale, 210 | adapter_guidance_start=adapter_guidance_start).images[0] 211 | img.save( 212 | f"{args.save_path}/{prompt[:10]}_{i}_ccs_{controlnet_condition_scale:.2f}_ags_{adapter_guidance_start:.2f}_acs_{adapter_condition_scale:.2f}.png") 213 | 214 | print(f"results saved in {args.save_path}") 215 | -------------------------------------------------------------------------------- /scripts/inference_ctrlnet_tile.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import json 3 | import os 4 | import numpy as np 5 | import cv2 6 | from tqdm import tqdm 7 | from diffusers import DiffusionPipeline 8 | from diffusers import DPMSolverMultistepScheduler 9 | from diffusers.utils import load_image 10 | from torch import Generator 11 | from safetensors.torch import load_file 12 | from PIL import Image 13 | from packaging import version 14 | from huggingface_hub import HfApi 15 | from pathlib import Path 16 | 17 | from transformers import CLIPTextModel, CLIPTokenizer, AutoTokenizer, PretrainedConfig 18 | 19 | import diffusers 20 | from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel, ControlNetModel, T2IAdapter, StableDiffusionControlNetPipeline 21 | from diffusers.optimization import get_scheduler 22 | from diffusers.training_utils import EMAModel 23 | from diffusers.utils import check_min_version, deprecate, is_wandb_available 24 | from diffusers.utils.import_utils import is_xformers_available 25 | 26 | from model.unet_adapter import UNet2DConditionModel as UNet2DConditionModel_v2 27 | from model.adapter import Adapter_XL 28 | from pipeline.pipeline_sd_xl_adapter_controlnet_img2img import StableDiffusionXLAdapterControlnetI2IPipeline 29 | from scripts.utils import str2float 30 | 31 | def import_model_class_from_model_name_or_path( 32 | pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder" 33 | ): 34 | text_encoder_config = PretrainedConfig.from_pretrained( 35 | pretrained_model_name_or_path, subfolder=subfolder, revision=revision 36 | ) 37 | model_class = text_encoder_config.architectures[0] 38 | 39 | if model_class == "CLIPTextModel": 40 | from transformers import CLIPTextModel 41 | 42 | return CLIPTextModel 43 | elif model_class == "CLIPTextModelWithProjection": 44 | from transformers import CLIPTextModelWithProjection 45 | 46 | return CLIPTextModelWithProjection 47 | else: 48 | raise ValueError(f"{model_class} is not supported.") 49 | 50 | 51 | def resize_for_condition_image(input_image: Image, resolution: int): 52 | input_image = input_image.convert("RGB") 53 | W, H = input_image.size 54 | k = float(resolution) / min(H, W) 55 | H *= k 56 | W *= k 57 | H = int(round(H / 64.0)) * 64 58 | W = int(round(W / 64.0)) * 64 59 | img = input_image.resize((W, H), resample=Image.LANCZOS) 60 | return img 61 | 62 | 63 | def inference_ctrlnet_tile(args): 64 | device = 'cuda' 65 | weight_dtype = torch.float16 66 | 67 | controlnet_condition_scale_list = str2float(args.controlnet_condition_scale_list) 68 | adapter_guidance_start_list = str2float(args.adapter_guidance_start_list) 69 | adapter_condition_scale_list = str2float(args.adapter_condition_scale_list) 70 | 71 | path = args.base_path 72 | path_sdxl = args.sdxl_path 73 | path_vae_sdxl = args.path_vae_sdxl 74 | adapter_path = args.adapter_checkpoint 75 | controlnet_path = args.controlnet_tile_path 76 | 77 | prompt = args.prompt 78 | if args.prompt_sd1_5 is None: 79 | prompt_sd1_5 = prompt 80 | else: 81 | prompt_sd1_5 = args.prompt_sd1_5 82 | 83 | if args.negative_prompt is None: 84 | negative_prompt = "(deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime:1.4), text, close up, cropped, out of frame, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck" 85 | else: 86 | negative_prompt = args.negative_prompt 87 | 88 | torch.set_grad_enabled(False) 89 | torch.backends.cudnn.benchmark = True 90 | 91 | # load controlnet 92 | controlnet = ControlNetModel.from_pretrained( 93 | controlnet_path, torch_dtype=weight_dtype 94 | ) 95 | 96 | source_image = Image.open(args.input_image_path) 97 | # control_image = resize_for_condition_image(source_image, 512) 98 | input_image = source_image.convert("RGB") 99 | control_image = input_image.resize((args.width_sd1_5, args.height_sd1_5), resample=Image.LANCZOS) 100 | 101 | print('successfully load controlnet') 102 | # load adapter 103 | adapter = Adapter_XL() 104 | ckpt = torch.load(adapter_path) 105 | adapter.load_state_dict(ckpt) 106 | adapter.to(weight_dtype) 107 | print('successfully load adapter') 108 | # load SD1.5 109 | noise_scheduler_sd1_5 = DDPMScheduler.from_pretrained( 110 | path, subfolder="scheduler" 111 | ) 112 | tokenizer_sd1_5 = CLIPTokenizer.from_pretrained( 113 | path, subfolder="tokenizer", revision=None, torch_dtype=weight_dtype 114 | ) 115 | text_encoder_sd1_5 = CLIPTextModel.from_pretrained( 116 | path, subfolder="text_encoder", revision=None, torch_dtype=weight_dtype 117 | ) 118 | vae_sd1_5 = AutoencoderKL.from_pretrained( 119 | path, subfolder="vae", revision=None, torch_dtype=weight_dtype 120 | ) 121 | unet_sd1_5 = UNet2DConditionModel_v2.from_pretrained( 122 | path, subfolder="unet", revision=None, torch_dtype=weight_dtype 123 | ) 124 | print('successfully load SD1.5') 125 | # load SDXL 126 | tokenizer_one = AutoTokenizer.from_pretrained( 127 | path_sdxl, subfolder="tokenizer", revision=None, use_fast=False, torch_dtype=weight_dtype 128 | ) 129 | tokenizer_two = AutoTokenizer.from_pretrained( 130 | path_sdxl, subfolder="tokenizer_2", revision=None, use_fast=False, torch_dtype=weight_dtype 131 | ) 132 | # import correct text encoder classes 133 | text_encoder_cls_one = import_model_class_from_model_name_or_path( 134 | path_sdxl, None 135 | ) 136 | text_encoder_cls_two = import_model_class_from_model_name_or_path( 137 | path_sdxl, None, subfolder="text_encoder_2" 138 | ) 139 | # Load scheduler and models 140 | noise_scheduler = DDPMScheduler.from_pretrained(path_sdxl, subfolder="scheduler") 141 | text_encoder_one = text_encoder_cls_one.from_pretrained( 142 | path_sdxl, subfolder="text_encoder", revision=None, torch_dtype=weight_dtype 143 | ) 144 | text_encoder_two = text_encoder_cls_two.from_pretrained( 145 | path_sdxl, subfolder="text_encoder_2", revision=None, torch_dtype=weight_dtype 146 | ) 147 | vae = AutoencoderKL.from_pretrained( 148 | path_vae_sdxl, revision=None, torch_dtype=weight_dtype 149 | ) 150 | unet = UNet2DConditionModel_v2.from_pretrained( 151 | path_sdxl, subfolder="unet", revision=None, torch_dtype=weight_dtype 152 | ) 153 | print('successfully load SDXL') 154 | 155 | if is_xformers_available(): 156 | import xformers 157 | 158 | xformers_version = version.parse(xformers.__version__) 159 | if xformers_version == version.parse("0.0.16"): 160 | logger.warn( 161 | "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." 162 | ) 163 | unet.enable_xformers_memory_efficient_attention() 164 | unet_sd1_5.enable_xformers_memory_efficient_attention() 165 | controlnet.enable_xformers_memory_efficient_attention() 166 | 167 | with torch.inference_mode(): 168 | gen = Generator(device) 169 | gen.manual_seed(args.seed) 170 | pipe = StableDiffusionXLAdapterControlnetI2IPipeline( 171 | vae=vae, 172 | text_encoder=text_encoder_one, 173 | text_encoder_2=text_encoder_two, 174 | tokenizer=tokenizer_one, 175 | tokenizer_2=tokenizer_two, 176 | unet=unet, 177 | scheduler=noise_scheduler, 178 | vae_sd1_5=vae_sd1_5, 179 | text_encoder_sd1_5=text_encoder_sd1_5, 180 | tokenizer_sd1_5=tokenizer_sd1_5, 181 | unet_sd1_5=unet_sd1_5, 182 | scheduler_sd1_5=noise_scheduler_sd1_5, 183 | adapter=adapter, 184 | controlnet=controlnet 185 | ) 186 | pipe.enable_model_cpu_offload() 187 | 188 | pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) 189 | pipe.scheduler_sd1_5 = DPMSolverMultistepScheduler.from_config(pipe.scheduler_sd1_5.config) 190 | pipe.scheduler_sd1_5.config.timestep_spacing = "leading" 191 | pipe.unet.to(device=device, dtype=weight_dtype, memory_format=torch.channels_last) 192 | 193 | 194 | for i in range(args.iter_num): 195 | for controlnet_condition_scale in controlnet_condition_scale_list: 196 | for adapter_guidance_start in adapter_guidance_start_list: 197 | for adapter_condition_scale in adapter_condition_scale_list: 198 | img = \ 199 | pipe(prompt=prompt, negative_prompt=negative_prompt, prompt_sd1_5=prompt_sd1_5, 200 | width=args.width, height=args.height, height_sd1_5=args.height_sd1_5, 201 | width_sd1_5=args.width_sd1_5, source_img=control_image, image=control_image, 202 | num_inference_steps=args.num_inference_steps, guidance_scale=args.guidance_scale, 203 | num_images_per_prompt=1, generator=gen, 204 | controlnet_conditioning_scale=controlnet_condition_scale, 205 | adapter_condition_scale=adapter_condition_scale, 206 | adapter_guidance_start=adapter_guidance_start).images[0] 207 | img.save( 208 | f"{args.save_path}/{prompt[:10]}_{i}_ccs_{controlnet_condition_scale:.2f}_ags_{adapter_guidance_start:.2f}_acs_{adapter_condition_scale:.2f}.png") 209 | 210 | print(f"results saved in {args.save_path}") 211 | 212 | -------------------------------------------------------------------------------- /scripts/inference_lora.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import json 3 | import os 4 | import numpy as np 5 | import cv2 6 | from tqdm import tqdm 7 | from diffusers import DiffusionPipeline 8 | from diffusers import DPMSolverMultistepScheduler 9 | from diffusers.utils import load_image 10 | from torch import Generator 11 | from safetensors.torch import load_file 12 | from PIL import Image 13 | from packaging import version 14 | 15 | from transformers import CLIPTextModel, CLIPTokenizer, AutoTokenizer, PretrainedConfig 16 | 17 | import diffusers 18 | from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel, ControlNetModel, \ 19 | T2IAdapter 20 | from diffusers.optimization import get_scheduler 21 | from diffusers.training_utils import EMAModel 22 | from diffusers.utils import check_min_version, deprecate, is_wandb_available 23 | from diffusers.utils.import_utils import is_xformers_available 24 | 25 | from model.unet_adapter import UNet2DConditionModel 26 | from pipeline.pipeline_sd_xl_adapter import StableDiffusionXLAdapterPipeline 27 | from model.adapter import Adapter_XL 28 | from scripts.utils import str2float 29 | 30 | 31 | def import_model_class_from_model_name_or_path( 32 | pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder" 33 | ): 34 | text_encoder_config = PretrainedConfig.from_pretrained( 35 | pretrained_model_name_or_path, subfolder=subfolder, revision=revision 36 | ) 37 | model_class = text_encoder_config.architectures[0] 38 | 39 | if model_class == "CLIPTextModel": 40 | from transformers import CLIPTextModel 41 | 42 | return CLIPTextModel 43 | elif model_class == "CLIPTextModelWithProjection": 44 | from transformers import CLIPTextModelWithProjection 45 | 46 | return CLIPTextModelWithProjection 47 | else: 48 | raise ValueError(f"{model_class} is not supported.") 49 | 50 | 51 | def load_lora(pipeline, lora_model_path, alpha): 52 | state_dict = load_file(lora_model_path) 53 | 54 | LORA_PREFIX_UNET = 'lora_unet' 55 | LORA_PREFIX_TEXT_ENCODER = 'lora_te' 56 | 57 | visited = [] 58 | 59 | # directly update weight in diffusers model 60 | for key in state_dict: 61 | 62 | # it is suggested to print out the key, it usually will be something like below 63 | # "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight" 64 | 65 | # as we have set the alpha beforehand, so just skip 66 | if '.alpha' in key or key in visited: 67 | continue 68 | 69 | if 'text' in key: 70 | layer_infos = key.split('.')[0].split(LORA_PREFIX_TEXT_ENCODER + '_')[-1].split('_') 71 | curr_layer = pipeline.text_encoder_sd1_5 72 | else: 73 | layer_infos = key.split('.')[0].split(LORA_PREFIX_UNET + '_')[-1].split('_') 74 | curr_layer = pipeline.unet_sd1_5 75 | 76 | # find the target layer 77 | temp_name = layer_infos.pop(0) 78 | while len(layer_infos) > -1: 79 | try: 80 | curr_layer = curr_layer.__getattr__(temp_name) 81 | if len(layer_infos) > 0: 82 | temp_name = layer_infos.pop(0) 83 | elif len(layer_infos) == 0: 84 | break 85 | except Exception: 86 | if len(temp_name) > 0: 87 | temp_name += '_' + layer_infos.pop(0) 88 | else: 89 | temp_name = layer_infos.pop(0) 90 | 91 | # org_forward(x) + lora_up(lora_down(x)) * multiplier 92 | pair_keys = [] 93 | if 'lora_down' in key: 94 | pair_keys.append(key.replace('lora_down', 'lora_up')) 95 | pair_keys.append(key) 96 | else: 97 | pair_keys.append(key) 98 | pair_keys.append(key.replace('lora_up', 'lora_down')) 99 | 100 | # update weight 101 | if len(state_dict[pair_keys[0]].shape) == 4: 102 | weight_up = state_dict[pair_keys[0]].squeeze(3).squeeze(2).to(torch.float32) 103 | weight_down = state_dict[pair_keys[1]].squeeze(3).squeeze(2).to(torch.float32) 104 | curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3) 105 | else: 106 | weight_up = state_dict[pair_keys[0]].to(torch.float32) 107 | weight_down = state_dict[pair_keys[1]].to(torch.float32) 108 | curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down) 109 | 110 | # update visited list 111 | for item in pair_keys: 112 | visited.append(item) 113 | 114 | 115 | def inference_lora(args): 116 | device = 'cuda' 117 | weight_dtype = torch.float16 118 | 119 | adapter_guidance_start_list = str2float(args.adapter_guidance_start_list) 120 | adapter_condition_scale_list = str2float(args.adapter_condition_scale_list) 121 | 122 | path = args.base_path 123 | path_sdxl = args.sdxl_path 124 | path_vae_sdxl = args.path_vae_sdxl 125 | adapter_path = args.adapter_checkpoint 126 | lora_model_path = args.lora_model_path 127 | 128 | prompt = args.prompt 129 | if args.prompt_sd1_5 is None: 130 | prompt_sd1_5 = prompt 131 | else: 132 | prompt_sd1_5 = args.prompt_sd1_5 133 | 134 | if args.negative_prompt is None: 135 | negative_prompt = "(deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime:1.4), text, close up, cropped, out of frame, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck" 136 | else: 137 | negative_prompt = args.negative_prompt 138 | 139 | torch.set_grad_enabled(False) 140 | torch.backends.cudnn.benchmark = True 141 | 142 | # load adapter 143 | adapter = Adapter_XL() 144 | ckpt = torch.load(adapter_path) 145 | adapter.load_state_dict(ckpt) 146 | print('successfully load adapter') 147 | # load SD1.5 148 | noise_scheduler_sd1_5 = DDPMScheduler.from_pretrained( 149 | path, subfolder="scheduler" 150 | ) 151 | tokenizer_sd1_5 = CLIPTokenizer.from_pretrained( 152 | path, subfolder="tokenizer", revision=None 153 | ) 154 | text_encoder_sd1_5 = CLIPTextModel.from_pretrained( 155 | path, subfolder="text_encoder", revision=None 156 | ) 157 | vae_sd1_5 = AutoencoderKL.from_pretrained( 158 | path, subfolder="vae", revision=None 159 | ) 160 | unet_sd1_5 = UNet2DConditionModel.from_pretrained( 161 | path, subfolder="unet", revision=None 162 | ) 163 | print('successfully load SD1.5') 164 | # load SDXL 165 | tokenizer_one = AutoTokenizer.from_pretrained( 166 | path_sdxl, subfolder="tokenizer", revision=None, use_fast=False 167 | ) 168 | tokenizer_two = AutoTokenizer.from_pretrained( 169 | path_sdxl, subfolder="tokenizer_2", revision=None, use_fast=False 170 | ) 171 | # import correct text encoder classes 172 | text_encoder_cls_one = import_model_class_from_model_name_or_path( 173 | path_sdxl, None 174 | ) 175 | text_encoder_cls_two = import_model_class_from_model_name_or_path( 176 | path_sdxl, None, subfolder="text_encoder_2" 177 | ) 178 | # Load scheduler and models 179 | noise_scheduler = DDPMScheduler.from_pretrained(path_sdxl, subfolder="scheduler") 180 | text_encoder_one = text_encoder_cls_one.from_pretrained( 181 | path_sdxl, subfolder="text_encoder", revision=None 182 | ) 183 | text_encoder_two = text_encoder_cls_two.from_pretrained( 184 | path_sdxl, subfolder="text_encoder_2", revision=None 185 | ) 186 | vae = AutoencoderKL.from_pretrained( 187 | path_vae_sdxl, revision=None 188 | ) 189 | unet = UNet2DConditionModel.from_pretrained( 190 | path_sdxl, subfolder="unet", revision=None 191 | ) 192 | print('successfully load SDXL') 193 | 194 | if is_xformers_available(): 195 | import xformers 196 | 197 | xformers_version = version.parse(xformers.__version__) 198 | if xformers_version == version.parse("0.0.16"): 199 | logger.warn( 200 | "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." 201 | ) 202 | unet.enable_xformers_memory_efficient_attention() 203 | unet_sd1_5.enable_xformers_memory_efficient_attention() 204 | 205 | with torch.inference_mode(): 206 | gen = Generator("cuda") 207 | gen.manual_seed(args.seed) 208 | 209 | pipe = StableDiffusionXLAdapterPipeline( 210 | vae=vae, 211 | text_encoder=text_encoder_one, 212 | text_encoder_2=text_encoder_two, 213 | tokenizer=tokenizer_one, 214 | tokenizer_2=tokenizer_two, 215 | unet=unet, 216 | scheduler=noise_scheduler, 217 | vae_sd1_5=vae_sd1_5, 218 | text_encoder_sd1_5=text_encoder_sd1_5, 219 | tokenizer_sd1_5=tokenizer_sd1_5, 220 | unet_sd1_5=unet_sd1_5, 221 | scheduler_sd1_5=noise_scheduler_sd1_5, 222 | adapter=adapter, 223 | ) 224 | # load lora 225 | load_lora(pipe, lora_model_path, 1) 226 | print('successfully load lora') 227 | 228 | pipe.to('cuda', weight_dtype) 229 | pipe.enable_model_cpu_offload() 230 | pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config) 231 | pipe.scheduler_sd1_5 = DPMSolverMultistepScheduler.from_config(pipe.scheduler_sd1_5.config) 232 | pipe.scheduler_sd1_5.config.timestep_spacing = "leading" 233 | 234 | for i in range(args.iter_num): 235 | for adapter_guidance_start in adapter_guidance_start_list: 236 | for adapter_condition_scale in adapter_condition_scale_list: 237 | img = \ 238 | pipe(prompt=prompt, prompt_sd1_5=prompt_sd1_5, negative_prompt=negative_prompt, width=args.width, 239 | height=args.height, height_sd1_5=args.height_sd1_5, width_sd1_5=args.width_sd1_5, 240 | num_inference_steps=args.num_inference_steps, guidance_scale=args.guidance_scale, 241 | num_images_per_prompt=1, generator=gen, 242 | adapter_guidance_start=adapter_guidance_start, 243 | adapter_condition_scale=adapter_condition_scale).images[0] 244 | img.save( 245 | f"{args.save_path}/{prompt[:10]}_{i}_ags_{adapter_guidance_start:.2f}_acs_{adapter_condition_scale:.2f}.png") 246 | print(f"results saved in {args.save_path}") 247 | 248 | 249 | 250 | 251 | 252 | 253 | -------------------------------------------------------------------------------- /scripts/utils.py: -------------------------------------------------------------------------------- 1 | def str2float(x): 2 | for i in range(len(x)): 3 | x[i] = float(x[i]) 4 | return x 5 | --------------------------------------------------------------------------------